1- from typing import TypeVar
1+ from typing import Any , Type , TypeVar
22from dataclasses import field
33
4- from kirin import ir , interp
4+ from kirin import ir , types , interp
55from kirin .analysis import Forward , const
6+ from kirin .dialects .ilist import IList
67from kirin .analysis .forward import ForwardFrame
8+ from kirin .analysis .const .lattice import PartialLambda
79
8- from bloqade .types import QubitType
9-
10- from .lattice import Address
10+ from .lattice import Address , AddressReg , ConstResult , PartialIList , PartialTuple
1111
1212
1313class AddressAnalysis (Forward [Address ]):
@@ -16,12 +16,15 @@ class AddressAnalysis(Forward[Address]):
1616 """
1717
1818 keys = ["qubit.address" ]
19+ _const_prop : const .Propagate
1920 lattice = Address
2021 next_address : int = field (init = False )
2122
2223 def initialize (self ):
2324 super ().initialize ()
2425 self .next_address : int = 0
26+ self ._const_prop = const .Propagate (self .dialects )
27+ self ._const_prop .initialize ()
2528 return self
2629
2730 @property
@@ -31,30 +34,117 @@ def qubit_count(self) -> int:
3134
3235 T = TypeVar ("T" )
3336
34- def get_const_value (self , typ : type [T ], value : ir .SSAValue ) -> T :
35- if isinstance (hint := value .hints .get ("const" ), const .Value ):
36- data = hint .data
37- if isinstance (data , typ ):
38- return hint .data
39- raise interp .InterpreterError (
40- f"Expected constant value <type = { typ } >, got { data } "
41- )
42- raise interp .InterpreterError (
43- f"Expected constant value <type = { typ } >, got { value } "
44- )
45-
46- def eval_stmt_fallback (
47- self , frame : ForwardFrame [Address ], stmt : ir .Statement
48- ) -> tuple [Address , ...] | interp .SpecialValue [Address ]:
49- return tuple (
50- (
51- self .lattice .top ()
52- if result .type .is_subseteq (QubitType )
53- else self .lattice .bottom ()
54- )
55- for result in stmt .results
56- )
37+ def to_address (self , result : const .Result ):
38+ return ConstResult (result )
39+
40+ def try_eval_const_prop (
41+ self ,
42+ frame : ForwardFrame [Address ],
43+ stmt : ir .Statement ,
44+ args : tuple [ConstResult , ...],
45+ ) -> interp .StatementResult [Address ]:
46+ _frame = self ._const_prop .initialize_frame (frame .code )
47+ _frame .set_values (stmt .args , tuple (x .result for x in args ))
48+ result = self ._const_prop .eval_stmt (_frame , stmt )
49+
50+ match result :
51+ case interp .ReturnValue (constant_ret ):
52+ return interp .ReturnValue (self .to_address (constant_ret ))
53+ case interp .YieldValue (constant_values ):
54+ return interp .YieldValue (tuple (map (self .to_address , constant_values )))
55+ case interp .Successor (block , block_args ):
56+ return interp .Successor (block , * map (self .to_address , block_args ))
57+ case tuple ():
58+ return tuple (map (self .to_address , result ))
59+ case _:
60+ return result
61+
62+ def unpack_iterable (self , iterable : Address ):
63+ """Extract the values of a container lattice element.
64+
65+ Args:
66+ iterable: The lattice element representing a container.
67+
68+ Returns:
69+ A tuple of the container type and the contained values.
70+
71+ """
72+
73+ def from_constant (constant : const .Result ) -> Address :
74+ return ConstResult (constant )
75+
76+ def from_literal (literal : Any ) -> Address :
77+ return ConstResult (const .Value (literal ))
78+
79+ match iterable :
80+ case PartialIList (data ):
81+ return PartialIList , data
82+ case PartialTuple (data ):
83+ return PartialTuple , data
84+ case AddressReg ():
85+ return PartialIList , iterable .qubits
86+ case ConstResult (const .Value (IList () as data )):
87+ return PartialIList , tuple (map (from_literal , data ))
88+ case ConstResult (const .Value (tuple () as data )):
89+ return PartialTuple , tuple (map (from_literal , data ))
90+ case ConstResult (const .PartialTuple (data )):
91+ return PartialTuple , tuple (map (from_constant , data ))
92+ case _:
93+ return None , ()
94+
95+ def run_lattice (
96+ self ,
97+ callee : Address ,
98+ inputs : tuple [Address , ...],
99+ kwargs : tuple [str , ...],
100+ ) -> Address :
101+ """Run a callable lattice element with the given inputs and keyword arguments.
102+
103+ Args:
104+ callee (Address): The lattice element representing the callable.
105+ inputs (tuple[Address, ...]): The input lattice elements.
106+ kwargs (tuple[str, ...]): The keyword argument names.
107+
108+ Returns:
109+ Address: The resulting lattice element after invoking the callable.
110+
111+ """
112+
113+ match callee :
114+ case PartialLambda (code = code , argnames = argnames ):
115+ _ , ret = self .run_callable (
116+ code , (callee ,) + self .permute_values (argnames , inputs , kwargs )
117+ )
118+ return ret
119+ case ConstResult (const .Value (ir .Method () as method )):
120+ _ , ret = self .run_method (
121+ method ,
122+ self .permute_values (method .arg_names , inputs , kwargs ),
123+ )
124+ return ret
125+ case _:
126+ return Address .top ()
127+
128+ def get_const_value (self , addr : Address , typ : Type [T ]) -> T | None :
129+ if not isinstance (addr , ConstResult ):
130+ return None
131+
132+ if not isinstance (result := addr .result , const .Value ):
133+ return None
134+
135+ if not isinstance (value := result .data , typ ):
136+ return None
137+
138+ return value
139+
140+ def eval_stmt_fallback (self , frame : ForwardFrame [Address ], stmt : ir .Statement ):
141+ args = frame .get_values (stmt .args )
142+ if types .is_tuple_of (args , ConstResult ):
143+ return self .try_eval_const_prop (frame , stmt , args )
144+
145+ return tuple (Address .from_type (result .type ) for result in stmt .results )
57146
58147 def run_method (self , method : ir .Method , args : tuple [Address , ...]):
59148 # NOTE: we do not support dynamic calls here, thus no need to propagate method object
60- return self .run_callable (method .code , (self .lattice .bottom (),) + args )
149+ self_mt = ConstResult (const .Value (method ))
150+ return self .run_callable (method .code , (self_mt ,) + args )
0 commit comments