Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.10"
dependencies = [
"numpy>=1.22.0",
"scipy>=1.13.1",
"kirin-toolchain~=0.14.0",
"kirin-toolchain~=0.16.0",
"rich>=13.9.4",
"pydantic>=1.3.0,<2.11.0",
"pandas>=2.2.3",
Expand Down Expand Up @@ -137,3 +137,6 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.coverage.run]
include = ["src/bloqade/*"]

[tool.pytest.ini_options]
testpaths = "test/"
8 changes: 4 additions & 4 deletions src/bloqade/noise/native/stmts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kirin import ir, types
from kirin import ir, types, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist

Expand All @@ -10,7 +10,7 @@
@statement(dialect=dialect)
class PauliChannel(ir.Statement):

traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})

px: float = info.attribute(types.Float)
py: float = info.attribute(types.Float)
Expand All @@ -24,7 +24,7 @@ class PauliChannel(ir.Statement):
@statement(dialect=dialect)
class CZPauliChannel(ir.Statement):

traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})

paired: bool = info.attribute(types.Bool)
px_ctrl: float = info.attribute(types.Float)
Expand All @@ -40,7 +40,7 @@ class CZPauliChannel(ir.Statement):
@statement(dialect=dialect)
class AtomLossChannel(ir.Statement):

traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})

prob: float = info.attribute(types.Float)
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
7 changes: 2 additions & 5 deletions src/bloqade/pyqrack/noise/native.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from typing import TYPE_CHECKING, List
from typing import List

from kirin import interp

from bloqade.noise import native
from bloqade.pyqrack import PyQrackInterpreter, reg

if TYPE_CHECKING:
from pyqrack import QrackSimulator


@native.dialect.register(key="pyqrack")
class PyQrackMethods(interp.MethodTable):
Expand Down Expand Up @@ -90,7 +87,7 @@ def atom_loss_channel(
frame: interp.Frame,
stmt: native.AtomLossChannel,
):
qargs: List[reg.PyQrackQubit["QrackSimulator"]] = frame.get(stmt.qargs)
qargs: List[reg.PyQrackQubit] = frame.get(stmt.qargs)

active_qubits = (qarg for qarg in qargs if qarg.is_active())

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import emit as emit, address as address, typeinfer as typeinfer
from . import _emit as _emit, address as address, _typeinfer as _typeinfer
from .stmts import * # noqa: F403
from ._dialect import dialect as dialect
16 changes: 8 additions & 8 deletions src/bloqade/qasm2/dialects/core/stmts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kirin import ir, types
from kirin import ir, types, lowering
from kirin.decl import info, statement

from bloqade.qasm2.types import BitType, CRegType, QRegType, QubitType
Expand All @@ -11,7 +11,7 @@ class QRegNew(ir.Statement):
"""Create a new quantum register."""

name = "qreg.new"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
n_qubits: ir.SSAValue = info.argument(types.Int)
"""n_qubits: The number of qubits in the register."""
result: ir.ResultValue = info.result(QRegType)
Expand All @@ -23,7 +23,7 @@ class CRegNew(ir.Statement):
"""Create a new classical register."""

name = "creg.new"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
n_bits: ir.SSAValue = info.argument(types.Int)
"""n_bits (Int): The number of bits in the register."""
result: ir.ResultValue = info.result(CRegType)
Expand All @@ -35,7 +35,7 @@ class Reset(ir.Statement):
"""Reset a qubit to the |0> state."""

name = "reset"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
qarg: ir.SSAValue = info.argument(QubitType)
"""qarg (Qubit): The qubit to reset."""

Expand All @@ -45,7 +45,7 @@ class Measure(ir.Statement):
"""Measure a qubit and store the result in a bit."""

name = "measure"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
qarg: ir.SSAValue = info.argument(QubitType)
"""qarg (Qubit): The qubit to measure."""
carg: ir.SSAValue = info.argument(BitType)
Expand All @@ -57,7 +57,7 @@ class CRegEq(ir.Statement):
"""Check if two classical registers are equal."""

name = "eq"
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType)
"""lhs (CReg): The first register."""
rhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType)
Expand All @@ -71,7 +71,7 @@ class QRegGet(ir.Statement):
"""Get a qubit from a quantum register."""

name = "qreg.get"
traits = frozenset({ir.FromPythonCall(), ir.Pure()})
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
reg: ir.SSAValue = info.argument(QRegType)
"""reg (QReg): The quantum register."""
idx: ir.SSAValue = info.argument(types.Int)
Expand All @@ -85,7 +85,7 @@ class CRegGet(ir.Statement):
"""Get a bit from a classical register."""

name = "creg.get"
traits = frozenset({ir.FromPythonCall(), ir.Pure()})
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
reg: ir.SSAValue = info.argument(CRegType)
"""reg (CReg): The classical register."""
idx: ir.SSAValue = info.argument(types.Int)
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/qasm2/dialects/expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import emit as emit, interp as interp, lowering as lowering
from . import _emit as _emit, _interp as _interp, _from_python as _from_python
from .stmts import * # noqa: F403
from ._dialect import dialect as dialect
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
import ast

from kirin import ir, types
from kirin.lowering import Result, FromPythonAST, LoweringState
from kirin.exceptions import DialectLoweringError
from kirin import ir, types, lowering

from . import stmts
from ._dialect import dialect


@dialect.register
class QASMUopLowering(FromPythonAST):
class QASMUopLowering(lowering.FromPythonAST):

def lower_Name(self, state: LoweringState, node: ast.Name) -> Result:
def lower_Name(self, state: lowering.State, node: ast.Name):
name = node.id
if isinstance(node.ctx, ast.Load):
value = state.current_frame.get(name)
if value is None:
raise DialectLoweringError(f"{name} is not defined")
return Result(value)
raise lowering.BuildError(f"{name} is not defined")
return value
elif isinstance(node.ctx, ast.Store):
raise DialectLoweringError("unhandled store operation")
raise lowering.BuildError("unhandled store operation")
else: # Del
raise DialectLoweringError("unhandled del operation")
raise lowering.BuildError("unhandled del operation")

def lower_Assign(self, state: LoweringState, node: ast.Assign) -> Result:
def lower_Assign(self, state: lowering.State, node: ast.Assign):
# NOTE: QASM only expects one value on right hand side
rhs = state.visit(node.value).expect_one()
rhs = state.lower(node.value).expect_one()
current_frame = state.current_frame
match node:
case ast.Assign(targets=[ast.Name(lhs_name, ast.Store())], value=_):
Expand All @@ -35,27 +33,26 @@ def lower_Assign(self, state: LoweringState, node: ast.Assign) -> Result:
target_syntax = ", ".join(
ast.unparse(target) for target in node.targets
)
raise DialectLoweringError(f"unsupported target syntax {target_syntax}")
return Result() # python assign does not have value
raise lowering.BuildError(f"unsupported target syntax {target_syntax}")

def lower_Expr(self, state: LoweringState, node: ast.Expr) -> Result:
return state.visit(node.value)
def lower_Expr(self, state: lowering.State, node: ast.Expr):
return state.parent.visit(state, node.value)

def lower_Constant(self, state: LoweringState, node: ast.Constant) -> Result:
def lower_Constant(self, state: lowering.State, node: ast.Constant):
if isinstance(node.value, int):
stmt = stmts.ConstInt(value=node.value)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)
elif isinstance(node.value, float):
stmt = stmts.ConstFloat(value=node.value)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)
else:
raise DialectLoweringError(
raise lowering.BuildError(
f"unsupported QASM 2.0 constant type {type(node.value)}"
)

def lower_BinOp(self, state: LoweringState, node: ast.BinOp) -> Result:
lhs = state.visit(node.left).expect_one()
rhs = state.visit(node.right).expect_one()
def lower_BinOp(self, state: lowering.State, node: ast.BinOp):
lhs = state.lower(node.left).expect_one()
rhs = state.lower(node.right).expect_one()
if isinstance(node.op, ast.Add):
stmt = stmts.Add(lhs, rhs)
elif isinstance(node.op, ast.Sub):
Expand All @@ -67,9 +64,9 @@ def lower_BinOp(self, state: LoweringState, node: ast.BinOp) -> Result:
elif isinstance(node.op, ast.Pow):
stmt = stmts.Pow(lhs, rhs)
else:
raise DialectLoweringError(f"unsupported QASM 2.0 binop {node.op}")
raise lowering.BuildError(f"unsupported QASM 2.0 binop {node.op}")
stmt.result.type = self.__promote_binop_type(lhs, rhs)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)

def __promote_binop_type(
self, lhs: ir.SSAValue, rhs: ir.SSAValue
Expand All @@ -78,12 +75,12 @@ def __promote_binop_type(
return types.Float
return types.Int

def lower_UnaryOp(self, state: LoweringState, node: ast.UnaryOp) -> Result:
def lower_UnaryOp(self, state: lowering.State, node: ast.UnaryOp):
if isinstance(node.op, ast.USub):
value = state.visit(node.operand).expect_one()
value = state.lower(node.operand).expect_one()
stmt = stmts.Neg(value)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)
elif isinstance(node.op, ast.UAdd):
return state.visit(node.operand)
return state.lower(node.operand).expect_one()
else:
raise DialectLoweringError(f"unsupported QASM 2.0 unaryop {node.op}")
raise lowering.BuildError(f"unsupported QASM 2.0 unaryop {node.op}")
Loading