Skip to content

Commit 21234ed

Browse files
authored
[Codegen][GPU] Infer workgroup size multiples from producers and consumers (#19804)
This PR adds new logic in ConfigUtils.cpp to analyze a dispatch and determine required multiples of workgroup tile sizes for the root operation. This affects dispatches that contain either tensor.pack or tensor.unpack ops, since the pack and unpack ops require the workgroup tile sizes to be a multiple of their inner_tiles in order for them to be fused into the workgroup scf.forall loop. The following example of a gpu set_encoding dispatch illustrates the new constraint imposed by this PR: ```mlir %in = flow.dispatch.tensor.load ... -> tensor<256x64xi8> %pack = tensor.pack %in ... inner_tiles = [128, 64] ... tensor<256x64xi8> -> tensor<2x1x128x64xi8> %expanded = tensor.expand_shape %pack [[0], [1], [2, 3, 4], [5, 6, 7]] : tensor<2x1x128x64xi8> into tensor<2x1x4x8x4x2x4x8xi8> // linalg.transpose is the root op. The workgroup tile sizes must contain an // even multiple of the tensor.pack inner_tiles. %transposed = linalg.transpose ins(%expanded : tensor<2x1x4x8x4x2x4x8xi8>) outs(%empty : tensor<2x1x8x4x4x4x2x8xi8>) permutation = [0, 1, 3, 6, 2, 4, 5, 7] flow.dispatch.tensor.store %transposed ``` Since the linalg.transpose is the root op, it needs to be aware of its producer chain when selecting tile sizes. With this PR, the lowering config selection logic will walk producers until it hits an unsupported operation or a block argument, and find the LCM of any pack or unpack tiles along the dimensions of their inner_tiles. In the above example, this would look like the following: 1. Walk producer chain up to the producer of `tensor.pack`, and stop at the `flow.dispatch.tensor.load`. The initial workgroup tile size multiples will be `[1, 1]` (i.e., no constraint for unsupported ops). 2. The workgroup tile sizes will be propagated through the `tensor.pack`, which updates the workgroup tile size multiples to `[1, 1, 128, 64]`. 3. Then, it will propagate through the `tensor.expand_shape`, which will expand the workgroup size multiples if possible. In this case, they are expanded to `[1, 1, 4, 8, 4, 2, 4, 8]`. 4. Now walk the consumer chain to find the multiples for the workgroup tile slice of the root op result. In this case, the propagation simply stops at the `flow.dispatch.tensor.store`, and the multiples are `[1, 1, 1, ...]`. 5. Now the root op has the required workgroup tile size multiples for the operand and result slices, and the multiples for the iteration space of the op are computed based on the indexing maps of the operation, by taking the LCM along each dimension of that dimension's multiples from all operands and results. In this case the final workgroup tile size multiples would become `[1, 1, 8, 4, 4, 4, 2, 8]`. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent e4c683f commit 21234ed

File tree

9 files changed

+669
-51
lines changed

9 files changed

+669
-51
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ iree_compiler_cc_library(
154154
"TileAndDistributeToWorkgroupsPass.cpp",
155155
"TileDispatchUsingForall.cpp",
156156
"TileDispatchUsingInterface.cpp",
157+
"TileInferenceUtils.cpp",
157158
"TileLargeTensors.cpp",
158159
"TileSizeSelection.cpp",
159160
"Transforms.cpp",
@@ -171,6 +172,7 @@ iree_compiler_cc_library(
171172
"PassUtils.h",
172173
"Passes.h",
173174
"TensorDynamicDimAnalysis.h",
175+
"TileInferenceUtils.h",
174176
"TileSizeSelection.h",
175177
"Transforms.h",
176178
"UserConfig.h",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ iree_cc_library(
7474
"PassUtils.h"
7575
"Passes.h"
7676
"TensorDynamicDimAnalysis.h"
77+
"TileInferenceUtils.h"
7778
"TileSizeSelection.h"
7879
"Transforms.h"
7980
"UserConfig.h"
@@ -146,6 +147,7 @@ iree_cc_library(
146147
"TileAndDistributeToWorkgroupsPass.cpp"
147148
"TileDispatchUsingForall.cpp"
148149
"TileDispatchUsingInterface.cpp"
150+
"TileInferenceUtils.cpp"
149151
"TileLargeTensors.cpp"
150152
"TileSizeSelection.cpp"
151153
"Transforms.cpp"

compiler/src/iree/compiler/Codegen/Common/TileInferenceUtils.cpp

Lines changed: 397 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Copyright 2025 The IREE Authors
2+
//
3+
// Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
#ifndef IREE_COMPILER_CODEGEN_LLVMCPU_TILEINFERENCEUTILS_H_
8+
#define IREE_COMPILER_CODEGEN_LLVMCPU_TILEINFERENCEUTILS_H_
9+
10+
#include "mlir/Interfaces/TilingInterface.h"
11+
12+
namespace mlir::iree_compiler {
13+
14+
/// Walks the producer and consumer chains of the `tilingOp`, and looks for ops
15+
/// that require specific workgroup tile size multiples. Right now, the only ops
16+
/// that require a specific multiple are pack and unpack, since the workgroup
17+
/// tile sizes need to be multiples of the inner_tiles. After walking the IR and
18+
/// finding multiples for the slices of the `tilingOp` operands and results, the
19+
/// function computes and returns the multiples of the `tilingOp` iteration
20+
/// space. The function may fail to find a valid set of workgroup size
21+
/// multiples, in which case the function will fallback to returning a list of
22+
/// all 1, meaning no constraints on the workgroup tile sizes.
23+
SmallVector<int64_t> getWorkgroupSizeMultiples(TilingInterface tilingOp);
24+
25+
} // namespace mlir::iree_compiler
26+
27+
#endif // IREE_COMPILER_CODEGEN_LLVMCPU_TILEINFERENCEUTILS_H_

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@ iree_compiler_cc_library(
2121
"ConfigUtils.h",
2222
],
2323
deps = [
24+
"//compiler/src/iree/compiler/Codegen/Common",
2425
"//compiler/src/iree/compiler/Codegen/Common/GPU:GPUHeuristics",
2526
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
2627
"//compiler/src/iree/compiler/Codegen/Dialect/GPU/IR:IREEGPUDialect",
2728
"//compiler/src/iree/compiler/Codegen/Utils",
2829
"//compiler/src/iree/compiler/Dialect/LinalgExt/Utils",
2930
"@llvm-project//llvm:Support",
31+
"@llvm-project//mlir:DialectUtils",
3032
"@llvm-project//mlir:FunctionInterfaces",
3133
"@llvm-project//mlir:IR",
3234
"@llvm-project//mlir:LinalgDialect",

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ iree_cc_library(
2424
MLIRLinalgDialect
2525
MLIRLinalgUtils
2626
MLIRSupport
27+
iree::compiler::Codegen::Common
2728
iree::compiler::Codegen::Common::GPU::GPUHeuristics
2829
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
2930
iree::compiler::Codegen::Dialect::GPU::IR::IREEGPUDialect

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.h"
88

99
#include "iree/compiler/Codegen/Common/GPU/GPUHeuristics.h"
10+
#include "iree/compiler/Codegen/Common/TileInferenceUtils.h"
1011
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1112
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
1213
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
@@ -20,6 +21,7 @@
2021
#include "llvm/Support/Debug.h"
2122
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
2223
#include "mlir/Dialect/Linalg/Utils/Utils.h"
24+
#include "mlir/Dialect/Utils/IndexingUtils.h"
2325
#include "mlir/IR/BuiltinAttributes.h"
2426
#include "mlir/IR/TypeUtilities.h"
2527
#include "mlir/Interfaces/FunctionInterfaces.h"
@@ -34,6 +36,10 @@ namespace mlir::iree_compiler::IREE::GPU {
3436
constexpr int64_t kCacheLineSizeBits = 128 * 8;
3537
constexpr int64_t kPreferredCopyNumBits = 128;
3638

39+
//===----------------------------------------------------------------------===//
40+
// Lowering Config Selection
41+
//===----------------------------------------------------------------------===//
42+
3743
LogicalResult setDataTiledMultiMmaLoweringConfig(
3844
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
3945
Operation *op, IREE::GPU::UKernelConfigAttr ukernelConfig) {
@@ -529,6 +535,17 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
529535
SmallVector<int64_t> workgroupTileSizes(loopDepth, 0);
530536
SmallVector<int64_t> threadTileSizes(loopDepth, 0);
531537

538+
// Find constraints on workgroup tile sizes due to pack or unpack ops in the
539+
// dispatch. If there are no pack or unpack ops present, then these multiples
540+
// will be 1, which means there is no constraint on workgroup tile sizes.
541+
//
542+
// TODO(Max191): Getting the workgroup size multiples is needed for current
543+
// pack and unpack GPU codegen. Ideally, we won't rely on propagating pack
544+
// and unpack tile size information during lowering strategy selection, and
545+
// this logic should be dropped once we have a better solution.
546+
SmallVector<int64_t> workgroupTileSizeMultiples =
547+
getWorkgroupSizeMultiples(cast<TilingInterface>(op));
548+
532549
// Common case for all linalg ops.
533550

534551
// The core idea is to distribute the partitioned loops to the workgroup
@@ -566,23 +583,15 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
566583
LDBG("Loss factor: " << lossFactor << "\n");
567584
// Initialize the configuration.
568585
flatWorkgroupSize = 1;
569-
// Initialize tiling along all partitioned loops with size 1.
586+
// Initialize thread tiling along all partitioned loops with size 1, and
587+
// workgroup tiling with the required tile size multiples. This may lead
588+
// to larger workgroup tiles than the number of threads in the workgroup,
589+
// but it is unavoidable.
570590
for (int64_t loopIndex : partitionableLoops) {
571-
workgroupTileSizes[loopIndex] = threadTileSizes[loopIndex] = 1;
572-
}
573-
// Override the innermost dimension to distribute to threads in a subgroup.
574-
workgroupTileSizes[partitionableLoops.back()] = subgroupSize;
575-
576-
// If there are more than 3 parallel dim try to tile the extra higher level
577-
// dimensions to 1 for extra dimensions.
578-
if (isa<linalg::GenericOp>(linalgOp.getOperation())) {
579-
for (auto [i, tileSize] : llvm::enumerate(workgroupTileSizes)) {
580-
if (tileSize != 0)
581-
break;
582-
if (loopBounds[i] != 1)
583-
tileSize = 1;
584-
}
591+
workgroupTileSizes[loopIndex] = workgroupTileSizeMultiples[loopIndex];
592+
threadTileSizes[loopIndex] = 1;
585593
}
594+
586595
// Scan from the innermost shape dimension and try to deduce the
587596
// configuration for the corresponding GPU workgroup dimension.
588597
int64_t wgDim = 0;
@@ -592,18 +601,26 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
592601
if (ShapedType::isDynamic(loopBound))
593602
continue;
594603

595-
// Try to find some power of two that can devide the current shape dim
604+
// Try to find some power of two that can divide the current shape dim
596605
// size. This vector keeps the candidate tile sizes.
597606
SmallVector<int64_t, 8> candidates;
598607

608+
// Ensure vectorization works with the `workgroupTileMultiple`.
609+
int64_t workgroupTileMultiple = workgroupTileSizeMultiples[shapeDim];
610+
vectorizable =
611+
vectorizable && 4 * numThreads % workgroupTileMultiple == 0;
599612
// For the inner most workgroup dim, try to see if we can have 4
600613
// elements per thread. This enables vectorization.
601614
if (vectorizable && wgDim == 0 && !lossFactor) {
602615
candidates.push_back(4 * numThreads);
603616
}
604-
// Try all power of two numbers up to the subgroup size.
605-
for (unsigned i = numThreads; i >= 1; i >>= 1) {
606-
candidates.push_back(i);
617+
// Try all power of two multiples of `workgroupTileMultiple` up to the
618+
// subgroup size.
619+
uint64_t maxCandidate =
620+
std::max<uint64_t>(1, llvm::PowerOf2Ceil(llvm::divideCeil(
621+
numThreads, workgroupTileMultiple)));
622+
for (unsigned i = maxCandidate; i >= 1; i >>= 1) {
623+
candidates.push_back(i * workgroupTileMultiple);
607624
}
608625
LLVM_DEBUG({
609626
llvm::dbgs() << "Base candidate tile sizes: [";
@@ -629,13 +646,10 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
629646
continue;
630647
}
631648

632-
// Found a suitable candidate. Try to let each thread handle 4
633-
// elements if this is the workgroup x dimension.
649+
// Try to let each thread handle 4 elements if this is the workgroup x
650+
// dimension.
634651
// TODO: Try to take into account element type bit width to get
635652
// 4xdword reads instead of 4x{elements}.
636-
workgroupTileSizes[shapeDim] = scaledTileSize;
637-
LLVM_DEBUG(llvm::dbgs()
638-
<< "Chosen workgroup tile size: " << scaledTileSize << "\n");
639653
if (vectorizable && wgDim == 0 && !lossFactor && candidate % 4 == 0) {
640654
// Use size-1 vectors to increase parallelism if larger ones causes
641655
// idle threads in the subgroup.
@@ -648,13 +662,29 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
648662
assert(numThreads % (candidate / vectorSize) == 0);
649663
numThreads /= candidate / vectorSize;
650664
} else {
665+
// When the workgroupTileMultiple is not a Po2, then the candidate
666+
// may not evenly divide the numThreads. In this case, we get some
667+
// idle threads in the last iteration of the workgroup tile. Verify
668+
// that the idle threads are within the lossFactor.
669+
int64_t maybeCandidateWorkgroupSize = candidate;
670+
if (numThreads % candidate != 0) {
671+
maybeCandidateWorkgroupSize =
672+
std::min<int64_t>(1ll << llvm::Log2_64(candidate), numThreads);
673+
int64_t idleThreads = candidate % maybeCandidateWorkgroupSize;
674+
if (idleThreads != 0 &&
675+
(!lossFactor || idleThreads > candidate / *lossFactor)) {
676+
continue;
677+
}
678+
}
651679
if (wgDim == 0)
652680
vectorizable = false;
653681
threadTileSizes[shapeDim] = scaleToByte;
654-
candidateWorkgroupSize = candidate;
655-
assert(numThreads % candidate == 0);
656-
numThreads /= candidate;
682+
candidateWorkgroupSize = maybeCandidateWorkgroupSize;
683+
numThreads /= candidateWorkgroupSize;
657684
}
685+
workgroupTileSizes[shapeDim] = scaledTileSize;
686+
LLVM_DEBUG(llvm::dbgs()
687+
<< "Chosen workgroup tile size: " << scaledTileSize << "\n");
658688
assert(numThreads >= 1);
659689
break;
660690
}
@@ -674,8 +704,17 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
674704
if (distributeToThreads(newNumThreads) != 1) {
675705
// Otherwise, allow larger and larger loss factor.
676706

677-
// Threads for distribution. Use 32 at least.
678-
int64_t numThreads = std::max(subgroupSize, 32);
707+
// Threads for distribution. Use `minPreferredNumThreads` at least, but no
708+
// more than 4 subgroups.
709+
int64_t minPreferredNumThreads = std::reduce(
710+
workgroupTileSizeMultiples.begin(), workgroupTileSizeMultiples.end(), 1,
711+
std::multiplies<int64_t>());
712+
int64_t numThreads =
713+
std::min<int64_t>(4 * subgroupSize, minPreferredNumThreads);
714+
// If minPreferredNumThreads is small, use at least 32 or subgroupSize
715+
// threads, whichever is larger.
716+
numThreads =
717+
std::max<int64_t>(std::max<int64_t>(subgroupSize, 32), numThreads);
679718
// We can tolerate (1 / lossFactor) of threads in the workgroup to be idle.
680719
int64_t lossFactor = 32;
681720

@@ -685,21 +724,6 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
685724
}
686725
}
687726

688-
// Attach the MMA schedule as an attribute to the entry point export function
689-
// for later access in the pipeline.
690-
MLIRContext *context = linalgOp.getContext();
691-
SmallVector<NamedAttribute, 1> attrs;
692-
Builder b(context);
693-
attrs.emplace_back(StringAttr::get(context, "workgroup"),
694-
b.getI64ArrayAttr(workgroupTileSizes));
695-
696-
attrs.emplace_back(StringAttr::get(context, "thread"),
697-
b.getI64ArrayAttr(threadTileSizes));
698-
699-
if (isNonMatvecContraction(linalgOp)) {
700-
GPU::setPromotedOperandList(context, attrs, {0, 1});
701-
}
702-
703727
// Heuristic value chosen to limit maximum vector sizes when tiling below.
704728
const unsigned maxVectorSize = 32;
705729

@@ -726,6 +750,22 @@ LogicalResult setTileAndFuseLoweringConfig(IREE::GPU::TargetAttr target,
726750
loopTileSizes[i] = tileSize;
727751
}
728752
}
753+
754+
// Attach the MMA schedule as an attribute to the entry point export function
755+
// for later access in the pipeline.
756+
MLIRContext *context = linalgOp.getContext();
757+
SmallVector<NamedAttribute, 1> attrs;
758+
Builder b(context);
759+
attrs.emplace_back(StringAttr::get(context, "workgroup"),
760+
b.getI64ArrayAttr(workgroupTileSizes));
761+
762+
attrs.emplace_back(StringAttr::get(context, "thread"),
763+
b.getI64ArrayAttr(threadTileSizes));
764+
765+
if (isNonMatvecContraction(linalgOp)) {
766+
GPU::setPromotedOperandList(context, attrs, {0, 1});
767+
}
768+
729769
if (llvm::any_of(loopTileSizes, [](int64_t s) { return s != 0; })) {
730770
attrs.emplace_back(StringAttr::get(context, "reduction"),
731771
b.getI64ArrayAttr(loopTileSizes));

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2520,14 +2520,15 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {
25202520

25212521
Operation *rootOperation = nullptr;
25222522

2523-
// Find the root operation. linalg.generic, linalg.fill, and scatter are not
2524-
// root operations if there are other compute operations present.
2525-
// Also, construct a set of generic ops that are to be skipped. These generic
2526-
// ops that are used to compute scatter indices are not root operations.
2523+
// Find the root operation. linalg.generic, linalg.fill, tensor.pack,
2524+
// tensor.unpack, and scatter are not root operations if there are other
2525+
// compute operations present. Also, construct a set of generic ops that
2526+
// are to be skipped. These generic ops that are used to compute scatter
2527+
// indices are not root operations.
25272528
llvm::SmallDenseSet<Operation *, 4> genericToSkip;
25282529
for (Operation *op : llvm::reverse(computeOps)) {
2529-
if (!isa<linalg::GenericOp, linalg::FillOp, IREE::LinalgExt::ScatterOp>(
2530-
op)) {
2530+
if (!isa<linalg::GenericOp, linalg::FillOp, IREE::LinalgExt::ScatterOp,
2531+
tensor::PackOp, tensor::UnPackOp>(op)) {
25312532
rootOperation = op;
25322533
break;
25332534
}
@@ -2554,7 +2555,8 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {
25542555
}
25552556
}
25562557

2557-
// Generic ops take priority over scatter and fill ops as the root op.
2558+
// Generic ops take priority over pack, unpack, scatter, and fill ops as the
2559+
// root op.
25582560
if (!rootOperation) {
25592561
for (Operation *op : llvm::reverse(computeOps)) {
25602562
if (isa<linalg::GenericOp>(op) && !genericToSkip.contains(op)) {
@@ -2564,6 +2566,16 @@ LogicalResult initGPULaunchConfig(FunctionOpInterface funcOp) {
25642566
}
25652567
}
25662568

2569+
// Pack and unpack ops take priority over scatter and fill ops as the root op.
2570+
if (!rootOperation) {
2571+
for (Operation *op : llvm::reverse(computeOps)) {
2572+
if (isa<tensor::PackOp, tensor::UnPackOp>(op)) {
2573+
rootOperation = op;
2574+
break;
2575+
}
2576+
}
2577+
}
2578+
25672579
if (!rootOperation) {
25682580
for (Operation *op : llvm::reverse(computeOps)) {
25692581
if (isa<IREE::LinalgExt::ScatterOp, linalg::FillOp>(op)) {

0 commit comments

Comments
 (0)