|
21 | 21 | #include "mlir/IR/PatternMatch.h"
|
22 | 22 | #include "mlir/Support/MathExtras.h"
|
23 | 23 | #include "mlir/Transforms/InliningUtils.h"
|
| 24 | +#include "llvm/ADT/MapVector.h" |
24 | 25 | #include "llvm/ADT/TypeSwitch.h"
|
25 | 26 |
|
26 | 27 | using namespace mlir;
|
@@ -1594,11 +1595,91 @@ struct ForallOpSingleOrZeroIterationDimsFolder
|
1594 | 1595 | }
|
1595 | 1596 | };
|
1596 | 1597 |
|
| 1598 | +struct FoldTensorCastOfOutputIntoForallOp |
| 1599 | + : public OpRewritePattern<scf::ForallOp> { |
| 1600 | + using OpRewritePattern<scf::ForallOp>::OpRewritePattern; |
| 1601 | + |
| 1602 | + struct TypeCast { |
| 1603 | + Type srcType; |
| 1604 | + Type dstType; |
| 1605 | + }; |
| 1606 | + |
| 1607 | + LogicalResult matchAndRewrite(scf::ForallOp forallOp, |
| 1608 | + PatternRewriter &rewriter) const final { |
| 1609 | + llvm::SmallMapVector<unsigned, TypeCast, 2> tensorCastProducers; |
| 1610 | + llvm::SmallVector<Value> newOutputTensors = forallOp.getOutputs(); |
| 1611 | + for (auto en : llvm::enumerate(newOutputTensors)) { |
| 1612 | + auto castOp = en.value().getDefiningOp<tensor::CastOp>(); |
| 1613 | + if (!castOp) |
| 1614 | + continue; |
| 1615 | + |
| 1616 | + // Only casts that that preserve static information, i.e. will make the |
| 1617 | + // loop result type "more" static than before, will be folded. |
| 1618 | + if (!tensor::preservesStaticInformation(castOp.getDest().getType(), |
| 1619 | + castOp.getSource().getType())) { |
| 1620 | + continue; |
| 1621 | + } |
| 1622 | + |
| 1623 | + tensorCastProducers[en.index()] = |
| 1624 | + TypeCast{castOp.getSource().getType(), castOp.getType()}; |
| 1625 | + newOutputTensors[en.index()] = castOp.getSource(); |
| 1626 | + } |
| 1627 | + |
| 1628 | + if (tensorCastProducers.empty()) |
| 1629 | + return failure(); |
| 1630 | + |
| 1631 | + // Create new loop. |
| 1632 | + Location loc = forallOp.getLoc(); |
| 1633 | + auto newForallOp = rewriter.create<ForallOp>( |
| 1634 | + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), |
| 1635 | + forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(), |
| 1636 | + [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { |
| 1637 | + auto castBlockArgs = |
| 1638 | + llvm::to_vector(bbArgs.take_back(forallOp->getNumResults())); |
| 1639 | + for (auto [index, cast] : tensorCastProducers) { |
| 1640 | + Value &oldTypeBBArg = castBlockArgs[index]; |
| 1641 | + oldTypeBBArg = nestedBuilder.create<tensor::CastOp>( |
| 1642 | + nestedLoc, cast.dstType, oldTypeBBArg); |
| 1643 | + } |
| 1644 | + |
| 1645 | + // Move old body into new parallel loop. |
| 1646 | + SmallVector<Value> ivsBlockArgs = |
| 1647 | + llvm::to_vector(bbArgs.take_front(forallOp.getRank())); |
| 1648 | + ivsBlockArgs.append(castBlockArgs); |
| 1649 | + rewriter.mergeBlocks(forallOp.getBody(), |
| 1650 | + bbArgs.front().getParentBlock(), ivsBlockArgs); |
| 1651 | + }); |
| 1652 | + |
| 1653 | + // After `mergeBlocks` happened, the destinations in the terminator were |
| 1654 | + // mapped to the tensor.cast old-typed results of the output bbArgs. The |
| 1655 | + // destination have to be updated to point to the output bbArgs directly. |
| 1656 | + auto terminator = newForallOp.getTerminator(); |
| 1657 | + for (auto [yieldingOp, outputBlockArg] : |
| 1658 | + llvm::zip(terminator.getYieldingOps(), |
| 1659 | + newForallOp.getOutputBlockArguments())) { |
| 1660 | + auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp); |
| 1661 | + insertSliceOp.getDestMutable().assign(outputBlockArg); |
| 1662 | + } |
| 1663 | + |
| 1664 | + // Cast results back to the original types. |
| 1665 | + rewriter.setInsertionPointAfter(newForallOp); |
| 1666 | + SmallVector<Value> castResults = newForallOp.getResults(); |
| 1667 | + for (auto &item : tensorCastProducers) { |
| 1668 | + Value &oldTypeResult = castResults[item.first]; |
| 1669 | + oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType, |
| 1670 | + oldTypeResult); |
| 1671 | + } |
| 1672 | + rewriter.replaceOp(forallOp, castResults); |
| 1673 | + return success(); |
| 1674 | + } |
| 1675 | +}; |
| 1676 | + |
1597 | 1677 | } // namespace
|
1598 | 1678 |
|
1599 | 1679 | void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
1600 | 1680 | MLIRContext *context) {
|
1601 |
| - results.add<DimOfForallOp, ForallOpControlOperandsFolder, |
| 1681 | + results.add<DimOfForallOp, FoldTensorCastOfOutputIntoForallOp, |
| 1682 | + ForallOpControlOperandsFolder, |
1602 | 1683 | ForallOpSingleOrZeroIterationDimsFolder>(context);
|
1603 | 1684 | }
|
1604 | 1685 |
|
|
0 commit comments