diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index e36d4b910193e..03af61c81ae6c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns( // The transformation is only applied if one divisor is a multiple of the other. -// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants struct UModSimplification final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern { if (!prevUMod) return failure(); - IntegerAttr prevValue; - IntegerAttr currValue; + TypedAttr prevValue; + TypedAttr currValue; if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) || !matchPattern(umodOp.getOperand(1), m_Constant(&currValue))) return failure(); - APInt prevConstValue = prevValue.getValue(); - APInt currConstValue = currValue.getValue(); + // Ensure that previous divisor is a multiple of the current divisor. If + // not, fail the transformation. + bool isApplicable = false; + if (auto prevInt = dyn_cast(prevValue)) { + auto currInt = cast(currValue); + isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0; + } else if (auto prevVec = dyn_cast(prevValue)) { + auto currVec = cast(currValue); + isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues(), + currVec.getValues()), + [](const auto &pair) { + auto &[prev, curr] = pair; + return prev.urem(curr) == 0; + }); + } - // Ensure that one divisor is a multiple of the other. If not, fail the - // transformation. - if (prevConstValue.urem(currConstValue) != 0 && - currConstValue.urem(prevConstValue) != 0) + if (!isApplicable) return failure(); // The transformation is safe. Replace the existing UMod operation with a diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 0fd6c18a6c241..722c27586aa61 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) { return %0, %1: i32, i32 } -// CHECK-LABEL: @umod_fail_vector_fold +// CHECK-LABEL: @umod_vector_fold // CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) -func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { +func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32> // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32> %const1 = spirv.Constant dense<32> : vector<4xi32> %0 = spirv.UMod %arg0, %const1 : vector<4xi32> - // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] %const2 = spirv.Constant dense<4> : vector<4xi32> %1 = spirv.UMod %0, %const2 : vector<4xi32> - // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]] + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] // CHECK: return %[[UMOD0]], %[[UMOD1]] return %0, %1: vector<4xi32>, vector<4xi32> } @@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) { return %0, %1: i32, i32 } -// CHECK-LABEL: @umod_fail_fold +// CHECK-LABEL: @umod_fail_1_fold // CHECK-SAME: (%[[ARG:.*]]: i32) -func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) { +func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) { // CHECK: %[[CONST5:.*]] = spirv.Constant 5 // CHECK: %[[CONST32:.*]] = spirv.Constant 32 %const1 = spirv.Constant 32 : i32 @@ -1011,6 +1011,51 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) { return %0, %1: i32, i32 } +// CHECK-LABEL: @umod_fail_2_fold +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fail_2_fold(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST32:.*]] = spirv.Constant 32 + // CHECK: %[[CONST4:.*]] = spirv.Constant 4 + %const1 = spirv.Constant 4 : i32 + %0 = spirv.UMod %arg0, %const1 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] + %const2 = spirv.Constant 32 : i32 + %1 = spirv.UMod %0, %const2 : i32 + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: i32, i32 +} + +// CHECK-LABEL: @umod_vector_fail_1_fold +// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) +func.func @umod_vector_fail_1_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + // CHECK: %[[CONST9:.*]] = spirv.Constant dense<9> : vector<4xi32> + // CHECK: %[[CONST64:.*]] = spirv.Constant dense<64> : vector<4xi32> + %const1 = spirv.Constant dense<64> : vector<4xi32> + %0 = spirv.UMod %arg0, %const1 : vector<4xi32> + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST64]] + %const2 = spirv.Constant dense<9> : vector<4xi32> + %1 = spirv.UMod %0, %const2 : vector<4xi32> + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST9]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: vector<4xi32>, vector<4xi32> +} + +// CHECK-LABEL: @umod_vector_fail_2_fold +// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) +func.func @umod_vector_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32> + // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32> + %const1 = spirv.Constant dense<4> : vector<4xi32> + %0 = spirv.UMod %arg0, %const1 : vector<4xi32> + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] + %const2 = spirv.Constant dense<32> : vector<4xi32> + %1 = spirv.UMod %0, %const2 : vector<4xi32> + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: vector<4xi32>, vector<4xi32> +} + // ----- //===----------------------------------------------------------------------===//