99// ===----------------------------------------------------------------------===//
1010
1111#include " program.hpp"
12+ #include " ur_util.hpp"
1213
1314#ifdef SYCL_ENABLE_KERNEL_FUSION
1415#ifdef UR_COMGR_VERSION4_INCLUDE
@@ -78,15 +79,6 @@ void getCoMgrBuildLog(const amd_comgr_data_set_t BuildDataSet, char *BuildLog,
7879} // namespace
7980#endif
8081
81- std::pair<std::string, std::string>
82- splitMetadataName (const std::string &metadataName) {
83- size_t splitPos = metadataName.rfind (' @' );
84- if (splitPos == std::string::npos)
85- return std::make_pair (metadataName, std::string{});
86- return std::make_pair (metadataName.substr (0 , splitPos),
87- metadataName.substr (splitPos, metadataName.length ()));
88- }
89-
9082ur_result_t
9183ur_program_handle_t_::setMetadata (const ur_program_metadata_t *Metadata,
9284 size_t Length) {
@@ -107,8 +99,29 @@ ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
10799 const char *MetadataValPtrEnd =
108100 MetadataValPtr + MetadataElement.size - sizeof (std::uint64_t );
109101 GlobalIDMD[Prefix] = std::string{MetadataValPtr, MetadataValPtrEnd};
102+ } else if (Tag == __SYCL_UR_PROGRAM_METADATA_TAG_REQD_WORK_GROUP_SIZE) {
103+ // If metadata is reqd_work_group_size, record it for the corresponding
104+ // kernel name.
105+ size_t MDElemsSize = MetadataElement.size - sizeof (std::uint64_t );
106+
107+ // Expect between 1 and 3 32-bit integer values.
108+ UR_ASSERT (MDElemsSize >= sizeof (std::uint32_t ) &&
109+ MDElemsSize <= sizeof (std::uint32_t ) * 3 ,
110+ UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
111+
112+ // Get pointer to data, skipping 64-bit size at the start of the data.
113+ const char *ValuePtr =
114+ reinterpret_cast <const char *>(MetadataElement.value .pData ) +
115+ sizeof (std::uint64_t );
116+ // Read values and pad with 1's for values not present.
117+ std::uint32_t ReqdWorkGroupElements[] = {1 , 1 , 1 };
118+ std::memcpy (ReqdWorkGroupElements, ValuePtr, MDElemsSize);
119+ KernelReqdWorkGroupSizeMD[Prefix] =
120+ std::make_tuple (ReqdWorkGroupElements[0 ], ReqdWorkGroupElements[1 ],
121+ ReqdWorkGroupElements[2 ]);
110122 }
111123 }
124+
112125 return UR_RESULT_SUCCESS;
113126}
114127
@@ -459,8 +472,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
459472 std::unique_ptr<ur_program_handle_t_> RetProgram{
460473 new ur_program_handle_t_{hContext, hDevice}};
461474
462- // TODO: Set metadata here and use reqd_work_group_size information.
463- // See urProgramCreateWithBinary in CUDA adapter.
464475 if (pProperties) {
465476 if (pProperties->count > 0 && pProperties->pMetadatas == nullptr ) {
466477 return UR_RESULT_ERROR_INVALID_NULL_POINTER;
@@ -469,8 +480,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramCreateWithBinary(
469480 }
470481 Result =
471482 RetProgram->setMetadata (pProperties->pMetadatas , pProperties->count );
483+ UR_ASSERT (Result == UR_RESULT_SUCCESS, Result);
472484 }
473- UR_ASSERT (Result == UR_RESULT_SUCCESS, Result);
474485
475486 auto pBinary_string = reinterpret_cast <const char *>(pBinary);
476487 if (size == 0 ) {
0 commit comments