Skip to content

Commit cc3b28f

Browse files
authored
[Codegen][GPU] Improve loop fusion pattern verification (#18671)
The current loop fusion patterns don't verify that the consumer loop won't be predicated after resolution. This is required because the loop fusion pattern introduces barrier semantics to the loop body which will result in invalid IR if the loop resolves to an `scf.if` (or anything that could lead to thread divergence). This also makes it so the loop fusion pattern does not require the consumer loop to have a `tensor.extract_slice`, improving the robustness of the pattern.
1 parent 88cb0ab commit cc3b28f

File tree

3 files changed

+213
-34
lines changed

3 files changed

+213
-34
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp

Lines changed: 98 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
8+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
79
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
810
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
911
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
1012
#include "iree/compiler/Codegen/Transforms/Transforms.h"
13+
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
1114
#include "llvm/ADT/TypeSwitch.h"
1215
#include "llvm/Support/Casting.h"
1316
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1417
#include "mlir/Dialect/Func/IR/FuncOps.h"
18+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1519
#include "mlir/Dialect/SCF/IR/SCF.h"
1620
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1722
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
1823
#include "mlir/Interfaces/FunctionInterfaces.h"
1924
#include "mlir/Interfaces/LoopLikeInterface.h"
@@ -36,36 +41,82 @@ struct FuseAndHoistParallelLoopsPass final
3641
};
3742
} // namespace
3843

39-
struct FuseForalls final : OpRewritePattern<tensor::ExtractSliceOp> {
44+
static std::optional<int64_t> getStaticForallTripCount(scf::ForallOp forall) {
45+
// TODO: Handle non-normalized loops.
46+
if (!forall.isNormalized()) {
47+
return std::nullopt;
48+
}
49+
int64_t tripCount = 1;
50+
for (OpFoldResult ub : forall.getMixedUpperBound()) {
51+
std::optional<int64_t> maybeConstantUb = getConstantIntValue(ub);
52+
if (!maybeConstantUb) {
53+
return std::nullopt;
54+
}
55+
tripCount *= *maybeConstantUb;
56+
}
57+
return tripCount;
58+
}
59+
60+
static bool forallTripCountMatchesWorkgroupSize(scf::ForallOp forallOp,
61+
int64_t flatWorkgroupSize) {
62+
std::optional<int64_t> maybeTripCount = getStaticForallTripCount(forallOp);
63+
if (!maybeTripCount) {
64+
return false;
65+
}
66+
67+
// For lane mapped foralls we need to verify that it is contained within
68+
// a parent warp mapped op that combines to match the workggroup size.
69+
if (forallOpHasMappingType<IREE::GPU::LaneIdAttr>(forallOp)) {
70+
auto parentForall = forallOp->getParentOfType<scf::ForallOp>();
71+
if (!parentForall ||
72+
!forallOpHasMappingType<gpu::GPUWarpMappingAttr>(parentForall)) {
73+
return false;
74+
}
75+
76+
std::optional<int64_t> maybeParentTripCount =
77+
getStaticForallTripCount(parentForall);
78+
if (!maybeParentTripCount) {
79+
return false;
80+
}
81+
82+
return *maybeParentTripCount * *maybeTripCount == flatWorkgroupSize;
83+
}
84+
85+
// All other loops must be mapped to threads to compare.
86+
if (!forallOpHasMappingType<gpu::GPUThreadMappingAttr>(forallOp)) {
87+
return false;
88+
}
89+
90+
return *maybeTripCount == flatWorkgroupSize;
91+
}
92+
struct FuseForalls final : OpRewritePattern<scf::ForallOp> {
4093
using OpRewritePattern::OpRewritePattern;
41-
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
94+
FuseForalls(MLIRContext *ctx, int64_t flatWorkgroupSize, PatternBenefit b = 1)
95+
: OpRewritePattern<scf::ForallOp>(ctx, b),
96+
flatWorkgroupSize(flatWorkgroupSize) {}
97+
LogicalResult matchAndRewrite(scf::ForallOp producerForall,
4298
PatternRewriter &rewriter) const override {
43-
auto sliceParent = sliceOp->getParentOfType<scf::ForallOp>();
44-
if (!sliceParent) {
45-
return failure();
99+
if (!producerForall->hasOneUse()) {
100+
return rewriter.notifyMatchFailure(producerForall,
101+
"multi-use producer forall");
46102
}
47103

48-
SmallVector<Operation *> consumerChain = {sliceOp};
49-
Operation *currProducer = sliceOp.getSource().getDefiningOp();
50-
while (currProducer && !llvm::isa<scf::ForallOp>(currProducer) &&
51-
currProducer->hasOneUse()) {
52-
consumerChain.insert(consumerChain.begin(), currProducer);
53-
currProducer =
54-
llvm::TypeSwitch<Operation *, Operation *>(currProducer)
55-
.Case<tensor::ExpandShapeOp>([](tensor::ExpandShapeOp expand) {
56-
return expand.getSrc().getDefiningOp();
57-
})
58-
.Case<tensor::CollapseShapeOp>(
59-
[](tensor::CollapseShapeOp collapse) {
60-
return collapse.getSrc().getDefiningOp();
61-
})
62-
.Default([](Operation *) { return nullptr; });
63-
}
64-
65-
auto producerForall =
66-
llvm::dyn_cast_if_present<scf::ForallOp>(currProducer);
67-
if (!producerForall) {
68-
return failure();
104+
SmallVector<Operation *> consumerChain;
105+
Operation *currProducer = *producerForall->user_begin();
106+
while (currProducer && currProducer->hasOneUse()) {
107+
consumerChain.push_back(currProducer);
108+
if (!isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp>(currProducer)) {
109+
break;
110+
}
111+
currProducer = *currProducer->user_begin();
112+
}
113+
114+
auto consumerForall = currProducer->getParentOfType<scf::ForallOp>();
115+
if (!consumerForall || !forallTripCountMatchesWorkgroupSize(
116+
consumerForall, flatWorkgroupSize)) {
117+
return rewriter.notifyMatchFailure(
118+
producerForall,
119+
"no consumer forall with trip count matching workgroup size");
69120
}
70121

71122
// TODO: Allow extracting multiple uses within the same consumer loop. Still
@@ -75,9 +126,13 @@ struct FuseForalls final : OpRewritePattern<tensor::ExtractSliceOp> {
75126
return failure();
76127
}
77128

78-
return fuseForallIntoConsumer(rewriter, producerForall, sliceParent,
129+
return fuseForallIntoConsumer(rewriter, producerForall, consumerForall,
79130
consumerChain);
80131
}
132+
133+
private:
134+
int64_t flatWorkgroupSize;
135+
int64_t subgroupSize;
81136
};
82137

83138
struct FuseTilableDestinationProducers final : OpRewritePattern<scf::ForallOp> {
@@ -198,12 +253,27 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
198253

199254
FunctionOpInterface funcOp = getOperation();
200255

256+
// Try to get the flat workgroup size if possible.
257+
std::optional<int64_t> maybeFlatWorkgroupSize = std::nullopt;
258+
if (std::optional<SmallVector<int64_t>> workgroupSize =
259+
getWorkgroupSize(funcOp)) {
260+
maybeFlatWorkgroupSize =
261+
std::accumulate(workgroupSize->begin(), workgroupSize->end(), 1,
262+
std::multiplies<int64_t>());
263+
}
264+
201265
// First run the hoisting and fusion patterns.
202266
{
203267
RewritePatternSet patterns(context);
204268
// These two patterns are run to a fixed point, allowing fusion within
205269
// potentially nested loops, hoisting from said loops, and continued fusion.
206-
patterns.add<FuseForalls>(context);
270+
if (maybeFlatWorkgroupSize) {
271+
// Forall fusion requires knowing the workgroup size to verify the fusion
272+
// is valid. Without validation we risk putting barriers inside
273+
// conditioned regions (e.g. scf.if/for).
274+
patterns.add<FuseForalls>(context, *maybeFlatWorkgroupSize,
275+
/*benefit=*/1);
276+
}
207277
patterns.add<FuseTilableForallConsumers>(context);
208278
populateForallLoopHoistingPattern(patterns);
209279
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s
22

3+
#translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
4+
35
#map = affine_map<(d0) -> (d0 * 2)>
46
#map1 = affine_map<(d0) -> (d0 * 4)>
57
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
68
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
79
#map4 = affine_map<(d0) -> (d0 * 16)>
8-
func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32> {
10+
func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32>
11+
attributes {translation_info = #translation_info} {
912
%c4 = arith.constant 4 : index
1013
%c128 = arith.constant 128 : index
1114
%c0 = arith.constant 0 : index
@@ -62,11 +65,14 @@ func.func @forall_fuse_then_hoist(%3: tensor<128x128xf16>, %4: tensor<128x128xf1
6265

6366
// -----
6467

68+
#translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
69+
6570
#map = affine_map<(d0) -> (d0 * 2)>
6671
#map1 = affine_map<(d0) -> (d0 * 4)>
6772
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
6873
#map3 = affine_map<(d0) -> (d0 * 16)>
69-
func.func @forall_fuse_then_hoist_mixed_mappings(%3: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32> {
74+
func.func @forall_fuse_then_hoist_mixed_mappings(%3: tensor<128x128xf16>, %5: tensor<128x128xf32>) -> tensor<128x128xf32>
75+
attributes {translation_info = #translation_info} {
7076
%c4 = arith.constant 4 : index
7177
%c128 = arith.constant 128 : index
7278
%c0 = arith.constant 0 : index
@@ -113,12 +119,15 @@ func.func @forall_fuse_then_hoist_mixed_mappings(%3: tensor<128x128xf16>, %5: te
113119

114120
// -----
115121

122+
#translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
123+
116124
#map = affine_map<(d0) -> (d0 * 2)>
117125
#map1 = affine_map<(d0) -> (d0 * 4)>
118126
#map2 = affine_map<(d0)[s0] -> (d0 * 4 + s0)>
119127
#map3 = affine_map<(d0)[s0] -> (d0 * 2 + s0)>
120128
#map4 = affine_map<(d0) -> (d0 * 16)>
121-
func.func @forall_fuse_then_hoist_with_fill(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>) -> tensor<128x128xf32> {
129+
func.func @forall_fuse_then_hoist_with_fill(%3: tensor<128x128xf16>, %4: tensor<128x128xf16>) -> tensor<128x128xf32>
130+
attributes {translation_info = #translation_info} {
122131
%c4 = arith.constant 4 : index
123132
%c128 = arith.constant 128 : index
124133
%c0 = arith.constant 0 : index
@@ -340,3 +349,103 @@ func.func @hoist_with_single_trip_loops(%2: tensor<128x128xf16>, %3: tensor<128x
340349
// CHECK: scf.forall.in_parallel
341350
// CHECK: scf.forall.in_parallel
342351
// CHECK: return
352+
353+
// -----
354+
355+
#map = affine_map<(d0) -> (d0 * 2)>
356+
#map1 = affine_map<(d0) -> (d0 * 16)>
357+
func.func @no_fuse_forall_without_workgroup_size(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32> {
358+
%0 = tensor.empty() : tensor<128x128xf32>
359+
%2 = scf.forall (%arg5, %arg6) in (64, 1) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
360+
%4 = affine.apply #map(%arg5)
361+
%extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
362+
%extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<128x128xf32> to tensor<2x128xf32>
363+
%5 = linalg.copy ins(%extracted_slice : tensor<2x128xf32>) outs(%extracted_slice_0 : tensor<2x128xf32>) -> tensor<2x128xf32>
364+
scf.forall.in_parallel {
365+
tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [2, 128] [1, 1] : tensor<2x128xf32> into tensor<128x128xf32>
366+
}
367+
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
368+
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
369+
%6 = affine.apply #map1(%arg5)
370+
%7 = affine.apply #map1(%arg6)
371+
%extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
372+
%extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
373+
%8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
374+
scf.forall.in_parallel {
375+
tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
376+
}
377+
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
378+
return %3 : tensor<128x128xf32>
379+
}
380+
381+
// CHECK-LABEL: func @no_fuse_forall_without_workgroup_size
382+
// CHECK-COUNT-2: scf.forall {{.*}} -> (tensor<128x128xf32>)
383+
384+
// -----
385+
386+
#translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>
387+
#map = affine_map<(d0) -> (d0 * 2)>
388+
#map1 = affine_map<(d0) -> (d0 * 16)>
389+
func.func @no_fuse_forall_workgroup_size_mismatch(%arg0: tensor<128x128xf32>) -> tensor<128x128xf32>
390+
attributes {translation_info = #translation_info} {
391+
%0 = tensor.empty() : tensor<128x128xf32>
392+
%2 = scf.forall (%arg5, %arg6) in (128, 1) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
393+
%4 = affine.apply #map(%arg5)
394+
%extracted_slice = tensor.extract_slice %arg0[%4, %arg6] [1, 128] [1, 1] : tensor<128x128xf32> to tensor<1x128xf32>
395+
%extracted_slice_0 = tensor.extract_slice %arg7[%4, %arg6] [1, 128] [1, 1] : tensor<128x128xf32> to tensor<1x128xf32>
396+
%5 = linalg.copy ins(%extracted_slice : tensor<1x128xf32>) outs(%extracted_slice_0 : tensor<1x128xf32>) -> tensor<1x128xf32>
397+
scf.forall.in_parallel {
398+
tensor.parallel_insert_slice %5 into %arg7[%4, %arg6] [1, 128] [1, 1] : tensor<1x128xf32> into tensor<128x128xf32>
399+
}
400+
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
401+
// We have 128 threads but only use 64 here, so loops cannot be fused.
402+
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
403+
%6 = affine.apply #map1(%arg5)
404+
%7 = affine.apply #map1(%arg6)
405+
%extracted_slice_0 = tensor.extract_slice %2[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
406+
%extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
407+
%8 = linalg.matmul ins(%extracted_slice_0, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
408+
scf.forall.in_parallel {
409+
tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
410+
}
411+
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
412+
return %3 : tensor<128x128xf32>
413+
}
414+
415+
// CHECK-LABEL: func @no_fuse_forall_workgroup_size_mismatch
416+
// CHECK-COUNT-2: scf.forall {{.*}} -> (tensor<128x128xf32>)
417+
418+
// -----
419+
420+
#translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
421+
#map1 = affine_map<(d0) -> (d0 * 16)>
422+
func.func @fuse_direct_forall_use(%arg0: tensor<128x128xf32>, %arg1: tensor<16x16xf32>) -> tensor<128x128xf32>
423+
attributes {translation_info = #translation_info} {
424+
%0 = tensor.empty() : tensor<128x128xf32>
425+
%1 = tensor.empty() : tensor<16x16xf32>
426+
%2 = scf.forall (%arg5, %arg6) in (4, 4) shared_outs(%arg7 = %1) -> (tensor<16x16xf32>) {
427+
%extracted_slice = tensor.extract_slice %arg1[%arg5, %arg6] [4, 4] [1, 1] : tensor<16x16xf32> to tensor<4x4xf32>
428+
%extracted_slice_0 = tensor.extract_slice %arg7[%arg5, %arg6] [4, 4] [1, 1] : tensor<16x16xf32> to tensor<4x4xf32>
429+
%5 = linalg.copy ins(%extracted_slice : tensor<4x4xf32>) outs(%extracted_slice_0 : tensor<4x4xf32>) -> tensor<4x4xf32>
430+
scf.forall.in_parallel {
431+
tensor.parallel_insert_slice %5 into %arg7[%arg5, %arg6] [4, 4] [1, 1] : tensor<4x4xf32> into tensor<16x16xf32>
432+
}
433+
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
434+
%3 = scf.forall (%arg5, %arg6) in (8, 8) shared_outs(%arg7 = %0) -> (tensor<128x128xf32>) {
435+
%6 = affine.apply #map1(%arg5)
436+
%7 = affine.apply #map1(%arg6)
437+
%extracted_slice_0 = tensor.extract_slice %arg0[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
438+
%extracted_slice_1 = tensor.extract_slice %arg7[%6, %7] [16, 16] [1, 1] : tensor<128x128xf32> to tensor<16x16xf32>
439+
%8 = linalg.matmul ins(%2, %extracted_slice_0 : tensor<16x16xf32>, tensor<16x16xf32>) outs(%extracted_slice_1 : tensor<16x16xf32>) -> tensor<16x16xf32>
440+
scf.forall.in_parallel {
441+
tensor.parallel_insert_slice %8 into %arg7[%6, %7] [16, 16] [1, 1] : tensor<16x16xf32> into tensor<128x128xf32>
442+
}
443+
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
444+
return %3 : tensor<128x128xf32>
445+
}
446+
447+
// CHECK-LABEL: func @fuse_direct_forall_use
448+
// CHECK: %[[FUSED_LOOP:.+]] = scf.forall
449+
// CHECK: %[[BARRIER:.+]] = iree_gpu.barrier_region
450+
// CHECK: linalg.matmul ins(%[[BARRIER]]
451+
// CHECK: return %[[FUSED_LOOP]]

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ hal.executable public @main {
400400
#config = #iree_gpu.lowering_config<{
401401
workgroup = [64, 64, 0],
402402
reduction = [0, 0, 2],
403-
subgroup = [2, 2],
403+
subgroup = [1, 1],
404404
mma_kind = #iree_gpu.mma_layout<MFMA_I32_32x32x16_I8>,
405405
promote_operands = [0, 1]
406406
}>
@@ -440,8 +440,8 @@ hal.executable public @main {
440440
// CHECK-LABEL: func @matmul_transpose_b_mfma_32x32x16_i8
441441
// CHECK-DAG: memref.alloc() : memref<64x40xi8, #gpu.address_space<workgroup>>
442442
// CHECK-DAG: memref.alloc() : memref<64x40xi8, #gpu.address_space<workgroup>>
443-
// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x4x1xi32>)
444-
// CHECK-COUNT-8: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32
443+
// CHECK: scf.for %{{.*}} = %c0 to %c80 step %c2 {{.*}} -> (vector<1x1x4x4x1xi32>)
444+
// CHECK-COUNT-2: amdgpu.mfma {{.*}}blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32
445445
// CHECK: scf.yield
446446

447447
// -----

0 commit comments

Comments
 (0)