Skip to content

Commit 420da4b

Browse files
committed
address Jakub's comments
Signed-off-by: James Newling <[email protected]>
1 parent 442a376 commit 420da4b

File tree

2 files changed

+63
-47
lines changed

2 files changed

+63
-47
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -7628,62 +7628,64 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76287628
const int64_t stepSize = stepOp.getResult().getType().getNumElements();
76297629

76307630
for (auto &use : stepOp.getResult().getUses()) {
7631-
if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
7632-
const unsigned stepOperandNumber = use.getOperandNumber();
7633-
7634-
// arith.cmpi canonicalizer makes constants final operands.
7635-
if (stepOperandNumber != 0)
7636-
continue;
7637-
7638-
// Check that operand 1 is a constant.
7639-
unsigned constOperandNumber = 1;
7640-
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7641-
auto maybeConstValue = getConstantIntValue(otherOperand);
7642-
if (!maybeConstValue.has_value())
7643-
continue;
7631+
auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
7632+
if (!cmpiOp)
7633+
continue;
76447634

7645-
int64_t constValue = maybeConstValue.value();
7646-
arith::CmpIPredicate pred = cmpiOp.getPredicate();
7635+
// arith.cmpi canonicalizer makes constants final operands.
7636+
const unsigned stepOperandNumber = use.getOperandNumber();
7637+
if (stepOperandNumber != 0)
7638+
continue;
76477639

7648-
auto maybeSplat = [&]() -> std::optional<bool> {
7649-
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7650-
if ((pred == arith::CmpIPredicate::ult ||
7651-
pred == arith::CmpIPredicate::uge) &&
7652-
stepSize <= constValue)
7653-
return pred == arith::CmpIPredicate::ult;
7640+
// Check that operand 1 is a constant.
7641+
unsigned constOperandNumber = 1;
7642+
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
7643+
auto maybeConstValue = getConstantIntValue(otherOperand);
7644+
if (!maybeConstValue.has_value())
7645+
continue;
76547646

7655-
// Handle ule and ugt.
7656-
if ((pred == arith::CmpIPredicate::ule ||
7657-
pred == arith::CmpIPredicate::ugt) &&
7658-
stepSize <= constValue + 1)
7659-
return pred == arith::CmpIPredicate::ule;
7647+
int64_t constValue = maybeConstValue.value();
7648+
arith::CmpIPredicate pred = cmpiOp.getPredicate();
7649+
7650+
auto maybeSplat = [&]() -> std::optional<bool> {
7651+
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7652+
if ((pred == arith::CmpIPredicate::ult ||
7653+
pred == arith::CmpIPredicate::uge) &&
7654+
stepSize <= constValue)
7655+
return pred == arith::CmpIPredicate::ult;
7656+
7657+
// Handle ule and ugt.
7658+
if ((pred == arith::CmpIPredicate::ule ||
7659+
pred == arith::CmpIPredicate::ugt) &&
7660+
stepSize - 1 <= constValue) {
7661+
return pred == arith::CmpIPredicate::ule;
7662+
}
76607663

7661-
// Handle eq and ne.
7662-
if ((pred == arith::CmpIPredicate::eq ||
7663-
pred == arith::CmpIPredicate::ne) &&
7664-
stepSize <= constValue)
7665-
return pred == arith::CmpIPredicate::ne;
7664+
// Handle eq and ne.
7665+
if ((pred == arith::CmpIPredicate::eq ||
7666+
pred == arith::CmpIPredicate::ne) &&
7667+
stepSize <= constValue)
7668+
return pred == arith::CmpIPredicate::ne;
76667669

7667-
return std::optional<bool>();
7668-
}();
7670+
return std::nullopt;
7671+
}();
76697672

7670-
if (!maybeSplat.has_value())
7671-
continue;
7673+
if (!maybeSplat.has_value())
7674+
continue;
76727675

7673-
rewriter.setInsertionPointAfter(cmpiOp);
7676+
rewriter.setInsertionPointAfter(cmpiOp);
76747677

7675-
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7676-
if (!type)
7677-
continue;
7678+
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7679+
if (!type)
7680+
continue;
76787681

7679-
DenseElementsAttr boolAttr =
7680-
DenseElementsAttr::get(type, maybeSplat.value());
7681-
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7682-
type, boolAttr);
7682+
DenseElementsAttr boolAttr =
7683+
DenseElementsAttr::get(type, maybeSplat.value());
7684+
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7685+
type, boolAttr);
76837686

7684-
rewriter.replaceOp(cmpiOp, splat);
7685-
return success();
7686-
}
7687+
rewriter.replaceOp(cmpiOp, splat);
7688+
return success();
76877689
}
76887690

76897691
return failure();

mlir/test/Dialect/Vector/canonicalize/vector-step.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
1+
// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
22

33
///===----------------------------------------------===//
44
/// Tests of `StepCompareFolder`
@@ -59,6 +59,20 @@ func.func @check_ugt_constant_3_rhs() -> vector<3xi1> {
5959
return %1 : vector<3xi1>
6060
}
6161

62+
// -----
63+
64+
// CHECK-LABEL: @check_ugt_constant_max_rhs
65+
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
66+
// CHECK: return %[[CST]] : vector<3xi1>
67+
func.func @check_ugt_constant_max_rhs() -> vector<3xi1> {
68+
// The largest i64 possible:
69+
%cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex>
70+
%0 = vector.step : vector<3xindex>
71+
%1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
72+
return %1 : vector<3xi1>
73+
}
74+
75+
6276
// -----
6377

6478
// CHECK-LABEL: @check_ugt_constant_2_rhs

0 commit comments

Comments
 (0)