From dce14244c82938dd3281a0db44b915e0df1e87b5 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 18 Jun 2025 19:36:12 +0000 Subject: [PATCH] [mlir][arith] Fix multiplication canonicalizations The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ```mlir %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ```mlir _, %hi = arith.mului_extended %a, %b : i32 ```` --- .../Dialect/Arith/IR/ArithCanonicalization.td | 14 ++++++-- mlir/test/Dialect/Arith/canonicalize.mlir | 32 ++++++++++++++++++- 2 files changed, 42 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td index 13eb97a910bd4..2f7beed549108 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td +++ b/mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td @@ -273,7 +273,7 @@ def RedundantSelectFalse : Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)), (SelectOp $pred, $a, $c)>; -// select(pred, false, true) => not(pred) +// select(pred, false, true) => not(pred) def SelectI1ToNot : Pat<(SelectOp $pred, (ConstantLikeMatcher ConstantAttr), @@ -376,6 +376,12 @@ def TruncationMatchesShiftAmount : CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == " "*getIntOrSplatIntValue($2)">]>>; +def ValueWidthMatchesShiftAmount : + Constraint, + CPred<"getScalarOrElementWidth($0) == " + "*getIntOrSplatIntValue($1)">]>>; + // trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated def TruncIExtSIToExtSI : Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)), @@ -406,7 +412,8 @@ def TruncIShrUIMulIToMulSIExtended : (Arith_MulSIExtendedOp:$res__1 $x, $y), [(ValuesWithSameType $tr, $x, $y), (ValueWiderThan $mul, $x), - (TruncationMatchesShiftAmount $mul, $x, $c0)]>; + (TruncationMatchesShiftAmount $mul, $x, $c0), + (ValueWidthMatchesShiftAmount $x, $c0)]>; // trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) def TruncIShrUIMulIToMulUIExtended : @@ -417,7 +424,8 @@ def TruncIShrUIMulIToMulUIExtended : (Arith_MulUIExtendedOp:$res__1 $x, $y), [(ValuesWithSameType $tr, $x, $y), (ValueWiderThan $mul, $x), - (TruncationMatchesShiftAmount $mul, $x, $c0)]>; + (TruncationMatchesShiftAmount $mul, $x, $c0), + (ValueWidthMatchesShiftAmount $x, $c0)]>; //===----------------------------------------------------------------------===// // TruncIOp diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index b6188c81ff912..542603722ab8a 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index { // CHECK-LABEL: @foldSubXX_tensor -// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> +// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> // CHECK: %[[sub:.+]] = arith.subi // CHECK: return %[[c0]], %[[sub]] func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor) -> (tensor<10xi32>, tensor) { @@ -2966,6 +2966,21 @@ func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 { return %hi : i32 } +// Verify that the signed extended multiplication pattern does not match +// if the right shift does not match the bitwidth of the multipliers. + +// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift +// CHECK-NOT: arith.mulsi_extended +func.func @wideMulToMulSIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 { + %x = arith.extsi %a: i32 to i33 + %y = arith.extsi %b: i32 to i33 + %m = arith.muli %x, %y: i33 + %c1 = arith.constant 1: i33 + %sh = arith.shrui %m, %c1 : i33 + %hi = arith.trunci %sh: i33 to i32 + return %hi : i32 +} + // CHECK-LABEL: @wideMulToMulSIExtendedVector // CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) // CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32> @@ -2994,6 +3009,21 @@ func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 { return %hi : i32 } +// Verify that the unsigned extended multiplication pattern does not match +// if the right shift does not match the bitwidth of the multipliers. + +// CHECK-LABEL: @wideMulToMulUIExtendedWithWrongShift +// CHECK-NOT: arith.mului_extended +func.func @wideMulToMulUIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 { + %x = arith.extui %a: i32 to i33 + %y = arith.extui %b: i32 to i33 + %m = arith.muli %x, %y: i33 + %c1 = arith.constant 1: i33 + %sh = arith.shrui %m, %c1 : i33 + %hi = arith.trunci %sh: i33 to i32 + return %hi : i32 +} + // CHECK-LABEL: @wideMulToMulUIExtendedVector // CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>) // CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>