Skip to content

Commit 6e6f823

Browse files
authored
Fix state initialization validation in gemini logical (#653)
Previously, we'd only allow any single non-Clifford gate to be applied to a set of qubits. That means, you couldn't e.g. apply an `Rx` to the first and an `Ry` to the second qubit, since the latter would have already counted as the second gate and it's non-Clifford. This fixes this by actually checking the first gates per qubit address.
1 parent c2a6a32 commit 6e6f823

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed
Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,46 @@
11
from typing import Any
2-
from dataclasses import dataclass
2+
from dataclasses import field, dataclass
33

44
from kirin import ir
5+
from kirin.interp import InterpreterError
56
from kirin.lattice import EmptyLattice
67
from kirin.analysis import Forward, ForwardFrame
78
from kirin.validation import ValidationPass
89

910
from bloqade import squin
11+
from bloqade.analysis.address import Address, AddressReg, AddressAnalysis
1012

1113

14+
@dataclass
1215
class _GeminiLogicalValidationAnalysis(Forward[EmptyLattice]):
1316
keys = ["gemini.validate.logical"]
1417

15-
first_gate = True
1618
lattice = EmptyLattice
19+
addr_frame: ForwardFrame[Address]
20+
21+
first_gates: dict[int, bool] = field(init=False, default_factory=dict)
1722

1823
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
1924
if isinstance(node, squin.gate.stmts.Gate):
20-
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
21-
self.first_gate = False
25+
raise InterpreterError(f"Missing implementation for gate {node}")
2226

2327
return tuple(self.lattice.bottom() for _ in range(len(node.results)))
2428

29+
def check_first_gate(self, qubits: ir.SSAValue) -> bool:
30+
address = self.addr_frame.get(qubits)
31+
32+
if not isinstance(address, AddressReg):
33+
# NOTE: we should have a flat kernel with simple address analysis, so in case we don't
34+
# get concrete addresses, we might as well error here since something's wrong
35+
return False
36+
37+
is_first = True
38+
for addr_int in address.data:
39+
is_first = is_first and self.first_gates.get(addr_int, True)
40+
self.first_gates[addr_int] = False
41+
42+
return is_first
43+
2544
def method_self(self, method: ir.Method) -> EmptyLattice:
2645
return self.lattice.bottom()
2746

@@ -34,7 +53,10 @@ def name(self) -> str:
3453
return "Gemini Logical Validation"
3554

3655
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
37-
analysis = _GeminiLogicalValidationAnalysis(method.dialects)
56+
addr_frame, _ = AddressAnalysis(method.dialects).run(method)
57+
analysis = _GeminiLogicalValidationAnalysis(
58+
method.dialects, addr_frame=addr_frame
59+
)
3860
frame, _ = analysis.run(method)
3961

4062
return frame, analysis.get_validation_errors()

src/bloqade/gemini/analysis/logical_validation/impls.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,7 @@ def non_clifford(
8080
frame: ForwardFrame,
8181
stmt: gate.stmts.SingleQubitGate | gate.stmts.RotationGate,
8282
):
83-
if interp.first_gate:
84-
interp.first_gate = False
83+
if interp.check_first_gate(stmt.qubits):
8584
return ()
8685

8786
interp.add_validation_error(
@@ -92,3 +91,34 @@ def non_clifford(
9291
),
9392
)
9493
return ()
94+
95+
@_interp.impl(gate.stmts.X)
96+
@_interp.impl(gate.stmts.Y)
97+
@_interp.impl(gate.stmts.SqrtX)
98+
@_interp.impl(gate.stmts.SqrtY)
99+
@_interp.impl(gate.stmts.Z)
100+
@_interp.impl(gate.stmts.H)
101+
@_interp.impl(gate.stmts.S)
102+
def clifford(
103+
self,
104+
interp: _GeminiLogicalValidationAnalysis,
105+
frame: ForwardFrame,
106+
stmt: gate.stmts.SingleQubitGate,
107+
):
108+
# NOTE: ignore result, but make sure the first gate flag is set to False
109+
interp.check_first_gate(stmt.qubits)
110+
111+
return ()
112+
113+
@_interp.impl(gate.stmts.CX)
114+
@_interp.impl(gate.stmts.CY)
115+
@_interp.impl(gate.stmts.CZ)
116+
def controlled_gate(
117+
self,
118+
interp: _GeminiLogicalValidationAnalysis,
119+
frame: ForwardFrame,
120+
stmt: gate.stmts.ControlledGate,
121+
):
122+
interp.check_first_gate(stmt.controls)
123+
interp.check_first_gate(stmt.targets)
124+
return ()

test/gemini/test_logical_validation.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from bloqade import squin, gemini
66
from bloqade.types import Qubit
7+
from bloqade.analysis.address import AddressAnalysis
78
from bloqade.gemini.analysis.logical_validation.analysis import (
89
GeminiLogicalValidation,
910
_GeminiLogicalValidationAnalysis,
@@ -35,7 +36,10 @@ def main():
3536
if m2:
3637
squin.y(q[2])
3738

38-
frame, _ = _GeminiLogicalValidationAnalysis(main.dialects).run_no_raise(main)
39+
addr_frame, _ = AddressAnalysis(main.dialects).run(main)
40+
frame, _ = _GeminiLogicalValidationAnalysis(
41+
main.dialects, addr_frame=addr_frame
42+
).run_no_raise(main)
3943

4044
main.print(analysis=frame.entries)
4145

@@ -188,3 +192,16 @@ def main(n: int):
188192
assert len(e.errors) == 4
189193

190194
assert did_error
195+
196+
197+
def test_non_clifford_parallel_gates():
198+
@gemini.logical.kernel
199+
def main():
200+
q = squin.qalloc(5)
201+
squin.rx(0.123, q[0])
202+
squin.broadcast.ry(0.333, q[1:])
203+
204+
squin.broadcast.x(q)
205+
squin.broadcast.h(q[1:])
206+
207+
main.print()

0 commit comments

Comments
 (0)