Skip to content

Commit ca25934

Browse files
Add default option to only do loop fission for unit trip loops (#21069)
This option is useful since we don't get good performance when we fission multi-trip loops. Additionally the prefetching pass that makes use of fission is not set up to support multi-trip nested loops. Note that even if one candidate is found to be multi-trip we don't do the whole pass as we wont be doing prefetching in that case and hence no point in fissioning at all. --------- Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent f7f7ea8 commit ca25934

File tree

6 files changed

+122
-75
lines changed

6 files changed

+122
-75
lines changed

compiler/src/iree/compiler/Codegen/Common/FissionTransferOpsInControlFlow.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Codegen/Common/Passes.h"
88
#include "iree/compiler/Codegen/Common/Transforms.h"
99
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
10+
#include "iree/compiler/Codegen/Utils/Utils.h"
1011
#include "mlir/Analysis/SliceAnalysis.h"
1112
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1213
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -254,6 +255,9 @@ static FailureOr<FissionTarget> populateFissionTarget(scf::ForOp forOp) {
254255
struct FissionTransferOpsInControlFlowPass final
255256
: impl::FissionTransferOpsInControlFlowPassBase<
256257
FissionTransferOpsInControlFlowPass> {
258+
using impl::FissionTransferOpsInControlFlowPassBase<
259+
FissionTransferOpsInControlFlowPass>::
260+
FissionTransferOpsInControlFlowPassBase;
257261
void runOnOperation() override {
258262
FunctionOpInterface funcOp = getOperation();
259263
IRRewriter rewriter(funcOp.getContext());
@@ -267,6 +271,12 @@ struct FissionTransferOpsInControlFlowPass final
267271
if (failed(result)) {
268272
continue;
269273
}
274+
// When not doing multi-trip fission if we have even one multi-trip loop
275+
// we bail-out from this pass and dont do fission as we wont be doing any
276+
// prefetching which is the point of doing fission.
277+
if (!FissionMultiTrip && !neverRunsSecondIteration(forOp)) {
278+
return;
279+
}
270280
fissionTargets.push_back(result.value());
271281
}
272282

compiler/src/iree/compiler/Codegen/Common/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ def FissionTransferOpsInControlFlowPass : InterfacePass<"iree-codegen-fission-tr
355355
let dependentDialects = [
356356
"memref::MemRefDialect"
357357
];
358+
let options = [
359+
Option<"FissionMultiTrip", "fission-multi-trip",
360+
"bool", /*default=*/"false",
361+
"Allow fission in presence of loops with greater than one trip count.">
362+
];
358363
}
359364

360365
def FlattenMemRefSubspanPass : Pass<"iree-codegen-flatten-memref-subspan", "ModuleOp"> {
Lines changed: 72 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
// RUN: iree-opt --split-input-file -pass-pipeline="builtin.module(func.func(iree-codegen-fission-transfer-ops-in-control-flow),cse,canonicalize)" %s | FileCheck %s
1+
// RUN: iree-opt --split-input-file -pass-pipeline="builtin.module(func.func(iree-codegen-fission-transfer-ops-in-control-flow{fission-multi-trip}),cse,canonicalize)" %s | FileCheck %s --check-prefixes=CHECK-ALL,MULTI
2+
// RUN: iree-opt --split-input-file -pass-pipeline="builtin.module(func.func(iree-codegen-fission-transfer-ops-in-control-flow),cse)" %s | FileCheck %s --check-prefixes=CHECK-ALL,SINGLE
23

3-
// CHECK-LABEL: @fission_global_read_to_private_write
4-
// CHECK-SAME: %[[ARG0:.*]]: memref<1x?x?x8xbf16, #amdgpu.address_space<fat_raw_buffer>>
5-
// CHECK-SAME: %[[ARG1:.*]]: index
6-
// CHECK-SAME: %[[ARG2:.*]]: i1
7-
// CHECK-SAME: %[[ARG3:.*]]: vector<1x1x1x8xbf16>
8-
// CHECK-SAME: %[[ARG4:.*]]: memref<1x1x1x8xbf16, #gpu.address_space<private>>
4+
// CHECK-ALL-LABEL: @fission_global_read_to_private_write
5+
// CHECK-ALL-SAME: %[[ARG0:.*]]: memref<1x?x?x8xbf16, #amdgpu.address_space<fat_raw_buffer>>
6+
// CHECK-ALL-SAME: %[[ARG1:.*]]: index
7+
// CHECK-ALL-SAME: %[[ARG2:.*]]: i1
8+
// CHECK-ALL-SAME: %[[ARG3:.*]]: vector<1x1x1x8xbf16>
9+
// CHECK-ALL-SAME: %[[ARG4:.*]]: memref<1x1x1x8xbf16, #gpu.address_space<private>>
910
func.func @fission_global_read_to_private_write(%arg0: memref<1x?x?x8xbf16, #amdgpu.address_space<fat_raw_buffer>>, %arg1: index, %arg2: i1, %arg3: vector<1x1x1x8xbf16>, %arg4: memref<1x1x1x8xbf16, #gpu.address_space<private>>) {
1011
%c0 = arith.constant 0 : index
1112
%c1 = arith.constant 1 : index
@@ -17,23 +18,26 @@ func.func @fission_global_read_to_private_write(%arg0: memref<1x?x?x8xbf16, #amd
1718
}
1819
return
1920
}
20-
// CHECK: %[[ALLOCA:.*]] = memref.alloca(%[[ARG1]])
21-
// CHECK: scf.for %[[ITER:.*]] = %c0 to %[[ARG1]] step %c1 {
22-
// CHECK: %[[read:.*]] = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]}
23-
// CHECK: vector.transfer_write %[[read]], %[[ALLOCA]][%[[ITER]], %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]}
24-
// CHECK: }
25-
// CHECK: scf.for %[[ITER:.*]] = %c0 to %[[ARG1]] step %c1 {
26-
// CHECK: %[[read:.*]] = vector.transfer_read %[[ALLOCA]][%[[ITER]], %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]}
27-
// CHECK: %[[select:.*]] = arith.select %[[ARG2]], %[[read]], %[[ARG3]]
28-
// CHECK: vector.transfer_write %[[select]], %arg4[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]}
29-
// CHECK: }
21+
// MULTI: %[[ALLOCA:.*]] = memref.alloca(%[[ARG1]])
22+
// MULTI: scf.for %[[ITER:.*]] = %c0 to %[[ARG1]] step %c1 {
23+
// MULTI: %[[read:.*]] = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]}
24+
// MULTI: vector.transfer_write %[[read]], %[[ALLOCA]][%[[ITER]], %c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]}
25+
// MULTI: }
26+
// MULTI: scf.for %[[ITER:.*]] = %c0 to %[[ARG1]] step %c1 {
27+
// MULTI: %[[read:.*]] = vector.transfer_read %[[ALLOCA]][%[[ITER]], %c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]}
28+
// MULTI: %[[select:.*]] = arith.select %[[ARG2]], %[[read]], %[[ARG3]]
29+
// MULTI: vector.transfer_write %[[select]], %arg4[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]}
30+
// MULTI: }
31+
32+
// SINGLE: scf.for
33+
// SINGLE-NOT: scf.for
3034

3135
// -----
3236

33-
// CHECK-LABEL: @fission_global_read_to_workgroup_write
34-
// CHECK-SAME: %[[ARG0:.*]]: index
35-
// CHECK-SAME: %[[ARG1:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
36-
// CHECK-SAME: %[[ARG2:.*]]: memref<1x4xf32, #gpu.address_space<workgroup>>
37+
// CHECK-ALL-LABEL: @fission_global_read_to_workgroup_write
38+
// CHECK-ALL-SAME: %[[ARG0:.*]]: index
39+
// CHECK-ALL-SAME: %[[ARG1:.*]]: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
40+
// CHECK-ALL-SAME: %[[ARG2:.*]]: memref<1x4xf32, #gpu.address_space<workgroup>>
3741
func.func @fission_global_read_to_workgroup_write(%arg0: index, %arg1: memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %arg2: memref<1x4xf32, #gpu.address_space<workgroup>>) {
3842
%c0 = arith.constant 0 : index
3943
%c16 = arith.constant 16 : index
@@ -45,28 +49,31 @@ func.func @fission_global_read_to_workgroup_write(%arg0: index, %arg1: memref<?x
4549
}
4650
return
4751
}
48-
// CHECK: %[[SUB:.*]] = arith.subi %c16, %[[ARG0]]
49-
// CHECK: %[[DIV:.*]] = arith.ceildivui %[[SUB]], %c128
50-
// CHECK: %[[ALLOCA:.*]] = memref.alloca(%[[DIV]])
51-
// CHECK: scf.for %[[ITER:.*]] = %[[ARG0]] to %c16 step %c128 {
52-
// CHECK: %[[READ:.*]] = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]}
53-
// CHECK: %[[SUB:.*]] = arith.subi %[[ITER]], %[[ARG0]]
54-
// CHECK: %[[DIV:.*]] = arith.divui %[[SUB]], %c128
55-
// CHECK: vector.transfer_write %[[READ]], %[[ALLOCA]][%[[DIV]], %c0, %c0] {in_bounds = [true, true]}
56-
// CHECK: }
57-
// CHECK: scf.for %[[ITER:.*]] = %[[ARG0]] to %c16 step %c128 {
58-
// CHECK: %[[SUB:.*]] = arith.subi %[[ITER]], %[[ARG0]]
59-
// CHECK: %[[DIV:.*]] = arith.divui %[[SUB]], %c128
60-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ALLOCA]][%[[DIV]], %c0, %c0], %cst {in_bounds = [true, true]}
61-
// CHECK: vector.transfer_write %[[READ]], %arg2[%c0, %c0] {in_bounds = [true, true]}
62-
// CHECK: }
52+
// MULTI: %[[SUB:.*]] = arith.subi %c16, %[[ARG0]]
53+
// MULTI: %[[DIV:.*]] = arith.ceildivui %[[SUB]], %c128
54+
// MULTI: %[[ALLOCA:.*]] = memref.alloca(%[[DIV]])
55+
// MULTI: scf.for %[[ITER:.*]] = %[[ARG0]] to %c16 step %c128 {
56+
// MULTI: %[[READ:.*]] = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]}
57+
// MULTI: %[[SUB:.*]] = arith.subi %[[ITER]], %[[ARG0]]
58+
// MULTI: %[[DIV:.*]] = arith.divui %[[SUB]], %c128
59+
// MULTI: vector.transfer_write %[[READ]], %[[ALLOCA]][%[[DIV]], %c0, %c0] {in_bounds = [true, true]}
60+
// MULTI: }
61+
// MULTI: scf.for %[[ITER:.*]] = %[[ARG0]] to %c16 step %c128 {
62+
// MULTI: %[[SUB:.*]] = arith.subi %[[ITER]], %[[ARG0]]
63+
// MULTI: %[[DIV:.*]] = arith.divui %[[SUB]], %c128
64+
// MULTI: %[[READ:.*]] = vector.transfer_read %[[ALLOCA]][%[[DIV]], %c0, %c0], %cst {in_bounds = [true, true]}
65+
// MULTI: vector.transfer_write %[[READ]], %arg2[%c0, %c0] {in_bounds = [true, true]}
66+
// MULTI: }
67+
68+
// SINGLE: scf.for
69+
// SINGLE-NOT: scf.for
6370

6471
// -----
6572

66-
// CHECK-LABEL: @no_fission_global_read_to_global_write
67-
// CHECK-SAME: %[[ARG0:.*]]: memref<1x?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
68-
// CHECK-SAME: %[[ARG1:.*]]: memref<1x?x?xf32, #gpu.address_space<global>>
69-
// CHECK-SAME: %[[ARG2:.*]]: index
73+
// CHECK-ALL-LABEL: @no_fission_global_read_to_global_write
74+
// CHECK-ALL-SAME: %[[ARG0:.*]]: memref<1x?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
75+
// CHECK-ALL-SAME: %[[ARG1:.*]]: memref<1x?x?xf32, #gpu.address_space<global>>
76+
// CHECK-ALL-SAME: %[[ARG2:.*]]: index
7077
func.func @no_fission_global_read_to_global_write(%arg0: memref<1x?x?xf32, #amdgpu.address_space<fat_raw_buffer>>, %arg1: memref<1x?x?xf32, #gpu.address_space<global>>, %arg2: index) {
7178
%c0 = arith.constant 0 : index
7279
%c1 = arith.constant 1 : index
@@ -77,8 +84,28 @@ func.func @no_fission_global_read_to_global_write(%arg0: memref<1x?x?xf32, #amdg
7784
}
7885
return
7986
}
80-
// CHECK: scf.for %[[ITER:.*]] = %c0 to %[[ARG2]] step %c1 {
81-
// CHECK: %[[READ:.*]] = vector.transfer_read
82-
// CHECK: vector.transfer_write %[[READ]], %arg1[%[[ITER]], %c0, %c0] {in_bounds = [true, true, true]}
83-
// CHECK: }
84-
// CHECK-NOT: scf.for
87+
// MULTI: scf.for %[[ITER:.*]] = %c0 to %[[ARG2]] step %c1 {
88+
// MULTI: %[[READ:.*]] = vector.transfer_read
89+
// MULTI: vector.transfer_write %[[READ]], %arg1[%[[ITER]], %c0, %c0] {in_bounds = [true, true, true]}
90+
// MULTI: }
91+
// MULTI-NOT: scf.for
92+
93+
// SINGLE: scf.for
94+
// SINGLE-NOT: scf.for
95+
96+
// -----
97+
98+
// CHECK-ALL-LABEL: @fission_unit_trip
99+
func.func @fission_unit_trip(%arg0: memref<1x?x?x8xbf16, #amdgpu.address_space<fat_raw_buffer>>, %arg1: index, %arg2: i1, %arg3: vector<1x1x1x8xbf16>, %arg4: memref<1x1x1x8xbf16, #gpu.address_space<private>>) {
100+
%c0 = arith.constant 0 : index
101+
%c1 = arith.constant 1 : index
102+
%cst = arith.constant 0.000000e+00 : bf16
103+
%ub = affine.min affine_map<(d0) -> (1, d0)>(%arg1)
104+
scf.for %arg5 = %c0 to %ub step %c1 {
105+
%read = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true, true]} : memref<1x?x?x8xbf16, #amdgpu.address_space<fat_raw_buffer>>, vector<1x1x1x8xbf16>
106+
%select = arith.select %arg2, %read, %arg3 : vector<1x1x1x8xbf16>
107+
vector.transfer_write %select, %arg4[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<1x1x1x8xbf16>, memref<1x1x1x8xbf16, #gpu.address_space<private>>
108+
}
109+
return
110+
}
111+
// CHECK-ALL-COUNT-2: scf.for

compiler/src/iree/compiler/Codegen/Transforms/RemoveSingleIterationLoop.cpp

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "iree/compiler/Codegen/Transforms/Transforms.h"
15+
#include "iree/compiler/Codegen/Utils/Utils.h"
1516
#include "llvm/Support/Debug.h"
1617
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1718
#include "mlir/Dialect/Affine/Utils.h"
@@ -60,36 +61,6 @@ static void replaceForWithIf(PatternRewriter &rewriter, scf::ForOp op,
6061
rewriter.replaceOp(op, ifOp);
6162
}
6263

63-
/// Return true if we can prove that the we always run at least the first
64-
/// iteration of the ForOp.
65-
static bool alwaysRunsFirstIteration(scf::ForOp op) {
66-
// Can't perform the analysis if the loops's bounds aren't index-typed.
67-
if (!op.getInductionVar().getType().isIndex())
68-
return false;
69-
FailureOr<bool> isLb = ValueBoundsConstraintSet::compare(
70-
getAsOpFoldResult(op.getLowerBound()), ValueBoundsConstraintSet::LT,
71-
getAsOpFoldResult(op.getUpperBound()));
72-
return isLb.value_or(false);
73-
}
74-
75-
/// Return true if we can prove that the we never run more than one iteration of
76-
/// the ForOp.
77-
static bool neverRunsSecondIteration(scf::ForOp op) {
78-
// Can't perform the analysis if the loops's bounds aren't index-typed.
79-
if (!op.getInductionVar().getType().isIndex())
80-
return false;
81-
// If the upper bound (ub) is less than or equal to the loop step, then
82-
// lower bound + step must be greater than the upper bound, assuming the
83-
// lower bound is non-negative.
84-
FailureOr<bool> isUbUnderStep = ValueBoundsConstraintSet::compare(
85-
getAsOpFoldResult(op.getUpperBound()), ValueBoundsConstraintSet::LE,
86-
getAsOpFoldResult(op.getStep()));
87-
FailureOr<bool> isLbNonNegative = ValueBoundsConstraintSet::compare(
88-
getAsOpFoldResult(op.getLowerBound()), ValueBoundsConstraintSet::GE,
89-
getAsIndexOpFoldResult(op.getContext(), 0));
90-
return isUbUnderStep.value_or(false) && isLbNonNegative.value_or(false);
91-
}
92-
9364
namespace {
9465
/// Rewriting pattern that replaces single-iteration loops with their bodies.
9566
struct SimplifyTrivialLoops : public OpRewritePattern<scf::ForOp> {

compiler/src/iree/compiler/Codegen/Utils/Utils.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1903,4 +1903,30 @@ std::optional<int64_t> getConstantIndex(Value value) {
19031903
return val.getSExtValue();
19041904
}
19051905

1906+
bool alwaysRunsFirstIteration(scf::ForOp op) {
1907+
// Can't perform the analysis if the loops's bounds aren't index-typed.
1908+
if (!op.getInductionVar().getType().isIndex())
1909+
return false;
1910+
FailureOr<bool> isLb = ValueBoundsConstraintSet::compare(
1911+
getAsOpFoldResult(op.getLowerBound()), ValueBoundsConstraintSet::LT,
1912+
getAsOpFoldResult(op.getUpperBound()));
1913+
return isLb.value_or(false);
1914+
}
1915+
1916+
bool neverRunsSecondIteration(scf::ForOp op) {
1917+
// Can't perform the analysis if the loops's bounds aren't index-typed.
1918+
if (!op.getInductionVar().getType().isIndex())
1919+
return false;
1920+
// If the upper bound (ub) is less than or equal to the loop step, then
1921+
// lower bound + step must be greater than the upper bound, assuming the
1922+
// lower bound is non-negative.
1923+
FailureOr<bool> isUbUnderStep = ValueBoundsConstraintSet::compare(
1924+
getAsOpFoldResult(op.getUpperBound()), ValueBoundsConstraintSet::LE,
1925+
getAsOpFoldResult(op.getStep()));
1926+
FailureOr<bool> isLbNonNegative = ValueBoundsConstraintSet::compare(
1927+
getAsOpFoldResult(op.getLowerBound()), ValueBoundsConstraintSet::GE,
1928+
getAsIndexOpFoldResult(op.getContext(), 0));
1929+
return isUbUnderStep.value_or(false) && isLbNonNegative.value_or(false);
1930+
}
1931+
19061932
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Utils/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,14 @@ inferSizesFromIR(linalg::LinalgOp linalgOp, std::optional<OpResult> opResult);
322322
/// Returns the underlying index if the given value is a constant index.
323323
std::optional<int64_t> getConstantIndex(Value value);
324324

325+
/// Return true if we can prove that the we always run at least the first
326+
/// iteration of the ForOp.
327+
bool alwaysRunsFirstIteration(scf::ForOp op);
328+
329+
/// Return true if we can prove that the we never run more than one iteration of
330+
/// the ForOp.
331+
bool neverRunsSecondIteration(scf::ForOp op);
332+
325333
} // namespace mlir::iree_compiler
326334

327335
#endif // IREE_COMPILER_CODEGEN_UTILS_UTILS_H_

0 commit comments

Comments
 (0)