From 69282571bcf243b172009821aef5ec82320936bb Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Mon, 2 Nov 2020 14:13:06 -0800 Subject: [PATCH] [mlir][Affine] Align affine fusion code in pass and utilities This patch is a refactoring/clean-up step that is needed to add support for producer-consumer fusion with producer loops with multiple stores. It introduces the following changes: - AffineLoopFusion pass now uses loop fusion utilities more broadly to compute fusion legality (canFuseLoops utility) and perform the fusion transformation (fuseLoops utility). - Loop fusion utilities have been extended to deal with AffineLoopFusion requirements and assumptions while preserving both loop fusion utilities and AffineLoopFusion current functionality within a unified implementation. This integration will be improved in the future when AffineLoopFusion supports more generic cases (WIP). - Improve separation of concerns for legality and profitability analysis: 'isFusionProfitable' no longer filters out illegal scenarios that 'canFuse' didn't detect, or the other way around. 'canFuse' now takes loop dependences into account to determine the fusion loop depth (producer-consumer fusion only). As a result, maximal fusion now doesn't require any profitability analysis. - Computation slices are now computed only once and reused across the legality, profitability and fusion transformation steps (producer-consumer). - Refactor some utilities and remove redundant copies of them. Despite all these changes, this patch is NFCI and should preserve the existing functionality of both the AffineLoopFusion pass and the affine fusion utilities. --- mlir/include/mlir/Analysis/Utils.h | 12 +- .../include/mlir/Transforms/LoopFusionUtils.h | 31 +- mlir/lib/Analysis/Utils.cpp | 38 +- mlir/lib/Transforms/LoopFusion.cpp | 503 +++++++----------- mlir/lib/Transforms/Utils/LoopFusionUtils.cpp | 170 +++++- mlir/test/lib/Transforms/TestLoopFusion.cpp | 2 +- 6 files changed, 421 insertions(+), 335 deletions(-) diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index b502d909d5c0e..30b6272181f52 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -82,6 +82,11 @@ struct ComputationSliceState { // Clears all bounds and operands in slice state. void clearBounds(); + + /// Return true if the computation slice is empty. + bool isEmpty() const { return ivs.empty(); } + + void dump() const; }; /// Computes the computation slice loop bounds for one loop nest as affine maps @@ -212,7 +217,7 @@ struct MemRefRegion { /// The last field is a 2-d FlatAffineConstraints symbolic in %i. /// LogicalResult compute(Operation *op, unsigned loopDepth, - ComputationSliceState *sliceState = nullptr, + const ComputationSliceState *sliceState = nullptr, bool addMemRefDimBounds = true); FlatAffineConstraints *getConstraints() { return &cst; } @@ -309,6 +314,11 @@ bool isLoopParallel(AffineForOp forOp); /// number of constraints. IntegerSet simplifyIntegerSet(IntegerSet set); +/// Returns the innermost common loop depth for the set of operations in 'ops'. +unsigned getInnermostCommonLoopDepth( + ArrayRef ops, + SmallVectorImpl *surroundingLoops = nullptr); + } // end namespace mlir #endif // MLIR_ANALYSIS_UTILS_H diff --git a/mlir/include/mlir/Transforms/LoopFusionUtils.h b/mlir/include/mlir/Transforms/LoopFusionUtils.h index 36d2520b7c85e..313cb5171e75e 100644 --- a/mlir/include/mlir/Transforms/LoopFusionUtils.h +++ b/mlir/include/mlir/Transforms/LoopFusionUtils.h @@ -15,6 +15,7 @@ #ifndef MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H #define MLIR_TRANSFORMS_LOOP_FUSION_UTILS_H +#include "mlir/IR/Value.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" @@ -38,6 +39,24 @@ struct FusionResult { FusionResult(ResultEnum v) : value(v) {} }; +/// Temporary enum to distinguish between the different fusion strategies +/// implemented in Affine. It is used to specialized the loop fusion utilities +/// with the assumptions made in the AffineLoopFusion pass while sharing a +/// single implementation. +// TODO: Remove this enum once the producer-consumer and sibling loop fusion +// strategies in AffineLoopFusion pass are generic enough. +struct FusionStrategy { + enum StrategyEnum { + None, // Generic fusion. No assumtions are made. + ProducerConsumer, // Producer-consumer fusion from AffineLoopFusion pass. + Sibling // Sibling fusion from AffineLoopFusion pass. + } strategy; + + Value memref; + FusionStrategy(StrategyEnum strategy, Value memref) + : strategy(strategy), memref(memref) {} +}; + /// Checks the feasibility of fusing the loop nest rooted at 'srcForOp' into the /// loop nest rooted at 'dstForOp' at 'dstLoopDepth'. Returns FusionResult /// 'Success' if fusion of the src/dst loop nests is feasible (i.e. they are @@ -46,14 +65,15 @@ struct FusionResult { /// NOTE: This function is not feature complete and should only be used in /// testing. /// TODO: Update comments when this function is fully implemented. -FusionResult canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, - unsigned dstLoopDepth, - ComputationSliceState *srcSlice); +FusionResult +canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, + ComputationSliceState *srcSlice, + FusionStrategy fusionStrategy = {FusionStrategy::None, Value()}); /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point /// and source slice loop bounds specified in 'srcSlice'. void fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, - ComputationSliceState *srcSlice); + const ComputationSliceState &srcSlice); /// LoopNestStats aggregates various per-loop statistics (eg. loop trip count /// and operation count) for a loop nest up until (and including) the innermost @@ -89,7 +109,8 @@ int64_t getComputeCost(AffineForOp forOp, LoopNestStats &stats); // TODO: Improve this cost model. bool getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, - ComputationSliceState *slice, int64_t *computeCost); + const ComputationSliceState &slice, + int64_t *computeCost); } // end namespace mlir diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index b02212a09bba6..35678432dd16d 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -105,6 +105,28 @@ void ComputationSliceState::clearBounds() { ubOperands.clear(); } +void ComputationSliceState::dump() const { + llvm::errs() << "\tIVs:\n"; + for (Value iv : ivs) + llvm::errs() << "\t\t" << iv << "\n"; + + llvm::errs() << "\tLBs:\n"; + for (auto &en : llvm::enumerate(lbs)) { + llvm::errs() << "\t\t" << en.value() << "\n"; + llvm::errs() << "\t\tOperands:\n"; + for (Value lbOp : lbOperands[en.index()]) + llvm::errs() << "\t\t\t" << lbOp << "\n"; + } + + llvm::errs() << "\tUBs:\n"; + for (auto &en : llvm::enumerate(ubs)) { + llvm::errs() << "\t\t" << en.value() << "\n"; + llvm::errs() << "\t\tOperands:\n"; + for (Value ubOp : ubOperands[en.index()]) + llvm::errs() << "\t\t\t" << ubOp << "\n"; + } +} + unsigned MemRefRegion::getRank() const { return memref.getType().cast().getRank(); } @@ -211,7 +233,7 @@ LogicalResult MemRefRegion::unionBoundingBox(const MemRefRegion &other) { // TODO: extend this to any other memref dereferencing ops // (dma_start, dma_wait). LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, - ComputationSliceState *sliceState, + const ComputationSliceState *sliceState, bool addMemRefDimBounds) { assert((isa(op)) && "affine read/write op expected"); @@ -541,13 +563,12 @@ static LogicalResult addMissingLoopIVBounds(SmallPtrSet &ivs, return success(); } -// Returns the innermost common loop depth for the set of operations in 'ops'. +/// Returns the innermost common loop depth for the set of operations in 'ops'. // TODO: Move this to LoopUtils. -static unsigned -getInnermostCommonLoopDepth(ArrayRef ops, - SmallVectorImpl &surroundingLoops) { +unsigned mlir::getInnermostCommonLoopDepth( + ArrayRef ops, SmallVectorImpl *surroundingLoops) { unsigned numOps = ops.size(); - assert(numOps > 0); + assert(numOps > 0 && "Expected at least one operation"); std::vector> loops(numOps); unsigned loopDepthLimit = std::numeric_limits::max(); @@ -564,7 +585,8 @@ getInnermostCommonLoopDepth(ArrayRef ops, if (loops[i - 1][d] != loops[i][d]) return loopDepth; } - surroundingLoops.push_back(loops[i - 1][d]); + if (surroundingLoops) + surroundingLoops->push_back(loops[i - 1][d]); ++loopDepth; } return loopDepth; @@ -684,7 +706,7 @@ LogicalResult mlir::computeSliceUnion(ArrayRef opsA, } SmallVector surroundingLoops; unsigned innermostCommonLoopDepth = - getInnermostCommonLoopDepth(ops, surroundingLoops); + getInnermostCommonLoopDepth(ops, &surroundingLoops); if (loopDepth > innermostCommonLoopDepth) { LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n"); return failure(); diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index ed79be02b8165..1075486378467 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -741,77 +741,6 @@ static void moveLoadsAccessingMemrefTo(Value memref, srcLoads->swap(srcLoadsToKeep); } -// Returns the innermost common loop depth for the set of operations in 'ops'. -static unsigned getInnermostCommonLoopDepth(ArrayRef ops) { - unsigned numOps = ops.size(); - assert(numOps > 0); - - std::vector> loops(numOps); - unsigned loopDepthLimit = std::numeric_limits::max(); - for (unsigned i = 0; i < numOps; ++i) { - getLoopIVs(*ops[i], &loops[i]); - loopDepthLimit = - std::min(loopDepthLimit, static_cast(loops[i].size())); - } - - unsigned loopDepth = 0; - for (unsigned d = 0; d < loopDepthLimit; ++d) { - unsigned i; - for (i = 1; i < numOps; ++i) { - if (loops[i - 1][d] != loops[i][d]) - break; - } - if (i != numOps) - break; - ++loopDepth; - } - return loopDepth; -} - -// Returns the maximum loop depth at which no dependences between 'loadOpInsts' -// and 'storeOpInsts' are satisfied. -static unsigned getMaxLoopDepth(ArrayRef loadOpInsts, - ArrayRef storeOpInsts) { - // Merge loads and stores into the same array. - SmallVector ops(loadOpInsts.begin(), loadOpInsts.end()); - ops.append(storeOpInsts.begin(), storeOpInsts.end()); - - // Compute the innermost common loop depth for loads and stores. - unsigned loopDepth = getInnermostCommonLoopDepth(ops); - - // Return common loop depth for loads if there are no store ops. - if (storeOpInsts.empty()) - return loopDepth; - - // Check dependences on all pairs of ops in 'ops' and store the minimum - // loop depth at which a dependence is satisfied. - for (unsigned i = 0, e = ops.size(); i < e; ++i) { - auto *srcOpInst = ops[i]; - MemRefAccess srcAccess(srcOpInst); - for (unsigned j = 0; j < e; ++j) { - auto *dstOpInst = ops[j]; - MemRefAccess dstAccess(dstOpInst); - - unsigned numCommonLoops = - getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); - for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { - FlatAffineConstraints dependenceConstraints; - // TODO: Cache dependence analysis results, check cache here. - DependenceResult result = checkMemrefAccessDependence( - srcAccess, dstAccess, d, &dependenceConstraints, - /*dependenceComponents=*/nullptr); - if (hasDependence(result)) { - // Store minimum loop depth and break because we want the min 'd' at - // which there is a dependence. - loopDepth = std::min(loopDepth, d - 1); - break; - } - } - } - } - return loopDepth; -} - // Sinks all sequential loops to the innermost levels (while preserving // relative order among them) and moves all parallel loops to the // outermost (while again preserving relative order among them). @@ -1112,9 +1041,9 @@ canFuseSrcWhichWritesToLiveOut(unsigned srcId, unsigned dstId, // is lower. static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, ArrayRef dstLoadOpInsts, - ArrayRef dstStoreOpInsts, - ComputationSliceState *sliceState, - unsigned *dstLoopDepth, bool maximalFusion, + ArrayRef depthSliceUnions, + unsigned maxLegalFusionDepth, + unsigned *dstLoopDepth, double computeToleranceThreshold) { LLVM_DEBUG({ llvm::dbgs() << "Checking whether fusion is profitable between src op:\n"; @@ -1124,10 +1053,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, }; }); + if (maxLegalFusionDepth == 0) { + LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth == 0 .\n"); + return false; + } + // Compute cost of sliced and unsliced src loop nest. SmallVector srcLoopIVs; getLoopIVs(*srcOpInst, &srcLoopIVs); - unsigned numSrcLoopIVs = srcLoopIVs.size(); // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; @@ -1142,19 +1075,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, if (!getLoopNestStats(dstLoopIVs[0], &dstLoopNestStats)) return false; - // Compute the maximum loop depth at which we can can insert the src slice - // and still satisfy dest loop nest dependences, for producer-consumer fusion. - unsigned maxDstLoopDepth = - (srcOpInst == srcStoreOpInst) - ? getMaxLoopDepth(dstLoadOpInsts, dstStoreOpInsts) - : dstLoopIVs.size(); - if (maxDstLoopDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxDstLoopDepth == 0 .\n"); - return false; - } - // Search for min cost value for 'dstLoopDepth'. At each value of - // 'dstLoopDepth' from 'maxDstLoopDepth' to '1', compute computation slice + // 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice // bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union // of these bounds). Next the union slice bounds are used to calculate // the cost of the slice and the cost of the slice inserted into the dst @@ -1163,8 +1085,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, double maxStorageReduction = 0.0; Optional sliceMemEstimate = None; - SmallVector sliceStates; - sliceStates.resize(maxDstLoopDepth); // The best loop depth at which to materialize the slice. Optional bestDstLoopDepth = None; @@ -1190,21 +1110,14 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // Evaluate all depth choices for materializing the slice in the destination // loop nest. - for (unsigned i = maxDstLoopDepth; i >= 1; --i) { - // Compute the union of slice bounds of all ops in 'dstLoadOpInsts'. - if (failed(mlir::computeSliceUnion({srcOpInst}, dstLoadOpInsts, - /*loopDepth=*/i, - /*numCommonLoops=*/0, - /*isBackwardSlice=*/true, - &sliceStates[i - 1]))) { - LLVM_DEBUG(llvm::dbgs() - << "computeSliceUnion failed for loopDepth: " << i << "\n"); + for (unsigned i = maxLegalFusionDepth; i >= 1; --i) { + // Skip slice union if it wasn't computed for this depth. + if (depthSliceUnions[i - 1].isEmpty()) continue; - } int64_t fusedLoopNestComputeCost; if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstLoopIVs[0], - dstLoopNestStats, &sliceStates[i - 1], + dstLoopNestStats, depthSliceUnions[i - 1], &fusedLoopNestComputeCost)) { LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost.\n."); continue; @@ -1216,11 +1129,11 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, 1; // Determine what the slice write MemRefRegion would be, if the src loop - // nest slice 'sliceStates[i - 1]' were to be inserted into the dst loop - // nest at loop depth 'i' + // nest slice 'depthSliceUnions[i - 1]' were to be inserted into the dst + // loop nest at loop depth 'i'. MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc()); if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0, - &sliceStates[i - 1]))) { + &depthSliceUnions[i - 1]))) { LLVM_DEBUG(llvm::dbgs() << "Failed to compute slice write region at loopDepth: " << i << "\n"); @@ -1260,7 +1173,8 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, << " fused nest cost: " << fusedLoopNestComputeCost << "\n" << " src write region size: " << srcWriteRegionSizeBytes << "\n" << " slice write region size: " << sliceWriteRegionSizeBytes - << "\n"; + << "\n src loop nest compute cost: " << srcLoopNestCost + << "\n dst loop nest compute cost: " << dstLoopNestCost << "\n"; llvm::dbgs() << msg.str(); }); @@ -1269,8 +1183,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, // (as per computeToleranceThreshold), we will simply pick the one that // reduces the intermediary size the most. if ((storageReduction > maxStorageReduction) && - (maximalFusion || - (additionalComputeFraction < computeToleranceThreshold))) { + (additionalComputeFraction < computeToleranceThreshold)) { maxStorageReduction = storageReduction; bestDstLoopDepth = i; minFusedLoopNestComputeCost = fusedLoopNestComputeCost; @@ -1278,10 +1191,9 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, } } - // A simple cost model: fuse if it reduces the memory footprint. If - // -maximal-fusion is set, fuse nevertheless. + // A simple cost model: fuse if it reduces the memory footprint. - if (!maximalFusion && !bestDstLoopDepth.hasValue()) { + if (!bestDstLoopDepth.hasValue()) { LLVM_DEBUG( llvm::dbgs() << "All fusion choices involve more than the threshold amount of " @@ -1310,33 +1222,30 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, Optional storageReduction = None; - if (!maximalFusion) { - if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { - LLVM_DEBUG( - llvm::dbgs() - << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); - return false; - } + if (!dstMemSize.hasValue() || !srcMemSize.hasValue()) { + LLVM_DEBUG(llvm::dbgs() + << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + return false; + } - auto srcMemSizeVal = srcMemSize.getValue(); - auto dstMemSizeVal = dstMemSize.getValue(); + auto srcMemSizeVal = srcMemSize.getValue(); + auto dstMemSizeVal = dstMemSize.getValue(); - assert(sliceMemEstimate.hasValue() && "expected value"); - auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); + assert(sliceMemEstimate.hasValue() && "expected value"); + auto fusedMem = dstMemSizeVal + sliceMemEstimate.getValue(); - LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" - << " dst mem: " << dstMemSizeVal << "\n" - << " fused mem: " << fusedMem << "\n" - << " slice mem: " << sliceMemEstimate << "\n"); + LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" + << " dst mem: " << dstMemSizeVal << "\n" + << " fused mem: " << fusedMem << "\n" + << " slice mem: " << sliceMemEstimate << "\n"); - if (static_cast(fusedMem) > srcMemSizeVal + dstMemSizeVal) { - LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); - return false; - } - storageReduction = - 100.0 * - (1.0 - fusedMem / (static_cast(srcMemSizeVal) + dstMemSizeVal)); + if (static_cast(fusedMem) > srcMemSizeVal + dstMemSizeVal) { + LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + return false; } + storageReduction = + 100.0 * + (1.0 - fusedMem / (static_cast(srcMemSizeVal) + dstMemSizeVal)); double additionalComputeFraction = 100.0 * (minFusedLoopNestComputeCost / @@ -1355,24 +1264,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst, llvm::dbgs() << msg.str(); }); - // Update return parameter 'sliceState' with 'bestSliceState'. - ComputationSliceState *bestSliceState = &sliceStates[*dstLoopDepth - 1]; - sliceState->lbs = bestSliceState->lbs; - sliceState->ubs = bestSliceState->ubs; - sliceState->lbOperands = bestSliceState->lbOperands; - sliceState->ubOperands = bestSliceState->ubOperands; - - // Canonicalize slice bound affine maps. - for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - if (sliceState->lbs[i] != AffineMap()) { - canonicalizeMapAndOperands(&sliceState->lbs[i], - &sliceState->lbOperands[i]); - } - if (sliceState->ubs[i] != AffineMap()) { - canonicalizeMapAndOperands(&sliceState->ubs[i], - &sliceState->ubOperands[i]); - } - } return true; } @@ -1592,138 +1483,143 @@ struct GreedyFusion { if (insertPointInst == nullptr) continue; + auto srcAffineForOp = cast(srcNode->op); + auto dstAffineForOp = cast(dstNode->op); + // Compute the innermost common loop depth for dstNode loads/stores. - SmallVector dstOps(dstNode->loads.begin(), - dstNode->loads.end()); - dstOps.append(dstNode->stores.begin(), dstNode->stores.end()); - unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstOps); + SmallVector dstMemrefOps; + for (Operation *op : dstNode->loads) + if (cast(op).getMemRef() == memref) + dstMemrefOps.push_back(op); + for (Operation *op : dstNode->stores) + if (cast(op).getMemRef() == memref) + dstMemrefOps.push_back(op); + unsigned dstLoopDepthTest = getInnermostCommonLoopDepth(dstMemrefOps); + // Check the feasibility of fusing src loop nest into dst loop nest // at loop depths in range [1, dstLoopDepthTest]. - // TODO: Use slice union computation and union of memref - // read/write regions to cost model and fusion. - bool canFuse = false; + unsigned maxLegalFusionDepth = 0; + SmallVector depthSliceUnions; + depthSliceUnions.resize(dstLoopDepthTest); + FusionStrategy strategy(FusionStrategy::ProducerConsumer, memref); for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { - ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops( - cast(srcNode->op), cast(dstNode->op), - /*dstLoopDepth=*/i, &sliceUnion); + srcAffineForOp, dstAffineForOp, + /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); + if (result.value == FusionResult::Success) - canFuse = true; + maxLegalFusionDepth = i; } - // Skip if fusion is not feasible at all loop depths. - if (!canFuse) + // Skip if fusion is not feasible at any loop depths. + if (maxLegalFusionDepth == 0) continue; - // Gather 'dstNode' store ops to 'memref'. - SmallVector dstStoreOpInsts; - for (auto *storeOpInst : dstNode->stores) - if (cast(storeOpInst).getMemRef() == memref) - dstStoreOpInsts.push_back(storeOpInst); - - unsigned bestDstLoopDepth; - mlir::ComputationSliceState sliceState; - // Check if fusion would be profitable. - if (!isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, - &bestDstLoopDepth, maximalFusion, - computeToleranceThreshold)) + // Check if fusion would be profitable. We skip profitability analysis + // for maximal fusion since we already know the maximal legal depth to + // fuse. + unsigned bestDstLoopDepth = maxLegalFusionDepth; + if (!maximalFusion && + !isFusionProfitable(srcStoreOp, srcStoreOp, dstLoadOpInsts, + depthSliceUnions, maxLegalFusionDepth, + &bestDstLoopDepth, computeToleranceThreshold)) continue; + assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); + assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && + "Missing slice union for depth"); + // Fuse computation slice of 'srcLoopNest' into 'dstLoopNest'. - auto sliceLoopNest = mlir::insertBackwardComputationSlice( - srcStoreOp, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); - if (sliceLoopNest) { - LLVM_DEBUG(llvm::dbgs() << "\tslice loop nest:\n" - << *sliceLoopNest.getOperation() << "\n"); - // Move 'dstAffineForOp' before 'insertPointInst' if needed. - auto dstAffineForOp = cast(dstNode->op); - if (insertPointInst != dstAffineForOp.getOperation()) { - dstAffineForOp.getOperation()->moveBefore(insertPointInst); - } - // Update edges between 'srcNode' and 'dstNode'. - mdg->updateEdges(srcNode->id, dstNode->id, memref, - createPrivateMemref); - - // Collect slice loop stats. - LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getOperation()); - // Promote single iteration slice loops to single IV value. - for (auto forOp : sliceCollector.forOps) { - promoteIfSingleIteration(forOp); - } - if (createPrivateMemref) { - // Create private memref for 'memref' in 'dstAffineForOp'. - SmallVector storesForMemref; - for (auto *storeOpInst : sliceCollector.storeOpInsts) { - if (cast(storeOpInst).getMemRef() == - memref) - storesForMemref.push_back(storeOpInst); - } - // TODO: Use union of memref write regions to compute - // private memref footprint. - auto newMemRef = createPrivateMemRef( - dstAffineForOp, storesForMemref[0], bestDstLoopDepth, - fastMemorySpace, localBufSizeThreshold); - visitedMemrefs.insert(newMemRef); - // Create new node in dependence graph for 'newMemRef' alloc op. - unsigned newMemRefNodeId = - mdg->addNode(newMemRef.getDefiningOp()); - // Add edge from 'newMemRef' node to dstNode. - mdg->addEdge(newMemRefNodeId, dstId, newMemRef); + fuseLoops(srcAffineForOp, dstAffineForOp, + depthSliceUnions[bestDstLoopDepth - 1]); + + LLVM_DEBUG(llvm::dbgs() + << "Fused src loop " << srcId << " into dst loop " << dstId + << " at depth " << bestDstLoopDepth << ":\n" + << dstAffineForOp << "\n"); + + // Move 'dstAffineForOp' before 'insertPointInst' if needed. + if (insertPointInst != dstAffineForOp.getOperation()) + dstAffineForOp.getOperation()->moveBefore(insertPointInst); + + // Update edges between 'srcNode' and 'dstNode'. + mdg->updateEdges(srcNode->id, dstNode->id, memref, + createPrivateMemref); + + // Collect slice loop stats. + LoopNestStateCollector dstForCollector; + dstForCollector.collect(dstAffineForOp); + if (createPrivateMemref) { + // Create private memref for 'memref' in 'dstAffineForOp'. + SmallVector storesForMemref; + for (auto *storeOpInst : dstForCollector.storeOpInsts) { + if (cast(storeOpInst).getMemRef() == + memref) + storesForMemref.push_back(storeOpInst); } + // TODO: Use union of memref write regions to compute + // private memref footprint. + auto newMemRef = createPrivateMemRef( + dstAffineForOp, storesForMemref[0], bestDstLoopDepth, + fastMemorySpace, localBufSizeThreshold); + visitedMemrefs.insert(newMemRef); + // Create new node in dependence graph for 'newMemRef' alloc op. + unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp()); + // Add edge from 'newMemRef' node to dstNode. + mdg->addEdge(newMemRefNodeId, dstId, newMemRef); + } - // Collect dst loop stats after memref privatization transformation. - LoopNestStateCollector dstLoopCollector; - dstLoopCollector.collect(dstAffineForOp.getOperation()); - - // Add new load ops to current Node load op list 'loads' to - // continue fusing based on new operands. - for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { - // NOTE: Change 'loads' to a hash set in case efficiency is an - // issue. We still use a vector since it's expected to be small. - if (!llvm::is_contained(loads, loadOpInst)) - loads.push_back(loadOpInst); - } - // Clear visited memrefs after fusion so that previously visited src - // nodes are considered for fusion again in the context of the new - // fused node. - // TODO: This shouldn't be necessary if we visited candidates in the - // dependence graph in post-order or once we fully support - // multi-store producers. Currently, in a multi-store producer - // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to the - // multiple outgoing edges. However, after fusing B+C, A has a - // single outgoing edge and can be fused if we revisit it in the - // context of the new fused B+C node. - visitedMemrefs.clear(); - - // Clear and add back loads and stores. - mdg->clearNodeLoadAndStores(dstNode->id); - mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, - dstLoopCollector.storeOpInsts); - // Remove old src loop nest if it no longer has outgoing dependence - // edges, and if it does not write to a memref which escapes the - // function. If 'writesToLiveInOrOut' is true, then 'srcNode' has - // been fused into 'dstNode' and write region of 'dstNode' covers - // the write region of 'srcNode', and 'srcNode' has no other users - // so it is safe to remove. - if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { - mdg->removeNode(srcNode->id); - srcNode->op->erase(); - } else { - // Add remaining users of 'oldMemRef' back on the worklist (if not - // already there), as its replacement with a local/private memref - // has reduced dependences on 'oldMemRef' which may have created - // new fusion opportunities. - if (mdg->outEdges.count(srcNode->id) > 0) { - SmallVector oldOutEdges = - mdg->outEdges[srcNode->id]; - for (auto &outEdge : oldOutEdges) { - if (outEdge.value == memref && - worklistSet.count(outEdge.id) == 0) { - worklist.push_back(outEdge.id); - worklistSet.insert(outEdge.id); - } + // Collect dst loop stats after memref privatization + // transformation. + LoopNestStateCollector dstLoopCollector; + dstLoopCollector.collect(dstAffineForOp.getOperation()); + + // Add new load ops to current Node load op list 'loads' to + // continue fusing based on new operands. + for (auto *loadOpInst : dstLoopCollector.loadOpInsts) { + // NOTE: Change 'loads' to a hash set in case efficiency is an + // issue. We still use a vector since it's expected to be small. + if (!llvm::is_contained(loads, loadOpInst)) + loads.push_back(loadOpInst); + } + // Clear visited memrefs after fusion so that previously visited + // src nodes are considered for fusion again in the context of the + // new fused node. + // TODO: This shouldn't be necessary if we visited candidates in + // the dependence graph in post-order or once we fully support + // multi-store producers. Currently, in a multi-store producer + // scenario such as A->B, A->C, B->C, we fail to fuse A+B due to + // the multiple outgoing edges. However, after fusing B+C, A has a + // single outgoing edge and can be fused if we revisit it in the + // context of the new fused B+C node. + visitedMemrefs.clear(); + + // Clear and add back loads and stores. + mdg->clearNodeLoadAndStores(dstNode->id); + mdg->addToNode(dstId, dstLoopCollector.loadOpInsts, + dstLoopCollector.storeOpInsts); + // Remove old src loop nest if it no longer has outgoing + // dependence edges, and if it does not write to a memref which + // escapes the function. If 'writesToLiveInOrOut' is true, then + // 'srcNode' has been fused into 'dstNode' and write region of + // 'dstNode' covers the write region of 'srcNode', and 'srcNode' + // has no other users so it is safe to remove. + if (writesToLiveInOrOut || mdg->canRemoveNode(srcNode->id)) { + mdg->removeNode(srcNode->id); + srcNode->op->erase(); + } else { + // Add remaining users of 'oldMemRef' back on the worklist (if + // not already there), as its replacement with a local/private + // memref has reduced dependences on 'oldMemRef' which may have + // created new fusion opportunities. + if (mdg->outEdges.count(srcNode->id) > 0) { + SmallVector oldOutEdges = + mdg->outEdges[srcNode->id]; + for (auto &outEdge : oldOutEdges) { + if (outEdge.value == memref && + worklistSet.count(outEdge.id) == 0) { + worklist.push_back(outEdge.id); + worklistSet.insert(outEdge.id); } } } @@ -1759,6 +1655,8 @@ struct GreedyFusion { void fuseWithSiblingNodes(Node *dstNode) { DenseSet visitedSibNodeIds; std::pair idAndMemref; + auto dstAffineForOp = cast(dstNode->op); + while (findSiblingNodeToFuse(dstNode, &visitedSibNodeIds, &idAndMemref)) { unsigned sibId = idAndMemref.first; Value memref = idAndMemref.second; @@ -1791,31 +1689,53 @@ struct GreedyFusion { SmallVector dstLoadOpInsts; dstNode->getLoadOpsForMemref(memref, &dstLoadOpInsts); - // Gather 'dstNode' store ops to 'memref'. - SmallVector dstStoreOpInsts; - dstNode->getStoreOpsForMemref(memref, &dstStoreOpInsts); - - unsigned bestDstLoopDepth; - mlir::ComputationSliceState sliceState; + SmallVector dstLoopIVs; + getLoopIVs(*dstLoadOpInsts[0], &dstLoopIVs); + unsigned dstLoopDepthTest = dstLoopIVs.size(); + auto sibAffineForOp = cast(sibNode->op); + + // Compute loop depth and slice union for fusion. + SmallVector depthSliceUnions; + depthSliceUnions.resize(dstLoopDepthTest); + unsigned maxLegalFusionDepth = 0; + FusionStrategy strategy(FusionStrategy::Sibling, memref); + for (unsigned i = 1; i <= dstLoopDepthTest; ++i) { + FusionResult result = mlir::canFuseLoops( + sibAffineForOp, dstAffineForOp, + /*dstLoopDepth=*/i, &depthSliceUnions[i - 1], strategy); + + if (result.value == FusionResult::Success) + maxLegalFusionDepth = i; + } - // Check if fusion would be profitable. - if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, - dstStoreOpInsts, &sliceState, &bestDstLoopDepth, - maximalFusion, computeToleranceThreshold)) + // Skip if fusion is not feasible at any loop depths. + if (maxLegalFusionDepth == 0) continue; + unsigned bestDstLoopDepth = dstLoopDepthTest; + if (!maximalFusion) { + // Check if fusion would be profitable. + if (!isFusionProfitable(sibLoadOpInst, sibStoreOpInst, dstLoadOpInsts, + depthSliceUnions, maxLegalFusionDepth, + &bestDstLoopDepth, computeToleranceThreshold)) + continue; + } + + assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth"); + assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() && + "Fusion depth has no computed slice union"); + // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'. - auto sliceLoopNest = mlir::insertBackwardComputationSlice( - sibLoadOpInst, dstLoadOpInsts[0], bestDstLoopDepth, &sliceState); - if (sliceLoopNest != nullptr) { - auto dstForInst = cast(dstNode->op); - // Update operation position of fused loop nest (if needed). - if (insertPointInst != dstForInst.getOperation()) { - dstForInst.getOperation()->moveBefore(insertPointInst); - } - // Update data dependence graph state post fusion. - updateStateAfterSiblingFusion(sliceLoopNest, sibNode, dstNode); + mlir::fuseLoops(sibAffineForOp, dstAffineForOp, + depthSliceUnions[bestDstLoopDepth - 1]); + + auto dstForInst = cast(dstNode->op); + // Update operation position of fused loop nest (if needed). + if (insertPointInst != dstForInst.getOperation()) { + dstForInst.getOperation()->moveBefore(insertPointInst); } + // Update data dependence graph state post fusion. + updateStateAfterSiblingFusion(sibNode, dstNode); } } @@ -1943,19 +1863,10 @@ struct GreedyFusion { return false; } - void updateStateAfterSiblingFusion(AffineForOp sliceLoopNest, Node *sibNode, - Node *dstNode) { + void updateStateAfterSiblingFusion(Node *sibNode, Node *dstNode) { // Update 'sibNode' and 'dstNode' input/output edges to reflect fusion. mdg->updateEdges(sibNode->id, dstNode->id); - // Collect slice loop stats. - LoopNestStateCollector sliceCollector; - sliceCollector.collect(sliceLoopNest.getOperation()); - // Promote single iteration slice loops to single IV value. - for (auto forOp : sliceCollector.forOps) { - promoteIfSingleIteration(forOp); - } - // Collect dst loop stats after memref privatization transformation. auto dstForInst = cast(dstNode->op); LoopNestStateCollector dstLoopCollector; diff --git a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp index 1bf9177bd8161..87f6bd7055cc9 100644 --- a/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopFusionUtils.cpp @@ -47,7 +47,7 @@ static void getLoadAndStoreMemRefAccesses(Operation *opA, }); } -// Returns true if 'op' is a load or store operation which access an memref +// Returns true if 'op' is a load or store operation which access a memref // accessed 'values' and at least one of the access is a store operation. // Returns false otherwise. static bool isDependentLoadOrStoreOp(Operation *op, @@ -187,26 +187,99 @@ gatherLoadsAndStores(AffineForOp forOp, return !hasIfOp; } +// Returns the maximum loop depth at which we could fuse producer loop +// 'srcForOp' into consumer loop 'dstForOp' without violating data dependences. +// TODO: Generalize this check for sibling and more generic fusion scenarios. +// TODO: Support forward slice fusion. +static unsigned getMaxLoopDepth(ArrayRef dstOps, + FusionStrategy fusionStrategy) { + assert(fusionStrategy.strategy == FusionStrategy::ProducerConsumer && + "Fusion strategy not supported"); + + if (dstOps.empty()) + // Expected at least one memory operation. + // TODO: Revisit this case with a specific example. + return 0; + + // Filter out ops in 'dstOps' that do not use the producer-consumer memref so + // that they are not considered for analysis. + // TODO: Currently, we pass the producer-consumer memref through + // fusionStrategy. We will retrieve the memrefs from 'srcOps' once we + // generalize the algorithm. + SmallVector targetDstOps; + for (Operation *dstOp : dstOps) { + auto loadOp = dyn_cast(dstOp); + Value memref = loadOp ? loadOp.getMemRef() + : cast(dstOp).getMemRef(); + if (memref == fusionStrategy.memref) + targetDstOps.push_back(dstOp); + } + + assert(!targetDstOps.empty() && + "No dependences between 'srcForOp' and 'dstForOp'?"); + + // Compute the innermost common loop depth for loads and stores. + unsigned loopDepth = getInnermostCommonLoopDepth(targetDstOps); + + // Return common loop depth for loads if there are no store ops. + if (all_of(targetDstOps, + [&](Operation *op) { return isa(op); })) + return loopDepth; + + // Check dependences on all pairs of ops in 'targetDstOps' and store the + // minimum loop depth at which a dependence is satisfied. + for (unsigned i = 0, e = targetDstOps.size(); i < e; ++i) { + auto *srcOpInst = targetDstOps[i]; + MemRefAccess srcAccess(srcOpInst); + for (unsigned j = 0; j < e; ++j) { + auto *dstOpInst = targetDstOps[j]; + MemRefAccess dstAccess(dstOpInst); + + unsigned numCommonLoops = + getNumCommonSurroundingLoops(*srcOpInst, *dstOpInst); + for (unsigned d = 1; d <= numCommonLoops + 1; ++d) { + FlatAffineConstraints dependenceConstraints; + // TODO: Cache dependence analysis results, check cache here. + DependenceResult result = checkMemrefAccessDependence( + srcAccess, dstAccess, d, &dependenceConstraints, + /*dependenceComponents=*/nullptr); + if (hasDependence(result)) { + // Store minimum loop depth and break because we want the min 'd' at + // which there is a dependence. + loopDepth = std::min(loopDepth, d - 1); + break; + } + } + } + } + + return loopDepth; +} + // TODO: Prevent fusion of loop nests with side-effecting operations. +// TODO: This pass performs some computation that is the same for all the depths +// (e.g., getMaxLoopDepth). Implement a version of this utility that processes +// all the depths at once or only the legal maximal depth for maximal fusion. FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, unsigned dstLoopDepth, - ComputationSliceState *srcSlice) { + ComputationSliceState *srcSlice, + FusionStrategy fusionStrategy) { // Return 'failure' if 'dstLoopDepth == 0'. if (dstLoopDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n."); + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); return FusionResult::FailPrecondition; } // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. auto *block = srcForOp.getOperation()->getBlock(); if (block != dstForOp.getOperation()->getBlock()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n."); + LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); return FusionResult::FailPrecondition; } // Return 'failure' if no valid insertion point for fused loop nest in 'block' // exists which would preserve dependences. if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n."); + LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); return FusionResult::FailBlockDependence; } @@ -220,25 +293,68 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. SmallVector opsA; if (!gatherLoadsAndStores(forOpA, opsA)) { - LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); return FusionResult::FailPrecondition; } // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. SmallVector opsB; if (!gatherLoadsAndStores(forOpB, opsB)) { - LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported.\n."); + LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); return FusionResult::FailPrecondition; } + // Return 'failure' if fusing loops at depth 'dstLoopDepth' wouldn't preserve + // loop dependences. + // TODO: Enable this check for sibling and more generic loop fusion + // strategies. + if (fusionStrategy.strategy == FusionStrategy::ProducerConsumer) { + // TODO: 'getMaxLoopDepth' does not support forward slice fusion. + assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); + if (getMaxLoopDepth(opsB, fusionStrategy) < dstLoopDepth) { + LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); + return FusionResult::FailFusionDependence; + } + } + // Calculate the number of common loops surrounding 'srcForOp' and 'dstForOp'. unsigned numCommonLoops = mlir::getNumCommonSurroundingLoops( *srcForOp.getOperation(), *dstForOp.getOperation()); + // Filter out ops in 'opsA' to compute the slice union based on the + // assumptions expected by the fusion strategy. + SmallVector strategyOpsA; + switch (fusionStrategy.strategy) { + case FusionStrategy::None: + // Generic fusion. Take into account all the memory operations to compute + // the slice union. + strategyOpsA.append(opsA.begin(), opsA.end()); + break; + case FusionStrategy::ProducerConsumer: + // Producer-consumer fusion (AffineLoopFusion pass) only takes into + // account stores to 'memref' in 'srcForOp' to compute the slice union. + for (Operation *op : opsA) { + auto store = dyn_cast(op); + if (store && store.getMemRef() == fusionStrategy.memref) + strategyOpsA.push_back(op); + } + break; + case FusionStrategy::Sibling: + // Sibling fusion (AffineLoopFusion pass) only takes into account the loads + // to 'memref' in 'srcForOp' to compute the slice union. + for (Operation *op : opsA) { + auto load = dyn_cast(op); + if (load && load.getMemRef() == fusionStrategy.memref) + strategyOpsA.push_back(op); + } + break; + } + // Compute union of computation slices computed between all pairs of ops // from 'forOpA' and 'forOpB'. - if (failed(mlir::computeSliceUnion(opsA, opsB, dstLoopDepth, numCommonLoops, - isSrcForOpBeforeDstForOp, srcSlice))) { + if (failed(mlir::computeSliceUnion(strategyOpsA, opsB, dstLoopDepth, + numCommonLoops, isSrcForOpBeforeDstForOp, + srcSlice))) { LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); return FusionResult::FailPrecondition; } @@ -249,24 +365,30 @@ FusionResult mlir::canFuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, /// Fuses 'srcForOp' into 'dstForOp' with destination loop block insertion point /// and source slice loop bounds specified in 'srcSlice'. void mlir::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp, - ComputationSliceState *srcSlice) { + const ComputationSliceState &srcSlice) { // Clone 'srcForOp' into 'dstForOp' at 'srcSlice->insertPoint'. - OpBuilder b(srcSlice->insertPoint->getBlock(), srcSlice->insertPoint); + OpBuilder b(srcSlice.insertPoint->getBlock(), srcSlice.insertPoint); BlockAndValueMapping mapper; b.clone(*srcForOp, mapper); // Update 'sliceLoopNest' upper and lower bounds from computed 'srcSlice'. SmallVector sliceLoops; - for (unsigned i = 0, e = srcSlice->ivs.size(); i < e; ++i) { - auto loopIV = mapper.lookupOrNull(srcSlice->ivs[i]); + for (unsigned i = 0, e = srcSlice.ivs.size(); i < e; ++i) { + auto loopIV = mapper.lookupOrNull(srcSlice.ivs[i]); if (!loopIV) continue; auto forOp = getForInductionVarOwner(loopIV); sliceLoops.push_back(forOp); - if (AffineMap lbMap = srcSlice->lbs[i]) - forOp.setLowerBound(srcSlice->lbOperands[i], lbMap); - if (AffineMap ubMap = srcSlice->ubs[i]) - forOp.setUpperBound(srcSlice->ubOperands[i], ubMap); + if (AffineMap lbMap = srcSlice.lbs[i]) { + auto lbOperands = srcSlice.lbOperands[i]; + canonicalizeMapAndOperands(&lbMap, &lbOperands); + forOp.setLowerBound(lbOperands, lbMap); + } + if (AffineMap ubMap = srcSlice.ubs[i]) { + auto ubOperands = srcSlice.ubOperands[i]; + canonicalizeMapAndOperands(&ubMap, &ubOperands); + forOp.setUpperBound(ubOperands, ubMap); + } } // Promote any single iteration slice loops. @@ -393,15 +515,15 @@ static uint64_t getSliceIterationCount( // was encountered). // TODO: Make this work with non-unit step loops. static bool buildSliceTripCountMap( - ComputationSliceState *slice, + const ComputationSliceState &slice, llvm::SmallDenseMap *tripCountMap) { - unsigned numSrcLoopIVs = slice->ivs.size(); + unsigned numSrcLoopIVs = slice.ivs.size(); // Populate map from AffineForOp -> trip count for (unsigned i = 0; i < numSrcLoopIVs; ++i) { - AffineForOp forOp = getForInductionVarOwner(slice->ivs[i]); + AffineForOp forOp = getForInductionVarOwner(slice.ivs[i]); auto *op = forOp.getOperation(); - AffineMap lbMap = slice->lbs[i]; - AffineMap ubMap = slice->ubs[i]; + AffineMap lbMap = slice.lbs[i]; + AffineMap ubMap = slice.ubs[i]; if (lbMap == AffineMap() || ubMap == AffineMap()) { // The iteration of src loop IV 'i' was not sliced. Use full loop bounds. if (forOp.hasConstantLowerBound() && forOp.hasConstantUpperBound()) { @@ -442,7 +564,7 @@ int64_t mlir::getComputeCost(AffineForOp forOp, LoopNestStats &stats) { /// the entire loop nest. bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, AffineForOp dstForOp, LoopNestStats &dstStats, - ComputationSliceState *slice, + const ComputationSliceState &slice, int64_t *computeCost) { llvm::SmallDenseMap sliceTripCountMap; DenseMap computeCostMap; @@ -454,7 +576,7 @@ bool mlir::getFusionComputeCost(AffineForOp srcForOp, LoopNestStats &srcStats, int64_t sliceIterationCount = getSliceIterationCount(sliceTripCountMap); assert(sliceIterationCount > 0); bool storeLoadFwdGuaranteed = (sliceIterationCount == 1); - auto *insertPointParent = slice->insertPoint->getParentOp(); + auto *insertPointParent = slice.insertPoint->getParentOp(); // The store and loads to this memref will disappear. // TODO: Add load coalescing to memref data flow opt pass. diff --git a/mlir/test/lib/Transforms/TestLoopFusion.cpp b/mlir/test/lib/Transforms/TestLoopFusion.cpp index 8e6974d3735c7..8dab63c6873f1 100644 --- a/mlir/test/lib/Transforms/TestLoopFusion.cpp +++ b/mlir/test/lib/Transforms/TestLoopFusion.cpp @@ -129,7 +129,7 @@ static bool testLoopFusionTransformation(AffineForOp forOpA, AffineForOp forOpB, mlir::ComputationSliceState sliceUnion; FusionResult result = mlir::canFuseLoops(forOpA, forOpB, d, &sliceUnion); if (result.value == FusionResult::Success) { - mlir::fuseLoops(forOpA, forOpB, &sliceUnion); + mlir::fuseLoops(forOpA, forOpB, sliceUnion); // Note: 'forOpA' is removed to simplify test output. A proper loop // fusion pass should check the data dependence graph and run memref // region analysis to ensure removing 'forOpA' is safe.