@@ -279,16 +279,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
279279// / \param[out] phNativeMem Set to the native handle of the UR mem object.
280280// /
281281// / \return UR_RESULT_SUCCESS
282- UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle (ur_mem_handle_t ,
283- ur_native_handle_t *) {
284- // FIXME: there is no good way of doing this with a multi device context.
285- // If we return a single pointer, how would we know which device's allocation
286- // it should be?
287- // If we return a vector of pointers, this is OK for read only access but if
288- // we write to a buffer, how would we know which one had been written to?
289- // Should unused allocations be updated afterwards? We have no way of knowing
290- // any of these things in the current API design.
291- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
282+ UR_APIEXPORT ur_result_t UR_APICALL
283+ urMemGetNativeHandle (ur_mem_handle_t hMem, ur_device_handle_t Device,
284+ ur_native_handle_t *phNativeMem) {
285+ #if defined(__HIP_PLATFORM_NVIDIA__)
286+ if (sizeof (BufferMem::native_type) > sizeof (ur_native_handle_t )) {
287+ // Check that all the upper bits that cannot be represented by
288+ // ur_native_handle_t are empty.
289+ // NOTE: The following shift might trigger a warning, but the check in the
290+ // if above makes sure that this does not underflow.
291+ BufferMem::native_type UpperBits =
292+ std::get<BufferMem>(hMem->Mem ).getPtr (Device) >>
293+ (sizeof (ur_native_handle_t ) * CHAR_BIT);
294+ if (UpperBits) {
295+ // Return an error if any of the remaining bits is non-zero.
296+ return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
297+ }
298+ }
299+ *phNativeMem = reinterpret_cast <ur_native_handle_t >(
300+ std::get<BufferMem>(hMem->Mem ).getPtr (Device));
301+ #elif defined(__HIP_PLATFORM_AMD__)
302+ *phNativeMem = reinterpret_cast <ur_native_handle_t >(
303+ std::get<BufferMem>(hMem->Mem ).getPtr (Device));
304+ #else
305+ #error ("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
306+ #endif
307+ return UR_RESULT_SUCCESS;
292308}
293309
294310UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle (
0 commit comments