From 891d5abf14c9b0d18e7af6a242ee80f2ea6a21ff Mon Sep 17 00:00:00 2001 From: fairywreath Date: Thu, 29 May 2025 02:13:25 -0400 Subject: [PATCH 1/4] [mlir][spirv] Implement UMod canonicalization for vector constants --- .../SPIRV/IR/SPIRVCanonicalization.cpp | 27 ++++++++++++------- .../SPIRV/Transforms/canonicalize.mlir | 27 ++++++++++++++----- 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index e36d4b910193e..837932efc72ba 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns( // The transformation is only applied if one divisor is a multiple of the other. -// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants struct UModSimplification final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -336,19 +335,29 @@ struct UModSimplification final : OpRewritePattern { if (!prevUMod) return failure(); - IntegerAttr prevValue; - IntegerAttr currValue; + TypedAttr prevValue; + TypedAttr currValue; if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) || !matchPattern(umodOp.getOperand(1), m_Constant(&currValue))) return failure(); - APInt prevConstValue = prevValue.getValue(); - APInt currConstValue = currValue.getValue(); + // Ensure that previous divisor is a multiple of the current divisor. If + // not, fail the transformation. + bool isApplicable = false; + if (auto prevInt = dyn_cast(prevValue)) { + auto currInt = dyn_cast(currValue); + isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0; + } else if (auto prevVec = dyn_cast(prevValue)) { + auto currVec = dyn_cast(currValue); + isApplicable = llvm::all_of( + llvm::zip(prevVec.getValues(), currVec.getValues()), + [](const auto &pair) { + const auto &[a, b] = pair; + return a.urem(b) == 0; + }); + } - // Ensure that one divisor is a multiple of the other. If not, fail the - // transformation. - if (prevConstValue.urem(currConstValue) != 0 && - currConstValue.urem(prevConstValue) != 0) + if (!isApplicable) return failure(); // The transformation is safe. Replace the existing UMod operation with a diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 0fd6c18a6c241..52c915bfebc66 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) { return %0, %1: i32, i32 } -// CHECK-LABEL: @umod_fail_vector_fold +// CHECK-LABEL: @umod_vector_fold // CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) -func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { +func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32> // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32> %const1 = spirv.Constant dense<32> : vector<4xi32> %0 = spirv.UMod %arg0, %const1 : vector<4xi32> - // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] %const2 = spirv.Constant dense<4> : vector<4xi32> %1 = spirv.UMod %0, %const2 : vector<4xi32> - // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]] + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] // CHECK: return %[[UMOD0]], %[[UMOD1]] return %0, %1: vector<4xi32>, vector<4xi32> } @@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) { return %0, %1: i32, i32 } -// CHECK-LABEL: @umod_fail_fold +// CHECK-LABEL: @umod_fail_1_fold // CHECK-SAME: (%[[ARG:.*]]: i32) -func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) { +func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) { // CHECK: %[[CONST5:.*]] = spirv.Constant 5 // CHECK: %[[CONST32:.*]] = spirv.Constant 32 %const1 = spirv.Constant 32 : i32 @@ -1011,6 +1011,21 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) { return %0, %1: i32, i32 } +// CHECK-LABEL: @umod_fail_2_fold +// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) +func.func @umod_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32> + // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32> + %const1 = spirv.Constant dense<4> : vector<4xi32> + %0 = spirv.UMod %arg0, %const1 : vector<4xi32> + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] + %const2 = spirv.Constant dense<32> : vector<4xi32> + %1 = spirv.UMod %0, %const2 : vector<4xi32> + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: vector<4xi32>, vector<4xi32> +} + // ----- //===----------------------------------------------------------------------===// From 2d294b2eeef68bd74ef9583eb55b9e31bcb22e41 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sun, 1 Jun 2025 00:56:48 -0400 Subject: [PATCH 2/4] Add more test and improve code --- .../Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 3 +-- .../Dialect/SPIRV/Transforms/canonicalize.mlir | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 837932efc72ba..21cc742da0d9c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -352,8 +352,7 @@ struct UModSimplification final : OpRewritePattern { isApplicable = llvm::all_of( llvm::zip(prevVec.getValues(), currVec.getValues()), [](const auto &pair) { - const auto &[a, b] = pair; - return a.urem(b) == 0; + return std::get<0>(pair).urem(std::get<1>(pair)) == 0; }); } diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 52c915bfebc66..60625c8d354dd 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -1012,8 +1012,23 @@ func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) { } // CHECK-LABEL: @umod_fail_2_fold +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fail_2_fold(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST32:.*]] = spirv.Constant 32 + // CHECK: %[[CONST4:.*]] = spirv.Constant 4 + %const1 = spirv.Constant 4 : i32 + %0 = spirv.UMod %arg0, %const1 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] + %const2 = spirv.Constant 32 : i32 + %1 = spirv.UMod %0, %const2 : i32 + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: i32, i32 +} + +// CHECK-LABEL: @umod_vec_fail_fold // CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) -func.func @umod_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { +func.func @umod_vec_fail_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32> // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32> %const1 = spirv.Constant dense<4> : vector<4xi32> From 75a25e4ddfbeba3933e01dcfdd5671b88edce0f0 Mon Sep 17 00:00:00 2001 From: fairywreath Date: Sun, 1 Jun 2025 01:02:27 -0400 Subject: [PATCH 3/4] Add one more test --- .../SPIRV/Transforms/canonicalize.mlir | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir index 60625c8d354dd..722c27586aa61 100644 --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -1026,9 +1026,24 @@ func.func @umod_fail_2_fold(%arg0: i32) -> (i32, i32) { return %0, %1: i32, i32 } -// CHECK-LABEL: @umod_vec_fail_fold +// CHECK-LABEL: @umod_vector_fail_1_fold // CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) -func.func @umod_vec_fail_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { +func.func @umod_vector_fail_1_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + // CHECK: %[[CONST9:.*]] = spirv.Constant dense<9> : vector<4xi32> + // CHECK: %[[CONST64:.*]] = spirv.Constant dense<64> : vector<4xi32> + %const1 = spirv.Constant dense<64> : vector<4xi32> + %0 = spirv.UMod %arg0, %const1 : vector<4xi32> + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST64]] + %const2 = spirv.Constant dense<9> : vector<4xi32> + %1 = spirv.UMod %0, %const2 : vector<4xi32> + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST9]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: vector<4xi32>, vector<4xi32> +} + +// CHECK-LABEL: @umod_vector_fail_2_fold +// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) +func.func @umod_vector_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32> // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32> %const1 = spirv.Constant dense<4> : vector<4xi32> From 9835e16b3c85cc5b3f1c09a69fe9d832430e0c6c Mon Sep 17 00:00:00 2001 From: fairywreath Date: Wed, 4 Jun 2025 02:40:53 -0400 Subject: [PATCH 4/4] Address review comments --- .../Dialect/SPIRV/IR/SPIRVCanonicalization.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp index 21cc742da0d9c..03af61c81ae6c 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -345,15 +345,16 @@ struct UModSimplification final : OpRewritePattern { // not, fail the transformation. bool isApplicable = false; if (auto prevInt = dyn_cast(prevValue)) { - auto currInt = dyn_cast(currValue); + auto currInt = cast(currValue); isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0; } else if (auto prevVec = dyn_cast(prevValue)) { - auto currVec = dyn_cast(currValue); - isApplicable = llvm::all_of( - llvm::zip(prevVec.getValues(), currVec.getValues()), - [](const auto &pair) { - return std::get<0>(pair).urem(std::get<1>(pair)) == 0; - }); + auto currVec = cast(currValue); + isApplicable = llvm::all_of(llvm::zip_equal(prevVec.getValues(), + currVec.getValues()), + [](const auto &pair) { + auto &[prev, curr] = pair; + return prev.urem(curr) == 0; + }); } if (!isApplicable)