@@ -121,8 +121,10 @@ def load_myint_dialect():
121121
122122
123123# This PDL pattern is to fold constant additions,
124- # i.e. add(constant0, constant1) -> constant2
125- # where constant2 = constant0 + constant1.
124+ # including two patterns:
125+ # 1. add(constant0, constant1) -> constant2
126+ # where constant2 = constant0 + constant1;
127+ # 2. add(x, 0) or add(0, x) -> x.
126128def get_pdl_pattern_fold ():
127129 m = Module .create ()
128130 i32 = IntegerType .get_signless (32 )
@@ -237,3 +239,73 @@ def test_pdl_register_function_constraint(module_):
237239 apply_patterns_and_fold_greedily (module_ , frozen )
238240
239241 return module_
242+
243+
244+ # This pattern is to expand constant to additions
245+ # unless the constant is no more than 1,
246+ # e.g. 3 -> 1 + 2 -> 1 + (1 + 1).
247+ def get_pdl_pattern_expand ():
248+ m = Module .create ()
249+ i32 = IntegerType .get_signless (32 )
250+ with InsertionPoint (m .body ):
251+
252+ @pdl .pattern (benefit = 1 , sym_name = "myint_constant_expand" )
253+ def pat ():
254+ t = pdl .TypeOp (i32 )
255+ cst = pdl .AttributeOp ()
256+ pdl .apply_native_constraint ([], "is_one" , [cst ])
257+ op0 = pdl .OperationOp (name = "myint.constant" , attributes = {"value" : cst }, types = [t ])
258+
259+ @pdl .rewrite ()
260+ def rew ():
261+ expanded = pdl .apply_native_rewrite ([pdl .OperationType .get ()], "expand" , [cst ])
262+ pdl .ReplaceOp (op0 , with_op = expanded )
263+
264+ def is_one (rewriter , results , values ):
265+ cst = values [0 ].value
266+ return cst <= 1
267+
268+ def expand (rewriter , results , values ):
269+ cst = values [0 ].value
270+ c1 = cst // 2
271+ c2 = cst - c1
272+ with rewriter .ip ():
273+ op1 = Operation .create ("myint.constant" , results = [i32 ], attributes = {"value" : IntegerAttr .get (i32 , c1 )})
274+ op2 = Operation .create ("myint.constant" , results = [i32 ], attributes = {"value" : IntegerAttr .get (i32 , c2 )})
275+ res = Operation .create ("myint.add" , results = [i32 ], operands = [op1 .result , op2 .result ])
276+ results .append (res )
277+
278+ pdl_module = PDLModule (m )
279+ pdl_module .register_constraint_function ("is_one" , is_one )
280+ pdl_module .register_rewrite_function ("expand" , expand )
281+ return pdl_module .freeze ()
282+
283+
284+ # CHECK-LABEL: TEST: test_pdl_register_function_expand
285+ # CHECK: %0 = "myint.constant"() {value = 1 : i32} : () -> i32
286+ # CHECK: %1 = "myint.constant"() {value = 1 : i32} : () -> i32
287+ # CHECK: %2 = "myint.add"(%0, %1) : (i32, i32) -> i32
288+ # CHECK: %3 = "myint.constant"() {value = 1 : i32} : () -> i32
289+ # CHECK: %4 = "myint.constant"() {value = 1 : i32} : () -> i32
290+ # CHECK: %5 = "myint.constant"() {value = 1 : i32} : () -> i32
291+ # CHECK: %6 = "myint.add"(%4, %5) : (i32, i32) -> i32
292+ # CHECK: %7 = "myint.add"(%3, %6) : (i32, i32) -> i32
293+ # CHECK: %8 = "myint.add"(%2, %7) : (i32, i32) -> i32
294+ # CHECK: return %8 : i32
295+ @construct_and_print_in_module
296+ def test_pdl_register_function_expand (module_ ):
297+ load_myint_dialect ()
298+
299+ module_ = Module .parse (
300+ """
301+ func.func @f() -> i32 {
302+ %0 = "myint.constant"() { value = 5 }: () -> (i32)
303+ return %0 : i32
304+ }
305+ """
306+ )
307+
308+ frozen = get_pdl_pattern_expand ()
309+ apply_patterns_and_fold_greedily (module_ , frozen )
310+
311+ return module_
0 commit comments