Skip to content

Commit dce1424

Browse files
committed
[mlir][arith] Fix multiplication canonicalizations
The Arith dialect includes patterns that canonicalize a sequence of: - trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y) - trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y) These patterns return the high word of an extended multiplication, which assumes that the shift amount is equal to the bit width of the original operands. This check was missing, leading to incorrect canonicalizations when the shift amount was less than the bit width. For example, the following code: ```mlir %x = arith.extui %a: i32 to i33 %y = arith.extui %b: i32 to i33 %m = arith.muli %x, %y: i33 %c1 = arith.constant 1: i33 %sh = arith.shrui %m, %c1 : i33 %hi = arith.trunci %sh: i33 to i32 ``` would incorrectly be canonicalized to: ```mlir _, %hi = arith.mului_extended %a, %b : i32 ````
1 parent 17f5b8b commit dce1424

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

mlir/lib/Dialect/Arith/IR/ArithCanonicalization.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def RedundantSelectFalse :
273273
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
274274
(SelectOp $pred, $a, $c)>;
275275

276-
// select(pred, false, true) => not(pred)
276+
// select(pred, false, true) => not(pred)
277277
def SelectI1ToNot :
278278
Pat<(SelectOp $pred,
279279
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -376,6 +376,12 @@ def TruncationMatchesShiftAmount :
376376
CPred<"(getScalarOrElementWidth($0) - getScalarOrElementWidth($1)) == "
377377
"*getIntOrSplatIntValue($2)">]>>;
378378

379+
def ValueWidthMatchesShiftAmount :
380+
Constraint<And<[
381+
CPred<"succeeded(getIntOrSplatIntValue($1))">,
382+
CPred<"getScalarOrElementWidth($0) == "
383+
"*getIntOrSplatIntValue($1)">]>>;
384+
379385
// trunci(extsi(x)) -> extsi(x), when only the sign-extension bits are truncated
380386
def TruncIExtSIToExtSI :
381387
Pat<(Arith_TruncIOp:$tr (Arith_ExtSIOp:$ext $x)),
@@ -406,7 +412,8 @@ def TruncIShrUIMulIToMulSIExtended :
406412
(Arith_MulSIExtendedOp:$res__1 $x, $y),
407413
[(ValuesWithSameType $tr, $x, $y),
408414
(ValueWiderThan $mul, $x),
409-
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
415+
(TruncationMatchesShiftAmount $mul, $x, $c0),
416+
(ValueWidthMatchesShiftAmount $x, $c0)]>;
410417

411418
// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
412419
def TruncIShrUIMulIToMulUIExtended :
@@ -417,7 +424,8 @@ def TruncIShrUIMulIToMulUIExtended :
417424
(Arith_MulUIExtendedOp:$res__1 $x, $y),
418425
[(ValuesWithSameType $tr, $x, $y),
419426
(ValueWiderThan $mul, $x),
420-
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
427+
(TruncationMatchesShiftAmount $mul, $x, $c0),
428+
(ValueWidthMatchesShiftAmount $x, $c0)]>;
421429

422430
//===----------------------------------------------------------------------===//
423431
// TruncIOp

mlir/test/Dialect/Arith/canonicalize.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
10001000

10011001

10021002
// CHECK-LABEL: @foldSubXX_tensor
1003-
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
1003+
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
10041004
// CHECK: %[[sub:.+]] = arith.subi
10051005
// CHECK: return %[[c0]], %[[sub]]
10061006
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2966,6 +2966,21 @@ func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
29662966
return %hi : i32
29672967
}
29682968

2969+
// Verify that the signed extended multiplication pattern does not match
2970+
// if the right shift does not match the bitwidth of the multipliers.
2971+
2972+
// CHECK-LABEL: @wideMulToMulSIExtendedWithWrongShift
2973+
// CHECK-NOT: arith.mulsi_extended
2974+
func.func @wideMulToMulSIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
2975+
%x = arith.extsi %a: i32 to i33
2976+
%y = arith.extsi %b: i32 to i33
2977+
%m = arith.muli %x, %y: i33
2978+
%c1 = arith.constant 1: i33
2979+
%sh = arith.shrui %m, %c1 : i33
2980+
%hi = arith.trunci %sh: i33 to i32
2981+
return %hi : i32
2982+
}
2983+
29692984
// CHECK-LABEL: @wideMulToMulSIExtendedVector
29702985
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
29712986
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
@@ -2994,6 +3009,21 @@ func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
29943009
return %hi : i32
29953010
}
29963011

3012+
// Verify that the unsigned extended multiplication pattern does not match
3013+
// if the right shift does not match the bitwidth of the multipliers.
3014+
3015+
// CHECK-LABEL: @wideMulToMulUIExtendedWithWrongShift
3016+
// CHECK-NOT: arith.mului_extended
3017+
func.func @wideMulToMulUIExtendedWithWrongShift(%a: i32, %b: i32) -> i32 {
3018+
%x = arith.extui %a: i32 to i33
3019+
%y = arith.extui %b: i32 to i33
3020+
%m = arith.muli %x, %y: i33
3021+
%c1 = arith.constant 1: i33
3022+
%sh = arith.shrui %m, %c1 : i33
3023+
%hi = arith.trunci %sh: i33 to i32
3024+
return %hi : i32
3025+
}
3026+
29973027
// CHECK-LABEL: @wideMulToMulUIExtendedVector
29983028
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
29993029
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>

0 commit comments

Comments
 (0)