Skip to content

Conversation

@Abhishek-Varma
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma commented Nov 25, 2024

-- This commit updates ::fold() to have constant(CST)
attribute for affine.delinearize_index/linearize_index op's basis
wherever applicable.
-- Essentially the code checks if the mixed basis OpFoldResult
set contains any constant SSA value and converts it to a constant
integer instead.

Signed-off-by: Abhishek Varma [email protected]

-- This commit adds canonicalization pattern to have constant(CST)
   attribute for affine.delinearize_index/linearize_index op's basis
   wherever applicable.
-- Essentially the patterns check if the mixed basis OpFoldResult
   set contains any constant SSA value and converts it to a constant
   integer attribute instead.

Signed-off-by: Abhishek Varma <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Nov 25, 2024

@llvm/pr-subscribers-mlir-affine

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit adds canonicalization pattern to have constant(CST)
attribute for affine.delinearize_index/linearize_index op's basis
wherever applicable.
-- Essentially the patterns check if the mixed basis OpFoldResult
set contains any constant SSA value and converts it to a constant
integer attribute instead.

Signed-off-by: Abhishek Varma <[email protected]>


Full diff: https://github.com/llvm/llvm-project/pull/117572.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+64-2)
  • (modified) mlir/test/Dialect/Affine/canonicalize.mlir (+30)
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 67d7da622a3550..3e82ec00763142 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -4729,12 +4729,55 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
     return success();
   }
 };
+
+/// Give mixed basis of affine.delinearize_index/linearize_index replace
+/// constant SSA values with constant attribute as OpFoldResult. In case no
+/// change is made to the existing mixed basis set, return failure; success
+/// otherwise.
+static LogicalResult
+fetchNewConstantBasis(PatternRewriter &rewriter,
+                      SmallVector<OpFoldResult> mixedBasis,
+                      SmallVector<OpFoldResult> &newBasis) {
+  // Replace all constant SSA values with the constant attribute.
+  bool hasConstantSSAVal = false;
+  for (OpFoldResult basis : mixedBasis) {
+    std::optional<int64_t> basisVal = getConstantIntValue(basis);
+    if (basisVal && !isa<Attribute>(basis)) {
+      newBasis.push_back(rewriter.getIndexAttr(*basisVal));
+      hasConstantSSAVal = true;
+    } else {
+      newBasis.push_back(basis);
+    }
+  }
+  if (hasConstantSSAVal)
+    return success();
+  return failure();
+}
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisDelinearizeIndexOpPattern
+    : public OpRewritePattern<affine::AffineDelinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineDelinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    // Replace all constant SSA values with the constant attribute.
+    SmallVector<OpFoldResult> newBasis;
+    if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+      return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+    rewriter.replaceOpWithNewOp<affine::AffineDelinearizeIndexOp>(
+        op, op.getLinearIndex(), newBasis, op.hasOuterBound());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.insert<CancelDelinearizeOfLinearizeDisjointExactTail,
-                  DropUnitExtentBasis>(context);
+                  DropUnitExtentBasis,
+                  ConstantAttributeBasisDelinearizeIndexOpPattern>(context);
 }
 
 //===----------------------------------------------------------------------===//
@@ -4959,12 +5002,31 @@ struct DropLinearizeLeadingZero final
     return success();
   }
 };
+
+/// Folds away constant SSA Value with constant Attribute in basis.
+struct ConstantAttributeBasisLinearizeIndexOpPattern
+    : public OpRewritePattern<affine::AffineLinearizeIndexOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op,
+                                PatternRewriter &rewriter) const override {
+    // Replace all constant SSA values with the constant attribute.
+    SmallVector<OpFoldResult> newBasis;
+    if (failed(fetchNewConstantBasis(rewriter, op.getMixedBasis(), newBasis)))
+      return rewriter.notifyMatchFailure(op, "no constant SSA value in basis");
+
+    rewriter.replaceOpWithNewOp<affine::AffineLinearizeIndexOp>(
+        op, op.getMultiIndex(), newBasis, op.getDisjoint());
+    return success();
+  }
+};
 } // namespace
 
 void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns(
     RewritePatternSet &patterns, MLIRContext *context) {
   patterns.add<CancelLinearizeOfDelinearizeExact, DropLinearizeLeadingZero,
-               DropLinearizeUnitComponentsIfDisjointOrZero>(context);
+               DropLinearizeUnitComponentsIfDisjointOrZero,
+               ConstantAttributeBasisLinearizeIndexOpPattern>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir
index 5384977151b47f..16cbce35aeec7e 100644
--- a/mlir/test/Dialect/Affine/canonicalize.mlir
+++ b/mlir/test/Dialect/Affine/canonicalize.mlir
@@ -1946,3 +1946,33 @@ func.func @affine_leading_zero_no_outer_bound(%arg0: index, %arg1: index) -> ind
   return %ret : index
 }
 
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_delinearize_index
+// CHECK-SAME:    (%[[ARG0:.*]]: index)
+// CHECK-DAG:     %[[C0:.+]] = arith.constant 0 : index
+// CHECK:         %[[RET:.*]]:2 = affine.delinearize_index %[[ARG0]] into (3, 4) : index, index
+// CHECK:         return %[[RET]]#0, %[[RET]]#1, %[[C0]] : index, index, index
+func.func @cst_value_to_cst_attr_basis_delinearize_index(%arg0 : index) ->
+    (index, index, index) {
+  %c4 = arith.constant 4 : index
+  %c3 = arith.constant 3 : index
+  %c1 = arith.constant 1 : index
+  %0:3 = affine.delinearize_index %arg0 into (%c3, %c4, %c1)
+      : index, index, index
+  return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+// CHECK-LABEL: @cst_value_to_cst_attr_basis_linearize_index
+// CHECK-SAME:    (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+// CHECK:         %[[RET:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (3, 4) : index
+// CHECK:         return %[[RET]] : index
+func.func @cst_value_to_cst_attr_basis_linearize_index(%arg0 : index, %arg1 : index, %arg2 : index) ->
+    (index) {
+  %c4 = arith.constant 4 : index
+  %c1 = arith.constant 1 : index
+  %1 = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by  (%c1, 3, %c4) : index
+  return %1 : index
+}

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I figure that this can be done in fold(), looking at whether the getDynamicBasis() adaptor has some non-null Attributes

@Abhishek-Varma
Copy link
Contributor Author

I figure that this can be done in fold(), looking at whether the getDynamicBasis() adaptor has some non-null Attributes

Hi @krzysz00 - I've made the changes. I figured I've to use MutableOperandRange for updating the dynamic basis - hence wasn't able to work with ::fold earlier.

Please take a look. Thanks!

@Abhishek-Varma Abhishek-Varma changed the title [MLIR][Affine] Add canonicalization pattern to have constant basis attr for affine.delinearize_index/linearize_index [MLIR][Affine] Update ::fold() to have constant basis attr for affine.delinearize_index/linearize_index wherever applicable Nov 26, 2024
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you double-check the requirements on fold(). I have the vague sense that if you mutated the operation you need to return the results of the operation you just mutated.

(Also, it might be a better move to look through the dynamic basis for constants and, if you find any non-null Attributes, to go swap them in at the right spot)

@Abhishek-Varma
Copy link
Contributor Author

Could you double-check the requirements on fold(). I have the vague sense that if you mutated the operation you need to return the results of the operation you just mutated.

Hi @krzysz00 - I've made the changes.
For affine.delinearize_index I'm returning success() whereas for affine.linearize_index I'm returning the result of the operation.

(Also, it might be a better move to look through the dynamic basis for constants and, if you find any non-null Attributes, to go swap them in at the right spot)

Correct, but swapping them at the right spot would need mixedBasis (or staticBasis) too - so I'm going ahead with inferring the constants from mixedBasis itself. That should be okay to go ahead with.

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor note, otherwise, approved, thank you!

@Abhishek-Varma Abhishek-Varma merged commit d83148f into llvm:main Nov 29, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants