Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.analysis import const
from kirin.dialects import py, scf, func, ilist

from bloqade.squin import qubit
from bloqade import qubit

from .lattice import (
AnyMeasureId,
Expand All @@ -21,22 +21,12 @@
@qubit.dialect.register(key="measure_id")
class SquinQubit(interp.MethodTable):

@interp.impl(qubit.MeasureQubit)
def measure_qubit(
self,
interp: MeasurementIDAnalysis,
frame: interp.Frame,
stmt: qubit.MeasureQubit,
):
interp.measure_count += 1
return (MeasureIdBool(interp.measure_count),)

@interp.impl(qubit.MeasureQubitList)
@interp.impl(qubit.stmts.Measure)
def measure_qubit_list(
self,
interp: MeasurementIDAnalysis,
frame: interp.Frame,
stmt: qubit.MeasureQubitList,
stmt: qubit.stmts.Measure,
):

# try to get the length of the list
Expand Down
15 changes: 15 additions & 0 deletions src/bloqade/cirq_utils/emit/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def main():
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)

AggressiveUnroll(mt_.dialects).fixpoint(mt_)
mt_.print()
return emitter.run(mt_, args=())


Expand Down Expand Up @@ -230,3 +231,17 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
"Function invokes should need to be inlined! "
"If you called the emit_circuit method, that should have happened, please report this issue."
)

@impl(func.Return)
def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
# NOTE: should only be hit if ignore_returns == True
return ()


@py.indexing.dialect.register(key="emit.cirq")
class __Concrete(interp.MethodTable):

@interp.impl(py.indexing.GetItem)
def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem):
# NOTE: no support for indexing into single statements in cirq
return ()
14 changes: 3 additions & 11 deletions src/bloqade/cirq_utils/emit/qubit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import cirq
from kirin.interp import MethodTable, impl

from bloqade.squin import qubit
from bloqade.qubit import stmts as qubit

from .base import EmitCirq, EmitCirqFrame

Expand All @@ -18,17 +18,9 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
frame.qubit_index += 1
return (cirq_qubit,)

@impl(qubit.MeasureQubit)
def measure_qubit(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit
):
qbit = frame.get(stmt.qubit)
frame.circuit.append(cirq.measure(qbit))
return (emit.void,)

@impl(qubit.MeasureQubitList)
@impl(qubit.Measure)
def measure_qubit_list(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure
):
qbits = frame.get(stmt.qubits)
frame.circuit.append(cirq.measure(qbits))
Expand Down
12 changes: 4 additions & 8 deletions src/bloqade/cirq_utils/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from kirin.rewrite import Walk, CFGCompactify
from kirin.dialects import py, scf, func, ilist

from bloqade.squin import gate, noise, qubit, kernel, qalloc
from bloqade import qubit
from bloqade.squin import gate, noise, kernel, qalloc


def load_circuit(
Expand Down Expand Up @@ -403,13 +404,8 @@ def bool_op_or(x: bool, y: bool) -> bool:
def visit_MeasurementGate(
self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
):
cirq_qubits = node.qubits
if len(cirq_qubits) == 1:
qbit = self.lower_qubit_getindex(state, node.qubits[0])
stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
else:
qubits = self.lower_qubit_getindices(state, node.qubits)
stmt = state.current_frame.push(qubit.MeasureQubitList(qubits))
qubits = self.lower_qubit_getindices(state, node.qubits)
stmt = state.current_frame.push(qubit.stmts.Measure(qubits))

# NOTE: add for classically controlled lowering
key = node.gate.key
Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/native/_prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from kirin.prelude import structural_no_opt
from typing_extensions import Doc

from bloqade.squin import qubit
from bloqade import qubit

from .dialects import gates

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/native/dialects/gates/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kirin import lowering
from kirin.dialects import ilist

from bloqade.squin import qubit
from bloqade import qubit

from .stmts import CZ, R, Rz

Expand Down
10 changes: 5 additions & 5 deletions src/bloqade/native/dialects/gates/stmts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.decl import info, statement
from kirin.dialects import ilist

from bloqade.squin import qubit
from bloqade.types import QubitType

from ._dialect import dialect

Expand All @@ -12,20 +12,20 @@
@statement(dialect=dialect)
class CZ(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
ctrls: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
qargs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn't part of the PR (so feel free to just resolve this) but I'm curious, why is it a qarg here when we usually do something like controls vs. targets?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that was introduced by @weinbe58 somewhere. I think I even left a similar comment on the PR? Should we change it? (Not in this PR, of course).



@statement(dialect=dialect)
class R(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
axis_angle: ir.SSAValue = info.argument(types.Float)
rotation_angle: ir.SSAValue = info.argument(types.Float)


@statement(dialect=dialect)
class Rz(ir.Statement):
traits = frozenset({lowering.FromPythonCall()})
inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
rotation_angle: ir.SSAValue = info.argument(types.Float)
2 changes: 1 addition & 1 deletion src/bloqade/native/stdlib/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from kirin.dialects import ilist

from bloqade.squin import qubit
from bloqade import qubit
from bloqade.native._prelude import kernel
from bloqade.native.dialects.gates import _interface as native

Expand Down
2 changes: 1 addition & 1 deletion src/bloqade/native/stdlib/simple.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from kirin.dialects import ilist

from bloqade.squin import qubit
from bloqade import qubit

from . import broadcast
from .._prelude import kernel
Expand Down
24 changes: 9 additions & 15 deletions src/bloqade/pyqrack/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kirin import interp
from kirin.dialects import ilist

from bloqade.squin import qubit
from bloqade.qubit import stmts as qubit
from bloqade.pyqrack.reg import QubitState, Measurement, PyQrackQubit
from bloqade.pyqrack.base import PyQrackInterpreter

Expand All @@ -27,38 +27,32 @@ def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
interp.set_global_measurement_id(m)
return m

@interp.impl(qubit.MeasureQubitList)
@interp.impl(qubit.Measure)
def measure_qubit_list(
self,
interp: PyQrackInterpreter,
frame: interp.Frame,
stmt: qubit.MeasureQubitList,
stmt: qubit.Measure,
):
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
result = ilist.IList([self._measure_qubit(qbit, interp) for qbit in qubits])
return (result,)

@interp.impl(qubit.MeasureQubit)
def measure_qubit(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit
):
qbit: PyQrackQubit = frame.get(stmt.qubit)
result = self._measure_qubit(qbit, interp)
return (result,)

@interp.impl(qubit.QubitId)
def qubit_id(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId
):
qbit: PyQrackQubit = frame.get(stmt.qubit)
return (qbit.addr,)
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
ids = ilist.IList([qbit.addr for qbit in qubits])
return (ids,)

@interp.impl(qubit.MeasurementId)
def measurement_id(
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId
):
measurement: Measurement = frame.get(stmt.measurement)
return (measurement.measurement_id,)
measurements: ilist.IList[Measurement, Any] = frame.get(stmt.measurements)
ids = ilist.IList([measurement.measurement_id for measurement in measurements])
return (ids,)

@interp.impl(qubit.Reset)
def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset):
Expand Down
12 changes: 12 additions & 0 deletions src/bloqade/qubit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from bloqade.types import Qubit as Qubit, QubitType as QubitType

from . import stmts as stmts, analysis as analysis
from .stdlib import new as new, qalloc as qalloc, broadcast as broadcast
from ._dialect import dialect as dialect
from ._prelude import kernel as kernel
from .stdlib.simple import (
reset as reset,
measure as measure,
get_qubit_id as get_qubit_id,
get_measurement_id as get_measurement_id,
)
3 changes: 3 additions & 0 deletions src/bloqade/qubit/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("qubit")
49 changes: 49 additions & 0 deletions src/bloqade/qubit/_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import Any, TypeVar

from kirin.dialects import ilist
from kirin.lowering import wraps

from bloqade.types import Qubit, MeasurementResult

from .stmts import New, Reset, Measure, QubitId, MeasurementId


@wraps(New)
def new() -> Qubit:
"""Create a new qubit.

Returns:
Qubit: A new qubit.
"""
...


N = TypeVar("N", bound=int)


@wraps(Measure)
def measure(qubits: ilist.IList[Qubit, N]) -> ilist.IList[MeasurementResult, N]:
"""Measure a list of qubits.

Args:
qubits (IList[Qubit, N]): The list of qubits to measure.

Returns:
IList[MeasurementResult, N]: The list containing the results of the measurements.
A MeasurementResult can represent both 0 and 1, but also atoms that are lost.
"""
...


@wraps(QubitId)
def get_qubit_id(qubits: ilist.IList[Qubit, N]) -> ilist.IList[int, N]: ...


@wraps(MeasurementId)
def get_measurement_id(
measurements: ilist.IList[MeasurementResult, N],
) -> ilist.IList[int, N]: ...


@wraps(Reset)
def reset(qubits: ilist.IList[Qubit, Any]) -> None: ...
43 changes: 43 additions & 0 deletions src/bloqade/qubit/_prelude.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Annotated

from kirin import ir
from kirin.passes import Default
from kirin.prelude import structural_no_opt
from typing_extensions import Doc

from . import _dialect as qubit


@ir.dialect_group(structural_no_opt.union([qubit]))
def kernel(self):
"""Compile to a qubit kernel"""

def run_pass(
mt,
*,
verify: Annotated[
bool, Doc("run `verify` before running passes, default is `True`")
] = True,
typeinfer: Annotated[
bool,
Doc(
"run type inference and apply the inferred type to IR, default `False`"
),
] = False,
fold: Annotated[bool, Doc("run folding passes")] = True,
aggressive: Annotated[
bool, Doc("run aggressive folding passes if `fold=True`")
] = False,
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
) -> None:
default_pass = Default(
self,
verify=verify,
fold=fold,
aggressive=aggressive,
typeinfer=typeinfer,
no_raise=no_raise,
)
default_pass.fixpoint(mt)

return run_pass
1 change: 1 addition & 0 deletions src/bloqade/qubit/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import address_impl as address_impl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
)
from bloqade.analysis.address.analysis import AddressAnalysis

from .. import qubit
from .. import stmts
from .._dialect import dialect

# Address lattice elements we can work with:
## NotQubit (bottom), AnyAddress (top)
Expand All @@ -24,15 +25,15 @@
### Base qubit address type


@qubit.dialect.register(key="qubit.address")
@dialect.register(key="qubit.address")
class SquinQubitMethodTable(interp.MethodTable):

@interp.impl(qubit.New)
@interp.impl(stmts.New)
def new_qubit(
self,
interp_: AddressAnalysis,
frame: ForwardFrame[Address],
stmt: qubit.New,
stmt: stmts.New,
):
addr = AddressQubit(interp_.next_address)
interp_.next_address += 1
Expand Down
2 changes: 2 additions & 0 deletions src/bloqade/qubit/stdlib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import simple as simple, broadcast as broadcast
from ._new import new as new, qalloc as qalloc
Loading