Skip to content

Commit 645b446

Browse files
Use scf::tileAndFuseConsumer in GPUFuseAndHoistParallelLoops (#22617)
This change simplifies the logic in `GPUFuseAndHoistParallelLoops` by using the `scf::tileAndFuseConsumer` method that directly takes the consumer to fuse as operand and find the slices to fuse along. The previous implementation related in a subtle bug, where the operation that was being expected to fuse and the actual operation fused were different. The new methods disallows this by construction. There is still an issue of the pattern rewrite still going into an infinite loop (or hitting the limit). That is a problem because of the tiling generating operations before failing. The tiling method is not intended to be called within pattern rewriters, but some outstanding changes to `TilingInterface` can also address this issue. Leaving this as an error for now. Fixes #22576 Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 1f322ce commit 645b446

File tree

3 files changed

+41
-34
lines changed

3 files changed

+41
-34
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
358358
patterns.add<FuseNestedLaneAndWarpForalls>(context);
359359
populateForallLoopHoistingPattern(patterns);
360360
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
361+
funcOp->emitOpError("failed to apply fusion + hoisting patterns (set 1)");
361362
return signalPassFailure();
362363
}
363364
}
@@ -379,6 +380,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
379380
tensor::populateFoldTensorEmptyPatterns(patterns);
380381
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
381382
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
383+
funcOp->emitOpError("failed to apply fusion + hoisting patterns (set 2)");
382384
return signalPassFailure();
383385
}
384386
}
@@ -393,6 +395,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
393395
tensor::populateFoldTensorEmptyPatterns(patterns);
394396
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
395397
if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) {
398+
funcOp->emitOpError("failed to apply fusion + hoisting patterns (set 3)");
396399
return signalPassFailure();
397400
}
398401
}

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-fuse-and-hoist-parallel-loops))' --split-input-file | FileCheck %s
1+
// RUN: iree-opt %s --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-fuse-and-hoist-parallel-loops))' --split-input-file --verify-diagnostics | FileCheck %s
22

33
#translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
44

@@ -486,7 +486,7 @@ func.func @forall_hoist_unit_loop_with_fill(%3: tensor<1x128xf16>, %4: tensor<12
486486

487487
// -----
488488

489-
func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> {
489+
func.func @fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> {
490490
%c4 = arith.constant 4 : index
491491
%c128 = arith.constant 128 : index
492492
%c0 = arith.constant 0 : index
@@ -496,10 +496,9 @@ func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -
496496
%extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
497497
%extracted_slice_3 = tensor.extract_slice %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
498498
%16 = linalg.copy ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_2 : tensor<2x2xf16>) -> tensor<2x2xf16>
499-
%17 = linalg.transpose ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_3 : tensor<2x2xf16>) permutation = [1, 0]
500499
scf.forall.in_parallel {
501500
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16>
502-
tensor.parallel_insert_slice %17 into %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16>
501+
tensor.parallel_insert_slice %16 into %arg8[%arg5, %arg6] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16>
503502
}
504503
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
505504
%add = linalg.add
@@ -508,14 +507,42 @@ func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -
508507
return %add : tensor<128x128xf16>
509508
}
510509

511-
// CHECK-LABEL: func @no_fuse_multi_use
510+
// CHECK-LABEL: func @fuse_multi_use(
512511
// CHECK: scf.forall
513512
// CHECK: linalg.copy
514-
// CHECK: linalg.transpose
513+
// CHECK: linalg.add
515514
// CHECK: scf.forall.in_parallel
516-
// CHECK: linalg.add
517515
// CHECK: return
518516

517+
518+
// -----
519+
520+
// For now this test errors out cause the pattern rewriter goes into an infinite loop. This happens cause the consumer
521+
// fusion fails, but modified the IR before failing. This will be fixed shortly upstream.
522+
523+
// expected-error @+1 {{failed to apply fusion + hoisting patterns (set 1)}}
524+
func.func @no_fuse_incompatible_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -> tensor<128x128xf16> {
525+
%c4 = arith.constant 4 : index
526+
%c128 = arith.constant 128 : index
527+
%c0 = arith.constant 0 : index
528+
%empty = tensor.empty() : tensor<128x128xf16>
529+
%10:2 = scf.forall (%arg5, %arg6) in (32, 32) shared_outs(%arg7 = %empty, %arg8 = %empty) -> (tensor<128x128xf16>, tensor<128x128xf16>) {
530+
%extracted_slice_1 = tensor.extract_slice %2[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
531+
%extracted_slice_2 = tensor.extract_slice %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
532+
%extracted_slice_3 = tensor.extract_slice %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<128x128xf16> to tensor<2x2xf16>
533+
%16 = linalg.copy ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_2 : tensor<2x2xf16>) -> tensor<2x2xf16>
534+
%17 = linalg.transpose ins(%extracted_slice_1 : tensor<2x2xf16>) outs(%extracted_slice_3 : tensor<2x2xf16>) permutation = [1, 0]
535+
scf.forall.in_parallel {
536+
tensor.parallel_insert_slice %16 into %arg7[%arg5, %arg6] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16>
537+
tensor.parallel_insert_slice %17 into %arg8[%arg6, %arg5] [2, 2] [1, 1] : tensor<2x2xf16> into tensor<128x128xf16>
538+
}
539+
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
540+
%add = linalg.add
541+
ins(%10#0, %10#1 : tensor<128x128xf16>, tensor<128x128xf16>)
542+
outs(%empty: tensor<128x128xf16>) -> tensor<128x128xf16>
543+
return %add : tensor<128x128xf16>
544+
}
545+
519546
// -----
520547

521548
#map = affine_map<(d0) -> (d0 * 64)>

compiler/src/iree/compiler/Codegen/Common/Transforms.cpp

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ struct FuseTilableForallConsumers final
4848

4949
tensor::ParallelInsertSliceOp producerSlice;
5050
LoopLikeOpInterface sliceOwner;
51-
Value fusionOperand;
5251
for (auto operand : dpsOp.getDpsInputs()) {
5352
auto forallProducer = operand.getDefiningOp<scf::ForallOp>();
5453
if (!forallProducer) {
@@ -57,36 +56,15 @@ struct FuseTilableForallConsumers final
5756
if (forallProducer->getBlock() != tilableOp->getBlock()) {
5857
continue;
5958
}
60-
Value iterArg = forallProducer.getTiedBlockArgument(
61-
forallProducer.getTiedOpOperand(cast<OpResult>(operand)));
62-
63-
for (auto user : iterArg.getUsers()) {
64-
auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(user);
65-
if (sliceOp && sliceOp.getDest() == iterArg) {
66-
producerSlice = sliceOp;
67-
sliceOwner = forallProducer;
68-
fusionOperand = operand;
69-
break;
70-
}
71-
}
72-
if (producerSlice) {
73-
break;
74-
}
59+
sliceOwner = forallProducer;
60+
break;
7561
}
7662

77-
if (!producerSlice) {
63+
if (!sliceOwner) {
7864
return rewriter.notifyMatchFailure(tilableOp,
7965
"no scf.forall producer to fuse into");
8066
}
8167

82-
for (auto operand : tilableOp->getOperands()) {
83-
if (operand != fusionOperand && operand.getDefiningOp() == sliceOwner) {
84-
return rewriter.notifyMatchFailure(tilableOp,
85-
"unimplemented: Cannot fuse op with "
86-
"multiple uses of producer loop");
87-
}
88-
}
89-
9068
// The `tileAndFuseConsumerOfSlices` transform will fail if there are any
9169
// users of the loop that do not dominate the `tilableOp`, so we move the
9270
// `tilableOp` and any producers needed for dominance right after the loop.
@@ -116,8 +94,7 @@ struct FuseTilableForallConsumers final
11694
}
11795

11896
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
119-
scf::tileAndFuseConsumerOfSlices(rewriter, producerSlice.getOperation(),
120-
{sliceOwner});
97+
scf::tileAndFuseConsumer(rewriter, tilableOp, {sliceOwner});
12198
if (failed(fuseConsumerResults)) {
12299
return failure();
123100
}

0 commit comments

Comments
 (0)