Skip to content

Commit 1a86f0a

Browse files
authored
[Offload] Add device info for shared memory (#167817)
1 parent e5f499f commit 1a86f0a

File tree

8 files changed

+47
-4
lines changed

8 files changed

+47
-4
lines changed

offload/liboffload/API/Device.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def ol_device_info_t : Enum {
4343
TaggedEtor<"ADDRESS_BITS", "uint32_t", "Number of bits used to represent an address in device memory">,
4444
TaggedEtor<"MAX_MEM_ALLOC_SIZE", "uint64_t", "The maximum size of memory object allocation in bytes">,
4545
TaggedEtor<"GLOBAL_MEM_SIZE", "uint64_t", "The size of global device memory in bytes">,
46+
TaggedEtor<"WORK_GROUP_LOCAL_MEM_SIZE", "uint64_t", "The maximum size of local shared memory per work group in bytes">,
4647
];
4748
list<TaggedEtor> fp_configs = !foreach(type, ["Single", "Double", "Half"], TaggedEtor<type # "_FP_CONFIG", "ol_device_fp_capability_flags_t", type # " precision floating point capability">);
4849
list<TaggedEtor> native_vec_widths = !foreach(type, ["char","short","int","long","float","double","half"], TaggedEtor<"NATIVE_VECTOR_WIDTH_" # type, "uint32_t", "Native vector width for " # type>);

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,13 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device,
495495
return Info.write(static_cast<uint32_t>(Value));
496496
}
497497

498+
case OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE: {
499+
if (!std::holds_alternative<uint64_t>(Entry->Value))
500+
return makeError(ErrorCode::BACKEND_FAILURE,
501+
"plugin returned incorrect type");
502+
return Info.write(std::get<uint64_t>(Entry->Value));
503+
}
504+
498505
case OL_DEVICE_INFO_MAX_WORK_SIZE_PER_DIMENSION:
499506
case OL_DEVICE_INFO_MAX_WORK_GROUP_SIZE_PER_DIMENSION: {
500507
// {x, y, z} triples
@@ -590,6 +597,7 @@ Error olGetDeviceInfoImplDetailHost(ol_device_handle_t Device,
590597
return Info.write<uint32_t>(std::numeric_limits<uintptr_t>::digits);
591598
case OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE:
592599
case OL_DEVICE_INFO_GLOBAL_MEM_SIZE:
600+
case OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE:
593601
return Info.write<uint64_t>(0);
594602
default:
595603
return createOffloadError(ErrorCode::INVALID_ENUMERATION,

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,6 +2186,16 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
21862186
if (auto Err = checkIfAPU())
21872187
return Err;
21882188

2189+
// Retrieve the size of the group memory.
2190+
for (const auto *Pool : AllMemoryPools) {
2191+
if (Pool->isGroup()) {
2192+
if (auto Err = Pool->getAttr(HSA_AMD_MEMORY_POOL_INFO_SIZE,
2193+
MaxBlockSharedMemSize))
2194+
return Err;
2195+
break;
2196+
}
2197+
}
2198+
21892199
return Plugin::success();
21902200
}
21912201

@@ -2923,6 +2933,9 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
29232933
if (Status == HSA_STATUS_SUCCESS)
29242934
Info.add("Cacheline Size", TmpUInt);
29252935

2936+
Info.add("Max Shared Memory per Work Group", MaxBlockSharedMemSize, "bytes",
2937+
DeviceInfo::WORK_GROUP_LOCAL_MEM_SIZE);
2938+
29262939
Status = getDeviceAttrRaw(HSA_AMD_AGENT_INFO_MAX_CLOCK_FREQUENCY, TmpUInt);
29272940
if (Status == HSA_STATUS_SUCCESS)
29282941
Info.add("Max Clock Freq", TmpUInt, "MHz",

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,10 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
794794
/// Get the unique identifier of the device.
795795
const char *getDeviceUid() const { return DeviceUid.c_str(); }
796796

797+
/// Get the total shared memory per block (in bytes) that can be used in any
798+
/// kernel.
799+
size_t getMaxBlockSharedMemSize() const { return MaxBlockSharedMemSize; }
800+
797801
/// Set the context of the device if needed, before calling device-specific
798802
/// functions. Plugins may implement this function as a no-op if not needed.
799803
virtual Error setContext() = 0;
@@ -1251,6 +1255,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
12511255
/// Internal representation for OMPT device (initialize & finalize)
12521256
std::atomic<bool> OmptInitialized;
12531257
#endif
1258+
1259+
/// The total per-block native shared memory that a kernel may use.
1260+
size_t MaxBlockSharedMemSize = 0;
12541261
};
12551262

12561263
/// Class implementing common functionalities of offload plugins. Each plugin

offload/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,12 @@ struct CUDADeviceTy : public GenericDeviceTy {
379379
return Err;
380380
HardwareParallelism = NumMuliprocessors * (MaxThreadsPerSM / WarpSize);
381381

382+
uint32_t MaxSharedMem;
383+
if (auto Err = getDeviceAttr(
384+
CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK, MaxSharedMem))
385+
return Err;
386+
MaxBlockSharedMemSize = MaxSharedMem;
387+
382388
return Plugin::success();
383389
}
384390

@@ -1089,10 +1095,8 @@ struct CUDADeviceTy : public GenericDeviceTy {
10891095
if (Res == CUDA_SUCCESS)
10901096
Info.add("Total Constant Memory", TmpInt, "bytes");
10911097

1092-
Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK,
1093-
TmpInt);
1094-
if (Res == CUDA_SUCCESS)
1095-
Info.add("Max Shared Memory per Block", TmpInt, "bytes");
1098+
Info.add("Max Shared Memory per Block", MaxBlockSharedMemSize, "bytes",
1099+
DeviceInfo::WORK_GROUP_LOCAL_MEM_SIZE);
10961100

10971101
Res = getDeviceAttrRaw(CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, TmpInt);
10981102
if (Res == CUDA_SUCCESS)

offload/tools/deviceinfo/llvm-offload-device-info.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,9 @@ ol_result_t printDevice(std::ostream &S, ol_device_handle_t D) {
205205
S, D, OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE, "Max Mem Allocation Size", "B"));
206206
OFFLOAD_ERR(printDeviceValue<uint64_t>(S, D, OL_DEVICE_INFO_GLOBAL_MEM_SIZE,
207207
"Global Mem Size", "B"));
208+
OFFLOAD_ERR(
209+
printDeviceValue<uint64_t>(S, D, OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE,
210+
"Work Group Shared Mem Size", "B"));
208211
OFFLOAD_ERR(
209212
(printDeviceValue<ol_device_fp_capability_flags_t, PrintKind::FP_FLAGS>(
210213
S, D, OL_DEVICE_INFO_SINGLE_FP_CONFIG,

offload/unittests/OffloadAPI/device/olGetDeviceInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ OL_DEVICE_INFO_TEST_DEVICE_VALUE_GT(GlobalMemSize, uint64_t,
217217
OL_DEVICE_INFO_GLOBAL_MEM_SIZE, 0);
218218
OL_DEVICE_INFO_TEST_HOST_SUCCESS(GlobalMemSize, uint64_t,
219219
OL_DEVICE_INFO_GLOBAL_MEM_SIZE);
220+
OL_DEVICE_INFO_TEST_DEVICE_VALUE_GT(SharedMemSize, uint64_t,
221+
OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE,
222+
0);
223+
OL_DEVICE_INFO_TEST_HOST_SUCCESS(SharedMemSize, uint64_t,
224+
OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE);
220225

221226
TEST_P(olGetDeviceInfoTest, InvalidNullHandleDevice) {
222227
ol_device_type_t DeviceType;

offload/unittests/OffloadAPI/device/olGetDeviceInfoSize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ OL_DEVICE_INFO_SIZE_TEST_EQ(MaxMemAllocSize, uint64_t,
7171
OL_DEVICE_INFO_MAX_MEM_ALLOC_SIZE);
7272
OL_DEVICE_INFO_SIZE_TEST_EQ(GlobalMemSize, uint64_t,
7373
OL_DEVICE_INFO_GLOBAL_MEM_SIZE);
74+
OL_DEVICE_INFO_SIZE_TEST_EQ(SharedMemSize, uint64_t,
75+
OL_DEVICE_INFO_WORK_GROUP_LOCAL_MEM_SIZE);
7476

7577
TEST_P(olGetDeviceInfoSizeTest, SuccessMaxWorkGroupSizePerDimension) {
7678
size_t Size = 0;

0 commit comments

Comments
 (0)