1515#include " mlir/Dialect/Affine/Analysis/AffineStructures.h"
1616#include " mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1717#include " mlir/Dialect/Affine/Analysis/Utils.h"
18- #include " mlir/Dialect/Affine/IR/AffineOps.h"
1918#include " mlir/Dialect/Affine/LoopFusionUtils.h"
2019#include " mlir/Dialect/Affine/LoopUtils.h"
2120#include " mlir/Dialect/Affine/Utils.h"
@@ -473,7 +472,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
473472// is lower.
474473// TODO: Extend profitability analysis to support scenarios with multiple
475474// stores.
476- static bool isFusionProfitable (AffineForOp srcForOp, Operation *srcStoreOpInst,
475+ static bool isFusionProfitable (AffineForOp srcForOp,
476+ ArrayRef<Operation *> producerStores,
477477 AffineForOp dstForOp,
478478 ArrayRef<ComputationSliceState> depthSliceUnions,
479479 unsigned maxLegalFusionDepth,
@@ -503,6 +503,35 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
503503 if (!getLoopNestStats (dstForOp, &dstLoopNestStats))
504504 return false ;
505505
506+ // We limit profitability analysis to only scenarios with
507+ // a single producer store for now. Note that some multi-store
508+ // producer scenarios will still go through profitability analysis
509+ // if only one of the stores is involved in the producer-consumer
510+ // relationship of the candidate loops.
511+ // TODO: Suppport multiple producer stores in profitability
512+ // analysis.
513+ if (producerStores.size () > 1 ) {
514+ LLVM_DEBUG (llvm::dbgs () << " Limited profitability analysis. Not "
515+ " supported for multiple producer store case.\n " );
516+ int64_t sliceCost;
517+ int64_t fusedLoopNestComputeCost;
518+ // We will still fuse if fusion obeys the specified compute
519+ // tolerance at the max legal depth.
520+ auto fraction = getAdditionalComputeFraction (
521+ srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
522+ fusedLoopNestComputeCost);
523+ if (!fraction || fraction > computeToleranceThreshold) {
524+ LLVM_DEBUG (llvm::dbgs () << " Additional computation exceeds "
525+ " compute tolerance. Not fusing.\n " );
526+ return false ;
527+ }
528+ LLVM_DEBUG (llvm::dbgs ()
529+ << " Considering fusion profitable at max legal depth.\n " );
530+ return true ;
531+ }
532+
533+ Operation *srcStoreOp = producerStores.front ();
534+
506535 // Search for min cost value for 'dstLoopDepth'. At each value of
507536 // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
508537 // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
@@ -516,12 +545,9 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
516545 // The best loop depth at which to materialize the slice.
517546 std::optional<unsigned > bestDstLoopDepth;
518547
519- // Compute op instance count for the src loop nest without iteration slicing.
520- uint64_t srcLoopNestCost = getComputeCost (srcForOp, srcLoopNestStats);
521-
522548 // Compute src loop nest write region size.
523- MemRefRegion srcWriteRegion (srcStoreOpInst ->getLoc ());
524- if (failed (srcWriteRegion.compute (srcStoreOpInst , /* loopDepth=*/ 0 ))) {
549+ MemRefRegion srcWriteRegion (srcStoreOp ->getLoc ());
550+ if (failed (srcWriteRegion.compute (srcStoreOp , /* loopDepth=*/ 0 ))) {
525551 LLVM_DEBUG (llvm::dbgs ()
526552 << " Unable to compute MemRefRegion for source operation\n " );
527553 return false ;
@@ -533,7 +559,10 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
533559 return false ;
534560 int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
535561
536- // Compute op instance count for the src loop nest.
562+ // Compute op instance count for the src loop nest without iteration slicing.
563+ uint64_t srcLoopNestCost = getComputeCost (srcForOp, srcLoopNestStats);
564+
565+ // Compute op instance count for the destination loop nest.
537566 uint64_t dstLoopNestCost = getComputeCost (dstForOp, dstLoopNestStats);
538567
539568 // Evaluate all depth choices for materializing the slice in the destination
@@ -563,9 +592,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
563592 // Determine what the slice write MemRefRegion would be, if the src loop
564593 // nest slice 'slice' were to be inserted into the dst loop nest at loop
565594 // depth 'i'.
566- MemRefRegion sliceWriteRegion (srcStoreOpInst->getLoc ());
567- if (failed (sliceWriteRegion.compute (srcStoreOpInst, /* loopDepth=*/ 0 ,
568- &slice))) {
595+ MemRefRegion sliceWriteRegion (srcStoreOp->getLoc ());
596+ if (failed (sliceWriteRegion.compute (srcStoreOp, /* loopDepth=*/ 0 , &slice))) {
569597 LLVM_DEBUG (llvm::dbgs ()
570598 << " Failed to compute slice write region at loopDepth: " << i
571599 << " \n " );
@@ -1025,21 +1053,13 @@ struct GreedyFusion {
10251053 cast<AffineWriteOpInterface>(op).getMemRef ()))
10261054 producerStores.push_back (op);
10271055
1028- // TODO: Suppport multiple producer stores in profitability
1029- // analysis. We limit profitability analysis to only scenarios with
1030- // a single producer store for now. Note that some multi-store
1031- // producer scenarios will still go through profitability analysis
1032- // if only one of the stores is involved the producer-consumer
1033- // relationship of the candidate loops.
10341056 assert (!producerStores.empty () && " Expected producer store" );
1035- if (producerStores.size () > 1 )
1036- LLVM_DEBUG (llvm::dbgs () << " Skipping profitability analysis. Not "
1037- " supported for this case\n " );
1038- else if (!isFusionProfitable (srcAffineForOp, producerStores[0 ],
1039- dstAffineForOp, depthSliceUnions,
1040- maxLegalFusionDepth, &bestDstLoopDepth,
1041- computeToleranceThresholdToUse))
1057+ if (!isFusionProfitable (srcAffineForOp, producerStores,
1058+ dstAffineForOp, depthSliceUnions,
1059+ maxLegalFusionDepth, &bestDstLoopDepth,
1060+ computeToleranceThresholdToUse)) {
10421061 continue ;
1062+ }
10431063 }
10441064
10451065 assert (bestDstLoopDepth > 0 && " Unexpected loop fusion depth" );
0 commit comments