77#include " iree/compiler/Dialect/Flow/IR/FlowOps.h"
88#include " iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
99#include " iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
10+ #include " iree/compiler/DispatchCreation/FusionUtils.h"
1011#include " iree/compiler/DispatchCreation/Passes.h"
1112#include " mlir/Analysis/SliceAnalysis.h"
1213#include " mlir/Analysis/TopologicalSortUtils.h"
@@ -50,9 +51,9 @@ struct FuseHorizontalContractionsPass final
5051// / Structs that captures the ops that are to be fused
5152struct HorizontalFusionGroup {
5253 // Contractions op that are to be fused.
53- SmallVector<linalg::LinalgOp > contractionOps;
54+ SmallVector<Operation * > contractionOps;
5455 // Optional truncate operations that could be following the contraction op.
55- std::optional<SmallVector<linalg::GenericOp >> truncateOps;
56+ std::optional<SmallVector<Operation * >> truncateOps;
5657};
5758
5859// / Helper method to check operations equivalence
@@ -128,11 +129,11 @@ static bool isHorizontalToGroup(Operation *op,
128129}
129130
130131// / Get user of operation that is a truncate operation.
131- static std::optional<linalg::GenericOp >
132+ static std::optional<linalg::LinalgOp >
132133getTruncateOp (Operation *op,
133134 const llvm::SetVector<Operation *> &groupedOperations,
134135 const DominanceInfo &dominanceInfo,
135- std::optional<linalg::GenericOp > seedTruncateOp = std::nullopt ) {
136+ std::optional<linalg::LinalgOp > seedTruncateOp = std::nullopt ) {
136137 if (!op->hasOneUse ()) {
137138 return std::nullopt ;
138139 }
@@ -177,7 +178,7 @@ getTruncateOp(Operation *op,
177178// / the `truncf` on the result.
178179static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers (
179180 linalg::LinalgOp seedOp,
180- const llvm::SmallDenseSet<linalg::LinalgOp > &groupedOperations,
181+ const llvm::SmallDenseSet<Operation * > &groupedOperations,
181182 const DominanceInfo &dominanceInfo, int fusionLimit) {
182183
183184 Value lhs = seedOp->getOperand (0 );
@@ -188,10 +189,10 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
188189 auto outType = cast<RankedTensorType>(out.getType ());
189190
190191 SetVector<Operation *> allOps;
191- SmallVector<linalg::LinalgOp > contractionOps = {seedOp};
192- std::optional<linalg::GenericOp > seedTruncOp =
192+ SmallVector<Operation * > contractionOps = {seedOp};
193+ std::optional<linalg::LinalgOp > seedTruncOp =
193194 getTruncateOp (seedOp, allOps, dominanceInfo);
194- std::optional<SmallVector<linalg::GenericOp >> truncateOps;
195+ std::optional<SmallVector<Operation * >> truncateOps;
195196 if (seedTruncOp) {
196197 truncateOps = {seedTruncOp.value ()};
197198 }
@@ -254,7 +255,7 @@ static std::optional<HorizontalFusionGroup> getHorizontalFusionGroupMembers(
254255 continue ;
255256 }
256257
257- std::optional<linalg::GenericOp > userTruncOp =
258+ std::optional<linalg::LinalgOp > userTruncOp =
258259 getTruncateOp (linalgUser, allOps, dominanceInfo, seedTruncOp);
259260
260261 // If there are truncate ops to fuse and current contraction op
@@ -356,38 +357,140 @@ static AffineMap getConcatenatedIndexingMap(RewriterBase &rewriter,
356357 return newIndexingMap.insertResult (rewriter.getAffineDimExpr (0 ), 0 );
357358}
358359
359- // / During horizontal fusion, there might be operands of the fused operations
360- // / whose definitions are interspersed between the fused operations. For groups
361- // / chosen to fuse horizontally, such operations can be moved before the
362- // / seed contraction operation (where the fused operation is generated).
363- template <typename T>
364- static LogicalResult
365- moveOperandDefs (RewriterBase &rewriter, ArrayRef<T> operations,
366- Operation *insertionPoint, DominanceInfo &dominanceInfo,
367- ArrayRef<linalg::LinalgOp> ignoreOperations = {}) {
368- BackwardSliceOptions options;
369- llvm::DenseSet<Operation *> ignoreOperationsSet;
370- ignoreOperationsSet.insert (ignoreOperations.begin (), ignoreOperations.end ());
371- options.filter = [&](Operation *op) {
372- return !dominanceInfo.properlyDominates (op, insertionPoint) &&
373- !ignoreOperationsSet.contains (op);
374- };
375- // Set inclusive to true cause the slice is computed from the operand, and
376- // we want to include the defining op (which is the point here)
377- options.inclusive = true ;
360+ template <typename V, typename R>
361+ static void appendRange (V &vector, R &&range) {
362+ vector.append (range.begin (), range.end ());
363+ }
378364
379- llvm::SetVector<Operation *> slice;
380- for (auto op : operations) {
381- for (auto operand : op->getOperands ()) {
382- getBackwardSlice (operand, &slice, options);
365+ static FailureOr<linalg::GenericOp>
366+ fuseHorizontally (RewriterBase &rewriter, Location loc,
367+ ArrayRef<Operation *> ops) {
368+
369+ SmallVector<linalg::LinalgOp> linalgOps;
370+ for (auto op : ops) {
371+ auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
372+ if (!linalgOp) {
373+ return op->emitOpError (" expected linalg op for horizontal fusion" );
383374 }
375+ linalgOps.push_back (linalgOp);
384376 }
385377
386- mlir::topologicalSort (slice);
387- for (auto op : slice) {
388- rewriter.moveOpBefore (op, insertionPoint);
378+ SmallVector<Value> fusedIns;
379+ SmallVector<Value> fusedOuts;
380+ SmallVector<Type> fusedResultTypes;
381+ SmallVector<AffineMap> fusedInsIndexingMaps;
382+ SmallVector<AffineMap> fusedOutsIndexingMaps;
383+
384+ for (auto linalgOp : linalgOps) {
385+ fusedIns.append (linalgOp.getDpsInputs ());
386+ appendRange (fusedOuts, llvm::map_range (linalgOp.getDpsInitsMutable (),
387+ [](OpOperand &operand) {
388+ return operand.get ();
389+ }));
390+ fusedResultTypes.append (linalgOp->result_type_begin (),
391+ linalgOp->result_type_end ());
392+ appendRange (
393+ fusedInsIndexingMaps,
394+ llvm::map_range (linalgOp.getIndexingMaps ().getValue ().take_front (
395+ linalgOp.getNumDpsInputs ()),
396+ [](Attribute attr) {
397+ return cast<AffineMapAttr>(attr).getValue ();
398+ }));
399+ appendRange (
400+ fusedOutsIndexingMaps,
401+ llvm::map_range (linalgOp.getIndexingMaps ().getValue ().drop_front (
402+ linalgOp.getNumDpsInputs ()),
403+ [](Attribute attr) {
404+ return cast<AffineMapAttr>(attr).getValue ();
405+ }));
389406 }
390- return success ();
407+
408+ SmallVector<utils::IteratorType> fusedIteratorTypes =
409+ linalgOps.front ().getIteratorTypesArray ();
410+ SmallVector<AffineMap> fusedIndexingMaps = std::move (fusedInsIndexingMaps);
411+ fusedIndexingMaps.append (fusedOutsIndexingMaps);
412+ auto fusedOp = rewriter.create <linalg::GenericOp>(
413+ loc, fusedResultTypes, fusedIns, fusedOuts, fusedIndexingMaps,
414+ fusedIteratorTypes, [](OpBuilder &, Location, ValueRange) {});
415+
416+ Block *fusedBody = fusedOp.getBlock ();
417+ auto insIndex = 0 ;
418+ auto outsIndex = fusedOp.getNumDpsInputs ();
419+ SmallVector<Value> yieldVals;
420+ for (auto op : linalgOps) {
421+ auto linalgOp = cast<linalg::LinalgOp>(op);
422+ Block *body = linalgOp.getBlock ();
423+ SmallVector<Value> replacements = llvm::map_to_vector (
424+ fusedBody->getArguments ().slice (insIndex, linalgOp.getNumDpsInputs ()),
425+ [](BlockArgument arg) -> Value { return arg; });
426+ appendRange (
427+ replacements,
428+ llvm::map_range (fusedBody->getArguments ().slice (
429+ outsIndex, linalgOp.getNumDpsInits ()),
430+ [](BlockArgument arg) -> Value { return arg; }));
431+
432+ rewriter.mergeBlocks (body, fusedBody, replacements);
433+ insIndex += linalgOp.getNumDpsInputs ();
434+ outsIndex += linalgOp.getNumDpsInits ();
435+
436+ auto yieldOp = cast<linalg::YieldOp>(fusedBody->getTerminator ());
437+ yieldVals.append (yieldOp->operand_begin (), yieldOp->operand_end ());
438+ rewriter.eraseOp (yieldOp);
439+ }
440+ OpBuilder::InsertionGuard g (rewriter);
441+ rewriter.setInsertionPointToEnd (fusedBody);
442+ rewriter.create <linalg::YieldOp>(loc, yieldVals);
443+
444+ auto resultsIndex = 0 ;
445+ for (auto linalgOp : linalgOps) {
446+ rewriter.replaceOp (linalgOp, fusedOp->getResults ().slice (
447+ resultsIndex, linalgOp->getNumResults ()));
448+ resultsIndex += linalgOp->getNumResults ();
449+ }
450+
451+ return fusedOp;
452+ }
453+
454+ static FailureOr<SmallVector<linalg::GenericOp>>
455+ fuseGroup2 (RewriterBase &rewriter, HorizontalFusionGroup &fusionGroup,
456+ DominanceInfo &dominanceInfo) {
457+ if (!llvm::all_of (fusionGroup.contractionOps ,
458+ [](Operation *op) { return isa<linalg::LinalgOp>(op); })) {
459+ return failure ();
460+ }
461+ linalg::LinalgOp baseContractOp =
462+ cast<linalg::LinalgOp>(fusionGroup.contractionOps .front ());
463+ Location loc = baseContractOp.getLoc ();
464+ OpBuilder::InsertionGuard g (rewriter);
465+ rewriter.setInsertionPoint (baseContractOp);
466+ SmallVector<linalg::GenericOp> fusedOperations;
467+
468+ if (failed (moveOperandDefs (rewriter, fusionGroup.contractionOps ,
469+ baseContractOp, dominanceInfo))) {
470+ return baseContractOp.emitOpError (" failed to re-order operand definitions" );
471+ }
472+
473+ FailureOr<linalg::GenericOp> fusedContractionOp =
474+ fuseHorizontally (rewriter, loc, fusionGroup.contractionOps );
475+ if (failed (fusedContractionOp)) {
476+ return baseContractOp.emitOpError (
477+ " failed to fuse contraction ops horizontally" );
478+ }
479+ fusedOperations.push_back (fusedContractionOp.value ());
480+
481+ // if (!fusionGroup.truncateOps) {
482+ // return fusedOperations;
483+ // }
484+
485+ // rewriter.setInsertionPoint(fusionGroup.truncateOps->front());
486+ // FailureOr<linalg::GenericOp> fusedTruncateOp =
487+ // fuseHorizontally(rewriter, loc, fusionGroup.truncateOps.value());
488+ // if (failed(fusedTruncateOp)) {
489+ // return baseContractOp.emitOpError(
490+ // "failed to fuse truncation ops horizontally");
491+ // }
492+ // fusedOperations.push_back(fusedTruncateOp.value());
493+ return fusedOperations;
391494}
392495
393496// / On finding this pattern
@@ -428,22 +531,23 @@ moveOperandDefs(RewriterBase &rewriter, ArrayRef<T> operations,
428531static LogicalResult fuseGroup (RewriterBase &rewriter,
429532 HorizontalFusionGroup &fusionGroup,
430533 DominanceInfo &dominanceInfo) {
431- linalg::LinalgOp baseContractOp = fusionGroup.contractionOps .front ();
534+ linalg::LinalgOp baseContractOp =
535+ cast<linalg::LinalgOp>(fusionGroup.contractionOps .front ());
432536 Location loc = baseContractOp.getLoc ();
433537 OpBuilder::InsertionGuard g (rewriter);
434538 rewriter.setInsertionPoint (baseContractOp);
435539
436- if (failed (moveOperandDefs (
437- rewriter, ArrayRef<linalg::LinalgOp>(fusionGroup.contractionOps ),
438- baseContractOp, dominanceInfo))) {
540+ if (failed (moveOperandDefs (rewriter, fusionGroup.contractionOps ,
541+ baseContractOp, dominanceInfo))) {
439542 return baseContractOp.emitOpError (" failed to re-order operand definitions" );
440543 }
441544
442545 SmallVector<Value> rhsValues;
443546 SmallVector<Value> initValues;
444547 for (auto op : fusionGroup.contractionOps ) {
445- Value rhs = op.getDpsInputOperand (1 )->get ();
446- Value init = op.getDpsInitOperand (0 )->get ();
548+ auto linalgOp = cast<linalg::LinalgOp>(op);
549+ Value rhs = linalgOp.getDpsInputOperand (1 )->get ();
550+ Value init = linalgOp.getDpsInitOperand (0 )->get ();
447551 rhsValues.push_back (rhs);
448552 initValues.push_back (init);
449553 }
@@ -485,15 +589,15 @@ static LogicalResult fuseGroup(RewriterBase &rewriter,
485589 if (fusionGroup.truncateOps ) {
486590 SmallVector<Value> newTruncOperands;
487591 SmallVector<AffineMap> newTruncIndexingMaps;
488- linalg::GenericOp baseTruncOp = fusionGroup.truncateOps ->front ();
592+ linalg::LinalgOp baseTruncOp =
593+ cast<linalg::LinalgOp>(fusionGroup.truncateOps ->front ());
489594 SmallVector<AffineMap> baseTruncOpIndexingMaps =
490595 baseTruncOp.getIndexingMapsArray ();
491596
492597 rewriter.setInsertionPoint (baseTruncOp);
493- if (failed (moveOperandDefs (
494- rewriter,
495- ArrayRef<linalg::GenericOp>(fusionGroup.truncateOps .value ()),
496- baseTruncOp, dominanceInfo, fusionGroup.contractionOps ))) {
598+ if (failed (moveOperandDefs (rewriter, fusionGroup.truncateOps .value (),
599+ baseTruncOp, dominanceInfo,
600+ fusionGroup.contractionOps ))) {
497601 return baseTruncOp.emitOpError (
498602 " failed to move operand defs for truncate operations" );
499603 }
@@ -503,7 +607,7 @@ static LogicalResult fuseGroup(RewriterBase &rewriter,
503607 // Collect all the operands for the trunc operation.
504608 SmallVector<Value> truncOperands;
505609 for (auto truncOp : fusionGroup.truncateOps .value ()) {
506- truncOperands.push_back (truncOp. getOperand (operandIndex));
610+ truncOperands.push_back (truncOp-> getOperand (operandIndex));
507611 }
508612
509613 // Three cases to handle here.
@@ -539,7 +643,7 @@ static LogicalResult fuseGroup(RewriterBase &rewriter,
539643
540644 // Insert truncate operator.
541645 auto baseTruncType =
542- cast<RankedTensorType>(baseTruncOp. getResult (0 ).getType ());
646+ cast<RankedTensorType>(baseTruncOp-> getResult (0 ).getType ());
543647 SmallVector<int64_t > newTruncShape = {
544648 static_cast <int64_t >(rhsValues.size ())};
545649 newTruncShape.append (baseTruncType.getShape ().begin (),
@@ -610,7 +714,7 @@ void FuseHorizontalContractionsPass::runOnOperation() {
610714 DominanceInfo dominanceInfo (getOperation ());
611715
612716 SmallVector<HorizontalFusionGroup> horizontalFusionGroups;
613- llvm::SmallDenseSet<linalg::LinalgOp > groupedOperations;
717+ llvm::SmallDenseSet<Operation * > groupedOperations;
614718
615719 getOperation ()->walk ([&](linalg::LinalgOp linalgOp) {
616720 if (!isEmptyFillContractionDAGRootOp (linalgOp)) {
@@ -653,38 +757,19 @@ void FuseHorizontalContractionsPass::runOnOperation() {
653757
654758 IRRewriter rewriter (context);
655759 for (auto &fusionGroup : horizontalFusionGroups) {
656- if (failed (fuseGroup (rewriter, fusionGroup, dominanceInfo))) {
760+ FailureOr<SmallVector<linalg::GenericOp>> fusedOperations =
761+ fuseGroup2 (rewriter, fusionGroup, dominanceInfo);
762+ if (failed (fusedOperations)) {
657763 return signalPassFailure ();
658764 }
659- }
660-
661- {
662- RewritePatternSet foldReshapePatterns (context);
663- tensor::populateFoldTensorEmptyPatterns (foldReshapePatterns);
664- linalg::FillOp::getCanonicalizationPatterns (foldReshapePatterns, context);
665- if (failed (applyPatternsGreedily (getOperation (),
666- std::move (foldReshapePatterns)))) {
667- getOperation ()->emitOpError (" failed during reshape folding patterns" );
668- return signalPassFailure ();
669- }
670-
671- RewritePatternSet foldPatterns (context);
672- tensor::populateFoldTensorEmptyPatterns (foldPatterns);
673- linalg::FillOp::getCanonicalizationPatterns (foldPatterns, context);
674- if (failed (
675- applyPatternsGreedily (getOperation (), std::move (foldPatterns)))) {
676- getOperation ()->emitOpError (" failed to fold empty/fill with concats" );
677- return signalPassFailure ();
765+ for (auto fusedOp : fusedOperations.value ()) {
766+ rewriter.setInsertionPoint (fusedOp);
767+ if (failed (linalg::deduplicateOperandsAndRemoveDeadResults (
768+ rewriter, fusedOp, /* removeOutputs=*/ false ))) {
769+ fusedOp->emitOpError (" failed to remove duplicate operands" );
770+ return signalPassFailure ();
771+ }
678772 }
679773 }
680-
681- // Note: Currently these patterns are required due to early lowering of
682- // tensor.concat. When we choose to move the lowering of tensor.concat later,
683- // these patterns should be dropped.
684- RewritePatternSet patterns (context);
685- tensor::populateDecomposeTensorConcatPatterns (patterns);
686- if (failed (applyPatternsGreedily (getOperation (), std::move (patterns)))) {
687- return signalPassFailure ();
688- }
689774}
690775} // namespace mlir::iree_compiler::DispatchCreation
0 commit comments