Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -642,48 +642,57 @@ LogicalResult
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
llvm::SmallVector<Value> &returnValues) {
Location loc = forOp.getLoc();
Type t = lb.getType();

// Emit different versions of the induction variable. They will be
// removed by dead code if not used.

// bounds_range = ub - lb
// total_iterations = (bounds_range + step - 1) / step
Type t = lb.getType();
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 0));
Value one =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, 1));
Value minusOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
auto getConst = [&](int v) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe rename to createConst?

return rewriter.create<arith::ConstantOp>(loc,
rewriter.getIntegerAttr(t, v));
};

// total_iterations = cdiv(range_diff, step);
// - range_diff = ub - lb
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
Value zero = getConst(0);
Value one = getConst(1);
Value stepLessZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, step, zero);
Value stepDecr =
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, minusOne);
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, getConst(-1));

Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
Value rangeDecr =
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);

// If total_iters < max_stage, start the epilogue at zero to match the
// ramp-up in the prologue.
// start_iter = max(0, total_iters - max_stage)
Value iterI =
rewriter.create<arith::SubIOp>(loc, totalIterations, getConst(maxStage));
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);

// Capture predicates for dynamic loops.
SmallVector<Value> predicates(maxStage + 1);
for (int64_t i = 0; i < maxStage; i++) {
// iterI = total_iters - 1 - i
// May go negative...
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
Value iterI = rewriter.create<arith::AddIOp>(
loc, rewriter.create<arith::AddIOp>(loc, totalIterations, minusOne),
minusI);

for (int64_t i = 1; i <= maxStage; i++) {
// newLastIter = lb + step * iterI
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));

setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
setValueMapping(forOp.getInductionVar(), newlastIter, i);

// increment to next iterI
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);

if (dynamicLoop) {
// pred = iterI >= 0
predicates[i + 1] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, iterI, zero);
// Disable stages when `i` is greater than total_iters.
// pred = total_iters >= i
predicates[i] = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, totalIterations, getConst(i));
}
}

Expand Down
67 changes: 34 additions & 33 deletions mlir/test/Dialect/SCF/loop-pipelining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// Check for predicated epilogue for dynamic loop.
// CHECK-LABEL: dynamic_loop(
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
Expand All @@ -779,32 +780,32 @@ func.func @stage_0_value_escape(%A: memref<?xf32>, %result: memref<?xf32>, %ub:
// CHECK: scf.yield %[[ADDF_24]], %[[LOAD_27]]
// CHECK: }
// CHECK: %[[CMPI_10:.*]] = arith.cmpi slt, %[[STEP]], %[[C0]]
// CHECK: %[[SEL_10:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
// CHECK: %[[SUBI_10:.*]] = arith.subi %[[UB]], %[[LB]]
// CHECK: %[[ADDI_11:.*]] = arith.addi %[[SUBI_10]], %[[STEP]]
// CHECK: %[[ADDI_12:.*]] = arith.addi %[[ADDI_11]], %[[SEL_10]]
// CHECK: %[[DIVSI_13:.*]] = arith.divsi %[[ADDI_12]], %[[STEP]]
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[DIVSI_13]], %[[CM1]]
// CHECK: %[[MULI_15:.*]] = arith.muli %{{.*}}, %[[ADDI_14]]
// CHECK: %[[ADDI_16:.*]] = arith.addi %{{.*}}, %[[MULI_15]]
// CHECK: %[[CMPI_17:.*]] = arith.cmpi sge, %[[ADDI_14]], %[[C0]]
// CHECK: %[[ADDI_18:.*]] = arith.addi %[[DIVSI_13]], %{{.*}}-1
// CHECK: %[[ADDI_19:.*]] = arith.addi %[[ADDI_18]], %{{.*}}-1
// CHECK: %[[MULI_20:.*]] = arith.muli %{{.*}}, %[[ADDI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %{{.*}}, %[[MULI_20]]
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[ADDI_19]], %[[C0]]
// CHECK: scf.if %[[CMPI_17]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_21]]]
// CHECK: %[[SELECT_11:.*]] = arith.select %[[CMPI_10]], %[[C1]], %[[CM1]]
// CHECK: %[[SUBI_12:.*]] = arith.subi %[[UB]], %[[LB]]
// CHECK: %[[ADDI_13:.*]] = arith.addi %[[SUBI_12]], %[[STEP]]
// CHECK: %[[ADDI_14:.*]] = arith.addi %[[ADDI_13]], %[[SELECT_11]]
// CHECK: %[[DIVSI_15:.*]] = arith.divsi %[[ADDI_14]], %[[STEP]]
// CHECK: %[[SUBI_17:.*]] = arith.subi %[[DIVSI_15]], %[[C2]]
// CHECK: %[[MAXSI_18:.*]] = arith.maxsi %[[SUBI_17]], %[[C0]]
// CHECK: %[[MULI_19:.*]] = arith.muli %[[STEP]], %[[MAXSI_18]]
// CHECK: %[[ADDI_20:.*]] = arith.addi %[[LB]], %[[MULI_19]]
// CHECK: %[[ADDI_21:.*]] = arith.addi %[[MAXSI_18]], %[[C1]]
// CHECK: %[[CMPI_22:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C1]]
// CHECK: %[[MULI_23:.*]] = arith.muli %[[STEP]], %[[ADDI_21]]
// CHECK: %[[ADDI_24:.*]] = arith.addi %[[LB]], %[[MULI_23]]
// CHECK: %[[CMPI_25:.*]] = arith.cmpi sge, %[[DIVSI_15]], %[[C2]]
// CHECK: scf.if %[[CMPI_22]] {
// CHECK: memref.store %{{.*}}#0, %{{.*}}[%[[ADDI_20]]]
// CHECK: } else {
// CHECK: }
// CHECK: %[[IF_23:.*]] = scf.if %[[CMPI_22]] -> (f32) {
// CHECK: %[[ADDF_24:.*]] = arith.addf %{{.*}}#1, %{{.*}}
// CHECK: scf.yield %[[ADDF_24]]
// CHECK: %[[IF_26:.*]] = scf.if %[[CMPI_25]]
// CHECK: %[[ADDF_27:.*]] = arith.addf %{{.*}}#1, %{{.*}}
// CHECK: scf.yield %[[ADDF_27]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: }
// CHECK: scf.if %[[CMPI_22]] {
// CHECK: memref.store %[[IF_23]], %{{.*}}[%[[ADDI_16]]]
// CHECK: scf.if %[[CMPI_25]] {
// CHECK: memref.store %[[IF_26]], %{{.*}}[%[[ADDI_24]]]
// CHECK: } else {
// CHECK: }
// CHECK: return
Expand Down Expand Up @@ -842,6 +843,7 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CM1:.*]] = arith.constant -1 : index
// CHECK-DAG: %[[CF0:.*]] = arith.constant 0.000000e+00
// CHECK: %[[UBM:.*]] = arith.subi %[[UB:.*]], %{{.*}}
// CHECK: %{{.*}}:2 = scf.for %[[ARG5:.*]] = %[[LB:.*]] to %[[UBM]] step %[[STEP:.*]] iter_args(%[[ARG6:.*]] = %{{.*}}, %[[ARG7:.*]] = %{{.*}})
// CHECK: %[[ADDF_13:.*]] = arith.addf %[[ARG7]], %[[ARG6]]
Expand All @@ -856,22 +858,21 @@ func.func @dynamic_loop(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %
// CHECK: %[[ADDI_7:.*]] = arith.addi %[[SUBI_6]], %[[STEP]]
// CHECK: %[[ADDI_8:.*]] = arith.addi %[[ADDI_7]], %[[SELECT_5]]
// CHECK: %[[DIVSI_9:.*]] = arith.divsi %[[ADDI_8]], %[[STEP]]
// CHECK: %[[ADDI_10:.*]] = arith.addi %[[DIVSI_9]], %[[CM1]]
// CHECK: %[[CMPI_11:.*]] = arith.cmpi sge, %[[ADDI_10]], %[[C0]]
// CHECK: %[[IF_10:.*]] = scf.if %[[CMPI_11]]
// CHECK: %[[ADDF_13:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_13]]
// CHECK: %[[CMPI_10:.*]] = arith.cmpi sge, %[[DIVSI_9]], %[[C1]]
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_10]]
// CHECK: %[[ADDF_14:.*]] = arith.addf %{{.*}}#1, %{{.*}}#0
// CHECK: scf.yield %[[ADDF_14]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: scf.yield %[[CF0]]
// CHECK: }
// CHECK: %[[IF_11:.*]] = scf.if %[[CMPI_11]]
// CHECK: %[[MULF_13:.*]] = arith.mulf %[[IF_10]], %{{.*}}
// CHECK: scf.yield %[[MULF_13]]
// CHECK: %[[IF_12:.*]] = scf.if %[[CMPI_10]]
// CHECK: %[[MULF_14:.*]] = arith.mulf %[[IF_11]], %{{.*}}
// CHECK: scf.yield %[[MULF_14]]
// CHECK: } else {
// CHECK: scf.yield %{{.*}}
// CHECK: scf.yield %[[CF0]]
// CHECK: }
// CHECK: %[[SELECT_12:.*]] = arith.select %[[CMPI_11]], %[[IF_11]], %{{.*}}#0
// CHECK: memref.store %[[SELECT_12]], %{{.*}}[%{{.*}}]
// CHECK: %[[SELECT_13:.*]] = arith.select %[[CMPI_10]], %[[IF_12]], %{{.*}}#0
// CHECK: memref.store %[[SELECT_13]], %{{.*}}[%[[C0]]]
func.func @dynamic_loop_result(%A: memref<?xf32>, %result: memref<?xf32>, %lb: index, %ub: index, %step: index) {
%cf0 = arith.constant 1.0 : f32
%cf1 = arith.constant 33.0 : f32
Expand Down
Loading