Skip to content

Commit c8117eb

Browse files
committed
[mlir] Add a pattern to fold tensor.cast into scf.forall.
Differential revision: https://reviews.llvm.org/D146558
1 parent 7949a2a commit c8117eb

File tree

2 files changed

+131
-1
lines changed

2 files changed

+131
-1
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/Support/MathExtras.h"
2323
#include "mlir/Transforms/InliningUtils.h"
24+
#include "llvm/ADT/MapVector.h"
2425
#include "llvm/ADT/TypeSwitch.h"
2526

2627
using namespace mlir;
@@ -1594,11 +1595,91 @@ struct ForallOpSingleOrZeroIterationDimsFolder
15941595
}
15951596
};
15961597

1598+
struct FoldTensorCastOfOutputIntoForallOp
1599+
: public OpRewritePattern<scf::ForallOp> {
1600+
using OpRewritePattern<scf::ForallOp>::OpRewritePattern;
1601+
1602+
struct TypeCast {
1603+
Type srcType;
1604+
Type dstType;
1605+
};
1606+
1607+
LogicalResult matchAndRewrite(scf::ForallOp forallOp,
1608+
PatternRewriter &rewriter) const final {
1609+
llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers;
1610+
llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs();
1611+
for (auto en : llvm::enumerate(newOutputTensors)) {
1612+
auto castOp = en.value().getDefiningOp<tensor::CastOp>();
1613+
if (!castOp)
1614+
continue;
1615+
1616+
// Only casts that that preserve static information, i.e. will make the
1617+
// loop result type "more" static than before, will be folded.
1618+
if (!tensor::preservesStaticInformation(castOp.getDest().getType(),
1619+
castOp.getSource().getType())) {
1620+
continue;
1621+
}
1622+
1623+
tensorCastProducers[en.index()] =
1624+
TypeCast{castOp.getSource().getType(), castOp.getType()};
1625+
newOutputTensors[en.index()] = castOp.getSource();
1626+
}
1627+
1628+
if (tensorCastProducers.empty())
1629+
return failure();
1630+
1631+
// Create new loop.
1632+
Location loc = forallOp.getLoc();
1633+
auto newForallOp = rewriter.create<ForallOp>(
1634+
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
1635+
forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(),
1636+
[&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) {
1637+
auto castBlockArgs =
1638+
llvm::to_vector(bbArgs.take_back(forallOp->getNumResults()));
1639+
for (auto [index, cast] : tensorCastProducers) {
1640+
Value &oldTypeBBArg = castBlockArgs[index];
1641+
oldTypeBBArg = nestedBuilder.create<tensor::CastOp>(
1642+
nestedLoc, cast.dstType, oldTypeBBArg);
1643+
}
1644+
1645+
// Move old body into new parallel loop.
1646+
SmallVector<Value> ivsBlockArgs =
1647+
llvm::to_vector(bbArgs.take_front(forallOp.getRank()));
1648+
ivsBlockArgs.append(castBlockArgs);
1649+
rewriter.mergeBlocks(forallOp.getBody(),
1650+
bbArgs.front().getParentBlock(), ivsBlockArgs);
1651+
});
1652+
1653+
// After `mergeBlocks` happened, the destinations in the terminator were
1654+
// mapped to the tensor.cast old-typed results of the output bbArgs. The
1655+
// destination have to be updated to point to the output bbArgs directly.
1656+
auto terminator = newForallOp.getTerminator();
1657+
for (auto [yieldingOp, outputBlockArg] :
1658+
llvm::zip(terminator.getYieldingOps(),
1659+
newForallOp.getOutputBlockArguments())) {
1660+
auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
1661+
insertSliceOp.getDestMutable().assign(outputBlockArg);
1662+
}
1663+
1664+
// Cast results back to the original types.
1665+
rewriter.setInsertionPointAfter(newForallOp);
1666+
SmallVector<Value> castResults = newForallOp.getResults();
1667+
for (auto &item : tensorCastProducers) {
1668+
Value &oldTypeResult = castResults[item.first];
1669+
oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType,
1670+
oldTypeResult);
1671+
}
1672+
rewriter.replaceOp(forallOp, castResults);
1673+
return success();
1674+
}
1675+
};
1676+
15971677
} // namespace
15981678

15991679
void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
16001680
MLIRContext *context) {
1601-
results.add<DimOfForallOp, ForallOpControlOperandsFolder,
1681+
results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp,
1682+
ForallOpControlOperandsFolder,
16021683
ForallOpSingleOrZeroIterationDimsFolder>(context);
16031684
}
16041685

mlir/test/Dialect/SCF/canonicalize.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,3 +1651,52 @@ func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> {
16511651
// CHECK: %[[EMPTY:.*]] = tensor.empty
16521652
// CHECK: return %[[EMPTY]]
16531653

1654+
// -----
1655+
1656+
func.func @fold_tensor_cast_into_forall(
1657+
%in: tensor<2xi32>, %out: tensor<2xi32>) -> tensor<2xi32> {
1658+
%cst = arith.constant dense<[100500]> : tensor<1xi32>
1659+
1660+
1661+
%out_cast = tensor.cast %out : tensor<2xi32> to tensor<?xi32>
1662+
%result = scf.forall (%i) = (0) to (2) step (1)
1663+
shared_outs (%out_ = %out_cast) -> tensor<?xi32> {
1664+
1665+
scf.forall.in_parallel {
1666+
tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
1667+
: tensor<1xi32> into tensor<?xi32>
1668+
}
1669+
}
1670+
%result_cast = tensor.cast %result : tensor<?xi32> to tensor<2xi32>
1671+
func.return %result_cast : tensor<2xi32>
1672+
}
1673+
// CHECK-LABEL: @fold_tensor_cast_into_forall
1674+
// CHECK-NOT: tensor.cast
1675+
// CHECK: parallel_insert_slice
1676+
// CHECK-SAME: : tensor<1xi32> into tensor<2xi32>
1677+
// CHECK-NOT: tensor.cast
1678+
1679+
// -----
1680+
1681+
func.func @do_not_fold_tensor_cast_from_dynamic_to_static_type_into_forall(
1682+
%in: tensor<?xi32>, %out: tensor<?xi32>) -> tensor<?xi32> {
1683+
%cst = arith.constant dense<[100500]> : tensor<1xi32>
1684+
1685+
1686+
%out_cast = tensor.cast %out : tensor<?xi32> to tensor<2xi32>
1687+
%result = scf.forall (%i) = (0) to (2) step (1)
1688+
shared_outs (%out_ = %out_cast) -> tensor<2xi32> {
1689+
1690+
scf.forall.in_parallel {
1691+
tensor.parallel_insert_slice %cst into %out_[%i] [1] [1]
1692+
: tensor<1xi32> into tensor<2xi32>
1693+
}
1694+
}
1695+
%result_cast = tensor.cast %result : tensor<2xi32> to tensor<?xi32>
1696+
func.return %result_cast : tensor<?xi32>
1697+
}
1698+
// CHECK-LABEL: @do_not_fold_tensor_cast_
1699+
// CHECK: tensor.cast
1700+
// CHECK: parallel_insert_slice
1701+
// CHECK-SAME: : tensor<1xi32> into tensor<2xi32>
1702+
// CHECK: tensor.cast

0 commit comments

Comments
 (0)