diff --git a/pyproject.toml b/pyproject.toml index d842e484..20367a5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ requires-python = ">=3.10" dependencies = [ "numpy>=1.22.0", "scipy>=1.13.1", - "kirin-toolchain~=0.14.0", + "kirin-toolchain~=0.16.0", "rich>=13.9.4", "pydantic>=1.3.0,<2.11.0", "pandas>=2.2.3", @@ -137,3 +137,6 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" [tool.coverage.run] include = ["src/bloqade/*"] + +[tool.pytest.ini_options] +testpaths = "test/" \ No newline at end of file diff --git a/src/bloqade/noise/native/stmts.py b/src/bloqade/noise/native/stmts.py index 1e78e01a..d969e71c 100644 --- a/src/bloqade/noise/native/stmts.py +++ b/src/bloqade/noise/native/stmts.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from kirin.dialects import ilist @@ -10,7 +10,7 @@ @statement(dialect=dialect) class PauliChannel(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) px: float = info.attribute(types.Float) py: float = info.attribute(types.Float) @@ -24,7 +24,7 @@ class PauliChannel(ir.Statement): @statement(dialect=dialect) class CZPauliChannel(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) paired: bool = info.attribute(types.Bool) px_ctrl: float = info.attribute(types.Float) @@ -40,7 +40,7 @@ class CZPauliChannel(ir.Statement): @statement(dialect=dialect) class AtomLossChannel(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) prob: float = info.attribute(types.Float) qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType]) diff --git a/src/bloqade/pyqrack/noise/native.py b/src/bloqade/pyqrack/noise/native.py index df51c504..940038ca 100644 --- a/src/bloqade/pyqrack/noise/native.py +++ b/src/bloqade/pyqrack/noise/native.py @@ -1,13 +1,10 @@ -from typing import TYPE_CHECKING, List +from typing import List from kirin import interp from bloqade.noise import native from bloqade.pyqrack import PyQrackInterpreter, reg -if TYPE_CHECKING: - from pyqrack import QrackSimulator - @native.dialect.register(key="pyqrack") class PyQrackMethods(interp.MethodTable): @@ -90,7 +87,7 @@ def atom_loss_channel( frame: interp.Frame, stmt: native.AtomLossChannel, ): - qargs: List[reg.PyQrackQubit["QrackSimulator"]] = frame.get(stmt.qargs) + qargs: List[reg.PyQrackQubit] = frame.get(stmt.qargs) active_qubits = (qarg for qarg in qargs if qarg.is_active()) diff --git a/src/bloqade/qasm2/dialects/core/__init__.py b/src/bloqade/qasm2/dialects/core/__init__.py index ec24d224..0c65c8a6 100644 --- a/src/bloqade/qasm2/dialects/core/__init__.py +++ b/src/bloqade/qasm2/dialects/core/__init__.py @@ -1,3 +1,3 @@ -from . import emit as emit, address as address, typeinfer as typeinfer +from . import _emit as _emit, address as address, _typeinfer as _typeinfer from .stmts import * # noqa: F403 from ._dialect import dialect as dialect diff --git a/src/bloqade/qasm2/dialects/core/emit.py b/src/bloqade/qasm2/dialects/core/_emit.py similarity index 100% rename from src/bloqade/qasm2/dialects/core/emit.py rename to src/bloqade/qasm2/dialects/core/_emit.py diff --git a/src/bloqade/qasm2/dialects/core/typeinfer.py b/src/bloqade/qasm2/dialects/core/_typeinfer.py similarity index 100% rename from src/bloqade/qasm2/dialects/core/typeinfer.py rename to src/bloqade/qasm2/dialects/core/_typeinfer.py diff --git a/src/bloqade/qasm2/dialects/core/stmts.py b/src/bloqade/qasm2/dialects/core/stmts.py index 661f97cf..7231bada 100644 --- a/src/bloqade/qasm2/dialects/core/stmts.py +++ b/src/bloqade/qasm2/dialects/core/stmts.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from bloqade.qasm2.types import BitType, CRegType, QRegType, QubitType @@ -11,7 +11,7 @@ class QRegNew(ir.Statement): """Create a new quantum register.""" name = "qreg.new" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) n_qubits: ir.SSAValue = info.argument(types.Int) """n_qubits: The number of qubits in the register.""" result: ir.ResultValue = info.result(QRegType) @@ -23,7 +23,7 @@ class CRegNew(ir.Statement): """Create a new classical register.""" name = "creg.new" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) n_bits: ir.SSAValue = info.argument(types.Int) """n_bits (Int): The number of bits in the register.""" result: ir.ResultValue = info.result(CRegType) @@ -35,7 +35,7 @@ class Reset(ir.Statement): """Reset a qubit to the |0> state.""" name = "reset" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qarg: ir.SSAValue = info.argument(QubitType) """qarg (Qubit): The qubit to reset.""" @@ -45,7 +45,7 @@ class Measure(ir.Statement): """Measure a qubit and store the result in a bit.""" name = "measure" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qarg: ir.SSAValue = info.argument(QubitType) """qarg (Qubit): The qubit to measure.""" carg: ir.SSAValue = info.argument(BitType) @@ -57,7 +57,7 @@ class CRegEq(ir.Statement): """Check if two classical registers are equal.""" name = "eq" - traits = frozenset({ir.Pure(), ir.FromPythonCall()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall()}) lhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType) """lhs (CReg): The first register.""" rhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType) @@ -71,7 +71,7 @@ class QRegGet(ir.Statement): """Get a qubit from a quantum register.""" name = "qreg.get" - traits = frozenset({ir.FromPythonCall(), ir.Pure()}) + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) reg: ir.SSAValue = info.argument(QRegType) """reg (QReg): The quantum register.""" idx: ir.SSAValue = info.argument(types.Int) @@ -85,7 +85,7 @@ class CRegGet(ir.Statement): """Get a bit from a classical register.""" name = "creg.get" - traits = frozenset({ir.FromPythonCall(), ir.Pure()}) + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) reg: ir.SSAValue = info.argument(CRegType) """reg (CReg): The classical register.""" idx: ir.SSAValue = info.argument(types.Int) diff --git a/src/bloqade/qasm2/dialects/expr/__init__.py b/src/bloqade/qasm2/dialects/expr/__init__.py index 06e9d475..15d0bbae 100644 --- a/src/bloqade/qasm2/dialects/expr/__init__.py +++ b/src/bloqade/qasm2/dialects/expr/__init__.py @@ -1,3 +1,3 @@ -from . import emit as emit, interp as interp, lowering as lowering +from . import _emit as _emit, _interp as _interp, _from_python as _from_python from .stmts import * # noqa: F403 from ._dialect import dialect as dialect diff --git a/src/bloqade/qasm2/dialects/expr/emit.py b/src/bloqade/qasm2/dialects/expr/_emit.py similarity index 100% rename from src/bloqade/qasm2/dialects/expr/emit.py rename to src/bloqade/qasm2/dialects/expr/_emit.py diff --git a/src/bloqade/qasm2/dialects/expr/lowering.py b/src/bloqade/qasm2/dialects/expr/_from_python.py similarity index 52% rename from src/bloqade/qasm2/dialects/expr/lowering.py rename to src/bloqade/qasm2/dialects/expr/_from_python.py index f34e2092..d987ce28 100644 --- a/src/bloqade/qasm2/dialects/expr/lowering.py +++ b/src/bloqade/qasm2/dialects/expr/_from_python.py @@ -1,31 +1,29 @@ import ast -from kirin import ir, types -from kirin.lowering import Result, FromPythonAST, LoweringState -from kirin.exceptions import DialectLoweringError +from kirin import ir, types, lowering from . import stmts from ._dialect import dialect @dialect.register -class QASMUopLowering(FromPythonAST): +class QASMUopLowering(lowering.FromPythonAST): - def lower_Name(self, state: LoweringState, node: ast.Name) -> Result: + def lower_Name(self, state: lowering.State, node: ast.Name): name = node.id if isinstance(node.ctx, ast.Load): value = state.current_frame.get(name) if value is None: - raise DialectLoweringError(f"{name} is not defined") - return Result(value) + raise lowering.BuildError(f"{name} is not defined") + return value elif isinstance(node.ctx, ast.Store): - raise DialectLoweringError("unhandled store operation") + raise lowering.BuildError("unhandled store operation") else: # Del - raise DialectLoweringError("unhandled del operation") + raise lowering.BuildError("unhandled del operation") - def lower_Assign(self, state: LoweringState, node: ast.Assign) -> Result: + def lower_Assign(self, state: lowering.State, node: ast.Assign): # NOTE: QASM only expects one value on right hand side - rhs = state.visit(node.value).expect_one() + rhs = state.lower(node.value).expect_one() current_frame = state.current_frame match node: case ast.Assign(targets=[ast.Name(lhs_name, ast.Store())], value=_): @@ -35,27 +33,26 @@ def lower_Assign(self, state: LoweringState, node: ast.Assign) -> Result: target_syntax = ", ".join( ast.unparse(target) for target in node.targets ) - raise DialectLoweringError(f"unsupported target syntax {target_syntax}") - return Result() # python assign does not have value + raise lowering.BuildError(f"unsupported target syntax {target_syntax}") - def lower_Expr(self, state: LoweringState, node: ast.Expr) -> Result: - return state.visit(node.value) + def lower_Expr(self, state: lowering.State, node: ast.Expr): + return state.parent.visit(state, node.value) - def lower_Constant(self, state: LoweringState, node: ast.Constant) -> Result: + def lower_Constant(self, state: lowering.State, node: ast.Constant): if isinstance(node.value, int): stmt = stmts.ConstInt(value=node.value) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) elif isinstance(node.value, float): stmt = stmts.ConstFloat(value=node.value) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) else: - raise DialectLoweringError( + raise lowering.BuildError( f"unsupported QASM 2.0 constant type {type(node.value)}" ) - def lower_BinOp(self, state: LoweringState, node: ast.BinOp) -> Result: - lhs = state.visit(node.left).expect_one() - rhs = state.visit(node.right).expect_one() + def lower_BinOp(self, state: lowering.State, node: ast.BinOp): + lhs = state.lower(node.left).expect_one() + rhs = state.lower(node.right).expect_one() if isinstance(node.op, ast.Add): stmt = stmts.Add(lhs, rhs) elif isinstance(node.op, ast.Sub): @@ -67,9 +64,9 @@ def lower_BinOp(self, state: LoweringState, node: ast.BinOp) -> Result: elif isinstance(node.op, ast.Pow): stmt = stmts.Pow(lhs, rhs) else: - raise DialectLoweringError(f"unsupported QASM 2.0 binop {node.op}") + raise lowering.BuildError(f"unsupported QASM 2.0 binop {node.op}") stmt.result.type = self.__promote_binop_type(lhs, rhs) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) def __promote_binop_type( self, lhs: ir.SSAValue, rhs: ir.SSAValue @@ -78,12 +75,12 @@ def __promote_binop_type( return types.Float return types.Int - def lower_UnaryOp(self, state: LoweringState, node: ast.UnaryOp) -> Result: + def lower_UnaryOp(self, state: lowering.State, node: ast.UnaryOp): if isinstance(node.op, ast.USub): - value = state.visit(node.operand).expect_one() + value = state.lower(node.operand).expect_one() stmt = stmts.Neg(value) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) elif isinstance(node.op, ast.UAdd): - return state.visit(node.operand) + return state.lower(node.operand).expect_one() else: - raise DialectLoweringError(f"unsupported QASM 2.0 unaryop {node.op}") + raise lowering.BuildError(f"unsupported QASM 2.0 unaryop {node.op}") diff --git a/src/bloqade/qasm2/dialects/expr/interp.py b/src/bloqade/qasm2/dialects/expr/_interp.py similarity index 100% rename from src/bloqade/qasm2/dialects/expr/interp.py rename to src/bloqade/qasm2/dialects/expr/_interp.py diff --git a/src/bloqade/qasm2/dialects/expr/stmts.py b/src/bloqade/qasm2/dialects/expr/stmts.py index dfdd92e7..4e9c0bf3 100644 --- a/src/bloqade/qasm2/dialects/expr/stmts.py +++ b/src/bloqade/qasm2/dialects/expr/stmts.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from kirin.print.printer import Printer from kirin.dialects.func.attrs import Signature @@ -50,7 +50,7 @@ class ConstInt(ir.Statement): """IR Statement representing a constant integer value.""" name = "constant.int" - traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()}) + traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()}) value: int = info.attribute(types.Int) """value (int): The constant integer value.""" result: ir.ResultValue = info.result(types.Int) @@ -70,7 +70,7 @@ class ConstFloat(ir.Statement): """IR Statement representing a constant float value.""" name = "constant.float" - traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()}) + traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()}) value: float = info.attribute(types.Float) """value (float): The constant float value.""" result: ir.ResultValue = info.result(types.Float) @@ -91,7 +91,7 @@ class ConstPI(ir.Statement): # this is marked as constant but not pure. name = "constant.pi" - traits = frozenset({ir.ConstantLike(), ir.FromPythonCall()}) + traits = frozenset({ir.ConstantLike(), lowering.FromPythonCall()}) result: ir.ResultValue = info.result(types.Float) """result (ConstPI): The result value.""" @@ -113,7 +113,7 @@ class Neg(ir.Statement): """Negate a number.""" name = "neg" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to negate.""" result: ir.ResultValue = info.result(PyNum) @@ -125,7 +125,7 @@ class Sin(ir.Statement): """Take the sine of a number.""" name = "sin" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the sine of.""" result: ir.ResultValue = info.result(PyNum) @@ -137,7 +137,7 @@ class Cos(ir.Statement): """Take the cosine of a number.""" name = "cos" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the cosine of.""" result: ir.ResultValue = info.result(PyNum) @@ -149,7 +149,7 @@ class Tan(ir.Statement): """Take the tangent of a number.""" name = "tan" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the tangent of.""" result: ir.ResultValue = info.result(PyNum) @@ -161,7 +161,7 @@ class Exp(ir.Statement): """Take the exponential of a number.""" name = "exp" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the exponential of.""" result: ir.ResultValue = info.result(PyNum) @@ -173,7 +173,7 @@ class Log(ir.Statement): """Take the natural log of a number.""" name = "ln" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the natural log of.""" result: ir.ResultValue = info.result(PyNum) @@ -185,7 +185,7 @@ class Sqrt(ir.Statement): """Take the square root of a number.""" name = "sqrt" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) value: ir.SSAValue = info.argument(PyNum) """value (Union[int, float]): The number to take the square root of.""" result: ir.ResultValue = info.result(PyNum) @@ -197,7 +197,7 @@ class Add(ir.Statement): """Add two numbers.""" name = "add" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) lhs: ir.SSAValue = info.argument(PyNum) """lhs (Union[int, float]): The left-hand side of the addition.""" rhs: ir.SSAValue = info.argument(PyNum) @@ -211,7 +211,7 @@ class Sub(ir.Statement): """Subtract two numbers.""" name = "sub" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) lhs: ir.SSAValue = info.argument(PyNum) """lhs (Union[int, float]): The left-hand side of the subtraction.""" rhs: ir.SSAValue = info.argument(PyNum) @@ -225,7 +225,7 @@ class Mul(ir.Statement): """Multiply two numbers.""" name = "mul" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) lhs: ir.SSAValue = info.argument(PyNum) """lhs (Union[int, float]): The left-hand side of the multiplication.""" rhs: ir.SSAValue = info.argument(PyNum) @@ -239,7 +239,7 @@ class Pow(ir.Statement): """Take the power of a number.""" name = "pow" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) lhs: ir.SSAValue = info.argument(PyNum) """lhs (Union[int, float]): The base.""" rhs: ir.SSAValue = info.argument(PyNum) @@ -253,7 +253,7 @@ class Div(ir.Statement): """Divide two numbers.""" name = "div" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) lhs: ir.SSAValue = info.argument(PyNum) """lhs (Union[int, float]): The numerator.""" rhs: ir.SSAValue = info.argument(PyNum) diff --git a/src/bloqade/qasm2/dialects/glob.py b/src/bloqade/qasm2/dialects/glob.py index eb9f308e..04defbfa 100644 --- a/src/bloqade/qasm2/dialects/glob.py +++ b/src/bloqade/qasm2/dialects/glob.py @@ -1,4 +1,4 @@ -from kirin import ir, types, interp +from kirin import ir, types, interp, lowering from kirin.decl import info, statement from kirin.dialects import ilist @@ -13,7 +13,7 @@ @statement(dialect=dialect) class UGate(ir.Statement): name = "ugate" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) registers: ir.SSAValue = info.argument(ilist.IListType[QRegType]) theta: ir.SSAValue = info.argument(types.Float) phi: ir.SSAValue = info.argument(types.Float) @@ -33,7 +33,8 @@ class GlobEmit(interp.MethodTable): @interp.impl(UGate) def ugate(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: UGate): registers = [ - emit.assert_node(ast.Name, reg) for reg in frame.get(stmt.registers) + emit.assert_node(ast.Name, reg) + for reg in frame.get_casted(stmt.registers, ilist.IList) ] theta = emit.assert_node(ast.Expr, frame.get(stmt.theta)) phi = emit.assert_node(ast.Expr, frame.get(stmt.phi)) diff --git a/src/bloqade/qasm2/dialects/indexing.py b/src/bloqade/qasm2/dialects/indexing.py index 9bc913f1..942232f1 100644 --- a/src/bloqade/qasm2/dialects/indexing.py +++ b/src/bloqade/qasm2/dialects/indexing.py @@ -7,9 +7,7 @@ import ast -from kirin import ir, types -from kirin.lowering import Result, FromPythonAST, LoweringState -from kirin.exceptions import DialectLoweringError +from kirin import ir, types, lowering from bloqade.qasm2.types import BitType, CRegType, QRegType, QubitType from bloqade.qasm2.dialects import core @@ -18,35 +16,35 @@ @dialect.register -class QASMCoreLowering(FromPythonAST): - def lower_Compare(self, state: LoweringState, node: ast.Compare) -> Result: - lhs = state.visit(node.left).expect_one() +class QASMCoreLowering(lowering.FromPythonAST): + def lower_Compare(self, state: lowering.State, node: ast.Compare): + lhs = state.lower(node.left).expect_one() if len(node.ops) != 1: - raise DialectLoweringError( + raise lowering.BuildError( "only one comparison operator and == is supported for qasm2 lowering" ) - rhs = state.visit(node.comparators[0]).expect_one() + rhs = state.lower(node.comparators[0]).expect_one() if isinstance(node.ops[0], ast.Eq): stmt = core.CRegEq(lhs, rhs) else: - raise DialectLoweringError( + raise lowering.BuildError( f"unsupported comparison operator {node.ops[0]} only Eq is supported." ) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) - def lower_Subscript(self, state: LoweringState, node: ast.Subscript) -> Result: - value = state.visit(node.value).expect_one() - index = state.visit(node.slice).expect_one() + def lower_Subscript(self, state: lowering.State, node: ast.Subscript): + value = state.lower(node.value).expect_one() + index = state.lower(node.slice).expect_one() if not index.type.is_subseteq(types.Int): - raise DialectLoweringError( + raise lowering.BuildError( f"unsupported subscript index type {index.type}," " only integer indices are supported in QASM 2.0" ) if not isinstance(node.ctx, ast.Load): - raise DialectLoweringError( + raise lowering.BuildError( f"unsupported subscript context {node.ctx}," " cannot write to subscript in QASM 2.0" ) @@ -58,9 +56,9 @@ def lower_Subscript(self, state: LoweringState, node: ast.Subscript) -> Result: stmt = core.CRegGet(reg=value, idx=index) stmt.result.type = BitType else: - raise DialectLoweringError( + raise lowering.BuildError( f"unsupported subscript value type {value.type}," " only QReg and CReg are supported in QASM 2.0" ) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) diff --git a/src/bloqade/qasm2/dialects/inline.py b/src/bloqade/qasm2/dialects/inline.py index b00d9b35..290ab3c1 100644 --- a/src/bloqade/qasm2/dialects/inline.py +++ b/src/bloqade/qasm2/dialects/inline.py @@ -10,37 +10,49 @@ from kirin import ir, types, lowering from kirin.decl import info, statement from kirin.print import Printer -from kirin.exceptions import DialectLoweringError dialect = ir.Dialect("qasm2.inline") @dataclass(frozen=True) -class InlineQASMLowering(ir.FromPythonCall): +class InlineQASMLowering(lowering.FromPythonCall): def lower( - self, stmt: type, state: lowering.LoweringState, node: ast.Call + self, stmt: type, state: lowering.State, node: ast.Call ) -> lowering.Result: from bloqade.qasm2.parse import loads - from bloqade.qasm2.parse.lowering import LoweringQASM + from bloqade.qasm2.parse.lowering import QASM2 if len(node.args) != 1 or node.keywords: - raise DialectLoweringError("InlineQASM takes 1 positional argument") + raise lowering.BuildError("InlineQASM takes 1 positional argument") text = node.args[0] # 1. string literal if isinstance(text, ast.Constant) and isinstance(text.value, str): value = text.value elif isinstance(text, ast.Name) and isinstance(text.ctx, ast.Load): - value = state.get_global(text.id).expect(str) + value = state.get_global(text).expect(str) else: - raise DialectLoweringError( + raise lowering.BuildError( "InlineQASM takes a string literal or global string" ) + from kirin.dialects import ilist + + from bloqade.qasm2.groups import main + from bloqade.qasm2.dialects import glob, noise, parallel + raw = textwrap.dedent(value) - qasm_lowering = LoweringQASM(state) - qasm_lowering.visit(loads(raw)) - return lowering.Result() + qasm_lowering = QASM2(main.union([ilist, glob, noise, parallel])) + region = qasm_lowering.run(loads(raw)) + for qasm_stmt in region.blocks[0].stmts: + qasm_stmt.detach() + state.current_frame.push(qasm_stmt) + + for block in region.blocks: + for qasm_stmt in block.stmts: + qasm_stmt.detach() + state.current_frame.push(qasm_stmt) + state.current_frame.jump_next_block() # NOTE: this is a dummy statement that won't appear in IR. diff --git a/src/bloqade/qasm2/dialects/noise.py b/src/bloqade/qasm2/dialects/noise.py index 962b2e03..56125a16 100644 --- a/src/bloqade/qasm2/dialects/noise.py +++ b/src/bloqade/qasm2/dialects/noise.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from bloqade.qasm2.types import QubitType @@ -9,7 +9,7 @@ @statement(dialect=dialect) class Pauli1(ir.Statement): name = "pauli_1" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) px: ir.SSAValue = info.argument(types.Float) py: ir.SSAValue = info.argument(types.Float) pz: ir.SSAValue = info.argument(types.Float) diff --git a/src/bloqade/qasm2/dialects/parallel.py b/src/bloqade/qasm2/dialects/parallel.py index 8313e572..216b4d91 100644 --- a/src/bloqade/qasm2/dialects/parallel.py +++ b/src/bloqade/qasm2/dialects/parallel.py @@ -1,6 +1,6 @@ from typing import Any -from kirin import ir, types, interp +from kirin import ir, types, interp, lowering from kirin.decl import info, statement from kirin.analysis import ForwardFrame from kirin.dialects import ilist @@ -18,7 +18,7 @@ @statement(dialect=dialect) class CZ(ir.Statement): name = "cz" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) @@ -26,7 +26,7 @@ class CZ(ir.Statement): @statement(dialect=dialect) class UGate(ir.Statement): name = "u" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType]) theta: ir.SSAValue = info.argument(types.Float) phi: ir.SSAValue = info.argument(types.Float) @@ -36,7 +36,7 @@ class UGate(ir.Statement): @statement(dialect=dialect) class RZ(ir.Statement): name = "rz" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType]) theta: ir.SSAValue = info.argument(types.Float) diff --git a/src/bloqade/qasm2/dialects/uop/__init__.py b/src/bloqade/qasm2/dialects/uop/__init__.py index 30256fe5..7ab47ef6 100644 --- a/src/bloqade/qasm2/dialects/uop/__init__.py +++ b/src/bloqade/qasm2/dialects/uop/__init__.py @@ -1,4 +1,4 @@ -from . import emit as emit, stmts as stmts +from . import _emit as _emit, stmts as stmts from .stmts import * # noqa: F403 from ._dialect import dialect as dialect from .schedule import * # noqa: F403 diff --git a/src/bloqade/qasm2/dialects/uop/emit.py b/src/bloqade/qasm2/dialects/uop/_emit.py similarity index 100% rename from src/bloqade/qasm2/dialects/uop/emit.py rename to src/bloqade/qasm2/dialects/uop/_emit.py diff --git a/src/bloqade/qasm2/dialects/uop/stmts.py b/src/bloqade/qasm2/dialects/uop/stmts.py index f0740be3..755b50e5 100644 --- a/src/bloqade/qasm2/dialects/uop/stmts.py +++ b/src/bloqade/qasm2/dialects/uop/stmts.py @@ -1,4 +1,4 @@ -from kirin import ir +from kirin import ir, lowering from kirin.decl import info, statement from bloqade.qasm2.types import QubitType @@ -12,14 +12,14 @@ class SingleQubitGate(ir.Statement): """Base class for single qubit gates.""" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qarg: ir.SSAValue = info.argument(QubitType) """qarg (Qubit): The qubit argument.""" @statement class TwoQubitCtrlGate(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) ctrl: ir.SSAValue = info.argument(QubitType) """ctrl (Qubit): The control qubit.""" qarg: ir.SSAValue = info.argument(QubitType) @@ -51,7 +51,7 @@ class Barrier(ir.Statement): """Apply the Barrier statement.""" name = "barrier" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qargs: tuple[ir.SSAValue, ...] = info.argument(QubitType) """qargs: tuple of qubits to apply the barrier to.""" @@ -223,7 +223,7 @@ class CCX(ir.Statement): """Apply the doubly controlled X gate.""" name = "ccx" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) ctrl1: ir.SSAValue = info.argument(QubitType) """ctrl1 (Qubit): The first control qubit.""" ctrl2: ir.SSAValue = info.argument(QubitType) @@ -237,7 +237,7 @@ class CSwap(ir.Statement): """Apply the controlled swap gate.""" name = "ccx" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) ctrl: ir.SSAValue = info.argument(QubitType) """ctrl (Qubit): The control qubit.""" qarg1: ir.SSAValue = info.argument(QubitType) diff --git a/src/bloqade/qasm2/groups.py b/src/bloqade/qasm2/groups.py index 84361eff..280638c0 100644 --- a/src/bloqade/qasm2/groups.py +++ b/src/bloqade/qasm2/groups.py @@ -43,7 +43,7 @@ def run_pass( fold_pass(method) typeinfer_pass(method) - method.code.typecheck() + method.verify_type() return run_pass @@ -76,7 +76,7 @@ def run_pass( fold_pass(method) typeinfer_pass(method) - method.code.typecheck() + method.verify_type() return run_pass @@ -115,6 +115,6 @@ def run_pass( indexing_desugar_pass(mt) if typeinfer: typeinfer_pass(mt) # fix types after desugaring - mt.code.typecheck() + mt.verify_type() return run_pass diff --git a/src/bloqade/qasm2/parse/lowering.py b/src/bloqade/qasm2/parse/lowering.py index 4d18a64c..5b1d9d59 100644 --- a/src/bloqade/qasm2/parse/lowering.py +++ b/src/bloqade/qasm2/parse/lowering.py @@ -1,24 +1,100 @@ -from dataclasses import dataclass +from typing import Any +from dataclasses import field, dataclass -from kirin import ir, lowering +from kirin import ir, types, lowering from kirin.dialects import cf, func, ilist -from kirin.lowering import LoweringState -from kirin.exceptions import DialectLoweringError from bloqade.qasm2.types import CRegType, QRegType from bloqade.qasm2.dialects import uop, core, expr, glob, noise, parallel from . import ast -from .visitor import Visitor @dataclass -class LoweringQASM(Visitor[lowering.Result]): - state: LoweringState - extension: str | None = None +class QASM2(lowering.LoweringABC[ast.Node]): + max_lines: int = field(default=3, kw_only=True) + hint_indent: int = field(default=2, kw_only=True) + hint_show_lineno: bool = field(default=True, kw_only=True) + stacktrace: bool = field(default=True, kw_only=True) + + def run( + self, + stmt: ast.Node, + *, + source: str | None = None, + globals: dict[str, Any] | None = None, + file: str | None = None, + lineno_offset: int = 0, + col_offset: int = 0, + compactify: bool = True, + ) -> ir.Region: + # TODO: add source info + state = lowering.State( + self, + file=file, + lineno_offset=lineno_offset, + col_offset=col_offset, + ) + with state.frame( + [stmt], + globals=globals, + ) as frame: + try: + self.visit(state, stmt) + except lowering.BuildError as e: + hint = state.error_hint( + e, + max_lines=self.max_lines, + indent=self.hint_indent, + show_lineno=self.hint_show_lineno, + ) + if self.stacktrace: + raise Exception( + f"{e.args[0]}\n\n{hint}", + *e.args[1:], + ) from e + else: + e.args = (hint,) + raise e + + region = frame.curr_region + + if compactify: + from kirin.rewrite import Walk, CFGCompactify + + Walk(CFGCompactify()).rewrite(region) + return region + + def visit(self, state: lowering.State[ast.Node], node: ast.Node) -> lowering.Result: + name = node.__class__.__name__ + return getattr(self, f"visit_{name}", self.generic_visit)(state, node) + + def generic_visit( + self, state: lowering.State[ast.Node], node: ast.Node + ) -> lowering.Result: + if isinstance(node, ast.Node): + raise lowering.BuildError( + f"Cannot lower {node.__class__.__name__} node: {node}" + ) + raise lowering.BuildError( + f"Unexpected `{node.__class__.__name__}` node: {repr(node)} is not an AST node" + ) - def visit_MainProgram(self, node: ast.MainProgram) -> lowering.Result: - allowed = {dialect.name for dialect in self.state.dialects} + def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue: + if isinstance(value, int): + stmt = expr.ConstInt(value=value) + elif isinstance(value, float): + stmt = expr.ConstFloat(value=value) + state.current_frame.push(stmt) + return stmt.result + + def lower_global( + self, state: lowering.State[ast.Node], node: ast.Node + ) -> lowering.LoweringABC.Result: + raise lowering.BuildError("Global variables are not supported in QASM 2.0") + + def visit_MainProgram(self, state: lowering.State[ast.Node], node: ast.MainProgram): + allowed = {dialect.name for dialect in self.dialects} if isinstance(node.header, ast.OPENQASM) and node.header.version.major == 2: dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"] elif isinstance(node.header, ast.Kirin): @@ -26,106 +102,163 @@ def visit_MainProgram(self, node: ast.MainProgram) -> lowering.Result: for dialect in dialects: if dialect not in allowed: - raise DialectLoweringError( + raise lowering.BuildError( f"Dialect {dialect} not found, allowed: {', '.join(allowed)}" ) for stmt in node.statements: - self.visit(stmt) - return lowering.Result() + state.lower(stmt) - def visit_QReg(self, node: ast.QReg) -> lowering.Result: + def visit_QReg(self, state: lowering.State[ast.Node], node: ast.QReg): reg = core.QRegNew( - self.state.append_stmt(expr.ConstInt(value=node.size)).result + state.current_frame.push(expr.ConstInt(value=node.size)).result ) - self.state.append_stmt(reg) - self.state.current_frame.defs[node.name] = reg.result - return lowering.Result() + state.current_frame.push(reg) + state.current_frame.defs[node.name] = reg.result - def visit_CReg(self, node: ast.CReg) -> lowering.Result: + def visit_CReg(self, state: lowering.State[ast.Node], node: ast.CReg): reg = core.CRegNew( - self.state.append_stmt(expr.ConstInt(value=node.size)).result + state.current_frame.push(expr.ConstInt(value=node.size)).result ) - self.state.append_stmt(reg) - self.state.current_frame.defs[node.name] = reg.result - return lowering.Result() + state.current_frame.push(reg) + state.current_frame.defs[node.name] = reg.result - def visit_Barrier(self, node: ast.Barrier) -> lowering.Result: - self.state.append_stmt( + def visit_Barrier(self, state: lowering.State[ast.Node], node: ast.Barrier): + state.current_frame.push( uop.Barrier( - qargs=tuple(self.visit(qarg).expect_one() for qarg in node.qargs) + qargs=tuple(state.lower(qarg).expect_one() for qarg in node.qargs) ) ) - return lowering.Result() - def visit_CXGate(self, node: ast.CXGate) -> lowering.Result: - self.state.append_stmt( + def visit_CXGate(self, state: lowering.State[ast.Node], node: ast.CXGate): + state.current_frame.push( uop.CX( - ctrl=self.visit(node.ctrl).expect_one(), - qarg=self.visit(node.qarg).expect_one(), + ctrl=state.lower(node.ctrl).expect_one(), + qarg=state.lower(node.qarg).expect_one(), ) ) - return lowering.Result() - def visit_Measure(self, node: ast.Measure) -> lowering.Result: - self.state.append_stmt( + def visit_Measure(self, state: lowering.State[ast.Node], node: ast.Measure): + state.current_frame.push( core.Measure( - qarg=self.visit(node.qarg).expect_one(), - carg=self.visit(node.carg).expect_one(), + qarg=state.lower(node.qarg).expect_one(), + carg=state.lower(node.carg).expect_one(), ) ) - return lowering.Result() - def visit_UGate(self, node: ast.UGate) -> lowering.Result: - self.state.append_stmt( + def visit_UGate(self, state: lowering.State[ast.Node], node: ast.UGate): + state.current_frame.push( uop.UGate( - theta=self.visit(node.theta).expect_one(), - phi=self.visit(node.phi).expect_one(), - lam=self.visit(node.lam).expect_one(), - qarg=self.visit(node.qarg).expect_one(), + theta=state.lower(node.theta).expect_one(), + phi=state.lower(node.phi).expect_one(), + lam=state.lower(node.lam).expect_one(), + qarg=state.lower(node.qarg).expect_one(), ) ) - return lowering.Result() - def visit_Reset(self, node: ast.Reset) -> lowering.Result: - self.state.append_stmt(core.Reset(qarg=self.visit(node.qarg).expect_one())) - return lowering.Result() + def visit_Reset(self, state: lowering.State[ast.Node], node: ast.Reset): + state.current_frame.push(core.Reset(qarg=state.lower(node.qarg).expect_one())) - def visit_IfStmt(self, node: ast.IfStmt) -> lowering.Result: + # TODO: clean this up? copied from cf dialect with a small modification + def visit_IfStmt(self, state: lowering.State[ast.Node], node: ast.IfStmt): cond_stmt = core.CRegEq( - lhs=self.visit(node.cond.lhs).expect_one(), - rhs=self.visit(node.cond.rhs).expect_one(), + lhs=state.lower(node.cond.lhs).expect_one(), + rhs=state.lower(node.cond.rhs).expect_one(), ) - cond = self.state.append_stmt(cond_stmt).result - frame = self.state.current_frame + cond = state.current_frame.push(cond_stmt).result + frame = state.current_frame before_block = frame.curr_block - if frame.exit_block is None: - raise DialectLoweringError("code block is not exiting") - else: - before_block_next = frame.exit_block - if_block = self.state.current_frame.append_block() - for stmt in node.body: - self.visit(stmt) - if_block.stmts.append( + with state.frame(node.body, region=frame.curr_region) as if_frame: + true_cond = if_frame.entr_block.args.append_from(types.Bool, cond.name) + if cond.name: + if_frame.defs[cond.name] = true_cond + + if_frame.exhaust() + self.branch_next_if_not_terminated(if_frame) + + with state.frame([], region=frame.curr_region) as else_frame: + true_cond = else_frame.entr_block.args.append_from(types.Bool, cond.name) + if cond.name: + else_frame.defs[cond.name] = true_cond + else_frame.exhaust() + self.branch_next_if_not_terminated(else_frame) + + with state.frame(frame.stream.split(), region=frame.curr_region) as after_frame: + after_frame.defs.update(frame.defs) + phi: set[str] = set() + for name in if_frame.defs.keys(): + if frame.get(name): + phi.add(name) + elif name in else_frame.defs: + phi.add(name) + + for name in else_frame.defs.keys(): + if frame.get(name): # not defined in if_frame + phi.add(name) + + for name in phi: + after_frame.defs[name] = after_frame.entr_block.args.append_from( + types.Any, name + ) + + after_frame.exhaust() + self.branch_next_if_not_terminated(after_frame) + after_frame.next_block.stmts.append( + cf.Branch(arguments=(), successor=frame.next_block) + ) + + if_args = [] + for name in phi: + if value := if_frame.get(name): + if_args.append(value) + else: + raise lowering.BuildError(f"undefined variable {name} in if branch") + + else_args = [] + for name in phi: + if value := else_frame.get(name): + else_args.append(value) + else: + raise lowering.BuildError(f"undefined variable {name} in else branch") + + if_frame.next_block.stmts.append( cf.Branch( - arguments=(), - successor=before_block_next, + arguments=tuple(if_args), + successor=after_frame.entr_block, + ) + ) + else_frame.next_block.stmts.append( + cf.Branch( + arguments=tuple(else_args), + successor=after_frame.entr_block, ) ) before_block.stmts.append( cf.ConditionalBranch( cond=cond, - then_arguments=(), - then_successor=if_block, - else_arguments=(), - else_successor=before_block_next, + then_arguments=(cond,), + then_successor=if_frame.entr_block, + else_arguments=(cond,), + else_successor=else_frame.entr_block, ) ) - frame.curr_block = before_block_next - return lowering.Result() + frame.defs.update(after_frame.defs) + frame.jump_next_block() + + def branch_next_if_not_terminated(self, frame: lowering.Frame): + """Branch to the next block if the current block is not terminated. + + This must be used after exhausting the current frame and before popping the frame. + """ + if not frame.curr_block.last_stmt or not frame.curr_block.last_stmt.has_trait( + ir.IsTerminator + ): + frame.curr_block.stmts.append( + cf.Branch(arguments=(), successor=frame.next_block) + ) - def visit_BinOp(self, node: ast.BinOp) -> lowering.Result: + def visit_BinOp(self, state: lowering.State[ast.Node], node: ast.BinOp): if node.op == "+": stmt_type = expr.Add elif node.op == "-": @@ -135,295 +268,286 @@ def visit_BinOp(self, node: ast.BinOp) -> lowering.Result: else: stmt_type = expr.Div - stmt = self.state.append_stmt( + return state.current_frame.push( stmt_type( - lhs=self.visit(node.lhs).expect_one(), - rhs=self.visit(node.rhs).expect_one(), + lhs=state.lower(node.lhs).expect_one(), + rhs=state.lower(node.rhs).expect_one(), ) ) - return lowering.Result(stmt.result) - def visit_UnaryOp(self, node: ast.UnaryOp) -> lowering.Result: + def visit_UnaryOp(self, state: lowering.State[ast.Node], node: ast.UnaryOp): if node.op == "-": - stmt = expr.Neg(value=self.visit(node.operand).expect_one()) - return lowering.Result(stmt.result) + stmt = expr.Neg(value=state.lower(node.operand).expect_one()) + return stmt.result else: - return lowering.Result(self.visit(node.operand).expect_one()) + return state.lower(node.operand).expect_one() - def visit_Bit(self, node: ast.Bit) -> lowering.Result: - if node.name.id not in self.state.current_frame.defs: + def visit_Bit(self, state: lowering.State[ast.Node], node: ast.Bit): + if node.name.id not in state.current_frame.defs: raise ValueError(f"Bit {node.name} not found") - addr = self.state.append_stmt(expr.ConstInt(value=node.addr)) - reg = self.state.current_frame.get_local(node.name.id) + addr = state.current_frame.push(expr.ConstInt(value=node.addr)) + reg = state.current_frame.get_local(node.name.id) if reg is None: - raise DialectLoweringError(f"{node.name.id} is not defined") + raise lowering.BuildError(f"{node.name.id} is not defined") if reg.type.is_subseteq(QRegType): stmt = core.QRegGet(reg, addr.result) elif reg.type.is_subseteq(CRegType): stmt = core.CRegGet(reg, addr.result) - return lowering.Result(self.state.append_stmt(stmt).result) + return state.current_frame.push(stmt).result - def visit_Call(self, node: ast.Call) -> lowering.Result: + def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call): if node.name == "cos": - stmt = expr.Cos(self.visit(node.args[0]).expect_one()) + stmt = expr.Cos(state.lower(node.args[0]).expect_one()) elif node.name == "sin": - stmt = expr.Sin(self.visit(node.args[0]).expect_one()) + stmt = expr.Sin(state.lower(node.args[0]).expect_one()) elif node.name == "tan": - stmt = expr.Tan(self.visit(node.args[0]).expect_one()) + stmt = expr.Tan(state.lower(node.args[0]).expect_one()) elif node.name == "exp": - stmt = expr.Exp(self.visit(node.args[0]).expect_one()) + stmt = expr.Exp(state.lower(node.args[0]).expect_one()) elif node.name == "log": - stmt = expr.Log(self.visit(node.args[0]).expect_one()) + stmt = expr.Log(state.lower(node.args[0]).expect_one()) elif node.name == "sqrt": - stmt = expr.Sqrt(self.visit(node.args[0]).expect_one()) + stmt = expr.Sqrt(state.lower(node.args[0]).expect_one()) else: raise ValueError(f"Unknown function {node.name}") - self.state.append_stmt(stmt) - return lowering.Result(stmt.result) + state.current_frame.push(stmt) + return stmt.result + + def visit_Name(self, state: lowering.State[ast.Node], node: ast.Name): + if (value := state.current_frame.get_local(node.id)) is not None: + return value + raise ValueError(f"name {node.id} not found") + + def visit_ParaCZGate(self, state: lowering.State[ast.Node], node: ast.ParaCZGate): + ctrls: list[ir.SSAValue] = [] + qargs: list[ir.SSAValue] = [] + for pair in node.qargs: + if len(pair) != 2: + raise ValueError("CZ gate requires exactly two qargs") + ctrl, qarg = pair + ctrls.append(state.lower(ctrl).expect_one()) + qargs.append(state.lower(qarg).expect_one()) + + ctrls_stmt = ilist.New(values=ctrls) + qargs_stmt = ilist.New(values=qargs) + state.current_frame.push(ctrls_stmt) + state.current_frame.push(qargs_stmt) + state.current_frame.push( + parallel.CZ(ctrls=ctrls_stmt.result, qargs=qargs_stmt.result) + ) + + def visit_ParaRZGate(self, state: lowering.State[ast.Node], node: ast.ParaRZGate): + qargs: list[ir.SSAValue] = [] + for pair in node.qargs: + if len(pair) != 1: + raise ValueError("Rz gate requires exactly one qarg") + qargs.append(state.lower(pair[0]).expect_one()) + + qargs_stmt = ilist.New(values=qargs) + state.current_frame.push(qargs_stmt) + state.current_frame.push( + parallel.RZ( + theta=state.lower(node.theta).expect_one(), + qargs=qargs_stmt.result, + ) + ) + + def visit_ParaU3Gate(self, state: lowering.State[ast.Node], node: ast.ParaU3Gate): + qargs: list[ir.SSAValue] = [] + for pair in node.qargs: + if len(pair) != 1: + raise ValueError("U3 gate requires exactly one qarg") + qargs.append(state.lower(pair[0]).expect_one()) + + qargs_stmt = ilist.New(values=qargs) + state.current_frame.push(qargs_stmt) + state.current_frame.push( + parallel.UGate( + theta=state.lower(node.theta).expect_one(), + phi=state.lower(node.phi).expect_one(), + lam=state.lower(node.lam).expect_one(), + qargs=qargs_stmt.result, + ) + ) + + def visit_GlobUGate(self, state: lowering.State[ast.Node], node: ast.GlobUGate): + registers: list[ir.SSAValue] = [] + + for register in node.registers: # These will all be ast.Names + registers.append(state.lower(register).expect_one()) + + registers_stmt = ilist.New(values=registers) + state.current_frame.push(registers_stmt) + state.current_frame.push( + # all the stuff going into the args should be SSA values + glob.UGate( + registers=registers_stmt.result, # expect_one = a singular SSA value + theta=state.lower(node.theta).expect_one(), + phi=state.lower(node.phi).expect_one(), + lam=state.lower(node.lam).expect_one(), + ) + ) + + def visit_NoisePAULI1(self, state: lowering.State[ast.Node], node: ast.NoisePAULI1): + state.current_frame.push( + noise.Pauli1( + px=state.lower(node.px).expect_one(), + py=state.lower(node.py).expect_one(), + pz=state.lower(node.pz).expect_one(), + qarg=state.lower(node.qarg).expect_one(), + ) + ) + + def visit_Number(self, state: lowering.State[ast.Node], node: ast.Number): + if isinstance(node.value, int): + stmt = expr.ConstInt(value=node.value) + else: + stmt = expr.ConstFloat(value=node.value) + state.current_frame.push(stmt) + return stmt - def visit_Instruction(self, node: ast.Instruction) -> lowering.Result: - params = [self.visit(param).expect_one() for param in node.params] - qargs = [self.visit(qarg).expect_one() for qarg in node.qargs] - visit_inst = getattr(self, "visit_instruction_" + node.name.id, None) + def visit_Pi(self, state: lowering.State[ast.Node], node: ast.Pi): + return state.current_frame.push(expr.ConstPI()).result + + def visit_Include(self, state: lowering.State[ast.Node], node: ast.Include): + if node.filename not in ["qelib1.inc"]: + raise lowering.BuildError(f"Include {node.filename} not found") + + def visit_Gate(self, state: lowering.State[ast.Node], node: ast.Gate): + raise NotImplementedError("Gate lowering not supported") + + def visit_Instruction(self, state: lowering.State[ast.Node], node: ast.Instruction): + params = [state.lower(param).expect_one() for param in node.params] + qargs = [state.lower(qarg).expect_one() for qarg in node.qargs] + visit_inst = getattr(self, "visit_Instruction_" + node.name.id, None) if visit_inst is not None: - self.state.append_stmt(visit_inst(node, params, qargs)) + state.current_frame.push(visit_inst(params, qargs)) else: - value = self.state.get_global(node.name.id).expect(ir.Method) + value = state.get_global(node.name).expect(ir.Method) # NOTE: QASM expects the return type to be known at call site if value.return_type is None: raise ValueError(f"Unknown return type for {node.name.id}") - self.state.append_stmt( + state.current_frame.push( func.Invoke( callee=value, inputs=tuple(params + qargs), kwargs=tuple(), ) ) - return lowering.Result() - def visit_instruction_id(self, node: ast.Instruction, params, qargs): + def visit_Instruction_id(self, params, qargs): return uop.Id(qarg=qargs[0]) - def visit_instruction_x(self, node: ast.Instruction, params, qargs): + def visit_Instruction_x(self, params, qargs): return uop.X(qarg=qargs[0]) - def visit_instruction_y(self, node: ast.Instruction, params, qargs): + def visit_Instruction_y(self, params, qargs): return uop.Y(qarg=qargs[0]) - def visit_instruction_z(self, node: ast.Instruction, params, qargs): + def visit_Instruction_z(self, params, qargs): return uop.Z(qarg=qargs[0]) - def visit_instruction_h(self, node: ast.Instruction, params, qargs): + def visit_Instruction_h(self, params, qargs): return uop.H(qarg=qargs[0]) - def visit_instruction_s(self, node: ast.Instruction, params, qargs): + def visit_Instruction_s(self, params, qargs): return uop.S(qarg=qargs[0]) - def visit_instruction_sdg(self, node: ast.Instruction, params, qargs): + def visit_Instruction_sdg(self, params, qargs): return uop.Sdag(qarg=qargs[0]) - def visit_instruction_sx(self, node: ast.Instruction, params, qargs): + def visit_Instruction_sx(self, params, qargs): return uop.SX(qarg=qargs[0]) - def visit_instruction_sxdg(self, node: ast.Instruction, params, qargs): + def visit_Instruction_sxdg(self, params, qargs): return uop.SXdag(qarg=qargs[0]) - def visit_instruction_t(self, node: ast.Instruction, params, qargs): + def visit_Instruction_t(self, params, qargs): return uop.T(qarg=qargs[0]) - def visit_instruction_tdg(self, node: ast.Instruction, params, qargs): + def visit_Instruction_tdg(self, params, qargs): return uop.Tdag(qarg=qargs[0]) - def visit_instruction_rx(self, node: ast.Instruction, params, qargs): + def visit_Instruction_rx(self, params, qargs): return uop.RX(theta=params[0], qarg=qargs[0]) - def visit_instruction_ry(self, node: ast.Instruction, params, qargs): + def visit_Instruction_ry(self, params, qargs): return uop.RY(theta=params[0], qarg=qargs[0]) - def visit_instruction_rz(self, node: ast.Instruction, params, qargs): + def visit_Instruction_rz(self, params, qargs): return uop.RZ(theta=params[0], qarg=qargs[0]) - def visit_instruction_p(self, node: ast.Instruction, params, qargs): + def visit_Instruction_p(self, params, qargs): return uop.U1(lam=params[0], qarg=qargs[0]) - def visit_instruction_u(self, node: ast.Instruction, params, qargs): + def visit_Instruction_u(self, params, qargs): return uop.UGate(theta=params[0], phi=params[1], lam=params[2], qarg=qargs[0]) - def visit_instruction_u1(self, node: ast.Instruction, params, qargs): + def visit_Instruction_u1(self, params, qargs): return uop.U1(lam=params[0], qarg=qargs[0]) - def visit_instruction_u2(self, node: ast.Instruction, params, qargs): + def visit_Instruction_u2(self, params, qargs): return uop.U2(phi=params[0], lam=params[1], qarg=qargs[0]) - def visit_instruction_u3(self, node: ast.Instruction, params, qargs): + def visit_Instruction_u3(self, params, qargs): return uop.UGate(theta=params[0], phi=params[1], lam=params[2], qarg=qargs[0]) - def visit_instruction_CX(self, node: ast.Instruction, params, qargs): + def visit_Instruction_CX(self, params, qargs): return uop.CX(ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_cx(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cx(self, params, qargs): return uop.CX(ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_cy(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cy(self, params, qargs): return uop.CY(ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_cz(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cz(self, params, qargs): return uop.CZ(ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_ch(self, node: ast.Instruction, params, qargs): + def visit_Instruction_ch(self, params, qargs): return uop.CH(ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_crx(self, node: ast.Instruction, params, qargs): + def visit_Instruction_crx(self, params, qargs): return uop.CRX(lam=params[0], ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_cry(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cry(self, params, qargs): return uop.CRY(lam=params[0], ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_crz(self, node: ast.Instruction, params, qargs): + def visit_Instruction_crz(self, params, qargs): return uop.CRZ(lam=params[0], ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_ccx(self, node: ast.Instruction, params, qargs): + def visit_Instruction_ccx(self, params, qargs): return uop.CCX(ctrl1=qargs[0], ctrl2=qargs[1], qarg=qargs[2]) - def visit_instruction_csx(self, node: ast.Instruction, params, qargs): + def visit_Instruction_csx(self, params, qargs): return uop.CSX(ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_cswap(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cswap(self, params, qargs): return uop.CSwap(ctrl=qargs[0], qarg1=qargs[1], qarg2=qargs[2]) - def visit_instruction_cp(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cp(self, params, qargs): return uop.CU1(lam=params[0], ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_cu1(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cu1(self, params, qargs): return uop.CU1(lam=params[0], ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_cu3(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cu3(self, params, qargs): return uop.CU3( theta=params[0], phi=params[1], lam=params[2], ctrl=qargs[0], qarg=qargs[1] ) - def visit_instruction_cu(self, node: ast.Instruction, params, qargs): + def visit_Instruction_cu(self, params, qargs): return uop.CU3( theta=params[0], phi=params[1], lam=params[2], ctrl=qargs[0], qarg=qargs[1] ) - def visit_instruction_rxx(self, node: ast.Instruction, params, qargs): + def visit_Instruction_rxx(self, params, qargs): return uop.RXX(theta=params[0], ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_rzz(self, node: ast.Instruction, params, qargs): + def visit_Instruction_rzz(self, params, qargs): return uop.RZZ(theta=params[0], ctrl=qargs[0], qarg=qargs[1]) - def visit_instruction_swap(self, node: ast.Instruction, params, qargs): + def visit_Instruction_swap(self, params, qargs): return uop.Swap(ctrl=qargs[0], qarg=qargs[1]) - - def visit_Number(self, node: ast.Number) -> lowering.Result: - if isinstance(node.value, int): - stmt = expr.ConstInt(value=node.value) - else: - stmt = expr.ConstFloat(value=node.value) - return lowering.Result(self.state.append_stmt(stmt).result) - - def visit_Pi(self, node: ast.Pi) -> lowering.Result: - return lowering.Result(self.state.append_stmt(expr.ConstPI()).result) - - def visit_Include(self, node: ast.Include) -> lowering.Result: - if node.filename not in ["qelib1.inc"]: - raise DialectLoweringError(f"Include {node.filename} not found") - - return lowering.Result() - - def visit_Gate(self, node: ast.Gate) -> lowering.Result: - raise NotImplementedError("Gate lowering not supported") - - def visit_Name(self, node: ast.Name) -> lowering.Result: - if (value := self.state.current_frame.get_local(node.id)) is not None: - return lowering.Result(value) - raise ValueError(f"name {node.id} not found") - - def visit_ParaCZGate(self, node: ast.ParaCZGate) -> lowering.Result: - ctrls: list[ir.SSAValue] = [] - qargs: list[ir.SSAValue] = [] - for pair in node.qargs: - if len(pair) != 2: - raise ValueError("CZ gate requires exactly two qargs") - ctrl, qarg = pair - ctrls.append(self.visit(ctrl).expect_one()) - qargs.append(self.visit(qarg).expect_one()) - - ctrls_stmt = ilist.New(values=ctrls) - qargs_stmt = ilist.New(values=qargs) - self.state.append_stmt(ctrls_stmt) - self.state.append_stmt(qargs_stmt) - self.state.append_stmt( - parallel.CZ(ctrls=ctrls_stmt.result, qargs=qargs_stmt.result) - ) - return lowering.Result() - - def visit_ParaRZGate(self, node: ast.ParaRZGate) -> lowering.Result: - qargs: list[ir.SSAValue] = [] - for pair in node.qargs: - if len(pair) != 1: - raise ValueError("Rz gate requires exactly one qarg") - qargs.append(self.visit(pair[0]).expect_one()) - - qargs_stmt = ilist.New(values=qargs) - self.state.append_stmt(qargs_stmt) - self.state.append_stmt( - parallel.RZ( - theta=self.visit(node.theta).expect_one(), - qargs=qargs_stmt.result, - ) - ) - return lowering.Result() - - def visit_ParaU3Gate(self, node: ast.ParaU3Gate) -> lowering.Result: - qargs: list[ir.SSAValue] = [] - for pair in node.qargs: - if len(pair) != 1: - raise ValueError("U3 gate requires exactly one qarg") - qargs.append(self.visit(pair[0]).expect_one()) - - qargs_stmt = ilist.New(values=qargs) - self.state.append_stmt(qargs_stmt) - self.state.append_stmt( - parallel.UGate( - theta=self.visit(node.theta).expect_one(), - phi=self.visit(node.phi).expect_one(), - lam=self.visit(node.lam).expect_one(), - qargs=qargs_stmt.result, - ) - ) - return lowering.Result() - - def visit_GlobUGate(self, node: ast.GlobUGate) -> lowering.Result: - - registers: list[ir.SSAValue] = [] - - for register in node.registers: # These will all be ast.Names - registers.append(self.visit(register).expect_one()) - - registers_stmt = ilist.New(values=registers) - self.state.append_stmt(registers_stmt) - self.state.append_stmt( - # all the stuff going into the args should be SSA values - glob.UGate( - registers=registers_stmt.result, # expect_one = a singular SSA value - theta=self.visit(node.theta).expect_one(), - phi=self.visit(node.phi).expect_one(), - lam=self.visit(node.lam).expect_one(), - ) - ) - return lowering.Result() - - def visit_NoisePAULI1(self, node: ast.NoisePAULI1) -> lowering.Result: - self.state.append_stmt( - noise.Pauli1( - px=self.visit(node.px).expect_one(), - py=self.visit(node.py).expect_one(), - pz=self.visit(node.pz).expect_one(), - qarg=self.visit(node.qarg).expect_one(), - ) - ) - return lowering.Result() diff --git a/src/bloqade/qasm2/passes/py2qasm.py b/src/bloqade/qasm2/passes/py2qasm.py index b87c52e0..a1478864 100644 --- a/src/bloqade/qasm2/passes/py2qasm.py +++ b/src/bloqade/qasm2/passes/py2qasm.py @@ -32,11 +32,12 @@ class _Py2QASM(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: if isinstance(node, py.Constant): - if isinstance(node.value, int): - node.replace_by(expr.ConstInt(value=node.value)) + value = node.value.unwrap() + if isinstance(value, int): + node.replace_by(expr.ConstInt(value=value)) return RewriteResult(has_done_something=True) - elif isinstance(node.value, float): - node.replace_by(expr.ConstFloat(value=node.value)) + elif isinstance(value, float): + node.replace_by(expr.ConstFloat(value=value)) return RewriteResult(has_done_something=True) elif isinstance(node, py.BinOp): if (pystmt := self.BINARY_OPS.get(type(node))) is not None: diff --git a/src/bloqade/qasm2/passes/qasm2py.py b/src/bloqade/qasm2/passes/qasm2py.py index 9b0fa794..e052d891 100644 --- a/src/bloqade/qasm2/passes/qasm2py.py +++ b/src/bloqade/qasm2/passes/qasm2py.py @@ -17,11 +17,11 @@ class _QASM2Py(RewriteRule): UNARY_OPS = { expr.Neg: py.USub, - expr.Sin: math.sin, - expr.Cos: math.cos, - expr.Tan: math.tan, - expr.Exp: math.exp, - expr.Sqrt: math.sqrt, + expr.Sin: math.stmts.sin, + expr.Cos: math.stmts.cos, + expr.Tan: math.stmts.tan, + expr.Exp: math.stmts.exp, + expr.Sqrt: math.stmts.sqrt, } BINARY_OPS = { diff --git a/src/bloqade/qbraid/lowering.py b/src/bloqade/qbraid/lowering.py index 45fed722..0d53ac8e 100644 --- a/src/bloqade/qbraid/lowering.py +++ b/src/bloqade/qbraid/lowering.py @@ -29,7 +29,7 @@ def run_pass( fold_pass(method) typeinfer_pass(method) - method.code.typecheck() + method.code.verify_type() return run_pass diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index dcffd93a..9a2437f4 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from .types import OpType @@ -31,19 +31,19 @@ class BinaryOp(CompositeOp): @statement(dialect=dialect) class Kron(BinaryOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), MaybeUnitary()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) @statement(dialect=dialect) class Mult(BinaryOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), MaybeUnitary()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) @statement(dialect=dialect) class Adjoint(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), MaybeUnitary()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) result: ir.ResultValue = info.result(OpType) @@ -51,7 +51,7 @@ class Adjoint(CompositeOp): @statement(dialect=dialect) class Scale(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), MaybeUnitary()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) factor: ir.SSAValue = info.argument(Complex) @@ -60,7 +60,7 @@ class Scale(CompositeOp): @statement(dialect=dialect) class Control(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), MaybeUnitary()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) n_controls: int = info.attribute() @@ -69,7 +69,7 @@ class Control(CompositeOp): @statement(dialect=dialect) class Rot(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary()}) axis: ir.SSAValue = info.argument(OpType) angle: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -77,7 +77,7 @@ class Rot(CompositeOp): @statement(dialect=dialect) class Identity(CompositeOp): - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSites()}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()}) sites: int = info.attribute() result: ir.ResultValue = info.result(OpType) @@ -85,7 +85,7 @@ class Identity(CompositeOp): @statement class ConstantOp(PrimitiveOp): traits = frozenset( - {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), FixedSites(1)} + {ir.Pure(), lowering.FromPythonCall(), ir.ConstantLike(), FixedSites(1)} ) result: ir.ResultValue = info.result(OpType) @@ -93,7 +93,13 @@ class ConstantOp(PrimitiveOp): @statement class ConstantUnitary(ConstantOp): traits = frozenset( - {ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), FixedSites(1)} + { + ir.Pure(), + lowering.FromPythonCall(), + ir.ConstantLike(), + Unitary(), + FixedSites(1), + } ) @@ -107,7 +113,7 @@ class PhaseOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) @@ -122,7 +128,7 @@ class ShiftOp(PrimitiveOp): $$ """ - traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), FixedSites(1)}) + traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)}) theta: ir.SSAValue = info.argument(types.Float) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py index 9ed80d31..f54882f6 100644 --- a/src/bloqade/squin/qubit.py +++ b/src/bloqade/squin/qubit.py @@ -9,7 +9,7 @@ from typing import Any -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from kirin.dialects import ilist from kirin.lowering import wraps @@ -22,28 +22,28 @@ @statement(dialect=dialect) class New(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) n_qubits: ir.SSAValue = info.argument(types.Int) result: ir.ResultValue = info.result(ilist.IListType[QubitType, types.Any]) @statement(dialect=dialect) class Apply(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) operator: ir.SSAValue = info.argument(OpType) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) @statement(dialect=dialect) class Measure(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) result: ir.ResultValue = info.result(types.Int) @statement(dialect=dialect) class MeasureAndReset(ir.Statement): - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType]) result: ir.ResultValue = info.result(types.Int) diff --git a/src/bloqade/squin/wire.py b/src/bloqade/squin/wire.py index 07e41728..ec46c957 100644 --- a/src/bloqade/squin/wire.py +++ b/src/bloqade/squin/wire.py @@ -6,7 +6,7 @@ dialect. """ -from kirin import ir, types, interp +from kirin import ir, types, interp, lowering from kirin.decl import info, statement from bloqade.types import QubitType @@ -34,7 +34,7 @@ class Wire: # no return value for `wrap` @statement(dialect=dialect) class Wrap(ir.Statement): - traits = frozenset({ir.FromPythonCall(), WireTerminator()}) + traits = frozenset({lowering.FromPythonCall(), WireTerminator()}) wire: ir.SSAValue = info.argument(WireType) qubit: ir.SSAValue = info.argument(QubitType) @@ -43,7 +43,7 @@ class Wrap(ir.Statement): # Unwrap(Qubit) -> Wire @statement(dialect=dialect) class Unwrap(ir.Statement): - traits = frozenset({ir.FromPythonCall(), ir.Pure()}) + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) qubit: ir.SSAValue = info.argument(QubitType) result: ir.ResultValue = info.result(WireType) @@ -52,7 +52,7 @@ class Unwrap(ir.Statement): # In this case though we just need to indicate that an operator is applied to list[wires] @statement(dialect=dialect) class Apply(ir.Statement): # apply(op, w1, w2, ...) - traits = frozenset({ir.FromPythonCall(), ir.Pure()}) + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) operator: ir.SSAValue = info.argument(OpType) inputs: tuple[ir.SSAValue, ...] = info.argument(WireType) @@ -73,14 +73,14 @@ def __init__(self, operator: ir.SSAValue, *args: ir.SSAValue): # the user in the wire dialect. @statement(dialect=dialect) class Measure(ir.Statement): - traits = frozenset({ir.FromPythonCall(), WireTerminator()}) + traits = frozenset({lowering.FromPythonCall(), WireTerminator()}) wire: ir.SSAValue = info.argument(WireType) result: ir.ResultValue = info.result(types.Int) @statement(dialect=dialect) class MeasureAndReset(ir.Statement): - traits = frozenset({ir.FromPythonCall(), WireTerminator()}) + traits = frozenset({lowering.FromPythonCall(), WireTerminator()}) wire: ir.SSAValue = info.argument(WireType) result: ir.ResultValue = info.result(types.Int) out_wire: ir.ResultValue = info.result(WireType) @@ -88,7 +88,7 @@ class MeasureAndReset(ir.Statement): @statement(dialect=dialect) class Reset(ir.Statement): - traits = frozenset({ir.FromPythonCall(), WireTerminator()}) + traits = frozenset({lowering.FromPythonCall(), WireTerminator()}) wire: ir.SSAValue = info.argument(WireType) diff --git a/src/bloqade/stim/dialects/aux/lowering.py b/src/bloqade/stim/dialects/aux/lowering.py index 42cff6ff..a901965d 100644 --- a/src/bloqade/stim/dialects/aux/lowering.py +++ b/src/bloqade/stim/dialects/aux/lowering.py @@ -1,17 +1,16 @@ import ast -from kirin.lowering import Result, FromPythonAST, LoweringState -from kirin.exceptions import DialectLoweringError +from kirin import lowering from . import stmts from ._dialect import dialect @dialect.register -class StimAuxLowering(FromPythonAST): +class StimAuxLowering(lowering.FromPythonAST): def _const_stmt( - self, state: LoweringState, value: int | float | str | bool + self, state: lowering.State, value: int | float | str | bool ) -> stmts.ConstInt | stmts.ConstFloat | stmts.ConstStr | stmts.ConstBool: if isinstance(value, bool): @@ -23,19 +22,19 @@ def _const_stmt( elif isinstance(value, str): return stmts.ConstStr(value=value) else: - raise DialectLoweringError(f"unsupported Stim constant type {type(value)}") + raise lowering.BuildError(f"unsupported Stim constant type {type(value)}") - def lower_Constant(self, state: LoweringState, node: ast.Constant) -> Result: + def lower_Constant(self, state: lowering.State, node: ast.Constant): stmt = self._const_stmt(state, node.value) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) - def lower_Expr(self, state: LoweringState, node: ast.Expr) -> Result: - return state.visit(node.value) + def lower_Expr(self, state: lowering.State, node: ast.Expr): + return state.parent.visit(state, node.value) # just forward the visit - def lower_UnaryOp(self, state: LoweringState, node: ast.UnaryOp) -> Result: + def lower_UnaryOp(self, state: lowering.State, node: ast.UnaryOp): if isinstance(node.op, ast.USub): - value = state.visit(node.operand).expect_one() + value = state.lower(node.operand).expect_one() stmt = stmts.Neg(operand=value) - return Result(state.append_stmt(stmt)) + return state.current_frame.push(stmt) else: - raise DialectLoweringError(f"unsupported Stim unaryop {node.op}") + raise lowering.BuildError(f"unsupported Stim unaryop {node.op}") diff --git a/src/bloqade/stim/dialects/aux/stmts/annotate.py b/src/bloqade/stim/dialects/aux/stmts/annotate.py index 9351832b..eacd881f 100644 --- a/src/bloqade/stim/dialects/aux/stmts/annotate.py +++ b/src/bloqade/stim/dialects/aux/stmts/annotate.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from ..types import RecordType, PauliStringType @@ -10,7 +10,7 @@ @statement(dialect=dialect) class GetRecord(ir.Statement): name = "get_rec" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) id: ir.SSAValue = info.argument(type=types.Int) result: ir.ResultValue = info.result(type=RecordType) @@ -18,7 +18,7 @@ class GetRecord(ir.Statement): @statement(dialect=dialect) class Detector(ir.Statement): name = "detector" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) coord: tuple[ir.SSAValue, ...] = info.argument(PyNum) targets: tuple[ir.SSAValue, ...] = info.argument(RecordType) @@ -26,7 +26,7 @@ class Detector(ir.Statement): @statement(dialect=dialect) class ObservableInclude(ir.Statement): name = "obs.include" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) idx: ir.SSAValue = info.argument(type=types.Int) targets: tuple[ir.SSAValue, ...] = info.argument(RecordType) @@ -34,13 +34,13 @@ class ObservableInclude(ir.Statement): @statement(dialect=dialect) class Tick(ir.Statement): name = "tick" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) @statement(dialect=dialect) class NewPauliString(ir.Statement): name = "new_pauli_string" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) string: tuple[ir.SSAValue, ...] = info.argument(types.String) flipped: tuple[ir.SSAValue, ...] = info.argument(types.Bool) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) diff --git a/src/bloqade/stim/dialects/aux/stmts/const.py b/src/bloqade/stim/dialects/aux/stmts/const.py index dc21ed97..42481fd1 100644 --- a/src/bloqade/stim/dialects/aux/stmts/const.py +++ b/src/bloqade/stim/dialects/aux/stmts/const.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from kirin.print import Printer @@ -10,7 +10,7 @@ class ConstInt(ir.Statement): """IR Statement representing a constant integer value.""" name = "constant.int" - traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()}) + traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()}) value: int = info.attribute(types.Int) """value (int): The constant integer value.""" result: ir.ResultValue = info.result(types.Int) @@ -30,7 +30,7 @@ class ConstFloat(ir.Statement): """IR Statement representing a constant float value.""" name = "constant.float" - traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()}) + traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()}) value: float = info.attribute(types.Float) """value (float): The constant float value.""" result: ir.ResultValue = info.result(types.Float) @@ -50,7 +50,7 @@ class ConstBool(ir.Statement): """IR Statement representing a constant float value.""" name = "constant.bool" - traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()}) + traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()}) value: bool = info.attribute(types.Bool) """value (float): The constant float value.""" result: ir.ResultValue = info.result(types.Bool) @@ -70,7 +70,7 @@ class ConstStr(ir.Statement): """IR Statement representing a constant str value.""" name = "constant.str" - traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()}) + traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()}) value: str = info.attribute(types.String) """value (str): The constant str value.""" result: ir.ResultValue = info.result(types.String) @@ -90,6 +90,6 @@ class Neg(ir.Statement): """IR Statement representing a negation operation.""" name = "neg" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) operand: ir.SSAValue = info.argument(types.Int) result: ir.ResultValue = info.result(types.Int) diff --git a/src/bloqade/stim/dialects/collapse/stmts/measure.py b/src/bloqade/stim/dialects/collapse/stmts/measure.py index 519f7310..bb22be79 100644 --- a/src/bloqade/stim/dialects/collapse/stmts/measure.py +++ b/src/bloqade/stim/dialects/collapse/stmts/measure.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from .._dialect import dialect @@ -7,7 +7,7 @@ @statement class Measurement(ir.Statement): name = "measurement" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) p: ir.SSAValue = info.argument(types.Float) """probability of noise introduced by measurement. For example 0.01 means 1% the measurement will be flipped""" targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) diff --git a/src/bloqade/stim/dialects/collapse/stmts/pp_measure.py b/src/bloqade/stim/dialects/collapse/stmts/pp_measure.py index d9bc2fa3..af16679b 100644 --- a/src/bloqade/stim/dialects/collapse/stmts/pp_measure.py +++ b/src/bloqade/stim/dialects/collapse/stmts/pp_measure.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from .._dialect import dialect @@ -8,7 +8,7 @@ @statement(dialect=dialect) class PPMeasurement(ir.Statement): name = "MPP" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) p: ir.SSAValue = info.argument(types.Float) """probability of noise introduced by measurement. For example 0.01 means 1% the measurement will be flipped""" targets: tuple[ir.SSAValue, ...] = info.argument(PauliStringType) diff --git a/src/bloqade/stim/dialects/collapse/stmts/reset.py b/src/bloqade/stim/dialects/collapse/stmts/reset.py index 16be409c..aa182447 100644 --- a/src/bloqade/stim/dialects/collapse/stmts/reset.py +++ b/src/bloqade/stim/dialects/collapse/stmts/reset.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from .._dialect import dialect @@ -7,7 +7,7 @@ @statement class Reset(ir.Statement): name = "reset" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) diff --git a/src/bloqade/stim/dialects/gate/stmts/base.py b/src/bloqade/stim/dialects/gate/stmts/base.py index cae7beba..f60ba9da 100644 --- a/src/bloqade/stim/dialects/gate/stmts/base.py +++ b/src/bloqade/stim/dialects/gate/stmts/base.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from bloqade.stim.dialects.aux import RecordType @@ -7,7 +7,7 @@ @statement class Gate(ir.Statement): name = "stim_gate" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) dagger: bool = info.attribute(default=False) diff --git a/src/bloqade/stim/dialects/gate/stmts/pp.py b/src/bloqade/stim/dialects/gate/stmts/pp.py index 6653dc4d..d9ab4a90 100644 --- a/src/bloqade/stim/dialects/gate/stmts/pp.py +++ b/src/bloqade/stim/dialects/gate/stmts/pp.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from .._dialect import dialect @@ -10,6 +10,6 @@ @statement(dialect=dialect) class SPP(ir.Statement): name = "SPP" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) dagger: bool = info.attribute(types.Bool, default=False) targets: tuple[ir.SSAValue, ...] = info.argument(PauliStringType) diff --git a/src/bloqade/stim/dialects/noise/stmts.py b/src/bloqade/stim/dialects/noise/stmts.py index f3697cb6..78231c6f 100644 --- a/src/bloqade/stim/dialects/noise/stmts.py +++ b/src/bloqade/stim/dialects/noise/stmts.py @@ -1,4 +1,4 @@ -from kirin import ir, types +from kirin import ir, types, lowering from kirin.decl import info, statement from ._dialect import dialect @@ -7,7 +7,7 @@ @statement(dialect=dialect) class Depolarize1(ir.Statement): name = "Depolarize1" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) p: ir.SSAValue = info.argument(types.Float) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) @@ -15,7 +15,7 @@ class Depolarize1(ir.Statement): @statement(dialect=dialect) class Depolarize2(ir.Statement): name = "Depolarize2" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) p: ir.SSAValue = info.argument(types.Float) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) @@ -23,7 +23,7 @@ class Depolarize2(ir.Statement): @statement(dialect=dialect) class PauliChannel1(ir.Statement): name = "PauliChannel1" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) px: ir.SSAValue = info.argument(types.Float) py: ir.SSAValue = info.argument(types.Float) pz: ir.SSAValue = info.argument(types.Float) @@ -34,7 +34,7 @@ class PauliChannel1(ir.Statement): class PauliChannel2(ir.Statement): name = "PauliChannel2" # TODO custom lowering to make sugar for this - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) pix: ir.SSAValue = info.argument(types.Float) piy: ir.SSAValue = info.argument(types.Float) piz: ir.SSAValue = info.argument(types.Float) @@ -56,7 +56,7 @@ class PauliChannel2(ir.Statement): @statement(dialect=dialect) class XError(ir.Statement): name = "X_ERROR" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) p: ir.SSAValue = info.argument(types.Float) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) @@ -64,7 +64,7 @@ class XError(ir.Statement): @statement(dialect=dialect) class YError(ir.Statement): name = "Y_ERROR" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) p: ir.SSAValue = info.argument(types.Float) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) @@ -72,6 +72,6 @@ class YError(ir.Statement): @statement(dialect=dialect) class ZError(ir.Statement): name = "Z_ERROR" - traits = frozenset({ir.FromPythonCall()}) + traits = frozenset({lowering.FromPythonCall()}) p: ir.SSAValue = info.argument(types.Float) targets: tuple[ir.SSAValue, ...] = info.argument(types.Int) diff --git a/test/pyqrack/test_target.py b/test/pyqrack/test_target.py index 5be96998..e4774cfa 100644 --- a/test/pyqrack/test_target.py +++ b/test/pyqrack/test_target.py @@ -1,7 +1,6 @@ import math from bloqade import qasm2 -from pyqrack import QrackSimulator from bloqade.pyqrack import PyQrack, reg @@ -22,7 +21,6 @@ def ghz(): q = target.run(ghz) assert isinstance(q, reg.PyQrackReg) - assert isinstance(q.sim_reg, QrackSimulator) out = q.sim_reg.out_ket() @@ -40,7 +38,3 @@ def ghz(): assert math.isclose(out[0].real, val, abs_tol=abs_tol) assert math.isclose(out[-1].real, val, abs_tol=abs_tol) assert all(math.isclose(ele.real, 0.0, abs_tol=abs_tol) for ele in out[1:-1]) - - -if __name__ == "__main__": - test_target() diff --git a/test/qasm2/passes/test_heuristic_noise.py b/test/qasm2/passes/test_heuristic_noise.py index 553a2058..8265eedc 100644 --- a/test/qasm2/passes/test_heuristic_noise.py +++ b/test/qasm2/passes/test_heuristic_noise.py @@ -12,6 +12,8 @@ class NoiseTestModel(native.MoveNoiseModelABC): + + @classmethod def parallel_cz_errors(cls, ctrls, qargs, rest): return {(0.01, 0.01, 0.01, 0.01): ctrls + qargs + rest} @@ -290,7 +292,9 @@ def test_method(): q0 := core.QRegGet(reg0.result, zero.result), reg1 := core.QRegNew(n_qubits.result), q1 := core.QRegGet(reg1.result, zero.result), - reg_list := ilist.New(values=[reg0.result, reg1.result]), + reg_list := ilist.New( + values=[reg0.result, reg1.result], elem_type=reg0.result.type + ), theta := constant.Constant(0.1), phi := constant.Constant(0.2), lam := constant.Constant(0.3), diff --git a/test/qasm2/test_inline.py b/test/qasm2/test_inline.py index a53aeb65..8dda1100 100644 --- a/test/qasm2/test_inline.py +++ b/test/qasm2/test_inline.py @@ -64,27 +64,3 @@ def qasm2_inline_code(): qasm2.inline(lines) qasm2_inline_code.print() - - -if __name__ == "__main__": - # test_inline() - - lines = textwrap.dedent( - """ - KIRIN {qasm2.glob, qasm2.uop}; - include "qelib1.inc"; - - qreg q1[2]; - qreg q2[3]; - - glob.U(1.0, 2.0, 3.0) {q1, q2} - """ - ) - - print(lines) - - @qasm2.extended.add(inline) - def qasm2_inline_code(): - qasm2.inline(lines) - - qasm2_inline_code.print() diff --git a/test/qasm2/test_lowering.py b/test/qasm2/test_lowering.py new file mode 100644 index 00000000..21e716b9 --- /dev/null +++ b/test/qasm2/test_lowering.py @@ -0,0 +1,22 @@ +import textwrap + +from bloqade import qasm2 +from bloqade.qasm2.parse.lowering import QASM2 + +lines = textwrap.dedent( + """ +OPENQASM 2.0; + +qreg q[2]; +creg c[2]; + +h q[0]; +CX q[0], q[1]; +barrier q[0], q[1]; +CX q[0], q[1]; +rx(pi/2) q[0]; +""" +) +ast = qasm2.parse.loads(lines) +code = QASM2(qasm2.main).run(ast) +code.print()