Skip to content

Commit 19dd7b3

Browse files
weinbe58david-pl
andauthored
Simplyfing qubit.new (#518)
Closes #508 however I am not sure where to reexport `new` since I Can't reexport it into the `qubit` module. For now I have re-exported it into `squin.new`. --------- Co-authored-by: David Plankensteiner <[email protected]>
1 parent ff4d74f commit 19dd7b3

37 files changed

+389
-326
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ requires-python = ">=3.10"
1313
dependencies = [
1414
"numpy>=1.22.0",
1515
"scipy>=1.13.1",
16-
"kirin-toolchain~=0.17.26",
16+
"kirin-toolchain~=0.17.30",
1717
"rich>=13.9.4",
1818
"pydantic>=1.3.0,<2.11.0",
1919
"pandas>=2.2.3",

src/bloqade/cirq_utils/emit/base.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from kirin import ir, types, interp
77
from kirin.emit import EmitABC, EmitError, EmitFrame
88
from kirin.interp import MethodTable, impl
9-
from kirin.passes import inline
10-
from kirin.dialects import func
9+
from kirin.dialects import py, func
1110
from typing_extensions import Self
1211

1312
from bloqade.squin import kernel
13+
from bloqade.rewrite.passes import AggressiveUnroll
1414

1515

1616
def emit_circuit(
@@ -28,7 +28,7 @@ def emit_circuit(
2828
Keyword Args:
2929
circuit_qubits (Sequence[cirq.Qid] | None):
3030
A list of qubits to use as the qubits in the circuit. Defaults to None.
31-
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
31+
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qalloc`
3232
statement in the order they appear inside the kernel.
3333
**Note**: If a list of qubits is provided, make sure that there is a sufficient
3434
number of qubits for the resulting circuit.
@@ -48,7 +48,7 @@ def emit_circuit(
4848
4949
@squin.kernel
5050
def main():
51-
q = squin.qubit.new(2)
51+
q = squin.qalloc(2)
5252
squin.h(q[0])
5353
squin.cx(q[0], q[1])
5454
@@ -74,8 +74,10 @@ def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
7474
7575
@squin.kernel
7676
def main():
77-
q = squin.qubit.new(2)
78-
entangle(q)
77+
q = squin.qalloc(2)
78+
q2 = squin.qalloc(3)
79+
squin.cx(q[1], q2[2])
80+
7981
8082
# custom list of qubits on grid
8183
qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
@@ -112,10 +114,43 @@ def main():
112114

113115
emitter = EmitCirq(qubits=circuit_qubits)
114116

115-
mt_ = mt.similar(mt.dialects)
116-
inline.InlinePass(mt_.dialects).fixpoint(mt_)
117+
symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface)
118+
if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None:
119+
raise EmitError("The method is not a symbol, cannot emit circuit!")
120+
121+
sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap()
122+
123+
if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None:
124+
raise EmitError(
125+
f"The method {sym_name} does not have a signature, cannot emit circuit!"
126+
)
127+
128+
signature = signature_trait.get_signature(mt.code)
129+
new_signature = func.Signature(inputs=(), output=signature.output)
130+
131+
callable_region = mt.callable_region.clone()
132+
entry_block = callable_region.blocks[0]
133+
args_ssa = list(entry_block.args)
134+
first_stmt = entry_block.first_stmt
135+
136+
assert first_stmt is not None, "Method has no statements!"
137+
if len(args_ssa) - 1 != len(args):
138+
raise EmitError(
139+
f"The method {sym_name} takes {len(args_ssa) - 1} arguments, but you passed in {len(args)} via the `args` keyword!"
140+
)
141+
142+
for arg, arg_ssa in zip(args, args_ssa[1:], strict=True):
143+
(value := py.Constant(arg)).insert_before(first_stmt)
144+
arg_ssa.replace_by(value.result)
145+
entry_block.args.delete(arg_ssa)
146+
147+
new_func = func.Function(
148+
sym_name=sym_name, body=callable_region, signature=new_signature
149+
)
150+
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
117151

118-
return emitter.run(mt_, args=args)
152+
AggressiveUnroll(mt_.dialects).fixpoint(mt_)
153+
return emitter.run(mt_, args=())
119154

120155

121156
@dataclass

src/bloqade/cirq_utils/emit/qubit.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,13 @@
1010
class EmitCirqQubitMethods(MethodTable):
1111
@impl(qubit.New)
1212
def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
13-
n_qubits = frame.get(stmt.n_qubits)
14-
1513
if frame.qubits is not None:
16-
cirq_qubits = tuple(
17-
frame.qubits[i + frame.qubit_index] for i in range(n_qubits)
18-
)
14+
cirq_qubit = frame.qubits[frame.qubit_index]
1915
else:
20-
cirq_qubits = tuple(
21-
cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
22-
)
16+
cirq_qubit = cirq.LineQubit(frame.qubit_index)
2317

24-
frame.qubit_index += n_qubits
25-
return (cirq_qubits,)
18+
frame.qubit_index += 1
19+
return (cirq_qubit,)
2620

2721
@impl(qubit.MeasureQubit)
2822
def measure_qubit(

src/bloqade/cirq_utils/lowering.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from kirin.rewrite import Walk, CFGCompactify
77
from kirin.dialects import py, scf, func, ilist
88

9-
from bloqade.squin import gate, noise, qubit, kernel
9+
from bloqade.squin import gate, noise, qubit, kernel, qalloc
1010

1111

1212
def load_circuit(
@@ -92,7 +92,7 @@ def load_circuit(
9292
@squin.kernel
9393
def main():
9494
qreg = get_entangled_qubits()
95-
qreg2 = squin.qubit.new(1)
95+
qreg2 = squin.qalloc(1)
9696
entangle_qubits([qreg[1], qreg2[0]])
9797
return squin.qubit.measure(qreg2)
9898
```
@@ -142,7 +142,7 @@ def main():
142142
body=body,
143143
)
144144

145-
return ir.Method(
145+
mt = ir.Method(
146146
mod=None,
147147
py_func=None,
148148
sym_name=kernel_name,
@@ -151,6 +151,11 @@ def main():
151151
code=code,
152152
)
153153

154+
assert (run_pass := kernel.run_pass) is not None
155+
run_pass(mt, typeinfer=True)
156+
157+
return mt
158+
154159

155160
CirqNode = (
156161
cirq.Circuit
@@ -254,7 +259,9 @@ def run(
254259
# NOTE: create a new register of appropriate size
255260
n_qubits = len(self.qreg_index)
256261
n = frame.push(py.Constant(n_qubits))
257-
self.qreg = frame.push(qubit.New(n_qubits=n.result)).result
262+
self.qreg = frame.push(
263+
func.Invoke((n.result,), callee=qalloc, kwargs=())
264+
).result
258265

259266
self.visit(state, stmt)
260267

@@ -382,8 +389,16 @@ def bool_op_or(x: bool, y: bool) -> bool:
382389
# NOTE: remove stmt from parent block
383390
then_stmt.detach()
384391
then_body = ir.Block((then_stmt,))
392+
then_body.args.append_from(types.Bool, name="cond")
393+
then_body.stmts.append(scf.Yield())
385394

386-
return state.current_frame.push(scf.IfElse(condition, then_body=then_body))
395+
else_body = ir.Block(())
396+
else_body.args.append_from(types.Bool, name="cond")
397+
else_body.stmts.append(scf.Yield())
398+
399+
return state.current_frame.push(
400+
scf.IfElse(condition, then_body=then_body, else_body=else_body)
401+
)
387402

388403
def visit_MeasurementGate(
389404
self, state: lowering.State[cirq.Circuit], node: cirq.GateOperation

src/bloqade/pyqrack/squin/qubit.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,12 @@
1111
@qubit.dialect.register(key="pyqrack")
1212
class PyQrackMethods(interp.MethodTable):
1313
@interp.impl(qubit.New)
14-
def new(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New):
15-
n_qubits: int = frame.get(stmt.n_qubits)
16-
qreg = ilist.IList(
17-
[
18-
PyQrackQubit(i, interp.memory.sim_reg, QubitState.Active)
19-
for i in interp.memory.allocate(n_qubits=n_qubits)
20-
]
21-
)
22-
return (qreg,)
14+
def new_qubit(
15+
self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: qubit.New
16+
):
17+
(addr,) = interp.memory.allocate(1)
18+
qb = PyQrackQubit(addr, interp.memory.sim_reg, QubitState.Active)
19+
return (qb,)
2320

2421
def _measure_qubit(self, qbit: PyQrackQubit, interp: PyQrackInterpreter):
2522
if qbit.is_active():

src/bloqade/rewrite/passes/aggressive_unroll.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
3838
InlineGetField(),
3939
InlineGetItem(),
4040
ilist.rewrite.InlineGetItem(),
41+
ilist.rewrite.FlattenAdd(),
4142
ilist.rewrite.HintLen(),
4243
)
4344
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
@@ -68,7 +69,7 @@ def unsafe_run(self, mt: Method) -> RewriteResult:
6869
.rewrite(mt.code)
6970
.join(result)
7071
)
71-
result = self.typeinfer.unsafe_run(mt).join(result)
72+
self.typeinfer.unsafe_run(mt)
7273
result = self.fold.unsafe_run(mt).join(result)
7374
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
7475
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)

src/bloqade/squin/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
noise as noise,
44
qubit as qubit,
55
analysis as analysis,
6-
_typeinfer as _typeinfer,
76
)
87
from .groups import kernel as kernel
8+
from .stdlib.qubit import qalloc as qalloc
99
from .stdlib.simple import (
1010
h as h,
1111
s as s,

src/bloqade/squin/_typeinfer.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/bloqade/squin/analysis/address_impl.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from bloqade.analysis.address.lattice import (
55
Address,
6-
AddressReg,
6+
AddressQubit,
77
)
88
from bloqade.analysis.address.analysis import AddressAnalysis
99

@@ -27,15 +27,13 @@
2727
@qubit.dialect.register(key="qubit.address")
2828
class SquinQubitMethodTable(interp.MethodTable):
2929

30-
# This can be treated like a QRegNew impl
3130
@interp.impl(qubit.New)
32-
def new(
31+
def new_qubit(
3332
self,
3433
interp_: AddressAnalysis,
3534
frame: ForwardFrame[Address],
3635
stmt: qubit.New,
3736
):
38-
n_qubits = interp_.get_const_value(int, stmt.n_qubits)
39-
addr = AddressReg(range(interp_.next_address, interp_.next_address + n_qubits))
40-
interp_.next_address += n_qubits
37+
addr = AddressQubit(interp_.next_address)
38+
interp_.next_address += 1
4139
return (addr,)

src/bloqade/squin/groups.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run_pass(method: ir.Method, *, fold=True, typeinfer=True):
2020
fold_pass.fixpoint(method)
2121

2222
if typeinfer:
23-
typeinfer_pass(method)
23+
typeinfer_pass(method) # infer types before desugaring
2424
desugar_pass.rewrite(method.code)
2525

2626
ilist_desugar_pass(method)

0 commit comments

Comments
 (0)