1212
1313#include " common.hpp"
1414#include " program.hpp"
15+ #include < cstdint>
1516
1617UR_APIEXPORT ur_result_t UR_APICALL
1718urProgramCreateWithIL (ur_context_handle_t hContext, const void *pIL,
@@ -26,6 +27,39 @@ urProgramCreateWithIL(ur_context_handle_t hContext, const void *pIL,
2627 DIE_NO_IMPLEMENTATION
2728}
2829
30+ // TODO: taken from CUDA adapter, move this to a common header?
31+ static std::pair<std::string, std::string>
32+ splitMetadataName (const std::string &metadataName) {
33+ size_t splitPos = metadataName.rfind (' @' );
34+ if (splitPos == std::string::npos)
35+ return std::make_pair (metadataName, std::string{});
36+ return std::make_pair (metadataName.substr (0 , splitPos),
37+ metadataName.substr (splitPos, metadataName.length ()));
38+ }
39+
40+ static ur_result_t getReqdWGSize (const ur_program_metadata_t &MetadataElement,
41+ native_cpu::ReqdWGSize_t &res) {
42+ size_t MDElemsSize = MetadataElement.size - sizeof (std::uint64_t );
43+
44+ // Expect between 1 and 3 32-bit integer values.
45+ UR_ASSERT (MDElemsSize == sizeof (std::uint32_t ) ||
46+ MDElemsSize == sizeof (std::uint32_t ) * 2 ||
47+ MDElemsSize == sizeof (std::uint32_t ) * 3 ,
48+ UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
49+
50+ // Get pointer to data, skipping 64-bit size at the start of the data.
51+ const char *ValuePtr =
52+ reinterpret_cast <const char *>(MetadataElement.value .pData ) +
53+ sizeof (std::uint64_t );
54+ // Read values and pad with 1's for values not present.
55+ std::uint32_t ReqdWorkGroupElements[] = {1 , 1 , 1 };
56+ std::memcpy (ReqdWorkGroupElements, ValuePtr, MDElemsSize);
57+ std::get<0 >(res) = ReqdWorkGroupElements[0 ];
58+ std::get<1 >(res) = ReqdWorkGroupElements[1 ];
59+ std::get<2 >(res) = ReqdWorkGroupElements[2 ];
60+ return UR_RESULT_SUCCESS;
61+ }
62+
2963UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary (
3064 ur_context_handle_t hContext, ur_device_handle_t hDevice, size_t size,
3165 const uint8_t *pBinary, const ur_program_properties_t *pProperties,
@@ -40,6 +74,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
4074
4175 auto hProgram = new ur_program_handle_t_ (
4276 hContext, reinterpret_cast <const unsigned char *>(pBinary));
77+ if (pProperties != nullptr ) {
78+ for (uint32_t i = 0 ; i < pProperties->count ; i++) {
79+ auto mdNode = pProperties->pMetadatas [i];
80+ std::string mdName (mdNode.pName );
81+ auto [Prefix, Tag] = splitMetadataName (mdName);
82+ if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
83+ native_cpu::ReqdWGSize_t reqdWGSize;
84+ auto res = getReqdWGSize (mdNode, reqdWGSize);
85+ if (res != UR_RESULT_SUCCESS) {
86+ return res;
87+ }
88+ hProgram->KernelReqdWorkGroupSizeMD [Prefix] = std::move (reqdWGSize);
89+ }
90+ }
91+ }
4392
4493 const nativecpu_entry *nativecpu_it =
4594 reinterpret_cast <const nativecpu_entry *>(pBinary);
0 commit comments