diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 80f63076..36f73bac 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -2,7 +2,7 @@ from kirin.analysis import const from kirin.dialects import py, scf, func, ilist -from bloqade import qubit, annotate +from bloqade import qubit, gemini, annotate from .lattice import ( Predicate, @@ -15,6 +15,8 @@ ) from .analysis import MeasureIDFrame, MeasurementIDAnalysis +# from bloqade.gemini.dialects.logical import stmts as gemini_stmts, dialect as logical_dialect + @qubit.dialect.register(key="measure_id") class SquinQubit(interp.MethodTable): @@ -74,6 +76,31 @@ def measurement_predicate( return (MeasureIdTuple(data=tuple(predicate_measure_ids)),) +@gemini.logical.dialect.register(key="measure_id") +class LogicalQubit(interp.MethodTable): + @interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement) + def terminal_measurement( + self, + interp: MeasurementIDAnalysis, + frame: interp.Frame, + stmt: gemini.logical.stmts.TerminalLogicalMeasurement, + ): + # try to get the length of the list + qubits_type = stmt.qubits.type + # vars[0] is just the type of the elements in the ilist, + # vars[1] can contain a literal with length information + num_qubits = qubits_type.vars[1] + if not isinstance(num_qubits, kirin_types.Literal): + return (AnyMeasureId(),) + + measure_id_bools = [] + for _ in range(num_qubits.data): + interp.measure_count += 1 + measure_id_bools.append(RawMeasureId(interp.measure_count)) + + return (MeasureIdTuple(data=tuple(measure_id_bools)),) + + @annotate.dialect.register(key="measure_id") class Annotate(interp.MethodTable): @interp.impl(annotate.stmts.SetObservable) diff --git a/src/bloqade/gemini/analysis/measurement_validation/__init__.py b/src/bloqade/gemini/analysis/measurement_validation/__init__.py new file mode 100644 index 00000000..3b403330 --- /dev/null +++ b/src/bloqade/gemini/analysis/measurement_validation/__init__.py @@ -0,0 +1,4 @@ +from . import impls as impls, analysis as analysis # NOTE: register methods +from .analysis import ( + GeminiTerminalMeasurementValidation as GeminiTerminalMeasurementValidation, +) diff --git a/src/bloqade/gemini/analysis/measurement_validation/analysis.py b/src/bloqade/gemini/analysis/measurement_validation/analysis.py new file mode 100644 index 00000000..a7c35f4b --- /dev/null +++ b/src/bloqade/gemini/analysis/measurement_validation/analysis.py @@ -0,0 +1,66 @@ +from typing import Any +from dataclasses import field, dataclass + +from kirin import ir +from kirin.lattice import EmptyLattice +from kirin.analysis import Forward, ForwardFrame +from kirin.validation import ValidationPass + +from bloqade.analysis import address, measure_id + + +@dataclass +class _GeminiTerminalMeasurementValidationAnalysis(Forward[EmptyLattice]): + keys = ("gemini.validate.terminal_measurement",) + + measurement_analysis_results: ForwardFrame + unique_qubits_allocated: int = 0 + terminal_measurement_encountered: bool = False + lattice = EmptyLattice + + # boilerplate, not really worried about these right now + def eval_fallback(self, frame: ForwardFrame, node: ir.Statement): + return tuple(self.lattice.bottom() for _ in range(len(node.results))) + + def method_self(self, method: ir.Method) -> EmptyLattice: + return self.lattice.bottom() + + +@dataclass +class GeminiTerminalMeasurementValidation(ValidationPass): + + analysis_cache: dict = field(default_factory=dict) + + def name(self) -> str: + return "Gemini Terminal Measurement Validation" + + def get_required_analyses(self) -> list[type]: + return [measure_id.MeasurementIDAnalysis] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + self.analysis_cache.update(cache) + + def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]: + + # get the data out of the cache and forward it to the underlying analysis + measurement_analysis_results = self.analysis_cache.get( + measure_id.MeasurementIDAnalysis + ) + + address_analysis = address.AddressAnalysis(dialects=method.dialects) + address_analysis.run(method) + unique_qubits_allocated = address_analysis.qubit_count + + assert ( + measurement_analysis_results is not None + ), "Measurement ID analysis results not found in cache" + + analysis = _GeminiTerminalMeasurementValidationAnalysis( + method.dialects, + measurement_analysis_results, + unique_qubits_allocated=unique_qubits_allocated, + ) + + frame, _ = analysis.run(method) + + return frame, analysis.get_validation_errors() diff --git a/src/bloqade/gemini/analysis/measurement_validation/impls.py b/src/bloqade/gemini/analysis/measurement_validation/impls.py new file mode 100644 index 00000000..6a713ce4 --- /dev/null +++ b/src/bloqade/gemini/analysis/measurement_validation/impls.py @@ -0,0 +1,89 @@ +from kirin import ir, interp as _interp +from kirin.analysis import ForwardFrame +from kirin.dialects import func + +from bloqade import qubit, gemini +from bloqade.analysis.address.impls import Func as AddressFuncMethodTable +from bloqade.analysis.measure_id.lattice import MeasureIdTuple + +from .analysis import _GeminiTerminalMeasurementValidationAnalysis + + +@qubit.dialect.register(key="gemini.validate.terminal_measurement") +class __QubitGeminiMeasurementValidation(_interp.MethodTable): + + # This is a non-logical measurement, can safely flag as invalid + @_interp.impl(qubit.stmts.Measure) + def measure( + self, + interp: _GeminiTerminalMeasurementValidationAnalysis, + frame: ForwardFrame, + stmt: qubit.stmts.Measure, + ): + + interp.add_validation_error( + stmt, + ir.ValidationError( + stmt, + "Non-terminal measurements are not allowed in Gemini programs!", + ), + ) + + return (interp.lattice.bottom(),) + + +@gemini.logical.dialect.register(key="gemini.validate.terminal_measurement") +class __GeminiLogicalMeasurementValidation(_interp.MethodTable): + + @_interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement) + def terminal_measure( + self, + interp: _GeminiTerminalMeasurementValidationAnalysis, + frame: ForwardFrame, + stmt: gemini.logical.stmts.TerminalLogicalMeasurement, + ): + + # should only be one terminal measurement EVER + if not interp.terminal_measurement_encountered: + interp.terminal_measurement_encountered = True + else: + interp.add_validation_error( + stmt, + ir.ValidationError( + stmt, + "Multiple terminal measurements are not allowed in Gemini logical programs!", + ), + ) + return (interp.lattice.bottom(),) + + measurement_analysis_results = interp.measurement_analysis_results + total_qubits_allocated = interp.unique_qubits_allocated + + measure_lattice_element = measurement_analysis_results.get(stmt.result) + if not isinstance(measure_lattice_element, MeasureIdTuple): + interp.add_validation_error( + stmt, + ir.ValidationError( + stmt, + "Measurement ID Analysis failed to produce the necessary results needed for validation.", + ), + ) + return (interp.lattice.bottom(),) + + if len(measure_lattice_element.data) != total_qubits_allocated: + interp.add_validation_error( + stmt, + ir.ValidationError( + stmt, + "The number of qubits in the terminal measurement does not match the number of total qubits allocated! " + + f"{total_qubits_allocated} qubits were allocated but only {len(measure_lattice_element.data)} were measured.", + ), + ) + return (interp.lattice.bottom(),) + + return (interp.lattice.bottom(),) + + +@func.dialect.register(key="gemini.validate.terminal_measurement") +class Func(AddressFuncMethodTable): + pass diff --git a/src/bloqade/gemini/dialects/logical/groups.py b/src/bloqade/gemini/dialects/logical/groups.py index b9258aae..c4c7b9d2 100644 --- a/src/bloqade/gemini/dialects/logical/groups.py +++ b/src/bloqade/gemini/dialects/logical/groups.py @@ -10,7 +10,6 @@ from bloqade.squin import gate, qubit from bloqade.rewrite.passes import AggressiveUnroll -from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidation from ._dialect import dialect @@ -63,7 +62,17 @@ def run_pass( default_pass.fixpoint(mt) if verify: - validator = ValidationSuite([GeminiLogicalValidation]) + # stop circular import problems + from bloqade.gemini.analysis.logical_validation import ( + GeminiLogicalValidation, + ) + from bloqade.gemini.analysis.measurement_validation import ( + GeminiTerminalMeasurementValidation, + ) + + validator = ValidationSuite( + [GeminiLogicalValidation, GeminiTerminalMeasurementValidation] + ) validation_result = validator.validate(mt) validation_result.raise_if_invalid() mt.verify() diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index 19bc5f06..9745e1d7 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -1,7 +1,7 @@ from kirin.dialects import scf from kirin.passes.inline import InlinePass -from bloqade import squin +from bloqade import squin, gemini from bloqade.analysis.measure_id import MeasurementIDAnalysis from bloqade.stim.passes.flatten import Flatten from bloqade.analysis.measure_id.lattice import ( @@ -339,3 +339,23 @@ def test(): assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools assert frame.get(results["is_one_bools"]) == expected_is_one_bools assert frame.get(results["is_lost_bools"]) == expected_is_lost_bools + + +def test_terminal_logical_measurement(): + + @gemini.logical.kernel(no_raise=False, typeinfer=True, aggressive_unroll=True) + def tm_logical_kernel(): + q = squin.qalloc(3) + tm = gemini.logical.terminal_measure(q) + return tm + + frame, _ = MeasurementIDAnalysis(tm_logical_kernel.dialects).run(tm_logical_kernel) + # will have a MeasureIdTuple that's not from the terminal measurement, + # basically a container of InvalidMeasureIds from the qubits that get allocated + analysis_results = [ + val for val in frame.entries.values() if isinstance(val, MeasureIdTuple) + ] + expected_result = MeasureIdTuple( + data=tuple([RawMeasureId(idx=i) for i in range(1, 4)]) + ) + assert expected_result in analysis_results diff --git a/test/gemini/test_logical_validation.py b/test/gemini/test_logical_validation.py index b814b04a..736d49fa 100644 --- a/test/gemini/test_logical_validation.py +++ b/test/gemini/test_logical_validation.py @@ -1,5 +1,4 @@ import pytest -from kirin import ir from kirin.validation import ValidationSuite from kirin.ir.exception import ValidationErrorGroup @@ -9,6 +8,9 @@ GeminiLogicalValidation, _GeminiLogicalValidationAnalysis, ) +from bloqade.gemini.analysis.measurement_validation.analysis import ( + GeminiTerminalMeasurementValidation, +) def test_if_stmt_invalid(): @@ -122,21 +124,39 @@ def main(): main.print() - with pytest.raises(ir.ValidationError): + @gemini.logical.kernel( + verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True + ) + def not_all_qubits_consumed(): + qs = squin.qalloc(3) + sub_qs = qs[0:2] + tm = gemini.logical.terminal_measure(sub_qs) + return tm - @gemini.logical.kernel(no_raise=False) - def invalid(): - q = squin.qalloc(3) - squin.x(q[0]) - m = gemini.logical.terminal_measure(q) - another_m = gemini.logical.terminal_measure(q) - return m, another_m + validator = ValidationSuite([GeminiTerminalMeasurementValidation]) + validation_result = validator.validate(not_all_qubits_consumed) - frame, _ = _GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise( - invalid - ) + with pytest.raises(ValidationErrorGroup): + validation_result.raise_if_invalid() - invalid.print(analysis=frame.entries) + @gemini.logical.kernel(verify=False) + def terminal_measure_kernel(q): + return gemini.logical.terminal_measure(q) + + @gemini.logical.kernel( + verify=False, no_raise=False, aggressive_unroll=True, typeinfer=True + ) + def terminal_measure_in_kernel(): + q = squin.qalloc(10) + sub_qs = q[:2] + m = terminal_measure_kernel(sub_qs) + return m + + validator = ValidationSuite([GeminiTerminalMeasurementValidation]) + validation_result = validator.validate(terminal_measure_in_kernel) + + with pytest.raises(ValidationErrorGroup): + validation_result.raise_if_invalid() def test_multiple_errors(): @@ -159,6 +179,6 @@ def main(n: int): except ValidationErrorGroup as e: did_error = True - assert len(e.errors) == 3 + assert len(e.errors) == 4 assert did_error