Skip to content

Commit e7e2ef2

Browse files
committed
Verify that U3 can only occur at first position
1 parent 84b3492 commit e7e2ef2

File tree

7 files changed

+119
-19
lines changed

7 files changed

+119
-19
lines changed
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
from kirin import ir
22

3+
from bloqade import squin
34
from bloqade.validation.analysis import ValidationFrame, ValidationAnalysis
4-
from bloqade.validation.analysis.lattice import ErrorType
55

66

77
class GeminiLogicalValidationAnalysis(ValidationAnalysis):
88
keys = ["gemini.validate.logical"]
9-
lattice = ErrorType
109

11-
has_allocated_qubits: bool = False
12-
13-
def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]):
14-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
10+
first_gate = True
1511

1612
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
13+
14+
if isinstance(stmt, squin.gate.stmts.Gate):
15+
# NOTE: to validate that only the first encountered gate can be non-Clifford, we need to track this here
16+
self.first_gate = False
17+
1718
return (self.lattice.top(),)

src/bloqade/gemini/analysis/logical_validation/impls.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import scf, func
44

5+
from bloqade.squin import gate
56
from bloqade.validation.analysis import ValidationFrame
67
from bloqade.validation.analysis.lattice import Error
78

@@ -58,3 +59,24 @@ def invoke(
5859
help="Make sure to decorate your function with `@logical(inline = True)` or `@logical(aggressive_unroll = True)` to inline function calls",
5960
),
6061
)
62+
63+
64+
@gate.dialect.register(key="gemini.validate.logical")
65+
class __GateGeminiLogicalValidation(_interp.MethodTable):
66+
@_interp.impl(gate.stmts.U3)
67+
def u3(
68+
self,
69+
interp: GeminiLogicalValidationAnalysis,
70+
frame: ValidationFrame,
71+
stmt: gate.stmts.U3,
72+
):
73+
if interp.first_gate:
74+
interp.first_gate = False
75+
return (interp.lattice.top(),)
76+
77+
return (
78+
Error(
79+
stmt,
80+
"U3 gate can only be used for initial state preparation, i.e. as the first gate!",
81+
),
82+
)

src/bloqade/gemini/groups.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def run_pass(
4242
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
4343
) -> None:
4444

45+
if inline and not aggressive_unroll:
46+
InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt)
47+
4548
if aggressive_unroll:
4649
AggressiveUnroll(mt.dialects, no_raise=no_raise).fixpoint(mt)
4750
else:
@@ -56,9 +59,6 @@ def run_pass(
5659

5760
default_pass.fixpoint(mt)
5861

59-
if inline and not aggressive_unroll:
60-
InlinePass(mt.dialects, no_raise=no_raise).fixpoint(mt)
61-
6262
if verify:
6363
validator = KernelValidation(GeminiLogicalValidationAnalysis)
6464
validator.run(mt, no_raise=no_raise)

src/bloqade/squin/gate/stmts.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99

1010
@statement
11-
class SingleQubitGate(ir.Statement):
11+
class Gate(ir.Statement):
12+
# NOTE: just for easier isinstance checks elsewhere, all gates inherit from this class
13+
pass
14+
15+
16+
@statement
17+
class SingleQubitGate(Gate):
1218
traits = frozenset({lowering.FromPythonCall()})
1319
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
1420

@@ -59,7 +65,7 @@ class SqrtY(SingleQubitNonHermitianGate):
5965

6066

6167
@statement
62-
class RotationGate(ir.Statement):
68+
class RotationGate(Gate):
6369
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
6470
traits = frozenset({lowering.FromPythonCall()})
6571
angle: ir.SSAValue = info.argument(types.Float)
@@ -85,7 +91,7 @@ class Rz(RotationGate):
8591

8692

8793
@statement
88-
class ControlledGate(ir.Statement):
94+
class ControlledGate(Gate):
8995
traits = frozenset({lowering.FromPythonCall()})
9096
controls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
9197
targets: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
@@ -110,7 +116,7 @@ class CZ(ControlledGate):
110116

111117

112118
@statement(dialect=dialect)
113-
class U3(ir.Statement):
119+
class U3(Gate):
114120
# NOTE: don't inherit from SingleQubitGate here so the wrapper doesn't have qubits as first arg
115121
traits = frozenset({lowering.FromPythonCall()})
116122
theta: ir.SSAValue = info.argument(types.Float)

src/bloqade/validation/analysis/analysis.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
from abc import ABC
2+
from typing import Iterable
3+
from dataclasses import field, dataclass
24

35
from kirin import ir
6+
from kirin.interp import AbstractFrame
47
from kirin.analysis import Forward, ForwardFrame
58

69
from .lattice import ErrorType
710

811
ValidationFrame = ForwardFrame[ErrorType]
912

1013

14+
@dataclass
1115
class ValidationAnalysis(Forward[ErrorType], ABC):
1216
"""Analysis pass that indicates errors in the IR according to the respective method tables.
1317
@@ -17,9 +21,43 @@ class ValidationAnalysis(Forward[ErrorType], ABC):
1721

1822
lattice = ErrorType
1923

24+
additional_errors: list[ErrorType] = field(default_factory=list)
25+
"""List to store return values that are not associated with an SSA Value (e.g. when the statement has no ResultValue)"""
26+
2027
def run_method(self, method: ir.Method, args: tuple[ErrorType, ...]):
21-
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
28+
return self.run_callable(method.code, (self.lattice.top(),) + args)
2229

2330
def eval_stmt_fallback(self, frame: ValidationFrame, stmt: ir.Statement):
2431
# NOTE: default to no errors
2532
return (self.lattice.top(),)
33+
34+
def set_values(
35+
self,
36+
frame: AbstractFrame[ErrorType],
37+
ssa: Iterable[ir.SSAValue],
38+
results: Iterable[ErrorType],
39+
):
40+
"""Set the abstract values for the given SSA values in the frame.
41+
42+
This method is overridden to account for additional errors we may
43+
encounter when they are not associated to an SSA Value.
44+
"""
45+
46+
number_of_ssa_values = 0
47+
for ssa_value, result in zip(ssa, results):
48+
number_of_ssa_values += 1
49+
if ssa_value in frame.entries:
50+
frame.entries[ssa_value] = frame.entries[ssa_value].join(result)
51+
else:
52+
frame.entries[ssa_value] = result
53+
54+
if isinstance(results, tuple):
55+
# NOTE: usually what we have
56+
self.additional_errors.extend(results[number_of_ssa_values:])
57+
58+
for i, result in enumerate(results):
59+
# NOTE: only sure-fire way I found to get remaining values from an Iterable
60+
if i < number_of_ssa_values:
61+
continue
62+
63+
self.additional_errors.append(result)

src/bloqade/validation/kernel_validation.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import itertools
12
from dataclasses import dataclass
23

34
from kirin import ir
45

56
from .analysis import ValidationFrame, ValidationAnalysis
6-
from .analysis.lattice import Error
7+
from .analysis.lattice import Error, ErrorType
78

89

910
@dataclass
@@ -14,7 +15,9 @@ def run(self, mt: ir.Method, **kwargs) -> None:
1415
validation_analysis = self.validation_analysis_cls(mt.dialects)
1516
validation_frame, _ = validation_analysis.run_analysis(mt, **kwargs)
1617

17-
errors = self.get_exceptions(mt, validation_frame)
18+
errors = self.get_exceptions(
19+
mt, validation_frame, validation_analysis.additional_errors
20+
)
1821

1922
if len(errors) == 0:
2023
# Valid program
@@ -23,9 +26,16 @@ def run(self, mt: ir.Method, **kwargs) -> None:
2326
# TODO: Make something similar to an ExceptionGroup that pretty-prints ValidationErrors
2427
raise errors[0]
2528

26-
def get_exceptions(self, mt: ir.Method, validation_frame: ValidationFrame):
29+
def get_exceptions(
30+
self,
31+
mt: ir.Method,
32+
validation_frame: ValidationFrame,
33+
additional_errors: list[ErrorType],
34+
):
2735
errors = []
28-
for value in validation_frame.entries.values():
36+
for value in itertools.chain(
37+
validation_frame.entries.values(), additional_errors
38+
):
2939
if not isinstance(value, Error):
3040
continue
3141

test/gemini/test_logical.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,27 @@ def invalid():
8484
sub_kernel(q[0])
8585

8686

87-
test_func()
87+
def test_clifford_gates():
88+
@gemini.logical
89+
def main():
90+
q = squin.qalloc(2)
91+
squin.u3(0.123, 0.253, 1.2, q[0])
92+
93+
squin.h(q[0])
94+
squin.cx(q[0], q[1])
95+
96+
with pytest.raises(ir.ValidationError):
97+
98+
@gemini.logical(no_raise=False)
99+
def invalid():
100+
q = squin.qalloc(2)
101+
102+
squin.h(q[0])
103+
squin.cx(q[0], q[1])
104+
squin.u3(0.123, 0.253, 1.2, q[0])
105+
106+
frame, _ = GeminiLogicalValidationAnalysis(invalid.dialects).run_analysis(
107+
invalid, no_raise=False
108+
)
109+
110+
invalid.print(analysis=frame.entries)

0 commit comments

Comments
 (0)