Skip to content

Commit 79880d4

Browse files
committed
Merge branch 'main' into john/repeat-support
2 parents a8da761 + ca47d3d commit 79880d4

File tree

11 files changed

+513
-27
lines changed

11 files changed

+513
-27
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from kirin.dialects import py, scf, func, ilist
55
from kirin.ir.attrs.py import PyAttr
66

7-
from bloqade import qubit, annotate
7+
from bloqade import qubit, gemini, annotate
88

99
from .lattice import (
1010
Predicate,
@@ -18,6 +18,8 @@
1818
)
1919
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
2020

21+
# from bloqade.gemini.dialects.logical import stmts as gemini_stmts, dialect as logical_dialect
22+
2123

2224
@qubit.dialect.register(key="measure_id")
2325
class SquinQubit(interp.MethodTable):
@@ -80,6 +82,31 @@ def measurement_predicate(
8082
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)
8183

8284

85+
@gemini.logical.dialect.register(key="measure_id")
86+
class LogicalQubit(interp.MethodTable):
87+
@interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
88+
def terminal_measurement(
89+
self,
90+
interp: MeasurementIDAnalysis,
91+
frame: interp.Frame,
92+
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
93+
):
94+
# try to get the length of the list
95+
qubits_type = stmt.qubits.type
96+
# vars[0] is just the type of the elements in the ilist,
97+
# vars[1] can contain a literal with length information
98+
num_qubits = qubits_type.vars[1]
99+
if not isinstance(num_qubits, kirin_types.Literal):
100+
return (AnyMeasureId(),)
101+
102+
measure_id_bools = []
103+
for _ in range(num_qubits.data):
104+
interp.measure_count += 1
105+
measure_id_bools.append(RawMeasureId(interp.measure_count))
106+
107+
return (MeasureIdTuple(data=tuple(measure_id_bools)),)
108+
109+
83110
@annotate.dialect.register(key="measure_id")
84111
class Annotate(interp.MethodTable):
85112
@interp.impl(annotate.stmts.SetObservable)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import impls as impls, analysis as analysis # NOTE: register methods
2+
from .analysis import (
3+
GeminiTerminalMeasurementValidation as GeminiTerminalMeasurementValidation,
4+
)
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from typing import Any
2+
from dataclasses import field, dataclass
3+
4+
from kirin import ir
5+
from kirin.lattice import EmptyLattice
6+
from kirin.analysis import Forward, ForwardFrame
7+
from kirin.validation import ValidationPass
8+
9+
from bloqade.analysis import address, measure_id
10+
11+
12+
@dataclass
13+
class _GeminiTerminalMeasurementValidationAnalysis(Forward[EmptyLattice]):
14+
keys = ("gemini.validate.terminal_measurement",)
15+
16+
measurement_analysis_results: ForwardFrame
17+
unique_qubits_allocated: int = 0
18+
terminal_measurement_encountered: bool = False
19+
lattice = EmptyLattice
20+
21+
# boilerplate, not really worried about these right now
22+
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
23+
return tuple(self.lattice.bottom() for _ in range(len(node.results)))
24+
25+
def method_self(self, method: ir.Method) -> EmptyLattice:
26+
return self.lattice.bottom()
27+
28+
29+
@dataclass
30+
class GeminiTerminalMeasurementValidation(ValidationPass):
31+
32+
analysis_cache: dict = field(default_factory=dict)
33+
34+
def name(self) -> str:
35+
return "Gemini Terminal Measurement Validation"
36+
37+
def get_required_analyses(self) -> list[type]:
38+
return [measure_id.MeasurementIDAnalysis]
39+
40+
def set_analysis_cache(self, cache: dict[type, Any]) -> None:
41+
self.analysis_cache.update(cache)
42+
43+
def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:
44+
45+
# get the data out of the cache and forward it to the underlying analysis
46+
measurement_analysis_results = self.analysis_cache.get(
47+
measure_id.MeasurementIDAnalysis
48+
)
49+
50+
address_analysis = address.AddressAnalysis(dialects=method.dialects)
51+
address_analysis.run(method)
52+
unique_qubits_allocated = address_analysis.qubit_count
53+
54+
assert (
55+
measurement_analysis_results is not None
56+
), "Measurement ID analysis results not found in cache"
57+
58+
analysis = _GeminiTerminalMeasurementValidationAnalysis(
59+
method.dialects,
60+
measurement_analysis_results,
61+
unique_qubits_allocated=unique_qubits_allocated,
62+
)
63+
64+
frame, _ = analysis.run(method)
65+
66+
return frame, analysis.get_validation_errors()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from kirin import ir, interp as _interp
2+
from kirin.analysis import ForwardFrame
3+
from kirin.dialects import func
4+
5+
from bloqade import qubit, gemini
6+
from bloqade.analysis.address.impls import Func as AddressFuncMethodTable
7+
from bloqade.analysis.measure_id.lattice import MeasureIdTuple
8+
9+
from .analysis import _GeminiTerminalMeasurementValidationAnalysis
10+
11+
12+
@qubit.dialect.register(key="gemini.validate.terminal_measurement")
13+
class __QubitGeminiMeasurementValidation(_interp.MethodTable):
14+
15+
# This is a non-logical measurement, can safely flag as invalid
16+
@_interp.impl(qubit.stmts.Measure)
17+
def measure(
18+
self,
19+
interp: _GeminiTerminalMeasurementValidationAnalysis,
20+
frame: ForwardFrame,
21+
stmt: qubit.stmts.Measure,
22+
):
23+
24+
interp.add_validation_error(
25+
stmt,
26+
ir.ValidationError(
27+
stmt,
28+
"Non-terminal measurements are not allowed in Gemini programs!",
29+
),
30+
)
31+
32+
return (interp.lattice.bottom(),)
33+
34+
35+
@gemini.logical.dialect.register(key="gemini.validate.terminal_measurement")
36+
class __GeminiLogicalMeasurementValidation(_interp.MethodTable):
37+
38+
@_interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
39+
def terminal_measure(
40+
self,
41+
interp: _GeminiTerminalMeasurementValidationAnalysis,
42+
frame: ForwardFrame,
43+
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
44+
):
45+
46+
# should only be one terminal measurement EVER
47+
if not interp.terminal_measurement_encountered:
48+
interp.terminal_measurement_encountered = True
49+
else:
50+
interp.add_validation_error(
51+
stmt,
52+
ir.ValidationError(
53+
stmt,
54+
"Multiple terminal measurements are not allowed in Gemini logical programs!",
55+
),
56+
)
57+
return (interp.lattice.bottom(),)
58+
59+
measurement_analysis_results = interp.measurement_analysis_results
60+
total_qubits_allocated = interp.unique_qubits_allocated
61+
62+
measure_lattice_element = measurement_analysis_results.get(stmt.result)
63+
if not isinstance(measure_lattice_element, MeasureIdTuple):
64+
interp.add_validation_error(
65+
stmt,
66+
ir.ValidationError(
67+
stmt,
68+
"Measurement ID Analysis failed to produce the necessary results needed for validation.",
69+
),
70+
)
71+
return (interp.lattice.bottom(),)
72+
73+
if len(measure_lattice_element.data) != total_qubits_allocated:
74+
interp.add_validation_error(
75+
stmt,
76+
ir.ValidationError(
77+
stmt,
78+
"The number of qubits in the terminal measurement does not match the number of total qubits allocated! "
79+
+ f"{total_qubits_allocated} qubits were allocated but only {len(measure_lattice_element.data)} were measured.",
80+
),
81+
)
82+
return (interp.lattice.bottom(),)
83+
84+
return (interp.lattice.bottom(),)
85+
86+
87+
@func.dialect.register(key="gemini.validate.terminal_measurement")
88+
class Func(AddressFuncMethodTable):
89+
pass

src/bloqade/gemini/dialects/logical/_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
from .stmts import TerminalLogicalMeasurement
99

10-
Len = TypeVar("Len", bound=int)
11-
CodeN = TypeVar("CodeN", bound=int)
10+
Len = TypeVar("Len")
11+
CodeN = TypeVar("CodeN")
1212

1313

1414
@lowering.wraps(TerminalLogicalMeasurement)

src/bloqade/gemini/dialects/logical/groups.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from bloqade.squin import gate, qubit
1212
from bloqade.rewrite.passes import AggressiveUnroll
13-
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidation
1413

1514
from ._dialect import dialect
1615

@@ -63,7 +62,17 @@ def run_pass(
6362
default_pass.fixpoint(mt)
6463

6564
if verify:
66-
validator = ValidationSuite([GeminiLogicalValidation])
65+
# stop circular import problems
66+
from bloqade.gemini.analysis.logical_validation import (
67+
GeminiLogicalValidation,
68+
)
69+
from bloqade.gemini.analysis.measurement_validation import (
70+
GeminiTerminalMeasurementValidation,
71+
)
72+
73+
validator = ValidationSuite(
74+
[GeminiLogicalValidation, GeminiTerminalMeasurementValidation]
75+
)
6776
validation_result = validator.validate(mt)
6877
validation_result.raise_if_invalid()
6978
mt.verify()

src/bloqade/gemini/dialects/logical/stmts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from ._dialect import dialect
88

9-
Len = types.TypeVar("Len", bound=types.Int)
10-
CodeN = types.TypeVar("CodeN", bound=types.Int)
9+
Len = types.TypeVar("Len")
10+
CodeN = types.TypeVar("CodeN")
1111

1212

1313
@statement(dialect=dialect)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from kirin import ir
2+
from kirin.rewrite import abc as rewrite_abc
3+
from kirin.dialects import py
4+
5+
from bloqade.squin.gate import stmts as gate_stmts
6+
7+
8+
class RewriteNonCliffordToU3(rewrite_abc.RewriteRule):
9+
"""Rewrite non-Clifford gates to U3 gates.
10+
11+
This rewrite rule transforms specific non-Clifford single-qubit gates
12+
into equivalent U3 gate representations. The following transformations are applied:
13+
- T gate (with adjoint attribute) to U3 gate with parameters (0, 0, ±π/4)
14+
- Rx gate to U3 gate with parameters (angle, -π/2, π/2)
15+
- Ry gate to U3 gate with parameters (angle, 0, 0)
16+
- Rz gate is U3 gate with parameters (0, 0, angle)
17+
18+
This rewrite should be paired with `U3ToClifford` to canonicalize the circuit.
19+
20+
"""
21+
22+
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
23+
if not isinstance(
24+
node,
25+
(
26+
gate_stmts.T,
27+
gate_stmts.Rx,
28+
gate_stmts.Ry,
29+
gate_stmts.Rz,
30+
),
31+
):
32+
return rewrite_abc.RewriteResult()
33+
34+
rule = getattr(self, f"rewrite_{type(node).__name__}")
35+
36+
return rule(node)
37+
38+
def rewrite_T(self, node: gate_stmts.T) -> rewrite_abc.RewriteResult:
39+
if node.adjoint:
40+
lam_value = -1.0 / 8.0
41+
else:
42+
lam_value = 1.0 / 8.0
43+
44+
(theta_stmt := py.Constant(0.0)).insert_before(node)
45+
(phi_stmt := py.Constant(0.0)).insert_before(node)
46+
(lam_stmt := py.Constant(lam_value)).insert_before(node)
47+
48+
node.replace_by(
49+
gate_stmts.U3(
50+
qubits=node.qubits,
51+
theta=theta_stmt.result,
52+
phi=phi_stmt.result,
53+
lam=lam_stmt.result,
54+
)
55+
)
56+
57+
return rewrite_abc.RewriteResult(has_done_something=True)
58+
59+
def rewrite_Rx(self, node: gate_stmts.Rx) -> rewrite_abc.RewriteResult:
60+
(phi_stmt := py.Constant(-0.25)).insert_before(node)
61+
(lam_stmt := py.Constant(0.25)).insert_before(node)
62+
63+
node.replace_by(
64+
gate_stmts.U3(
65+
qubits=node.qubits,
66+
theta=node.angle,
67+
phi=phi_stmt.result,
68+
lam=lam_stmt.result,
69+
)
70+
)
71+
72+
return rewrite_abc.RewriteResult(has_done_something=True)
73+
74+
def rewrite_Ry(self, node: gate_stmts.Ry) -> rewrite_abc.RewriteResult:
75+
(phi_stmt := py.Constant(0.0)).insert_before(node)
76+
(lam_stmt := py.Constant(0.0)).insert_before(node)
77+
78+
node.replace_by(
79+
gate_stmts.U3(
80+
qubits=node.qubits,
81+
theta=node.angle,
82+
phi=phi_stmt.result,
83+
lam=lam_stmt.result,
84+
)
85+
)
86+
87+
return rewrite_abc.RewriteResult(has_done_something=True)
88+
89+
def rewrite_Rz(self, node: gate_stmts.Rz) -> rewrite_abc.RewriteResult:
90+
(theta_stmt := py.Constant(0.0)).insert_before(node)
91+
(phi_stmt := py.Constant(0.0)).insert_before(node)
92+
93+
node.replace_by(
94+
gate_stmts.U3(
95+
qubits=node.qubits,
96+
theta=theta_stmt.result,
97+
phi=phi_stmt.result,
98+
lam=node.angle,
99+
)
100+
)
101+
102+
return rewrite_abc.RewriteResult(has_done_something=True)

test/analysis/measure_id/test_measure_id.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from kirin.dialects import scf
22
from kirin.passes.inline import InlinePass
33

4-
from bloqade import squin
4+
from bloqade import squin, gemini
55
from bloqade.analysis.measure_id import MeasurementIDAnalysis
66
from bloqade.stim.passes.flatten import Flatten
77
from bloqade.analysis.measure_id.lattice import (
@@ -339,3 +339,23 @@ def test():
339339
assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools
340340
assert frame.get(results["is_one_bools"]) == expected_is_one_bools
341341
assert frame.get(results["is_lost_bools"]) == expected_is_lost_bools
342+
343+
344+
def test_terminal_logical_measurement():
345+
346+
@gemini.logical.kernel(no_raise=False, typeinfer=True, aggressive_unroll=True)
347+
def tm_logical_kernel():
348+
q = squin.qalloc(3)
349+
tm = gemini.logical.terminal_measure(q)
350+
return tm
351+
352+
frame, _ = MeasurementIDAnalysis(tm_logical_kernel.dialects).run(tm_logical_kernel)
353+
# will have a MeasureIdTuple that's not from the terminal measurement,
354+
# basically a container of InvalidMeasureIds from the qubits that get allocated
355+
analysis_results = [
356+
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
357+
]
358+
expected_result = MeasureIdTuple(
359+
data=tuple([RawMeasureId(idx=i) for i in range(1, 4)])
360+
)
361+
assert expected_result in analysis_results

0 commit comments

Comments
 (0)