Skip to content

Commit dadcd37

Browse files
david-pljohnzl-777
andauthored
Restructure and move qubit dialect (#557)
As discussed in #549, this restructures the qubit dialect. I also moved it to the top level of bloqade, since it's used in multiple places. It's still re-exported by squin, though. I've got it down to a single failing test, where the `AggressiveUnroll` fails because the constant propagation doesn't hint the loop iterable as constant for some reason. This breaks the syntax such that you have to change the function calls using lists such as `squin.qubit.measure(q: IList[Qubit])` to `squin.broadcast.measure(q: IList[Qubit])`. **Edit**: turns out I was just too eager to remove the type inference method @weinbe58 added in #508. That resolves the failing test. **Edit2**: I've also changed `QubitId` and `MeasurementId` to be "broadcasted" (i.e., they now operate on lists of qubits / measurements and return lists). This means that now **all statements except `qubit.new` are broadcasted**, which is as consistent as we can have it, I think. --------- Co-authored-by: John Long <[email protected]>
1 parent da6def7 commit dadcd37

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+427
-356
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import py, scf, func, ilist
44

5-
from bloqade.squin import qubit
5+
from bloqade import qubit
66

77
from .lattice import (
88
AnyMeasureId,
@@ -21,22 +21,12 @@
2121
@qubit.dialect.register(key="measure_id")
2222
class SquinQubit(interp.MethodTable):
2323

24-
@interp.impl(qubit.MeasureQubit)
25-
def measure_qubit(
26-
self,
27-
interp: MeasurementIDAnalysis,
28-
frame: interp.Frame,
29-
stmt: qubit.MeasureQubit,
30-
):
31-
interp.measure_count += 1
32-
return (MeasureIdBool(interp.measure_count),)
33-
34-
@interp.impl(qubit.MeasureQubitList)
24+
@interp.impl(qubit.stmts.Measure)
3525
def measure_qubit_list(
3626
self,
3727
interp: MeasurementIDAnalysis,
3828
frame: interp.Frame,
39-
stmt: qubit.MeasureQubitList,
29+
stmt: qubit.stmts.Measure,
4030
):
4131

4232
# try to get the length of the list

src/bloqade/cirq_utils/emit/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,17 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
230230
"Function invokes should need to be inlined! "
231231
"If you called the emit_circuit method, that should have happened, please report this issue."
232232
)
233+
234+
@impl(func.Return)
235+
def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
236+
# NOTE: should only be hit if ignore_returns == True
237+
return ()
238+
239+
240+
@py.indexing.dialect.register(key="emit.cirq")
241+
class __Concrete(interp.MethodTable):
242+
243+
@interp.impl(py.indexing.GetItem)
244+
def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem):
245+
# NOTE: no support for indexing into single statements in cirq
246+
return ()

src/bloqade/cirq_utils/emit/qubit.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import cirq
22
from kirin.interp import MethodTable, impl
33

4-
from bloqade.squin import qubit
4+
from bloqade.qubit import stmts as qubit
55

66
from .base import EmitCirq, EmitCirqFrame
77

@@ -18,17 +18,9 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
1818
frame.qubit_index += 1
1919
return (cirq_qubit,)
2020

21-
@impl(qubit.MeasureQubit)
22-
def measure_qubit(
23-
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit
24-
):
25-
qbit = frame.get(stmt.qubit)
26-
frame.circuit.append(cirq.measure(qbit))
27-
return (emit.void,)
28-
29-
@impl(qubit.MeasureQubitList)
21+
@impl(qubit.Measure)
3022
def measure_qubit_list(
31-
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList
23+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure
3224
):
3325
qbits = frame.get(stmt.qubits)
3426
frame.circuit.append(cirq.measure(qbits))

src/bloqade/cirq_utils/lowering.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from kirin.rewrite import Walk, CFGCompactify
77
from kirin.dialects import py, scf, func, ilist
88

9-
from bloqade.squin import gate, noise, qubit, kernel, qalloc
9+
from bloqade import qubit
10+
from bloqade.squin import gate, noise, kernel, qalloc
1011

1112

1213
def load_circuit(
@@ -403,13 +404,8 @@ def bool_op_or(x: bool, y: bool) -> bool:
403404
def visit_MeasurementGate(
404405
self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
405406
):
406-
cirq_qubits = node.qubits
407-
if len(cirq_qubits) == 1:
408-
qbit = self.lower_qubit_getindex(state, node.qubits[0])
409-
stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
410-
else:
411-
qubits = self.lower_qubit_getindices(state, node.qubits)
412-
stmt = state.current_frame.push(qubit.MeasureQubitList(qubits))
407+
qubits = self.lower_qubit_getindices(state, node.qubits)
408+
stmt = state.current_frame.push(qubit.stmts.Measure(qubits))
413409

414410
# NOTE: add for classically controlled lowering
415411
key = node.gate.key

src/bloqade/native/_prelude.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from kirin.prelude import structural_no_opt
66
from typing_extensions import Doc
77

8-
from bloqade.squin import qubit
8+
from bloqade import qubit
99

1010
from .dialects import gates
1111

src/bloqade/native/dialects/gates/_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import lowering
44
from kirin.dialects import ilist
55

6-
from bloqade.squin import qubit
6+
from bloqade import qubit
77

88
from .stmts import CZ, R, Rz
99

src/bloqade/native/dialects/gates/stmts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.decl import info, statement
33
from kirin.dialects import ilist
44

5-
from bloqade.squin import qubit
5+
from bloqade.types import QubitType
66

77
from ._dialect import dialect
88

@@ -12,20 +12,20 @@
1212
@statement(dialect=dialect)
1313
class CZ(ir.Statement):
1414
traits = frozenset({lowering.FromPythonCall()})
15-
ctrls: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
16-
qargs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
15+
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
16+
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
1717

1818

1919
@statement(dialect=dialect)
2020
class R(ir.Statement):
2121
traits = frozenset({lowering.FromPythonCall()})
22-
inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
22+
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
2323
axis_angle: ir.SSAValue = info.argument(types.Float)
2424
rotation_angle: ir.SSAValue = info.argument(types.Float)
2525

2626

2727
@statement(dialect=dialect)
2828
class Rz(ir.Statement):
2929
traits = frozenset({lowering.FromPythonCall()})
30-
inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
30+
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
3131
rotation_angle: ir.SSAValue = info.argument(types.Float)

src/bloqade/native/stdlib/broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from kirin.dialects import ilist
55

6-
from bloqade.squin import qubit
6+
from bloqade import qubit
77
from bloqade.native._prelude import kernel
88
from bloqade.native.dialects.gates import _interface as native
99

src/bloqade/native/stdlib/simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from kirin.dialects import ilist
22

3-
from bloqade.squin import qubit
3+
from bloqade import qubit
44

55
from . import broadcast
66
from .._prelude import kernel

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import interp
44
from kirin.dialects import ilist
55

6-
from bloqade.squin import qubit
6+
from bloqade.qubit import stmts as qubit
77
from bloqade.pyqrack.reg import QubitState, Measurement, PyQrackQubit
88
from bloqade.pyqrack.base import PyQrackInterpreter
99

@@ -27,38 +27,32 @@ def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
2727
interp.set_global_measurement_id(m)
2828
return m
2929

30-
@interp.impl(qubit.MeasureQubitList)
30+
@interp.impl(qubit.Measure)
3131
def measure_qubit_list(
3232
self,
3333
interp: PyQrackInterpreter,
3434
frame: interp.Frame,
35-
stmt: qubit.MeasureQubitList,
35+
stmt: qubit.Measure,
3636
):
3737
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
3838
result = ilist.IList([self._measure_qubit(qbit, interp) for qbit in qubits])
3939
return (result,)
4040

41-
@interp.impl(qubit.MeasureQubit)
42-
def measure_qubit(
43-
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit
44-
):
45-
qbit: PyQrackQubit = frame.get(stmt.qubit)
46-
result = self._measure_qubit(qbit, interp)
47-
return (result,)
48-
4941
@interp.impl(qubit.QubitId)
5042
def qubit_id(
5143
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId
5244
):
53-
qbit: PyQrackQubit = frame.get(stmt.qubit)
54-
return (qbit.addr,)
45+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
46+
ids = ilist.IList([qbit.addr for qbit in qubits])
47+
return (ids,)
5548

5649
@interp.impl(qubit.MeasurementId)
5750
def measurement_id(
5851
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId
5952
):
60-
measurement: Measurement = frame.get(stmt.measurement)
61-
return (measurement.measurement_id,)
53+
measurements: ilist.IList[Measurement, Any] = frame.get(stmt.measurements)
54+
ids = ilist.IList([measurement.measurement_id for measurement in measurements])
55+
return (ids,)
6256

6357
@interp.impl(qubit.Reset)
6458
def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset):

0 commit comments

Comments
 (0)