Skip to content

Commit 8eb84b7

Browse files
authored
Merge branch 'main' into phil/rewrite-non-clifford-u3-2
2 parents ae462fc + 76399c1 commit 8eb84b7

File tree

7 files changed

+253
-18
lines changed

7 files changed

+253
-18
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import py, scf, func, ilist
44

5-
from bloqade import qubit, annotate
5+
from bloqade import qubit, gemini, annotate
66

77
from .lattice import (
88
Predicate,
@@ -15,6 +15,8 @@
1515
)
1616
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
1717

18+
# from bloqade.gemini.dialects.logical import stmts as gemini_stmts, dialect as logical_dialect
19+
1820

1921
@qubit.dialect.register(key="measure_id")
2022
class SquinQubit(interp.MethodTable):
@@ -74,6 +76,31 @@ def measurement_predicate(
7476
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
7577

7678

79+
@gemini.logical.dialect.register(key="measure_id")
80+
class LogicalQubit(interp.MethodTable):
81+
@interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
82+
def terminal_measurement(
83+
self,
84+
interp: MeasurementIDAnalysis,
85+
frame: interp.Frame,
86+
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
87+
):
88+
# try to get the length of the list
89+
qubits_type = stmt.qubits.type
90+
# vars[0] is just the type of the elements in the ilist,
91+
# vars[1] can contain a literal with length information
92+
num_qubits = qubits_type.vars[1]
93+
if not isinstance(num_qubits, kirin_types.Literal):
94+
return (AnyMeasureId(),)
95+
96+
measure_id_bools = []
97+
for _ in range(num_qubits.data):
98+
interp.measure_count += 1
99+
measure_id_bools.append(RawMeasureId(interp.measure_count))
100+
101+
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
102+
103+
77104
@annotate.dialect.register(key="measure_id")
78105
class Annotate(interp.MethodTable):
79106
@interp.impl(annotate.stmts.SetObservable)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import impls as impls, analysis as analysis # NOTE: register methods
2+
from .analysis import (
3+
GeminiTerminalMeasurementValidation as GeminiTerminalMeasurementValidation,
4+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Any
2+
from dataclasses import field, dataclass
3+
4+
from kirin import ir
5+
from kirin.lattice import EmptyLattice
6+
from kirin.analysis import Forward, ForwardFrame
7+
from kirin.validation import ValidationPass
8+
9+
from bloqade.analysis import address, measure_id
10+
11+
12+
@dataclass
13+
class _GeminiTerminalMeasurementValidationAnalysis(Forward[EmptyLattice]):
14+
keys = ("gemini.validate.terminal_measurement",)
15+
16+
measurement_analysis_results: ForwardFrame
17+
unique_qubits_allocated: int = 0
18+
terminal_measurement_encountered: bool = False
19+
lattice = EmptyLattice
20+
21+
# boilerplate, not really worried about these right now
22+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
23+
return tuple(self.lattice.bottom() for _ in range(len(node.results)))
24+
25+
def method_self(self, method: ir.Method) -> EmptyLattice:
26+
return self.lattice.bottom()
27+
28+
29+
@dataclass
30+
class GeminiTerminalMeasurementValidation(ValidationPass):
31+
32+
analysis_cache: dict = field(default_factory=dict)
33+
34+
def name(self) -> str:
35+
return "Gemini Terminal Measurement Validation"
36+
37+
def get_required_analyses(self) -> list[type]:
38+
return [measure_id.MeasurementIDAnalysis]
39+
40+
def set_analysis_cache(self, cache: dict[type, Any]) -> None:
41+
self.analysis_cache.update(cache)
42+
43+
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
44+
45+
# get the data out of the cache and forward it to the underlying analysis
46+
measurement_analysis_results = self.analysis_cache.get(
47+
measure_id.MeasurementIDAnalysis
48+
)
49+
50+
address_analysis = address.AddressAnalysis(dialects=method.dialects)
51+
address_analysis.run(method)
52+
unique_qubits_allocated = address_analysis.qubit_count
53+
54+
assert (
55+
measurement_analysis_results is not None
56+
), "Measurement ID analysis results not found in cache"
57+
58+
analysis = _GeminiTerminalMeasurementValidationAnalysis(
59+
method.dialects,
60+
measurement_analysis_results,
61+
unique_qubits_allocated=unique_qubits_allocated,
62+
)
63+
64+
frame, _ = analysis.run(method)
65+
66+
return frame, analysis.get_validation_errors()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from kirin import ir, interp as _interp
2+
from kirin.analysis import ForwardFrame
3+
from kirin.dialects import func
4+
5+
from bloqade import qubit, gemini
6+
from bloqade.analysis.address.impls import Func as AddressFuncMethodTable
7+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
8+
9+
from .analysis import _GeminiTerminalMeasurementValidationAnalysis
10+
11+
12+
@qubit.dialect.register(key="gemini.validate.terminal_measurement")
13+
class __QubitGeminiMeasurementValidation(_interp.MethodTable):
14+
15+
# This is a non-logical measurement, can safely flag as invalid
16+
@_interp.impl(qubit.stmts.Measure)
17+
def measure(
18+
self,
19+
interp: _GeminiTerminalMeasurementValidationAnalysis,
20+
frame: ForwardFrame,
21+
stmt: qubit.stmts.Measure,
22+
):
23+
24+
interp.add_validation_error(
25+
stmt,
26+
ir.ValidationError(
27+
stmt,
28+
"Non-terminal measurements are not allowed in Gemini programs!",
29+
),
30+
)
31+
32+
return (interp.lattice.bottom(),)
33+
34+
35+
@gemini.logical.dialect.register(key="gemini.validate.terminal_measurement")
36+
class __GeminiLogicalMeasurementValidation(_interp.MethodTable):
37+
38+
@_interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
39+
def terminal_measure(
40+
self,
41+
interp: _GeminiTerminalMeasurementValidationAnalysis,
42+
frame: ForwardFrame,
43+
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
44+
):
45+
46+
# should only be one terminal measurement EVER
47+
if not interp.terminal_measurement_encountered:
48+
interp.terminal_measurement_encountered = True
49+
else:
50+
interp.add_validation_error(
51+
stmt,
52+
ir.ValidationError(
53+
stmt,
54+
"Multiple terminal measurements are not allowed in Gemini logical programs!",
55+
),
56+
)
57+
return (interp.lattice.bottom(),)
58+
59+
measurement_analysis_results = interp.measurement_analysis_results
60+
total_qubits_allocated = interp.unique_qubits_allocated
61+
62+
measure_lattice_element = measurement_analysis_results.get(stmt.result)
63+
if not isinstance(measure_lattice_element, MeasureIdTuple):
64+
interp.add_validation_error(
65+
stmt,
66+
ir.ValidationError(
67+
stmt,
68+
"Measurement ID Analysis failed to produce the necessary results needed for validation.",
69+
),
70+
)
71+
return (interp.lattice.bottom(),)
72+
73+
if len(measure_lattice_element.data) != total_qubits_allocated:
74+
interp.add_validation_error(
75+
stmt,
76+
ir.ValidationError(
77+
stmt,
78+
"The number of qubits in the terminal measurement does not match the number of total qubits allocated! "
79+
+ f"{total_qubits_allocated} qubits were allocated but only {len(measure_lattice_element.data)} were measured.",
80+
),
81+
)
82+
return (interp.lattice.bottom(),)
83+
84+
return (interp.lattice.bottom(),)
85+
86+
87+
@func.dialect.register(key="gemini.validate.terminal_measurement")
88+
class Func(AddressFuncMethodTable):
89+
pass

src/bloqade/gemini/dialects/logical/groups.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from bloqade.squin import gate, qubit
1212
from bloqade.rewrite.passes import AggressiveUnroll
13-
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidation
1413

1514
from ._dialect import dialect
1615

@@ -63,7 +62,17 @@ def run_pass(
6362
default_pass.fixpoint(mt)
6463

6564
if verify:
66-
validator = ValidationSuite([GeminiLogicalValidation])
65+
# stop circular import problems
66+
from bloqade.gemini.analysis.logical_validation import (
67+
GeminiLogicalValidation,
68+
)
69+
from bloqade.gemini.analysis.measurement_validation import (
70+
GeminiTerminalMeasurementValidation,
71+
)
72+
73+
validator = ValidationSuite(
74+
[GeminiLogicalValidation, GeminiTerminalMeasurementValidation]
75+
)
6776
validation_result = validator.validate(mt)
6877
validation_result.raise_if_invalid()
6978
mt.verify()

test/analysis/measure_id/test_measure_id.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from kirin.dialects import scf
22
from kirin.passes.inline import InlinePass
33

4-
from bloqade import squin
4+
from bloqade import squin, gemini
55
from bloqade.analysis.measure_id import MeasurementIDAnalysis
66
from bloqade.stim.passes.flatten import Flatten
77
from bloqade.analysis.measure_id.lattice import (
@@ -339,3 +339,23 @@ def test():
339339
assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools
340340
assert frame.get(results["is_one_bools"]) == expected_is_one_bools
341341
assert frame.get(results["is_lost_bools"]) == expected_is_lost_bools
342+
343+
344+
def test_terminal_logical_measurement():
345+
346+
@gemini.logical.kernel(no_raise=False, typeinfer=True, aggressive_unroll=True)
347+
def tm_logical_kernel():
348+
q = squin.qalloc(3)
349+
tm = gemini.logical.terminal_measure(q)
350+
return tm
351+
352+
frame, _ = MeasurementIDAnalysis(tm_logical_kernel.dialects).run(tm_logical_kernel)
353+
# will have a MeasureIdTuple that's not from the terminal measurement,
354+
# basically a container of InvalidMeasureIds from the qubits that get allocated
355+
analysis_results = [
356+
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
357+
]
358+
expected_result = MeasureIdTuple(
359+
data=tuple([RawMeasureId(idx=i) for i in range(1, 4)])
360+
)
361+
assert expected_result in analysis_results

test/gemini/test_logical_validation.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
from kirin import ir
32
from kirin.validation import ValidationSuite
43
from kirin.ir.exception import ValidationErrorGroup
54

@@ -9,6 +8,9 @@
98
GeminiLogicalValidation,
109
_GeminiLogicalValidationAnalysis,
1110
)
11+
from bloqade.gemini.analysis.measurement_validation.analysis import (
12+
GeminiTerminalMeasurementValidation,
13+
)
1214

1315

1416
def test_if_stmt_invalid():
@@ -122,21 +124,39 @@ def main():
122124

123125
main.print()
124126

125-
with pytest.raises(ir.ValidationError):
127+
@gemini.logical.kernel(
128+
verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True
129+
)
130+
def not_all_qubits_consumed():
131+
qs = squin.qalloc(3)
132+
sub_qs = qs[0:2]
133+
tm = gemini.logical.terminal_measure(sub_qs)
134+
return tm
126135

127-
@gemini.logical.kernel(no_raise=False)
128-
def invalid():
129-
q = squin.qalloc(3)
130-
squin.x(q[0])
131-
m = gemini.logical.terminal_measure(q)
132-
another_m = gemini.logical.terminal_measure(q)
133-
return m, another_m
136+
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
137+
validation_result = validator.validate(not_all_qubits_consumed)
134138

135-
frame, _ = _GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise(
136-
invalid
137-
)
139+
with pytest.raises(ValidationErrorGroup):
140+
validation_result.raise_if_invalid()
138141

139-
invalid.print(analysis=frame.entries)
142+
@gemini.logical.kernel(verify=False)
143+
def terminal_measure_kernel(q):
144+
return gemini.logical.terminal_measure(q)
145+
146+
@gemini.logical.kernel(
147+
verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True
148+
)
149+
def terminal_measure_in_kernel():
150+
q = squin.qalloc(10)
151+
sub_qs = q[:2]
152+
m = terminal_measure_kernel(sub_qs)
153+
return m
154+
155+
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
156+
validation_result = validator.validate(terminal_measure_in_kernel)
157+
158+
with pytest.raises(ValidationErrorGroup):
159+
validation_result.raise_if_invalid()
140160

141161

142162
def test_multiple_errors():
@@ -159,6 +179,6 @@ def main(n: int):
159179

160180
except ValidationErrorGroup as e:
161181
did_error = True
162-
assert len(e.errors) == 3
182+
assert len(e.errors) == 4
163183

164184
assert did_error

0 commit comments

Comments
 (0)