diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index ecd829ed14add..cdc52f4f3668c 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -221,6 +221,45 @@ FailureOr normalizeForallOp(RewriterBase &rewriter, /// 4. Each region iter arg and result has exactly one use bool isPerfectlyNestedForLoops(MutableArrayRef loops); +/// Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs' +/// and 'yieldedValues' as the block arguments and yielded values of the loop. +/// The content of the loop body is replicated 'unrollFactor' times, calling +/// 'ivRemapFn' to remap 'iv' for each unrolled body. If specified, annotates +/// the Ops in each unrolled iteration using annotateFn. If provided, +/// 'clonedToSrcOpsMap' is populated with the mappings from the cloned ops to +/// the original op. +void generateUnrolledLoop( + Block *loopBodyBlock, Value iv, uint64_t unrollFactor, + function_ref ivRemapFn, + function_ref annotateFn, + ValueRange iterArgs, ValueRange yieldedValues, + IRMapping *clonedToSrcOpsMap = nullptr); + +/// Unroll this scf::Parallel loop by the specified unroll factors. Returns the +/// unrolled loop if the unroll succeded; otherwise returns failure if the loop +/// cannot be unrolled either due to restrictions or to invalid unroll factors. +/// Requires positive loop bounds and step. If specified, annotates the Ops in +/// each unrolled iteration by applying `annotateFn`. +/// If provided, 'clonedToSrcOpsMap' is populated with the mappings from the +/// cloned ops to the original op. +FailureOr parallelLoopUnrollByFactors( + scf::ParallelOp op, ArrayRef unrollFactors, + RewriterBase &rewriter, + function_ref annotateFn = nullptr, + IRMapping *clonedToSrcOpsMap = nullptr); + +/// Get constant trip counts for each of the induction variables of the given +/// loop operation. If any of the loop's trip counts is not constant, return an +/// empty vector. +llvm::SmallVector +getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp); + +namespace scf { +/// Helper function to compute the difference between two values. This is used +/// by the loop implementations to compute the trip count. +std::optional computeUbMinusLb(Value lb, Value ub, bool isSigned); +} // namespace scf + } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 744a5951330a3..395b52fe46d25 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/IRMapping.h" @@ -111,24 +112,6 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion, return nullptr; } -/// Helper function to compute the difference between two values. This is used -/// by the loop implementations to compute the trip count. -static std::optional computeUbMinusLb(Value lb, Value ub, - bool isSigned) { - llvm::APSInt diff; - auto addOp = ub.getDefiningOp(); - if (!addOp) - return std::nullopt; - if ((isSigned && !addOp.hasNoSignedWrap()) || - (!isSigned && !addOp.hasNoUnsignedWrap())) - return std::nullopt; - - if (addOp.getLhs() != lb || - !matchPattern(addOp.getRhs(), m_ConstantInt(&diff))) - return std::nullopt; - return diff; -} - //===----------------------------------------------------------------------===// // ExecuteRegionOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 10eae8906ce31..2d989d50bb8ac 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, return arith::DivUIOp::create(builder, loc, sum, divisor); } -/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with -/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap -/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each -/// unrolled iteration using annotateFn. -static void generateUnrolledLoop( - Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor, +void mlir::generateUnrolledLoop( + Block *loopBodyBlock, Value iv, uint64_t unrollFactor, function_ref ivRemapFn, function_ref annotateFn, - ValueRange iterArgs, ValueRange yieldedValues) { + ValueRange iterArgs, ValueRange yieldedValues, + IRMapping *clonedToSrcOpsMap) { + + // Check if the op was cloned from another source op, and return it if found + // (or the same op if not found) + auto findOriginalSrcOp = + [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * { + Operation *srcOp = op; + // If the source op derives from another op: traverse the chain to find the + // original source op + while (srcOp && clonedToSrcOpsMap.contains(srcOp)) + srcOp = clonedToSrcOpsMap.lookup(srcOp); + return srcOp; + }; + // Builder to insert unrolled bodies just before the terminator of the body of - // 'forOp'. + // the loop. auto builder = OpBuilder::atBlockTerminator(loopBodyBlock); - constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; + static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {}; if (!annotateFn) - annotateFn = defaultAnnotateFn; + annotateFn = noopAnnotateFn; // Keep a pointer to the last non-terminator operation in the original block // so that we know what to clone (since we are doing this in-place). Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2); - // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies). + // Unroll the contents of the loop body (append unrollFactor - 1 additional + // copies). SmallVector lastYielded(yieldedValues); for (unsigned i = 1; i < unrollFactor; i++) { - IRMapping operandMap; - // Prepare operand map. + IRMapping operandMap; operandMap.map(iterArgs, lastYielded); // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forOpIV.use_empty()) { - Value ivUnroll = ivRemapFn(i, forOpIV, builder); - operandMap.map(forOpIV, ivUnroll); + if (!iv.use_empty()) { + Value ivUnroll = ivRemapFn(i, iv, builder); + operandMap.map(iv, ivUnroll); } // Clone the original body of 'forOp'. for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) { - Operation *clonedOp = builder.clone(*it, operandMap); + Operation *srcOp = &(*it); + Operation *clonedOp = builder.clone(*srcOp, operandMap); annotateFn(i, clonedOp, builder); + if (clonedToSrcOpsMap) + clonedToSrcOpsMap->map(clonedOp, + findOriginalSrcOp(srcOp, *clonedToSrcOpsMap)); } // Update yielded values. @@ -1544,3 +1558,116 @@ bool mlir::isPerfectlyNestedForLoops( } return true; } + +std::optional mlir::scf::computeUbMinusLb(Value lb, Value ub, + bool isSigned) { + llvm::APSInt diff; + auto addOp = ub.getDefiningOp(); + if (!addOp) + return std::nullopt; + if ((isSigned && !addOp.hasNoSignedWrap()) || + (!isSigned && !addOp.hasNoUnsignedWrap())) + return std::nullopt; + + if (addOp.getLhs() != lb || + !matchPattern(addOp.getRhs(), m_ConstantInt(&diff))) + return std::nullopt; + return diff; +} + +llvm::SmallVector +mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) { + std::optional> loBnds = loopOp.getLoopLowerBounds(); + std::optional> upBnds = loopOp.getLoopUpperBounds(); + std::optional> steps = loopOp.getLoopSteps(); + if (!loBnds || !upBnds || !steps) + return {}; + llvm::SmallVector tripCounts; + for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) { + std::optional numIter = constantTripCount( + lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb); + if (!numIter) + return {}; + tripCounts.push_back(numIter->getSExtValue()); + } + return tripCounts; +} + +FailureOr mlir::parallelLoopUnrollByFactors( + scf::ParallelOp op, ArrayRef unrollFactors, + RewriterBase &rewriter, + function_ref annotateFn, + IRMapping *clonedToSrcOpsMap) { + const unsigned numLoops = op.getNumLoops(); + assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) && + "Expected positive unroll factors"); + assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) && + "Expected non-empty unroll factors of size <= to the number of loops"); + + // Bail out if no valid unroll factors were provided + if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; })) + return rewriter.notifyMatchFailure( + op, "Unrolling not applied if all factors are 1"); + + // Return if the loop body is empty. + if (llvm::hasSingleElement(op.getBody()->getOperations())) + return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body"); + + // If the provided unroll factors do not cover all the loop dims, they are + // applied to the inner loop dimensions. + const unsigned firstLoopDimIdx = numLoops - unrollFactors.size(); + + // Make sure that the unroll factors divide the iteration space evenly + // TODO: Support unrolling loops with dynamic iteration spaces. + const llvm::SmallVector tripCounts = getConstLoopTripCounts(op); + if (tripCounts.empty()) + return rewriter.notifyMatchFailure( + op, "Failed to compute constant trip counts for the loop. Note that " + "dynamic loop sizes are not supported."); + + for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) { + const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx]; + if (tripCounts[dimIdx] % unrollFactor) + return rewriter.notifyMatchFailure( + op, "Unroll factors don't divide the iteration space evenly"); + } + + std::optional> maybeFoldSteps = op.getLoopSteps(); + if (!maybeFoldSteps) + return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps"); + llvm::SmallVector steps{}; + for (auto step : *maybeFoldSteps) + steps.push_back(static_cast(*getConstantIntValue(step))); + + for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) { + const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx]; + if (unrollFactor == 1) + continue; + const size_t origStep = steps[dimIdx]; + const int64_t newStep = origStep * unrollFactor; + IRMapping clonedToSrcOpsMap; + + ValueRange iterArgs = ValueRange(op.getRegionIterArgs()); + auto yieldedValues = op.getBody()->getTerminator()->getOperands(); + + generateUnrolledLoop( + op.getBody(), op.getInductionVars()[dimIdx], unrollFactor, + [&](unsigned i, Value iv, OpBuilder b) { + // iv' = iv + step * i; + const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i); + const auto map = + b.getDimIdentityMap().dropResult(0).insertResult(expr, 0); + return affine::AffineApplyOp::create(b, iv.getLoc(), map, + ValueRange{iv}); + }, + /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap); + + // Update loop step + auto prevInsertPoint = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + op.getStepMutable()[dimIdx].assign( + arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep)); + rewriter.restoreInsertionPoint(prevInsertPoint); + } + return op; +} diff --git a/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir b/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir new file mode 100644 index 0000000000000..12b502e996c60 --- /dev/null +++ b/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir @@ -0,0 +1,171 @@ +// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=1,2' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=1,2 loop-depth=1' -split-input-file | FileCheck %s --check-prefix CHECK-UNROLL-INNER +// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=3,1' -split-input-file | FileCheck %s --check-prefix CHECK-UNROLL-BY-3 + +func.func @unroll_simple_parallel_loop(%src: memref<1x16x12xf32>, %dst: memref<1x16x12xf32>) { + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1, %c1) { + %read = memref.load %src[%arg2, %arg3, %arg4] : memref<1x16x12xf32> + memref.store %read, %dst[%arg2, %arg3, %arg4] : memref<1x16x12xf32> + scf.reduce + } + return +} + +// CHECK-LABEL: func @unroll_simple_parallel_loop +// CHECK-SAME: ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>) +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C12:%.*]] = arith.constant 12 : index +// CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index +// CHECK: scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C1]], [[C16]], [[C12]]) step ([[C1]], [[C1]], [[C2]]) +// CHECK: [[LOADED1:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32> +// CHECK: memref.store [[LOADED1]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32> +// CHECK: [[UNR_IV2:%.*]] = affine.apply {{.*}}([[IV2]]) +// CHECK: [[LOADED2:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[UNR_IV2]]] : memref<1x16x12xf32> +// CHECK: memref.store [[LOADED2]], [[ARG1]][[[IV0]], [[IV1]], [[UNR_IV2]]] : memref<1x16x12xf32> + +// ----- + +func.func @negative_unroll_factors_dont_divide_evenly(%src: memref<1x16x12xf32>, %dst: memref<1x16x12xf32>) { + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1, %c1) { + %read = memref.load %src[%arg2, %arg3, %arg4] : memref<1x16x12xf32> + memref.store %read, %dst[%arg2, %arg3, %arg4] : memref<1x16x12xf32> + scf.reduce + } + return +} + +// CHECK-UNROLL-BY-3-LABEL: func @negative_unroll_factors_dont_divide_evenly +// CHECK-UNROLL-BY-3-SAME: ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>) +// CHECK-UNROLL-BY-3: [[C1:%.*]] = arith.constant 1 : index +// CHECK-UNROLL-BY-3: scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = {{.*}} step ([[C1]], [[C1]], [[C1]]) +// CHECK-UNROLL-BY-3: [[LOADED:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32> +// CHECK-UNROLL-BY-3: memref.store [[LOADED]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32> +// CHECK-UNROLL-BY-3-NOT: affine.apply +// CHECK-UNROLL-BY-3-NOT: memref.load +// CHECK-UNROLL-BY-3-NOT: memref.store + +// ----- + +func.func @unroll_outer_nested_parallel_loop(%src: memref<5x16x12x4x4xf32>, %dst: memref<5x16x12x4x4xf32>) { + %c4 = arith.constant 4 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c5 = arith.constant 5 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1, %c1) { + scf.parallel (%arg6, %arg7) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1) { + %0 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg4, %arg6) + %1 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg5, %arg7) + %subv_in = memref.subview %src[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>> + %subv_out = memref.subview %dst[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>> + linalg.erf ins(%subv_in : memref<4x4xf32, strided<[4, 1], offset: ?>>) outs(%subv_out : memref<4x4xf32, strided<[4, 1], offset: ?>>) + scf.reduce + } + scf.reduce + } + return +} + +// CHECK-UNROLL-BY-3-LABEL: func @unroll_outer_nested_parallel_loop +// CHECK-LABEL: func @unroll_outer_nested_parallel_loop +// CHECK-SAME: ([[ARG0:%.*]]: memref<5x16x12x4x4xf32>, [[ARG1:%.*]]: memref<5x16x12x4x4xf32>) +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-DAG: [[C4:%.*]] = arith.constant 4 : index +// CHECK-DAG: [[C5:%.*]] = arith.constant 5 : index +// CHECK-DAG: [[C12:%.*]] = arith.constant 12 : index +// CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index +// CHECK: scf.parallel ([[OUTV0:%.*]], [[OUTV1:%.*]], [[OUTV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C5]], [[C16]], [[C12]]) step ([[C1]], [[C1]], [[C2]]) +// CHECK: scf.parallel ([[INV0:%.*]], [[INV1:%.*]]) = ([[C0]], [[C0]]) to ([[C4]], [[C4]]) step ([[C1]], [[C1]]) +// CHECK: affine.apply {{.*}}([[OUTV1]], [[INV0]]) +// CHECK: affine.apply {{.*}}([[OUTV2]], [[INV1]]) +// CHECK: linalg.erf + +// CHECK: [[UNR_OUTV2:%.*]] = affine.apply {{.*}}([[OUTV2]]) +// CHECK: scf.parallel ([[INV0B:%.*]], [[INV1B:%.*]]) = ([[C0]], [[C0]]) to ([[C4]], [[C4]]) step ([[C1]], [[C1]]) +// CHECK: affine.apply {{.*}}([[OUTV1]], [[INV0B]]) +// CHECK: affine.apply {{.*}}([[UNR_OUTV2]], [[INV1B]]) +// CHECK: linalg.erf + +// ----- + +func.func @negative_unroll_dynamic_parallel_loop(%src: memref<1x16x12xf32>, %dst: memref<1x16x12xf32>, %ub3: index) { + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %ub3) step (%c1, %c1, %c1) { + %read = memref.load %src[%arg2, %arg3, %arg4] : memref<1x16x12xf32> + memref.store %read, %dst[%arg2, %arg3, %arg4] : memref<1x16x12xf32> + scf.reduce + } + return +} + +// CHECK-LABEL: func @negative_unroll_dynamic_parallel_loop +// CHECK-SAME: ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>, [[UB3:%.*]]: index) +// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index +// CHECK: scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C1]], [[C16]], [[UB3]]) step ([[C1]], [[C1]], [[C1]]) +// CHECK: [[LOADED:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32> +// CHECK: memref.store [[LOADED]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32> +// CHECK-NOT: affine.apply +// CHECK-NOT: memref.load +// CHECK-NOT: memref.store + +// ----- + +func.func @unroll_inner_nested_parallel_loop(%src: memref<5x16x12x4x4xf32>, %dst: memref<5x16x12x4x4xf32>) { + %c4 = arith.constant 4 : index + %c12 = arith.constant 12 : index + %c16 = arith.constant 16 : index + %c5 = arith.constant 5 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1, %c1) { + scf.parallel (%arg6, %arg7) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1) { + %0 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg4, %arg6) + %1 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg5, %arg7) + %subv_in = memref.subview %src[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>> + %subv_out = memref.subview %dst[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>> + linalg.erf ins(%subv_in : memref<4x4xf32, strided<[4, 1], offset: ?>>) outs(%subv_out : memref<4x4xf32, strided<[4, 1], offset: ?>>) + scf.reduce + } + scf.reduce + } + return +} + +// CHECK-LABEL: func @unroll_inner_nested_parallel_loop +// CHECK-UNROLL-INNER-LABEL: func @unroll_inner_nested_parallel_loop +// CHECK-UNROLL-INNER-SAME: ([[ARG0:%.*]]: memref<5x16x12x4x4xf32>, [[ARG1:%.*]]: memref<5x16x12x4x4xf32>) +// CHECK-UNROLL-INNER-DAG: [[C0:%.*]] = arith.constant 0 : index +// CHECK-UNROLL-INNER-DAG: [[C1:%.*]] = arith.constant 1 : index +// CHECK-UNROLL-INNER-DAG: [[C4:%.*]] = arith.constant 4 : index +// CHECK-UNROLL-INNER-DAG: [[C5:%.*]] = arith.constant 5 : index +// CHECK-UNROLL-INNER-DAG: [[C12:%.*]] = arith.constant 12 : index +// CHECK-UNROLL-INNER-DAG: [[C16:%.*]] = arith.constant 16 : index +// CHECK-UNROLL-INNER: scf.parallel ([[OUTV0:%.*]], [[OUTV1:%.*]], [[OUTV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C5]], [[C16]], [[C12]]) step ([[C1]], [[C1]], [[C1]]) +// CHECK-UNROLL-INNER-DAG: [[C2:%.*]] = arith.constant 2 : index +// CHECK-UNROLL-INNER: scf.parallel ([[INV0:%.*]], [[INV1:%.*]]) = ([[C0]], [[C0]]) to ([[C4]], [[C4]]) step ([[C1]], [[C2]]) +// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV1]], [[INV0]]) +// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV2]], [[INV1]]) +// CHECK-UNROLL-INNER: linalg.erf + +// CHECK-UNROLL-INNER: [[UNR_INV1:%.*]] = affine.apply {{.*}}([[INV1]]) +// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV1]], [[INV0]]) +// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV2]], [[UNR_INV1]]) +// CHECK-UNROLL-INNER: linalg.erf diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt index 791c2e681415a..d2f97e816cc14 100644 --- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt +++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIRSCFTestPasses TestLoopParametricTiling.cpp TestLoopUnrolling.cpp + TestParallelLoopUnrolling.cpp TestSCFUtils.cpp TestSCFWrapInZeroTripCheck.cpp TestUpliftWhileToFor.cpp diff --git a/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp new file mode 100644 index 0000000000000..77a22a1812537 --- /dev/null +++ b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp @@ -0,0 +1,85 @@ +//=== TestParallelLoopUnrolling.cpp - loop unrolling test pass ===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to unroll loops by a specified unroll factor. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +static unsigned getNestingDepth(Operation *op) { + Operation *currOp = op; + unsigned depth = 0; + while ((currOp = currOp->getParentOp())) { + if (isa(currOp)) + depth++; + } + return depth; +} + +struct TestParallelLoopUnrollingPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParallelLoopUnrollingPass) + + StringRef getArgument() const final { return "test-parallel-loop-unrolling"; } + StringRef getDescription() const final { + return "Tests parallel loop unrolling transformation"; + } + TestParallelLoopUnrollingPass() = default; + TestParallelLoopUnrollingPass(const TestParallelLoopUnrollingPass &) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + SmallVector loops; + getOperation()->walk([&](scf::ParallelOp parLoop) { + if (getNestingDepth(parLoop) == loopDepth) + loops.push_back(parLoop); + }); + auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) { + if (annotateLoop) { + op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i)); + } + }; + PatternRewriter rewriter(getOperation()->getContext()); + for (auto loop : loops) { + (void)parallelLoopUnrollByFactors(loop, unrollFactors, rewriter, + annotateFn); + } + } + + ListOption unrollFactors{ + *this, "unroll-factors", + llvm::cl::desc( + "Unroll factors for each parallel loop dim. If fewer factors than " + "loop dims are provided, they are applied to the inner dims.")}; + Option loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."), + llvm::cl::init(0)}; + Option annotateLoop{*this, "annotate", + llvm::cl::desc("Annotate unrolled iterations."), + llvm::cl::init(false)}; +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestParallelLoopUnrollingPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 88421800fed1e..ac739be8c5cb5 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -140,6 +140,7 @@ void registerTestOneShotModuleBufferizePass(); void registerTestOpaqueLoc(); void registerTestOpLoweringPasses(); void registerTestPadFusion(); +void registerTestParallelLoopUnrollingPass(); void registerTestRecursiveTypesPass(); void registerTestSCFUpliftWhileToFor(); void registerTestSCFUtilsPass(); @@ -289,6 +290,7 @@ void registerTestPasses() { mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestOpLoweringPasses(); mlir::test::registerTestPadFusion(); + mlir::test::registerTestParallelLoopUnrollingPass(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUpliftWhileToFor(); mlir::test::registerTestSCFUtilsPass();