Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4729,12 +4729,55 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
return success();
}
};

/// Give mixed basis of affine.delinearize_index/linearize_index replace
/// constant SSA values with constant attribute as OpFoldResult. In case no
/// change is made to the existing mixed basis set, return failure; success
/// otherwise.
static LogicalResult
fetchNewConstantBasis(PatternRewriter &rewriter,
SmallVector<OpFoldResult> mixedBasis,
SmallVector<OpFoldResult> &newBasis) {
// Replace all constant SSA values with the constant attribute.
bool hasConstantSSAVal = false;
for (OpFoldResult basis : mixedBasis) {
std::optional<int64_t> basisVal = getConstantIntValue(basis);
if (basisVal && !isa<Attribute>(basis)) {
newBasis.push_back(rewriter.getIndexAttr(*basisVal));
hasConstantSSAVal = true;
} else {
newBasis.push_back(basis);
}
}
if (hasConstantSSAVal)
return success();
return failure();
}

/// Folds away constant SSA Value with constant Attribute in basis.
struct ConstantAttributeBasisDelinearizeIndexOpPattern
: public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp op,
PatternRewriter &rewriter) const override {
// Replace all constant SSA values with the constant attribute.
SmallVector<OpFoldResult> newBasis;
if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");

rewriter.replaceOpWithNewOp<affine::AffineDelinearizeIndexOp>(
op, op.getLinearIndex(), newBasis, op.hasOuterBound());
return success();
}
};
} // namespace

void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
DropUnitExtentBasis>(context);
DropUnitExtentBasis,
ConstantAttributeBasisDelinearizeIndexOpPattern>(context);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -4959,12 +5002,31 @@ struct DropLinearizeLeadingZero final
return success();
}
};

/// Folds away constant SSA Value with constant Attribute in basis.
struct ConstantAttributeBasisLinearizeIndexOpPattern
: public OpRewritePattern<affine::AffineLinearizeIndexOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
PatternRewriter &rewriter) const override {
// Replace all constant SSA values with the constant attribute.
SmallVector<OpFoldResult> newBasis;
if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");

rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
op, op.getMultiIndex(), newBasis, op.getDisjoint());
return success();
}
};
} // namespace

void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
DropLinearizeUnitComponentsIfDisjointOrZero>(context);
DropLinearizeUnitComponentsIfDisjointOrZero,
ConstantAttributeBasisLinearizeIndexOpPattern>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Affine/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1946,3 +1946,33 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
return %ret : index
}

// -----

// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[RET:.*]]:2 = affine.delinearize_index %[[ARG0]] into (3, 4) : index, index
// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[C0]] : index, index, index
func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
(index, index, index) {
%c4 = arith.constant 4 : index
%c3 = arith.constant 3 : index
%c1 = arith.constant 1 : index
%0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c1)
: index, index, index
return %0#0, %0#1, %0#2 : index, index, index
}

// -----

// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
// CHECK: %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (3, 4) : index
// CHECK: return %[[RET]] : index
func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
(index) {
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%1 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%c1, 3, %c4) : index
return %1 : index
}
Loading