Skip to content

Commit 6d25b91

Browse files
authored
Teach PropagateDispatchSizeBounds about gpu.lane_id (#20456)
If subgroup_size is available, update the upper bound of gpu.lane_id. Fixes: #20385
1 parent 6bf49d9 commit 6d25b91

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

compiler/src/iree/compiler/Codegen/Common/PropagateDispatchSizeBounds.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,16 @@ static void foldConstantBounds(
6464

6565
static void applyBounds(FunctionOpInterface funcOp,
6666
ArrayRef<std::optional<int64_t>> workgroupSizes,
67-
ArrayRef<std::optional<int64_t>> workgroupCounts) {
67+
ArrayRef<std::optional<int64_t>> workgroupCounts,
68+
std::optional<uint64_t> subgroupSize) {
6869
Builder b(funcOp->getContext());
6970
funcOp->walk([&](Operation *op) {
7071
TypeSwitch<Operation *>(op)
72+
.Case([&](gpu::LaneIdOp laneIdOp) {
73+
if (subgroupSize) {
74+
laneIdOp.setUpperBoundAttr(b.getIndexAttr(*subgroupSize));
75+
}
76+
})
7177
.Case([&](gpu::ThreadIdOp tidOp) {
7278
std::optional<int64_t> bound =
7379
workgroupSizes[static_cast<uint32_t>(tidOp.getDimension())];
@@ -132,6 +138,8 @@ struct PropagateDispatchSizeBoundsPass final
132138
std::optional<SmallVector<int64_t>> staticWorkgroupSize =
133139
getWorkgroupSize(funcOp);
134140

141+
std::optional<uint64_t> subgroupSize = getGPUSubgroupSize(funcOp);
142+
135143
// Late in codegen, we've reconciled the workgroup size onto the export op.
136144
if (std::optional<IREE::HAL::ExecutableExportOp> exportOp =
137145
getEntryPoint(funcOp)) {
@@ -141,6 +149,11 @@ struct PropagateDispatchSizeBoundsPass final
141149
llvm::map_to_vector(exportWorkgroupSize->getAsRange<IntegerAttr>(),
142150
[](IntegerAttr a) { return a.getInt(); });
143151
}
152+
153+
if (std::optional<uint64_t> exportSubgroupSize =
154+
exportOp->getSubgroupSizeAsUInt()) {
155+
subgroupSize = exportSubgroupSize;
156+
}
144157
}
145158

146159
if (staticWorkgroupSize) {
@@ -162,7 +175,7 @@ struct PropagateDispatchSizeBoundsPass final
162175
}
163176

164177
foldConstantBounds(funcOp, staticWorkgroupSize, staticWorkgroupCounts);
165-
applyBounds(funcOp, workgroupSizes, workgroupCounts);
178+
applyBounds(funcOp, workgroupSizes, workgroupCounts, subgroupSize);
166179
}
167180
};
168181
} // namespace

compiler/src/iree/compiler/Codegen/Common/test/propagate_dispatch_size_bounds.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ hal.executable private @static {
2727
builtin.module {
2828
// CHECK-LABEL: func.func @static()
2929
func.func @static() {
30+
// CHECK-NEXT: gpu.lane_id upper_bound 32
31+
%lane_id = gpu.lane_id
32+
3033
// CHECK-NEXT: gpu.thread_id x upper_bound 64
3134
// CHECK-NEXT: gpu.thread_id y upper_bound 2
3235
// CHECK-NEXT: gpu.thread_id z upper_bound 1
@@ -70,6 +73,42 @@ hal.executable private @static {
7073

7174
// -----
7275

76+
// Note: not the real target definition, missing types
77+
#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
78+
wgp = <compute = fp32,
79+
storage = b32,
80+
subgroup = arithmetic,
81+
dot = none, mma = [],
82+
subgroup_size_choices = [32, 64],
83+
max_workgroup_sizes = [1024, 1024, 1024],
84+
max_thread_count_per_workgroup = 1024,
85+
max_workgroup_memory_bytes = 65536,
86+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>}>
87+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer>]>
88+
89+
hal.executable private @manual_subgroup_size {
90+
hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
91+
hal.executable.export public @manual_subgroup_size ordinal(0) layout(#pipeline_layout) attributes {subgroup_size = 64 : index} {
92+
^bb0(%arg0: !hal.device):
93+
%c32 = arith.constant 32 : index
94+
%c8 = arith.constant 8 : index
95+
%c1 = arith.constant 1 : index
96+
hal.return %c32, %c8, %c1 : index, index, index
97+
}
98+
builtin.module {
99+
// CHECK-LABEL: func.func @manual_subgroup_size()
100+
func.func @manual_subgroup_size() {
101+
// CHECK-NEXT: gpu.lane_id upper_bound 64
102+
%lane_id = gpu.lane_id
103+
104+
return
105+
}
106+
}
107+
}
108+
}
109+
110+
// -----
111+
73112
#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb",
74113
{iree.gpu.target = #iree_gpu.target<arch = "gfx1100", features = "",
75114
wgp = <compute = fp32,

0 commit comments

Comments
 (0)