Skip to content

Commit c28ff2c

Browse files
committed
WIP: trying to fix cirq emit
1 parent 0edcf0f commit c28ff2c

File tree

3 files changed

+48
-8
lines changed

3 files changed

+48
-8
lines changed

src/bloqade/cirq_utils/emit/base.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from warnings import warn
33
from dataclasses import field, dataclass
44

5+
from bloqade.rewrite.passes import AggressiveUnroll
56
import cirq
6-
from kirin import ir, types, interp
7+
from kirin import ir, types, interp, passes
78
from kirin.emit import EmitABC, EmitError, EmitFrame
89
from kirin.interp import MethodTable, impl
910
from kirin.passes import inline
10-
from kirin.dialects import func
11+
from kirin.dialects import func, py
1112
from typing_extensions import Self
1213

1314
from bloqade.squin import kernel
@@ -114,10 +115,49 @@ def main():
114115

115116
emitter = EmitCirq(qubits=circuit_qubits)
116117

117-
mt_ = mt.similar(mt.dialects)
118-
inline.InlinePass(mt_.dialects).fixpoint(mt_)
119118

120-
return emitter.run(mt_, args=args)
119+
120+
symbol_op_trait = mt.code.get_trait(ir.SymbolOpInterface)
121+
if (symbol_op_trait := mt.code.get_trait(ir.SymbolOpInterface)) is None:
122+
raise EmitError(
123+
f"The method is not a symbol, cannot emit circuit!"
124+
)
125+
126+
sym_name = symbol_op_trait.get_sym_name(mt.code).unwrap()
127+
128+
if (signature_trait := mt.code.get_trait(ir.HasSignature)) is None:
129+
raise EmitError(
130+
f"The method {sym_name} does not have a signature, cannot emit circuit!"
131+
)
132+
133+
signature = signature_trait.get_signature(mt.code)
134+
new_signature = func.Signature(inputs=(), output=signature.output)
135+
136+
callable_region = mt.callable_region.clone()
137+
entry_block = callable_region.blocks[0]
138+
args_ssa = list(entry_block.args)
139+
first_stmt = entry_block.first_stmt
140+
141+
assert first_stmt is not None, "Method has no statements!"
142+
if len(args_ssa) - 1 != len(args):
143+
raise EmitError(
144+
f"The method {sym_name} takes {len(args_ssa)} arguments, but you passed in {len(args)} via the `args` keyword!"
145+
)
146+
147+
for arg, arg_ssa in zip(args, args_ssa[1:], strict=True):
148+
(value := py.Constant(arg)).insert_before(first_stmt)
149+
arg_ssa.replace_by(value.result)
150+
entry_block.args.delete(arg_ssa)
151+
152+
new_func = func.Function(sym_name=sym_name, body=callable_region, signature=new_signature)
153+
mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
154+
155+
passes.Fold(mt_.dialects, no_raise=False)(mt_)
156+
mt_.print(hint="const")
157+
158+
# AggressiveUnroll(mt_.dialects)(mt_)
159+
# mt_.print(hint="const")
160+
return emitter.run(mt_, args=())
121161

122162

123163
@dataclass

src/bloqade/cirq_utils/emit/qubit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
class EmitCirqQubitMethods(MethodTable):
1111
@impl(qubit.New)
1212
def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
13-
print("emitting new qubit")
1413
if frame.qubits is not None:
1514
cirq_qubit = frame.qubits[frame.qubit_index]
1615
else:

test/cirq_utils/test_cirq_to_squin.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,6 @@ def main(n: int):
378378
for i in range(n):
379379
squin.x(q[i])
380380

381-
main.print()
382-
383381
n_arg = 3
384382
circuit = emit_circuit(main, args=(n_arg,))
385383
print(circuit)
@@ -403,6 +401,9 @@ def multi_arg(n: int, p: float):
403401

404402
print(circuit)
405403

404+
if __name__ == "__main__":
405+
test_kernel_with_args()
406+
406407

407408
@pytest.mark.xfail
408409
def test_amplitude_damping():

0 commit comments

Comments
 (0)