Skip to content

Commit 6001f9c

Browse files
Fix distribution logic when number of parallel loops is greater than 3 (#18714)
Make the distribution logic for handling distribution of more than 3 loops more robust by avoiding use of tile sizes to figure out which loops are distribute, but instead pass only loop ranges that are gauranteed to be distributed. This also requires making the range passed to these loops be the ranges of the tiled loops. Fixes #18708 Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 5b0680d commit 6001f9c

File tree

7 files changed

+94
-44
lines changed

7 files changed

+94
-44
lines changed

compiler/src/iree/compiler/Codegen/Common/TileAndDistributeToWorkgroupsPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ void TileAndDistributeToWorkgroupsPass::runOnOperation() {
372372
auto linalgTilingOptions =
373373
linalg::LinalgTilingOptions()
374374
.setDistributionOptions(getIREELinalgLoopDistributionOptions(
375-
tileSizes, distributionMethodValue, maxWorkgroupParallelDims))
375+
distributionMethodValue, maxWorkgroupParallelDims))
376376
.setInterchange(llvm::map_to_vector(
377377
interchange,
378378
[](int64_t v) -> unsigned { return static_cast<unsigned>(v); }))

compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -291,10 +291,28 @@ tileDispatchUsingSCFFopOp(RewriterBase &rewriter, TilingInterface op,
291291

292292
IREETilingResult tilingResult;
293293
tilingResult.tiledLoops.resize(numLoops, false);
294-
for (auto [index, tileSize] : llvm::enumerate(tileSizes)) {
295-
if (!isConstantIntValue(tileSize, 0)) {
296-
tilingResult.tiledLoops.set(index);
294+
AffineExpr s0, s1, s2, s3; // lb, ub, step, tileSize
295+
bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
296+
AffineExpr numTilesExprs = (s1 - s0).ceilDiv(s2 * s3);
297+
for (auto [index, iteratorType, range, tileSize] :
298+
llvm::enumerate(op.getLoopIteratorTypes(), iterationDomain, tileSizes)) {
299+
// If distribution is specified, only parallel loops are tiled.
300+
if (options.distribution && iteratorType != utils::IteratorType::parallel) {
301+
continue;
302+
}
303+
// If tile size is 0, it isnt tiled.
304+
if (isConstantIntValue(tileSize, 0)) {
305+
continue;
297306
}
307+
// If number of tiles is statically know to be 1, the loop isnt tiled.
308+
OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
309+
rewriter, loc, numTilesExprs,
310+
{range.offset, range.size, range.stride, tileSize});
311+
if (isConstantIntValue(numTiles, 1)) {
312+
continue;
313+
}
314+
315+
tilingResult.tiledLoops.set(index);
298316
}
299317

300318
if (!tilingResult.tiledLoops.any()) {
@@ -328,40 +346,30 @@ tileDispatchUsingSCFFopOp(RewriterBase &rewriter, TilingInterface op,
328346
iterationDomain.size(), linalg::DistributionMethod::None);
329347
SmallVector<linalg::ProcInfo> procInfo;
330348
if (options.distribution) {
331-
SmallVector<utils::IteratorType> iteratorTypes =
332-
op.getLoopIteratorTypes();
333-
334-
// The parallel loops that are tiled are partitionable loops.
335349
SmallVector<Range> parallelLoopRanges;
336-
SmallVector<unsigned> partitionedLoopIds;
337-
338-
AffineExpr s0, s1, s2, s3; // lb, ub, step, tileSize
339-
bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
340-
AffineExpr numTilesExprs = (s1 - s0).ceilDiv(s2 * s3);
341-
for (auto [index, iteratorType] : llvm::enumerate(iteratorTypes)) {
342-
if (iteratorType != utils::IteratorType::parallel ||
343-
isConstantIntValue(tileSizes[index], 0)) {
344-
continue;
345-
}
346-
347-
OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
348-
rewriter, loc, numTilesExprs,
349-
{iterationDomain[index].offset, iterationDomain[index].size,
350-
iterationDomain[index].stride, tileSizes[index]});
351-
if (isConstantIntValue(numTiles, 1)) {
352-
continue;
350+
for (auto loopIdx : llvm::seq<unsigned>(0, numLoops)) {
351+
if (tilingResult.tiledLoops.test(loopIdx)) {
352+
AffineExpr s0, s1;
353+
bindSymbols(rewriter.getContext(), s0, s1);
354+
OpFoldResult parallelLoopStep = affine::makeComposedFoldedAffineApply(
355+
rewriter, loc, s0 * s1,
356+
{iterationDomain[loopIdx].stride, tileSizes[loopIdx]});
357+
Range r = {iterationDomain[loopIdx].offset,
358+
iterationDomain[loopIdx].size, parallelLoopStep};
359+
parallelLoopRanges.emplace_back(std::move(r));
353360
}
354-
355-
parallelLoopRanges.push_back(iterationDomain[index]);
356-
partitionedLoopIds.push_back(index);
357361
}
358362

359-
// Query the callback to get the {procId, nprocs} to use.
360363
procInfo =
361364
options.distribution->procInfo(rewriter, loc, parallelLoopRanges);
362365

363-
for (auto [index, loopIdx] : llvm::enumerate(partitionedLoopIds)) {
364-
distributionMethods[loopIdx] = procInfo[index].distributionMethod;
366+
unsigned partitionedLoopIdx = 0;
367+
for (auto loopIdx : llvm::seq<unsigned>(0, numLoops)) {
368+
if (!tilingResult.tiledLoops.test(loopIdx)) {
369+
continue;
370+
}
371+
distributionMethods[loopIdx] =
372+
procInfo[partitionedLoopIdx++].distributionMethod;
365373
}
366374
}
367375

@@ -443,7 +451,8 @@ static SmallVector<Operation *> getAllFusableProducers(TilingInterface op) {
443451
worklist.pop_front();
444452
for (OpOperand &operand : currOp->getOpOperands()) {
445453
Operation *definingOp = operand.get().getDefiningOp();
446-
auto tilingInterfaceProducer = dyn_cast<TilingInterface>(definingOp);
454+
auto tilingInterfaceProducer =
455+
dyn_cast_or_null<TilingInterface>(definingOp);
447456
if (!tilingInterfaceProducer || isa<tensor::PadOp>(definingOp) ||
448457
producers.count(tilingInterfaceProducer)) {
449458
continue;

compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ iree_lit_test_suite(
6565
"repeated_matcher_use.mlir",
6666
"replace_slow_min_max_ops.mlir",
6767
"test_partitionable_loops_interface.mlir",
68+
"tile_and_distribute_to_workgroups_func_scope.mlir",
6869
"tile_and_distribute_to_workgroups.mlir",
6970
"tile_and_distribute_workgroups_using_forall.mlir",
7071
"transform_buffer_opt.mlir",

compiler/src/iree/compiler/Codegen/Common/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ iree_lit_test_suite(
6262
"replace_slow_min_max_ops.mlir"
6363
"test_partitionable_loops_interface.mlir"
6464
"tile_and_distribute_to_workgroups.mlir"
65+
"tile_and_distribute_to_workgroups_func_scope.mlir"
6566
"tile_and_distribute_workgroups_using_forall.mlir"
6667
"transform_buffer_opt.mlir"
6768
"transform_copy_operand.mlir"
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-tile-and-distribute-to-workgroups{distribution-method=2}, canonicalize, cse))" --mlir-print-local-scope --split-input-file %s | FileCheck %s
2+
3+
func.func @multiple_dim_distribute(%s0 : index, %s1 : index, %s2 : index, %s3 : index,
4+
%arg0 : tensor<2x3x4x5xf32>) attributes {
5+
translation_info = #iree_codegen.translation_info<LLVMGPUTileAndFuse workgroup_size = [32, 1, 1] subgroup_size = 32>} {
6+
%c0 = arith.constant 0 : index
7+
%result = hal.interface.binding.subspan layout(
8+
<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">,
9+
#hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>)
10+
binding(0) alignment(64) offset(%c0) flags(Indirect)
11+
: !flow.dispatch.tensor<writeonly:tensor<?x2x?x3x?x4x?x5xf32>>{%s0, %s1, %s2, %s3}
12+
%35 = tensor.empty(%s0, %s1, %s2, %s3) : tensor<?x2x?x3x?x4x?x5xf32>
13+
%36 = linalg.generic {
14+
indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d3, d5, d7)>,
15+
affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)>],
16+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]}
17+
ins(%arg0 : tensor<2x3x4x5xf32>) outs(%35 : tensor<?x2x?x3x?x4x?x5xf32>)
18+
attrs = {lowering_config = #iree_gpu.lowering_config<{thread = [1, 1, 1, 1, 1, 1, 1, 1], workgroup = [1, 2, 1, 4, 1, 4, 1, 1]}>} {
19+
^bb0(%in: f32, %out: f32):
20+
linalg.yield %in : f32
21+
} -> tensor<?x2x?x3x?x4x?x5xf32>
22+
flow.dispatch.tensor.store %36, %result, offsets = [0, 0, 0, 0, 0, 0, 0, 0], sizes = [%s0, 2, %s1, 3, %s2, 4, %s3, 5], strides = [1, 1, 1, 1, 1, 1, 1, 1]
23+
: tensor<?x2x?x3x?x4x?x5xf32> -> !flow.dispatch.tensor<writeonly:tensor<?x2x?x3x?x4x?x5xf32>>{%s0, %s1, %s2, %s3}
24+
return
25+
}
26+
// CHECK-LABEL: func @multiple_dim_distribute(
27+
// CHECK-SAME: %[[S0:[a-zA-Z0-9]+]]: index,
28+
// CHECK-SAME: %[[S1:[a-zA-Z0-9]+]]: index,
29+
// CHECK-SAME: %[[S2:[a-zA-Z0-9]+]]: index,
30+
// CHECK-SAME: %[[S3:[a-zA-Z0-9]+]]: index,
31+
// CHECK-SAME: %[[INPUT:.+]]: tensor<2x3x4x5xf32>)
32+
// CHECK-DAG: %[[WG_ID_X:.+]] = hal.interface.workgroup.id[0]
33+
// CHECK-DAG: %[[WG_ID_Y:.+]] = hal.interface.workgroup.id[1]
34+
// CHECK-DAG: %[[WG_ID_Z:.+]] = hal.interface.workgroup.id[2]
35+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<1x2x1x3x1x4x1x1xf32>
36+
// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[INPUT]][0, 0, 0, %[[WG_ID_X]]] [2, 3, 4, 1]
37+
// CHECK: %[[GENERIC:.+]] = linalg.generic
38+
// CHECK-SAME: ins(%[[IN_SLICE]] :
39+
// CHECK-SAME: outs(%[[EMPTY]] :
40+
// CHECK-DAG: %[[WG_ID_Z_0:.+]] = affine.apply affine_map<()[s0, s1, s2] -> ((s1 floordiv s2) floordiv s0)>()[%[[S1]], %[[WG_ID_Z]], %[[S2]]]
41+
// CHECK-DAG: %[[WG_ID_Z_1:.+]] = affine.apply affine_map<()[s0, s1, s2] -> ((s1 floordiv s2) mod s0)>()[%[[S1]], %[[WG_ID_Z]], %[[S2]]]
42+
// CHECK-DAG: %[[WG_ID_Z_2:.+]] = affine.apply affine_map<()[s0, s1] -> (s0 mod s1)>()[%[[WG_ID_Z]], %[[S2]]]
43+
// CHECK: flow.dispatch.tensor.store %[[GENERIC]],
44+
// CHECK-SAME: offsets = [%[[WG_ID_Z_0]], 0, %[[WG_ID_Z_1]], 0, %[[WG_ID_Z_2]], 0, %[[WG_ID_Y]], %[[WG_ID_X]]]
45+
// CHECK-SAME: sizes = [1, 2, 1, 3, 1, 4, 1, 1]

compiler/src/iree/compiler/Codegen/Utils/Utils.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -704,17 +704,11 @@ static Value buildHALWorkgroupInfoOp(OpBuilder &b, unsigned dim) {
704704
}
705705

706706
linalg::LinalgLoopDistributionOptions getIREELinalgLoopDistributionOptions(
707-
const SmallVector<int64_t> &tileSizes,
708707
linalg::DistributionMethod distributionMethod,
709708
int32_t maxWorkgroupParallelDims) {
710-
return {[&tileSizes, distributionMethod,
709+
return {[distributionMethod,
711710
maxWorkgroupParallelDims](OpBuilder &builder, Location loc,
712711
ArrayRef<Range> parallelLoopRanges) {
713-
SmallVector<int64_t> nonZeroTileSizes;
714-
for (int64_t size : tileSizes) {
715-
if (size != 0)
716-
nonZeroTileSizes.push_back(size);
717-
}
718712
auto numParallelDims = parallelLoopRanges.size();
719713

720714
SmallVector<linalg::ProcInfo, 3> procInfo(numParallelDims);
@@ -729,11 +723,12 @@ linalg::LinalgLoopDistributionOptions getIREELinalgLoopDistributionOptions(
729723
OpFoldResult size = parallelLoopRanges[numParallelDims - dim - 1].size;
730724
OpFoldResult offset =
731725
parallelLoopRanges[numParallelDims - dim - 1].offset;
732-
AffineExpr d0, d1;
733-
int64_t tileSize = nonZeroTileSizes[numParallelDims - dim - 1];
734-
bindSymbols(builder.getContext(), d0, d1);
726+
OpFoldResult step =
727+
parallelLoopRanges[numParallelDims - dim - 1].stride;
728+
AffineExpr d0, d1, d2;
729+
bindSymbols(builder.getContext(), d0, d1, d2);
735730
OpFoldResult numTiles = affine::makeComposedFoldedAffineApply(
736-
builder, loc, (d0 - d1).ceilDiv(tileSize), {size, offset});
731+
builder, loc, (d1 - d0).ceilDiv(d2), {offset, size, step});
737732
OpFoldResult dimValue;
738733
if (dim == numParallelDims - 1)
739734
dimValue = splitDim.value();

compiler/src/iree/compiler/Codegen/Utils/Utils.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,6 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, Value to,
183183
/// Returns the option that distributes the ops using the flow workgroup
184184
/// ID/Count operations.
185185
linalg::LinalgLoopDistributionOptions getIREELinalgLoopDistributionOptions(
186-
const SmallVector<int64_t> &tileSizes,
187186
linalg::DistributionMethod distributionMethod,
188187
int32_t maxWorkgroupParallelDims = kNumMaxParallelDims);
189188

0 commit comments

Comments
 (0)