Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ void spirv::UMulExtendedOp::getCanonicalizationPatterns(

// The transformation is only applied if one divisor is a multiple of the other.

// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants
struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
using OpRewritePattern::OpRewritePattern;

Expand All @@ -336,19 +335,28 @@ struct UModSimplification final : OpRewritePattern<spirv::UModOp> {
if (!prevUMod)
return failure();

IntegerAttr prevValue;
IntegerAttr currValue;
TypedAttr prevValue;
TypedAttr currValue;
if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) ||
!matchPattern(umodOp.getOperand(1), m_Constant(&currValue)))
return failure();

APInt prevConstValue = prevValue.getValue();
APInt currConstValue = currValue.getValue();
// Ensure that previous divisor is a multiple of the current divisor. If
// not, fail the transformation.
bool isApplicable = false;
if (auto prevInt = dyn_cast<IntegerAttr>(prevValue)) {
auto currInt = dyn_cast<IntegerAttr>(currValue);
isApplicable = prevInt.getValue().urem(currInt.getValue()) == 0;
} else if (auto prevVec = dyn_cast<DenseElementsAttr>(prevValue)) {
auto currVec = dyn_cast<DenseElementsAttr>(currValue);
isApplicable = llvm::all_of(
llvm::zip(prevVec.getValues<APInt>(), currVec.getValues<APInt>()),
[](const auto &pair) {
return std::get<0>(pair).urem(std::get<1>(pair)) == 0;
});
}

// Ensure that one divisor is a multiple of the other. If not, fail the
// transformation.
if (prevConstValue.urem(currConstValue) != 0 &&
currConstValue.urem(prevConstValue) != 0)
if (!isApplicable)
return failure();

// The transformation is safe. Replace the existing UMod operation with a
Expand Down
57 changes: 51 additions & 6 deletions mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -967,17 +967,17 @@ func.func @umod_fold(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}

// CHECK-LABEL: @umod_fail_vector_fold
// CHECK-LABEL: @umod_vector_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
func.func @umod_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
%const1 = spirv.Constant dense<32> : vector<4xi32>
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
%const2 = spirv.Constant dense<4> : vector<4xi32>
%1 = spirv.UMod %0, %const2 : vector<4xi32>
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]]
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]]
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
// CHECK: return %[[UMOD0]], %[[UMOD1]]
return %0, %1: vector<4xi32>, vector<4xi32>
}
Expand All @@ -996,9 +996,9 @@ func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}

// CHECK-LABEL: @umod_fail_fold
// CHECK-LABEL: @umod_fail_1_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
func.func @umod_fail_1_fold(%arg0: i32) -> (i32, i32) {
// CHECK: %[[CONST5:.*]] = spirv.Constant 5
// CHECK: %[[CONST32:.*]] = spirv.Constant 32
%const1 = spirv.Constant 32 : i32
Expand All @@ -1011,6 +1011,51 @@ func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) {
return %0, %1: i32, i32
}

// CHECK-LABEL: @umod_fail_2_fold
// CHECK-SAME: (%[[ARG:.*]]: i32)
func.func @umod_fail_2_fold(%arg0: i32) -> (i32, i32) {
// CHECK: %[[CONST32:.*]] = spirv.Constant 32
// CHECK: %[[CONST4:.*]] = spirv.Constant 4
%const1 = spirv.Constant 4 : i32
%0 = spirv.UMod %arg0, %const1 : i32
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
%const2 = spirv.Constant 32 : i32
%1 = spirv.UMod %0, %const2 : i32
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
// CHECK: return %[[UMOD0]], %[[UMOD1]]
return %0, %1: i32, i32
}

// CHECK-LABEL: @umod_vector_fail_1_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
func.func @umod_vector_fail_1_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
// CHECK: %[[CONST9:.*]] = spirv.Constant dense<9> : vector<4xi32>
// CHECK: %[[CONST64:.*]] = spirv.Constant dense<64> : vector<4xi32>
%const1 = spirv.Constant dense<64> : vector<4xi32>
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST64]]
%const2 = spirv.Constant dense<9> : vector<4xi32>
%1 = spirv.UMod %0, %const2 : vector<4xi32>
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST9]]
// CHECK: return %[[UMOD0]], %[[UMOD1]]
return %0, %1: vector<4xi32>, vector<4xi32>
}

// CHECK-LABEL: @umod_vector_fail_2_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>)
func.func @umod_vector_fail_2_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
// CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32>
// CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32>
%const1 = spirv.Constant dense<4> : vector<4xi32>
%0 = spirv.UMod %arg0, %const1 : vector<4xi32>
// CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST4]]
%const2 = spirv.Constant dense<32> : vector<4xi32>
%1 = spirv.UMod %0, %const2 : vector<4xi32>
// CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST32]]
// CHECK: return %[[UMOD0]], %[[UMOD1]]
return %0, %1: vector<4xi32>, vector<4xi32>
}

// -----

//===----------------------------------------------------------------------===//
Expand Down
Loading