@@ -149,25 +149,28 @@ def __init__(
149149 self ,
150150 target : Union [Operation , Value , OpView ],
151151 * ,
152- sizes : Optional [Union [DynamicIndexList , ArrayAttr ]] = None ,
153- interchange : OptionalIntList = None ,
152+ tile_sizes : Optional [Union [DynamicIndexList , ArrayAttr ]] = None ,
153+ tile_interchange : OptionalIntList = None ,
154154 loc = None ,
155155 ip = None ,
156156 ):
157157 ...
158158
159159 def __init__ (
160160 self ,
161- loop_types_or_target : Union [Type , List [Type ], Operation , Value ],
161+ loop_types_or_target : Union [Type , Sequence [Type ], Operation , OpView , Value ],
162162 target_or_none : Optional [Union [Operation , Value , OpView ]] = None ,
163163 * ,
164- sizes : Optional [Union [DynamicIndexList , ArrayAttr ]] = None ,
165- interchange : OptionalIntList = None ,
164+ tile_sizes : Optional [Union [DynamicIndexList , ArrayAttr ]] = None ,
165+ tile_interchange : OptionalIntList = None ,
166166 loc = None ,
167167 ip = None ,
168168 ):
169- sizes = sizes if sizes else []
170- num_loops = sum (v if v == 0 else 1 for v in sizes )
169+ tile_sizes = tile_sizes if tile_sizes else []
170+ tile_interchange = tile_interchange if tile_interchange else []
171+ _ , tile_sizes , _ = _dispatch_dynamic_index_list (tile_sizes )
172+ _ , tile_interchange , _ = _dispatch_dynamic_index_list (tile_interchange )
173+ num_loops = sum (0 if v == 0 else 1 for v in tile_sizes )
171174
172175 if isinstance (loop_types_or_target , (Operation , Value , OpView )):
173176 loop_types = [transform .AnyOpType .get ()] * num_loops
@@ -184,8 +187,8 @@ def __init__(
184187 target .type ,
185188 loop_types ,
186189 target ,
187- tile_sizes = sizes ,
188- tile_interchange = interchange ,
190+ tile_sizes = tile_sizes ,
191+ tile_interchange = tile_interchange ,
189192 loc = loc ,
190193 ip = ip ,
191194 )
0 commit comments