@@ -28,9 +28,18 @@ def to_muli(op, rewriter, pattern):
2828 new_op = arith .muli (op .operands [0 ], op .operands [1 ], loc = op .location )
2929 rewriter .replace_op (op , new_op .owner )
3030
31+ def constant_1_to_2 (op , rewriter , pattern ):
32+ c = op .attributes ["value" ].value
33+ if c != 1 :
34+ return True # failed to match
35+ with rewriter .ip :
36+ new_op = arith .constant (op .result .type , 2 , loc = op .location )
37+ rewriter .replace_op (op , [new_op ])
38+
3139 with Context ():
3240 patterns = RewritePatternSet ()
3341 patterns .add (arith .AddIOp , to_muli )
42+ patterns .add (arith .ConstantOp , constant_1_to_2 )
3443 frozen = patterns .freeze ()
3544
3645 module = ModuleOp .parse (
@@ -48,3 +57,21 @@ def to_muli(op, rewriter, pattern):
4857 # CHECK: %0 = arith.muli %arg0, %arg1 : i64
4958 # CHECK: return %0 : i64
5059 print (module )
60+
61+ module = ModuleOp .parse (
62+ r"""
63+ module {
64+ func.func @const() -> (i64, i64) {
65+ %0 = arith.constant 1 : i64
66+ %1 = arith.constant 3 : i64
67+ return %0, %1 : i64, i64
68+ }
69+ }
70+ """
71+ )
72+
73+ apply_patterns_and_fold_greedily (module , frozen )
74+ # CHECK: %c2_i64 = arith.constant 2 : i64
75+ # CHECK: %c3_i64 = arith.constant 3 : i64
76+ # CHECK: return %c2_i64, %c3_i64 : i64, i64
77+ print (module )
0 commit comments