Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bloqade/squin/groups.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from kirin import ir, passes
from kirin.prelude import structural_no_opt
from kirin.dialects import ilist
from kirin.rewrite.walk import Walk

from . import op, wire, qubit
from .rewrite.measure_desugar import MeasureDesugarRule


@ir.dialect_group(structural_no_opt.union([op, qubit]))
def kernel(self):
fold_pass = passes.Fold(self)
typeinfer_pass = passes.TypeInfer(self)
ilist_desugar_pass = ilist.IListDesugar(self)
measure_desugar_pass = Walk(MeasureDesugarRule())

def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
method.verify()
Expand All @@ -18,6 +21,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):

if typeinfer:
typeinfer_pass(method)
measure_desugar_pass.rewrite(method.code)
ilist_desugar_pass(method)
if typeinfer:
typeinfer_pass(method) # fix types after desugaring
Expand Down
58 changes: 43 additions & 15 deletions src/bloqade/squin/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
- `kirin.dialects.ilist`: provides the `ilist.IListType` type for lists of qubits.
"""

from typing import Any
from typing import Any, overload

from kirin import ir, types, lowering
from kirin.decl import info, statement
Expand Down Expand Up @@ -42,7 +42,27 @@ class Broadcast(ir.Statement):


@statement(dialect=dialect)
class Measure(ir.Statement):
class MeasureAny(ir.Statement):
name = "measure"

traits = frozenset({lowering.FromPythonCall()})
input: ir.SSAValue = info.argument(types.Any)
result: ir.ResultValue = info.result(types.Any)


@statement(dialect=dialect)
class MeasureQubit(ir.Statement):
name = "measure.qubit"

traits = frozenset({lowering.FromPythonCall()})
qubit: ir.SSAValue = info.argument(ilist.IListType[QubitType])
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])


@statement(dialect=dialect)
class MeasureQubitList(ir.Statement):
name = "measure.qubit.list"

traits = frozenset({lowering.FromPythonCall()})
qubits: ir.SSAValue = info.argument(ilist.IListType[QubitType])
result: ir.ResultValue = info.result(ilist.IListType[types.Bool])
Expand Down Expand Up @@ -89,31 +109,39 @@ def apply(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
...


@wraps(Broadcast)
def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
"""Broadcast and apply an operator to a list of qubits. For example, an operator
that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.
@overload
def measure(input: Qubit) -> bool: ...
@overload
def measure(input: ilist.IList[Qubit, Any] | list[Qubit]) -> list[bool]: ...


@wraps(MeasureAny)
def measure(input: Any) -> Any:
"""Measure a qubit or qubits in the list.

Args:
operator: The operator to broadcast and apply.
qubits: The list of qubits to broadcast and apply the operator to. The size of the list
must be inferable and match the number of qubits expected by the operator.
input: A qubit or a list of qubits to measure.

Returns:
None
bool | list[bool]: The result of the measurement. If a single qubit is measured,
a single boolean is returned. If a list of qubits is measured, a list of booleans
is returned.
"""
...


@wraps(Measure)
def measure(qubits: ilist.IList[Qubit, Any]) -> int:
"""Measure the qubits in the list."
@wraps(Broadcast)
def broadcast(operator: Op, qubits: ilist.IList[Qubit, Any] | list[Qubit]) -> None:
"""Broadcast and apply an operator to a list of qubits. For example, an operator
that expects 2 qubits can be applied to a list of 2n qubits, where n is an integer > 0.

Args:
qubits: The list of qubits to measure.
operator: The operator to broadcast and apply.
qubits: The list of qubits to broadcast and apply the operator to. The size of the list
must be inferable and match the number of qubits expected by the operator.

Returns:
int: The result of the measurement.
None
"""
...

Expand Down
Empty file.
33 changes: 33 additions & 0 deletions src/bloqade/squin/rewrite/measure_desugar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from kirin import ir, types
from kirin.dialects import ilist
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.squin.qubit import QubitType, MeasureAny, MeasureQubit, MeasureQubitList


class MeasureDesugarRule(RewriteRule):
"""
Desugar measure operations in the circuit.
"""

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

if not isinstance(node, MeasureAny):
return RewriteResult()

if node.input.type.is_subseteq(QubitType):
node.replace_by(
MeasureQubit(
qubit=node.input,
)
)
return RewriteResult(has_done_something=True)
elif node.input.type.is_subseteq(ilist.IListType[QubitType, types.Any]):
node.replace_by(
MeasureQubitList(
qubits=node.input,
)
)
return RewriteResult(has_done_something=True)

return RewriteResult()

Check warning on line 33 in src/bloqade/squin/rewrite/measure_desugar.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/rewrite/measure_desugar.py#L33

Added line #L33 was not covered by tests
36 changes: 36 additions & 0 deletions test/squin/test_measure_sugar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from kirin import ir
from kirin.dialects import func

from bloqade import squin


def get_return_value_stmt(kernel: ir.Method):
assert isinstance(
last_stmt := kernel.callable_region.blocks[-1].last_stmt, func.Return
)
return last_stmt.value.owner


def test_measure_register():
@squin.kernel
def test_measure_sugar():
q = squin.qubit.new(2)

return squin.qubit.measure(q)

assert isinstance(
get_return_value_stmt(test_measure_sugar), squin.qubit.MeasureQubitList
)


def test_measure_qubit():
@squin.kernel
def test_measure_sugar():
q = squin.qubit.new(2)

return squin.qubit.measure(q[0])

assert isinstance(
get_return_value_stmt(test_measure_sugar),
squin.qubit.MeasureQubit,
)