Skip to content

Commit 0cbf12a

Browse files
committed
address Jakub's comments
Signed-off-by: James Newling <[email protected]>
1 parent d0ae5b5 commit 0cbf12a

File tree

2 files changed

+63
-47
lines changed

2 files changed

+63
-47
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7629,62 +7629,64 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76297629
const int64_t stepSize = stepOp.getResult().getType().getNumElements();
76307630

76317631
for (auto &use : stepOp.getResult().getUses()) {
7632-
if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
7633-
const unsigned stepOperandNumber = use.getOperandNumber();
7634-
7635-
// arith.cmpi canonicalizer makes constants final operands.
7636-
if (stepOperandNumber != 0)
7637-
continue;
7638-
7639-
// Check that operand 1 is a constant.
7640-
unsigned constOperandNumber = 1;
7641-
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7642-
auto maybeConstValue = getConstantIntValue(otherOperand);
7643-
if (!maybeConstValue.has_value())
7644-
continue;
7632+
auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7633+
if (!cmpiOp)
7634+
continue;
76457635

7646-
int64_t constValue = maybeConstValue.value();
7647-
arith::CmpIPredicate pred = cmpiOp.getPredicate();
7636+
// arith.cmpi canonicalizer makes constants final operands.
7637+
const unsigned stepOperandNumber = use.getOperandNumber();
7638+
if (stepOperandNumber != 0)
7639+
continue;
76487640

7649-
auto maybeSplat = [&]() -> std::optional<bool> {
7650-
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7651-
if ((pred == arith::CmpIPredicate::ult ||
7652-
pred == arith::CmpIPredicate::uge) &&
7653-
stepSize <= constValue)
7654-
return pred == arith::CmpIPredicate::ult;
7641+
// Check that operand 1 is a constant.
7642+
unsigned constOperandNumber = 1;
7643+
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7644+
auto maybeConstValue = getConstantIntValue(otherOperand);
7645+
if (!maybeConstValue.has_value())
7646+
continue;
76557647

7656-
// Handle ule and ugt.
7657-
if ((pred == arith::CmpIPredicate::ule ||
7658-
pred == arith::CmpIPredicate::ugt) &&
7659-
stepSize <= constValue + 1)
7660-
return pred == arith::CmpIPredicate::ule;
7648+
int64_t constValue = maybeConstValue.value();
7649+
arith::CmpIPredicate pred = cmpiOp.getPredicate();
7650+
7651+
auto maybeSplat = [&]() -> std::optional<bool> {
7652+
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7653+
if ((pred == arith::CmpIPredicate::ult ||
7654+
pred == arith::CmpIPredicate::uge) &&
7655+
stepSize <= constValue)
7656+
return pred == arith::CmpIPredicate::ult;
7657+
7658+
// Handle ule and ugt.
7659+
if ((pred == arith::CmpIPredicate::ule ||
7660+
pred == arith::CmpIPredicate::ugt) &&
7661+
stepSize - 1 <= constValue) {
7662+
return pred == arith::CmpIPredicate::ule;
7663+
}
76617664

7662-
// Handle eq and ne.
7663-
if ((pred == arith::CmpIPredicate::eq ||
7664-
pred == arith::CmpIPredicate::ne) &&
7665-
stepSize <= constValue)
7666-
return pred == arith::CmpIPredicate::ne;
7665+
// Handle eq and ne.
7666+
if ((pred == arith::CmpIPredicate::eq ||
7667+
pred == arith::CmpIPredicate::ne) &&
7668+
stepSize <= constValue)
7669+
return pred == arith::CmpIPredicate::ne;
76677670

7668-
return std::optional<bool>();
7669-
}();
7671+
return std::nullopt;
7672+
}();
76707673

7671-
if (!maybeSplat.has_value())
7672-
continue;
7674+
if (!maybeSplat.has_value())
7675+
continue;
76737676

7674-
rewriter.setInsertionPointAfter(cmpiOp);
7677+
rewriter.setInsertionPointAfter(cmpiOp);
76757678

7676-
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7677-
if (!type)
7678-
continue;
7679+
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7680+
if (!type)
7681+
continue;
76797682

7680-
DenseElementsAttr boolAttr =
7681-
DenseElementsAttr::get(type, maybeSplat.value());
7682-
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7683-
type, boolAttr);
7683+
DenseElementsAttr boolAttr =
7684+
DenseElementsAttr::get(type, maybeSplat.value());
7685+
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7686+
type, boolAttr);
76847687

7685-
rewriter.replaceOp(cmpiOp, splat);
7686-
return success();
7687-
}
7688+
rewriter.replaceOp(cmpiOp, splat);
7689+
return success();
76887690
}
76897691

76907692
return failure();

mlir/test/Dialect/Vector/canonicalize/vector-step.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
22

33
///===----------------------------------------------===//
44
/// Tests of `StepCompareFolder`
@@ -59,6 +59,20 @@ func.func @check_ugt_constant_3_rhs() -> vector<3xi1> {
5959
return %1 : vector<3xi1>
6060
}
6161

62+
// -----
63+
64+
// CHECK-LABEL: @check_ugt_constant_max_rhs
65+
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
66+
// CHECK: return %[[CST]] : vector<3xi1>
67+
func.func @check_ugt_constant_max_rhs() -> vector<3xi1> {
68+
// The largest i64 possible:
69+
%cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex>
70+
%0 = vector.step : vector<3xindex>
71+
%1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
72+
return %1 : vector<3xi1>
73+
}
74+
75+
6276
// -----
6377

6478
// CHECK-LABEL: @check_ugt_constant_2_rhs

0 commit comments

Comments
 (0)