Skip to content

Commit 891d5ab

Browse files
committed
[mlir][spirv] Implement UMod canonicalization for vector constants
1 parent b7f5950 commit 891d5ab

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns(
326326

327327
// The transformation is only applied if one divisor is a multiple of the other.
328328

329-
// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
330329
struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
331330
using OpRewritePattern::OpRewritePattern;
332331

@@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
336335
if (!prevUMod)
337336
return failure();
338337

339-
IntegerAttr prevValue;
340-
IntegerAttr currValue;
338+
TypedAttr prevValue;
339+
TypedAttr currValue;
341340
if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
342341
!matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
343342
return failure();
344343

345-
APInt prevConstValue = prevValue.getValue();
346-
APInt currConstValue = currValue.getValue();
344+
// Ensure that previous divisor is a multiple of the current divisor. If
345+
// not, fail the transformation.
346+
bool isApplicable = false;
347+
if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
348+
auto currInt = dyn_cast<IntegerAttr>(currValue);
349+
isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
350+
} else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
351+
auto currVec = dyn_cast<DenseElementsAttr>(currValue);
352+
isApplicable = llvm::all_of(
353+
llvm::zip(prevVec.getValues<APInt>(), currVec.getValues<APInt>()),
354+
[](const auto &pair) {
355+
const auto &[a, b] = pair;
356+
return a.urem(b) == 0;
357+
});
358+
}
347359

348-
// Ensure that one divisor is a multiple of the other. If not, fail the
349-
// transformation.
350-
if (prevConstValue.urem(currConstValue) != 0 &&
351-
currConstValue.urem(prevConstValue) != 0)
360+
if (!isApplicable)
352361
return failure();
353362

354363
// The transformation is safe. Replace the existing UMod operation with a

mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) {
967967
return %0, %1: i32, i32
968968
}
969969

970-
// CHECK-LABEL: @umod_fail_vector_fold
970+
// CHECK-LABEL: @umod_vector_fold
971971
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
972-
func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
972+
func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
973973
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
974974
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
975975
%const1 = spirv.Constant dense<32> : vector<4xi32>
976976
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
977-
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
978977
%const2 = spirv.Constant dense<4> : vector<4xi32>
979978
%1 = spirv.UMod %0, %const2 : vector<4xi32>
980-
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
979+
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
980+
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
981981
// CHECK: return %[[UMOD0]], %[[UMOD1]]
982982
return %0, %1: vector<4xi32>, vector<4xi32>
983983
}
@@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
996996
return %0, %1: i32, i32
997997
}
998998

999-
// CHECK-LABEL: @umod_fail_fold
999+
// CHECK-LABEL: @umod_fail_1_fold
10001000
// CHECK-SAME: (%[[ARG:.*]]: i32)
1001-
func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
1001+
func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) {
10021002
// CHECK: %[[CONST5:.*]] = spirv.Constant 5
10031003
// CHECK: %[[CONST32:.*]] = spirv.Constant 32
10041004
%const1 = spirv.Constant 32 : i32
@@ -1011,6 +1011,21 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
10111011
return %0, %1: i32, i32
10121012
}
10131013

1014+
// CHECK-LABEL: @umod_fail_2_fold
1015+
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
1016+
func.func @umod_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
1017+
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
1018+
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
1019+
%const1 = spirv.Constant dense<4> : vector<4xi32>
1020+
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
1021+
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
1022+
%const2 = spirv.Constant dense<32> : vector<4xi32>
1023+
%1 = spirv.UMod %0, %const2 : vector<4xi32>
1024+
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
1025+
// CHECK: return %[[UMOD0]], %[[UMOD1]]
1026+
return %0, %1: vector<4xi32>, vector<4xi32>
1027+
}
1028+
10141029
// -----
10151030

10161031
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)