diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 92fd6e99338ae..1dd9b9a440ecc 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1158,7 +1158,7 @@ def AffineLinearizeIndexOp : Affine_Op<"linearize_index", let assemblyFormat = [{ (`disjoint` $disjoint^)? ` ` - `[` $multi_index `]` `by` ` ` + `[` $multi_index `]` `by` custom($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren") attr-dict `:` type($linear_index) }]; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index ca55c44856d19..3d38de4bf1068 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4734,11 +4734,29 @@ struct DropLinearizeUnitComponentsIfDisjointOrZero final return success(); } }; + +/// Rewrite `affine.linearize_index [%%x] by (%b)`, into `%x`. +/// +/// By definition, that operation is `affine.apply affine_map<()[s0] -> (s0)>,` +/// which is the identity. +struct DropLinearizeOneBasisElement final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op, + PatternRewriter &rewriter) const override { + if (op.getStaticBasis().size() != 1 || op.getMultiIndex().size() != 1) + return failure(); + rewriter.replaceOp(op, op.getMultiIndex().front()); + return success(); + } +}; } // namespace void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns( RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index f6007aa16c126..fa179744094c6 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1566,3 +1566,14 @@ func.func @linearize_all_zero_unit_basis() -> index { %ret = affine.linearize_index [%c0, %c0] by (1, 1) : index return %ret : index } + +// ----- + +// CHECK-LABEL: @linearize_one_element_basis +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index) +// CHECK-NOT: affine.linearize_index +// CHECK: return %[[arg0]] +func.func @linearize_one_element_basis(%arg0: index, %arg1: index) -> index { + %ret = affine.linearize_index [%arg0] by (%arg1) : index + return %ret : index +}