Skip to content

Commit 6dcf451

Browse files
committed
add invoke support
1 parent 0c3c8f1 commit 6dcf451

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/bloqade/gemini/analysis/measurement_validation/impls.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin import ir, interp as _interp
22
from kirin.analysis import ForwardFrame
3+
from kirin.dialects import func
34

45
from bloqade import qubit, gemini
56
from bloqade.analysis.address.lattice import AddressReg, AddressQubit
@@ -108,3 +109,21 @@ def terminal_measure(
108109
return (interp.lattice.bottom(),)
109110

110111
return (interp.lattice.bottom(),)
112+
113+
114+
@func.dialect.register(key="gemini.validate.terminal_measurement")
115+
class Func(_interp.MethodTable):
116+
@_interp.impl(func.Invoke)
117+
def return_(
118+
self,
119+
interp: _GeminiTerminalMeasurementValidationAnalysis,
120+
frame: ForwardFrame,
121+
stmt: func.Invoke,
122+
):
123+
_, ret = interp.call(
124+
stmt.callee.code,
125+
interp.method_self(stmt.callee),
126+
*frame.get_values(stmt.inputs),
127+
)
128+
129+
return (ret,)

test/gemini/test_logical_validation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ def not_all_qubits_consumed():
137137
with pytest.raises(ValidationErrorGroup):
138138
validation_result.raise_if_invalid()
139139

140+
@gemini.logical.kernel
141+
def terminal_measure_kernel(q):
142+
return gemini.logical.terminal_measure(q)
143+
144+
@gemini.logical.kernel(no_raise=False, aggressive_unroll=True, typeinfer=True)
145+
def terminal_measure_in_kernel():
146+
q = squin.qalloc(10)
147+
sub_qs = q[:2]
148+
m = terminal_measure_kernel(sub_qs)
149+
return m
150+
151+
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
152+
validation_result = validator.validate(terminal_measure_in_kernel)
153+
154+
with pytest.raises(ValidationErrorGroup):
155+
validation_result.raise_if_invalid()
156+
140157

141158
def test_multiple_errors():
142159
did_error = False

0 commit comments

Comments
 (0)