Skip to content

Commit accb460

Browse files
committed
rename to use sites instead of shape
1 parent 7af6fb0 commit accb460

File tree

10 files changed

+152
-156
lines changed

10 files changed

+152
-156
lines changed

squin_op_playground.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin.dialects import py, func
44

55
from bloqade import squin
6-
from bloqade.squin.analysis import shape
6+
from bloqade.squin.analysis import nsites
77

88

99
def as_int(value: int):
@@ -51,7 +51,7 @@ def as_int(value: int):
5151
constructed_method.print(analysis=address_frame.entries)
5252
"""
5353

54-
shape_frame, _ = shape.ShapeAnalysis(constructed_method.dialects).run_analysis(
54+
shape_frame, _ = nsites.NSitesAnalysis(constructed_method.dialects).run_analysis(
5555
constructed_method, no_raise=False
5656
)
5757

src/bloqade/analysis/address/analysis.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class AddressAnalysis(Forward[Address]):
2222
def initialize(self):
2323
super().initialize()
2424
self.next_address: int = 0
25-
print(self.registry.statements)
2625
return self
2726

2827
@property
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Need this for impl registration to work properly!
22
from . import impls as impls
3-
from .analysis import ShapeAnalysis as ShapeAnalysis
3+
from .analysis import NSitesAnalysis as NSitesAnalysis

src/bloqade/squin/analysis/shape/analysis.py renamed to src/bloqade/squin/analysis/nsites/analysis.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,15 @@
55
from kirin.analysis.forward import ForwardFrame
66

77
from bloqade.squin.op.types import OpType
8-
from bloqade.squin.op.traits import Sized, HasSize
8+
from bloqade.squin.op.traits import Sites, HasNSitesTrait
99

10-
from .lattice import Shape, NoShape, OpShape
10+
from .lattice import NSites, NoSites, HasNSites
1111

1212

13-
class ShapeAnalysis(Forward[Shape]):
13+
class NSitesAnalysis(Forward[NSites]):
1414

15-
keys = ["op.shape"]
16-
lattice = Shape
17-
18-
def initialize(self):
19-
super().initialize()
20-
return self
15+
keys = ["op.nsites"]
16+
lattice = NSites
2117

2218
# Take a page from const prop in Kirin,
2319
# I can get the data I want from the SizedTrait
@@ -28,20 +24,20 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
2824
method = self.lookup_registry(frame, stmt)
2925
if method is not None:
3026
return method(self, frame, stmt)
31-
elif stmt.has_trait(HasSize):
32-
has_size_inst = stmt.get_trait(HasSize)
33-
size = has_size_inst.get_size(stmt)
34-
return (OpShape(size=size),)
35-
elif stmt.has_trait(Sized):
36-
size = stmt.get_trait(Sized)
37-
return (OpShape(size=size.data),)
27+
elif stmt.has_trait(HasNSitesTrait):
28+
has_n_sites_trait = stmt.get_trait(HasNSitesTrait)
29+
sites = has_n_sites_trait.get_sites(stmt)
30+
return (HasNSites(sites=sites),)
31+
elif stmt.has_trait(Sites):
32+
sites_trait = stmt.get_trait(Sites)
33+
return (HasNSites(sites=sites_trait.data),)
3834
else:
39-
return (NoShape(),)
35+
return (NoSites(),)
4036

4137
# For when no implementation is found for the statement
4238
def eval_stmt_fallback(
43-
self, frame: ForwardFrame[Shape], stmt: ir.Statement
44-
) -> tuple[Shape, ...]: # some form of Shape will go back into the frame
39+
self, frame: ForwardFrame[NSites], stmt: ir.Statement
40+
) -> tuple[NSites, ...]: # some form of Shape will go back into the frame
4541
return tuple(
4642
(
4743
self.lattice.top()
@@ -51,6 +47,6 @@ def eval_stmt_fallback(
5147
for result in stmt.results
5248
)
5349

54-
def run_method(self, method: ir.Method, args: tuple[Shape, ...]):
50+
def run_method(self, method: ir.Method, args: tuple[NSites, ...]):
5551
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
5652
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from typing import cast
2+
3+
from kirin import ir, interp
4+
5+
from bloqade.squin import op
6+
7+
from .lattice import (
8+
NoSites,
9+
HasNSites,
10+
)
11+
from .analysis import NSitesAnalysis
12+
13+
14+
@op.dialect.register(key="op.nsites")
15+
class SquinOp(interp.MethodTable):
16+
17+
@interp.impl(op.stmts.Kron)
18+
def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
19+
lhs = frame.get(stmt.lhs)
20+
rhs = frame.get(stmt.rhs)
21+
if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
22+
new_n_sites = lhs.sites + rhs.sites
23+
return (HasNSites(sites=new_n_sites),)
24+
else:
25+
return (NoSites(),)
26+
27+
@interp.impl(op.stmts.Mult)
28+
def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
29+
lhs = frame.get(stmt.lhs)
30+
rhs = frame.get(stmt.rhs)
31+
32+
if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
33+
lhs_sites = lhs.sites
34+
rhs_sites = rhs.sites
35+
# I originally considered throwing an exception here
36+
# but Xiu-zhe (Roger) Luo has pointed out it would be
37+
# a much better UX to add a type element that
38+
# could explicitly indicate the error. The downside
39+
# is you'll have some added complexity in the type lattice.
40+
if lhs_sites != rhs_sites:
41+
return (NoSites(),)
42+
else:
43+
return (HasNSites(sites=lhs_sites + rhs_sites),)
44+
else:
45+
return (NoSites(),)
46+
47+
@interp.impl(op.stmts.Control)
48+
def control(
49+
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control
50+
):
51+
op_sites = frame.get(stmt.op)
52+
53+
if isinstance(op_sites, HasNSites):
54+
n_sites = op_sites.sites
55+
n_controls_attr = stmt.get_attr_or_prop("n_controls")
56+
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
57+
return (HasNSites(sites=n_sites + n_controls),)
58+
else:
59+
return (NoSites(),)
60+
61+
@interp.impl(op.stmts.Rot)
62+
def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
63+
op_sites = frame.get(stmt.axis)
64+
return (op_sites,)
65+
66+
@interp.impl(op.stmts.Scale)
67+
def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
68+
op_sites = frame.get(stmt.op)
69+
return (op_sites,)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import final
2+
from dataclasses import dataclass
3+
4+
from kirin.lattice import (
5+
SingletonMeta,
6+
BoundedLattice,
7+
SimpleJoinMixin,
8+
SimpleMeetMixin,
9+
)
10+
11+
12+
@dataclass
13+
class NSites(
14+
SimpleJoinMixin["NSites"], SimpleMeetMixin["NSites"], BoundedLattice["NSites"]
15+
):
16+
@classmethod
17+
def bottom(cls) -> "NSites":
18+
return NoSites()
19+
20+
@classmethod
21+
def top(cls) -> "NSites":
22+
return AnySites()
23+
24+
25+
@final
26+
@dataclass
27+
class NoSites(NSites, metaclass=SingletonMeta):
28+
29+
def is_subseteq(self, other: NSites) -> bool:
30+
return True
31+
32+
33+
@final
34+
@dataclass
35+
class AnySites(NSites, metaclass=SingletonMeta):
36+
37+
def is_subseteq(self, other: NSites) -> bool:
38+
return isinstance(other, NSites)
39+
40+
41+
@final
42+
@dataclass
43+
class HasNSites(NSites):
44+
sites: int
45+
46+
def is_subseteq(self, other: NSites) -> bool:
47+
if isinstance(other, HasNSites):
48+
return self.sites == other.sites
49+
return False

src/bloqade/squin/analysis/shape/impls.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

src/bloqade/squin/analysis/shape/lattice.py

Lines changed: 0 additions & 49 deletions
This file was deleted.

src/bloqade/squin/op/stmts.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.decl import info, statement
33

44
from .types import OpType
5-
from .traits import Sized, HasSize, Unitary, MaybeUnitary
5+
from .traits import Sites, Unitary, MaybeUnitary, HasNSitesTrait
66
from .complex import Complex
77
from ._dialect import dialect
88

@@ -77,21 +77,21 @@ class Rot(CompositeOp):
7777

7878
@statement(dialect=dialect)
7979
class Identity(CompositeOp):
80-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSize()})
81-
size: int = info.attribute()
80+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasNSitesTrait()})
81+
sites: int = info.attribute()
8282
result: ir.ResultValue = info.result(OpType)
8383

8484

8585
@statement
8686
class ConstantOp(PrimitiveOp):
87-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sized(1)})
87+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Sites(1)})
8888
result: ir.ResultValue = info.result(OpType)
8989

9090

9191
@statement
9292
class ConstantUnitary(ConstantOp):
9393
traits = frozenset(
94-
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sized(1)}
94+
{ir.Pure(), ir.FromPythonCall(), ir.ConstantLike(), Unitary(), Sites(1)}
9595
)
9696

9797

@@ -105,7 +105,7 @@ class PhaseOp(PrimitiveOp):
105105
$$
106106
"""
107107

108-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
108+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
109109
theta: ir.SSAValue = info.argument(types.Float)
110110
result: ir.ResultValue = info.result(OpType)
111111

@@ -120,7 +120,7 @@ class ShiftOp(PrimitiveOp):
120120
$$
121121
"""
122122

123-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sized(1)})
123+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), Sites(1)})
124124
theta: ir.SSAValue = info.argument(types.Float)
125125
result: ir.ResultValue = info.result(OpType)
126126

0 commit comments

Comments
 (0)