Skip to content

Commit 447126c

Browse files
authored
Merge branch 'main' into khwu/stim_annotate
2 parents 5126e6d + dadcd37 commit 447126c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+427
-356
lines changed

src/bloqade/analysis/measure_id/impls.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.analysis import const
33
from kirin.dialects import py, scf, func, ilist
44

5-
from bloqade.squin import qubit
5+
from bloqade import qubit
66

77
from .lattice import (
88
AnyMeasureId,
@@ -21,22 +21,12 @@
2121
@qubit.dialect.register(key="measure_id")
2222
class SquinQubit(interp.MethodTable):
2323

24-
@interp.impl(qubit.MeasureQubit)
25-
def measure_qubit(
26-
self,
27-
interp: MeasurementIDAnalysis,
28-
frame: interp.Frame,
29-
stmt: qubit.MeasureQubit,
30-
):
31-
interp.measure_count += 1
32-
return (MeasureIdBool(interp.measure_count),)
33-
34-
@interp.impl(qubit.MeasureQubitList)
24+
@interp.impl(qubit.stmts.Measure)
3525
def measure_qubit_list(
3626
self,
3727
interp: MeasurementIDAnalysis,
3828
frame: interp.Frame,
39-
stmt: qubit.MeasureQubitList,
29+
stmt: qubit.stmts.Measure,
4030
):
4131

4232
# try to get the length of the list

src/bloqade/cirq_utils/emit/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,3 +230,17 @@ def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
230230
"Function invokes should need to be inlined! "
231231
"If you called the emit_circuit method, that should have happened, please report this issue."
232232
)
233+
234+
@impl(func.Return)
235+
def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
236+
# NOTE: should only be hit if ignore_returns == True
237+
return ()
238+
239+
240+
@py.indexing.dialect.register(key="emit.cirq")
241+
class __Concrete(interp.MethodTable):
242+
243+
@interp.impl(py.indexing.GetItem)
244+
def getindex(self, interp, frame: interp.Frame, stmt: py.indexing.GetItem):
245+
# NOTE: no support for indexing into single statements in cirq
246+
return ()

src/bloqade/cirq_utils/emit/qubit.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import cirq
22
from kirin.interp import MethodTable, impl
33

4-
from bloqade.squin import qubit
4+
from bloqade.qubit import stmts as qubit
55

66
from .base import EmitCirq, EmitCirqFrame
77

@@ -18,17 +18,9 @@ def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
1818
frame.qubit_index += 1
1919
return (cirq_qubit,)
2020

21-
@impl(qubit.MeasureQubit)
22-
def measure_qubit(
23-
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit
24-
):
25-
qbit = frame.get(stmt.qubit)
26-
frame.circuit.append(cirq.measure(qbit))
27-
return (emit.void,)
28-
29-
@impl(qubit.MeasureQubitList)
21+
@impl(qubit.Measure)
3022
def measure_qubit_list(
31-
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList
23+
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Measure
3224
):
3325
qbits = frame.get(stmt.qubits)
3426
frame.circuit.append(cirq.measure(qbits))

src/bloqade/cirq_utils/lowering.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from kirin.rewrite import Walk, CFGCompactify
77
from kirin.dialects import py, scf, func, ilist
88

9-
from bloqade.squin import gate, noise, qubit, kernel, qalloc
9+
from bloqade import qubit
10+
from bloqade.squin import gate, noise, kernel, qalloc
1011

1112

1213
def load_circuit(
@@ -403,13 +404,8 @@ def bool_op_or(x: bool, y: bool) -> bool:
403404
def visit_MeasurementGate(
404405
self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation
405406
):
406-
cirq_qubits = node.qubits
407-
if len(cirq_qubits) == 1:
408-
qbit = self.lower_qubit_getindex(state, node.qubits[0])
409-
stmt = state.current_frame.push(qubit.MeasureQubit(qbit))
410-
else:
411-
qubits = self.lower_qubit_getindices(state, node.qubits)
412-
stmt = state.current_frame.push(qubit.MeasureQubitList(qubits))
407+
qubits = self.lower_qubit_getindices(state, node.qubits)
408+
stmt = state.current_frame.push(qubit.stmts.Measure(qubits))
413409

414410
# NOTE: add for classically controlled lowering
415411
key = node.gate.key

src/bloqade/native/_prelude.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from kirin.prelude import structural_no_opt
66
from typing_extensions import Doc
77

8-
from bloqade.squin import qubit
8+
from bloqade import qubit
99

1010
from .dialects import gates
1111

src/bloqade/native/dialects/gates/_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import lowering
44
from kirin.dialects import ilist
55

6-
from bloqade.squin import qubit
6+
from bloqade import qubit
77

88
from .stmts import CZ, R, Rz
99

src/bloqade/native/dialects/gates/stmts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.decl import info, statement
33
from kirin.dialects import ilist
44

5-
from bloqade.squin import qubit
5+
from bloqade.types import QubitType
66

77
from ._dialect import dialect
88

@@ -12,20 +12,20 @@
1212
@statement(dialect=dialect)
1313
class CZ(ir.Statement):
1414
traits = frozenset({lowering.FromPythonCall()})
15-
ctrls: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
16-
qargs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, N])
15+
ctrls: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
16+
qargs: ir.SSAValue = info.argument(ilist.IListType[QubitType, N])
1717

1818

1919
@statement(dialect=dialect)
2020
class R(ir.Statement):
2121
traits = frozenset({lowering.FromPythonCall()})
22-
inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
22+
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
2323
axis_angle: ir.SSAValue = info.argument(types.Float)
2424
rotation_angle: ir.SSAValue = info.argument(types.Float)
2525

2626

2727
@statement(dialect=dialect)
2828
class Rz(ir.Statement):
2929
traits = frozenset({lowering.FromPythonCall()})
30-
inputs: ir.SSAValue = info.argument(ilist.IListType[qubit.QubitType, types.Any])
30+
inputs: ir.SSAValue = info.argument(ilist.IListType[QubitType, types.Any])
3131
rotation_angle: ir.SSAValue = info.argument(types.Float)

src/bloqade/native/stdlib/broadcast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from kirin.dialects import ilist
55

6-
from bloqade.squin import qubit
6+
from bloqade import qubit
77
from bloqade.native._prelude import kernel
88
from bloqade.native.dialects.gates import _interface as native
99

src/bloqade/native/stdlib/simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from kirin.dialects import ilist
22

3-
from bloqade.squin import qubit
3+
from bloqade import qubit
44

55
from . import broadcast
66
from .._prelude import kernel

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from kirin import interp
44
from kirin.dialects import ilist
55

6-
from bloqade.squin import qubit
6+
from bloqade.qubit import stmts as qubit
77
from bloqade.pyqrack.reg import QubitState, Measurement, PyQrackQubit
88
from bloqade.pyqrack.base import PyQrackInterpreter
99

@@ -27,38 +27,32 @@ def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
2727
interp.set_global_measurement_id(m)
2828
return m
2929

30-
@interp.impl(qubit.MeasureQubitList)
30+
@interp.impl(qubit.Measure)
3131
def measure_qubit_list(
3232
self,
3333
interp: PyQrackInterpreter,
3434
frame: interp.Frame,
35-
stmt: qubit.MeasureQubitList,
35+
stmt: qubit.Measure,
3636
):
3737
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
3838
result = ilist.IList([self._measure_qubit(qbit, interp) for qbit in qubits])
3939
return (result,)
4040

41-
@interp.impl(qubit.MeasureQubit)
42-
def measure_qubit(
43-
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasureQubit
44-
):
45-
qbit: PyQrackQubit = frame.get(stmt.qubit)
46-
result = self._measure_qubit(qbit, interp)
47-
return (result,)
48-
4941
@interp.impl(qubit.QubitId)
5042
def qubit_id(
5143
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.QubitId
5244
):
53-
qbit: PyQrackQubit = frame.get(stmt.qubit)
54-
return (qbit.addr,)
45+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
46+
ids = ilist.IList([qbit.addr for qbit in qubits])
47+
return (ids,)
5548

5649
@interp.impl(qubit.MeasurementId)
5750
def measurement_id(
5851
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.MeasurementId
5952
):
60-
measurement: Measurement = frame.get(stmt.measurement)
61-
return (measurement.measurement_id,)
53+
measurements: ilist.IList[Measurement, Any] = frame.get(stmt.measurements)
54+
ids = ilist.IList([measurement.measurement_id for measurement in measurements])
55+
return (ids,)
6256

6357
@interp.impl(qubit.Reset)
6458
def reset(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.Reset):

0 commit comments

Comments
 (0)