Skip to content

Commit c0a0593

Browse files
committed
add unit tests, implement name changes suggested by Roger
1 parent accb460 commit c0a0593

File tree

6 files changed

+256
-26
lines changed

6 files changed

+256
-26
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
11
# Need this for impl registration to work properly!
22
from . import impls as impls
3+
from .lattice import (
4+
NoSites as NoSites,
5+
AnySites as AnySites,
6+
HasNSites as HasNSites,
7+
)
38
from .analysis import NSitesAnalysis as NSitesAnalysis

src/bloqade/squin/analysis/nsites/analysis.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from kirin.analysis.forward import ForwardFrame
66

77
from bloqade.squin.op.types import OpType
8-
from bloqade.squin.op.traits import Sites, HasNSitesTrait
8+
from bloqade.squin.op.traits import NSites, HasNSitesTrait
99

10-
from .lattice import NSites, NoSites, HasNSites
10+
from .lattice import Sites, NoSites, HasNSites
1111

1212

13-
class NSitesAnalysis(Forward[NSites]):
13+
class NSitesAnalysis(Forward[Sites]):
1414

1515
keys = ["op.nsites"]
16-
lattice = NSites
16+
lattice = Sites
1717

1818
# Take a page from const prop in Kirin,
1919
# I can get the data I want from the SizedTrait
@@ -28,16 +28,16 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
2828
has_n_sites_trait = stmt.get_trait(HasNSitesTrait)
2929
sites = has_n_sites_trait.get_sites(stmt)
3030
return (HasNSites(sites=sites),)
31-
elif stmt.has_trait(Sites):
32-
sites_trait = stmt.get_trait(Sites)
31+
elif stmt.has_trait(NSites):
32+
sites_trait = stmt.get_trait(NSites)
3333
return (HasNSites(sites=sites_trait.data),)
3434
else:
3535
return (NoSites(),)
3636

3737
# For when no implementation is found for the statement
3838
def eval_stmt_fallback(
39-
self, frame: ForwardFrame[NSites], stmt: ir.Statement
40-
) -> tuple[NSites, ...]: # some form of Shape will go back into the frame
39+
self, frame: ForwardFrame[Sites], stmt: ir.Statement
40+
) -> tuple[Sites, ...]: # some form of Shape will go back into the frame
4141
return tuple(
4242
(
4343
self.lattice.top()
@@ -47,6 +47,6 @@ def eval_stmt_fallback(
4747
for result in stmt.results
4848
)
4949

50-
def run_method(self, method: ir.Method, args: tuple[NSites, ...]):
50+
def run_method(self, method: ir.Method, args: tuple[Sites, ...]):
5151
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
5252
return self.run_callable(method.code, (self.lattice.bottom(),) + args)

src/bloqade/squin/analysis/nsites/lattice.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,40 @@
1010

1111

1212
@dataclass
13-
class NSites(
14-
SimpleJoinMixin["NSites"], SimpleMeetMixin["NSites"], BoundedLattice["NSites"]
13+
class Sites(
14+
SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"]
1515
):
1616
@classmethod
17-
def bottom(cls) -> "NSites":
17+
def bottom(cls) -> "Sites":
1818
return NoSites()
1919

2020
@classmethod
21-
def top(cls) -> "NSites":
21+
def top(cls) -> "Sites":
2222
return AnySites()
2323

2424

2525
@final
2626
@dataclass
27-
class NoSites(NSites, metaclass=SingletonMeta):
27+
class NoSites(Sites, metaclass=SingletonMeta):
2828

29-
def is_subseteq(self, other: NSites) -> bool:
29+
def is_subseteq(self, other: Sites) -> bool:
3030
return True
3131

3232

3333
@final
3434
@dataclass
35-
class AnySites(NSites, metaclass=SingletonMeta):
35+
class AnySites(Sites, metaclass=SingletonMeta):
3636

37-
def is_subseteq(self, other: NSites) -> bool:
38-
return isinstance(other, NSites)
37+
def is_subseteq(self, other: Sites) -> bool:
38+
return isinstance(other, Sites)
3939

4040

4141
@final
4242
@dataclass
43-
class HasNSites(NSites):
43+
class HasNSites(Sites):
4444
sites: int
4545

46-
def is_subseteq(self, other: NSites) -> bool:
46+
def is_subseteq(self, other: Sites) -> bool:
4747
if isinstance(other, HasNSites):
4848
return self.sites == other.sites
4949
return False

src/bloqade/squin/op/stmts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.decl import info, statement
33

44
from .types import OpType
5-
from .traits import Sites, Unitary, MaybeUnitary, HasNSitesTrait
5+
from .traits import NSites, Unitary, MaybeUnitary, HasNSitesTrait
66
from .complex import Complex
77
from ._dialect import dialect
88

@@ -84,14 +84,14 @@ class Identity(CompositeOp):
8484

8585
@statement
8686
class ConstantOp(PrimitiveOp):
87-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sites(1)})
87+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), NSites(1)})
8888
result: ir.ResultValue = info.result(OpType)
8989

9090

9191
@statement
9292
class ConstantUnitary(ConstantOp):
9393
traits = frozenset(
94-
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sites(1)}
94+
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), NSites(1)}
9595
)
9696

9797

@@ -105,7 +105,7 @@ class PhaseOp(PrimitiveOp):
105105
$$
106106
"""
107107

108-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
108+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), NSites(1)})
109109
theta: ir.SSAValue = info.argument(types.Float)
110110
result: ir.ResultValue = info.result(OpType)
111111

@@ -120,7 +120,7 @@ class ShiftOp(PrimitiveOp):
120120
$$
121121
"""
122122

123-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
123+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), NSites(1)})
124124
theta: ir.SSAValue = info.argument(types.Float)
125125
result: ir.ResultValue = info.result(OpType)
126126

src/bloqade/squin/op/traits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
@dataclass(frozen=True)
8-
class Sites(ir.StmtTrait):
8+
class NSites(ir.StmtTrait):
99
data: int
1010

1111

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
from kirin import ir, types
2+
from kirin.passes import Fold
3+
from kirin.dialects import py, func
4+
5+
from bloqade import squin
6+
from bloqade.squin.analysis import nsites
7+
8+
9+
def as_int(value: int):
10+
return py.constant.Constant(value=value)
11+
12+
13+
def as_float(value: float):
14+
return py.constant.Constant(value=value)
15+
16+
17+
def gen_func_from_stmts(stmts):
18+
19+
squin_with_py = squin.groups.wired.add(py)
20+
21+
block = ir.Block(stmts)
22+
block.args.append_from(types.MethodType[[], types.NoneType], "main_self")
23+
func_wrapper = func.Function(
24+
sym_name="main",
25+
signature=func.Signature(inputs=(), output=squin.op.types.OpType),
26+
body=ir.Region(blocks=block),
27+
)
28+
29+
constructed_method = ir.Method(
30+
mod=None,
31+
py_func=None,
32+
sym_name="main",
33+
dialects=squin_with_py,
34+
code=func_wrapper,
35+
arg_names=[],
36+
)
37+
38+
fold_pass = Fold(squin_with_py)
39+
fold_pass(constructed_method)
40+
41+
return constructed_method
42+
43+
44+
def test_primitive_ops():
45+
pass
46+
47+
48+
# Kron, Mult, Control, Rot, and Scale all have methods defined for handling them in impls,
49+
# The following should ensure the code paths are properly exercised
50+
51+
52+
def test_control():
53+
# Control doesn't have an impl but it is handled in the eval_stmt of the interpreter
54+
# because it has a HasNSitesTrait future statements might have
55+
56+
stmts: list[ir.Statement] = [
57+
(h0 := squin.op.stmts.H()),
58+
(controlled_h := squin.op.stmts.Control(op=h0.result, n_controls=1)),
59+
(func.Return(controlled_h.result)),
60+
]
61+
62+
constructed_method = gen_func_from_stmts(stmts)
63+
64+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
65+
constructed_method, no_raise=False
66+
)
67+
68+
has_n_sites = []
69+
for nsites_type in nsites_frame.entries.values():
70+
if isinstance(nsites_type, nsites.HasNSites):
71+
has_n_sites.append(nsites_type)
72+
73+
assert len(has_n_sites) == 2
74+
assert has_n_sites[0].sites == 1
75+
assert has_n_sites[1].sites == 2
76+
77+
78+
def test_kron():
79+
80+
stmts: list[ir.Statement] = [
81+
(h0 := squin.op.stmts.H()),
82+
(h1 := squin.op.stmts.H()),
83+
(hh := squin.op.stmts.Kron(h0.result, h1.result)),
84+
(func.Return(hh.result)),
85+
]
86+
87+
constructed_method = gen_func_from_stmts(stmts)
88+
89+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
90+
constructed_method, no_raise=False
91+
)
92+
93+
has_n_sites = []
94+
for nsites_type in nsites_frame.entries.values():
95+
if isinstance(nsites_type, nsites.HasNSites):
96+
has_n_sites.append(nsites_type)
97+
98+
assert len(has_n_sites) == 3
99+
assert has_n_sites[0].sites == 1
100+
assert has_n_sites[1].sites == 1
101+
assert has_n_sites[2].sites == 3
102+
103+
104+
def test_mult_square_same_sites():
105+
# Ensure that two operators of the same size produce
106+
# a valid operator as their result
107+
108+
stmts: list[ir.Statement] = [
109+
(h0 := squin.op.stmts.H()),
110+
(h1 := squin.op.stmts.H()),
111+
(h2 := squin.op.stmts.Mult(h0.result, h1.result)),
112+
(func.Return(h2.result)),
113+
]
114+
115+
constructed_method = gen_func_from_stmts(stmts)
116+
117+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
118+
constructed_method, no_raise=False
119+
)
120+
121+
has_n_sites = []
122+
for nsites_type in nsites_frame.entries.values():
123+
if isinstance(nsites_type, nsites.HasNSites):
124+
has_n_sites.append(nsites_type)
125+
126+
# should be three HasNSites types
127+
assert len(has_n_sites) == 3
128+
# the first 2 HasNSites will have 1 site but
129+
# the Kron-produced operator should have 2 sites
130+
assert has_n_sites[0].sites == 1
131+
assert has_n_sites[1].sites == 1
132+
assert has_n_sites[2].sites == 2
133+
134+
135+
def test_mult_square_different_sites():
136+
# Ensure that two operators of different sizes produce
137+
# NoSites as a type. Note that a better solution would be
138+
# to implement a special error type in the type lattice
139+
# but this would introduce some complexity later on
140+
141+
stmts: list[ir.Statement] = [
142+
(h0 := squin.op.stmts.H()),
143+
(h1 := squin.op.stmts.H()),
144+
# Kron to make nsites = 2 operator
145+
(hh := squin.op.stmts.Kron(h0.result, h1.result)),
146+
# apply Mult on HasNSites(2) and HasNSites(1)
147+
(invalid_op := squin.op.stmts.Mult(hh.result, h1.result)),
148+
(func.Return(invalid_op.result)),
149+
]
150+
151+
constructed_method = gen_func_from_stmts(stmts)
152+
153+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
154+
constructed_method, no_raise=False
155+
)
156+
157+
nsites_types = list(nsites_frame.entries.values())
158+
159+
has_n_sites = []
160+
no_sites = []
161+
for nsite_type in nsites_types:
162+
if isinstance(nsite_type, nsites.HasNSites):
163+
has_n_sites.append(nsite_type)
164+
elif isinstance(nsite_type, nsites.NoSites):
165+
no_sites.append(nsite_type)
166+
167+
assert len(has_n_sites) == 3
168+
# HasNSites(1) for Hadamards, 2 for Kron result
169+
assert has_n_sites[0].sites == 1
170+
assert has_n_sites[1].sites == 1
171+
assert has_n_sites[2].sites == 2
172+
# One from function itself, another from invalid mult
173+
assert len(no_sites) == 2
174+
175+
176+
def test_rot():
177+
178+
stmts: list[ir.Statement] = [
179+
(h0 := squin.op.stmts.H()),
180+
(angle := as_float(0.2)),
181+
(rot_h := squin.op.stmts.Rot(axis=h0.result, angle=angle.result)),
182+
(func.Return(rot_h.result)),
183+
]
184+
185+
constructed_method = gen_func_from_stmts(stmts)
186+
187+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
188+
constructed_method, no_raise=False
189+
)
190+
191+
has_n_sites = []
192+
for nsites_type in nsites_frame.entries.values():
193+
if isinstance(nsites_type, nsites.HasNSites):
194+
has_n_sites.append(nsites_type)
195+
196+
assert len(has_n_sites) == 2
197+
# Rot should just propagate whatever Sites type is there
198+
assert has_n_sites[0].sites == 1
199+
assert has_n_sites[1].sites == 1
200+
201+
202+
def test_scale():
203+
204+
stmts: list[ir.Statement] = [
205+
(h0 := squin.op.stmts.H()),
206+
(factor := as_float(0.2)),
207+
(rot_h := squin.op.stmts.Scale(op=h0.result, factor=factor.result)),
208+
(func.Return(rot_h.result)),
209+
]
210+
211+
constructed_method = gen_func_from_stmts(stmts)
212+
213+
nsites_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
214+
constructed_method, no_raise=False
215+
)
216+
217+
has_n_sites = []
218+
for nsites_type in nsites_frame.entries.values():
219+
if isinstance(nsites_type, nsites.HasNSites):
220+
has_n_sites.append(nsites_type)
221+
222+
assert len(has_n_sites) == 2
223+
# Rot should just propagate whatever Sites type is there
224+
assert has_n_sites[0].sites == 1
225+
assert has_n_sites[1].sites == 1

0 commit comments

Comments
 (0)