Skip to content

Commit 2b69f17

Browse files
Jakub Chlandanpmiller
authored andcommitted
UR_KERNEL_SUB_GROUP_INFO_SUB_GROUP_SIZE_INTEL on Cuda and HIP (#17137)
For HIP the value of sub group size can either be 32 or 64, it can be retrieved from `intel_reqd_sub_group_size` metadata node. Cuda only supports 32, which is enforced in the compiler, see [SemaSYCL::addIntelReqdSubGroupSizeAttr](https://github.com/intel/llvm/blob/sycl/clang/lib/Sema/SemaSYCLDeclAttr.cpp#L828). --------- Co-authored-by: Nicolas Miller <[email protected]>
1 parent dac89ee commit 2b69f17

File tree

7 files changed

+32
-10
lines changed

7 files changed

+32
-10
lines changed

source/adapters/cuda/kernel.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,15 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
339339
return ReturnValue(0);
340340
}
341341
case UR_KERNEL_SUB_GROUP_INFO_SUB_GROUP_SIZE_INTEL: {
342-
// Return value of 0 => unspecified or "auto" sub-group size
343-
// Correct for now, since warp size may be read from special register
344-
// TODO: Return warp size once default is primary sub-group size
345-
// TODO: Revisit if we can recover [[sub_group_size]] attribute from PTX
342+
const auto &KernelReqdSubGroupSizeMap =
343+
hKernel->getProgram()->KernelReqdSubGroupSizeMD;
344+
// If present, return the value of intel_reqd_sub_group_size metadata, if
345+
// not: 0, which stands for unspecified or auto sub-group size.
346+
if (auto KernelReqdSubGroupSize =
347+
KernelReqdSubGroupSizeMap.find(hKernel->getName());
348+
KernelReqdSubGroupSize != KernelReqdSubGroupSizeMap.end())
349+
return ReturnValue(KernelReqdSubGroupSize->second);
350+
346351
return ReturnValue(0);
347352
}
348353
default:

source/adapters/cuda/program.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
9898
} else if (Tag ==
9999
__SYCL_UR_PROGRAM_METADATA_TAG_MAX_LINEAR_WORK_GROUP_SIZE) {
100100
KernelMaxLinearWorkGroupSizeMD[Prefix] = MetadataElement.value.data64;
101+
} else if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_SUB_GROUP_SIZE) {
102+
assert(MetadataElement.type == UR_PROGRAM_METADATA_TYPE_UINT32);
103+
KernelReqdSubGroupSizeMD[Prefix] = MetadataElement.value.data32;
101104
}
102105
}
103106
return UR_RESULT_SUCCESS;

source/adapters/cuda/program.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct ur_program_handle_t_ {
3939
std::unordered_map<std::string, std::tuple<uint32_t, uint32_t, uint32_t>>
4040
KernelMaxWorkGroupSizeMD;
4141
std::unordered_map<std::string, uint64_t> KernelMaxLinearWorkGroupSizeMD;
42+
std::unordered_map<std::string, uint32_t> KernelReqdSubGroupSizeMD;
4243

4344
constexpr static size_t MaxLogSize = 8192u;
4445

@@ -49,7 +50,8 @@ struct ur_program_handle_t_ {
4950
ur_program_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device)
5051
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1},
5152
Context{Context}, Device{Device}, KernelReqdWorkGroupSizeMD{},
52-
KernelMaxWorkGroupSizeMD{}, KernelMaxLinearWorkGroupSizeMD{} {
53+
KernelMaxWorkGroupSizeMD{}, KernelMaxLinearWorkGroupSizeMD{},
54+
KernelReqdSubGroupSizeMD{} {
5355
urContextRetain(Context);
5456
urDeviceRetain(Device);
5557
}

source/adapters/hip/kernel.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,10 +274,15 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
274274
return ReturnValue(0);
275275
}
276276
case UR_KERNEL_SUB_GROUP_INFO_SUB_GROUP_SIZE_INTEL: {
277-
// Return value of 0 => unspecified or "auto" sub-group size
278-
// Correct for now, since warp size may be read from special register
279-
// TODO: Return warp size once default is primary sub-group size
280-
// TODO: Revisit if we can recover [[sub_group_size]] attribute from PTX
277+
const auto &KernelReqdSubGroupSizeMap =
278+
hKernel->getProgram()->KernelReqdSubGroupSizeMD;
279+
// If present, return the value of intel_reqd_sub_group_size metadata, if
280+
// not: 0, which stands for unspecified or auto sub-group size.
281+
if (auto KernelReqdSubGroupSize =
282+
KernelReqdSubGroupSizeMap.find(hKernel->getName());
283+
KernelReqdSubGroupSize != KernelReqdSubGroupSizeMap.end())
284+
return ReturnValue(KernelReqdSubGroupSize->second);
285+
281286
return ReturnValue(0);
282287
}
283288
default:

source/adapters/hip/program.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
119119
KernelReqdWorkGroupSizeMD[Prefix] =
120120
std::make_tuple(ReqdWorkGroupElements[0], ReqdWorkGroupElements[1],
121121
ReqdWorkGroupElements[2]);
122+
} else if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_SUB_GROUP_SIZE) {
123+
assert(MetadataElement.type == UR_PROGRAM_METADATA_TYPE_UINT32);
124+
KernelReqdSubGroupSizeMD[Prefix] = MetadataElement.value.data32;
122125
}
123126
}
124127

source/adapters/hip/program.hpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct ur_program_handle_t_ {
3939
std::unordered_map<std::string, std::string> GlobalIDMD;
4040
std::unordered_map<std::string, std::tuple<uint32_t, uint32_t, uint32_t>>
4141
KernelReqdWorkGroupSizeMD;
42+
std::unordered_map<std::string, uint32_t> KernelReqdSubGroupSizeMD;
4243

4344
constexpr static size_t MAX_LOG_SIZE = 8192u;
4445

@@ -48,7 +49,8 @@ struct ur_program_handle_t_ {
4849

4950
ur_program_handle_t_(ur_context_handle_t Ctxt, ur_device_handle_t Device)
5051
: Module{nullptr}, Binary{}, BinarySizeInBytes{0}, RefCount{1},
51-
Context{Ctxt}, Device{Device}, KernelReqdWorkGroupSizeMD{} {
52+
Context{Ctxt}, Device{Device}, KernelReqdWorkGroupSizeMD{},
53+
KernelReqdSubGroupSizeMD{} {
5254
urContextRetain(Context);
5355
urDeviceRetain(Device);
5456
}

source/ur/ur.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ const ur_command_t UR_EXT_COMMAND_TYPE_USER =
5757
"@max_work_group_size"
5858
#define __SYCL_UR_PROGRAM_METADATA_TAG_MAX_LINEAR_WORK_GROUP_SIZE \
5959
"@max_linear_work_group_size"
60+
#define __SYCL_UR_PROGRAM_METADATA_TAG_REQD_SUB_GROUP_SIZE \
61+
"@reqd_sub_group_size"
6062
#define __SYCL_UR_PROGRAM_METADATA_TAG_NEED_FINALIZATION "Requires finalization"
6163

6264
// Terminates the process with a catastrophic error message.

0 commit comments

Comments
 (0)