Skip to content

Commit d0ae5b5

Browse files
committed
cosmetics
Signed-off-by: James Newling <[email protected]>
1 parent b7f146c commit d0ae5b5

File tree

2 files changed

+53
-72
lines changed

2 files changed

+53
-72
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 41 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7604,89 +7604,66 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
76047604

76057605
namespace {
76067606

7607-
/// Constant fold vector.step when it is compared to constant with arith.cmpi
7608-
/// and the result is the same at all indices. For example, rewrite:
7607+
/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
7608+
/// constant large enough such that the result is the same at all indices.
7609+
///
7610+
/// For example, rewrite the 'greater than' comparison below,
76097611
///
76107612
/// %cst = arith.constant dense<7> : vector<3xindex>
7611-
/// %0 = vector.step : vector<3xindex>
7612-
/// %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
7613+
/// %stp = vector.step : vector<3xindex>
7614+
/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
76137615
///
7614-
/// as
7616+
/// as,
76157617
///
7616-
/// %out = arith.constant dense<false> : vector<3xi1>
7618+
/// %out = arith.constant dense<false> : vector<3xi1>.
76177619
///
76187620
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
7619-
/// false at ALL indices we fold to the constant. false. If the constant was 1,
7620-
/// then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
7621-
/// fold, preferring the more 'compact' vector.step representation.
7621+
/// false at ALL indices we fold. If the constant was 1, then
7622+
/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
7623+
/// preferring the 'compact' vector.step representation.
76227624
struct StepCompareFolder : public OpRewritePattern<StepOp> {
76237625
using OpRewritePattern::OpRewritePattern;
76247626

76257627
LogicalResult matchAndRewrite(StepOp stepOp,
76267628
PatternRewriter &rewriter) const override {
7627-
7628-
int64_t stepSize = stepOp.getResult().getType().getNumElements();
7629+
const int64_t stepSize = stepOp.getResult().getType().getNumElements();
76297630

76307631
for (auto &use : stepOp.getResult().getUses()) {
76317632
if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
7632-
unsigned stepOperandNumber = use.getOperandNumber();
7633+
const unsigned stepOperandNumber = use.getOperandNumber();
76337634

7634-
// arith.cmpi has a canonicalizer to put constants on operand 1. Let it
7635-
// run first.
7636-
if (stepOperandNumber != 0) {
7635+
// arith.cmpi canonicalizer makes constants final operands.
7636+
if (stepOperandNumber != 0)
76377637
continue;
7638-
}
76397638

76407639
// Check that operand 1 is a constant.
7641-
unsigned otherOperandNumber = 1;
7642-
Value otherOperand = cmpiOp.getOperand(otherOperandNumber);
7640+
unsigned constOperandNumber = 1;
7641+
Value otherOperand = cmpiOp.getOperand(constOperandNumber);
76437642
auto maybeConstValue = getConstantIntValue(otherOperand);
76447643
if (!maybeConstValue.has_value())
76457644
continue;
7646-
int64_t constValue = maybeConstValue.value();
76477645

7646+
int64_t constValue = maybeConstValue.value();
76487647
arith::CmpIPredicate pred = cmpiOp.getPredicate();
76497648

76507649
auto maybeSplat = [&]() -> std::optional<bool> {
76517650
// Handle ult (unsigned less than) and uge (unsigned greater equal).
7652-
// Examples where stepSize = constValue = 3, for the 4
7653-
// cases of [ult, uge] x [stepOperandNumber = 0, 1]:
7654-
//
7655-
// pred stepOperandNumber
7656-
// ==== =================
7657-
// ult 0 [0, 1, 2] < 3 ==> true.
7658-
// ult 1 3 < [0, 1, 2] ==> false.
7659-
// uge 0 [0, 1, 2] >= 3 ==> true.
7660-
// uge 1 3 >= [0, 1, 2] ==> false.
7661-
//
7662-
// If constValue is any smaller, the comparison is not constant.
7663-
if (pred == arith::CmpIPredicate::ult ||
7664-
pred == arith::CmpIPredicate::uge) {
7665-
if (stepSize <= constValue) {
7666-
return pred == arith::CmpIPredicate::ult;
7667-
}
7668-
}
7651+
if ((pred == arith::CmpIPredicate::ult ||
7652+
pred == arith::CmpIPredicate::uge) &&
7653+
stepSize <= constValue)
7654+
return pred == arith::CmpIPredicate::ult;
76697655

76707656
// Handle ule and ugt.
7671-
//
7672-
// pred stepOperandNumber
7673-
// ==== =================
7674-
// ule 0 [0, 1, 2] <= 2 ==> true
7675-
// (stepSize = 3, constValue = 2).
7676-
if (pred == arith::CmpIPredicate::ule ||
7677-
pred == arith::CmpIPredicate::ugt) {
7678-
if (stepSize <= constValue + 1) {
7679-
return pred == arith::CmpIPredicate::ule;
7680-
}
7681-
}
7657+
if ((pred == arith::CmpIPredicate::ule ||
7658+
pred == arith::CmpIPredicate::ugt) &&
7659+
stepSize <= constValue + 1)
7660+
return pred == arith::CmpIPredicate::ule;
76827661

7683-
// Handle eq and ne
7684-
if (pred == arith::CmpIPredicate::eq ||
7685-
pred == arith::CmpIPredicate::ne) {
7686-
if (stepSize <= constValue) {
7687-
return pred == arith::CmpIPredicate::ne;
7688-
}
7689-
}
7662+
// Handle eq and ne.
7663+
if ((pred == arith::CmpIPredicate::eq ||
7664+
pred == arith::CmpIPredicate::ne) &&
7665+
stepSize <= constValue)
7666+
return pred == arith::CmpIPredicate::ne;
76907667

76917668
return std::optional<bool>();
76927669
}();
@@ -7695,13 +7672,17 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
76957672
continue;
76967673

76977674
rewriter.setInsertionPointAfter(cmpiOp);
7698-
auto boolConst = mlir::arith::ConstantOp::create(
7699-
rewriter, cmpiOp.getLoc(),
7700-
rewriter.getBoolAttr(maybeSplat.value()));
7701-
auto splat = vector::BroadcastOp::create(
7702-
rewriter, cmpiOp.getLoc(), cmpiOp.getResult().getType(), boolConst);
77037675

7704-
rewriter.replaceOp(cmpiOp, splat.getResult());
7676+
auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
7677+
if (!type)
7678+
continue;
7679+
7680+
DenseElementsAttr boolAttr =
7681+
DenseElementsAttr::get(type, maybeSplat.value());
7682+
Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
7683+
type, boolAttr);
7684+
7685+
rewriter.replaceOp(cmpiOp, splat);
77057686
return success();
77067687
}
77077688
}

mlir/test/Dialect/Vector/canonicalize/vector-step.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
///===----------------------------------------------===//
66

77

8-
///===--------------===//
8+
///===------------------------------------===//
99
/// Tests of `ugt` (unsigned greater than)
10-
///===--------------===//
10+
///===------------------------------------===//
1111

1212
// CHECK-LABEL: @check_ugt_constant_3_lhs
1313
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
@@ -87,9 +87,9 @@ func.func @check_ugt_constant_1_rhs() -> vector<3xi1> {
8787

8888
// -----
8989

90-
///===--------------===//
90+
///===------------------------------------===//
9191
/// Tests of `uge` (unsigned greater than or equal)
92-
///===--------------===//
92+
///===------------------------------------===//
9393

9494
// CHECK-LABEL: @check_uge_constant_3_lhs
9595
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
@@ -171,9 +171,9 @@ func.func @check_uge_constant_1_rhs() -> vector<3xi1> {
171171

172172

173173

174-
///===--------------===//
174+
///===------------------------------------===//
175175
/// Tests of `ult` (unsigned less than)
176-
///===--------------===//
176+
///===------------------------------------===//
177177

178178
// CHECK-LABEL: @check_ult_constant_3_lhs
179179
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -247,9 +247,9 @@ func.func @check_ult_constant_1_rhs() -> vector<3xi1> {
247247

248248
// -----
249249

250-
///===--------------===//
250+
///===------------------------------------===//
251251
/// Tests of `ule` (unsigned less than or equal)
252-
///===--------------===//
252+
///===------------------------------------===//
253253

254254
// CHECK-LABEL: @check_ule_constant_3_lhs
255255
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -323,9 +323,9 @@ func.func @check_ule_constant_1_rhs() -> vector<3xi1> {
323323

324324
// -----
325325

326-
///===--------------===//
326+
///===------------------------------------===//
327327
/// Tests of `eq` (equal)
328-
///===--------------===//
328+
///===------------------------------------===//
329329

330330
// CHECK-LABEL: @check_eq_constant_3
331331
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -351,9 +351,9 @@ func.func @check_eq_constant_2() -> vector<3xi1> {
351351

352352
// -----
353353

354-
///===--------------===//
354+
///===------------------------------------===//
355355
/// Tests of `ne` (not equal)
356-
///===--------------===//
356+
///===------------------------------------===//
357357

358358
// CHECK-LABEL: @check_ne_constant_3
359359
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>

0 commit comments

Comments
 (0)