Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/bloqade/noise/native/stmts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kirin import ir, types
from kirin import ir, types, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist

Expand All @@ -10,7 +10,7 @@
@statement(dialect=dialect)
class PauliChannel(ir.Statement):

traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})

px: float = info.attribute(types.Float)
py: float = info.attribute(types.Float)
Expand All @@ -24,7 +24,7 @@ class PauliChannel(ir.Statement):
@statement(dialect=dialect)
class CZPauliChannel(ir.Statement):

traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})

paired: bool = info.attribute(types.Bool)
px_ctrl: float = info.attribute(types.Float)
Expand All @@ -40,7 +40,7 @@ class CZPauliChannel(ir.Statement):
@statement(dialect=dialect)
class AtomLossChannel(ir.Statement):

traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})

prob: float = info.attribute(types.Float)
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType])
16 changes: 8 additions & 8 deletions src/bloqade/qasm2/dialects/core/stmts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kirin import ir, types
from kirin import ir, types, lowering
from kirin.decl import info, statement

from bloqade.qasm2.types import BitType, CRegType, QRegType, QubitType
Expand All @@ -11,7 +11,7 @@ class QRegNew(ir.Statement):
"""Create a new quantum register."""

name = "qreg.new"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
n_qubits: ir.SSAValue = info.argument(types.Int)
"""n_qubits: The number of qubits in the register."""
result: ir.ResultValue = info.result(QRegType)
Expand All @@ -23,7 +23,7 @@ class CRegNew(ir.Statement):
"""Create a new classical register."""

name = "creg.new"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
n_bits: ir.SSAValue = info.argument(types.Int)
"""n_bits (Int): The number of bits in the register."""
result: ir.ResultValue = info.result(CRegType)
Expand All @@ -35,7 +35,7 @@ class Reset(ir.Statement):
"""Reset a qubit to the |0> state."""

name = "reset"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
qarg: ir.SSAValue = info.argument(QubitType)
"""qarg (Qubit): The qubit to reset."""

Expand All @@ -45,7 +45,7 @@ class Measure(ir.Statement):
"""Measure a qubit and store the result in a bit."""

name = "measure"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
qarg: ir.SSAValue = info.argument(QubitType)
"""qarg (Qubit): The qubit to measure."""
carg: ir.SSAValue = info.argument(BitType)
Expand All @@ -57,7 +57,7 @@ class CRegEq(ir.Statement):
"""Check if two classical registers are equal."""

name = "eq"
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType)
"""lhs (CReg): The first register."""
rhs: ir.SSAValue = info.argument(types.Int | CRegType | BitType)
Expand All @@ -71,7 +71,7 @@ class QRegGet(ir.Statement):
"""Get a qubit from a quantum register."""

name = "qreg.get"
traits = frozenset({ir.FromPythonCall(), ir.Pure()})
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
reg: ir.SSAValue = info.argument(QRegType)
"""reg (QReg): The quantum register."""
idx: ir.SSAValue = info.argument(types.Int)
Expand All @@ -85,7 +85,7 @@ class CRegGet(ir.Statement):
"""Get a bit from a classical register."""

name = "creg.get"
traits = frozenset({ir.FromPythonCall(), ir.Pure()})
traits = frozenset({lowering.FromPythonCall(), ir.Pure()})
reg: ir.SSAValue = info.argument(CRegType)
"""reg (CReg): The classical register."""
idx: ir.SSAValue = info.argument(types.Int)
Expand Down
55 changes: 26 additions & 29 deletions src/bloqade/qasm2/dialects/expr/lowering.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,29 @@
import ast

from kirin import ir, types
from kirin.lowering import Result, FromPythonAST, LoweringState
from kirin.exceptions import DialectLoweringError
from kirin import ir, types, lowering

from . import stmts
from ._dialect import dialect


@dialect.register
class QASMUopLowering(FromPythonAST):
class QASMUopLowering(lowering.FromPythonAST):

def lower_Name(self, state: LoweringState, node: ast.Name) -> Result:
def lower_Name(self, state: lowering.State, node: ast.Name):
name = node.id
if isinstance(node.ctx, ast.Load):
value = state.current_frame.get(name)
if value is None:
raise DialectLoweringError(f"{name} is not defined")
return Result(value)
raise lowering.BuildError(f"{name} is not defined")
return value
elif isinstance(node.ctx, ast.Store):
raise DialectLoweringError("unhandled store operation")
raise lowering.BuildError("unhandled store operation")
else: # Del
raise DialectLoweringError("unhandled del operation")
raise lowering.BuildError("unhandled del operation")

def lower_Assign(self, state: LoweringState, node: ast.Assign) -> Result:
def lower_Assign(self, state: lowering.State, node: ast.Assign):
# NOTE: QASM only expects one value on right hand side
rhs = state.visit(node.value).expect_one()
rhs = state.lower(node.value).expect_one()
current_frame = state.current_frame
match node:
case ast.Assign(targets=[ast.Name(lhs_name, ast.Store())], value=_):
Expand All @@ -35,27 +33,26 @@ def lower_Assign(self, state: LoweringState, node: ast.Assign) -> Result:
target_syntax = ", ".join(
ast.unparse(target) for target in node.targets
)
raise DialectLoweringError(f"unsupported target syntax {target_syntax}")
return Result() # python assign does not have value
raise lowering.BuildError(f"unsupported target syntax {target_syntax}")

def lower_Expr(self, state: LoweringState, node: ast.Expr) -> Result:
return state.visit(node.value)
def lower_Expr(self, state: lowering.State, node: ast.Expr):
return state.lower(node.value)

def lower_Constant(self, state: LoweringState, node: ast.Constant) -> Result:
def lower_Constant(self, state: lowering.State, node: ast.Constant):
if isinstance(node.value, int):
stmt = stmts.ConstInt(value=node.value)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)
elif isinstance(node.value, float):
stmt = stmts.ConstFloat(value=node.value)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)
else:
raise DialectLoweringError(
raise lowering.BuildError(
f"unsupported QASM 2.0 constant type {type(node.value)}"
)

def lower_BinOp(self, state: LoweringState, node: ast.BinOp) -> Result:
lhs = state.visit(node.left).expect_one()
rhs = state.visit(node.right).expect_one()
def lower_BinOp(self, state: lowering.State, node: ast.BinOp):
lhs = state.lower(node.left).expect_one()
rhs = state.lower(node.right).expect_one()
if isinstance(node.op, ast.Add):
stmt = stmts.Add(lhs, rhs)
elif isinstance(node.op, ast.Sub):
Expand All @@ -67,9 +64,9 @@ def lower_BinOp(self, state: LoweringState, node: ast.BinOp) -> Result:
elif isinstance(node.op, ast.Pow):
stmt = stmts.Pow(lhs, rhs)
else:
raise DialectLoweringError(f"unsupported QASM 2.0 binop {node.op}")
raise lowering.BuildError(f"unsupported QASM 2.0 binop {node.op}")
stmt.result.type = self.__promote_binop_type(lhs, rhs)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)

def __promote_binop_type(
self, lhs: ir.SSAValue, rhs: ir.SSAValue
Expand All @@ -78,12 +75,12 @@ def __promote_binop_type(
return types.Float
return types.Int

def lower_UnaryOp(self, state: LoweringState, node: ast.UnaryOp) -> Result:
def lower_UnaryOp(self, state: lowering.State, node: ast.UnaryOp):
if isinstance(node.op, ast.USub):
value = state.visit(node.operand).expect_one()
value = state.lower(node.operand).expect_one()
stmt = stmts.Neg(value)
return Result(state.append_stmt(stmt))
return state.current_frame.push(stmt)
elif isinstance(node.op, ast.UAdd):
return state.visit(node.operand)
return state.lower(node.operand)
else:
raise DialectLoweringError(f"unsupported QASM 2.0 unaryop {node.op}")
raise lowering.BuildError(f"unsupported QASM 2.0 unaryop {node.op}")
32 changes: 16 additions & 16 deletions src/bloqade/qasm2/dialects/expr/stmts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kirin import ir, types
from kirin import ir, types, lowering
from kirin.decl import info, statement
from kirin.print.printer import Printer
from kirin.dialects.func.attrs import Signature
Expand Down Expand Up @@ -50,7 +50,7 @@ class ConstInt(ir.Statement):
"""IR Statement representing a constant integer value."""

name = "constant.int"
traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()})
traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()})
value: int = info.attribute(types.Int)
"""value (int): The constant integer value."""
result: ir.ResultValue = info.result(types.Int)
Expand All @@ -70,7 +70,7 @@ class ConstFloat(ir.Statement):
"""IR Statement representing a constant float value."""

name = "constant.float"
traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()})
traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()})
value: float = info.attribute(types.Float)
"""value (float): The constant float value."""
result: ir.ResultValue = info.result(types.Float)
Expand All @@ -91,7 +91,7 @@ class ConstPI(ir.Statement):

# this is marked as constant but not pure.
name = "constant.pi"
traits = frozenset({ir.ConstantLike(), ir.FromPythonCall()})
traits = frozenset({ir.ConstantLike(), lowering.FromPythonCall()})
result: ir.ResultValue = info.result(types.Float)
"""result (ConstPI): The result value."""

Expand All @@ -113,7 +113,7 @@ class Neg(ir.Statement):
"""Negate a number."""

name = "neg"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
value: ir.SSAValue = info.argument(PyNum)
"""value (Union[int, float]): The number to negate."""
result: ir.ResultValue = info.result(PyNum)
Expand All @@ -125,7 +125,7 @@ class Sin(ir.Statement):
"""Take the sine of a number."""

name = "sin"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
value: ir.SSAValue = info.argument(PyNum)
"""value (Union[int, float]): The number to take the sine of."""
result: ir.ResultValue = info.result(PyNum)
Expand All @@ -137,7 +137,7 @@ class Cos(ir.Statement):
"""Take the cosine of a number."""

name = "cos"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
value: ir.SSAValue = info.argument(PyNum)
"""value (Union[int, float]): The number to take the cosine of."""
result: ir.ResultValue = info.result(PyNum)
Expand All @@ -149,7 +149,7 @@ class Tan(ir.Statement):
"""Take the tangent of a number."""

name = "tan"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
value: ir.SSAValue = info.argument(PyNum)
"""value (Union[int, float]): The number to take the tangent of."""
result: ir.ResultValue = info.result(PyNum)
Expand All @@ -161,7 +161,7 @@ class Exp(ir.Statement):
"""Take the exponential of a number."""

name = "exp"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
value: ir.SSAValue = info.argument(PyNum)
"""value (Union[int, float]): The number to take the exponential of."""
result: ir.ResultValue = info.result(PyNum)
Expand All @@ -173,7 +173,7 @@ class Log(ir.Statement):
"""Take the natural log of a number."""

name = "ln"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
value: ir.SSAValue = info.argument(PyNum)
"""value (Union[int, float]): The number to take the natural log of."""
result: ir.ResultValue = info.result(PyNum)
Expand All @@ -185,7 +185,7 @@ class Sqrt(ir.Statement):
"""Take the square root of a number."""

name = "sqrt"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
value: ir.SSAValue = info.argument(PyNum)
"""value (Union[int, float]): The number to take the square root of."""
result: ir.ResultValue = info.result(PyNum)
Expand All @@ -197,7 +197,7 @@ class Add(ir.Statement):
"""Add two numbers."""

name = "add"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(PyNum)
"""lhs (Union[int, float]): The left-hand side of the addition."""
rhs: ir.SSAValue = info.argument(PyNum)
Expand All @@ -211,7 +211,7 @@ class Sub(ir.Statement):
"""Subtract two numbers."""

name = "sub"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(PyNum)
"""lhs (Union[int, float]): The left-hand side of the subtraction."""
rhs: ir.SSAValue = info.argument(PyNum)
Expand All @@ -225,7 +225,7 @@ class Mul(ir.Statement):
"""Multiply two numbers."""

name = "mul"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(PyNum)
"""lhs (Union[int, float]): The left-hand side of the multiplication."""
rhs: ir.SSAValue = info.argument(PyNum)
Expand All @@ -239,7 +239,7 @@ class Pow(ir.Statement):
"""Take the power of a number."""

name = "pow"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(PyNum)
"""lhs (Union[int, float]): The base."""
rhs: ir.SSAValue = info.argument(PyNum)
Expand All @@ -253,7 +253,7 @@ class Div(ir.Statement):
"""Divide two numbers."""

name = "div"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
lhs: ir.SSAValue = info.argument(PyNum)
"""lhs (Union[int, float]): The numerator."""
rhs: ir.SSAValue = info.argument(PyNum)
Expand Down
4 changes: 2 additions & 2 deletions src/bloqade/qasm2/dialects/glob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from kirin import ir, types, interp
from kirin import ir, types, interp, lowering
from kirin.decl import info, statement
from kirin.dialects import ilist

Expand All @@ -13,7 +13,7 @@
@statement(dialect=dialect)
class UGate(ir.Statement):
name = "ugate"
traits = frozenset({ir.FromPythonCall()})
traits = frozenset({lowering.FromPythonCall()})
registers: ir.SSAValue = info.argument(ilist.IListType[QRegType])
theta: ir.SSAValue = info.argument(types.Float)
phi: ir.SSAValue = info.argument(types.Float)
Expand Down
Loading
Loading