Skip to content

Commit d7bb36b

Browse files
authored
[Codegen][ROCDL] Drop nominal support for dynamic shared mem (#20805)
This was introduced for CUDA to work around static allocation limits (see #8317) and then subsequently cloned for ROCM. There is no such allocation limit on the AMD side and we don't actually support dynamic shared memory, so we can drop it. Immediately this fixes issues where multiple exported functions will request the sum of the memory required across all functions because the dynamic symbol obfuscates the actual required amount. To avoid regressions, this matches the current behavior of forcing all shared memory allocations to be aligned by 16. Runtime support for dynamic shared is dropped and the `block_shared_memory_size` field in the flatbuffer is marked as deprecated. To retain load time checking of shared memory resource constraints, we now query exported functions directly using `hipFuncGetAttribute(..., HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES)`.
1 parent f9d892a commit d7bb36b

File tree

11 files changed

+79
-39
lines changed

11 files changed

+79
-39
lines changed

compiler/plugins/target/ROCM/ROCMTarget.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -828,12 +828,6 @@ class ROCMTargetBackend final : public TargetBackend {
828828
blockDims.z = cast<IntegerAttr>(workgroupSize[2]).getInt();
829829
}
830830

831-
uint32_t blockSharedMemorySize = 0;
832-
if (std::optional<APInt> workgroupLocalMemoryAttr =
833-
exportOp.getWorkgroupLocalMemory()) {
834-
blockSharedMemorySize = workgroupLocalMemoryAttr->getSExtValue();
835-
}
836-
837831
auto layoutAttr = exportOp.getLayoutAttr();
838832
uint32_t constantCount = static_cast<uint32_t>(layoutAttr.getConstants());
839833
SmallVector<iree_hal_hip_BindingBits_enum_t> bindingFlags;
@@ -856,8 +850,6 @@ class ROCMTargetBackend final : public TargetBackend {
856850
iree_hal_hip_ExportDef_module_ordinal_add(builder, 0); // always 0 today
857851
iree_hal_hip_ExportDef_kernel_name_add(builder, kernelNameRef);
858852
iree_hal_hip_ExportDef_block_dims_add(builder, &blockDims);
859-
iree_hal_hip_ExportDef_block_shared_memory_size_add(
860-
builder, blockSharedMemorySize);
861853
iree_hal_hip_ExportDef_constant_count_add(builder, constantCount);
862854
iree_hal_hip_ExportDef_binding_flags_add(builder, bindingFlagsRef);
863855
iree_hal_hip_ExportDef_debug_info_add(builder, exportDebugInfos[ordinal]);

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ namespace mlir::iree_compiler {
3232

3333
void ConvertToDynamicSharedMemory(ModuleOp moduleOp) {
3434
SymbolTableCollection symbolTableCollection;
35-
// Collect all the adressOfOps to static shared memory globals.
35+
// Collect all the addressOfOps to static shared memory globals.
3636
SmallVector<LLVM::AddressOfOp> addressOfOps;
3737
moduleOp.walk([&](LLVM::AddressOfOp addressOfOp) {
3838
// Check that the global associated with this addressOfOp has shared memory
@@ -95,6 +95,18 @@ void ConvertToDynamicSharedMemory(ModuleOp moduleOp) {
9595
}
9696
}
9797

98+
void setSharedMemoryAlignment(ModuleOp moduleOp, uint64_t newAlignment) {
99+
for (auto global : moduleOp.getOps<LLVM::GlobalOp>()) {
100+
if (global.getAddrSpace() == 3) {
101+
uint64_t baseAlignment = 0;
102+
if (std::optional<uint64_t> alignment = global.getAlignment()) {
103+
baseAlignment = alignment.value();
104+
}
105+
global.setAlignment(std::max<uint64_t>(baseAlignment, newAlignment));
106+
}
107+
}
108+
}
109+
98110
namespace {
99111

100112
/// Scalarize math ops. It is needed to lower vector operation that don't have

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToLLVM.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ void populateConvertSharedMemoryAllocOps(RewritePatternSet &patterns);
2929

3030
void ConvertToDynamicSharedMemory(ModuleOp moduleOp);
3131

32+
void setSharedMemoryAlignment(ModuleOp moduleOp, uint64_t newAlignment);
33+
3234
using MemorySpaceMapping =
3335
std::function<unsigned(gpu::AddressSpace gpuAddressSpace)>;
3436
void populateGpuMemorySpaceAttributeConversions(

compiler/src/iree/compiler/Codegen/LLVMGPU/ConvertToROCDL.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,9 +339,14 @@ struct ConvertToROCDLPass final
339339
}
340340

341341
LDBG("After converting to rocdl\n" << m);
342-
ConvertToDynamicSharedMemory(m);
343342

344-
LDBG("After converting to dynamic shared memory\n" << m);
343+
// 16 is the maximum relevant alignment for all AMD GPUs. Unceremoniously
344+
// set it to 16 as all of our allocations almost always have much greater
345+
// alignment than this.
346+
// TODO(qedawkins): Set this much earlier when we introduce the allocations.
347+
setSharedMemoryAlignment(m, 16);
348+
349+
LDBG("After updating shared memory alignments\n" << m);
345350
}
346351
};
347352
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/LLVMGPU/test/convert_to_rocdl.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,30 @@ module {
305305
// CHECK-NEXT: rocdl.s.setprio 3
306306
// CHECK: rocdl.mfma
307307
// CHECK-NEXT: rocdl.s.setprio 4
308+
309+
// -----
310+
311+
builtin.module {
312+
func.func @shared_memory_lowering() {
313+
%c0 = arith.constant 0 : index
314+
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
315+
%0 = memref.alloc() : memref<1x16x32xf32, #gpu.address_space<workgroup>>
316+
%1 = memref.alloc() : memref<1x32x16xf32, #gpu.address_space<workgroup>>
317+
%2 = memref.alloc() : memref<1x8x16xf32, #gpu.address_space<workgroup>>
318+
vector.store %cst, %1[%c0, %c0, %c0] : memref<1x32x16xf32, #gpu.address_space<workgroup>>, vector<4xf32>
319+
vector.store %cst, %2[%c0, %c0, %c0] : memref<1x8x16xf32, #gpu.address_space<workgroup>>, vector<4xf32>
320+
vector.store %cst, %0[%c0, %c0, %c0] : memref<1x16x32xf32, #gpu.address_space<workgroup>>, vector<4xf32>
321+
return
322+
}
323+
}
324+
325+
// CHECK-DAG: llvm.mlir.global private @__shared_memory___1() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<1 x array<16 x array<32 x f32>>>
326+
// CHECK-DAG: llvm.mlir.global private @__shared_memory___0() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<1 x array<32 x array<16 x f32>>>
327+
// CHECK-DAG: llvm.mlir.global private @__shared_memory__() {addr_space = 3 : i32, alignment = 16 : i64} : !llvm.array<1 x array<8 x array<16 x f32>>>
328+
// CHECK-LABEL: llvm.func @shared_memory_lowering() {
329+
// CHECK-DAG: %[[A1:.+]] = llvm.mlir.addressof @__shared_memory___1
330+
// CHECK-DAG: llvm.getelementptr %[[A1]][0, 0, 0, 0]
331+
// CHECK-DAG: %[[A0:.+]] = llvm.mlir.addressof @__shared_memory___0
332+
// CHECK-DAG: llvm.getelementptr %[[A0]][0, 0, 0, 0]
333+
// CHECK-DAG: %[[A:.+]] = llvm.mlir.addressof @__shared_memory__
334+
// CHECK-DAG: llvm.getelementptr %[[A]][0, 0, 0, 0]

runtime/src/iree/hal/drivers/hip/dynamic_symbol_tables.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ IREE_HAL_HIP_REQUIRED_PFN_DECL(hipEventRecord, hipEvent_t, hipStream_t)
3636
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipEventSynchronize, hipEvent_t)
3737
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipFree, void *)
3838
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipFreeAsync, void *, hipStream_t)
39+
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipFuncGetAttribute, int *, hipFuncAttribute,
40+
const void *)
3941
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipFuncSetAttribute, const void *,
4042
hipFuncAttribute, int)
4143
IREE_HAL_HIP_REQUIRED_PFN_DECL(hipGetDeviceCount, int *)

runtime/src/iree/hal/drivers/hip/graph_command_buffer.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,7 +815,6 @@ static iree_status_t iree_hal_hip_graph_command_buffer_dispatch(
815815
.gridDim.z = workgroup_count[2],
816816
.func = kernel_params->function,
817817
.kernelParams = params_ptr,
818-
.sharedMemBytes = kernel_params->block_shared_memory_size,
819818
};
820819

821820
if (command_buffer->graph_node_count >=

runtime/src/iree/hal/drivers/hip/native_executable.c

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,6 @@ static iree_status_t iree_hal_hip_native_executable_flatbuffer_verify(
174174
i);
175175
}
176176

177-
uint32_t block_shared_memory_size =
178-
iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def);
179-
if (block_shared_memory_size > limits->max_block_shared_memory_size) {
180-
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
181-
"exports[%" PRIhsz
182-
"] requires %uB of shared memory and "
183-
"exceeds the device maximum of %uB per block",
184-
i, block_shared_memory_size,
185-
limits->max_block_shared_memory_size);
186-
}
187-
188177
uint32_t constant_count =
189178
iree_hal_hip_ExportDef_constant_count_get(export_def);
190179
if (constant_count > IREE_HAL_HIP_MAX_DISPATCH_CONSTANT_COUNT) {
@@ -212,6 +201,30 @@ static iree_status_t iree_hal_hip_native_executable_flatbuffer_verify(
212201
return iree_ok_status();
213202
}
214203

204+
// Verifies a function against the device limits so that we can avoid doing so
205+
// during runtime.
206+
static iree_status_t iree_hal_hip_function_attributes_verify(
207+
iree_host_size_t id, const iree_hal_hip_dynamic_symbols_t* symbols,
208+
hipFunction_t function, const iree_hal_hip_limits_t* limits) {
209+
int block_shared_memory_size;
210+
IREE_RETURN_IF_ERROR(IREE_HIP_CALL_TO_STATUS(
211+
symbols,
212+
hipFuncGetAttribute(
213+
&block_shared_memory_size,
214+
(hipFuncAttribute)HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function),
215+
"hipFuncGetAttribute"));
216+
if (block_shared_memory_size > limits->max_block_shared_memory_size) {
217+
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
218+
"exports[%" PRIhsz
219+
"] requires %uB of shared memory and "
220+
"exceeds the device maximum of %uB per block",
221+
id, block_shared_memory_size,
222+
limits->max_block_shared_memory_size);
223+
}
224+
225+
return iree_ok_status();
226+
}
227+
215228
iree_status_t iree_hal_hip_native_executable_create(
216229
const iree_hal_hip_dynamic_symbols_t* symbols,
217230
iree_hal_hip_device_topology_t topology,
@@ -397,16 +410,8 @@ iree_status_t iree_hal_hip_native_executable_create(
397410
break;
398411
}
399412

400-
uint32_t block_shared_memory_size =
401-
iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def);
402-
status = IREE_HIP_CALL_TO_STATUS(
403-
symbols,
404-
hipFuncSetAttribute(
405-
function,
406-
(hipFuncAttribute)
407-
HIP_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
408-
block_shared_memory_size),
409-
"hipFuncSetAttribute");
413+
status = iree_hal_hip_function_attributes_verify(i, symbols, function,
414+
&limits);
410415
if (!iree_status_is_ok(status)) break;
411416

412417
// Package required parameters for kernel launches for each entry
@@ -419,8 +424,6 @@ iree_status_t iree_hal_hip_native_executable_create(
419424
kernel_info->block_dims[0] = block_dims->x;
420425
kernel_info->block_dims[1] = block_dims->y;
421426
kernel_info->block_dims[2] = block_dims->z;
422-
kernel_info->block_shared_memory_size =
423-
iree_hal_hip_ExportDef_block_shared_memory_size_get(export_def);
424427
kernel_info->constant_count =
425428
iree_hal_hip_ExportDef_constant_count_get(export_def);
426429
iree_hal_hip_BindingBits_vec_t binding_flags_vec =

runtime/src/iree/hal/drivers/hip/native_executable.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ typedef struct iree_hal_hip_kernel_params_t {
3737
uint32_t binding_count;
3838

3939
uint32_t block_dims[3];
40-
uint32_t block_shared_memory_size;
4140

4241
IREE_TRACE(iree_hal_hip_kernel_debug_info_t debug_info;)
4342
} iree_hal_hip_kernel_params_t;

runtime/src/iree/hal/drivers/hip/stream_command_buffer.c

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,7 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch(
578578
kernel_params->function, workgroup_count[0], workgroup_count[1],
579579
workgroup_count[2], kernel_params->block_dims[0],
580580
kernel_params->block_dims[1], kernel_params->block_dims[2],
581-
kernel_params->block_shared_memory_size, command_buffer->hip_stream,
582-
params_ptr, NULL),
581+
/*sharedMemBytes=*/0, command_buffer->hip_stream, params_ptr, NULL),
583582
"hipModuleLaunchKernel");
584583

585584
IREE_HAL_STREAM_TRACE_ZONE_END(command_buffer->tracing_context,

0 commit comments

Comments
 (0)