@@ -16,6 +16,7 @@ def construct_and_print_in_module(f):
1616 print (module )
1717 return f
1818
19+
1920def get_pdl_patterns ():
2021 # Create a rewrite from add to mul. This will match
2122 # - operation name is arith.addi
@@ -254,25 +255,39 @@ def pat():
254255 t = pdl .TypeOp (i32 )
255256 cst = pdl .AttributeOp ()
256257 pdl .apply_native_constraint ([], "is_one" , [cst ])
257- op0 = pdl .OperationOp (name = "myint.constant" , attributes = {"value" : cst }, types = [t ])
258+ op0 = pdl .OperationOp (
259+ name = "myint.constant" , attributes = {"value" : cst }, types = [t ]
260+ )
258261
259262 @pdl .rewrite ()
260263 def rew ():
261- expanded = pdl .apply_native_rewrite ([pdl .OperationType .get ()], "expand" , [cst ])
264+ expanded = pdl .apply_native_rewrite (
265+ [pdl .OperationType .get ()], "expand" , [cst ]
266+ )
262267 pdl .ReplaceOp (op0 , with_op = expanded )
263268
264269 def is_one (rewriter , results , values ):
265270 cst = values [0 ].value
266271 return cst <= 1
267-
272+
268273 def expand (rewriter , results , values ):
269274 cst = values [0 ].value
270275 c1 = cst // 2
271276 c2 = cst - c1
272277 with rewriter .ip ():
273- op1 = Operation .create ("myint.constant" , results = [i32 ], attributes = {"value" : IntegerAttr .get (i32 , c1 )})
274- op2 = Operation .create ("myint.constant" , results = [i32 ], attributes = {"value" : IntegerAttr .get (i32 , c2 )})
275- res = Operation .create ("myint.add" , results = [i32 ], operands = [op1 .result , op2 .result ])
278+ op1 = Operation .create (
279+ "myint.constant" ,
280+ results = [i32 ],
281+ attributes = {"value" : IntegerAttr .get (i32 , c1 )},
282+ )
283+ op2 = Operation .create (
284+ "myint.constant" ,
285+ results = [i32 ],
286+ attributes = {"value" : IntegerAttr .get (i32 , c2 )},
287+ )
288+ res = Operation .create (
289+ "myint.add" , results = [i32 ], operands = [op1 .result , op2 .result ]
290+ )
276291 results .append (res )
277292
278293 pdl_module = PDLModule (m )
0 commit comments