22qubit.address method table for a few builtin dialects.
33"""
44
5- from typing import Any
65from itertools import chain
76
87from kirin import ir , interp
98from kirin .analysis import ForwardFrame , const
109from kirin .dialects import cf , py , scf , func , ilist
11- from kirin .dialects .ilist import IList
1210
1311from .lattice import (
1412 Address ,
2220from .analysis import AddressAnalysis
2321
2422
25- # TODO: put this in the interpreter
26- class CallInterfaceMixin :
27- """This mixin provides a generic implementation to call lattice elements for method tables.
28-
29- It handles PartialLambda and ConstResult wrapping ir.Method."""
30-
31- def call_function (
32- self ,
33- interp_ : AddressAnalysis ,
34- callee : Address ,
35- inputs : tuple [Address , ...],
36- kwargs : tuple [str , ...],
37- ) -> Address :
38-
39- match callee :
40- case PartialLambda (code = code , argnames = argnames ):
41- _ , ret = interp_ .run_callable (
42- code , (callee ,) + interp_ .permute_values (argnames , inputs , kwargs )
43- )
44- return ret
45- case ConstResult (const .Value (ir .Method () as method )):
46- _ , ret = interp_ .run_method (
47- method ,
48- interp_ .permute_values (method .arg_names , inputs , kwargs ),
49- )
50- return ret
51- case _:
52- return Address .top ()
53-
54-
55- class GetValuesMixin :
56- """This mixin provides a generic implementation to extract values of lattice elements
57-
58- that are represent the values of containers. The return type is used to differentiate
59- between IList and Tuple containers in the analysis for cases where the type information
60- is important for the analysis not just the contained values.
61-
62- """
63-
64- def get_values (self , collection : Address ):
65- """Extract the values of a container lattice element.
66-
67- Args:
68- collection: The lattice element representing a container.
69-
70- Returns:
71- A tuple of the container type and the contained values.
72-
73- """
74-
75- def from_constant (constant : const .Result ) -> Address :
76- return ConstResult (constant )
77-
78- def from_literal (literal : Any ) -> Address :
79- return ConstResult (const .Value (literal ))
80-
81- match collection :
82- case PartialIList (data ):
83- return PartialIList , data
84- case PartialTuple (data ):
85- return PartialTuple , data
86- case AddressReg ():
87- return PartialIList , collection .qubits
88- case ConstResult (const .Value (IList () as data )):
89- return PartialIList , tuple (map (from_literal , data ))
90- case ConstResult (const .Value (tuple () as data )):
91- return PartialTuple , tuple (map (from_literal , data ))
92- case ConstResult (const .PartialTuple (data )):
93- return PartialTuple , tuple (map (from_constant , data ))
94- case _:
95- return None , ()
96-
97-
9823@py .constant .dialect .register (key = "qubit.address" )
9924class PyConstant (interp .MethodTable ):
10025 @interp .impl (py .Constant )
@@ -108,7 +33,7 @@ def constant(
10833
10934
11035@py .binop .dialect .register (key = "qubit.address" )
111- class PyBinOp (interp .MethodTable , GetValuesMixin ):
36+ class PyBinOp (interp .MethodTable ):
11237 @interp .impl (py .Add )
11338 def add (
11439 self ,
@@ -119,8 +44,8 @@ def add(
11944 lhs = frame .get (stmt .lhs )
12045 rhs = frame .get (stmt .rhs )
12146
122- lhs_type , lhs_values = self . get_values (lhs )
123- rhs_type , rhs_values = self . get_values (rhs )
47+ lhs_type , lhs_values = interp_ . unpack_iterable (lhs )
48+ rhs_type , rhs_values = interp_ . unpack_iterable (rhs )
12449
12550 if lhs_type is None or rhs_type is None :
12651 return (Address .top (),)
@@ -144,7 +69,7 @@ def new_tuple(
14469
14570
14671@ilist .dialect .register (key = "qubit.address" )
147- class IListMethods (interp .MethodTable , CallInterfaceMixin , GetValuesMixin ):
72+ class IListMethods (interp .MethodTable ):
14873 @interp .impl (ilist .New )
14974 def new_ilist (
15075 self ,
@@ -165,7 +90,7 @@ def map_(
16590 results = []
16691 fn = frame .get (stmt .fn )
16792 collection = frame .get (stmt .collection )
168- collection_type , values = self . get_values (collection )
93+ collection_type , values = interp_ . unpack_iterable (collection )
16994
17095 if collection_type is None :
17196 return (Address .top (),)
@@ -175,21 +100,21 @@ def map_(
175100
176101 results = []
177102 for ele in values :
178- ret = self . call_function ( interp_ , fn , (ele ,), ())
103+ ret = interp_ . run_lattice ( fn , (ele ,), ())
179104 results .append (ret )
180105
181106 if isinstance (stmt , ilist .Map ):
182107 return (PartialIList (tuple (results )),)
183108
184109
185110@py .len .dialect .register (key = "qubit.address" )
186- class PyLen (interp .MethodTable , GetValuesMixin ):
111+ class PyLen (interp .MethodTable ):
187112 @interp .impl (py .Len )
188113 def len_ (
189114 self , interp_ : AddressAnalysis , frame : ForwardFrame [Address ], stmt : py .Len
190115 ):
191116 obj = frame .get (stmt .value )
192- _ , values = self . get_values (obj )
117+ _ , values = interp_ . unpack_iterable (obj )
193118
194119 if values is None :
195120 return (Address .top (),)
@@ -198,7 +123,7 @@ def len_(
198123
199124
200125@py .indexing .dialect .register (key = "qubit.address" )
201- class PyIndexing (interp .MethodTable , GetValuesMixin ):
126+ class PyIndexing (interp .MethodTable ):
202127 @interp .impl (py .GetItem )
203128 def getitem (
204129 self ,
@@ -211,7 +136,7 @@ def getitem(
211136 obj = frame .get (stmt .obj )
212137 index = frame .get (stmt .index )
213138
214- typ , values = self . get_values (obj )
139+ typ , values = interp_ . unpack_iterable (obj )
215140 if typ is None :
216141 return (Address .top (),)
217142
@@ -242,7 +167,7 @@ def alias(
242167
243168# TODO: look for abstract method table for func.
244169@func .dialect .register (key = "qubit.address" )
245- class Func (interp .MethodTable , CallInterfaceMixin ):
170+ class Func (interp .MethodTable ):
246171 @interp .impl (func .Return )
247172 def return_ (
248173 self ,
@@ -296,8 +221,7 @@ def call(
296221 frame : ForwardFrame [Address ],
297222 stmt : func .Call ,
298223 ):
299- result = self .call_function (
300- interp_ ,
224+ result = interp_ .run_lattice (
301225 frame .get (stmt .callee ),
302226 frame .get_values (stmt .inputs ),
303227 stmt .kwargs ,
@@ -376,7 +300,7 @@ def conditional_branch(
376300
377301
378302@scf .dialect .register (key = "qubit.address" )
379- class Scf (interp .MethodTable , GetValuesMixin ):
303+ class Scf (interp .MethodTable ):
380304 @interp .impl (scf .Yield )
381305 def yield_ (
382306 self ,
@@ -442,7 +366,7 @@ def for_loop(
442366 stmt : scf .For ,
443367 ):
444368 loop_vars = frame .get_values (stmt .initializers )
445- iter_type , iterable = self . get_values (frame .get (stmt .iterable ))
369+ iter_type , iterable = interp_ . unpack_iterable (frame .get (stmt .iterable ))
446370
447371 if iter_type is None :
448372 return interp_ .eval_stmt_fallback (frame , stmt )
0 commit comments