Skip to content

Commit 5884d0b

Browse files
committed
Rewrite py.Mult to squin.Mult in squin kernel
1 parent 4dbe384 commit 5884d0b

File tree

5 files changed

+96
-1
lines changed

5 files changed

+96
-1
lines changed

src/bloqade/squin/groups.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
from kirin.dialects import ilist
44

55
from . import op, wire, qubit
6+
from .op.rewrite import PyMultToSquinMult
67

78

89
@ir.dialect_group(structural_no_opt.union([op, qubit]))
910
def kernel(self):
1011
fold_pass = passes.Fold(self)
1112
typeinfer_pass = passes.TypeInfer(self)
1213
ilist_desugar_pass = ilist.IListDesugar(self)
14+
py_mult_to_mult_pass = PyMultToSquinMult(self)
1315

1416
def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1517
method.verify()
@@ -18,7 +20,10 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
1820

1921
if typeinfer:
2022
typeinfer_pass(method)
23+
2124
ilist_desugar_pass(method)
25+
py_mult_to_mult_pass(method)
26+
2227
if typeinfer:
2328
typeinfer_pass(method) # fix types after desugaring
2429
method.verify_type()

src/bloqade/squin/op/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.prelude import structural_no_opt as _structural_no_opt
33
from kirin.lowering import wraps as _wraps
44

5-
from . import stmts as stmts, types as types
5+
from . import stmts as stmts, types as types, rewrite as rewrite
66
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
77
from ._dialect import dialect as dialect
88

src/bloqade/squin/op/rewrite.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Rewrite py.binop.mult to Mult stmt"""
2+
3+
from kirin import ir
4+
from kirin.passes import Pass
5+
from kirin.rewrite import Walk
6+
from kirin.dialects import py
7+
from kirin.rewrite.abc import RewriteRule, RewriteResult
8+
9+
from .stmts import Mult, Scale
10+
from .types import OpType
11+
12+
13+
class _PyMultToSquinMult(RewriteRule):
14+
15+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
16+
if not isinstance(node, py.Mult):
17+
return RewriteResult()
18+
19+
lhs_is_op = node.lhs.type.is_subseteq(OpType)
20+
rhs_is_op = node.lhs.type.is_subseteq(OpType)
21+
22+
if not lhs_is_op and not rhs_is_op:
23+
return RewriteResult()
24+
25+
if lhs_is_op and rhs_is_op:
26+
mult = Mult(node.lhs, node.rhs)
27+
node.replace_by(mult)
28+
return RewriteResult(has_done_something=True)
29+
30+
if lhs_is_op:
31+
scale = Scale(node.lhs, node.rhs)
32+
node.replace_by(scale)
33+
return RewriteResult(has_done_something=True)
34+
35+
if rhs_is_op:
36+
scale = Scale(node.rhs, node.lhs)
37+
node.replace_by(scale)
38+
return RewriteResult(has_done_something=True)
39+
40+
raise ValueError(
41+
"Rewrite of py.binop.mult failed. This exception should not be reachable, please report this issue."
42+
)
43+
44+
45+
class PyMultToSquinMult(Pass):
46+
47+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
48+
return Walk(_PyMultToSquinMult()).rewrite(mt.code)

src/bloqade/squin/op/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,8 @@ class Op:
66
def __matmul__(self, other: "Op") -> "Op":
77
raise NotImplementedError("@ can only be used within a squin kernel program")
88

9+
def __mul__(self, other: "Op") -> "Op":
10+
raise NotImplementedError("@ can only be used within a squin kernel program")
11+
912

1013
OpType = types.PyClass(Op)

test/squin/test_mult_rewrite.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from kirin.dialects import py, func
2+
3+
from bloqade import squin
4+
5+
6+
def test_mult_rewrite():
7+
8+
@squin.kernel
9+
def helper(x: squin.op.types.Op, y: squin.op.types.Op):
10+
return x * y
11+
12+
@squin.kernel
13+
def main():
14+
q = squin.qubit.new(1)
15+
x = squin.op.x()
16+
y = squin.op.y()
17+
z = x * y
18+
t = helper(x, z)
19+
20+
squin.qubit.apply(t, q)
21+
return q
22+
23+
helper.print()
24+
25+
assert isinstance(helper.code, func.Function)
26+
27+
helper_stmts = list(helper.code.body.stmts())
28+
assert len(helper_stmts) == 2 # [Mult(), Return()]
29+
assert isinstance(helper_stmts[0], squin.op.stmts.Mult)
30+
31+
assert isinstance(main.code, func.Function)
32+
33+
count_mults_in_main = 0
34+
for stmt in main.code.body.stmts():
35+
assert not isinstance(stmt, py.Mult)
36+
37+
count_mults_in_main += isinstance(stmt, squin.op.stmts.Mult)
38+
39+
assert count_mults_in_main == 1

0 commit comments

Comments
 (0)