66from kirin import ir , types , interp
77from kirin .emit import EmitABC , EmitError , EmitFrame
88from kirin .interp import MethodTable , impl
9- from kirin .passes import inline
10- from kirin .dialects import func
9+ from kirin .dialects import py , func
1110from typing_extensions import Self
1211
1312from bloqade .squin import kernel
13+ from bloqade .rewrite .passes import AggressiveUnroll
1414
1515
1616def emit_circuit (
@@ -28,7 +28,7 @@ def emit_circuit(
2828 Keyword Args:
2929 circuit_qubits (Sequence[cirq.Qid] | None):
3030 A list of qubits to use as the qubits in the circuit. Defaults to None.
31- If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new `
31+ If this is None, then `cirq.LineQubit`s are inserted for every `squin.qalloc `
3232 statement in the order they appear inside the kernel.
3333 **Note**: If a list of qubits is provided, make sure that there is a sufficient
3434 number of qubits for the resulting circuit.
@@ -48,7 +48,7 @@ def emit_circuit(
4848
4949 @squin.kernel
5050 def main():
51- q = squin.qubit.new (2)
51+ q = squin.qalloc (2)
5252 squin.h(q[0])
5353 squin.cx(q[0], q[1])
5454
@@ -74,8 +74,10 @@ def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
7474
7575 @squin.kernel
7676 def main():
77- q = squin.qubit.new(2)
78- entangle(q)
77+ q = squin.qalloc(2)
78+ q2 = squin.qalloc(3)
79+ squin.cx(q[1], q2[2])
80+
7981
8082 # custom list of qubits on grid
8183 qubits = [cirq.GridQubit(i, i+1) for i in range(5)]
@@ -112,10 +114,43 @@ def main():
112114
113115 emitter = EmitCirq (qubits = circuit_qubits )
114116
115- mt_ = mt .similar (mt .dialects )
116- inline .InlinePass (mt_ .dialects ).fixpoint (mt_ )
117+ symbol_op_trait = mt .code .get_trait (ir .SymbolOpInterface )
118+ if (symbol_op_trait := mt .code .get_trait (ir .SymbolOpInterface )) is None :
119+ raise EmitError ("The method is not a symbol, cannot emit circuit!" )
120+
121+ sym_name = symbol_op_trait .get_sym_name (mt .code ).unwrap ()
122+
123+ if (signature_trait := mt .code .get_trait (ir .HasSignature )) is None :
124+ raise EmitError (
125+ f"The method { sym_name } does not have a signature, cannot emit circuit!"
126+ )
127+
128+ signature = signature_trait .get_signature (mt .code )
129+ new_signature = func .Signature (inputs = (), output = signature .output )
130+
131+ callable_region = mt .callable_region .clone ()
132+ entry_block = callable_region .blocks [0 ]
133+ args_ssa = list (entry_block .args )
134+ first_stmt = entry_block .first_stmt
135+
136+ assert first_stmt is not None , "Method has no statements!"
137+ if len (args_ssa ) - 1 != len (args ):
138+ raise EmitError (
139+ f"The method { sym_name } takes { len (args_ssa ) - 1 } arguments, but you passed in { len (args )} via the `args` keyword!"
140+ )
141+
142+ for arg , arg_ssa in zip (args , args_ssa [1 :], strict = True ):
143+ (value := py .Constant (arg )).insert_before (first_stmt )
144+ arg_ssa .replace_by (value .result )
145+ entry_block .args .delete (arg_ssa )
146+
147+ new_func = func .Function (
148+ sym_name = sym_name , body = callable_region , signature = new_signature
149+ )
150+ mt_ = ir .Method (None , None , sym_name , [], mt .dialects , new_func )
117151
118- return emitter .run (mt_ , args = args )
152+ AggressiveUnroll (mt_ .dialects ).fixpoint (mt_ )
153+ return emitter .run (mt_ , args = ())
119154
120155
121156@dataclass
0 commit comments