Skip to content

Commit 442a376

Browse files
committed
cosmetics
Signed-off-by: James Newling <[email protected]>
1 parent 8f748ae commit 442a376

File tree

2 files changed

+53
-72
lines changed

2 files changed

+53
-72
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 41 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7603,89 +7603,66 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76037603

76047604
namespace {
76057605

7606-
/// Constant fold vector.step when it is compared to constant with arith.cmpi
7607-
/// and the result is the same at all indices. For example, rewrite:
7606+
/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7607+
/// constant large enough such that the result is the same at all indices.
7608+
///
7609+
/// For example, rewrite the 'greater than' comparison below,
76087610
///
76097611
/// %cst = arith.constant dense<7> : vector<3xindex>
7610-
/// %0 = vector.step : vector<3xindex>
7611-
/// %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
7612+
/// %stp = vector.step : vector<3xindex>
7613+
/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
76127614
///
7613-
/// as
7615+
/// as,
76147616
///
7615-
/// %out = arith.constant dense<false> : vector<3xi1>
7617+
/// %out = arith.constant dense<false> : vector<3xi1>.
76167618
///
76177619
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
7618-
/// false at ALL indices we fold to the constant. false. If the constant was 1,
7619-
/// then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
7620-
/// fold, preferring the more 'compact' vector.step representation.
7620+
/// false at ALL indices we fold. If the constant was 1, then
7621+
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
7622+
/// preferring the 'compact' vector.step representation.
76217623
struct StepCompareFolder : public OpRewritePattern<StepOp> {
76227624
using OpRewritePattern::OpRewritePattern;
76237625

76247626
LogicalResult matchAndRewrite(StepOp stepOp,
76257627
PatternRewriter &rewriter) const override {
7626-
7627-
int64_t stepSize = stepOp.getResult().getType().getNumElements();
7628+
const int64_t stepSize = stepOp.getResult().getType().getNumElements();
76287629

76297630
for (auto &use : stepOp.getResult().getUses()) {
76307631
if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
7631-
unsigned stepOperandNumber = use.getOperandNumber();
7632+
const unsigned stepOperandNumber = use.getOperandNumber();
76327633

7633-
// arith.cmpi has a canonicalizer to put constants on operand 1. Let it
7634-
// run first.
7635-
if (stepOperandNumber != 0) {
7634+
// arith.cmpi canonicalizer makes constants final operands.
7635+
if (stepOperandNumber != 0)
76367636
continue;
7637-
}
76387637

76397638
// Check that operand 1 is a constant.
7640-
unsigned otherOperandNumber = 1;
7641-
Value otherOperand = cmpiOp.getOperand(otherOperandNumber);
7639+
unsigned constOperandNumber = 1;
7640+
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
76427641
auto maybeConstValue = getConstantIntValue(otherOperand);
76437642
if (!maybeConstValue.has_value())
76447643
continue;
7645-
int64_t constValue = maybeConstValue.value();
76467644

7645+
int64_t constValue = maybeConstValue.value();
76477646
arith::CmpIPredicate pred = cmpiOp.getPredicate();
76487647

76497648
auto maybeSplat = [&]() -> std::optional<bool> {
76507649
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7651-
// Examples where stepSize = constValue = 3, for the 4
7652-
// cases of [ult, uge] x [stepOperandNumber = 0, 1]:
7653-
//
7654-
// pred stepOperandNumber
7655-
// ==== =================
7656-
// ult 0 [0, 1, 2] < 3 ==> true.
7657-
// ult 1 3 < [0, 1, 2] ==> false.
7658-
// uge 0 [0, 1, 2] >= 3 ==> true.
7659-
// uge 1 3 >= [0, 1, 2] ==> false.
7660-
//
7661-
// If constValue is any smaller, the comparison is not constant.
7662-
if (pred == arith::CmpIPredicate::ult ||
7663-
pred == arith::CmpIPredicate::uge) {
7664-
if (stepSize <= constValue) {
7665-
return pred == arith::CmpIPredicate::ult;
7666-
}
7667-
}
7650+
if ((pred == arith::CmpIPredicate::ult ||
7651+
pred == arith::CmpIPredicate::uge) &&
7652+
stepSize <= constValue)
7653+
return pred == arith::CmpIPredicate::ult;
76687654

76697655
// Handle ule and ugt.
7670-
//
7671-
// pred stepOperandNumber
7672-
// ==== =================
7673-
// ule 0 [0, 1, 2] <= 2 ==> true
7674-
// (stepSize = 3, constValue = 2).
7675-
if (pred == arith::CmpIPredicate::ule ||
7676-
pred == arith::CmpIPredicate::ugt) {
7677-
if (stepSize <= constValue + 1) {
7678-
return pred == arith::CmpIPredicate::ule;
7679-
}
7680-
}
7656+
if ((pred == arith::CmpIPredicate::ule ||
7657+
pred == arith::CmpIPredicate::ugt) &&
7658+
stepSize <= constValue + 1)
7659+
return pred == arith::CmpIPredicate::ule;
76817660

7682-
// Handle eq and ne
7683-
if (pred == arith::CmpIPredicate::eq ||
7684-
pred == arith::CmpIPredicate::ne) {
7685-
if (stepSize <= constValue) {
7686-
return pred == arith::CmpIPredicate::ne;
7687-
}
7688-
}
7661+
// Handle eq and ne.
7662+
if ((pred == arith::CmpIPredicate::eq ||
7663+
pred == arith::CmpIPredicate::ne) &&
7664+
stepSize <= constValue)
7665+
return pred == arith::CmpIPredicate::ne;
76897666

76907667
return std::optional<bool>();
76917668
}();
@@ -7694,13 +7671,17 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76947671
continue;
76957672

76967673
rewriter.setInsertionPointAfter(cmpiOp);
7697-
auto boolConst = mlir::arith::ConstantOp::create(
7698-
rewriter, cmpiOp.getLoc(),
7699-
rewriter.getBoolAttr(maybeSplat.value()));
7700-
auto splat = vector::BroadcastOp::create(
7701-
rewriter, cmpiOp.getLoc(), cmpiOp.getResult().getType(), boolConst);
77027674

7703-
rewriter.replaceOp(cmpiOp, splat.getResult());
7675+
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7676+
if (!type)
7677+
continue;
7678+
7679+
DenseElementsAttr boolAttr =
7680+
DenseElementsAttr::get(type, maybeSplat.value());
7681+
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7682+
type, boolAttr);
7683+
7684+
rewriter.replaceOp(cmpiOp, splat);
77047685
return success();
77057686
}
77067687
}

mlir/test/Dialect/Vector/canonicalize/vector-step.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
///===----------------------------------------------===//
66

77

8-
///===--------------===//
8+
///===------------------------------------===//
99
/// Tests of `ugt` (unsigned greater than)
10-
///===--------------===//
10+
///===------------------------------------===//
1111

1212
// CHECK-LABEL: @check_ugt_constant_3_lhs
1313
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
@@ -87,9 +87,9 @@ func.func @check_ugt_constant_1_rhs() -> vector<3xi1> {
8787

8888
// -----
8989

90-
///===--------------===//
90+
///===------------------------------------===//
9191
/// Tests of `uge` (unsigned greater than or equal)
92-
///===--------------===//
92+
///===------------------------------------===//
9393

9494
// CHECK-LABEL: @check_uge_constant_3_lhs
9595
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
@@ -171,9 +171,9 @@ func.func @check_uge_constant_1_rhs() -> vector<3xi1> {
171171

172172

173173

174-
///===--------------===//
174+
///===------------------------------------===//
175175
/// Tests of `ult` (unsigned less than)
176-
///===--------------===//
176+
///===------------------------------------===//
177177

178178
// CHECK-LABEL: @check_ult_constant_3_lhs
179179
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -247,9 +247,9 @@ func.func @check_ult_constant_1_rhs() -> vector<3xi1> {
247247

248248
// -----
249249

250-
///===--------------===//
250+
///===------------------------------------===//
251251
/// Tests of `ule` (unsigned less than or equal)
252-
///===--------------===//
252+
///===------------------------------------===//
253253

254254
// CHECK-LABEL: @check_ule_constant_3_lhs
255255
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -323,9 +323,9 @@ func.func @check_ule_constant_1_rhs() -> vector<3xi1> {
323323

324324
// -----
325325

326-
///===--------------===//
326+
///===------------------------------------===//
327327
/// Tests of `eq` (equal)
328-
///===--------------===//
328+
///===------------------------------------===//
329329

330330
// CHECK-LABEL: @check_eq_constant_3
331331
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -351,9 +351,9 @@ func.func @check_eq_constant_2() -> vector<3xi1> {
351351

352352
// -----
353353

354-
///===--------------===//
354+
///===------------------------------------===//
355355
/// Tests of `ne` (not equal)
356-
///===--------------===//
356+
///===------------------------------------===//
357357

358358
// CHECK-LABEL: @check_ne_constant_3
359359
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>

0 commit comments

Comments
 (0)