diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index dda7a5b1..9b94b63d 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -4,6 +4,7 @@ from kirin.rewrite.walk import Walk from . import op, wire, qubit +from .op.rewrite import PyMultToSquinMult from .rewrite.measure_desugar import MeasureDesugarRule @@ -13,16 +14,21 @@ def kernel(self): typeinfer_pass = passes.TypeInfer(self) ilist_desugar_pass = ilist.IListDesugar(self) measure_desugar_pass = Walk(MeasureDesugarRule()) + py_mult_to_mult_pass = PyMultToSquinMult(self) def run_pass(method: ir.Method, *, fold=True, typeinfer=True): method.verify() if fold: fold_pass.fixpoint(method) + py_mult_to_mult_pass(method) + if typeinfer: typeinfer_pass(method) measure_desugar_pass.rewrite(method.code) + ilist_desugar_pass(method) + if typeinfer: typeinfer_pass(method) # fix types after desugaring method.verify_type() @@ -32,7 +38,9 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): @ir.dialect_group(structural_no_opt.union([op, wire])) def wired(self): + py_mult_to_mult_pass = PyMultToSquinMult(self) + def run_pass(method): - pass + py_mult_to_mult_pass(method) return run_pass diff --git a/src/bloqade/squin/op/__init__.py b/src/bloqade/squin/op/__init__.py index 77b07c64..42cd426a 100644 --- a/src/bloqade/squin/op/__init__.py +++ b/src/bloqade/squin/op/__init__.py @@ -2,7 +2,7 @@ from kirin.prelude import structural_no_opt as _structural_no_opt from kirin.lowering import wraps as _wraps -from . import stmts as stmts, types as types +from . import stmts as stmts, types as types, rewrite as rewrite from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary from ._dialect import dialect as dialect diff --git a/src/bloqade/squin/op/complex.py b/src/bloqade/squin/op/complex.py deleted file mode 100644 index 10e0d630..00000000 --- a/src/bloqade/squin/op/complex.py +++ /dev/null @@ -1,6 +0,0 @@ -# Stopgap Measure, squin dialect needs Complex type but -# this is only available in Kirin 0.15.x - -from kirin.ir.attrs.types import PyClass - -Complex = PyClass(complex) diff --git a/src/bloqade/squin/op/number.py b/src/bloqade/squin/op/number.py new file mode 100644 index 00000000..3cdc12b1 --- /dev/null +++ b/src/bloqade/squin/op/number.py @@ -0,0 +1,5 @@ +import numbers + +from kirin.ir.attrs.types import PyClass + +NumberType = PyClass(numbers.Number) diff --git a/src/bloqade/squin/op/rewrite.py b/src/bloqade/squin/op/rewrite.py new file mode 100644 index 00000000..64000343 --- /dev/null +++ b/src/bloqade/squin/op/rewrite.py @@ -0,0 +1,46 @@ +"""Rewrite py.binop.mult to Mult stmt""" + +from kirin import ir +from kirin.passes import Pass +from kirin.rewrite import Walk +from kirin.dialects import py +from kirin.rewrite.abc import RewriteRule, RewriteResult + +from .stmts import Mult, Scale +from .types import OpType + + +class _PyMultToSquinMult(RewriteRule): + + def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: + if not isinstance(node, py.Mult): + return RewriteResult() + + lhs_is_op = node.lhs.type.is_subseteq(OpType) + rhs_is_op = node.rhs.type.is_subseteq(OpType) + + if not lhs_is_op and not rhs_is_op: + return RewriteResult() + + if lhs_is_op and rhs_is_op: + mult = Mult(node.lhs, node.rhs) + node.replace_by(mult) + return RewriteResult(has_done_something=True) + + if lhs_is_op: + scale = Scale(node.lhs, node.rhs) + node.replace_by(scale) + return RewriteResult(has_done_something=True) + + if rhs_is_op: + scale = Scale(node.rhs, node.lhs) + node.replace_by(scale) + return RewriteResult(has_done_something=True) + + return RewriteResult() + + +class PyMultToSquinMult(Pass): + + def unsafe_run(self, mt: ir.Method) -> RewriteResult: + return Walk(_PyMultToSquinMult()).rewrite(mt.code) diff --git a/src/bloqade/squin/op/stmts.py b/src/bloqade/squin/op/stmts.py index 9f948510..94994c5e 100644 --- a/src/bloqade/squin/op/stmts.py +++ b/src/bloqade/squin/op/stmts.py @@ -2,8 +2,8 @@ from kirin.decl import info, statement from .types import OpType +from .number import NumberType from .traits import Unitary, HasSites, FixedSites, MaybeUnitary -from .complex import Complex from ._dialect import dialect @@ -54,7 +54,7 @@ class Scale(CompositeOp): traits = frozenset({ir.Pure(), lowering.FromPythonCall(), MaybeUnitary()}) is_unitary: bool = info.attribute(default=False) op: ir.SSAValue = info.argument(OpType) - factor: ir.SSAValue = info.argument(Complex) + factor: ir.SSAValue = info.argument(NumberType) result: ir.ResultValue = info.result(OpType) diff --git a/src/bloqade/squin/op/types.py b/src/bloqade/squin/op/types.py index 0c4564e6..4f4388c4 100644 --- a/src/bloqade/squin/op/types.py +++ b/src/bloqade/squin/op/types.py @@ -1,3 +1,5 @@ +from typing import overload + from kirin import types @@ -6,5 +8,17 @@ class Op: def __matmul__(self, other: "Op") -> "Op": raise NotImplementedError("@ can only be used within a squin kernel program") + @overload + def __mul__(self, other: "Op") -> "Op": ... + + @overload + def __mul__(self, other: complex) -> "Op": ... + + def __mul__(self, other) -> "Op": + raise NotImplementedError("@ can only be used within a squin kernel program") + + def __rmul__(self, other: complex) -> "Op": + raise NotImplementedError("@ can only be used within a squin kernel program") + OpType = types.PyClass(Op) diff --git a/test/squin/test_mult_rewrite.py b/test/squin/test_mult_rewrite.py new file mode 100644 index 00000000..8a0bbaea --- /dev/null +++ b/test/squin/test_mult_rewrite.py @@ -0,0 +1,166 @@ +from kirin.types import PyClass +from kirin.dialects import py, func + +from bloqade import squin + + +def test_mult_rewrite(): + + @squin.kernel + def helper(x: squin.op.types.Op, y: squin.op.types.Op): + return x * y + + @squin.kernel + def main(): + q = squin.qubit.new(1) + x = squin.op.x() + y = squin.op.y() + z = x * y + t = helper(x, z) + + squin.qubit.apply(t, q) + return q + + helper.print() + main.print() + + assert isinstance(helper.code, func.Function) + + helper_stmts = list(helper.code.body.stmts()) + assert len(helper_stmts) == 2 # [Mult(), Return()] + assert isinstance(helper_stmts[0], squin.op.stmts.Mult) + + assert isinstance(main.code, func.Function) + + count_mults_in_main = 0 + for stmt in main.code.body.stmts(): + assert not isinstance(stmt, py.Mult) + + count_mults_in_main += isinstance(stmt, squin.op.stmts.Mult) + + assert count_mults_in_main == 1 + + +def test_scale_rewrite(): + + @squin.kernel + def simple_rmul(): + x = squin.op.x() + y = 2 * x + return y + + simple_rmul.print() + + assert isinstance(simple_rmul.code, func.Function) + + simple_rmul_stmts = list(simple_rmul.code.body.stmts()) + assert any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_rmul_stmts) + ) + assert not any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_rmul_stmts) + ) + assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_rmul_stmts)) + + @squin.kernel + def simple_lmul(): + x = squin.op.x() + y = x * 2 + return y + + simple_lmul.print() + + assert isinstance(simple_lmul.code, func.Function) + + simple_lmul_stmts = list(simple_lmul.code.body.stmts()) + assert any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), simple_lmul_stmts) + ) + assert not any( + map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), simple_lmul_stmts) + ) + assert not any(map(lambda stmt: isinstance(stmt, py.Mult), simple_lmul_stmts)) + + @squin.kernel + def scale_mult(): + x = squin.op.x() + y = squin.op.y() + return 2 * (x * y) + + assert isinstance(scale_mult.code, func.Function) + + scale_mult_stmts = list(scale_mult.code.body.stmts()) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult_stmts)) + == 1 + ) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult_stmts)) + == 1 + ) + + @squin.kernel + def scale_mult2(): + x = squin.op.x() + y = squin.op.y() + return 2 * x * y + + scale_mult2.print() + + assert isinstance(scale_mult2.code, func.Function) + + scale_mult2_stmts = list(scale_mult2.code.body.stmts()) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Scale), scale_mult2_stmts)) + == 1 + ) + assert ( + sum(map(lambda stmt: isinstance(stmt, squin.op.stmts.Mult), scale_mult2_stmts)) + == 1 + ) + + +def test_scale_types(): + @squin.kernel + def simple_lmul(): + x = squin.op.x() + y = x * (2 + 0j) + return y + + @squin.kernel + def simple_rmul(): + x = squin.op.x() + y = 2.1 * x + return y + + @squin.kernel + def nested_rmul(): + x = squin.op.x() + y = squin.op.y() + return 2 * x * y + + @squin.kernel + def nested_rmul2(): + x = squin.op.x() + y = squin.op.y() + return 2 * (x * y) + + @squin.kernel + def nested_lmul(): + x = squin.op.x() + y = squin.op.y() + return x * y * 2.0j + + def check_stmt_type(code, typ): + for stmt in code.body.stmts(): + if isinstance(stmt, func.Return): + continue + is_op = stmt.result.type.is_subseteq(squin.op.types.OpType) + is_num = stmt.result.type.is_equal(PyClass(typ)) + assert is_op or is_num + + check_stmt_type(simple_lmul.code, complex) + check_stmt_type(simple_rmul.code, float) + check_stmt_type(nested_rmul.code, int) + check_stmt_type(nested_rmul2.code, int) + check_stmt_type(nested_lmul.code, complex)