Skip to content

Commit 15c9cea

Browse files
[DispatchCreation] Rework horizontal fusion to not create concats.
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent e09cf66 commit 15c9cea

File tree

1 file changed

+164
-79
lines changed

1 file changed

+164
-79
lines changed

compiler/src/iree/compiler/DispatchCreation/FuseHorizontalContractions.cpp

Lines changed: 164 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
88
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
99
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
10+
#include "iree/compiler/DispatchCreation/FusionUtils.h"
1011
#include "iree/compiler/DispatchCreation/Passes.h"
1112
#include "mlir/Analysis/SliceAnalysis.h"
1213
#include "mlir/Analysis/TopologicalSortUtils.h"
@@ -50,9 +51,9 @@ struct FuseHorizontalContractionsPass final
5051
/// Structs that captures the ops that are to be fused
5152
struct HorizontalFusionGroup {
5253
// Contractions op that are to be fused.
53-
SmallVector<linalg::LinalgOp> contractionOps;
54+
SmallVector<Operation *> contractionOps;
5455
// Optional truncate operations that could be following the contraction op.
55-
std::optional<SmallVector<linalg::GenericOp>> truncateOps;
56+
std::optional<SmallVector<Operation *>> truncateOps;
5657
};
5758

5859
/// Helper method to check operations equivalence
@@ -128,11 +129,11 @@ static bool isHorizontalToGroup(Operation *op,
128129
}
129130

130131
/// Get user of operation that is a truncate operation.
131-
static std::optional<linalg::GenericOp>
132+
static std::optional<linalg::LinalgOp>
132133
getTruncateOp(Operation *op,
133134
const llvm::SetVector<Operation *> &groupedOperations,
134135
const DominanceInfo &dominanceInfo,
135-
std::optional<linalg::GenericOp> seedTruncateOp = std::nullopt) {
136+
std::optional<linalg::LinalgOp> seedTruncateOp = std::nullopt) {
136137
if (!op->hasOneUse()) {
137138
return std::nullopt;
138139
}
@@ -177,7 +178,7 @@ getTruncateOp(Operation *op,
177178
/// the `truncf` on the result.
178179
static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
179180
linalg::LinalgOp seedOp,
180-
const llvm::SmallDenseSet<linalg::LinalgOp> &groupedOperations,
181+
const llvm::SmallDenseSet<Operation *> &groupedOperations,
181182
const DominanceInfo &dominanceInfo, int fusionLimit) {
182183

183184
Value lhs = seedOp->getOperand(0);
@@ -188,10 +189,10 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
188189
auto outType = cast<RankedTensorType>(out.getType());
189190

190191
SetVector<Operation *> allOps;
191-
SmallVector<linalg::LinalgOp> contractionOps = {seedOp};
192-
std::optional<linalg::GenericOp> seedTruncOp =
192+
SmallVector<Operation *> contractionOps = {seedOp};
193+
std::optional<linalg::LinalgOp> seedTruncOp =
193194
getTruncateOp(seedOp, allOps, dominanceInfo);
194-
std::optional<SmallVector<linalg::GenericOp>> truncateOps;
195+
std::optional<SmallVector<Operation *>> truncateOps;
195196
if (seedTruncOp) {
196197
truncateOps = {seedTruncOp.value()};
197198
}
@@ -254,7 +255,7 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
254255
continue;
255256
}
256257

257-
std::optional<linalg::GenericOp> userTruncOp =
258+
std::optional<linalg::LinalgOp> userTruncOp =
258259
getTruncateOp(linalgUser, allOps, dominanceInfo, seedTruncOp);
259260

260261
// If there are truncate ops to fuse and current contraction op
@@ -356,38 +357,140 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter,
356357
return newIndexingMap.insertResult(rewriter.getAffineDimExpr(0), 0);
357358
}
358359

359-
/// During horizontal fusion, there might be operands of the fused operations
360-
/// whose definitions are interspersed between the fused operations. For groups
361-
/// chosen to fuse horizontally, such operations can be moved before the
362-
/// seed contraction operation (where the fused operation is generated).
363-
template <typename T>
364-
static LogicalResult
365-
moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
366-
Operation *insertionPoint, DominanceInfo &dominanceInfo,
367-
ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
368-
BackwardSliceOptions options;
369-
llvm::DenseSet<Operation *> ignoreOperationsSet;
370-
ignoreOperationsSet.insert(ignoreOperations.begin(), ignoreOperations.end());
371-
options.filter = [&](Operation *op) {
372-
return !dominanceInfo.properlyDominates(op, insertionPoint) &&
373-
!ignoreOperationsSet.contains(op);
374-
};
375-
// Set inclusive to true cause the slice is computed from the operand, and
376-
// we want to include the defining op (which is the point here)
377-
options.inclusive = true;
360+
template <typename V, typename R>
361+
static void appendRange(V &vector, R &&range) {
362+
vector.append(range.begin(), range.end());
363+
}
378364

379-
llvm::SetVector<Operation *> slice;
380-
for (auto op : operations) {
381-
for (auto operand : op->getOperands()) {
382-
getBackwardSlice(operand, &slice, options);
365+
static FailureOr<linalg::GenericOp>
366+
fuseHorizontally(RewriterBase &rewriter, Location loc,
367+
ArrayRef<Operation *> ops) {
368+
369+
SmallVector<linalg::LinalgOp> linalgOps;
370+
for (auto op : ops) {
371+
auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
372+
if (!linalgOp) {
373+
return op->emitOpError("expected linalg op for horizontal fusion");
383374
}
375+
linalgOps.push_back(linalgOp);
384376
}
385377

386-
mlir::topologicalSort(slice);
387-
for (auto op : slice) {
388-
rewriter.moveOpBefore(op, insertionPoint);
378+
SmallVector<Value> fusedIns;
379+
SmallVector<Value> fusedOuts;
380+
SmallVector<Type> fusedResultTypes;
381+
SmallVector<AffineMap> fusedInsIndexingMaps;
382+
SmallVector<AffineMap> fusedOutsIndexingMaps;
383+
384+
for (auto linalgOp : linalgOps) {
385+
fusedIns.append(linalgOp.getDpsInputs());
386+
appendRange(fusedOuts, llvm::map_range(linalgOp.getDpsInitsMutable(),
387+
[](OpOperand &operand) {
388+
return operand.get();
389+
}));
390+
fusedResultTypes.append(linalgOp->result_type_begin(),
391+
linalgOp->result_type_end());
392+
appendRange(
393+
fusedInsIndexingMaps,
394+
llvm::map_range(linalgOp.getIndexingMaps().getValue().take_front(
395+
linalgOp.getNumDpsInputs()),
396+
[](Attribute attr) {
397+
return cast<AffineMapAttr>(attr).getValue();
398+
}));
399+
appendRange(
400+
fusedOutsIndexingMaps,
401+
llvm::map_range(linalgOp.getIndexingMaps().getValue().drop_front(
402+
linalgOp.getNumDpsInputs()),
403+
[](Attribute attr) {
404+
return cast<AffineMapAttr>(attr).getValue();
405+
}));
389406
}
390-
return success();
407+
408+
SmallVector<utils::IteratorType> fusedIteratorTypes =
409+
linalgOps.front().getIteratorTypesArray();
410+
SmallVector<AffineMap> fusedIndexingMaps = std::move(fusedInsIndexingMaps);
411+
fusedIndexingMaps.append(fusedOutsIndexingMaps);
412+
auto fusedOp = rewriter.create<linalg::GenericOp>(
413+
loc, fusedResultTypes, fusedIns, fusedOuts, fusedIndexingMaps,
414+
fusedIteratorTypes, [](OpBuilder &, Location, ValueRange) {});
415+
416+
Block *fusedBody = fusedOp.getBlock();
417+
auto insIndex = 0;
418+
auto outsIndex = fusedOp.getNumDpsInputs();
419+
SmallVector<Value> yieldVals;
420+
for (auto op : linalgOps) {
421+
auto linalgOp = cast<linalg::LinalgOp>(op);
422+
Block *body = linalgOp.getBlock();
423+
SmallVector<Value> replacements = llvm::map_to_vector(
424+
fusedBody->getArguments().slice(insIndex, linalgOp.getNumDpsInputs()),
425+
[](BlockArgument arg) -> Value { return arg; });
426+
appendRange(
427+
replacements,
428+
llvm::map_range(fusedBody->getArguments().slice(
429+
outsIndex, linalgOp.getNumDpsInits()),
430+
[](BlockArgument arg) -> Value { return arg; }));
431+
432+
rewriter.mergeBlocks(body, fusedBody, replacements);
433+
insIndex += linalgOp.getNumDpsInputs();
434+
outsIndex += linalgOp.getNumDpsInits();
435+
436+
auto yieldOp = cast<linalg::YieldOp>(fusedBody->getTerminator());
437+
yieldVals.append(yieldOp->operand_begin(), yieldOp->operand_end());
438+
rewriter.eraseOp(yieldOp);
439+
}
440+
OpBuilder::InsertionGuard g(rewriter);
441+
rewriter.setInsertionPointToEnd(fusedBody);
442+
rewriter.create<linalg::YieldOp>(loc, yieldVals);
443+
444+
auto resultsIndex = 0;
445+
for (auto linalgOp : linalgOps) {
446+
rewriter.replaceOp(linalgOp, fusedOp->getResults().slice(
447+
resultsIndex, linalgOp->getNumResults()));
448+
resultsIndex += linalgOp->getNumResults();
449+
}
450+
451+
return fusedOp;
452+
}
453+
454+
static FailureOr<SmallVector<linalg::GenericOp>>
455+
fuseGroup2(RewriterBase &rewriter, HorizontalFusionGroup &fusionGroup,
456+
DominanceInfo &dominanceInfo) {
457+
if (!llvm::all_of(fusionGroup.contractionOps,
458+
[](Operation *op) { return isa<linalg::LinalgOp>(op); })) {
459+
return failure();
460+
}
461+
linalg::LinalgOp baseContractOp =
462+
cast<linalg::LinalgOp>(fusionGroup.contractionOps.front());
463+
Location loc = baseContractOp.getLoc();
464+
OpBuilder::InsertionGuard g(rewriter);
465+
rewriter.setInsertionPoint(baseContractOp);
466+
SmallVector<linalg::GenericOp> fusedOperations;
467+
468+
if (failed(moveOperandDefs(rewriter, fusionGroup.contractionOps,
469+
baseContractOp, dominanceInfo))) {
470+
return baseContractOp.emitOpError("failed to re-order operand definitions");
471+
}
472+
473+
FailureOr<linalg::GenericOp> fusedContractionOp =
474+
fuseHorizontally(rewriter, loc, fusionGroup.contractionOps);
475+
if (failed(fusedContractionOp)) {
476+
return baseContractOp.emitOpError(
477+
"failed to fuse contraction ops horizontally");
478+
}
479+
fusedOperations.push_back(fusedContractionOp.value());
480+
481+
// if (!fusionGroup.truncateOps) {
482+
// return fusedOperations;
483+
// }
484+
485+
// rewriter.setInsertionPoint(fusionGroup.truncateOps->front());
486+
// FailureOr<linalg::GenericOp> fusedTruncateOp =
487+
// fuseHorizontally(rewriter, loc, fusionGroup.truncateOps.value());
488+
// if (failed(fusedTruncateOp)) {
489+
// return baseContractOp.emitOpError(
490+
// "failed to fuse truncation ops horizontally");
491+
// }
492+
// fusedOperations.push_back(fusedTruncateOp.value());
493+
return fusedOperations;
391494
}
392495

393496
/// On finding this pattern
@@ -428,22 +531,23 @@ moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
428531
static LogicalResult fuseGroup(RewriterBase &rewriter,
429532
HorizontalFusionGroup &fusionGroup,
430533
DominanceInfo &dominanceInfo) {
431-
linalg::LinalgOp baseContractOp = fusionGroup.contractionOps.front();
534+
linalg::LinalgOp baseContractOp =
535+
cast<linalg::LinalgOp>(fusionGroup.contractionOps.front());
432536
Location loc = baseContractOp.getLoc();
433537
OpBuilder::InsertionGuard g(rewriter);
434538
rewriter.setInsertionPoint(baseContractOp);
435539

436-
if (failed(moveOperandDefs(
437-
rewriter, ArrayRef<linalg::LinalgOp>(fusionGroup.contractionOps),
438-
baseContractOp, dominanceInfo))) {
540+
if (failed(moveOperandDefs(rewriter, fusionGroup.contractionOps,
541+
baseContractOp, dominanceInfo))) {
439542
return baseContractOp.emitOpError("failed to re-order operand definitions");
440543
}
441544

442545
SmallVector<Value> rhsValues;
443546
SmallVector<Value> initValues;
444547
for (auto op : fusionGroup.contractionOps) {
445-
Value rhs = op.getDpsInputOperand(1)->get();
446-
Value init = op.getDpsInitOperand(0)->get();
548+
auto linalgOp = cast<linalg::LinalgOp>(op);
549+
Value rhs = linalgOp.getDpsInputOperand(1)->get();
550+
Value init = linalgOp.getDpsInitOperand(0)->get();
447551
rhsValues.push_back(rhs);
448552
initValues.push_back(init);
449553
}
@@ -485,15 +589,15 @@ static LogicalResult fuseGroup(RewriterBase &rewriter,
485589
if (fusionGroup.truncateOps) {
486590
SmallVector<Value> newTruncOperands;
487591
SmallVector<AffineMap> newTruncIndexingMaps;
488-
linalg::GenericOp baseTruncOp = fusionGroup.truncateOps->front();
592+
linalg::LinalgOp baseTruncOp =
593+
cast<linalg::LinalgOp>(fusionGroup.truncateOps->front());
489594
SmallVector<AffineMap> baseTruncOpIndexingMaps =
490595
baseTruncOp.getIndexingMapsArray();
491596

492597
rewriter.setInsertionPoint(baseTruncOp);
493-
if (failed(moveOperandDefs(
494-
rewriter,
495-
ArrayRef<linalg::GenericOp>(fusionGroup.truncateOps.value()),
496-
baseTruncOp, dominanceInfo, fusionGroup.contractionOps))) {
598+
if (failed(moveOperandDefs(rewriter, fusionGroup.truncateOps.value(),
599+
baseTruncOp, dominanceInfo,
600+
fusionGroup.contractionOps))) {
497601
return baseTruncOp.emitOpError(
498602
"failed to move operand defs for truncate operations");
499603
}
@@ -503,7 +607,7 @@ static LogicalResult fuseGroup(RewriterBase &rewriter,
503607
// Collect all the operands for the trunc operation.
504608
SmallVector<Value> truncOperands;
505609
for (auto truncOp : fusionGroup.truncateOps.value()) {
506-
truncOperands.push_back(truncOp.getOperand(operandIndex));
610+
truncOperands.push_back(truncOp->getOperand(operandIndex));
507611
}
508612

509613
// Three cases to handle here.
@@ -539,7 +643,7 @@ static LogicalResult fuseGroup(RewriterBase &rewriter,
539643

540644
// Insert truncate operator.
541645
auto baseTruncType =
542-
cast<RankedTensorType>(baseTruncOp.getResult(0).getType());
646+
cast<RankedTensorType>(baseTruncOp->getResult(0).getType());
543647
SmallVector<int64_t> newTruncShape = {
544648
static_cast<int64_t>(rhsValues.size())};
545649
newTruncShape.append(baseTruncType.getShape().begin(),
@@ -610,7 +714,7 @@ void FuseHorizontalContractionsPass::runOnOperation() {
610714
DominanceInfo dominanceInfo(getOperation());
611715

612716
SmallVector<HorizontalFusionGroup> horizontalFusionGroups;
613-
llvm::SmallDenseSet<linalg::LinalgOp> groupedOperations;
717+
llvm::SmallDenseSet<Operation *> groupedOperations;
614718

615719
getOperation()->walk([&](linalg::LinalgOp linalgOp) {
616720
if (!isEmptyFillContractionDAGRootOp(linalgOp)) {
@@ -653,38 +757,19 @@ void FuseHorizontalContractionsPass::runOnOperation() {
653757

654758
IRRewriter rewriter(context);
655759
for (auto &fusionGroup : horizontalFusionGroups) {
656-
if (failed(fuseGroup(rewriter, fusionGroup, dominanceInfo))) {
760+
FailureOr<SmallVector<linalg::GenericOp>> fusedOperations =
761+
fuseGroup2(rewriter, fusionGroup, dominanceInfo);
762+
if (failed(fusedOperations)) {
657763
return signalPassFailure();
658764
}
659-
}
660-
661-
{
662-
RewritePatternSet foldReshapePatterns(context);
663-
tensor::populateFoldTensorEmptyPatterns(foldReshapePatterns);
664-
linalg::FillOp::getCanonicalizationPatterns(foldReshapePatterns, context);
665-
if (failed(applyPatternsGreedily(getOperation(),
666-
std::move(foldReshapePatterns)))) {
667-
getOperation()->emitOpError("failed during reshape folding patterns");
668-
return signalPassFailure();
669-
}
670-
671-
RewritePatternSet foldPatterns(context);
672-
tensor::populateFoldTensorEmptyPatterns(foldPatterns);
673-
linalg::FillOp::getCanonicalizationPatterns(foldPatterns, context);
674-
if (failed(
675-
applyPatternsGreedily(getOperation(), std::move(foldPatterns)))) {
676-
getOperation()->emitOpError("failed to fold empty/fill with concats");
677-
return signalPassFailure();
765+
for (auto fusedOp : fusedOperations.value()) {
766+
rewriter.setInsertionPoint(fusedOp);
767+
if (failed(linalg::deduplicateOperandsAndRemoveDeadResults(
768+
rewriter, fusedOp, /*removeOutputs=*/false))) {
769+
fusedOp->emitOpError("failed to remove duplicate operands");
770+
return signalPassFailure();
771+
}
678772
}
679773
}
680-
681-
// Note: Currently these patterns are required due to early lowering of
682-
// tensor.concat. When we choose to move the lowering of tensor.concat later,
683-
// these patterns should be dropped.
684-
RewritePatternSet patterns(context);
685-
tensor::populateDecomposeTensorConcatPatterns(patterns);
686-
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
687-
return signalPassFailure();
688-
}
689774
}
690775
} // namespace mlir::iree_compiler::DispatchCreation

0 commit comments

Comments
 (0)