Skip to content

Commit c8c38ae

Browse files
committed
Move wrap analysis rewrite into its own file
1 parent 47691e2 commit c8c38ae

File tree

3 files changed

+78
-68
lines changed

3 files changed

+78
-68
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .stim import (
22
SitesAttribute as SitesAttribute,
33
AddressAttribute as AddressAttribute,
4-
WrapSquinAnalysis as WrapSquinAnalysis,
54
_SquinToStim as _SquinToStim,
65
)
6+
from .wrap_analysis import WrapSquinAnalysis as WrapSquinAnalysis

src/bloqade/squin/rewrite/stim.py

Lines changed: 4 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,15 @@
1-
from typing import Dict, cast
1+
from typing import cast
22
from dataclasses import dataclass
33

44
from kirin import ir
55
from kirin.dialects import py
66
from kirin.rewrite.abc import RewriteRule, RewriteResult
7-
from kirin.print.printer import Printer
87

98
from bloqade import stim
109
from bloqade.squin import op, wire, qubit
11-
from bloqade.analysis.address import Address, AddressWire, AddressQubit, AddressTuple
12-
from bloqade.squin.analysis.nsites import Sites, NumberSites
13-
14-
15-
@wire.dialect.register
16-
@dataclass
17-
class AddressAttribute(ir.Attribute):
18-
19-
name = "Address"
20-
address: Address
21-
22-
def __hash__(self) -> int:
23-
return hash(self.address)
24-
25-
def print_impl(self, printer: Printer) -> None:
26-
# Can return to implementing this later
27-
printer.print(self.address)
28-
29-
30-
@op.dialect.register
31-
@dataclass
32-
class SitesAttribute(ir.Attribute):
33-
34-
name = "Sites"
35-
sites: Sites
36-
37-
def __hash__(self) -> int:
38-
return hash(self.sites)
39-
40-
def print_impl(self, printer: Printer) -> None:
41-
# Can return to implementing this later
42-
printer.print(self.sites)
43-
44-
45-
@dataclass
46-
class WrapSquinAnalysis(RewriteRule):
47-
48-
address_analysis: Dict[ir.SSAValue, Address]
49-
op_site_analysis: Dict[ir.SSAValue, Sites]
50-
51-
def wrap(self, value: ir.SSAValue) -> bool:
52-
address_analysis_result = self.address_analysis[value]
53-
op_site_analysis_result = self.op_site_analysis[value]
54-
55-
if value.hints.get("address") and value.hints.get("sites"):
56-
return False
57-
else:
58-
value.hints["address"] = AddressAttribute(address_analysis_result)
59-
value.hints["sites"] = SitesAttribute(op_site_analysis_result)
60-
61-
return True
62-
63-
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
64-
has_done_something = False
65-
for arg in node.args:
66-
if self.wrap(arg):
67-
has_done_something = True
68-
return RewriteResult(has_done_something=has_done_something)
69-
70-
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
71-
has_done_something = False
72-
for result in node.results:
73-
if self.wrap(result):
74-
has_done_something = True
75-
return RewriteResult(has_done_something=has_done_something)
10+
from bloqade.analysis.address import AddressWire, AddressQubit, AddressTuple
11+
from bloqade.squin.analysis.nsites import NumberSites
12+
from bloqade.squin.rewrite.wrap_analysis import SitesAttribute, AddressAttribute
7613

7714

7815
@dataclass
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from typing import Dict
2+
from dataclasses import dataclass
3+
4+
from kirin import ir
5+
from kirin.rewrite.abc import RewriteRule, RewriteResult
6+
from kirin.print.printer import Printer
7+
8+
from bloqade.squin import op, wire
9+
from bloqade.analysis.address import Address
10+
from bloqade.squin.analysis.nsites import Sites
11+
12+
13+
@wire.dialect.register
14+
@dataclass
15+
class AddressAttribute(ir.Attribute):
16+
17+
name = "Address"
18+
address: Address
19+
20+
def __hash__(self) -> int:
21+
return hash(self.address)
22+
23+
def print_impl(self, printer: Printer) -> None:
24+
# Can return to implementing this later
25+
printer.print(self.address)
26+
27+
28+
@op.dialect.register
29+
@dataclass
30+
class SitesAttribute(ir.Attribute):
31+
32+
name = "Sites"
33+
sites: Sites
34+
35+
def __hash__(self) -> int:
36+
return hash(self.sites)
37+
38+
def print_impl(self, printer: Printer) -> None:
39+
# Can return to implementing this later
40+
printer.print(self.sites)
41+
42+
43+
@dataclass
44+
class WrapSquinAnalysis(RewriteRule):
45+
46+
address_analysis: Dict[ir.SSAValue, Address]
47+
op_site_analysis: Dict[ir.SSAValue, Sites]
48+
49+
def wrap(self, value: ir.SSAValue) -> bool:
50+
address_analysis_result = self.address_analysis[value]
51+
op_site_analysis_result = self.op_site_analysis[value]
52+
53+
if value.hints.get("address") and value.hints.get("sites"):
54+
return False
55+
else:
56+
value.hints["address"] = AddressAttribute(address_analysis_result)
57+
value.hints["sites"] = SitesAttribute(op_site_analysis_result)
58+
59+
return True
60+
61+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
62+
has_done_something = False
63+
for arg in node.args:
64+
if self.wrap(arg):
65+
has_done_something = True
66+
return RewriteResult(has_done_something=has_done_something)
67+
68+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
69+
has_done_something = False
70+
for result in node.results:
71+
if self.wrap(result):
72+
has_done_something = True
73+
return RewriteResult(has_done_something=has_done_something)

0 commit comments

Comments
 (0)