-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][Affine] Update ::fold() to have constant basis attr for affine.delinearize_index/linearize_index wherever applicable #117572
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
-- This commit adds canonicalization pattern to have constant(CST) attribute for affine.delinearize_index/linearize_index op's basis wherever applicable. -- Essentially the patterns check if the mixed basis OpFoldResult set contains any constant SSA value and converts it to a constant integer attribute instead. Signed-off-by: Abhishek Varma <[email protected]>
|
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir Author: Abhishek Varma (Abhishek-Varma) Changes-- This commit adds canonicalization pattern to have constant(CST) Signed-off-by: Abhishek Varma <[email protected]> Full diff: https://github.com/llvm/llvm-project/pull/117572.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 67d7da622a3550..3e82ec00763142 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4729,12 +4729,55 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
return success();
}
};
+
+/// Give mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with constant attribute as OpFoldResult. In case no
+/// change is made to the existing mixed basis set, return failure; success
+/// otherwise.
+static LogicalResult
+fetchNewConstantBasis(PatternRewriter &rewriter,
+ SmallVector<OpFoldResult> mixedBasis,
+ SmallVector<OpFoldResult> &newBasis) {
+ // Replace all constant SSA values with the constant attribute.
+ bool hasConstantSSAVal = false;
+ for (OpFoldResult basis : mixedBasis) {
+ std::optional<int64_t> basisVal = getConstantIntValue(basis);
+ if (basisVal && !isa<Attribute>(basis)) {
+ newBasis.push_back(rewriter.getIndexAttr(*basisVal));
+ hasConstantSSAVal = true;
+ } else {
+ newBasis.push_back(basis);
+ }
+ }
+ if (hasConstantSSAVal)
+ return success();
+ return failure();
+}
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisDelinearizeIndexOpPattern
+ : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace all constant SSA values with the constant attribute.
+ SmallVector<OpFoldResult> newBasis;
+ if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+ return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+ rewriter.replaceOpWithNewOp<affine::AffineDelinearizeIndexOp>(
+ op, op.getLinearIndex(), newBasis, op.hasOuterBound());
+ return success();
+ }
+};
} // namespace
void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
- DropUnitExtentBasis>(context);
+ DropUnitExtentBasis,
+ ConstantAttributeBasisDelinearizeIndexOpPattern>(context);
}
//===----------------------------------------------------------------------===//
@@ -4959,12 +5002,31 @@ struct DropLinearizeLeadingZero final
return success();
}
};
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisLinearizeIndexOpPattern
+ : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+ PatternRewriter &rewriter) const override {
+ // Replace all constant SSA values with the constant attribute.
+ SmallVector<OpFoldResult> newBasis;
+ if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+ return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+ rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+ op, op.getMultiIndex(), newBasis, op.getDisjoint());
+ return success();
+ }
+};
} // namespace
void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
- DropLinearizeUnitComponentsIfDisjointOrZero>(context);
+ DropLinearizeUnitComponentsIfDisjointOrZero,
+ ConstantAttributeBasisLinearizeIndexOpPattern>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 5384977151b47f..16cbce35aeec7e 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1946,3 +1946,33 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
return %ret : index
}
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
+// CHECK-SAME: (%[[ARG0:.*]]: index)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[RET:.*]]:2 = affine.delinearize_index %[[ARG0]] into (3, 4) : index, index
+// CHECK: return %[[RET]]#0, %[[RET]]#1, %[[C0]] : index, index, index
+func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
+ (index, index, index) {
+ %c4 = arith.constant 4 : index
+ %c3 = arith.constant 3 : index
+ %c1 = arith.constant 1 : index
+ %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c1)
+ : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK: %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (3, 4) : index
+// CHECK: return %[[RET]] : index
+func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
+ (index) {
+ %c4 = arith.constant 4 : index
+ %c1 = arith.constant 1 : index
+ %1 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (%c1, 3, %c4) : index
+ return %1 : index
+}
|
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figure that this can be done in fold(), looking at whether the getDynamicBasis() adaptor has some non-null Attributes
Hi @krzysz00 - I've made the changes. I figured I've to use Please take a look. Thanks! |
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you double-check the requirements on fold(). I have the vague sense that if you mutated the operation you need to return the results of the operation you just mutated.
(Also, it might be a better move to look through the dynamic basis for constants and, if you find any non-null Attributes, to go swap them in at the right spot)
Hi @krzysz00 - I've made the changes.
Correct, but swapping them at the right spot would need |
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor note, otherwise, approved, thank you!
-- This commit updates
::fold()to have constant(CST)attribute for affine.delinearize_index/linearize_index op's basis
wherever applicable.
-- Essentially the code checks if the mixed basis OpFoldResult
set contains any constant SSA value and converts it to a constant
integer instead.
Signed-off-by: Abhishek Varma [email protected]