44
55from typing import Any
66from itertools import chain
7- from collections .abc import Iterable
87
98from kirin import ir , interp
109from kirin .analysis import ForwardFrame , const
@@ -57,14 +56,20 @@ def from_literal(literal: Any) -> Address:
5756 return ConstResult (const .Value (literal ))
5857
5958 match collection :
60- case PartialIList (data ) | PartialTuple (data ):
61- return data
59+ case PartialIList (data ):
60+ return PartialIList , data
61+ case PartialTuple (data ):
62+ return PartialTuple , data
6263 case AddressReg ():
63- return collection .qubits
64- case ConstResult (const .Value (data )) if isinstance (data , Iterable ):
65- return tuple (map (from_literal , data ))
64+ return PartialIList , collection .qubits
65+ case ConstResult (const .Value (IList () as data )):
66+ return PartialIList , tuple (map (from_literal , data ))
67+ case ConstResult (const .Value (tuple () as data )):
68+ return PartialTuple , tuple (map (from_literal , data ))
6669 case ConstResult (const .PartialTuple (data )):
67- return tuple (map (from_constant , data ))
70+ return PartialTuple , tuple (map (from_constant , data ))
71+ case _:
72+ return None , ()
6873
6974
7075@py .constant .dialect .register (key = "qubit.address" )
@@ -90,36 +95,17 @@ def add(
9095 ):
9196 lhs = frame .get (stmt .lhs )
9297 rhs = frame .get (stmt .rhs )
93- print (lhs , rhs )
94- match lhs , rhs :
95- case (PartialTuple (lhs_data ), PartialTuple (rhs_data )) | (
96- PartialIList (lhs_data ),
97- PartialIList (rhs_data ),
98- ):
99- return (lhs .new (lhs_data + rhs_data ),)
100- case (AddressReg (lhs_data ), AddressReg (rhs_data )):
101- return (AddressReg (tuple (chain (lhs_data , rhs_data ))),)
10298
103- lhs_constant = lhs . result if isinstance ( lhs , ConstResult ) else const . Unknown ( )
104- rhs_constant = rhs . result if isinstance ( rhs , ConstResult ) else const . Unknown ( )
99+ lhs_type , lhs_values = self . get_values ( lhs )
100+ rhs_type , rhs_values = self . get_values ( rhs )
105101
106- match (lhs , rhs_constant ):
107- case PartialIList () | AddressReg (), const .Value (IList () as lst ) if (
108- len (lst ) == 0
109- ):
110- return (lhs ,)
111- case PartialTuple (), const .Value (()):
112- return (lhs ,)
102+ if lhs_type is None or rhs_type is None :
103+ return (Address .top (),)
113104
114- match (lhs_constant , rhs ):
115- case const .Value (IList () as lst ), PartialIList () | AddressReg () if (
116- len (lst ) == 0
117- ):
118- return (rhs ,)
119- case const .Value (()), PartialTuple ():
120- return (rhs ,)
105+ if lhs_type is not rhs_type :
106+ return (Address .bottom (),)
121107
122- return interp_ . eval_stmt_fallback ( frame , stmt )
108+ return ( lhs_type ( tuple ( chain ( lhs_values , rhs_values ))), )
123109
124110
125111@py .tuple .dialect .register (key = "qubit.address" )
@@ -156,11 +142,14 @@ def map_(
156142 results = []
157143 fn = frame .get (stmt .fn )
158144 collection = frame .get (stmt .collection )
159- iterable = self .get_values (collection )
145+ iterable_type , iterable = self .get_values (collection )
160146
161- if iterable is None :
147+ if iterable_type is None :
162148 return (Address .top (),)
163149
150+ if iterable_type is not PartialIList :
151+ return (Address .bottom (),)
152+
164153 results = []
165154 for ele in iterable :
166155 ret = self .call_function (interp_ , fn , (ele ,), ())
@@ -429,9 +418,9 @@ def for_loop(
429418 stmt : scf .For ,
430419 ):
431420 loop_vars = frame .get_values (stmt .initializers )
432- iterable = self .get_values (frame .get (stmt .iterable ))
421+ iter_type , iterable = self .get_values (frame .get (stmt .iterable ))
433422
434- if iterable is None :
423+ if iter_type is None :
435424 return interp_ .eval_stmt_fallback (frame , stmt )
436425
437426 for value in iterable :
0 commit comments