Skip to content

Commit d7c8b3a

Browse files
committed
merging
2 parents 36ae4a1 + 3b788ca commit d7c8b3a

File tree

7 files changed

+454
-14
lines changed

7 files changed

+454
-14
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Need this for impl registration to work properly!
2+
from . import impls as impls
3+
from .lattice import (
4+
NoSites as NoSites,
5+
AnySites as AnySites,
6+
NumberSites as NumberSites,
7+
)
8+
from .analysis import NSitesAnalysis as NSitesAnalysis
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# from typing import cast
2+
3+
from kirin import ir
4+
from kirin.analysis import Forward
5+
from kirin.analysis.forward import ForwardFrame
6+
7+
from bloqade.squin.op.types import OpType
8+
from bloqade.squin.op.traits import HasSites, FixedSites
9+
10+
from .lattice import Sites, NoSites, NumberSites
11+
12+
13+
class NSitesAnalysis(Forward[Sites]):
14+
15+
keys = ["op.nsites"]
16+
lattice = Sites
17+
18+
# Take a page from const prop in Kirin,
19+
# I can get the data I want from the SizedTrait
20+
# and go from there
21+
22+
## This gets called before the registry look up
23+
def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
24+
method = self.lookup_registry(frame, stmt)
25+
if method is not None:
26+
return method(self, frame, stmt)
27+
elif stmt.has_trait(HasSites):
28+
has_sites_trait = stmt.get_trait(HasSites)
29+
sites = has_sites_trait.get_sites(stmt)
30+
return (NumberSites(sites=sites),)
31+
elif stmt.has_trait(FixedSites):
32+
sites_trait = stmt.get_trait(FixedSites)
33+
return (NumberSites(sites=sites_trait.data),)
34+
else:
35+
return (NoSites(),)
36+
37+
# For when no implementation is found for the statement
38+
def eval_stmt_fallback(
39+
self, frame: ForwardFrame[Sites], stmt: ir.Statement
40+
) -> tuple[Sites, ...]: # some form of Shape will go back into the frame
41+
return tuple(
42+
(
43+
self.lattice.top()
44+
if result.type.is_subseteq(OpType)
45+
else self.lattice.bottom()
46+
)
47+
for result in stmt.results
48+
)
49+
50+
def run_method(self, method: ir.Method, args: tuple[Sites, ...]):
51+
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
52+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from typing import cast
2+
3+
from kirin import ir, interp
4+
5+
from bloqade.squin import op
6+
7+
from .lattice import (
8+
NoSites,
9+
NumberSites,
10+
)
11+
from .analysis import NSitesAnalysis
12+
13+
14+
@op.dialect.register(key="op.nsites")
15+
class SquinOp(interp.MethodTable):
16+
17+
@interp.impl(op.stmts.Kron)
18+
def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
19+
lhs = frame.get(stmt.lhs)
20+
rhs = frame.get(stmt.rhs)
21+
if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
22+
new_n_sites = lhs.sites + rhs.sites
23+
return (NumberSites(sites=new_n_sites),)
24+
else:
25+
return (NoSites(),)
26+
27+
@interp.impl(op.stmts.Mult)
28+
def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult):
29+
lhs = frame.get(stmt.lhs)
30+
rhs = frame.get(stmt.rhs)
31+
32+
if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
33+
lhs_sites = lhs.sites
34+
rhs_sites = rhs.sites
35+
# I originally considered throwing an exception here
36+
# but Xiu-zhe (Roger) Luo has pointed out it would be
37+
# a much better UX to add a type element that
38+
# could explicitly indicate the error. The downside
39+
# is you'll have some added complexity in the type lattice.
40+
if lhs_sites != rhs_sites:
41+
return (NoSites(),)
42+
else:
43+
return (NumberSites(sites=lhs_sites + rhs_sites),)
44+
else:
45+
return (NoSites(),)
46+
47+
@interp.impl(op.stmts.Control)
48+
def control(
49+
self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Control
50+
):
51+
op_sites = frame.get(stmt.op)
52+
53+
if isinstance(op_sites, NumberSites):
54+
n_sites = op_sites.sites
55+
n_controls_attr = stmt.get_attr_or_prop("n_controls")
56+
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
57+
return (NumberSites(sites=n_sites + n_controls),)
58+
else:
59+
return (NoSites(),)
60+
61+
@interp.impl(op.stmts.Rot)
62+
def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
63+
op_sites = frame.get(stmt.axis)
64+
return (op_sites,)
65+
66+
@interp.impl(op.stmts.Scale)
67+
def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
68+
op_sites = frame.get(stmt.op)
69+
return (op_sites,)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import final
2+
from dataclasses import dataclass
3+
4+
from kirin.lattice import (
5+
SingletonMeta,
6+
BoundedLattice,
7+
SimpleJoinMixin,
8+
SimpleMeetMixin,
9+
)
10+
11+
12+
@dataclass
13+
class Sites(
14+
SimpleJoinMixin["Sites"], SimpleMeetMixin["Sites"], BoundedLattice["Sites"]
15+
):
16+
@classmethod
17+
def bottom(cls) -> "Sites":
18+
return NoSites()
19+
20+
@classmethod
21+
def top(cls) -> "Sites":
22+
return AnySites()
23+
24+
25+
@final
26+
@dataclass
27+
class NoSites(Sites, metaclass=SingletonMeta):
28+
29+
def is_subseteq(self, other: Sites) -> bool:
30+
return True
31+
32+
33+
@final
34+
@dataclass
35+
class AnySites(Sites, metaclass=SingletonMeta):
36+
37+
def is_subseteq(self, other: Sites) -> bool:
38+
return isinstance(other, Sites)
39+
40+
41+
@final
42+
@dataclass
43+
class NumberSites(Sites):
44+
sites: int
45+
46+
def is_subseteq(self, other: Sites) -> bool:
47+
if isinstance(other, NumberSites):
48+
return self.sites == other.sites
49+
return False

src/bloqade/squin/op/stmts.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.decl import info, statement
33

44
from .types import OpType
5-
from .traits import Sized, HasSize, Unitary, MaybeUnitary
5+
from .traits import Unitary, HasSites, FixedSites, MaybeUnitary
66
from .complex import Complex
77
from ._dialect import dialect
88

@@ -77,23 +77,29 @@ class Rot(CompositeOp):
7777

7878
@statement(dialect=dialect)
7979
class Identity(CompositeOp):
80-
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSize()})
80+
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), HasSites()})
8181
size: int = info.attribute()
8282
result: ir.ResultValue = info.result(OpType)
8383

8484

8585
@statement
8686
class ConstantOp(PrimitiveOp):
8787
traits = frozenset(
88-
{ir.Pure(), lowering.FromPythonCall(), ir.ConstantLike(), Sized(1)}
88+
{ir.Pure(), lowering.FromPythonCall(), ir.ConstantLike(), FixedSites(1)}
8989
)
9090
result: ir.ResultValue = info.result(OpType)
9191

9292

9393
@statement
9494
class ConstantUnitary(ConstantOp):
9595
traits = frozenset(
96-
{ir.Pure(), lowering.FromPythonCall(), ir.ConstantLike(), Unitary(), Sized(1)}
96+
{
97+
ir.Pure(),
98+
lowering.FromPythonCall(),
99+
ir.ConstantLike(),
100+
Unitary(),
101+
FixedSites(1),
102+
}
97103
)
98104

99105

@@ -107,7 +113,7 @@ class PhaseOp(PrimitiveOp):
107113
$$
108114
"""
109115

110-
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), Sized(1)})
116+
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
111117
theta: ir.SSAValue = info.argument(types.Float)
112118
result: ir.ResultValue = info.result(OpType)
113119

@@ -122,7 +128,7 @@ class ShiftOp(PrimitiveOp):
122128
$$
123129
"""
124130

125-
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), Sized(1)})
131+
traits = frozenset({ir.Pure(), lowering.FromPythonCall(), Unitary(), FixedSites(1)})
126132
theta: ir.SSAValue = info.argument(types.Float)
127133
result: ir.ResultValue = info.result(OpType)
128134

src/bloqade/squin/op/traits.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55

66

77
@dataclass(frozen=True)
8-
class Sized(ir.StmtTrait):
8+
class FixedSites(ir.StmtTrait):
99
data: int
1010

1111

1212
@dataclass(frozen=True)
13-
class HasSize(ir.StmtTrait):
14-
"""An operator with a `size` attribute."""
13+
class HasSites(ir.StmtTrait):
14+
"""An operator with a `sites` attribute."""
1515

16-
def get_size(self, stmt: ir.Statement):
17-
attr = stmt.get_attr_or_prop("size")
16+
def get_sites(self, stmt: ir.Statement):
17+
attr = stmt.get_attr_or_prop("sites")
1818
if attr is None:
19-
raise ValueError(f"Missing size attribute in {stmt}")
19+
raise ValueError(f"Missing sites attribute in {stmt}")
2020
return cast(ir.PyAttr[int], attr).data
2121

22-
def set_size(self, stmt: ir.Statement, value: int):
23-
stmt.attributes["size"] = ir.PyAttr(value)
22+
def set_sites(self, stmt: ir.Statement, value: int):
23+
stmt.attributes["sites"] = ir.PyAttr(value)
2424
return
2525

2626

0 commit comments

Comments
 (0)