Skip to content

Commit d463fa9

Browse files
committed
[mlir][scf] Implement Conversion from scf.parallel to Nested scf.for
Add a utility function/transform operation to convert `scf.parallel` loops to nested `scf.for` loops.
1 parent 7d92756 commit d463fa9

File tree

9 files changed

+320
-0
lines changed

9 files changed

+320
-0
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
105105
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
106106
}
107107

108+
def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
109+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
110+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
111+
let summary = "Converts scf.parallel into a nest of scf.for operations";
112+
let description = [{
113+
Converts the `scf.parallel` operation pointed to by the given handle into a
114+
set of nested `scf.for` operations. Each new operation corresponds to one
115+
dimension of the original parallel loop.
116+
117+
The operand handle must be associated with exactly one payload operation.
118+
119+
Loops with shared outputs are currently not supported.
120+
121+
#### Return Modes
122+
123+
Consumes the operand handle. Produces a silenceable failure if the operand
124+
is not associated with a single `scf.parallel` payload operation.
125+
Returns as many handles as the given `parallel` op has dimensions that are
126+
associated with the generated `scf.for` loops.
127+
Produces a silenceable failure if another number of resulting handles is
128+
requested.
129+
}];
130+
let arguments = (ins TransformHandleTypeInterface:$target);
131+
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
132+
133+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
134+
}
135+
108136
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
109137
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
110138
DeclareOpInterfaceMethods<TransformOpInterface>]> {

mlir/include/mlir/Dialect/SCF/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
6262
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
6363
std::unique_ptr<Pass> createForallToParallelLoopPass();
6464

65+
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
66+
std::unique_ptr<Pass> createParallelForToNestedForsPass();
67+
6568
// Creates a pass which lowers for loops into while loops.
6669
std::unique_ptr<Pass> createForToWhileLoopPass();
6770

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
124124
let constructor = "mlir::createForallToParallelLoopPass()";
125125
}
126126

127+
def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
128+
let summary = "Convert SCF parallel for loops to nested SCF for loops";
129+
let constructor = "mlir::createParallelForToNestedForsPass()";
130+
let description = [{
131+
This pass transforms SCF.ParallelOp operations into a nest of SCF.ForOp
132+
operations. The transformation is useful for cases where the parallel loop
133+
can be expressed as a series of sequential iterations, allowing for more
134+
fine-grained control over the loop execution.
135+
}];
136+
}
137+
127138
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
128139
let summary = "Convert SCF for loops to SCF while loops";
129140
let constructor = "mlir::createForToWhileLoopPass()";

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
4242
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
4343
ParallelOp *result = nullptr);
4444

45+
/// Try converting scf.forall into an scf.parallel loop.
46+
/// The conversion is only supported for forall operations with no results.
47+
LogicalResult parallelForToNestedFors(RewriterBase &rewriter,
48+
ParallelOp parallelOp,
49+
ForOp *result = nullptr);
50+
4551
/// Fuses all adjacent scf.parallel operations with identical bounds and step
4652
/// into one scf.parallel operations. Uses a naive aliasing and dependency
4753
/// analysis.

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,44 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
149149
return DiagnosedSilenceableFailure::success();
150150
}
151151

152+
//===----------------------------------------------------------------------===//
153+
// ParallelForToNestedForOps
154+
//===----------------------------------------------------------------------===//
155+
156+
DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
157+
transform::TransformRewriter &rewriter,
158+
transform::TransformResults &results, transform::TransformState &state) {
159+
auto payload = state.getPayloadOps(getTarget());
160+
if (!llvm::hasSingleElement(payload))
161+
return emitSilenceableError() << "expected a single payload op";
162+
163+
auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
164+
if (!target) {
165+
DiagnosedSilenceableFailure diag =
166+
emitSilenceableError() << "expected the payload to be scf.parallel";
167+
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
168+
return diag;
169+
}
170+
171+
if (getNumResults() != 1) {
172+
DiagnosedSilenceableFailure diag = emitSilenceableError()
173+
<< "op expects one result, given "
174+
<< getNumResults();
175+
diag.attachNote(target.getLoc()) << "payload op";
176+
return diag;
177+
}
178+
179+
scf::ForOp opResult;
180+
if (failed(scf::parallelForToNestedFors(rewriter, target, &opResult))) {
181+
DiagnosedSilenceableFailure diag =
182+
emitSilenceableError() << "failed to convert parallel into nested fors";
183+
return diag;
184+
}
185+
186+
results.set(cast<OpResult>(getTransformed()[0]), {opResult});
187+
return DiagnosedSilenceableFailure::success();
188+
}
189+
152190
//===----------------------------------------------------------------------===//
153191
// LoopOutlineOp
154192
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
88
LoopPipelining.cpp
99
LoopRangeFolding.cpp
1010
LoopSpecialization.cpp
11+
ParallelForToNestedFors.cpp
1112
ParallelLoopCollapsing.cpp
1213
ParallelLoopFusion.cpp
1314
ParallelLoopTiling.cpp
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.ParallelOp to nested scf.for ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/IR/SCF.h"
14+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
15+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "llvm/Support/Debug.h"
18+
19+
namespace mlir {
20+
#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
21+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22+
} // namespace mlir
23+
24+
using namespace mlir;
25+
26+
LogicalResult mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
27+
scf::ParallelOp parallelOp,
28+
scf::ForOp *result) {
29+
30+
if (!parallelOp.getResults().empty()) {
31+
parallelOp->emitError("Currently ScfParallel to ScfFor conversion "
32+
"doesn't support ScfParallel with results.");
33+
return failure();
34+
}
35+
36+
rewriter.setInsertionPoint(parallelOp);
37+
38+
Location loc = parallelOp.getLoc();
39+
auto lowerBounds = parallelOp.getLowerBound();
40+
auto upperBounds = parallelOp.getUpperBound();
41+
auto steps = parallelOp.getStep();
42+
43+
assert(lowerBounds.size() == upperBounds.size() &&
44+
lowerBounds.size() == steps.size() &&
45+
"Mismatched parallel loop bounds");
46+
47+
SmallVector<Value> ivs;
48+
auto loopNest =
49+
scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
50+
51+
auto oldInductionVars = parallelOp.getInductionVars();
52+
auto newInductionVars = llvm::map_to_vector(
53+
loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
54+
assert(oldInductionVars.size() == newInductionVars.size() &&
55+
"Mismatched induction variables");
56+
for (auto [oldIV, newIV] : llvm::zip(oldInductionVars, newInductionVars))
57+
oldIV.replaceAllUsesWith(newIV);
58+
59+
auto *linearizedBody = loopNest.loops.back().getBody();
60+
Block &parallelBody = *parallelOp.getBody();
61+
for (Operation &op : llvm::make_early_inc_range(parallelBody)) {
62+
// Skip the terminator of the parallelOp body.
63+
if (&op == parallelBody.getTerminator())
64+
continue;
65+
op.moveBefore(linearizedBody->getTerminator());
66+
}
67+
rewriter.eraseOp(parallelOp);
68+
if (result)
69+
*result = loopNest.loops.front();
70+
return success();
71+
}
72+
73+
namespace {
74+
struct ParallelForToNestedFors final
75+
: public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
76+
void runOnOperation() override {
77+
Operation *parentOp = getOperation();
78+
IRRewriter rewriter(parentOp->getContext());
79+
80+
parentOp->walk([&](scf::ParallelOp parallelOp) {
81+
if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
82+
return signalPassFailure();
83+
}
84+
});
85+
}
86+
};
87+
} // namespace
88+
89+
std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
90+
return std::make_unique<ParallelForToNestedFors>();
91+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
6+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
7+
func.call @callee(%i, %j) : (index, index) -> ()
8+
}
9+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
10+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
11+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
12+
// CHECK: }
13+
// CHECK: }
14+
return
15+
}
16+
17+
// -----
18+
19+
func.func private @callee(%i: index, %j: index)
20+
21+
func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
22+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
23+
func.call @callee(%i, %j) : (index, index) -> ()
24+
}
25+
26+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
27+
func.call @callee(%i, %j) : (index, index) -> ()
28+
}
29+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
30+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
31+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
32+
// CHECK: }
33+
// CHECK: }
34+
// CHECK: scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
35+
// CHECK: scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
36+
// CHECK: func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
37+
// CHECK: }
38+
// CHECK: }
39+
40+
return
41+
}
42+
43+
// -----
44+
45+
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
46+
47+
func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
48+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
49+
scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
50+
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
51+
}
52+
}
53+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
54+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
55+
// CHECK: scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
56+
// CHECK: scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
57+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
58+
// CHECK: }
59+
// CHECK: }
60+
// CHECK: }
61+
// CHECK: }
62+
return
63+
}
64+
65+
// -----
66+
func.func private @callee(%i: index, %j: index) -> i32
67+
68+
func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
69+
%c0 = arith.constant 0 : i32
70+
// expected-error@+1 {{Currently ScfParallel to ScfFor conversion doesn't support ScfParallel with results}}
71+
%0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
72+
%curr = func.call @callee(%i, %j) : (index, index) -> i32
73+
scf.reduce(%curr : i32) {
74+
^bb0(%arg3: i32, %arg4: i32):
75+
%3 = arith.addi %arg3, %arg4 : i32
76+
scf.reduce.return %3 : i32
77+
}
78+
}
79+
return %0 : i32
80+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
6+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
7+
func.call @callee(%i, %j) : (index, index) -> ()
8+
}
9+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
10+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
11+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
12+
// CHECK: }
13+
// CHECK: }
14+
return
15+
}
16+
17+
module attributes {transform.with_named_sequence} {
18+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
19+
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
20+
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
21+
transform.yield
22+
}
23+
}
24+
25+
// -----
26+
27+
func.func private @callee(%i: index, %j: index)
28+
29+
func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
30+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
31+
func.call @callee(%i, %j) : (index, index) -> ()
32+
}
33+
34+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
35+
func.call @callee(%i, %j) : (index, index) -> ()
36+
}
37+
38+
return
39+
}
40+
41+
module attributes {transform.with_named_sequence} {
42+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
43+
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
44+
// expected-error @below {{expected a single payload op}}
45+
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
46+
transform.yield
47+
}
48+
}
49+
50+
// -----
51+
52+
// expected-note @below {{payload op}}
53+
func.func private @callee(%i: index, %j: index)
54+
55+
module attributes {transform.with_named_sequence} {
56+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
57+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
58+
// expected-error @below {{expected the payload to be scf.parallel}}
59+
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
60+
transform.yield
61+
}
62+
}

0 commit comments

Comments
 (0)