Skip to content

Commit 12b9c0d

Browse files
authored
[mlir][scf] Implement Conversion from scf.parallel to Nested scf.for (#147692)
Add a utility function/transform operation to convert `scf.parallel` loops to nested `scf.for` loops.
1 parent 7cd1ce3 commit 12b9c0d

File tree

9 files changed

+316
-0
lines changed

9 files changed

+316
-0
lines changed

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,34 @@ def ForallToParallelOp : Op<Transform_Dialect, "loop.forall_to_parallel",
105105
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
106106
}
107107

108+
def ParallelForToNestedForOps : Op<Transform_Dialect, "loop.parallel_for_to_nested_fors",
109+
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
110+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
111+
let summary = "Converts scf.parallel into a nest of scf.for operations";
112+
let description = [{
113+
Converts the `scf.parallel` operation pointed to by the given handle into a
114+
set of nested `scf.for` operations. Each new operation corresponds to one
115+
dimension of the original parallel loop.
116+
117+
The operand handle must be associated with exactly one payload operation.
118+
119+
Loops with shared outputs are currently not supported.
120+
121+
#### Return Modes
122+
123+
Consumes the operand handle. Produces a silenceable failure if the operand
124+
is not associated with a single `scf.parallel` payload operation.
125+
Returns as many handles as the given `parallel` op has dimensions that are
126+
associated with the generated `scf.for` loops.
127+
Produces a silenceable failure if another number of resulting handles is
128+
requested.
129+
}];
130+
let arguments = (ins TransformHandleTypeInterface:$target);
131+
let results = (outs Variadic<TransformHandleTypeInterface>:$transformed);
132+
133+
let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)";
134+
}
135+
108136
def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
109137
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
110138
DeclareOpInterfaceMethods<TransformOpInterface>]> {

mlir/include/mlir/Dialect/SCF/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ std::unique_ptr<Pass> createForallToForLoopPass();
6262
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
6363
std::unique_ptr<Pass> createForallToParallelLoopPass();
6464

65+
/// Creates a pass that converts SCF forall loops to SCF parallel loops.
66+
std::unique_ptr<Pass> createParallelForToNestedForsPass();
67+
6568
// Creates a pass which lowers for loops into while loops.
6669
std::unique_ptr<Pass> createForToWhileLoopPass();
6770

mlir/include/mlir/Dialect/SCF/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,17 @@ def SCFForallToParallelLoop : Pass<"scf-forall-to-parallel"> {
130130
let constructor = "mlir::createForallToParallelLoopPass()";
131131
}
132132

133+
def SCFParallelForToNestedFors : Pass<"scf-parallel-for-to-nested-fors"> {
134+
let summary = "Convert SCF parallel for loops to nested SCF for loops";
135+
let constructor = "mlir::createParallelForToNestedForsPass()";
136+
let description = [{
137+
This pass transforms SCF::ParallelOp operations into a nest of SCF::ForOp
138+
operations. The transformation is useful for cases where the parallel loop
139+
can be expressed as a series of sequential iterations, allowing for more
140+
fine-grained control over the loop execution.
141+
}];
142+
}
143+
133144
def SCFForToWhileLoop : Pass<"scf-for-to-while"> {
134145
let summary = "Convert SCF for loops to SCF while loops";
135146
let constructor = "mlir::createForToWhileLoopPass()";

mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_SCF_TRANSFORMS_TRANSFORMS_H_
1414
#define MLIR_DIALECT_SCF_TRANSFORMS_TRANSFORMS_H_
1515

16+
#include "mlir/Dialect/SCF/IR/SCF.h"
1617
#include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
1718
#include "mlir/Support/LLVM.h"
1819
#include "llvm/ADT/ArrayRef.h"
@@ -42,6 +43,11 @@ LogicalResult forallToForLoop(RewriterBase &rewriter, ForallOp forallOp,
4243
LogicalResult forallToParallelLoop(RewriterBase &rewriter, ForallOp forallOp,
4344
ParallelOp *result = nullptr);
4445

46+
/// Try converting scf.forall into an scf.parallel loop.
47+
/// The conversion is only supported for parallel operations with no results.
48+
FailureOr<scf::LoopNest> parallelForToNestedFors(RewriterBase &rewriter,
49+
ParallelOp parallelOp);
50+
4551
/// Fuses all adjacent scf.parallel operations with identical bounds and step
4652
/// into one scf.parallel operations. Uses a naive aliasing and dependency
4753
/// analysis.

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,45 @@ transform::ForallToParallelOp::apply(transform::TransformRewriter &rewriter,
146146
return DiagnosedSilenceableFailure::success();
147147
}
148148

149+
//===----------------------------------------------------------------------===//
150+
// ParallelForToNestedForOps
151+
//===----------------------------------------------------------------------===//
152+
153+
DiagnosedSilenceableFailure transform::ParallelForToNestedForOps::apply(
154+
transform::TransformRewriter &rewriter,
155+
transform::TransformResults &results, transform::TransformState &state) {
156+
auto payload = state.getPayloadOps(getTarget());
157+
if (!llvm::hasSingleElement(payload))
158+
return emitSilenceableError() << "expected a single payload op";
159+
160+
auto target = dyn_cast<scf::ParallelOp>(*payload.begin());
161+
if (!target) {
162+
DiagnosedSilenceableFailure diag =
163+
emitSilenceableError() << "expected the payload to be scf.parallel";
164+
diag.attachNote((*payload.begin())->getLoc()) << "payload op";
165+
return diag;
166+
}
167+
168+
if (getNumResults() != 1) {
169+
DiagnosedSilenceableFailure diag = emitSilenceableError()
170+
<< "op expects one result, given "
171+
<< getNumResults();
172+
diag.attachNote(target.getLoc()) << "payload op";
173+
return diag;
174+
}
175+
176+
FailureOr<scf::LoopNest> loopNest =
177+
scf::parallelForToNestedFors(rewriter, target);
178+
if (failed(loopNest)) {
179+
DiagnosedSilenceableFailure diag =
180+
emitSilenceableError() << "failed to convert parallel into nested fors";
181+
return diag;
182+
}
183+
184+
results.set(cast<OpResult>(getTransformed()[0]), {loopNest->loops.front()});
185+
return DiagnosedSilenceableFailure::success();
186+
}
187+
149188
//===----------------------------------------------------------------------===//
150189
// LoopOutlineOp
151190
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
99
LoopPipelining.cpp
1010
LoopRangeFolding.cpp
1111
LoopSpecialization.cpp
12+
ParallelForToNestedFors.cpp
1213
ParallelLoopCollapsing.cpp
1314
ParallelLoopFusion.cpp
1415
ParallelLoopTiling.cpp
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===- ParallelForToNestedFors.cpp - scf.parallel to nested scf.for ops --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Transforms SCF.ParallelOp to nested scf.for ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SCF/IR/SCF.h"
14+
#include "mlir/Dialect/SCF/Transforms/Passes.h"
15+
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "llvm/Support/Debug.h"
18+
19+
namespace mlir {
20+
#define GEN_PASS_DEF_SCFPARALLELFORTONESTEDFORS
21+
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22+
} // namespace mlir
23+
24+
#define DEBUG_TYPE "parallel-for-to-nested-fors"
25+
using namespace mlir;
26+
27+
FailureOr<scf::LoopNest>
28+
mlir::scf::parallelForToNestedFors(RewriterBase &rewriter,
29+
scf::ParallelOp parallelOp) {
30+
31+
if (!parallelOp.getResults().empty())
32+
return rewriter.notifyMatchFailure(
33+
parallelOp, "Currently scf.parallel to scf.for conversion doesn't "
34+
"support scf.parallel with results.");
35+
36+
rewriter.setInsertionPoint(parallelOp);
37+
38+
Location loc = parallelOp.getLoc();
39+
SmallVector<Value> lowerBounds = parallelOp.getLowerBound();
40+
SmallVector<Value> upperBounds = parallelOp.getUpperBound();
41+
SmallVector<Value> steps = parallelOp.getStep();
42+
43+
assert(lowerBounds.size() == upperBounds.size() &&
44+
lowerBounds.size() == steps.size() &&
45+
"Mismatched parallel loop bounds");
46+
47+
SmallVector<Value> ivs;
48+
scf::LoopNest loopNest =
49+
scf::buildLoopNest(rewriter, loc, lowerBounds, upperBounds, steps);
50+
51+
SmallVector<Value> newInductionVars = llvm::map_to_vector(
52+
loopNest.loops, [](scf::ForOp forOp) { return forOp.getInductionVar(); });
53+
Block *linearizedBody = loopNest.loops.back().getBody();
54+
Block *parallelBody = parallelOp.getBody();
55+
rewriter.eraseOp(parallelBody->getTerminator());
56+
rewriter.inlineBlockBefore(parallelBody, linearizedBody->getTerminator(),
57+
newInductionVars);
58+
rewriter.eraseOp(parallelOp);
59+
return loopNest;
60+
}
61+
62+
namespace {
63+
struct ParallelForToNestedFors final
64+
: public impl::SCFParallelForToNestedForsBase<ParallelForToNestedFors> {
65+
void runOnOperation() override {
66+
Operation *parentOp = getOperation();
67+
IRRewriter rewriter(parentOp->getContext());
68+
69+
parentOp->walk(
70+
[&](scf::ParallelOp parallelOp) {
71+
if (failed(scf::parallelForToNestedFors(rewriter, parallelOp))) {
72+
LLVM_DEBUG(
73+
llvm::dbgs()
74+
<< "Failed to convert scf.parallel to nested scf.for ops for:\n"
75+
<< parallelOp << "\n");
76+
return WalkResult::advance();
77+
}
78+
return WalkResult::advance();
79+
});
80+
}
81+
};
82+
} // namespace
83+
84+
std::unique_ptr<Pass> mlir::createParallelForToNestedForsPass() {
85+
return std::make_unique<ParallelForToNestedFors>();
86+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(scf-parallel-for-to-nested-fors))' -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
6+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
7+
func.call @callee(%i, %j) : (index, index) -> ()
8+
}
9+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
10+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
11+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
12+
// CHECK: }
13+
// CHECK: }
14+
return
15+
}
16+
17+
// -----
18+
19+
func.func private @callee(%i: index, %j: index)
20+
21+
func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
22+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
23+
func.call @callee(%i, %j) : (index, index) -> ()
24+
}
25+
26+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
27+
func.call @callee(%i, %j) : (index, index) -> ()
28+
}
29+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
30+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
31+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
32+
// CHECK: }
33+
// CHECK: }
34+
// CHECK: scf.for %[[VAL_2:.*]] = %[[ARG0]] to %[[ARG2]] step %[[ARG4]] {
35+
// CHECK: scf.for %[[VAL_3:.*]] = %[[ARG1]] to %[[ARG3]] step %[[ARG5]] {
36+
// CHECK: func.call @callee(%[[VAL_2]], %[[VAL_3]]) : (index, index) -> ()
37+
// CHECK: }
38+
// CHECK: }
39+
40+
return
41+
}
42+
43+
// -----
44+
45+
func.func private @callee(%i: index, %j: index, %k: index, %l: index)
46+
47+
func.func @nested(%lb1: index, %lb2: index, %lb3: index, %lb4: index, %ub1: index, %ub2: index, %ub3: index, %ub4: index, %step1: index, %step2: index, %step3: index, %step4: index) {
48+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
49+
scf.parallel (%k, %l) = (%lb3, %lb4) to (%ub3, %ub4) step (%step3, %step4) {
50+
func.call @callee(%i, %j, %k, %l) : (index, index, index, index) -> ()
51+
}
52+
}
53+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG4:.*]] step %[[ARG8:.*]] {
54+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG5:.*]] step %[[ARG9:.*]] {
55+
// CHECK: scf.for %[[VAL_2:.*]] = %[[ARG2:.*]] to %[[ARG6:.*]] step %[[ARG10:.*]] {
56+
// CHECK: scf.for %[[VAL_3:.*]] = %[[ARG3:.*]] to %[[ARG7:.*]] step %[[ARG11:.*]] {
57+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (index, index, index, index) -> ()
58+
// CHECK: }
59+
// CHECK: }
60+
// CHECK: }
61+
// CHECK: }
62+
return
63+
}
64+
65+
// -----
66+
func.func private @callee(%i: index, %j: index) -> i32
67+
68+
func.func @two_iters_with_reduce(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) -> i32 {
69+
%c0 = arith.constant 0 : i32
70+
// CHECK: scf.parallel
71+
%0 = scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) init (%c0) -> i32 {
72+
%curr = func.call @callee(%i, %j) : (index, index) -> i32
73+
scf.reduce(%curr : i32) {
74+
^bb0(%arg3: i32, %arg4: i32):
75+
%3 = arith.addi %arg3, %arg4 : i32
76+
scf.reduce.return %3 : i32
77+
}
78+
}
79+
return %0 : i32
80+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | FileCheck %s
2+
3+
func.func private @callee(%i: index, %j: index)
4+
5+
func.func @two_iters(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
6+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
7+
func.call @callee(%i, %j) : (index, index) -> ()
8+
}
9+
// CHECK: scf.for %[[VAL_0:.*]] = %[[ARG0:.*]] to %[[ARG2:.*]] step %[[ARG4:.*]] {
10+
// CHECK: scf.for %[[VAL_1:.*]] = %[[ARG1:.*]] to %[[ARG3:.*]] step %[[ARG5:.*]] {
11+
// CHECK: func.call @callee(%[[VAL_0]], %[[VAL_1]]) : (index, index) -> ()
12+
// CHECK: }
13+
// CHECK: }
14+
return
15+
}
16+
17+
module attributes {transform.with_named_sequence} {
18+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
19+
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
20+
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
21+
transform.yield
22+
}
23+
}
24+
25+
// -----
26+
27+
func.func private @callee(%i: index, %j: index)
28+
29+
func.func @repeated(%lb1: index, %lb2: index, %ub1: index, %ub2: index, %step1: index, %step2: index) {
30+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
31+
func.call @callee(%i, %j) : (index, index) -> ()
32+
}
33+
34+
scf.parallel (%i, %j) = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
35+
func.call @callee(%i, %j) : (index, index) -> ()
36+
}
37+
38+
return
39+
}
40+
41+
module attributes {transform.with_named_sequence} {
42+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
43+
%0 = transform.structured.match ops{["scf.parallel"]} in %arg0 : (!transform.any_op) -> !transform.any_op
44+
// expected-error @below {{expected a single payload op}}
45+
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> (!transform.any_op)
46+
transform.yield
47+
}
48+
}
49+
50+
// -----
51+
52+
// expected-note @below {{payload op}}
53+
func.func private @callee(%i: index, %j: index)
54+
55+
module attributes {transform.with_named_sequence} {
56+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
57+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
58+
// expected-error @below {{expected the payload to be scf.parallel}}
59+
transform.loop.parallel_for_to_nested_fors %0 : (!transform.any_op) -> !transform.any_op
60+
transform.yield
61+
}
62+
}

0 commit comments

Comments
 (0)