2121from typing import Dict , Optional , Sequence , Union , NewType
2222
2323
24+ @register_attribute_builder ("ParamOperandIndexAttr" )
25+ def _paramOperandIndexAttr (x : int , context ) -> Attribute :
26+ return Attribute .parse (f"#transform.param_operand_index<{ x } >" , context = context )
27+
28+
2429@_ods_cext .register_operation (_Dialect , replace = True )
2530class CastOp (CastOp ):
2631 def __init__ (
@@ -214,11 +219,6 @@ def __init__(
214219 super ().__init__ (_get_op_results_or_values (operands ), loc = loc , ip = ip )
215220
216221
217- @register_attribute_builder ("ParamOperandIndexAttr" )
218- def _paramOperandIndexAttr (x : int , context ) -> Attribute :
219- return Attribute .parse (f"#transform.param_operand_index<{ x } >" , context = context )
220-
221-
222222@_ods_cext .register_operation (_Dialect , replace = True )
223223class ApplyRegisteredPassOp (ApplyRegisteredPassOp ):
224224 def __init__ (
@@ -227,10 +227,12 @@ def __init__(
227227 pass_name : Union [str , StringAttr ],
228228 target : Union [Operation , Value , OpView ],
229229 * ,
230- options : Dict [
231- Union [str , StringAttr ],
232- Union [Attribute , Value , Operation , OpView ],
233- ] = {},
230+ options : Optional [
231+ Dict [
232+ Union [str , StringAttr ],
233+ Union [Attribute , Value , Operation , OpView ],
234+ ]
235+ ] = None ,
234236 loc = None ,
235237 ip = None ,
236238 ):
@@ -241,20 +243,16 @@ def __init__(
241243 context = (loc and loc .context ) or Context .current
242244
243245 cur_param_operand_idx = 0
244- for key , value in options .items ():
246+ for key , value in options .items () if options is not None else {} :
245247 if isinstance (key , StringAttr ):
246248 key = key .value
247249
248250 if isinstance (value , (Value , Operation , OpView )):
249- value = _get_op_result_or_value (value )
250- # v = Attribute.parse(
251- # f"#transform.param_operand_index<{cur_param_operand_idx}>",
252- # context=context,
253- # )
254- v = _paramOperandIndexAttr (cur_param_operand_idx , context )
255- options_dict [key ] = v
251+ dynamic_options .append (_get_op_result_or_value (value ))
252+ options_dict [key ] = ParamOperandIndexAttr (
253+ cur_param_operand_idx , context
254+ )
256255 cur_param_operand_idx += 1
257- dynamic_options .append (value )
258256 elif isinstance (value , Attribute ):
259257 options_dict [key ] = value
260258 elif isinstance (value , str ):
@@ -279,10 +277,12 @@ def apply_registered_pass(
279277 pass_name : Union [str , StringAttr ],
280278 target : Union [Operation , Value , OpView ],
281279 * ,
282- options : Dict [
283- Union [str , StringAttr ],
284- Union [Attribute , Value , Operation , OpView ],
285- ] = {},
280+ options : Optional [
281+ Dict [
282+ Union [str , StringAttr ],
283+ Union [Attribute , Value , Operation , OpView ],
284+ ]
285+ ] = None ,
286286 loc = None ,
287287 ip = None ,
288288) -> Value :
0 commit comments