diff --git a/src/bloqade/analysis/measure_id/impls.py b/src/bloqade/analysis/measure_id/impls.py index 381e66f46..993b97bd3 100644 --- a/src/bloqade/analysis/measure_id/impls.py +++ b/src/bloqade/analysis/measure_id/impls.py @@ -2,7 +2,7 @@ from kirin.analysis import const from kirin.dialects import py, scf, func, ilist -from bloqade.squin import qubit +from bloqade import qubit from .lattice import ( AnyMeasureId, @@ -21,22 +21,12 @@ @qubit.dialect.register(key="measure_id") class SquinQubit(interp.MethodTable): - @interp.impl(qubit.MeasureQubit) - def measure_qubit( - self, - interp: MeasurementIDAnalysis, - frame: interp.Frame, - stmt: qubit.MeasureQubit, - ): - interp.measure_count += 1 - return (MeasureIdBool(interp.measure_count),) - - @interp.impl(qubit.MeasureQubitList) + @interp.impl(qubit.stmts.Measure) def measure_qubit_list( self, interp: MeasurementIDAnalysis, frame: interp.Frame, - stmt: qubit.MeasureQubitList, + stmt: qubit.stmts.Measure, ): # try to get the length of the list diff --git a/src/bloqade/cirq_utils/emit/base.py b/src/bloqade/cirq_utils/emit/base.py index 8f585b877..37dd5095c 100644 --- a/src/bloqade/cirq_utils/emit/base.py +++ b/src/bloqade/cirq_utils/emit/base.py @@ -230,3 +230,17 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke): "Function invokes should need to be inlined! " "If you called the emit_circuit method, that should have happened, please report this issue." ) + + @impl(func.Return) + def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return): + # NOTE: should only be hit if ignore_returns == True + return () + + +@py.indexing.dialect.register(key="emit.cirq") +class __Concrete(interp.MethodTable): + + @interp.impl(py.indexing.GetItem) + def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem): + # NOTE: no support for indexing into single statements in cirq + return () diff --git a/src/bloqade/cirq_utils/emit/qubit.py b/src/bloqade/cirq_utils/emit/qubit.py index 86ddf2fba..222d17983 100644 --- a/src/bloqade/cirq_utils/emit/qubit.py +++ b/src/bloqade/cirq_utils/emit/qubit.py @@ -1,7 +1,7 @@ import cirq from kirin.interp import MethodTable, impl -from bloqade.squin import qubit +from bloqade.qubit import stmts as qubit from .base import EmitCirq, EmitCirqFrame @@ -18,17 +18,9 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New): frame.qubit_index += 1 return (cirq_qubit,) - @impl(qubit.MeasureQubit) - def measure_qubit( - self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit - ): - qbit = frame.get(stmt.qubit) - frame.circuit.append(cirq.measure(qbit)) - return (emit.void,) - - @impl(qubit.MeasureQubitList) + @impl(qubit.Measure) def measure_qubit_list( - self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList + self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure ): qbits = frame.get(stmt.qubits) frame.circuit.append(cirq.measure(qbits)) diff --git a/src/bloqade/cirq_utils/lowering.py b/src/bloqade/cirq_utils/lowering.py index 53da263f1..8806e8799 100644 --- a/src/bloqade/cirq_utils/lowering.py +++ b/src/bloqade/cirq_utils/lowering.py @@ -6,7 +6,8 @@ from kirin.rewrite import Walk, CFGCompactify from kirin.dialects import py, scf, func, ilist -from bloqade.squin import gate, noise, qubit, kernel, qalloc +from bloqade import qubit +from bloqade.squin import gate, noise, kernel, qalloc def load_circuit( @@ -403,13 +404,8 @@ def bool_op_or(x: bool, y: bool) -> bool: def visit_MeasurementGate( self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation ): - cirq_qubits = node.qubits - if len(cirq_qubits) == 1: - qbit = self.lower_qubit_getindex(state, node.qubits[0]) - stmt = state.current_frame.push(qubit.MeasureQubit(qbit)) - else: - qubits = self.lower_qubit_getindices(state, node.qubits) - stmt = state.current_frame.push(qubit.MeasureQubitList(qubits)) + qubits = self.lower_qubit_getindices(state, node.qubits) + stmt = state.current_frame.push(qubit.stmts.Measure(qubits)) # NOTE: add for classically controlled lowering key = node.gate.key diff --git a/src/bloqade/native/_prelude.py b/src/bloqade/native/_prelude.py index 7fd06a64f..0776de0c1 100644 --- a/src/bloqade/native/_prelude.py +++ b/src/bloqade/native/_prelude.py @@ -5,7 +5,7 @@ from kirin.prelude import structural_no_opt from typing_extensions import Doc -from bloqade.squin import qubit +from bloqade import qubit from .dialects import gates diff --git a/src/bloqade/native/dialects/gates/_interface.py b/src/bloqade/native/dialects/gates/_interface.py index 1e1c88308..6d68a5891 100644 --- a/src/bloqade/native/dialects/gates/_interface.py +++ b/src/bloqade/native/dialects/gates/_interface.py @@ -3,7 +3,7 @@ from kirin import lowering from kirin.dialects import ilist -from bloqade.squin import qubit +from bloqade import qubit from .stmts import CZ, R, Rz diff --git a/src/bloqade/native/dialects/gates/stmts.py b/src/bloqade/native/dialects/gates/stmts.py index 064a3fd16..458af507d 100644 --- a/src/bloqade/native/dialects/gates/stmts.py +++ b/src/bloqade/native/dialects/gates/stmts.py @@ -2,7 +2,7 @@ from kirin.decl import info, statement from kirin.dialects import ilist -from bloqade.squin import qubit +from bloqade.types import QubitType from ._dialect import dialect @@ -12,14 +12,14 @@ @statement(dialect=dialect) class CZ(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - ctrls: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N]) - qargs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N]) + ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) + qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N]) @statement(dialect=dialect) class R(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any]) + inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) axis_angle: ir.SSAValue = info.argument(types.Float) rotation_angle: ir.SSAValue = info.argument(types.Float) @@ -27,5 +27,5 @@ class R(ir.Statement): @statement(dialect=dialect) class Rz(ir.Statement): traits = frozenset({lowering.FromPythonCall()}) - inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any]) + inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) rotation_angle: ir.SSAValue = info.argument(types.Float) diff --git a/src/bloqade/native/stdlib/broadcast.py b/src/bloqade/native/stdlib/broadcast.py index 36e0783ab..e5a10ccc6 100644 --- a/src/bloqade/native/stdlib/broadcast.py +++ b/src/bloqade/native/stdlib/broadcast.py @@ -3,7 +3,7 @@ from kirin.dialects import ilist -from bloqade.squin import qubit +from bloqade import qubit from bloqade.native._prelude import kernel from bloqade.native.dialects.gates import _interface as native diff --git a/src/bloqade/native/stdlib/simple.py b/src/bloqade/native/stdlib/simple.py index bc01634b8..3ead67785 100644 --- a/src/bloqade/native/stdlib/simple.py +++ b/src/bloqade/native/stdlib/simple.py @@ -1,6 +1,6 @@ from kirin.dialects import ilist -from bloqade.squin import qubit +from bloqade import qubit from . import broadcast from .._prelude import kernel diff --git a/src/bloqade/pyqrack/squin/qubit.py b/src/bloqade/pyqrack/squin/qubit.py index f04106ed7..a1e6e6818 100644 --- a/src/bloqade/pyqrack/squin/qubit.py +++ b/src/bloqade/pyqrack/squin/qubit.py @@ -3,7 +3,7 @@ from kirin import interp from kirin.dialects import ilist -from bloqade.squin import qubit +from bloqade.qubit import stmts as qubit from bloqade.pyqrack.reg import QubitState, Measurement, PyQrackQubit from bloqade.pyqrack.base import PyQrackInterpreter @@ -27,38 +27,32 @@ def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter): interp.set_global_measurement_id(m) return m - @interp.impl(qubit.MeasureQubitList) + @interp.impl(qubit.Measure) def measure_qubit_list( self, interp: PyQrackInterpreter, frame: interp.Frame, - stmt: qubit.MeasureQubitList, + stmt: qubit.Measure, ): qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) result = ilist.IList([self._measure_qubit(qbit, interp) for qbit in qubits]) return (result,) - @interp.impl(qubit.MeasureQubit) - def measure_qubit( - self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit - ): - qbit: PyQrackQubit = frame.get(stmt.qubit) - result = self._measure_qubit(qbit, interp) - return (result,) - @interp.impl(qubit.QubitId) def qubit_id( self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId ): - qbit: PyQrackQubit = frame.get(stmt.qubit) - return (qbit.addr,) + qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits) + ids = ilist.IList([qbit.addr for qbit in qubits]) + return (ids,) @interp.impl(qubit.MeasurementId) def measurement_id( self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId ): - measurement: Measurement = frame.get(stmt.measurement) - return (measurement.measurement_id,) + measurements: ilist.IList[Measurement, Any] = frame.get(stmt.measurements) + ids = ilist.IList([measurement.measurement_id for measurement in measurements]) + return (ids,) @interp.impl(qubit.Reset) def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset): diff --git a/src/bloqade/qubit/__init__.py b/src/bloqade/qubit/__init__.py new file mode 100644 index 000000000..e8881b4ac --- /dev/null +++ b/src/bloqade/qubit/__init__.py @@ -0,0 +1,12 @@ +from bloqade.types import Qubit as Qubit, QubitType as QubitType + +from . import stmts as stmts, analysis as analysis +from .stdlib import new as new, qalloc as qalloc, broadcast as broadcast +from ._dialect import dialect as dialect +from ._prelude import kernel as kernel +from .stdlib.simple import ( + reset as reset, + measure as measure, + get_qubit_id as get_qubit_id, + get_measurement_id as get_measurement_id, +) diff --git a/src/bloqade/qubit/_dialect.py b/src/bloqade/qubit/_dialect.py new file mode 100644 index 000000000..5c431df2a --- /dev/null +++ b/src/bloqade/qubit/_dialect.py @@ -0,0 +1,3 @@ +from kirin import ir + +dialect = ir.Dialect("qubit") diff --git a/src/bloqade/qubit/_interface.py b/src/bloqade/qubit/_interface.py new file mode 100644 index 000000000..86910cdd2 --- /dev/null +++ b/src/bloqade/qubit/_interface.py @@ -0,0 +1,49 @@ +from typing import Any, TypeVar + +from kirin.dialects import ilist +from kirin.lowering import wraps + +from bloqade.types import Qubit, MeasurementResult + +from .stmts import New, Reset, Measure, QubitId, MeasurementId + + +@wraps(New) +def new() -> Qubit: + """Create a new qubit. + + Returns: + Qubit: A new qubit. + """ + ... + + +N = TypeVar("N", bound=int) + + +@wraps(Measure) +def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]: + """Measure a list of qubits. + + Args: + qubits (IList[Qubit, N]): The list of qubits to measure. + + Returns: + IList[MeasurementResult, N]: The list containing the results of the measurements. + A MeasurementResult can represent both 0 and 1, but also atoms that are lost. + """ + ... + + +@wraps(QubitId) +def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]: ... + + +@wraps(MeasurementId) +def get_measurement_id( + measurements: ilist.IList[MeasurementResult, N], +) -> ilist.IList[int, N]: ... + + +@wraps(Reset) +def reset(qubits: ilist.IList[Qubit, Any]) -> None: ... diff --git a/src/bloqade/qubit/_prelude.py b/src/bloqade/qubit/_prelude.py new file mode 100644 index 000000000..5c402c6a6 --- /dev/null +++ b/src/bloqade/qubit/_prelude.py @@ -0,0 +1,45 @@ +from typing import Annotated + +from kirin import ir +from kirin.passes import Default +from kirin.prelude import structural_no_opt +from typing_extensions import Doc + +from . import _dialect as qubit + + +@ir.dialect_group(structural_no_opt.union([qubit])) +def kernel(self): + """Compile to a qubit kernel""" + + def run_pass( + mt, + *, + verify: Annotated[ + bool, Doc("run `verify` before running passes, default is `True`") + ] = True, + typeinfer: Annotated[ + bool, + Doc( + "run type inference and apply the inferred type to IR, default `False`" + ), + ] = False, + fold: Annotated[bool, Doc("run folding passes, default is `True`")] = True, + aggressive: Annotated[ + bool, Doc("run aggressive folding passes if `fold=True`") + ] = False, + no_raise: Annotated[ + bool, Doc("do not raise exception during analysis, default is `True`") + ] = True, + ) -> None: + default_pass = Default( + self, + verify=verify, + fold=fold, + aggressive=aggressive, + typeinfer=typeinfer, + no_raise=no_raise, + ) + default_pass.fixpoint(mt) + + return run_pass diff --git a/src/bloqade/qubit/analysis/__init__.py b/src/bloqade/qubit/analysis/__init__.py new file mode 100644 index 000000000..dea64d4e8 --- /dev/null +++ b/src/bloqade/qubit/analysis/__init__.py @@ -0,0 +1 @@ +from . import address_impl as address_impl diff --git a/src/bloqade/squin/analysis/address_impl.py b/src/bloqade/qubit/analysis/address_impl.py similarity index 86% rename from src/bloqade/squin/analysis/address_impl.py rename to src/bloqade/qubit/analysis/address_impl.py index cecc00cc9..d095a6e93 100644 --- a/src/bloqade/squin/analysis/address_impl.py +++ b/src/bloqade/qubit/analysis/address_impl.py @@ -7,7 +7,8 @@ ) from bloqade.analysis.address.analysis import AddressAnalysis -from .. import qubit +from .. import stmts +from .._dialect import dialect # Address lattice elements we can work with: ## NotQubit (bottom), AnyAddress (top) @@ -24,15 +25,15 @@ ### Base qubit address type -@qubit.dialect.register(key="qubit.address") +@dialect.register(key="qubit.address") class SquinQubitMethodTable(interp.MethodTable): - @interp.impl(qubit.New) + @interp.impl(stmts.New) def new_qubit( self, interp_: AddressAnalysis, frame: ForwardFrame[Address], - stmt: qubit.New, + stmt: stmts.New, ): addr = AddressQubit(interp_.next_address) interp_.next_address += 1 diff --git a/src/bloqade/qubit/stdlib/__init__.py b/src/bloqade/qubit/stdlib/__init__.py new file mode 100644 index 000000000..7902eea43 --- /dev/null +++ b/src/bloqade/qubit/stdlib/__init__.py @@ -0,0 +1,2 @@ +from . import simple as simple, broadcast as broadcast +from ._new import new as new, qalloc as qalloc diff --git a/src/bloqade/squin/stdlib/qubit.py b/src/bloqade/qubit/stdlib/_new.py similarity index 57% rename from src/bloqade/squin/stdlib/qubit.py rename to src/bloqade/qubit/stdlib/_new.py index b46f0e770..66da365b8 100644 --- a/src/bloqade/squin/stdlib/qubit.py +++ b/src/bloqade/qubit/stdlib/_new.py @@ -2,9 +2,21 @@ from kirin.dialects import ilist -from .. import qubit, kernel +from .. import _interface as qubit +from .._prelude import kernel +@kernel(typeinfer=True) +def new() -> qubit.Qubit: + """Allocate a single new qubit + + Returns: + (Qubit): The newly allocated qubit. + """ + return qubit.new() + + +# NOTE: this is a special case, that doesn't use the usual simple / broadcast semantics. @kernel(typeinfer=True) def qalloc(n_qubits: int) -> ilist.IList[qubit.Qubit, Any]: """Allocate a new list of qubits. diff --git a/src/bloqade/qubit/stdlib/broadcast.py b/src/bloqade/qubit/stdlib/broadcast.py new file mode 100644 index 000000000..09e9adff2 --- /dev/null +++ b/src/bloqade/qubit/stdlib/broadcast.py @@ -0,0 +1,62 @@ +from typing import Any, TypeVar + +from kirin.dialects import ilist + +from bloqade.types import Qubit, MeasurementResult + +from .. import _interface as _qubit +from .._prelude import kernel + +N = TypeVar("N", bound=int) + + +@kernel +def reset(qubits: ilist.IList[Qubit, Any]) -> None: + """ + Reset a list of qubits to the zero state. + + Args: + qubits (IList[Qubit, Any]): The list of qubits to reset. + """ + _qubit.reset(qubits) + + +@kernel +def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]: + """Measure a list of qubits. + + Args: + qubits (IList[Qubit, N]): The list of qubits to measure. + + Returns: + IList[MeasurementResult, N]: The list containing the results of the measurements. + A MeasurementResult can represent both 0 and 1 as well as atom loss. + """ + return _qubit.measure(qubits) + + +@kernel +def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]: + """Get the global, unique ID of each qubit in the list. + + Args: + qubits (IList[Qubit, N]): The list of qubits of which you want the ID. + + Returns: + qubit_ids (IList[int, N]): The list of global, unique IDs of the qubits. + """ + return _qubit.get_qubit_id(qubits) + + +@kernel +def get_measurement_id( + measurements: ilist.IList[MeasurementResult, N], +) -> ilist.IList[int, N]: + """Get the global, unique ID of each of the measurement results in the list. + + Args: + measurements (IList[MeasurementResult, N]): The previously taken measurement of which you want to know the ID. + Returns: + measurement_ids (IList[int, N]): The list of global, unique IDs of the measurements. + """ + return _qubit.get_measurement_id(measurements) diff --git a/src/bloqade/qubit/stdlib/simple.py b/src/bloqade/qubit/stdlib/simple.py new file mode 100644 index 000000000..29953611b --- /dev/null +++ b/src/bloqade/qubit/stdlib/simple.py @@ -0,0 +1,59 @@ +from kirin.dialects import ilist + +from bloqade.types import Qubit, MeasurementResult + +from . import broadcast +from .._prelude import kernel + + +@kernel +def reset(qubit: Qubit) -> None: + """ + Reset a qubit to the zero state. + + Args: + qubit (Qubit): The list qubit to reset. + """ + return broadcast.reset(ilist.IList([qubit])) + + +@kernel +def measure(qubit: Qubit) -> MeasurementResult: + """Measure a qubit. + + Args: + qubit (Qubit): The qubit to measure. + + Returns: + MeasurementResult: The result of the measurement. + A MeasurementResult can represent both 0 and 1, but also atoms that are lost. + """ + measurement_results = broadcast.measure(ilist.IList([qubit])) + return measurement_results[0] + + +@kernel +def get_qubit_id(qubit: Qubit) -> int: + """Get the global, unique ID of the qubit. + + Args: + qubit (Qubit): The qubit of which you want the ID. + + Returns: + qubit_id (int): The global, unique ID of the qubit. + """ + ids = broadcast.get_qubit_id(ilist.IList([qubit])) + return ids[0] + + +@kernel +def get_measurement_id(measurement: MeasurementResult) -> int: + """Get the global, unique ID of the measurement result. + + Args: + measurement (MeasurementResult): The previously taken measurement of which you want to know the ID. + Returns: + measurement_id (int): The global, unique ID of the measurement. + """ + ids = broadcast.get_measurement_id(ilist.IList([measurement])) + return ids[0] diff --git a/src/bloqade/qubit/stmts.py b/src/bloqade/qubit/stmts.py new file mode 100644 index 000000000..bd99756c5 --- /dev/null +++ b/src/bloqade/qubit/stmts.py @@ -0,0 +1,60 @@ +from kirin import ir, types, interp, lowering +from kirin.decl import info, statement +from kirin.dialects import ilist + +from bloqade.types import QubitType, MeasurementResultType + +from ._dialect import dialect + + +@statement(dialect=dialect) +class New(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) + result: ir.ResultValue = info.result(QubitType) + + +Len = types.TypeVar("Len", bound=types.Int) + + +@statement(dialect=dialect) +class Measure(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len]) + result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType, Len]) + + +@statement(dialect=dialect) +class QubitId(ir.Statement): + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len]) + result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len]) + + +@statement(dialect=dialect) +class MeasurementId(ir.Statement): + traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) + measurements: ir.SSAValue = info.argument( + ilist.IListType[MeasurementResultType, Len] + ) + result: ir.ResultValue = info.result(ilist.IListType[types.Int, Len]) + + +@statement(dialect=dialect) +class Reset(ir.Statement): + traits = frozenset({lowering.FromPythonCall()}) + qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) + + +# TODO: investigate why this is needed to get type inference to be correct. +@dialect.register(key="typeinfer") +class __TypeInfer(interp.MethodTable): + @interp.impl(Measure) + def measure_list(self, _interp, frame: interp.AbstractFrame, stmt: Measure): + qubit_type = frame.get(stmt.qubits) + + if isinstance(qubit_type, types.Generic): + len_type = qubit_type.vars[1] + else: + len_type = types.Any + + return (ilist.IListType[MeasurementResultType, len_type],) diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index 949a9b233..3b4e8c881 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -1,11 +1,17 @@ from . import ( gate as gate, noise as noise, - qubit as qubit, analysis as analysis, ) +from .. import qubit as qubit +from ..qubit import ( + reset as reset, + qalloc as qalloc, + measure as measure, + get_qubit_id as get_qubit_id, + get_measurement_id as get_measurement_id, +) from .groups import kernel as kernel -from .stdlib.qubit import qalloc as qalloc from .stdlib.simple import ( h as h, s as s, diff --git a/src/bloqade/squin/analysis/__init__.py b/src/bloqade/squin/analysis/__init__.py index dea64d4e8..e69de29bb 100644 --- a/src/bloqade/squin/analysis/__init__.py +++ b/src/bloqade/squin/analysis/__init__.py @@ -1 +0,0 @@ -from . import address_impl as address_impl diff --git a/src/bloqade/squin/groups.py b/src/bloqade/squin/groups.py index d90e19a31..c23f78fe1 100644 --- a/src/bloqade/squin/groups.py +++ b/src/bloqade/squin/groups.py @@ -1,10 +1,9 @@ from kirin import ir, passes from kirin.prelude import structural_no_opt -from kirin.rewrite import Walk, Chain from kirin.dialects import debug, ilist -from . import gate, noise, qubit -from .rewrite.desugar import MeasureDesugarRule +from . import gate, noise +from .. import qubit @ir.dialect_group(structural_no_opt.union([qubit, noise, gate, debug])) @@ -12,7 +11,6 @@ def kernel(self): fold_pass = passes.Fold(self) typeinfer_pass = passes.TypeInfer(self) ilist_desugar_pass = ilist.IListDesugar(self) - desugar_pass = Walk(Chain(MeasureDesugarRule())) def run_pass(method: ir.Method, *, fold=True, typeinfer=True): method.verify() @@ -21,7 +19,6 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True): if typeinfer: typeinfer_pass(method) # infer types before desugaring - desugar_pass.rewrite(method.code) ilist_desugar_pass(method) diff --git a/src/bloqade/squin/qubit.py b/src/bloqade/squin/qubit.py deleted file mode 100644 index a46d9f828..000000000 --- a/src/bloqade/squin/qubit.py +++ /dev/null @@ -1,135 +0,0 @@ -"""qubit dialect for squin language. - -This dialect defines the operations that can be performed on qubits. - -Depends on: -- `bloqade.squin.op`: provides the `OpType` type and semantics for operators applied to qubits. -- `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits. -""" - -from typing import Any, overload - -from kirin import ir, types, interp, lowering -from kirin.decl import info, statement -from kirin.dialects import ilist -from kirin.lowering import wraps - -from bloqade.types import Qubit, QubitType, MeasurementResult, MeasurementResultType - -dialect = ir.Dialect("squin.qubit") - - -@statement(dialect=dialect) -class New(ir.Statement): - traits = frozenset({lowering.FromPythonCall()}) - result: ir.ResultValue = info.result(QubitType) - - -@statement(dialect=dialect) -class MeasureAny(ir.Statement): - name = "measure" - - traits = frozenset({lowering.FromPythonCall()}) - input: ir.SSAValue = info.argument(types.Any) - result: ir.ResultValue = info.result(types.Any) - - -@statement(dialect=dialect) -class MeasureQubit(ir.Statement): - name = "measure.qubit" - - traits = frozenset({lowering.FromPythonCall()}) - qubit: ir.SSAValue = info.argument(QubitType) - result: ir.ResultValue = info.result(MeasurementResultType) - - -Len = types.TypeVar("Len") - - -@statement(dialect=dialect) -class MeasureQubitList(ir.Statement): - name = "measure.qubit.list" - - traits = frozenset({lowering.FromPythonCall()}) - qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, Len]) - result: ir.ResultValue = info.result(ilist.IListType[MeasurementResultType, Len]) - - -@statement(dialect=dialect) -class QubitId(ir.Statement): - traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) - qubit: ir.SSAValue = info.argument(QubitType) - result: ir.ResultValue = info.result(types.Int) - - -@statement(dialect=dialect) -class MeasurementId(ir.Statement): - traits = frozenset({lowering.FromPythonCall(), ir.Pure()}) - measurement: ir.SSAValue = info.argument(MeasurementResultType) - result: ir.ResultValue = info.result(types.Int) - - -@statement(dialect=dialect) -class Reset(ir.Statement): - traits = frozenset({lowering.FromPythonCall()}) - qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any]) - - -# NOTE: no dependent types in Python, so we have to mark it Any... -@wraps(New) -def new() -> Qubit: - """Create a new qubit. - - Returns: - Qubit: A new qubit. - """ - ... - - -@overload -def measure(input: Qubit) -> MeasurementResult: ... -@overload -def measure( - input: ilist.IList[Qubit, Any] | list[Qubit], -) -> ilist.IList[MeasurementResult, Any]: ... - - -@wraps(MeasureAny) -def measure(input: Any) -> Any: - """Measure a qubit or qubits in the list. - - Args: - input: A qubit or a list of qubits to measure. - - Returns: - MeasurementResult | list[MeasurementResult]: The result of the measurement. If a single qubit is measured, - a single result is returned. If a list of qubits is measured, a list of results - is returned. - A MeasurementResult can represent both 0 and 1, but also atoms that are lost. - """ - ... - - -@wraps(QubitId) -def get_qubit_id(qubit: Qubit) -> int: ... - - -@wraps(MeasurementId) -def get_measurement_id(measurement: MeasurementResult) -> int: ... - - -# TODO: investigate why this is needed to get type inference to be correct. -@dialect.register(key="typeinfer") -class __TypeInfer(interp.MethodTable): - @interp.impl(MeasureQubitList) - def measure_list( - self, _interp, frame: interp.AbstractFrame, stmt: MeasureQubitList - ): - qubit_type = frame.get(stmt.qubits) - - if isinstance(qubit_type, types.Generic): - len_type = qubit_type.vars[1] - else: - len_type = types.Any - - return (ilist.IListType[MeasurementResultType, len_type],) diff --git a/src/bloqade/squin/rewrite/desugar.py b/src/bloqade/squin/rewrite/desugar.py deleted file mode 100644 index a5375960f..000000000 --- a/src/bloqade/squin/rewrite/desugar.py +++ /dev/null @@ -1,38 +0,0 @@ -from kirin import ir, types -from kirin.dialects import ilist -from kirin.rewrite.abc import RewriteRule, RewriteResult - -from bloqade.squin.qubit import ( - QubitType, - MeasureAny, - MeasureQubit, - MeasureQubitList, -) - - -class MeasureDesugarRule(RewriteRule): - """ - Desugar measure operations in the circuit. - """ - - def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - - if not isinstance(node, MeasureAny): - return RewriteResult() - - if node.input.type.is_subseteq(QubitType): - node.replace_by( - MeasureQubit( - qubit=node.input, - ) - ) - return RewriteResult(has_done_something=True) - elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]): - node.replace_by( - MeasureQubitList( - qubits=node.input, - ) - ) - return RewriteResult(has_done_something=True) - - return RewriteResult() diff --git a/src/bloqade/squin/rewrite/remove_dangling_qubits.py b/src/bloqade/squin/rewrite/remove_dangling_qubits.py index d63bdace4..eebd89f3b 100644 --- a/src/bloqade/squin/rewrite/remove_dangling_qubits.py +++ b/src/bloqade/squin/rewrite/remove_dangling_qubits.py @@ -1,14 +1,14 @@ from kirin import ir from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import qubit +from bloqade import qubit class RemoveDeadRegister(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: - if not isinstance(node, qubit.New): + if not isinstance(node, qubit.stmts.New): return RewriteResult() if bool(node.result.uses): diff --git a/src/bloqade/squin/rewrite/wrap_analysis.py b/src/bloqade/squin/rewrite/wrap_analysis.py index 99b38e6d9..e1a354402 100644 --- a/src/bloqade/squin/rewrite/wrap_analysis.py +++ b/src/bloqade/squin/rewrite/wrap_analysis.py @@ -5,7 +5,7 @@ from kirin.rewrite.abc import RewriteRule, RewriteResult from kirin.print.printer import Printer -from bloqade.squin import qubit +from bloqade import qubit from bloqade.analysis.address import Address diff --git a/src/bloqade/squin/stdlib/broadcast/__init__.py b/src/bloqade/squin/stdlib/broadcast/__init__.py index c8b864ed8..ae1015ea1 100644 --- a/src/bloqade/squin/stdlib/broadcast/__init__.py +++ b/src/bloqade/squin/stdlib/broadcast/__init__.py @@ -31,3 +31,4 @@ two_qubit_pauli_channel as two_qubit_pauli_channel, single_qubit_pauli_channel as single_qubit_pauli_channel, ) +from ._qubit import reset as reset, measure as measure diff --git a/src/bloqade/squin/stdlib/broadcast/_qubit.py b/src/bloqade/squin/stdlib/broadcast/_qubit.py new file mode 100644 index 000000000..f6ef7baec --- /dev/null +++ b/src/bloqade/squin/stdlib/broadcast/_qubit.py @@ -0,0 +1,4 @@ +from bloqade.qubit.stdlib.broadcast import ( + reset as reset, + measure as measure, +) diff --git a/src/bloqade/stim/passes/squin_to_stim.py b/src/bloqade/stim/passes/squin_to_stim.py index 925454b38..bf73b2c73 100644 --- a/src/bloqade/stim/passes/squin_to_stim.py +++ b/src/bloqade/stim/passes/squin_to_stim.py @@ -26,7 +26,6 @@ from bloqade.analysis.address import AddressAnalysis from bloqade.analysis.measure_id import MeasurementIDAnalysis from bloqade.stim.passes.flatten import Flatten -from bloqade.squin.rewrite.desugar import MeasureDesugarRule from ..rewrite.ifs_to_stim import IfToStim @@ -40,9 +39,6 @@ def unsafe_run(self, mt: Method) -> RewriteResult: rewrite_result = Flatten(dialects=mt.dialects, no_raise=self.no_raise).fixpoint( mt ) - rewrite_result = ( - Walk(Chain(MeasureDesugarRule())).rewrite(mt.code).join(rewrite_result) - ) # after this the program should be in a state where it is analyzable # ------------------------------------------------------------------- diff --git a/src/bloqade/stim/rewrite/qubit_to_stim.py b/src/bloqade/stim/rewrite/qubit_to_stim.py index e3dee7c31..e14cfeedd 100644 --- a/src/bloqade/stim/rewrite/qubit_to_stim.py +++ b/src/bloqade/stim/rewrite/qubit_to_stim.py @@ -1,7 +1,8 @@ from kirin import ir from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import gate, qubit +from bloqade import qubit +from bloqade.squin import gate from bloqade.squin.rewrite import AddressAttribute from bloqade.stim.dialects import gate as stim_gate, collapse as stim_collapse from bloqade.stim.rewrite.util import ( @@ -21,7 +22,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: case gate.stmts.T() | gate.stmts.RotationGate(): return RewriteResult() # If you've reached this point all gates have stim equivalents - case qubit.Reset(): + case qubit.stmts.Reset(): return self.rewrite_Reset(node) case gate.stmts.SingleQubitGate(): return self.rewrite_SingleQubitGate(node) @@ -30,7 +31,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: case _: return RewriteResult() - def rewrite_Reset(self, stmt: qubit.Reset) -> RewriteResult: + def rewrite_Reset(self, stmt: qubit.stmts.Reset) -> RewriteResult: qubit_addr_attr = stmt.qubits.hints.get("address", None) diff --git a/src/bloqade/stim/rewrite/squin_measure.py b/src/bloqade/stim/rewrite/squin_measure.py index 8926d5acc..25f4d759e 100644 --- a/src/bloqade/stim/rewrite/squin_measure.py +++ b/src/bloqade/stim/rewrite/squin_measure.py @@ -5,7 +5,7 @@ from kirin.dialects import py from kirin.rewrite.abc import RewriteRule, RewriteResult -from bloqade.squin import qubit +from bloqade import qubit from bloqade.squin.rewrite import AddressAttribute from bloqade.stim.dialects import collapse from bloqade.stim.rewrite.util import ( @@ -22,14 +22,12 @@ class SquinMeasureToStim(RewriteRule): def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: match node: - case qubit.MeasureQubit() | qubit.MeasureQubitList(): + case qubit.stmts.Measure(): return self.rewrite_Measure(node) case _: return RewriteResult() - def rewrite_Measure( - self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList - ) -> RewriteResult: + def rewrite_Measure(self, measure_stmt: qubit.stmts.Measure) -> RewriteResult: qubit_idx_ssas = self.get_qubit_idx_ssas(measure_stmt) if qubit_idx_ssas is None: @@ -52,19 +50,12 @@ def rewrite_Measure( return RewriteResult(has_done_something=True) def get_qubit_idx_ssas( - self, measure_stmt: qubit.MeasureQubit | qubit.MeasureQubitList + self, measure_stmt: qubit.stmts.Measure ) -> tuple[ir.SSAValue, ...] | None: """ Extract the address attribute and insert qubit indices for the given measure statement. """ - match measure_stmt: - case qubit.MeasureQubit(): - address_attr = measure_stmt.qubit.hints.get("address") - case qubit.MeasureQubitList(): - address_attr = measure_stmt.qubits.hints.get("address") - case _: - return None - + address_attr = measure_stmt.qubits.hints.get("address") if address_attr is None: return None diff --git a/test/analysis/measure_id/test_measure_id.py b/test/analysis/measure_id/test_measure_id.py index e59a98a4b..b3041fa6a 100644 --- a/test/analysis/measure_id/test_measure_id.py +++ b/test/analysis/measure_id/test_measure_id.py @@ -1,5 +1,5 @@ import pytest -from kirin.passes import HintConst +from kirin.passes import HintConst, inline from kirin.dialects import scf from bloqade import squin @@ -24,8 +24,8 @@ def test(): ql2 = squin.qalloc(5) squin.broadcast.x(ql1) squin.broadcast.x(ql2) - ml1 = squin.qubit.measure(ql1) - ml2 = squin.qubit.measure(ql2) + ml1 = squin.broadcast.measure(ql1) + ml2 = squin.broadcast.measure(ql2) return ml1 + ml2 frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) @@ -47,7 +47,7 @@ def test_measure_alias(): @squin.kernel def test(): ql = squin.qalloc(5) - ml = squin.qubit.measure(ql) + ml = squin.broadcast.measure(ql) ml_alias = ml return ml_alias @@ -80,7 +80,7 @@ def test_measure_count_at_if_else(): def test(): q = squin.qalloc(5) squin.x(q[2]) - ms = squin.qubit.measure(q) + ms = squin.broadcast.measure(q) if ms[1]: squin.x(q[0]) @@ -106,9 +106,9 @@ def test(): ms = None cond = True if cond: - ms = squin.qubit.measure(q) + ms = squin.broadcast.measure(q) else: - ms = squin.qubit.measure(q[0]) + ms = squin.measure(q[0]) return ms @@ -136,12 +136,14 @@ def test(): ms = None cond = False if cond: - ms = squin.qubit.measure(q) + ms = squin.broadcast.measure(q) else: ms = squin.qubit.measure(q[0]) return ms + inline.InlinePass(test.dialects).fixpoint(test) + HintConst(dialects=test.dialects).unsafe_run(test) frame, _ = MeasurementIDAnalysis(test.dialects).run_analysis(test) @@ -161,7 +163,7 @@ def test(): q = squin.qalloc(6) squin.x(q[2]) - ms = squin.qubit.measure(q) + ms = squin.broadcast.measure(q) msi = ms[1:] # MeasureIdTuple becomes a python tuple msi2 = msi[1:] # slicing should still work on previous tuple ms_final = msi2[::2] @@ -187,7 +189,7 @@ def test_getitem_no_hint(): @squin.kernel def test(idx): q = squin.qalloc(6) - ms = squin.qubit.measure(q) + ms = squin.broadcast.measure(q) return ms[idx] @@ -202,7 +204,7 @@ def test_getitem_invalid_hint(): @squin.kernel def test(): q = squin.qalloc(6) - ms = squin.qubit.measure(q) + ms = squin.broadcast.measure(q) return ms["x"] @@ -218,7 +220,7 @@ def test_getitem_propagate_invalid_measure(): @squin.kernel def test(): q = squin.qalloc(6) - ms = squin.qubit.measure(q) + ms = squin.broadcast.measure(q) # this will return an InvalidMeasureId invalid_ms = ms["x"] return invalid_ms[0] diff --git a/test/cirq_utils/test_clifford_to_cirq.py b/test/cirq_utils/test_clifford_to_cirq.py index 3d62dbd80..3cb8ef3b0 100644 --- a/test/cirq_utils/test_clifford_to_cirq.py +++ b/test/cirq_utils/test_clifford_to_cirq.py @@ -168,7 +168,7 @@ def test_measurement(): def main(): q = squin.qalloc(2) squin.broadcast.y(q) - squin.qubit.measure(q) + squin.broadcast.measure(q) circuit = emit_circuit(main) @@ -318,22 +318,12 @@ def main(): def test_reset(): - # TODO: remove this wrapper once we have a proper one - from typing import Any - - from kirin.lowering import wraps - - from bloqade.types import Qubit - - @wraps(squin.qubit.Reset) - def reset(qubits: ilist.IList[Qubit, Any]) -> None: ... - @squin.kernel def main(): q = squin.qalloc(4) squin.broadcast.x(q) - reset(q) - return squin.qubit.measure(q) + squin.broadcast.reset(q) + return squin.broadcast.measure(q) main.print() diff --git a/test/cirq_utils/test_squin_noise_to_cirq.py b/test/cirq_utils/test_squin_noise_to_cirq.py index cab6b211c..eb0b275c7 100644 --- a/test/cirq_utils/test_squin_noise_to_cirq.py +++ b/test/cirq_utils/test_squin_noise_to_cirq.py @@ -33,7 +33,7 @@ def main(): q[0], q[1], ) - squin.qubit.measure(q) + squin.broadcast.measure(q) main.print() diff --git a/test/pyqrack/squin/test_kernel.py b/test/pyqrack/squin/test_kernel.py index 9ec95b1c1..7e1e80af9 100644 --- a/test/pyqrack/squin/test_kernel.py +++ b/test/pyqrack/squin/test_kernel.py @@ -36,7 +36,7 @@ def new(): @squin.kernel def m(): q = squin.qalloc(3) - m = squin.qubit.measure(q) + m = squin.broadcast.measure(q) return m target = PyQrack(3) @@ -179,7 +179,7 @@ def broadcast_adjoint(): # rotate back down squin.broadcast.u3(-math.pi / 2.0, 0, 0, q) - return squin.qubit.measure(q) + return squin.broadcast.measure(q) target = PyQrack(3) result = target.run(broadcast_adjoint) diff --git a/test/squin/noise/test_stdlib_noise.py b/test/squin/noise/test_stdlib_noise.py index 271668005..a43455a61 100644 --- a/test/squin/noise/test_stdlib_noise.py +++ b/test/squin/noise/test_stdlib_noise.py @@ -85,7 +85,7 @@ def main(): q = squin.qalloc(1) squin.bit_flip(1.0, q[0]) squin.single_qubit_pauli_channel(0.0, 1.0, 0.0, q[0]) - return squin.qubit.measure(q) + return squin.broadcast.measure(q) sim = StackMemorySimulator(min_qubits=1) result = sim.run(main) diff --git a/test/squin/test_qubit.py b/test/squin/test_qubit.py index 82f5723fb..bc16f4cc3 100644 --- a/test/squin/test_qubit.py +++ b/test/squin/test_qubit.py @@ -10,7 +10,7 @@ def test_get_ids(): def main(): q = squin.qalloc(3) - m = squin.qubit.measure(q) + m = squin.broadcast.measure(q) qid = squin.qubit.get_qubit_id(q[0]) mid = squin.qubit.get_measurement_id(m[1]) @@ -39,7 +39,7 @@ def main2(): if m2_id != 1: squin.x(q[4]) - return squin.qubit.measure(q) + return squin.broadcast.measure(q) sim = StackMemorySimulator(min_qubits=2) result = sim.run(main2) diff --git a/test/squin/test_sugar.py b/test/squin/test_sugar.py deleted file mode 100644 index bdc0c2083..000000000 --- a/test/squin/test_sugar.py +++ /dev/null @@ -1,36 +0,0 @@ -from kirin import ir -from kirin.dialects import func - -from bloqade import squin - - -def get_return_value_stmt(kernel: ir.Method): - assert isinstance( - last_stmt := kernel.callable_region.blocks[-1].last_stmt, func.Return - ) - return last_stmt.value.owner - - -def test_measure_register(): - @squin.kernel - def test_measure_sugar(): - q = squin.qalloc(2) - - return squin.qubit.measure(q) - - assert isinstance( - get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureQubitList - ) - - -def test_measure_qubit(): - @squin.kernel - def test_measure_sugar(): - q = squin.qalloc(2) - - return squin.qubit.measure(q[0]) - - assert isinstance( - get_return_value_stmt(test_measure_sugar), - squin.qubit.MeasureQubit, - ) diff --git a/test/stim/passes/test_squin_meas_to_stim.py b/test/stim/passes/test_squin_meas_to_stim.py index 71986b338..d52555e18 100644 --- a/test/stim/passes/test_squin_meas_to_stim.py +++ b/test/stim/passes/test_squin_meas_to_stim.py @@ -33,7 +33,7 @@ def main(): n_qubits = 4 q = sq.qalloc(n_qubits) - ms = sq.qubit.measure(q) + ms = sq.broadcast.measure(q) if ms[0]: sq.z(q[0]) @@ -44,7 +44,7 @@ def main(): sq.x(q[0]) sq.y(q[1]) - sq.qubit.measure(q) + sq.broadcast.measure(q) SquinToStimPass(main.dialects)(main) @@ -59,7 +59,7 @@ def test_alias_with_measure_list(): def main(): q = sq.qalloc(4) - ms = sq.qubit.measure(q) + ms = sq.broadcast.measure(q) new_ms = ms if new_ms[0]: @@ -79,19 +79,19 @@ def main(): n_qubits = 4 q = sq.qalloc(n_qubits) - ms0 = sq.qubit.measure(q) + ms0 = sq.broadcast.measure(q) if ms0[0]: # should be rec[-4] sq.z(q[0]) # another measurement - ms1 = sq.qubit.measure(q) + ms1 = sq.broadcast.measure(q) if ms1[0]: # should be rec[-4] sq.x(q[0]) # second round of measurement - ms2 = sq.qubit.measure(q) # noqa: F841 + ms2 = sq.broadcast.measure(q) # noqa: F841 # try accessing measurements from the very first round ## There are now 12 total measurements, ms0[0] @@ -113,18 +113,18 @@ def main(): n_qubits = 4 q = sq.qalloc(n_qubits) - ms0 = sq.qubit.measure(q) + ms0 = sq.broadcast.measure(q) if ms0[0]: sq.z(q[0]) - ms1 = sq.qubit.measure(q) + ms1 = sq.broadcast.measure(q) if ms1[0]: sq.x(q[1]) # another measurement - ms2 = sq.qubit.measure(q) + ms2 = sq.broadcast.measure(q) if ms2[0]: sq.y(q[2]) diff --git a/test/stim/passes/test_squin_qubit_to_stim.py b/test/stim/passes/test_squin_qubit_to_stim.py index 14f74a5b8..eafd4c3eb 100644 --- a/test/stim/passes/test_squin_qubit_to_stim.py +++ b/test/stim/passes/test_squin_qubit_to_stim.py @@ -5,8 +5,8 @@ from kirin import ir from kirin.dialects import py -from bloqade import squin as sq -from bloqade.squin import qubit, kernel +from bloqade import qubit, squin as sq +from bloqade.squin import kernel from bloqade.stim.emit import EmitStimMain from bloqade.stim.passes import SquinToStimPass from bloqade.rewrite.passes.aggressive_unroll import AggressiveUnroll @@ -45,7 +45,7 @@ def test(): sq.broadcast.h(ql) sq.x(ql[0]) sq.cx(ql[0], ql[1]) - sq.qubit.measure(ql) + sq.broadcast.measure(ql) return SquinToStimPass(test.dialects)(test) @@ -59,7 +59,7 @@ def test(): n_qubits = 1 q = sq.qalloc(n_qubits) # reset the qubit - qubit.Reset(q) + qubit.broadcast.reset(q) # measure out sq.qubit.measure(q[0]) return @@ -78,7 +78,7 @@ def test(): # apply Hadamard to all qubits sq.broadcast.h(ql) # measure out - sq.qubit.measure(ql) + sq.broadcast.measure(ql) return SquinToStimPass(test.dialects)(test) @@ -98,7 +98,7 @@ def test(): sq.qubit_loss(p=0.1, qubit=ql[3]) sq.broadcast.qubit_loss(p=0.05, qubits=ql) # measure out - sq.qubit.measure(ql) + sq.broadcast.measure(ql) return SquinToStimPass(test.dialects)(test) @@ -116,7 +116,7 @@ def test(): # apply U3 rotation that can be translated to a Clifford gate sq.u3(0.25 * math.tau, 0.0 * math.tau, 0.5 * math.tau, qubit=q[0]) # measure out - sq.qubit.measure(q) + sq.broadcast.measure(q) return SquinToStimPass(test.dialects)(test) @@ -240,7 +240,7 @@ def main(): def test_pick_if_else(): - @sq.kernel + @sq.kernel(fold=False) def main(): q = sq.qalloc(10) if False: @@ -260,7 +260,7 @@ def test_non_pure_loop_iterator(): @kernel def test_squin_kernel(): q = sq.qalloc(5) - result = qubit.measure(q) + result = qubit.broadcast.measure(q) outputs = [] for rnd in range(len(result)): # Non-pure loop iterator outputs += [] @@ -292,7 +292,7 @@ def rep_code(): ancilla = q[1::2] # reset everything initially - qubit.Reset(q) + qubit.broadcast.reset(q) ## Initial round, entangle data qubits with ancillas. ## This entanglement will happen again so it's best we @@ -309,10 +309,10 @@ def rep_code(): entangle(cx_pairs) - qubit.measure(ancilla) + qubit.broadcast.measure(ancilla) entangle(cx_pairs) - qubit.measure(ancilla) + qubit.broadcast.measure(ancilla) # Let's make this one a bit noisy entangle(cx_pairs) @@ -321,7 +321,7 @@ def rep_code(): ) sq.broadcast.qubit_loss(p=0.001, qubits=q) - qubit.measure(ancilla) + qubit.broadcast.measure(ancilla) SquinToStimPass(rep_code.dialects)(rep_code) base_stim_prog = load_reference_program("rep_code.stim") diff --git a/test/stim/test_measure_id_analysis.py b/test/stim/test_measure_id_analysis.py index 8281356b8..3f7a79da9 100644 --- a/test/stim/test_measure_id_analysis.py +++ b/test/stim/test_measure_id_analysis.py @@ -1,4 +1,5 @@ -from bloqade.squin import qubit, kernel, qalloc +from bloqade import qubit +from bloqade.squin import kernel, qalloc from bloqade.analysis.measure_id import MeasurementIDAnalysis