diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index 61f24255f305f..6bc488f2d3e1e 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -25,6 +25,7 @@ class FuncOp; namespace affine { class AffineForOp; +class AffineParallelOp; /// Fusion mode to attempt. The default mode `Greedy` does both /// producer-consumer and sibling fusion. @@ -108,6 +109,13 @@ std::unique_ptr> createLoopUnrollPass( std::unique_ptr> createLoopUnrollAndJamPass(int unrollJamFactor = -1); +/// Creates a memory banking pass to explicitly partition the memories used +/// inside affine parallel operations +std::unique_ptr> createParallelBankingPass( + int unrollFactor = -1, + const std::function &getBankingFactor = + nullptr); + /// Creates a pass to pipeline explicit movement of data across levels of the /// memory hierarchy. std::unique_ptr> createPipelineDataTransferPass(); diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index b08e803345f76..6321750f4b632 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -381,6 +381,15 @@ def AffineParallelize : Pass<"affine-parallelize", "func::FuncOp"> { ]; } +def AffineParallelBanking : Pass<"affine-parallel-banking", "func::FuncOp"> { + let summary = "Partition the memories used in affine parallel loops into banks"; + let constructor = "mlir::affine::createParallelBankingPass()"; + let options = [ + Option<"bankingFactor", "banking-factor", "unsigned", /*default=*/"1", + "Use this banking factor for all memories being partitioned"> + ]; +} + def AffineLoopNormalize : Pass<"affine-loop-normalize", "func::FuncOp"> { let summary = "Apply normalization transformations to affine loop-like ops"; let constructor = "mlir::affine::createAffineLoopNormalizePass()"; diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index 772f15335d907..9c1290636ba77 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRAffineTransforms LoopTiling.cpp LoopUnroll.cpp LoopUnrollAndJam.cpp + ParallelBanking.cpp PipelineDataTransfer.cpp ReifyValueBounds.cpp SuperVectorize.cpp @@ -39,5 +40,6 @@ add_mlir_dialect_library(MLIRAffineTransforms MLIRValueBoundsOpInterface MLIRVectorDialect MLIRVectorUtils + MLIRSCFDialect ) diff --git a/mlir/lib/Dialect/Affine/Transforms/ParallelBanking.cpp b/mlir/lib/Dialect/Affine/Transforms/ParallelBanking.cpp new file mode 100644 index 0000000000000..fb49db90353dd --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/ParallelBanking.cpp @@ -0,0 +1,381 @@ +//===- ParallelBanking.cpp - Code to perform memory bnaking in parallel loops +//--------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements parallel loop memory banking. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Visitors.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/TypeSwitch.h" +#include + +namespace mlir { +namespace affine { +#define GEN_PASS_DEF_AFFINEPARALLELBANKING +#include "mlir/Dialect/Affine/Passes.h.inc" +} // namespace affine +} // namespace mlir + +#define DEBUG_TYPE "affine-parallel-banking" + +using namespace mlir; +using namespace mlir::affine; + +namespace { + +/// Partition memories used in `affine.parallel` operation by the +/// `bankingFactor` throughout the program. +struct ParallelBanking + : public affine::impl::AffineParallelBankingBase { + const std::function getBankingFactor; + ParallelBanking() : getBankingFactor(nullptr) {} + ParallelBanking(const ParallelBanking &other) = default; + explicit ParallelBanking(std::optional bankingFactor = std::nullopt, + const std::function + &getBankingFactor = nullptr) + : getBankingFactor(getBankingFactor) { + if (bankingFactor) + this->bankingFactor = *bankingFactor; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; + LogicalResult parallelBankingByFactor(AffineParallelOp parOp, + uint64_t bankingFactor); + +private: + // map from original memory definition to newly allocated banks + DenseMap> memoryToBanks; +}; +} // namespace + +// Collect all memref in the `parOp`'s region' +DenseSet collectMemRefs(AffineParallelOp parOp) { + DenseSet memrefVals; + parOp.walk([&](Operation *op) { + for (auto operand : op->getOperands()) { + if (isa(operand.getType())) + memrefVals.insert(operand); + } + return WalkResult::advance(); + }); + return memrefVals; +} + +MemRefType computeBankedMemRefType(MemRefType originalType, + uint64_t bankingFactor) { + ArrayRef originalShape = originalType.getShape(); + assert(!originalShape.empty() && "memref shape should not be empty"); + assert(originalType.getRank() == 1 && + "currently only support one dimension memories"); + SmallVector newShape(originalShape.begin(), originalShape.end()); + assert(newShape.front() % bankingFactor == 0 && + "memref shape must be divided by the banking factor"); + newShape.front() /= bankingFactor; + MemRefType newMemRefType = + MemRefType::get(newShape, originalType.getElementType(), + originalType.getLayout(), originalType.getMemorySpace()); + + return newMemRefType; +} + +SmallVector createBanks(Value originalMem, uint64_t bankingFactor) { + MemRefType originalMemRefType = cast(originalMem.getType()); + MemRefType newMemRefType = + computeBankedMemRefType(originalMemRefType, bankingFactor); + SmallVector banks; + if (auto blockArgMem = dyn_cast(originalMem)) { + Block *block = blockArgMem.getOwner(); + unsigned blockArgNum = blockArgMem.getArgNumber(); + + SmallVector banksType; + for (unsigned i = 0; i < bankingFactor; ++i) { + block->insertArgument(blockArgNum + 1 + i, newMemRefType, + blockArgMem.getLoc()); + } + + auto blockArgs = + block->getArguments().slice(blockArgNum + 1, bankingFactor); + banks.append(blockArgs.begin(), blockArgs.end()); + } else { + Operation *originalDef = originalMem.getDefiningOp(); + Location loc = originalDef->getLoc(); + OpBuilder builder(originalDef); + builder.setInsertionPointAfter(originalDef); + TypeSwitch(originalDef) + .Case([&](memref::AllocOp allocOp) { + for (uint bankCnt = 0; bankCnt < bankingFactor; bankCnt++) { + auto bankAllocOp = + builder.create(loc, newMemRefType); + banks.push_back(bankAllocOp); + } + }) + .Case([&](memref::AllocaOp allocaOp) { + for (uint bankCnt = 0; bankCnt < bankingFactor; bankCnt++) { + auto bankAllocaOp = + builder.create(loc, newMemRefType); + banks.push_back(bankAllocaOp); + } + }) + .Default([](Operation *) { + llvm_unreachable("Unhandled memory operation type"); + }); + } + return banks; +} + +struct BankAffineLoadPattern : public OpRewritePattern { + BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor, + DenseMap> &memoryToBanks) + : OpRewritePattern(context), bankingFactor(bankingFactor), + memoryToBanks(memoryToBanks) {} + + LogicalResult matchAndRewrite(AffineLoadOp loadOp, + PatternRewriter &rewriter) const override { + Location loc = loadOp.getLoc(); + auto banks = memoryToBanks[loadOp.getMemref()]; + Value loadIndex = loadOp.getIndices().front(); + auto modMap = + AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor}); + auto divMap = AffineMap::get( + 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)}); + + Value bankIndex = rewriter.create( + loc, modMap, loadIndex); // assuming one-dim + Value offset = rewriter.create(loc, divMap, loadIndex); + + SmallVector resultTypes = {loadOp.getResult().getType()}; + + SmallVector caseValues; + for (unsigned i = 0; i < bankingFactor; ++i) + caseValues.push_back(i); + + rewriter.setInsertionPoint(loadOp); + scf::IndexSwitchOp switchOp = rewriter.create( + loc, resultTypes, bankIndex, caseValues, + /*numRegions=*/bankingFactor); + + for (unsigned i = 0; i < bankingFactor; ++i) { + Region &caseRegion = switchOp.getCaseRegions()[i]; + rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock()); + Value bankedLoad = rewriter.create(loc, banks[i], offset); + rewriter.create(loc, bankedLoad); + } + + Region &defaultRegion = switchOp.getDefaultRegion(); + assert(defaultRegion.empty() && "Default region should be empty"); + rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock()); + + TypedAttr zeroAttr = + cast(rewriter.getZeroAttr(loadOp.getType())); + auto defaultValue = rewriter.create(loc, zeroAttr); + rewriter.create(loc, defaultValue.getResult()); + + loadOp.getResult().replaceAllUsesWith(switchOp.getResult(0)); + + rewriter.eraseOp(loadOp); + return success(); + } + +private: + uint64_t bankingFactor; + DenseMap> &memoryToBanks; +}; + +struct BankAffineStorePattern : public OpRewritePattern { + BankAffineStorePattern(MLIRContext *context, uint64_t bankingFactor, + DenseMap> &memoryToBanks) + : OpRewritePattern(context), bankingFactor(bankingFactor), + memoryToBanks(memoryToBanks) {} + + LogicalResult matchAndRewrite(AffineStoreOp storeOp, + PatternRewriter &rewriter) const override { + Location loc = storeOp.getLoc(); + auto banks = memoryToBanks[storeOp.getMemref()]; + Value storeIndex = storeOp.getIndices().front(); + + auto modMap = + AffineMap::get(1, 0, {rewriter.getAffineDimExpr(0) % bankingFactor}); + auto divMap = AffineMap::get( + 1, 0, {rewriter.getAffineDimExpr(0).floorDiv(bankingFactor)}); + + Value bankIndex = rewriter.create( + loc, modMap, storeIndex); // assuming one-dim + Value offset = rewriter.create(loc, divMap, storeIndex); + + SmallVector resultTypes = {}; + + SmallVector caseValues; + for (unsigned i = 0; i < bankingFactor; ++i) + caseValues.push_back(i); + + rewriter.setInsertionPoint(storeOp); + scf::IndexSwitchOp switchOp = rewriter.create( + loc, resultTypes, bankIndex, caseValues, + /*numRegions=*/bankingFactor); + + for (unsigned i = 0; i < bankingFactor; ++i) { + Region &caseRegion = switchOp.getCaseRegions()[i]; + rewriter.setInsertionPointToStart(&caseRegion.emplaceBlock()); + rewriter.create(loc, storeOp.getValueToStore(), banks[i], + offset); + rewriter.create(loc); + } + + Region &defaultRegion = switchOp.getDefaultRegion(); + assert(defaultRegion.empty() && "Default region should be empty"); + rewriter.setInsertionPointToStart(&defaultRegion.emplaceBlock()); + + rewriter.create(loc); + + rewriter.eraseOp(storeOp); + return success(); + } + +private: + uint64_t bankingFactor; + DenseMap> &memoryToBanks; +}; + +struct BankReturnPattern : public OpRewritePattern { + BankReturnPattern(MLIRContext *context, + DenseMap> &memoryToBanks) + : OpRewritePattern(context), + memoryToBanks(memoryToBanks) {} + + LogicalResult matchAndRewrite(func::ReturnOp returnOp, + PatternRewriter &rewriter) const override { + Location loc = returnOp.getLoc(); + SmallVector newReturnOperands; + bool allOrigMemsUsedByReturn = true; + for (auto operand : returnOp.getOperands()) { + if (!memoryToBanks.contains(operand)) { + newReturnOperands.push_back(operand); + continue; + } + if (operand.hasOneUse()) + allOrigMemsUsedByReturn = false; + auto banks = memoryToBanks[operand]; + newReturnOperands.append(banks.begin(), banks.end()); + } + + func::FuncOp funcOp = returnOp.getParentOp(); + rewriter.setInsertionPointToEnd(&funcOp.getBlocks().front()); + auto newReturnOp = + rewriter.create(loc, ValueRange(newReturnOperands)); + TypeRange newReturnType = TypeRange(newReturnOperands); + FunctionType newFuncType = + FunctionType::get(funcOp.getContext(), + funcOp.getFunctionType().getInputs(), newReturnType); + funcOp.setType(newFuncType); + + if (allOrigMemsUsedByReturn) + rewriter.replaceOp(returnOp, newReturnOp); + + return success(); + } + +private: + DenseMap> &memoryToBanks; +}; + +LogicalResult cleanUpOldMemRefs(DenseSet &oldMemRefVals) { + DenseSet funcsToModify; + for (auto &memrefVal : oldMemRefVals) { + if (!memrefVal.use_empty()) + continue; + if (auto blockArg = dyn_cast(memrefVal)) { + Block *block = blockArg.getOwner(); + block->eraseArgument(blockArg.getArgNumber()); + if (auto funcOp = dyn_cast(block->getParentOp())) + funcsToModify.insert(funcOp); + } else + memrefVal.getDefiningOp()->erase(); + } + + // Modify the function type accordingly + for (auto funcOp : funcsToModify) { + SmallVector newArgTypes; + for (BlockArgument arg : funcOp.getArguments()) { + newArgTypes.push_back(arg.getType()); + } + FunctionType newFuncType = + FunctionType::get(funcOp.getContext(), newArgTypes, + funcOp.getFunctionType().getResults()); + funcOp.setType(newFuncType); + } + return success(); +} + +void ParallelBanking::runOnOperation() { + if (getOperation().isExternal()) { + return; + } + + getOperation().walk([&](AffineParallelOp parOp) { + DenseSet memrefsInPar = collectMemRefs(parOp); + + for (auto memrefVal : memrefsInPar) { + SmallVector banks = createBanks(memrefVal, bankingFactor); + memoryToBanks[memrefVal] = banks; + } + }); + + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + + patterns.add(ctx, bankingFactor, memoryToBanks); + patterns.add(ctx, bankingFactor, memoryToBanks); + patterns.add(ctx, memoryToBanks); + + GreedyRewriteConfig config; + config.strictMode = GreedyRewriteStrictness::ExistingOps; + + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + signalPassFailure(); + } + + // Clean up the old memref values + DenseSet oldMemRefVals; + for (const auto &pair : memoryToBanks) + oldMemRefVals.insert(pair.first); + + if (failed(cleanUpOldMemRefs(oldMemRefVals))) { + signalPassFailure(); + } +} + +std::unique_ptr> +mlir::affine::createParallelBankingPass( + int bankingFactor, + const std::function &getBankingFactor) { + return std::make_unique( + bankingFactor == -1 ? std::nullopt + : std::optional(bankingFactor), + getBankingFactor); +} diff --git a/mlir/test/Dialect/Affine/parallel-banking.mlir b/mlir/test/Dialect/Affine/parallel-banking.mlir new file mode 100644 index 0000000000000..6300871a44026 --- /dev/null +++ b/mlir/test/Dialect/Affine/parallel-banking.mlir @@ -0,0 +1,69 @@ +// RUN: mlir-opt %s -split-input-file -affine-parallel-banking="banking-factor=2" | FileCheck %s + +// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 mod 2)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0) -> (d0 floordiv 2)> + +// CHECK-LABEL: func.func @parallel_bank_one_dim( +// CHECK: %[[VAL_0:arg0]]: memref<4xf32>, +// CHECK: %[[VAL_1:arg1]]: memref<4xf32>, +// CHECK: %[[VAL_2:arg2]]: memref<4xf32>, +// CHECK: %[[VAL_3:arg3]]: memref<4xf32>) -> (memref<4xf32>, memref<4xf32>) { +// CHECK: %[[VAL_4:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = memref.alloc() : memref<4xf32> +// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<4xf32> +// CHECK: affine.parallel (%[[VAL_7:.*]]) = (0) to (8) { +// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]]) +// CHECK: %[[VAL_9:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]]) +// CHECK: %[[VAL_10:.*]] = scf.index_switch %[[VAL_8]] -> f32 +// CHECK: case 0 { +// CHECK: %[[VAL_11:.*]] = affine.load %[[VAL_0]]{{\[}}%[[VAL_9]]] : memref<4xf32> +// CHECK: scf.yield %[[VAL_11]] : f32 +// CHECK: } +// CHECK: case 1 { +// CHECK: %[[VAL_12:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<4xf32> +// CHECK: scf.yield %[[VAL_12]] : f32 +// CHECK: } +// CHECK: default { +// CHECK: scf.yield %[[VAL_4]] : f32 +// CHECK: } +// CHECK: %[[VAL_13:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]]) +// CHECK: %[[VAL_14:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]]) +// CHECK: %[[VAL_15:.*]] = scf.index_switch %[[VAL_13]] -> f32 +// CHECK: case 0 { +// CHECK: %[[VAL_16:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_14]]] : memref<4xf32> +// CHECK: scf.yield %[[VAL_16]] : f32 +// CHECK: } +// CHECK: case 1 { +// CHECK: %[[VAL_17:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<4xf32> +// CHECK: scf.yield %[[VAL_17]] : f32 +// CHECK: } +// CHECK: default { +// CHECK: scf.yield %[[VAL_4]] : f32 +// CHECK: } +// CHECK: %[[VAL_18:.*]] = arith.mulf %[[VAL_10]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_19:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]]) +// CHECK: %[[VAL_20:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]]) +// CHECK: scf.index_switch %[[VAL_19]] +// CHECK: case 0 { +// CHECK: affine.store %[[VAL_18]], %[[VAL_5]]{{\[}}%[[VAL_20]]] : memref<4xf32> +// CHECK: scf.yield +// CHECK: } +// CHECK: case 1 { +// CHECK: affine.store %[[VAL_18]], %[[VAL_6]]{{\[}}%[[VAL_20]]] : memref<4xf32> +// CHECK: scf.yield +// CHECK: } +// CHECK: default { +// CHECK: } +// CHECK: } +// CHECK: return %[[VAL_5]], %[[VAL_6]] : memref<4xf32>, memref<4xf32> +// CHECK: } +func.func @parallel_bank_one_dim(%arg0: memref<8xf32>, %arg1: memref<8xf32>) -> (memref<8xf32>) { + %mem = memref.alloc() : memref<8xf32> + affine.parallel (%i) = (0) to (8) { + %1 = affine.load %arg0[%i] : memref<8xf32> + %2 = affine.load %arg1[%i] : memref<8xf32> + %3 = arith.mulf %1, %2 : f32 + affine.store %3, %mem[%i] : memref<8xf32> + } + return %mem : memref<8xf32> +}