diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 4aacbe739ca5d..bcecef5e6e0a9 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -106,6 +106,20 @@ bool LoopPipelinerInternal::initializeLoopInfo( lb = forOp.getLowerBound(); step = forOp.getStep(); + std::vector> schedule; + options.getScheduleFn(forOp, schedule); + if (schedule.empty()) { + LDBG("--empty schedule -> BAIL"); + return false; + } + + opOrder.reserve(schedule.size()); + for (auto &opSchedule : schedule) { + maxStage = std::max(maxStage, opSchedule.second); + stages[opSchedule.first] = opSchedule.second; + opOrder.push_back(opSchedule.first); + } + dynamicLoop = true; auto upperBoundCst = getConstantIntValue(ub); auto lowerBoundCst = getConstantIntValue(lb); @@ -124,7 +138,7 @@ bool LoopPipelinerInternal::initializeLoopInfo( return false; } int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); - if (numIteration > maxStage) { + if (numIteration >= maxStage) { dynamicLoop = false; } else if (!options.supportDynamicLoops) { LDBG("--fewer loop iterations than pipeline stages -> BAIL"); @@ -137,19 +151,6 @@ bool LoopPipelinerInternal::initializeLoopInfo( LDBG("--no epilogue or predicate set -> BAIL"); return false; } - std::vector> schedule; - options.getScheduleFn(forOp, schedule); - if (schedule.empty()) { - LDBG("--empty schedule -> BAIL"); - return false; - } - - opOrder.reserve(schedule.size()); - for (auto &opSchedule : schedule) { - maxStage = std::max(maxStage, opSchedule.second); - stages[opSchedule.first] = opSchedule.second; - opOrder.push_back(opSchedule.first); - } // All operations need to have a stage. for (Operation &op : forOp.getBody()->without_terminator()) { diff --git a/mlir/test/Dialect/SCF/loop-pipelining.mlir b/mlir/test/Dialect/SCF/loop-pipelining.mlir index 25e5703cea5c6..86af637fc05d7 100644 --- a/mlir/test/Dialect/SCF/loop-pipelining.mlir +++ b/mlir/test/Dialect/SCF/loop-pipelining.mlir @@ -34,6 +34,42 @@ func.func @simple_pipeline(%A: memref, %result: memref) { return } +// ----- + +// A static loop does not satisfy `numIteration >= maxStage` + +// CHECK-LABEL: func.func @iteration_lt_stage( +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C_NEG1:.*]] = arith.constant -1 : index +// CHECK-DAG: %[[TRUE:.*]] = arith.constant true +// CHECK-DAG: %[[FALSE:.*]] = arith.constant false +// Prologue: +// CHECK: scf.if %[[TRUE]] +// CHECK: scf.if %[[TRUE]] +// CHECK: scf.if %[[FALSE]] +// Kernel: +// CHECK: scf.for %[[IV:.*]] = %[[C0:.*]] to %[[C_NEG1:.*]] step %[[C1:.*]] +// Epilogue: +// CHECK: scf.if %[[TRUE]] +// CHECK: scf.if %[[TRUE]] +// CHECK: scf.if %[[TRUE]] +// CHECK: scf.if %[[TRUE]] +// CHECK: scf.if %[[FALSE]] +// CHECK: scf.if %[[FALSE]] +func.func @iteration_lt_stage(%A: memref, %result: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cf = arith.constant 1.0 : f32 + scf.for %i0 = %c0 to %c2 step %c1 { + %A_elem = memref.load %A[%i0] { __test_pipelining_stage__ = 0, __test_pipelining_op_order__ = 2 } : memref + %A1_elem = arith.addf %A_elem, %cf { __test_pipelining_stage__ = 3, __test_pipelining_op_order__ = 0 } : f32 + memref.store %A1_elem, %result[%i0] { __test_pipelining_stage__ = 3, __test_pipelining_op_order__ = 1 } : memref + } { __test_pipelining_loop__ } + return +} + // -----