Skip to content

Commit fd62258

Browse files
committed
add validation pass to cath unsupported squin statements/structure in rewrite to stim
1 parent 64e6028 commit fd62258

File tree

6 files changed

+311
-0
lines changed

6 files changed

+311
-0
lines changed

src/bloqade/stim/analysis/__init__.py

Whitespace-only changes.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .analysis import StimFromSquinValidation as StimFromSquinValidation
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
from typing import Any
2+
3+
from kirin import ir, interp
4+
from kirin.lattice import EmptyLattice
5+
from kirin.analysis import Forward
6+
from kirin.dialects import scf
7+
from kirin.validation import ValidationPass
8+
from kirin.analysis.forward import ForwardFrame
9+
10+
from bloqade.qubit import stmts as qubit_stmts
11+
from bloqade.squin import gate
12+
from bloqade.qubit._dialect import dialect as qubit_dialect
13+
14+
PauliGateType = (gate.stmts.X, gate.stmts.Y, gate.stmts.Z)
15+
16+
17+
class _StimIfElseValidationAnalysis(Forward[EmptyLattice]):
18+
keys = ["stim.validate.from_squin"]
19+
20+
lattice = EmptyLattice
21+
22+
def method_self(self, method: ir.Method) -> EmptyLattice:
23+
return self.lattice.bottom()
24+
25+
def eval_fallback(
26+
self, frame: ForwardFrame[EmptyLattice], node: ir.Statement
27+
) -> tuple[EmptyLattice, ...]:
28+
return tuple(self.lattice.bottom() for _ in range(len(node.results)))
29+
30+
31+
@scf.dialect.register(key="stim.validate.from_squin")
32+
class _ScfMethods(interp.MethodTable):
33+
34+
@interp.impl(scf.IfElse)
35+
def if_else(
36+
self,
37+
interp_: _StimIfElseValidationAnalysis,
38+
frame: ForwardFrame[EmptyLattice],
39+
stmt: scf.IfElse,
40+
):
41+
for child in stmt.walk(include_self=False):
42+
if isinstance(child, scf.IfElse):
43+
interp_.add_validation_error(
44+
stmt,
45+
ir.ValidationError(
46+
stmt,
47+
"Nested IfElse statements are not supported in rewriting to Stim IR.",
48+
),
49+
)
50+
break
51+
52+
if stmt.else_body.blocks and not (
53+
len(stmt.else_body.blocks[0].stmts) == 1
54+
and isinstance(stmt.else_body.blocks[0].last_stmt, scf.Yield)
55+
):
56+
interp_.add_validation_error(
57+
stmt,
58+
ir.ValidationError(
59+
stmt,
60+
"IfElse statements with an else body are not supported in rewriting to Stim IR.",
61+
),
62+
)
63+
64+
for child in stmt.then_body.walk():
65+
if isinstance(child, gate.stmts.Gate) and not isinstance(
66+
child, PauliGateType
67+
):
68+
interp_.add_validation_error(
69+
stmt,
70+
ir.ValidationError(
71+
stmt,
72+
f"Only Pauli gates (X, Y, Z) are allowed inside an scf.IfElse "
73+
f"'then'-body for rewriting to Stim IR. Found: {type(child).__name__}",
74+
),
75+
)
76+
77+
78+
@qubit_dialect.register(key="stim.validate.from_squin")
79+
class _QubitMethods(interp.MethodTable):
80+
81+
@interp.impl(qubit_stmts.IsZero)
82+
def is_zero(
83+
self,
84+
interp_: _StimIfElseValidationAnalysis,
85+
frame: ForwardFrame[EmptyLattice],
86+
stmt: qubit_stmts.IsZero,
87+
):
88+
interp_.add_validation_error(
89+
stmt,
90+
ir.ValidationError(
91+
stmt,
92+
"is_zero predicate is not supported in rewriting to Stim IR. Only the is_one predicate is supported.",
93+
),
94+
)
95+
96+
@interp.impl(qubit_stmts.IsLost)
97+
def is_lost(
98+
self,
99+
interp_: _StimIfElseValidationAnalysis,
100+
frame: ForwardFrame[EmptyLattice],
101+
stmt: qubit_stmts.IsLost,
102+
):
103+
interp_.add_validation_error(
104+
stmt,
105+
ir.ValidationError(
106+
stmt,
107+
"is_lost predicate is not supported in rewriting to Stim IR. Only the is_one predicate is supported.",
108+
),
109+
)
110+
111+
112+
class StimFromSquinValidation(ValidationPass):
113+
def name(self) -> str:
114+
return "Stim from Squin Validation"
115+
116+
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
117+
analysis = _StimIfElseValidationAnalysis(method.dialects)
118+
frame, _ = analysis.run(method)
119+
return frame, analysis.get_validation_errors()

test/analysis/measure_id/dev.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from bloqade import squin
2+
from bloqade.analysis.measure_id import MeasurementIDAnalysis
3+
from bloqade.stim.passes.flatten import Flatten
4+
5+
6+
@squin.kernel
7+
def test():
8+
qs = squin.qalloc(5)
9+
ms = squin.broadcast.measure(qs)
10+
squin.set_detector([ms[0]], coordinates=[0, 0])
11+
new_ms = squin.broadcast.measure(qs)
12+
squin.set_detector([new_ms[0]], coordinates=[0, 0])
13+
14+
return ms
15+
16+
17+
Flatten(test.dialects).fixpoint(test)
18+
frame, _ = MeasurementIDAnalysis(test.dialects).run(test)
19+
20+
test.print(analysis=frame.entries)

test/stim/from_squin_validation/__init__.py

Whitespace-only changes.
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
import pytest
2+
from kirin.validation import ValidationSuite
3+
from kirin.ir.exception import ValidationErrorGroup
4+
5+
from bloqade import squin
6+
from bloqade.rewrite.passes import AggressiveUnroll
7+
from bloqade.stim.analysis.from_squin_validation import StimFromSquinValidation
8+
9+
10+
def test_is_zero_prohibited():
11+
@squin.kernel
12+
def main():
13+
q = squin.qalloc(3)
14+
ms = squin.broadcast.measure(q)
15+
could_be_zero = squin.broadcast.is_zero(ms)
16+
squin.broadcast.reset(q)
17+
if could_be_zero[0]:
18+
squin.x(q[0])
19+
20+
AggressiveUnroll(main.dialects).fixpoint(main)
21+
22+
suite = ValidationSuite([StimFromSquinValidation])
23+
result = suite.validate(main)
24+
assert result.error_count() == 1
25+
26+
with pytest.raises(ValidationErrorGroup):
27+
result.raise_if_invalid()
28+
29+
30+
def test_is_lost_prohibited():
31+
@squin.kernel
32+
def main():
33+
q = squin.qalloc(3)
34+
ms = squin.broadcast.measure(q)
35+
could_be_lost = squin.broadcast.is_lost(ms)
36+
squin.broadcast.reset(q)
37+
if could_be_lost[0]:
38+
squin.x(q[0])
39+
40+
AggressiveUnroll(main.dialects).fixpoint(main)
41+
42+
suite = ValidationSuite([StimFromSquinValidation])
43+
result = suite.validate(main)
44+
assert result.error_count() == 1
45+
46+
with pytest.raises(ValidationErrorGroup):
47+
result.raise_if_invalid()
48+
49+
50+
def test_is_one_allowed():
51+
@squin.kernel
52+
def main():
53+
q = squin.qalloc(3)
54+
ms = squin.broadcast.measure(q)
55+
could_be_one = squin.broadcast.is_one(ms)
56+
squin.broadcast.reset(q)
57+
if could_be_one[0]:
58+
squin.x(q[0])
59+
60+
AggressiveUnroll(main.dialects).fixpoint(main)
61+
62+
suite = ValidationSuite([StimFromSquinValidation])
63+
result = suite.validate(main)
64+
assert result.error_count() == 0
65+
66+
67+
def test_nested_ifelse():
68+
@squin.kernel
69+
def main():
70+
q = squin.qalloc(3)
71+
ms = squin.broadcast.measure(q)
72+
could_be_one = squin.broadcast.is_one(ms)
73+
squin.broadcast.reset(q)
74+
if could_be_one[0]:
75+
if could_be_one[1]:
76+
squin.x(q[1])
77+
78+
AggressiveUnroll(main.dialects).fixpoint(main)
79+
80+
suite = ValidationSuite([StimFromSquinValidation])
81+
result = suite.validate(main)
82+
assert result.error_count() >= 1
83+
84+
with pytest.raises(ValidationErrorGroup):
85+
result.raise_if_invalid()
86+
87+
88+
def test_else_body():
89+
@squin.kernel(fold=False)
90+
def main():
91+
q = squin.qalloc(3)
92+
ms = squin.broadcast.measure(q)
93+
could_be_one = squin.broadcast.is_one(ms)
94+
squin.broadcast.reset(q)
95+
if could_be_one[0]:
96+
squin.x(q[0])
97+
else:
98+
squin.z(q[0])
99+
100+
suite = ValidationSuite([StimFromSquinValidation])
101+
result = suite.validate(main)
102+
assert result.error_count() == 1
103+
104+
with pytest.raises(ValidationErrorGroup):
105+
result.raise_if_invalid()
106+
107+
108+
def test_non_pauli_gate_in_ifelse():
109+
@squin.kernel
110+
def main():
111+
q = squin.qalloc(3)
112+
ms = squin.broadcast.measure(q)
113+
could_be_one = squin.broadcast.is_one(ms)
114+
squin.broadcast.reset(q)
115+
if could_be_one[0]:
116+
squin.h(q[0])
117+
118+
AggressiveUnroll(main.dialects).fixpoint(main)
119+
120+
suite = ValidationSuite([StimFromSquinValidation])
121+
result = suite.validate(main)
122+
assert result.error_count() == 1
123+
124+
with pytest.raises(ValidationErrorGroup):
125+
result.raise_if_invalid()
126+
127+
128+
def test_pauli_gates_valid():
129+
@squin.kernel
130+
def main():
131+
q = squin.qalloc(3)
132+
ms = squin.broadcast.measure(q)
133+
could_be_one = squin.broadcast.is_one(ms)
134+
squin.broadcast.reset(q)
135+
if could_be_one[0]:
136+
squin.x(q[0])
137+
if could_be_one[1]:
138+
squin.y(q[1])
139+
if could_be_one[2]:
140+
squin.z(q[2])
141+
142+
AggressiveUnroll(main.dialects).fixpoint(main)
143+
144+
suite = ValidationSuite([StimFromSquinValidation])
145+
result = suite.validate(main)
146+
assert result.error_count() == 0
147+
148+
149+
def test_multiple_errors():
150+
@squin.kernel
151+
def main():
152+
q = squin.qalloc(3)
153+
ms = squin.broadcast.measure(q)
154+
could_be_zero = squin.broadcast.is_zero(ms)
155+
could_be_one = squin.broadcast.is_one(ms)
156+
squin.broadcast.reset(q)
157+
158+
if could_be_one[0]:
159+
squin.h(q[0])
160+
161+
if could_be_zero[1]:
162+
squin.x(q[1])
163+
164+
AggressiveUnroll(main.dialects).fixpoint(main)
165+
166+
suite = ValidationSuite([StimFromSquinValidation])
167+
result = suite.validate(main)
168+
assert result.error_count() == 2
169+
170+
with pytest.raises(ValidationErrorGroup):
171+
result.raise_if_invalid()

0 commit comments

Comments
 (0)