Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.analysis import const
from kirin.dialects import py, scf, func, ilist

from bloqade import qubit, annotate
from bloqade import qubit, gemini, annotate

from .lattice import (
Predicate,
Expand All @@ -15,6 +15,8 @@
)
from .analysis import MeasureIDFrame, MeasurementIDAnalysis

# from bloqade.gemini.dialects.logical import stmts as gemini_stmts, dialect as logical_dialect


@qubit.dialect.register(key="measure_id")
class SquinQubit(interp.MethodTable):
Expand Down Expand Up @@ -74,6 +76,31 @@ def measurement_predicate(
return (MeasureIdTuple(data=tuple(predicate_measure_ids)),)


@gemini.logical.dialect.register(key="measure_id")
class LogicalQubit(interp.MethodTable):
@interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
def terminal_measurement(
self,
interp: MeasurementIDAnalysis,
frame: interp.Frame,
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
):
# try to get the length of the list
qubits_type = stmt.qubits.type
# vars[0] is just the type of the elements in the ilist,
# vars[1] can contain a literal with length information
num_qubits = qubits_type.vars[1]
if not isinstance(num_qubits, kirin_types.Literal):
return (AnyMeasureId(),)

measure_id_bools = []
for _ in range(num_qubits.data):
interp.measure_count += 1
measure_id_bools.append(RawMeasureId(interp.measure_count))

return (MeasureIdTuple(data=tuple(measure_id_bools)),)


@annotate.dialect.register(key="measure_id")
class Annotate(interp.MethodTable):
@interp.impl(annotate.stmts.SetObservable)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import impls as impls, analysis as analysis # NOTE: register methods
from .analysis import (
GeminiTerminalMeasurementValidation as GeminiTerminalMeasurementValidation,
)
64 changes: 64 additions & 0 deletions src/bloqade/gemini/analysis/measurement_validation/analysis.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix the circular import put this validation inside logical folder since this is only a logical program validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By extension that would also mean that the current logical_validation should also go in there as well? Just as a heads up I'll be splitting up the rules in logical_validation right after this PR (so the scf, gate, and func impls are separate validations you can compose)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gave it a shot, didn't play nice ): Willing to revisit in a subsequent PR though, that will involve shuffling more things around anyways

Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any
from dataclasses import field, dataclass

from kirin import ir
from kirin.lattice import EmptyLattice
from kirin.analysis import Forward, ForwardFrame
from kirin.validation import ValidationPass

from bloqade.analysis import address, measure_id


@dataclass
class _GeminiTerminalMeasurementValidationAnalysis(Forward[EmptyLattice]):
keys = ["gemini.validate.terminal_measurement"]

address_analysis_results: ForwardFrame
measurement_analysis_results: ForwardFrame
num_terminal_measurements: int = 0
lattice = EmptyLattice

# boilerplate, not really worried about these right now
def eval_fallback(self, frame: ForwardFrame, node: ir.Statement):
return tuple(self.lattice.bottom() for _ in range(len(node.results)))

def method_self(self, method: ir.Method) -> EmptyLattice:
return self.lattice.bottom()


@dataclass
class GeminiTerminalMeasurementValidation(ValidationPass):

analysis_cache: dict = field(default_factory=dict)

def name(self) -> str:
return "Gemini Terminal Measurement Validation"

def get_required_analyses(self) -> list[type]:
return [measure_id.MeasurementIDAnalysis, address.AddressAnalysis]

def set_analysis_cache(self, cache: dict[type, Any]) -> None:
self.analysis_cache.update(cache)

def run(self, method: ir.Method) -> tuple[Any, list[ir.ValidationError]]:

# get the data out of the cache and forward it to the underlying analysis
address_analysis_results = self.analysis_cache.get(address.AddressAnalysis)
measurement_analysis_results = self.analysis_cache.get(
measure_id.MeasurementIDAnalysis
)

assert (
address_analysis_results is not None
), "Address analysis results not found in cache"
assert (
measurement_analysis_results is not None
), "Measurement ID analysis results not found in cache"

analysis = _GeminiTerminalMeasurementValidationAnalysis(
method.dialects, address_analysis_results, measurement_analysis_results
)

frame, _ = analysis.run(method)

return frame, analysis.get_validation_errors()
129 changes: 129 additions & 0 deletions src/bloqade/gemini/analysis/measurement_validation/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from kirin import ir, interp as _interp
from kirin.analysis import ForwardFrame
from kirin.dialects import func

from bloqade import qubit, gemini
from bloqade.analysis.address.lattice import AddressReg, AddressQubit
from bloqade.analysis.measure_id.lattice import MeasureIdTuple

from .analysis import _GeminiTerminalMeasurementValidationAnalysis


@qubit.dialect.register(key="gemini.validate.terminal_measurement")
class __QubitGeminiMeasurementValidation(_interp.MethodTable):

# This is a non-logical measurement, can safely flag as invalid
@_interp.impl(qubit.stmts.Measure)
def measure(
self,
interp: _GeminiTerminalMeasurementValidationAnalysis,
frame: ForwardFrame,
stmt: qubit.stmts.Measure,
):

interp.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"Non-terminal measurements are not allowed in Gemini programs!",
),
)

return (interp.lattice.bottom(),)


@gemini.logical.dialect.register(key="gemini.validate.terminal_measurement")
class __GeminiLogicalMeasurementValidation(_interp.MethodTable):

# This is a logical terminal measurement, which is allowed
# but we impose the following restrictions:
# 1. All qubits spawned MUST be consumed
@_interp.impl(gemini.logical.stmts.TerminalLogicalMeasurement)
def terminal_measure(
self,
interp: _GeminiTerminalMeasurementValidationAnalysis,
frame: ForwardFrame,
stmt: gemini.logical.stmts.TerminalLogicalMeasurement,
):

# should only be one MeasureIDTuple
interp.num_terminal_measurements += 1

if interp.num_terminal_measurements > 1:
interp.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"Multiple terminal measurements are not allowed in Gemini logical programs!",
),
)
return (interp.lattice.bottom(),)

# If we confirm there isn't a duplicate terminal measurement,
# now we need to check that for all the qubits that were spawned,
# they are all consumed by this measurement
address_analysis_results = interp.address_analysis_results
measurement_analysis_results = interp.measurement_analysis_results

# should just be one MeasureIDTuple
measure_lattice_element = measurement_analysis_results.get_values(stmt.results)[
0
]

# Figure out the total number of qubits spawned, keeping in mind that if a user
# "shuffles" the qubits (puts them in a new container, splits one off from a container type, etc.)
# it should be accounted for. This would be much cleaner if there was a way to propagate the
# final qubit count saved in the actual interpreter for address analysis...
witnessed_qubits = set()
total_qubits_allocated = 0
for address_lattice_elem in address_analysis_results.entries.values():
if isinstance(address_lattice_elem, AddressReg):
for member in address_lattice_elem.data:
if (
isinstance(member, AddressQubit)
and member.data not in witnessed_qubits
):
witnessed_qubits.add(member.data)
total_qubits_allocated += 1

if (
isinstance(address_lattice_elem, AddressQubit)
and address_lattice_elem.data not in witnessed_qubits
):
witnessed_qubits.add(address_lattice_elem.data)
total_qubits_allocated += 1

# could make these proper exceptions but would be tricky to communicate to user
# without revealing under-the-hood details
assert isinstance(measure_lattice_element, MeasureIdTuple)

if len(measure_lattice_element.data) != total_qubits_allocated:
interp.add_validation_error(
stmt,
ir.ValidationError(
stmt,
"The number of qubits in the terminal measurement does not match the number of total qubits allocated! "
+ f"{total_qubits_allocated} qubits were allocated but only {len(measure_lattice_element.data)} were measured.",
),
)
return (interp.lattice.bottom(),)

return (interp.lattice.bottom(),)


@func.dialect.register(key="gemini.validate.terminal_measurement")
class Func(_interp.MethodTable):
@_interp.impl(func.Invoke)
def return_(
self,
interp: _GeminiTerminalMeasurementValidationAnalysis,
frame: ForwardFrame,
stmt: func.Invoke,
):
_, ret = interp.call(
stmt.callee.code,
interp.method_self(stmt.callee),
*frame.get_values(stmt.inputs),
)

return (ret,)
6 changes: 5 additions & 1 deletion src/bloqade/gemini/dialects/logical/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from bloqade.squin import gate, qubit
from bloqade.rewrite.passes import AggressiveUnroll
from bloqade.gemini.analysis.logical_validation import GeminiLogicalValidation

from ._dialect import dialect

Expand Down Expand Up @@ -63,6 +62,11 @@ def run_pass(
default_pass.fixpoint(mt)

if verify:
# stop circular import problems
from bloqade.gemini.analysis.logical_validation import (
GeminiLogicalValidation,
)

validator = ValidationSuite([GeminiLogicalValidation])
validation_result = validator.validate(mt)
validation_result.raise_if_invalid()
Expand Down
22 changes: 21 additions & 1 deletion test/analysis/measure_id/test_measure_id.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from kirin.dialects import scf
from kirin.passes.inline import InlinePass

from bloqade import squin
from bloqade import squin, gemini
from bloqade.analysis.measure_id import MeasurementIDAnalysis
from bloqade.stim.passes.flatten import Flatten
from bloqade.analysis.measure_id.lattice import (
Expand Down Expand Up @@ -339,3 +339,23 @@ def test():
assert frame.get(results["is_zero_bools"]) == expected_is_zero_bools
assert frame.get(results["is_one_bools"]) == expected_is_one_bools
assert frame.get(results["is_lost_bools"]) == expected_is_lost_bools


def test_terminal_logical_measurement():

@gemini.logical.kernel(no_raise=False, typeinfer=True, aggressive_unroll=True)
def tm_logical_kernel():
q = squin.qalloc(3)
tm = gemini.logical.terminal_measure(q)
return tm

frame, _ = MeasurementIDAnalysis(tm_logical_kernel.dialects).run(tm_logical_kernel)
# will have a MeasureIdTuple that's not from the terminal measurement,
# basically a container of InvalidMeasureIds from the qubits that get allocated
analysis_results = [
val for val in frame.entries.values() if isinstance(val, MeasureIdTuple)
]
expected_result = MeasureIdTuple(
data=tuple([RawMeasureId(idx=i) for i in range(1, 4)])
)
assert expected_result in analysis_results
42 changes: 29 additions & 13 deletions test/gemini/test_logical_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import pytest
from kirin import ir
from kirin.validation import ValidationSuite
from kirin.ir.exception import ValidationErrorGroup

Expand All @@ -9,6 +8,9 @@
GeminiLogicalValidation,
_GeminiLogicalValidationAnalysis,
)
from bloqade.gemini.analysis.measurement_validation.analysis import (
GeminiTerminalMeasurementValidation,
)


def test_if_stmt_invalid():
Expand Down Expand Up @@ -122,21 +124,35 @@ def main():

main.print()

with pytest.raises(ir.ValidationError):
@gemini.logical.kernel(no_raise=False, aggressive_unroll=True, typeinfer=True)
def not_all_qubits_consumed():
qs = squin.qalloc(3)
sub_qs = qs[0:2]
tm = gemini.logical.terminal_measure(sub_qs)
return tm

@gemini.logical.kernel(no_raise=False)
def invalid():
q = squin.qalloc(3)
squin.x(q[0])
m = gemini.logical.terminal_measure(q)
another_m = gemini.logical.terminal_measure(q)
return m, another_m
validator = ValidationSuite([GeminiTerminalMeasurementValidation])
validation_result = validator.validate(not_all_qubits_consumed)

frame, _ = _GeminiLogicalValidationAnalysis(invalid.dialects).run_no_raise(
invalid
)
with pytest.raises(ValidationErrorGroup):
validation_result.raise_if_invalid()

invalid.print(analysis=frame.entries)
@gemini.logical.kernel
def terminal_measure_kernel(q):
return gemini.logical.terminal_measure(q)

@gemini.logical.kernel(no_raise=False, aggressive_unroll=True, typeinfer=True)
def terminal_measure_in_kernel():
q = squin.qalloc(10)
sub_qs = q[:2]
m = terminal_measure_kernel(sub_qs)
return m

validator = ValidationSuite([GeminiTerminalMeasurementValidation])
validation_result = validator.validate(terminal_measure_in_kernel)

with pytest.raises(ValidationErrorGroup):
validation_result.raise_if_invalid()


def test_multiple_errors():
Expand Down
Loading