@@ -28,18 +28,18 @@ def make_pdl_module():
2828 @pdl .pattern (benefit = 1 , sym_name = "addi_to_mul" )
2929 def pat ():
3030 # Match arith.addi with index types.
31- index_type = pdl .TypeOp (IndexType . get ( ))
32- operand0 = pdl .OperandOp (index_type )
33- operand1 = pdl .OperandOp (index_type )
31+ i64_type = pdl .TypeOp (IntegerType . get_signless ( 64 ))
32+ operand0 = pdl .OperandOp (i64_type )
33+ operand1 = pdl .OperandOp (i64_type )
3434 op0 = pdl .OperationOp (
35- name = "arith.addi" , args = [operand0 , operand1 ], types = [index_type ]
35+ name = "arith.addi" , args = [operand0 , operand1 ], types = [i64_type ]
3636 )
3737
3838 # Replace the matched op with arith.muli.
3939 @pdl .rewrite ()
4040 def rew ():
4141 newOp = pdl .OperationOp (
42- name = "arith.muli" , args = [operand0 , operand1 ], types = [index_type ]
42+ name = "arith.muli" , args = [operand0 , operand1 ], types = [i64_type ]
4343 )
4444 pdl .ReplaceOp (op0 , with_op = newOp )
4545
@@ -63,17 +63,21 @@ def run(self, m):
6363 module = ModuleOp .parse (
6464 r"""
6565 module {
66- func.func @add(%a: index , %b: index ) -> index {
67- %sum = arith.addi %a, %b : index
68- return %sum : index
66+ func.func @add(%a: i64 , %b: i64 ) -> i64 {
67+ %sum = arith.addi %a, %b : i64
68+ return %sum : i64
6969 }
7070 }
7171 """
7272 )
7373
74- # CHECK-LABEL: Dump After CustomPass
75- # CHECK: arith.muli
7674 pm = PassManager ("any" )
7775 pm .enable_ir_printing ()
76+
77+ # CHECK-LABEL: Dump After CustomPass
78+ # CHECK: arith.muli
7879 pm .add (CustomPass ())
80+ # CHECK-LABEL: Dump After ArithToLLVMConversionPass
81+ # CHECK: llvm.mul
82+ pm .add ("convert-arith-to-llvm" )
7983 pm .run (module )
0 commit comments