Skip to content

Commit 39fdc84

Browse files
committed
Rename trait and type in lattice to avoid collision
1 parent 22fc9e7 commit 39fdc84

File tree

7 files changed

+26
-26
lines changed

7 files changed

+26
-26
lines changed

src/bloqade/squin/analysis/nsites/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
from .lattice import (
44
NoSites as NoSites,
55
AnySites as AnySites,
6-
HasNSites as HasNSites,
6+
NumberSites as NumberSites,
77
)
88
from .analysis import NSitesAnalysis as NSitesAnalysis

src/bloqade/squin/analysis/nsites/analysis.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from kirin.analysis.forward import ForwardFrame
66

77
from bloqade.squin.op.types import OpType
8-
from bloqade.squin.op.traits import NSites, HasNSitesTrait
8+
from bloqade.squin.op.traits import NSites, HasSites
99

10-
from .lattice import Sites, NoSites, HasNSites
10+
from .lattice import Sites, NoSites, NumberSites
1111

1212

1313
class NSitesAnalysis(Forward[Sites]):
@@ -24,13 +24,13 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
2424
method = self.lookup_registry(frame, stmt)
2525
if method is not None:
2626
return method(self, frame, stmt)
27-
elif stmt.has_trait(HasNSitesTrait):
28-
has_n_sites_trait = stmt.get_trait(HasNSitesTrait)
27+
elif stmt.has_trait(HasSites):
28+
has_n_sites_trait = stmt.get_trait(HasSites)
2929
sites = has_n_sites_trait.get_sites(stmt)
30-
return (HasNSites(sites=sites),)
30+
return (NumberSites(sites=sites),)
3131
elif stmt.has_trait(NSites):
3232
sites_trait = stmt.get_trait(NSites)
33-
return (HasNSites(sites=sites_trait.data),)
33+
return (NumberSites(sites=sites_trait.data),)
3434
else:
3535
return (NoSites(),)
3636

src/bloqade/squin/analysis/nsites/impls.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .lattice import (
88
NoSites,
9-
HasNSites,
9+
NumberSites,
1010
)
1111
from .analysis import NSitesAnalysis
1212

@@ -18,9 +18,9 @@ class SquinOp(interp.MethodTable):
1818
def kron(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Kron):
1919
lhs = frame.get(stmt.lhs)
2020
rhs = frame.get(stmt.rhs)
21-
if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
21+
if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
2222
new_n_sites = lhs.sites + rhs.sites
23-
return (HasNSites(sites=new_n_sites),)
23+
return (NumberSites(sites=new_n_sites),)
2424
else:
2525
return (NoSites(),)
2626

@@ -29,7 +29,7 @@ def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult)
2929
lhs = frame.get(stmt.lhs)
3030
rhs = frame.get(stmt.rhs)
3131

32-
if isinstance(lhs, HasNSites) and isinstance(rhs, HasNSites):
32+
if isinstance(lhs, NumberSites) and isinstance(rhs, NumberSites):
3333
lhs_sites = lhs.sites
3434
rhs_sites = rhs.sites
3535
# I originally considered throwing an exception here
@@ -40,7 +40,7 @@ def mult(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Mult)
4040
if lhs_sites != rhs_sites:
4141
return (NoSites(),)
4242
else:
43-
return (HasNSites(sites=lhs_sites + rhs_sites),)
43+
return (NumberSites(sites=lhs_sites + rhs_sites),)
4444
else:
4545
return (NoSites(),)
4646

@@ -50,11 +50,11 @@ def control(
5050
):
5151
op_sites = frame.get(stmt.op)
5252

53-
if isinstance(op_sites, HasNSites):
53+
if isinstance(op_sites, NumberSites):
5454
n_sites = op_sites.sites
5555
n_controls_attr = stmt.get_attr_or_prop("n_controls")
5656
n_controls = cast(ir.PyAttr[int], n_controls_attr).data
57-
return (HasNSites(sites=n_sites + n_controls),)
57+
return (NumberSites(sites=n_sites + n_controls),)
5858
else:
5959
return (NoSites(),)
6060

src/bloqade/squin/analysis/nsites/lattice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def is_subseteq(self, other: Sites) -> bool:
4040

4141
@final
4242
@dataclass
43-
class HasNSites(Sites):
43+
class NumberSites(Sites):
4444
sites: int
4545

4646
def is_subseteq(self, other: Sites) -> bool:
47-
if isinstance(other, HasNSites):
47+
if isinstance(other, NumberSites):
4848
return self.sites == other.sites
4949
return False

src/bloqade/squin/op/stmts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.decl import info, statement
33

44
from .types import OpType
5-
from .traits import NSites, Unitary, MaybeUnitary, HasNSitesTrait
5+
from .traits import NSites, Unitary, HasSites, MaybeUnitary
66
from .complex import Complex
77
from ._dialect import dialect
88

@@ -77,7 +77,7 @@ class Rot(CompositeOp):
7777

7878
@statement(dialect=dialect)
7979
class Identity(CompositeOp):
80-
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasNSitesTrait()})
80+
traits = frozenset({ir.Pure(), ir.FromPythonCall(), Unitary(), HasSites()})
8181
sites: int = info.attribute()
8282
result: ir.ResultValue = info.result(OpType)
8383

src/bloqade/squin/op/traits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class NSites(ir.StmtTrait):
1010

1111

1212
@dataclass(frozen=True)
13-
class HasNSitesTrait(ir.StmtTrait):
13+
class HasSites(ir.StmtTrait):
1414
"""An operator with a `sites` attribute."""
1515

1616
def get_sites(self, stmt: ir.Statement):

test/squin/analysis/test_nsites_analysis.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_primitive_ops():
6969

7070
has_n_sites = []
7171
for nsites_type in nsites_frame.entries.values():
72-
if isinstance(nsites_type, nsites.HasNSites):
72+
if isinstance(nsites_type, nsites.NumberSites):
7373
has_n_sites.append(nsites_type)
7474
assert nsites_type.sites == 1
7575

@@ -98,7 +98,7 @@ def test_control():
9898

9999
has_n_sites = []
100100
for nsites_type in nsites_frame.entries.values():
101-
if isinstance(nsites_type, nsites.HasNSites):
101+
if isinstance(nsites_type, nsites.NumberSites):
102102
has_n_sites.append(nsites_type)
103103

104104
assert len(has_n_sites) == 2
@@ -123,7 +123,7 @@ def test_kron():
123123

124124
has_n_sites = []
125125
for nsites_type in nsites_frame.entries.values():
126-
if isinstance(nsites_type, nsites.HasNSites):
126+
if isinstance(nsites_type, nsites.NumberSites):
127127
has_n_sites.append(nsites_type)
128128

129129
assert len(has_n_sites) == 3
@@ -151,7 +151,7 @@ def test_mult_square_same_sites():
151151

152152
has_n_sites = []
153153
for nsites_type in nsites_frame.entries.values():
154-
if isinstance(nsites_type, nsites.HasNSites):
154+
if isinstance(nsites_type, nsites.NumberSites):
155155
has_n_sites.append(nsites_type)
156156

157157
# should be three HasNSites types
@@ -190,7 +190,7 @@ def test_mult_square_different_sites():
190190
has_n_sites = []
191191
no_sites = []
192192
for nsite_type in nsites_types:
193-
if isinstance(nsite_type, nsites.HasNSites):
193+
if isinstance(nsite_type, nsites.NumberSites):
194194
has_n_sites.append(nsite_type)
195195
elif isinstance(nsite_type, nsites.NoSites):
196196
no_sites.append(nsite_type)
@@ -221,7 +221,7 @@ def test_rot():
221221

222222
has_n_sites = []
223223
for nsites_type in nsites_frame.entries.values():
224-
if isinstance(nsites_type, nsites.HasNSites):
224+
if isinstance(nsites_type, nsites.NumberSites):
225225
has_n_sites.append(nsites_type)
226226

227227
assert len(has_n_sites) == 2
@@ -247,7 +247,7 @@ def test_scale():
247247

248248
has_n_sites = []
249249
for nsites_type in nsites_frame.entries.values():
250-
if isinstance(nsites_type, nsites.HasNSites):
250+
if isinstance(nsites_type, nsites.NumberSites):
251251
has_n_sites.append(nsites_type)
252252

253253
assert len(has_n_sites) == 2

0 commit comments

Comments
 (0)