66from kirin import ir , types , interp
77from kirin .emit import EmitABC , EmitFrame
88from kirin .interp import MethodTable , impl
9- from kirin .dialects import py , func
9+ from kirin .dialects import py , func , ilist
1010from typing_extensions import Self
1111
1212from bloqade .squin import kernel
@@ -149,17 +149,23 @@ def main():
149149 new_func = func .Function (
150150 sym_name = sym_name , body = callable_region , signature = new_signature
151151 )
152- mt_ = ir .Method (None , None , sym_name , [], mt .dialects , new_func )
152+ # mt_ = ir.Method(None, None, sym_name, [], mt.dialects, new_func)
153+ mt_ = ir .Method (
154+ dialects = mt .dialects ,
155+ code = new_func ,
156+ sym_name = sym_name ,
157+ )
153158
154159 AggressiveUnroll (mt_ .dialects ).fixpoint (mt_ )
155- return emitter .run (mt_ , args = ())
160+ emitter .initialize ()
161+ emitter .run (mt_ )
162+ return emitter .circuit
156163
157164
158165@dataclass
159166class EmitCirqFrame (EmitFrame ):
160167 qubit_index : int = 0
161168 qubits : Sequence [cirq .Qid ] | None = None
162- circuit : cirq .Circuit = field (default_factory = cirq .Circuit )
163169
164170
165171def _default_kernel ():
@@ -172,19 +178,20 @@ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
172178 dialects : ir .DialectGroup = field (default_factory = _default_kernel )
173179 void = cirq .Circuit ()
174180 qubits : Sequence [cirq .Qid ] | None = None
181+ circuit : cirq .Circuit = field (default_factory = cirq .Circuit )
175182
176183 def initialize (self ) -> Self :
177184 return super ().initialize ()
178185
179186 def initialize_frame (
180- self , code : ir .Statement , * , has_parent_access : bool = False
187+ self , node : ir .Statement , * , has_parent_access : bool = False
181188 ) -> EmitCirqFrame :
182189 return EmitCirqFrame (
183- code , has_parent_access = has_parent_access , qubits = self .qubits
190+ node , has_parent_access = has_parent_access , qubits = self .qubits
184191 )
185192
186193 def run_method (self , method : ir .Method , args : tuple [cirq .Circuit , ...]):
187- return self .run_callable (method . code , args )
194+ return self .call (method , * args )
188195
189196 def run_callable_region (
190197 self ,
@@ -198,7 +205,7 @@ def run_callable_region(
198205 # NOTE: skip self arg
199206 frame .set_values (block_args [1 :], args )
200207
201- results = self .eval_stmt (frame , code )
208+ results = self .frame_eval (frame , code )
202209 if isinstance (results , tuple ):
203210 if len (results ) == 0 :
204211 return self .void
@@ -208,20 +215,32 @@ def run_callable_region(
208215
209216 def emit_block (self , frame : EmitCirqFrame , block : ir .Block ) -> cirq .Circuit :
210217 for stmt in block .stmts :
211- result = self .eval_stmt (frame , stmt )
218+ result = self .frame_eval (frame , stmt )
212219 if isinstance (result , tuple ):
213220 frame .set_values (stmt .results , result )
214221
215- return frame .circuit
222+ return self .circuit
223+
224+ def reset (self ):
225+ pass
216226
217227
218228@func .dialect .register (key = "emit.cirq" )
219229class __FuncEmit (MethodTable ):
220230
221231 @impl (func .Function )
222232 def emit_func (self , emit : EmitCirq , frame : EmitCirqFrame , stmt : func .Function ):
223- emit .run_ssacfg_region (frame , stmt .body , ())
224- return (frame .circuit ,)
233+ for block in stmt .body .blocks :
234+ frame .current_block = block
235+ for s in block .stmts :
236+ frame .current_stmt = s
237+ stmt_results = emit .frame_eval (frame , s )
238+ if isinstance (stmt_results , tuple ):
239+ if len (stmt_results ) != 0 :
240+ frame .set_values (s .results , stmt_results )
241+ continue
242+
243+ return (emit .circuit ,)
225244
226245 @impl (func .Invoke )
227246 def emit_invoke (self , emit : EmitCirq , frame : EmitCirqFrame , stmt : func .Invoke ):
@@ -235,6 +254,12 @@ def return_(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Return):
235254 # NOTE: should only be hit if ignore_returns == True
236255 return ()
237256
257+ @impl (func .ConstantNone )
258+ def emit_constant_none (
259+ self , emit : EmitCirq , frame : EmitCirqFrame , stmt : func .ConstantNone
260+ ):
261+ return ()
262+
238263
239264@py .indexing .dialect .register (key = "emit.cirq" )
240265class __Concrete (interp .MethodTable ):
@@ -243,3 +268,19 @@ class __Concrete(interp.MethodTable):
243268 def getindex (self , interp , frame : interp .Frame , stmt : py .indexing .GetItem ):
244269 # NOTE: no support for indexing into single statements in cirq
245270 return ()
271+
272+ @interp .impl (py .Constant )
273+ def emit_constant (self , emit : EmitCirq , frame : EmitCirqFrame , stmt : py .Constant ):
274+ return (stmt .value .data ,) # pyright: ignore[reportAttributeAccessIssue]
275+
276+
277+ @ilist .dialect .register (key = "emit.cirq" )
278+ class __IList (interp .MethodTable ):
279+ @interp .impl (ilist .New )
280+ def new_ilist (
281+ self ,
282+ emit : EmitCirq ,
283+ frame : interp .Frame ,
284+ stmt : ilist .New ,
285+ ):
286+ return (ilist .IList (data = frame .get_values (stmt .values )),)
0 commit comments