Skip to content

Commit f6a9b6b

Browse files
authored
[Codegen][GPU] Enable destination fusion for unit trip loops (#18674)
When doing loop fusion + hoisting, this introduces new fusion opportunities after earlier canonicalization steps can have kicked in. This causes problems for unit trip loops where the slice on the destination will get folded away. This adds a pattern to move any dps ops into the body of the forall loop in such cases because unit trip loops are equivalent to single threaded regions.
1 parent ad68964 commit f6a9b6b

File tree

2 files changed

+100
-1
lines changed

2 files changed

+100
-1
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/FuseAndHoistParallelLoops.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ struct FuseForalls final : OpRewritePattern<scf::ForallOp> {
132132

133133
private:
134134
int64_t flatWorkgroupSize;
135-
int64_t subgroupSize;
136135
};
137136

138137
struct FuseTilableDestinationProducers final : OpRewritePattern<scf::ForallOp> {
@@ -174,6 +173,68 @@ struct FuseTilableDestinationProducers final : OpRewritePattern<scf::ForallOp> {
174173
}
175174
};
176175

176+
struct FuseUnitLoopDestination final : OpRewritePattern<scf::ForallOp> {
177+
using OpRewritePattern::OpRewritePattern;
178+
LogicalResult matchAndRewrite(scf::ForallOp forallOp,
179+
PatternRewriter &rewriter) const override {
180+
std::optional<int64_t> maybeTripCount = getStaticForallTripCount(forallOp);
181+
if (!maybeTripCount || *maybeTripCount != 1) {
182+
return rewriter.notifyMatchFailure(forallOp,
183+
"not a unit trip count loop");
184+
}
185+
DestinationStyleOpInterface dpsProducer;
186+
BlockArgument bodyArg;
187+
Value dpsResult;
188+
for (auto iterArg : forallOp.getRegionIterArgs()) {
189+
dpsResult = forallOp.getTiedLoopInit(iterArg)->get();
190+
bodyArg = iterArg;
191+
dpsProducer = dpsResult.getDefiningOp<DestinationStyleOpInterface>();
192+
if (dpsProducer) {
193+
break;
194+
}
195+
}
196+
if (!dpsProducer || !dpsProducer->hasOneUse()) {
197+
return rewriter.notifyMatchFailure(forallOp,
198+
"no single use DPS producer");
199+
}
200+
201+
Operation *parallelInsert = nullptr;
202+
for (auto user : bodyArg.getUsers()) {
203+
if (isa<tensor::ParallelInsertSliceOp>(user)) {
204+
// This should be illegal but check anyway.
205+
if (parallelInsert) {
206+
return rewriter.notifyMatchFailure(forallOp, "multiple insert users");
207+
}
208+
parallelInsert = user;
209+
}
210+
}
211+
if (!parallelInsert) {
212+
return rewriter.notifyMatchFailure(
213+
forallOp, "destination not used by a parallel insert");
214+
}
215+
216+
rewriter.startOpModification(forallOp);
217+
// Move the producer into the body of the forall loop.
218+
rewriter.moveOpBefore(dpsProducer, forallOp.getBody(),
219+
forallOp.getBody()->begin());
220+
221+
// Replace all uses of the region iter arg with the moved dps op.
222+
rewriter.replaceAllUsesExcept(bodyArg, dpsResult, parallelInsert);
223+
224+
// Set the init operand of the forall op to the init operand of the
225+
// producer.
226+
int64_t dpsInitIndex = cast<OpResult>(dpsResult).getResultNumber();
227+
forallOp->setOperand(forallOp.getTiedOpOperand(bodyArg)->getOperandNumber(),
228+
dpsProducer.getDpsInitOperand(dpsInitIndex)->get());
229+
230+
// Finally replace the init operand of the moved producer with the region
231+
// iter arg.
232+
dpsProducer.setDpsInitOperand(dpsInitIndex, bodyArg);
233+
rewriter.finalizeOpModification(forallOp);
234+
return success();
235+
}
236+
};
237+
177238
struct FuseTilableSliceProducers final
178239
: OpRewritePattern<tensor::ExtractSliceOp> {
179240
using OpRewritePattern::OpRewritePattern;
@@ -290,6 +351,7 @@ void FuseAndHoistParallelLoopsPass::runOnOperation() {
290351
{
291352
RewritePatternSet patterns(context);
292353
patterns.add<FuseTilableDestinationProducers>(context);
354+
patterns.add<FuseUnitLoopDestination>(context);
293355
patterns.add<FuseTilableForallConsumers>(context);
294356
tensor::populateFoldTensorEmptyPatterns(patterns);
295357
scf::ForallOp::getCanonicalizationPatterns(patterns, context);

compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/fuse_and_hoist_forall.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,40 @@ func.func @fuse_direct_forall_use(%arg0: tensor<128x128xf32>, %arg1: tensor<16x1
449449
// CHECK: %[[BARRIER:.+]] = iree_gpu.barrier_region
450450
// CHECK: linalg.matmul ins(%[[BARRIER]]
451451
// CHECK: return %[[FUSED_LOOP]]
452+
453+
// -----
454+
455+
#translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>
456+
457+
func.func @forall_hoist_unit_loop_with_fill(%3: tensor<1x128xf16>, %4: tensor<128x1xf16>) -> tensor<1x1xf32>
458+
attributes {translation_info = #translation_info} {
459+
%c4 = arith.constant 4 : index
460+
%c128 = arith.constant 128 : index
461+
%c0 = arith.constant 0 : index
462+
%empty = tensor.empty() : tensor<1x1xf32>
463+
%cst = arith.constant 0.0 : f32
464+
%5 = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x1xf32>) -> tensor<1x1xf32>
465+
%8 = scf.for %arg0 = %c0 to %c128 step %c4 iter_args(%arg1 = %5) -> (tensor<1x1xf32>) {
466+
%11 = scf.forall (%arg2, %arg3) in (1, 1) shared_outs(%arg4 = %arg1) -> (tensor<1x1xf32>) {
467+
%12 = affine.apply affine_map<(d0) -> (d0 * 4)>(%arg0)
468+
%extracted_slice = tensor.extract_slice %3[0, %12] [1, 4] [1, 1] : tensor<1x128xf16> to tensor<1x4xf16>
469+
%extracted_slice_0 = tensor.extract_slice %4[%12, 0] [4, 1] [1, 1] : tensor<128x1xf16> to tensor<4x1xf16>
470+
%14 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<1x4xf16>, tensor<4x1xf16>) outs(%arg4 : tensor<1x1xf32>) -> tensor<1x1xf32>
471+
scf.forall.in_parallel {
472+
tensor.parallel_insert_slice %14 into %arg4[%arg2, %arg3] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x1xf32>
473+
}
474+
} {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
475+
scf.yield %11 : tensor<1x1xf32>
476+
}
477+
return %8 : tensor<1x1xf32>
478+
}
479+
480+
// CHECK-LABEL: func @forall_hoist_unit_loop_with_fill
481+
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1xf32>
482+
// CHECK: %[[OUTER_PARALLEL:.+]] = scf.forall {{.*}}shared_outs(%[[ITER:.+]] = %[[EMPTY]])
483+
// CHECK: %[[FILL:.+]] = linalg.fill {{.*}} outs(%[[ITER]]
484+
// CHECK: %[[LOOP:.+]] = scf.for {{.*}} iter_args(%{{.*}} = %[[FILL]])
485+
// CHECK: scf.yield {{.*}} : tensor<1x1xf32>
486+
// CHECK: scf.forall.in_parallel
487+
// CHECK-NEXT: tensor.parallel_insert_slice %[[LOOP]] into %[[ITER]]
488+
// CHECK: return %[[OUTER_PARALLEL]]

0 commit comments

Comments
 (0)