44
55import cirq
66from kirin import ir , types , interp
7- from kirin .emit import EmitABC , EmitError , EmitFrame
7+ from kirin .emit import EmitABC , EmitFrame
88from kirin .interp import MethodTable , impl
9- from kirin .dialects import py , func
9+ from kirin .dialects import py , func , ilist
1010from typing_extensions import Self
1111
1212from bloqade .squin import kernel
@@ -102,7 +102,7 @@ def main():
102102 and isinstance (mt .code , func .Function )
103103 and not mt .code .signature .output .is_subseteq (types .NoneType )
104104 ):
105- raise EmitError (
105+ raise interp . exceptions . InterpreterError (
106106 "The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
107107 " Set `ignore_returns = True` in order to simply ignore the return values and emit a circuit."
108108 )
@@ -116,12 +116,14 @@ def main():
116116
117117 symbol_op_trait = mt .code .get_trait (ir .SymbolOpInterface )
118118 if (symbol_op_trait := mt .code .get_trait (ir .SymbolOpInterface )) is None :
119- raise EmitError ("The method is not a symbol, cannot emit circuit!" )
119+ raise interp .exceptions .InterpreterError (
120+ "The method is not a symbol, cannot emit circuit!"
121+ )
120122
121123 sym_name = symbol_op_trait .get_sym_name (mt .code ).unwrap ()
122124
123125 if (signature_trait := mt .code .get_trait (ir .HasSignature )) is None :
124- raise EmitError (
126+ raise interp . exceptions . InterpreterError (
125127 f"The method { sym_name } does not have a signature, cannot emit circuit!"
126128 )
127129
@@ -135,7 +137,7 @@ def main():
135137
136138 assert first_stmt is not None , "Method has no statements!"
137139 if len (args_ssa ) - 1 != len (args ):
138- raise EmitError (
140+ raise interp . exceptions . InterpreterError (
139141 f"The method { sym_name } takes { len (args_ssa ) - 1 } arguments, but you passed in { len (args )} via the `args` keyword!"
140142 )
141143
@@ -147,17 +149,22 @@ def main():
147149 new_func = func .Function (
148150 sym_name = sym_name , body = callable_region , signature = new_signature
149151 )
150- mt_ = ir .Method (None , None , sym_name , [], mt .dialects , new_func )
152+ mt_ = ir .Method (
153+ dialects = mt .dialects ,
154+ code = new_func ,
155+ sym_name = sym_name ,
156+ )
151157
152158 AggressiveUnroll (mt_ .dialects ).fixpoint (mt_ )
153- return emitter .run (mt_ , args = ())
159+ emitter .initialize ()
160+ emitter .run (mt_ )
161+ return emitter .circuit
154162
155163
156164@dataclass
157165class EmitCirqFrame (EmitFrame ):
158166 qubit_index : int = 0
159167 qubits : Sequence [cirq .Qid ] | None = None
160- circuit : cirq .Circuit = field (default_factory = cirq .Circuit )
161168
162169
163170def _default_kernel ():
@@ -166,23 +173,24 @@ def _default_kernel():
166173
167174@dataclass
168175class EmitCirq (EmitABC [EmitCirqFrame , cirq .Circuit ]):
169- keys = [ "emit.cirq" , "main" ]
176+ keys = ( "emit.cirq" , "emit. main" )
170177 dialects : ir .DialectGroup = field (default_factory = _default_kernel )
171178 void = cirq .Circuit ()
172179 qubits : Sequence [cirq .Qid ] | None = None
180+ circuit : cirq .Circuit = field (default_factory = cirq .Circuit )
173181
174182 def initialize (self ) -> Self :
175183 return super ().initialize ()
176184
177185 def initialize_frame (
178- self , code : ir .Statement , * , has_parent_access : bool = False
186+ self , node : ir .Statement , * , has_parent_access : bool = False
179187 ) -> EmitCirqFrame :
180188 return EmitCirqFrame (
181- code , has_parent_access = has_parent_access , qubits = self .qubits
189+ node , has_parent_access = has_parent_access , qubits = self .qubits
182190 )
183191
184192 def run_method (self , method : ir .Method , args : tuple [cirq .Circuit , ...]):
185- return self .run_callable (method . code , args )
193+ return self .call (method , * args )
186194
187195 def run_callable_region (
188196 self ,
@@ -196,7 +204,7 @@ def run_callable_region(
196204 # NOTE: skip self arg
197205 frame .set_values (block_args [1 :], args )
198206
199- results = self .eval_stmt (frame , code )
207+ results = self .frame_eval (frame , code )
200208 if isinstance (results , tuple ):
201209 if len (results ) == 0 :
202210 return self .void
@@ -206,33 +214,43 @@ def run_callable_region(
206214
207215 def emit_block (self , frame : EmitCirqFrame , block : ir .Block ) -> cirq .Circuit :
208216 for stmt in block .stmts :
209- result = self .eval_stmt (frame , stmt )
217+ result = self .frame_eval (frame , stmt )
210218 if isinstance (result , tuple ):
211219 frame .set_values (stmt .results , result )
212220
213- return frame .circuit
221+ return self .circuit
222+
223+ def reset (self ):
224+ pass
225+
226+ def eval_fallback (self , frame : EmitCirqFrame , node : ir .Statement ) -> tuple :
227+ return tuple (None for _ in range (len (node .results )))
214228
215229
216230@func .dialect .register (key = "emit.cirq" )
217231class __FuncEmit (MethodTable ):
218232
219233 @impl (func .Function )
220234 def emit_func (self , emit : EmitCirq , frame : EmitCirqFrame , stmt : func .Function ):
221- emit .run_ssacfg_region (frame , stmt .body , ())
222- return (frame .circuit ,)
235+ for block in stmt .body .blocks :
236+ frame .current_block = block
237+ for s in block .stmts :
238+ frame .current_stmt = s
239+ stmt_results = emit .frame_eval (frame , s )
240+ if isinstance (stmt_results , tuple ):
241+ if len (stmt_results ) != 0 :
242+ frame .set_values (s .results , stmt_results )
243+ continue
244+
245+ return (emit .circuit ,)
223246
224247 @impl (func .Invoke )
225248 def emit_invoke (self , emit : EmitCirq , frame : EmitCirqFrame , stmt : func .Invoke ):
226- raise EmitError (
249+ raise interp . exceptions . InterpreterError (
227250 "Function invokes should need to be inlined! "
228251 "If you called the emit_circuit method, that should have happened, please report this issue."
229252 )
230253
231- @impl (func .Return )
232- def return_ (self , emit : EmitCirq , frame : EmitCirqFrame , stmt : func .Return ):
233- # NOTE: should only be hit if ignore_returns == True
234- return ()
235-
236254
237255@py .indexing .dialect .register (key = "emit.cirq" )
238256class __Concrete (interp .MethodTable ):
@@ -241,3 +259,19 @@ class __Concrete(interp.MethodTable):
241259 def getindex (self , interp , frame : interp .Frame , stmt : py .indexing .GetItem ):
242260 # NOTE: no support for indexing into single statements in cirq
243261 return ()
262+
263+ @interp .impl (py .Constant )
264+ def emit_constant (self , emit : EmitCirq , frame : EmitCirqFrame , stmt : py .Constant ):
265+ return (stmt .value .data ,) # pyright: ignore[reportAttributeAccessIssue]
266+
267+
268+ @ilist .dialect .register (key = "emit.cirq" )
269+ class __IList (interp .MethodTable ):
270+ @interp .impl (ilist .New )
271+ def new_ilist (
272+ self ,
273+ emit : EmitCirq ,
274+ frame : interp .Frame ,
275+ stmt : ilist .New ,
276+ ):
277+ return (ilist .IList (data = frame .get_values (stmt .values )),)
0 commit comments