1515#include " mlir/Dialect/Affine/Analysis/AffineStructures.h"
1616#include " mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1717#include " mlir/Dialect/Affine/Analysis/Utils.h"
18- #include " mlir/Dialect/Affine/IR/AffineOps.h"
1918#include " mlir/Dialect/Affine/LoopFusionUtils.h"
2019#include " mlir/Dialect/Affine/LoopUtils.h"
2120#include " mlir/Dialect/Affine/Utils.h"
@@ -274,6 +273,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
274273 return firstAncestor;
275274}
276275
276+ // / Returns the amount of additional (redundant) computation that will be done
277+ // / as a fraction of the total computation if `srcForOp` is fused into
278+ // / `dstForOp` at depth `depth`. The method returns the compute cost of the
279+ // / slice and the fused nest's compute cost in the trailing output arguments.
280+ static std::optional<double > getAdditionalComputeFraction (
281+ AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
282+ ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
283+ int64_t &fusedLoopNestComputeCost) {
284+ LLVM_DEBUG (llvm::dbgs () << " Determining additional compute fraction...\n " ;);
285+ // Compute cost of sliced and unsliced src loop nest.
286+ // Walk src loop nest and collect stats.
287+ LoopNestStats srcLoopNestStats;
288+ if (!getLoopNestStats (srcForOp, &srcLoopNestStats)) {
289+ LLVM_DEBUG (llvm::dbgs () << " Failed to get source loop nest stats.\n " );
290+ return std::nullopt ;
291+ }
292+
293+ // Compute cost of dst loop nest.
294+ LoopNestStats dstLoopNestStats;
295+ if (!getLoopNestStats (dstForOp, &dstLoopNestStats)) {
296+ LLVM_DEBUG (llvm::dbgs () << " Failed to get destination loop nest stats.\n " );
297+ return std::nullopt ;
298+ }
299+
300+ // Compute op instance count for the src loop nest without iteration slicing.
301+ uint64_t srcLoopNestCost = getComputeCost (srcForOp, srcLoopNestStats);
302+
303+ // Compute op cost for the dst loop nest.
304+ uint64_t dstLoopNestCost = getComputeCost (dstForOp, dstLoopNestStats);
305+
306+ const ComputationSliceState &slice = depthSliceUnions[depth - 1 ];
307+ // Skip slice union if it wasn't computed for this depth.
308+ if (slice.isEmpty ()) {
309+ LLVM_DEBUG (llvm::dbgs () << " Slice wasn't computed.\n " );
310+ return std::nullopt ;
311+ }
312+
313+ if (!getFusionComputeCost (srcForOp, srcLoopNestStats, dstForOp,
314+ dstLoopNestStats, slice,
315+ &fusedLoopNestComputeCost)) {
316+ LLVM_DEBUG (llvm::dbgs () << " Unable to compute fusion compute cost\n " );
317+ return std::nullopt ;
318+ }
319+
320+ double additionalComputeFraction =
321+ fusedLoopNestComputeCost /
322+ (static_cast <double >(srcLoopNestCost) + dstLoopNestCost) -
323+ 1 ;
324+
325+ return additionalComputeFraction;
326+ }
327+
277328// Creates and returns a private (single-user) memref for fused loop rooted at
278329// 'forOp', with (potentially reduced) memref size based on the memref region
279330// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
@@ -384,20 +435,19 @@ static Value createPrivateMemRef(AffineForOp forOp,
384435}
385436
386437// Checks the profitability of fusing a backwards slice of the loop nest
387- // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
388- // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
389- // the memref being produced and consumed, which is an input to the cost model.
390- // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
391- // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
392- // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
393- // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
394- // unique store op in the src node, which will be used to check that the write
395- // region is the same after input-reuse fusion. Computation slices are provided
396- // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
397- // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
398- // profitable to fuse the candidate loop nests. Returns false otherwise.
399- // `dstLoopDepth` is set to the most profitable depth at which to materialize
400- // the source loop nest slice.
438+ // `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
439+ // 'srcStoreOpInst' is used to calculate the storage reduction on the memref
440+ // being produced and consumed, which is an input to the cost model. For
441+ // producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
442+ // as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
443+ // will be the src loop nest LoadOp which reads from the same memref as dst loop
444+ // nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
445+ // node, which will be used to check that the write region is the same after
446+ // input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
447+ // each legal fusion depth. The maximal depth at which fusion is legal is
448+ // provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
449+ // the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
450+ // the most profitable depth at which to materialize the source loop nest slice.
401451// The profitability model executes the following steps:
402452// *) Computes the backward computation slice at 'srcOpInst'. This
403453// computation slice of the loop nest surrounding 'srcOpInst' is
@@ -422,15 +472,16 @@ static Value createPrivateMemRef(AffineForOp forOp,
422472// is lower.
423473// TODO: Extend profitability analysis to support scenarios with multiple
424474// stores.
425- static bool isFusionProfitable (Operation *srcOpInst , Operation *srcStoreOpInst,
475+ static bool isFusionProfitable (AffineForOp srcForOp , Operation *srcStoreOpInst,
426476 AffineForOp dstForOp,
427477 ArrayRef<ComputationSliceState> depthSliceUnions,
428478 unsigned maxLegalFusionDepth,
429479 unsigned *dstLoopDepth,
430480 double computeToleranceThreshold) {
431481 LLVM_DEBUG ({
432- llvm::dbgs () << " Checking whether fusion is profitable between src op:\n " ;
433- llvm::dbgs () << ' ' << *srcOpInst << " and destination loop:\n " ;
482+ llvm::dbgs ()
483+ << " Checking whether fusion is profitable between source nest:\n " ;
484+ llvm::dbgs () << ' ' << srcForOp << " and destination nest:\n " ;
434485 llvm::dbgs () << dstForOp << " \n " ;
435486 });
436487
@@ -440,12 +491,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
440491 }
441492
442493 // Compute cost of sliced and unsliced src loop nest.
443- SmallVector<AffineForOp, 4 > srcLoopIVs;
444- getAffineForIVs (*srcOpInst, &srcLoopIVs);
445494
446495 // Walk src loop nest and collect stats.
447496 LoopNestStats srcLoopNestStats;
448- if (!getLoopNestStats (srcLoopIVs[ 0 ] , &srcLoopNestStats))
497+ if (!getLoopNestStats (srcForOp , &srcLoopNestStats))
449498 return false ;
450499
451500 // Compute cost of dst loop nest.
@@ -467,7 +516,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
467516 std::optional<unsigned > bestDstLoopDepth;
468517
469518 // Compute op instance count for the src loop nest without iteration slicing.
470- uint64_t srcLoopNestCost = getComputeCost (srcLoopIVs[ 0 ] , srcLoopNestStats);
519+ uint64_t srcLoopNestCost = getComputeCost (srcForOp , srcLoopNestStats);
471520
472521 // Compute src loop nest write region size.
473522 MemRefRegion srcWriteRegion (srcStoreOpInst->getLoc ());
@@ -494,18 +543,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
494543 if (slice.isEmpty ())
495544 continue ;
496545
546+ // Compute cost of the slice separately, i.e, the compute cost of the slice
547+ // if all outer trip counts are one.
548+ int64_t sliceCost;
549+
497550 int64_t fusedLoopNestComputeCost;
498- if (!getFusionComputeCost (srcLoopIVs[0 ], srcLoopNestStats, dstForOp,
499- dstLoopNestStats, slice,
500- &fusedLoopNestComputeCost)) {
501- LLVM_DEBUG (llvm::dbgs () << " Unable to compute fusion compute cost\n " );
551+
552+ auto mayAdditionalComputeFraction =
553+ getAdditionalComputeFraction (srcForOp, dstForOp, i, depthSliceUnions,
554+ sliceCost, fusedLoopNestComputeCost);
555+ if (!mayAdditionalComputeFraction) {
556+ LLVM_DEBUG (llvm::dbgs ()
557+ << " Can't determine additional compute fraction.\n " );
502558 continue ;
503559 }
504-
505- double additionalComputeFraction =
506- fusedLoopNestComputeCost /
507- (static_cast <double >(srcLoopNestCost) + dstLoopNestCost) -
508- 1 ;
560+ double additionalComputeFraction = *mayAdditionalComputeFraction;
509561
510562 // Determine what the slice write MemRefRegion would be, if the src loop
511563 // nest slice 'slice' were to be inserted into the dst loop nest at loop
@@ -530,14 +582,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
530582 }
531583 int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
532584
533- // If we are fusing for reuse, check that write regions remain the same.
534- // TODO: Write region check should check sizes and offsets in
535- // each dimension, so that we are sure they are covering the same memref
536- // region. Also, move this out to a isMemRefRegionSuperSet helper function.
537- if (srcOpInst != srcStoreOpInst &&
538- sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
539- continue ;
540-
541585 double storageReduction = static_cast <double >(srcWriteRegionSizeBytes) /
542586 static_cast <double >(sliceWriteRegionSizeBytes);
543587
@@ -595,7 +639,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
595639 << minFusedLoopNestComputeCost << " \n " );
596640
597641 auto dstMemSize = getMemoryFootprintBytes (dstForOp);
598- auto srcMemSize = getMemoryFootprintBytes (srcLoopIVs[ 0 ] );
642+ auto srcMemSize = getMemoryFootprintBytes (srcForOp );
599643
600644 std::optional<double > storageReduction;
601645
@@ -840,6 +884,8 @@ struct GreedyFusion {
840884 LLVM_DEBUG (llvm::dbgs ()
841885 << " Trying to fuse producer loop nest " << srcId
842886 << " with consumer loop nest " << dstId << " \n " );
887+ LLVM_DEBUG (llvm::dbgs () << " Compute tolerance threshold: "
888+ << computeToleranceThreshold << ' \n ' );
843889 LLVM_DEBUG (llvm::dbgs ()
844890 << " Producer loop nest:\n "
845891 << *srcNode->op << " \n and consumer loop nest:\n "
@@ -926,6 +972,9 @@ struct GreedyFusion {
926972 continue ;
927973 }
928974
975+ LLVM_DEBUG (llvm::dbgs () << " Max legal depth for fusion: "
976+ << maxLegalFusionDepth << ' \n ' );
977+
929978 // Check if fusion would be profitable. We skip profitability analysis
930979 // for maximal fusion since we already know the maximal legal depth to
931980 // fuse.
@@ -945,14 +994,28 @@ struct GreedyFusion {
945994 // if only one of the stores is involved the producer-consumer
946995 // relationship of the candidate loops.
947996 assert (!producerStores.empty () && " Expected producer store" );
948- if (producerStores.size () > 1 )
997+ if (producerStores.size () > 1 ) {
949998 LLVM_DEBUG (llvm::dbgs () << " Skipping profitability analysis. Not "
950999 " supported for this case\n " );
951- else if (!isFusionProfitable (producerStores[0 ], producerStores[0 ],
952- dstAffineForOp, depthSliceUnions,
953- maxLegalFusionDepth, &bestDstLoopDepth,
954- computeToleranceThreshold))
1000+ // We will still fuse if fusion obeys the specified compute
1001+ // tolerance at the max legal depth.
1002+ int64_t sliceCost;
1003+ int64_t fusedLoopNestComputeCost;
1004+ auto fraction = getAdditionalComputeFraction (
1005+ srcAffineForOp, dstAffineForOp, maxLegalFusionDepth,
1006+ depthSliceUnions, sliceCost, fusedLoopNestComputeCost);
1007+ if (!fraction || fraction > computeToleranceThreshold) {
1008+ LLVM_DEBUG (llvm::dbgs () << " Additional computation exceeds "
1009+ " compute tolerance. Not fusing.\n " );
1010+ continue ;
1011+ }
1012+ }
1013+ if (!isFusionProfitable (srcAffineForOp, producerStores[0 ],
1014+ dstAffineForOp, depthSliceUnions,
1015+ maxLegalFusionDepth, &bestDstLoopDepth,
1016+ computeToleranceThreshold)) {
9551017 continue ;
1018+ }
9561019 }
9571020
9581021 assert (bestDstLoopDepth > 0 && " Unexpected loop fusion depth" );
@@ -1169,7 +1232,7 @@ struct GreedyFusion {
11691232 // load op is treated as the src "store" op for fusion profitability
11701233 // purposes. The footprint of the load in the slice relative to the
11711234 // unfused source's determines reuse.
1172- if (!isFusionProfitable (sibLoadOpInst , sibLoadOpInst, dstAffineForOp,
1235+ if (!isFusionProfitable (sibAffineForOp , sibLoadOpInst, dstAffineForOp,
11731236 depthSliceUnions, maxLegalFusionDepth,
11741237 &bestDstLoopDepth, computeToleranceThreshold))
11751238 continue ;
0 commit comments