Skip to content

Commit 76399c1

Browse files
johnzl-777weinbe58
andauthored
TerminalLogicalMeasurement Validation (#641)
This is what I think should be sufficient for TerminalLogicalMeasurement Validation. I enforce the following conditions: 1. A program should only ever have ONE TerminalLogicalMeasurement 2. all spawned qubits should be consumed in that one terminal measurement. There should be no dangling qubits. There are some additions that come to mind - ranging from immediately relevant to some things that are further out in terms of infrastructure: - I had to make a separate validation pass as opposed to tacking things onto the current logical validation. I think because of this "logical validation" should be renamed to something a bit more precise. - AddressAnalysis already counts the number of (unique) qubits spawned but there's no way to access that information directly. It would be nice if there was some way to attach that to the final frame returned at the end of the analysis. Currently I tally up unique qubit ids and compare that against the number of things fed into the terminal measurement --------- Co-authored-by: Phillip Weinberg <[email protected]>
1 parent 0958f24 commit 76399c1

File tree

7 files changed

+253
-18
lines changed

7 files changed

+253
-18
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import py, scf, func, ilist
44

5-
from bloqade import qubit, annotate
5+
from bloqade import qubit, gemini, annotate
66

77
from .lattice import (
88
Predicate,
@@ -15,6 +15,8 @@
1515
)
1616
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
1717

18+
# from bloqade.gemini.dialects.logical import stmts as gemini_stmts, dialect as logical_dialect
19+
1820

1921
@qubit.dialect.register(key="measure_id")
2022
class SquinQubit(interp.MethodTable):
@@ -74,6 +76,31 @@ def measurement_predicate(
7476
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
7577

7678

79+
@gemini.logical.dialect.register(key="measure_id")
80+
class LogicalQubit(interp.MethodTable):
81+
@interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
82+
def terminal_measurement(
83+
self,
84+
interp: MeasurementIDAnalysis,
85+
frame: interp.Frame,
86+
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
87+
):
88+
# try to get the length of the list
89+
qubits_type = stmt.qubits.type
90+
# vars[0] is just the type of the elements in the ilist,
91+
# vars[1] can contain a literal with length information
92+
num_qubits = qubits_type.vars[1]
93+
if not isinstance(num_qubits, kirin_types.Literal):
94+
return (AnyMeasureId(),)
95+
96+
measure_id_bools = []
97+
for _ in range(num_qubits.data):
98+
interp.measure_count += 1
99+
measure_id_bools.append(RawMeasureId(interp.measure_count))
100+
101+
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
102+
103+
77104
@annotate.dialect.register(key="measure_id")
78105
class Annotate(interp.MethodTable):
79106
@interp.impl(annotate.stmts.SetObservable)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import impls as impls, analysis as analysis # NOTE: register methods
2+
from .analysis import (
3+
GeminiTerminalMeasurementValidation as GeminiTerminalMeasurementValidation,
4+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Any
2+
from dataclasses import field, dataclass
3+
4+
from kirin import ir
5+
from kirin.lattice import EmptyLattice
6+
from kirin.analysis import Forward, ForwardFrame
7+
from kirin.validation import ValidationPass
8+
9+
from bloqade.analysis import address, measure_id
10+
11+
12+
@dataclass
13+
class _GeminiTerminalMeasurementValidationAnalysis(Forward[EmptyLattice]):
14+
keys = ("gemini.validate.terminal_measurement",)
15+
16+
measurement_analysis_results: ForwardFrame
17+
unique_qubits_allocated: int = 0
18+
terminal_measurement_encountered: bool = False
19+
lattice = EmptyLattice
20+
21+
# boilerplate, not really worried about these right now
22+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
23+
return tuple(self.lattice.bottom() for _ in range(len(node.results)))
24+
25+
def method_self(self, method: ir.Method) -> EmptyLattice:
26+
return self.lattice.bottom()
27+
28+
29+
@dataclass
30+
class GeminiTerminalMeasurementValidation(ValidationPass):
31+
32+
analysis_cache: dict = field(default_factory=dict)
33+
34+
def name(self) -> str:
35+
return "Gemini Terminal Measurement Validation"
36+
37+
def get_required_analyses(self) -> list[type]:
38+
return [measure_id.MeasurementIDAnalysis]
39+
40+
def set_analysis_cache(self, cache: dict[type, Any]) -> None:
41+
self.analysis_cache.update(cache)
42+
43+
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
44+
45+
# get the data out of the cache and forward it to the underlying analysis
46+
measurement_analysis_results = self.analysis_cache.get(
47+
measure_id.MeasurementIDAnalysis
48+
)
49+
50+
address_analysis = address.AddressAnalysis(dialects=method.dialects)
51+
address_analysis.run(method)
52+
unique_qubits_allocated = address_analysis.qubit_count
53+
54+
assert (
55+
measurement_analysis_results is not None
56+
), "Measurement ID analysis results not found in cache"
57+
58+
analysis = _GeminiTerminalMeasurementValidationAnalysis(
59+
method.dialects,
60+
measurement_analysis_results,
61+
unique_qubits_allocated=unique_qubits_allocated,
62+
)
63+
64+
frame, _ = analysis.run(method)
65+
66+
return frame, analysis.get_validation_errors()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from kirin import ir, interp as _interp
2+
from kirin.analysis import ForwardFrame
3+
from kirin.dialects import func
4+
5+
from bloqade import qubit, gemini
6+
from bloqade.analysis.address.impls import Func as AddressFuncMethodTable
7+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
8+
9+
from .analysis import _GeminiTerminalMeasurementValidationAnalysis
10+
11+
12+
@qubit.dialect.register(key="gemini.validate.terminal_measurement")
13+
class __QubitGeminiMeasurementValidation(_interp.MethodTable):
14+
15+
# This is a non-logical measurement, can safely flag as invalid
16+
@_interp.impl(qubit.stmts.Measure)
17+
def measure(
18+
self,
19+
interp: _GeminiTerminalMeasurementValidationAnalysis,
20+
frame: ForwardFrame,
21+
stmt: qubit.stmts.Measure,
22+
):
23+
24+
interp.add_validation_error(
25+
stmt,
26+
ir.ValidationError(
27+
stmt,
28+
"Non-terminal measurements are not allowed in Gemini programs!",
29+
),
30+
)
31+
32+
return (interp.lattice.bottom(),)
33+
34+
35+
@gemini.logical.dialect.register(key="gemini.validate.terminal_measurement")
36+
class __GeminiLogicalMeasurementValidation(_interp.MethodTable):
37+
38+
@_interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
39+
def terminal_measure(
40+
self,
41+
interp: _GeminiTerminalMeasurementValidationAnalysis,
42+
frame: ForwardFrame,
43+
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
44+
):
45+
46+
# should only be one terminal measurement EVER
47+
if not interp.terminal_measurement_encountered:
48+
interp.terminal_measurement_encountered = True
49+
else:
50+
interp.add_validation_error(
51+
stmt,
52+
ir.ValidationError(
53+
stmt,
54+
"Multiple terminal measurements are not allowed in Gemini logical programs!",
55+
),
56+
)
57+
return (interp.lattice.bottom(),)
58+
59+
measurement_analysis_results = interp.measurement_analysis_results
60+
total_qubits_allocated = interp.unique_qubits_allocated
61+
62+
measure_lattice_element = measurement_analysis_results.get(stmt.result)
63+
if not isinstance(measure_lattice_element, MeasureIdTuple):
64+
interp.add_validation_error(
65+
stmt,
66+
ir.ValidationError(
67+
stmt,
68+
"Measurement ID Analysis failed to produce the necessary results needed for validation.",
69+
),
70+
)
71+
return (interp.lattice.bottom(),)
72+
73+
if len(measure_lattice_element.data) != total_qubits_allocated:
74+
interp.add_validation_error(
75+
stmt,
76+
ir.ValidationError(
77+
stmt,
78+
"The number of qubits in the terminal measurement does not match the number of total qubits allocated! "
79+
+ f"{total_qubits_allocated} qubits were allocated but only {len(measure_lattice_element.data)} were measured.",
80+
),
81+
)
82+
return (interp.lattice.bottom(),)
83+
84+
return (interp.lattice.bottom(),)
85+
86+
87+
@func.dialect.register(key="gemini.validate.terminal_measurement")
88+
class Func(AddressFuncMethodTable):
89+
pass

src/bloqade/gemini/dialects/logical/groups.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from bloqade.squin import gate, qubit
1212
from bloqade.rewrite.passes import AggressiveUnroll
13-
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidation
1413

1514
from ._dialect import dialect
1615

@@ -63,7 +62,17 @@ def run_pass(
6362
default_pass.fixpoint(mt)
6463

6564
if verify:
66-
validator = ValidationSuite([GeminiLogicalValidation])
65+
# stop circular import problems
66+
from bloqade.gemini.analysis.logical_validation import (
67+
GeminiLogicalValidation,
68+
)
69+
from bloqade.gemini.analysis.measurement_validation import (
70+
GeminiTerminalMeasurementValidation,
71+
)
72+
73+
validator = ValidationSuite(
74+
[GeminiLogicalValidation, GeminiTerminalMeasurementValidation]
75+
)
6776
validation_result = validator.validate(mt)
6877
validation_result.raise_if_invalid()
6978
mt.verify()

test/analysis/measure_id/test_measure_id.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from kirin.dialects import scf
22
from kirin.passes.inline import InlinePass
33

4-
from bloqade import squin
4+
from bloqade import squin, gemini
55
from bloqade.analysis.measure_id import MeasurementIDAnalysis
66
from bloqade.stim.passes.flatten import Flatten
77
from bloqade.analysis.measure_id.lattice import (
@@ -339,3 +339,23 @@ def test():
339339
assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools
340340
assert frame.get(results["is_one_bools"]) == expected_is_one_bools
341341
assert frame.get(results["is_lost_bools"]) == expected_is_lost_bools
342+
343+
344+
def test_terminal_logical_measurement():
345+
346+
@gemini.logical.kernel(no_raise=False, typeinfer=True, aggressive_unroll=True)
347+
def tm_logical_kernel():
348+
q = squin.qalloc(3)
349+
tm = gemini.logical.terminal_measure(q)
350+
return tm
351+
352+
frame, _ = MeasurementIDAnalysis(tm_logical_kernel.dialects).run(tm_logical_kernel)
353+
# will have a MeasureIdTuple that's not from the terminal measurement,
354+
# basically a container of InvalidMeasureIds from the qubits that get allocated
355+
analysis_results = [
356+
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
357+
]
358+
expected_result = MeasureIdTuple(
359+
data=tuple([RawMeasureId(idx=i) for i in range(1, 4)])
360+
)
361+
assert expected_result in analysis_results

test/gemini/test_logical_validation.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
from kirin import ir
32
from kirin.validation import ValidationSuite
43
from kirin.ir.exception import ValidationErrorGroup
54

@@ -9,6 +8,9 @@
98
GeminiLogicalValidation,
109
_GeminiLogicalValidationAnalysis,
1110
)
11+
from bloqade.gemini.analysis.measurement_validation.analysis import (
12+
GeminiTerminalMeasurementValidation,
13+
)
1214

1315

1416
def test_if_stmt_invalid():
@@ -122,21 +124,39 @@ def main():
122124

123125
main.print()
124126

125-
with pytest.raises(ir.ValidationError):
127+
@gemini.logical.kernel(
128+
verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True
129+
)
130+
def not_all_qubits_consumed():
131+
qs = squin.qalloc(3)
132+
sub_qs = qs[0:2]
133+
tm = gemini.logical.terminal_measure(sub_qs)
134+
return tm
126135

127-
@gemini.logical.kernel(no_raise=False)
128-
def invalid():
129-
q = squin.qalloc(3)
130-
squin.x(q[0])
131-
m = gemini.logical.terminal_measure(q)
132-
another_m = gemini.logical.terminal_measure(q)
133-
return m, another_m
136+
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
137+
validation_result = validator.validate(not_all_qubits_consumed)
134138

135-
frame, _ = _GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise(
136-
invalid
137-
)
139+
with pytest.raises(ValidationErrorGroup):
140+
validation_result.raise_if_invalid()
138141

139-
invalid.print(analysis=frame.entries)
142+
@gemini.logical.kernel(verify=False)
143+
def terminal_measure_kernel(q):
144+
return gemini.logical.terminal_measure(q)
145+
146+
@gemini.logical.kernel(
147+
verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True
148+
)
149+
def terminal_measure_in_kernel():
150+
q = squin.qalloc(10)
151+
sub_qs = q[:2]
152+
m = terminal_measure_kernel(sub_qs)
153+
return m
154+
155+
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
156+
validation_result = validator.validate(terminal_measure_in_kernel)
157+
158+
with pytest.raises(ValidationErrorGroup):
159+
validation_result.raise_if_invalid()
140160

141161

142162
def test_multiple_errors():
@@ -159,6 +179,6 @@ def main(n: int):
159179

160180
except ValidationErrorGroup as e:
161181
did_error = True
162-
assert len(e.errors) == 3
182+
assert len(e.errors) == 4
163183

164184
assert did_error

0 commit comments

Comments
 (0)