Skip to content

Commit 346b9d1

Browse files
[mlir][Linalg] Canonicalize TensorCastOp away when it feeds a LinalgOp.
This canonicalization is the counterpart of MemRefCastOp -> LinalgOp but on tensors. This is needed to properly canonicalize post linalg tiling on tensors. Differential Revision: https://reviews.llvm.org/D88729
1 parent 348d85a commit 346b9d1

File tree

6 files changed

+189
-2
lines changed

6 files changed

+189
-2
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,19 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
404404
return getInitTensors()[i];
405405
}]
406406
>,
407+
InterfaceMethod<
408+
/*desc=*/[{
409+
Return the number of inputs, output buffers and init tensors operands.
410+
}],
411+
/*retTy=*/"unsigned",
412+
/*methodName=*/"getNumShapedOperands",
413+
/*args=*/(ins),
414+
/*methodBody=*/"",
415+
/*defaultImplementation=*/[{
416+
auto range = this->getOperation()->getOperands();
417+
return getNumInputsAndOutputBuffers() + $_op.getNumInitTensors();
418+
}]
419+
>,
407420
InterfaceMethod<
408421
/*desc=*/[{
409422
Return the range over inputs, output buffers and init tensors.
@@ -414,7 +427,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
414427
/*methodBody=*/"",
415428
/*defaultImplementation=*/[{
416429
auto range = this->getOperation()->getOperands();
417-
return {range.begin(), range.begin() + getNumInputsAndOutputs()};
430+
return {range.begin(), range.begin() + getNumShapedOperands()};
418431
}]
419432
>,
420433
InterfaceMethod<
@@ -621,6 +634,27 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
621634
}]
622635
>
623636
];
637+
638+
let extraClassDeclaration = [{
639+
/// Returns all the operands past the inputs, output_buffers and
640+
/// init_tensors operands. Asserts that these operands are value types to
641+
/// allow transformations like tiling to just use the values when cloning
642+
/// `linalgOp`.
643+
SmallVector<Value, 4> getAssumedNonShapedOperands() {
644+
unsigned numShapedOperands = getNumInputsAndOutputs();
645+
unsigned nExtraOperands =
646+
getOperation()->getNumOperands() - numShapedOperands;
647+
SmallVector<Value, 4> res;
648+
res.reserve(nExtraOperands);
649+
for (unsigned i = 0; i < nExtraOperands; ++i) {
650+
res.push_back(getOperation()->getOperand(numShapedOperands + i));
651+
assert((res.back().getType().isSignlessIntOrIndexOrFloat()
652+
|| res.back().getType().isa<VectorType>()) &&
653+
"expected scalar or vector type");
654+
}
655+
return res;
656+
}
657+
}];
624658
}
625659

626660
#endif // LINALG_IR_STRUCTURED_OPS_INTERFACE

mlir/include/mlir/Dialect/StandardOps/IR/Ops.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,31 @@ ParseResult parseDimAndSymbolList(OpAsmParser &parser,
350350
/// ```
351351
bool canFoldIntoConsumerOp(MemRefCastOp castOp);
352352

353+
/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
354+
/// Determines whether TensorCastOp casts to a more dynamic version of the
355+
/// source tensor. This is useful to fold a tensor_cast into a consuming op and
356+
/// implement canonicalization patterns for ops in different dialects that may
357+
/// consume the results of tensor_cast operations. Such foldable tensor_cast
358+
/// operations are typically inserted as `subtensor` ops and are canonicalized,
359+
/// to preserve the type compatibility of their uses.
360+
///
361+
/// Returns true when all conditions are met:
362+
/// 1. source and result are ranked tensors with same element type and rank.
363+
/// 2. the tensor type has more static information than the result
364+
///
365+
/// Example:
366+
/// ```mlir
367+
/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
368+
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
369+
/// ```
370+
///
371+
/// folds into:
372+
///
373+
/// ```mlir
374+
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
375+
/// ```
376+
bool canFoldIntoConsumerOp(TensorCastOp castOp);
377+
353378
/// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
354379
/// comparison predicates.
355380
bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3334,7 +3334,7 @@ def TensorCastOp : CastOp<"tensor_cast"> {
33343334
```
33353335
}];
33363336

3337-
let arguments = (ins AnyTensor);
3337+
let arguments = (ins AnyTensor:$source);
33383338
let results = (outs AnyTensor);
33393339

33403340
let extraClassDeclaration = [{

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/StandardTypes.h"
2626
#include "mlir/Support/LLVM.h"
2727

28+
#include "llvm/ADT/SetVector.h"
2829
#include "llvm/ADT/StringSet.h"
2930
#include "llvm/Support/FormatVariadic.h"
3031
#include "llvm/Support/MathExtras.h"
@@ -1498,12 +1499,65 @@ struct EraseDeadLinalgOp : public RewritePattern {
14981499
return failure();
14991500
}
15001501
};
1502+
1503+
struct FoldTensorCastOp : public RewritePattern {
1504+
FoldTensorCastOp(PatternBenefit benefit = 1)
1505+
: RewritePattern(benefit, MatchAnyOpTypeTag()) {}
1506+
1507+
LogicalResult matchAndRewrite(Operation *op,
1508+
PatternRewriter &rewriter) const override {
1509+
auto linalgOp = dyn_cast<LinalgOp>(op);
1510+
if (!linalgOp)
1511+
return failure();
1512+
1513+
// If no operand comes from a TensorCastOp and can be folded then fail.
1514+
bool hasTensorCastOperand =
1515+
llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) {
1516+
if (v.isa<BlockArgument>())
1517+
return false;
1518+
auto castOp = v.getDefiningOp<TensorCastOp>();
1519+
return castOp && canFoldIntoConsumerOp(castOp);
1520+
});
1521+
if (!hasTensorCastOperand)
1522+
return failure();
1523+
1524+
SmallVector<Type, 4> newResultTypes;
1525+
newResultTypes.reserve(op->getNumResults());
1526+
SmallVector<Value, 4> newOperands;
1527+
newOperands.reserve(op->getNumOperands());
1528+
// Inputs may fold.
1529+
for (Value v : linalgOp.getInputs()) {
1530+
auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
1531+
newOperands.push_back(
1532+
canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v);
1533+
}
1534+
// Output buffers are memrefs, they don't fold.
1535+
newOperands.append(linalgOp.getOutputBuffers().begin(),
1536+
linalgOp.getOutputBuffers().end());
1537+
// Init tensors may fold, in which case the resultType must also change.
1538+
for (Value v : linalgOp.getInitTensors()) {
1539+
auto tensorCastOp = v.getDefiningOp<TensorCastOp>();
1540+
bool fold = canFoldIntoConsumerOp(tensorCastOp);
1541+
newOperands.push_back(fold ? tensorCastOp.getOperand() : v);
1542+
newResultTypes.push_back(newOperands.back().getType());
1543+
}
1544+
auto extraOperands = linalgOp.getAssumedNonShapedOperands();
1545+
newOperands.append(extraOperands.begin(), extraOperands.end());
1546+
// Clone op.
1547+
Operation *newOp =
1548+
linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
1549+
rewriter.replaceOp(op, newOp->getResults());
1550+
1551+
return success();
1552+
}
1553+
};
15011554
} // namespace
15021555

15031556
#define CANONICALIZERS_AND_FOLDERS(XXX) \
15041557
void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \
15051558
MLIRContext *context) { \
15061559
results.insert<EraseDeadLinalgOp>(); \
1560+
results.insert<FoldTensorCastOp>(); \
15071561
} \
15081562
\
15091563
LogicalResult XXX::fold(ArrayRef<Attribute>, \

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3157,6 +3157,60 @@ bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) {
31573157
return true;
31583158
}
31593159

3160+
/// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors.
3161+
/// Determines whether TensorCastOp casts to a more dynamic version of the
3162+
/// source tensor. This is useful to fold a tensor_cast into a consuming op and
3163+
/// implement canonicalization patterns for ops in different dialects that may
3164+
/// consume the results of tensor_cast operations. Such foldable tensor_cast
3165+
/// operations are typically inserted as `subtensor` ops and are canonicalized,
3166+
/// to preserve the type compatibility of their uses.
3167+
///
3168+
/// Returns true when all conditions are met:
3169+
/// 1. source and result are ranked tensors with same element type and rank.
3170+
/// 2. the tensor type has more static information than the result
3171+
///
3172+
/// Example:
3173+
/// ```mlir
3174+
/// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
3175+
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
3176+
/// ```
3177+
///
3178+
/// folds into:
3179+
///
3180+
/// ```mlir
3181+
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
3182+
/// ```
3183+
bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) {
3184+
if (!castOp)
3185+
return false;
3186+
3187+
RankedTensorType sourceType =
3188+
castOp.source().getType().dyn_cast<RankedTensorType>();
3189+
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
3190+
3191+
// Requires RankedTensorType.
3192+
if (!sourceType || !resultType)
3193+
return false;
3194+
3195+
// Requires same elemental type.
3196+
if (sourceType.getElementType() != resultType.getElementType())
3197+
return false;
3198+
3199+
// Requires same rank.
3200+
if (sourceType.getRank() != resultType.getRank())
3201+
return false;
3202+
3203+
// If cast is towards more static sizes along any dimension, don't fold.
3204+
for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
3205+
auto ss = std::get<0>(it), st = std::get<1>(it);
3206+
if (ss != st)
3207+
if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st))
3208+
return false;
3209+
}
3210+
3211+
return true;
3212+
}
3213+
31603214
namespace {
31613215
/// Pattern to rewrite a subview op with MemRefCast arguments.
31623216
/// This essentially pushes memref_cast past its consuming subview when

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,3 +259,23 @@ func @reshape_splat_constant_float64() -> tensor<2x4x2xf64>
259259
// CHECK: %[[CST:.*]] = constant dense<{{.*}}> : tensor<2x4x2xf64>
260260
// CHECK-NOT: linalg.tensor_reshape
261261
// CHECK: return %[[CST]]
262+
263+
// -----
264+
265+
// CHECK-LABEL: func @tensor_cast(
266+
func @tensor_cast(%a : tensor<3x4xf32>, %b : tensor<4x?xf32>, %c : tensor<3x?xf32>)
267+
-> tensor<3x?xf32>
268+
{
269+
%ta = tensor_cast %a : tensor<3x4xf32> to tensor<?x?xf32>
270+
%tb = tensor_cast %b : tensor<4x?xf32> to tensor<?x?xf32>
271+
%tc = tensor_cast %c : tensor<3x?xf32> to tensor<?x?xf32>
272+
273+
// CHECK: linalg.matmul ins({{.*}}tensor<3x4xf32>, tensor<4x?xf32>)
274+
// CHECK-SAME: init({{.*}}tensor<3x?xf32>) -> tensor<3x?xf32>
275+
%0 = linalg.matmul ins(%ta, %tb: tensor<?x?xf32>, tensor<?x?xf32>)
276+
init(%tc: tensor<?x?xf32>) -> tensor<?x?xf32>
277+
278+
%1 = tensor_cast %0 : tensor<?x?xf32> to tensor<3x?xf32>
279+
280+
return %1: tensor<3x?xf32>
281+
}

0 commit comments

Comments
 (0)