@@ -187,8 +187,15 @@ static ur_result_t USMDeviceAllocImpl(void **ResultPtr,
187187 ZeDesc.pNext = &RelaxedDesc;
188188 }
189189
190- ZE2UR_CALL (zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size, Alignment,
191- Device->ZeDevice , ResultPtr));
190+ ze_result_t ZeResult =
191+ zeMemAllocDevice (Context->ZeContext , &ZeDesc, Size, Alignment,
192+ Device->ZeDevice , ResultPtr);
193+ if (ZeResult != ZE_RESULT_SUCCESS) {
194+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
195+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
196+ }
197+ return ze2urResult (ZeResult);
198+ }
192199
193200 UR_ASSERT (Alignment == 0 ||
194201 reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -226,8 +233,15 @@ static ur_result_t USMSharedAllocImpl(void **ResultPtr,
226233 ZeDevDesc.pNext = &RelaxedDesc;
227234 }
228235
229- ZE2UR_CALL (zeMemAllocShared, (Context->ZeContext , &ZeDevDesc, &ZeHostDesc,
230- Size, Alignment, Device->ZeDevice , ResultPtr));
236+ ze_result_t ZeResult =
237+ zeMemAllocShared (Context->ZeContext , &ZeDevDesc, &ZeHostDesc, Size,
238+ Alignment, Device->ZeDevice , ResultPtr);
239+ if (ZeResult != ZE_RESULT_SUCCESS) {
240+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
241+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
242+ }
243+ return ze2urResult (ZeResult);
244+ }
231245
232246 UR_ASSERT (Alignment == 0 ||
233247 reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -254,8 +268,14 @@ static ur_result_t USMHostAllocImpl(void **ResultPtr,
254268 // TODO: translate PI properties to Level Zero flags
255269 ZeStruct<ze_host_mem_alloc_desc_t > ZeHostDesc;
256270 ZeHostDesc.flags = 0 ;
257- ZE2UR_CALL (zeMemAllocHost,
258- (Context->ZeContext , &ZeHostDesc, Size, Alignment, ResultPtr));
271+ ze_result_t ZeResult = zeMemAllocHost (Context->ZeContext , &ZeHostDesc, Size,
272+ Alignment, ResultPtr);
273+ if (ZeResult != ZE_RESULT_SUCCESS) {
274+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
275+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
276+ }
277+ return ze2urResult (ZeResult);
278+ }
259279
260280 UR_ASSERT (Alignment == 0 ||
261281 reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -599,6 +619,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
599619 ZE2UR_CALL (zeMemGetAddressRange, (Context->ZeContext , Ptr, nullptr , &Size));
600620 return ReturnValue (Size);
601621 }
622+ case UR_USM_ALLOC_INFO_POOL: {
623+ auto UMFPool = umfPoolByPtr (Ptr);
624+ if (!UMFPool) {
625+ return UR_RESULT_ERROR_INVALID_VALUE;
626+ }
627+
628+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
629+
630+ auto SearchMatchingPool =
631+ [](std::unordered_map<ur_device_handle_t , umf::pool_unique_handle_t >
632+ &PoolMap,
633+ umf_memory_pool_handle_t UMFPool) {
634+ for (auto &PoolPair : PoolMap) {
635+ if (PoolPair.second .get () == UMFPool) {
636+ return true ;
637+ }
638+ }
639+ return false ;
640+ };
641+
642+ for (auto &Pool : Context->UsmPoolHandles ) {
643+ if (SearchMatchingPool (Pool->DeviceMemPools , UMFPool)) {
644+ return ReturnValue (Pool);
645+ }
646+ if (SearchMatchingPool (Pool->SharedMemPools , UMFPool)) {
647+ return ReturnValue (Pool);
648+ }
649+ if (Pool->HostMemPool .get () == UMFPool) {
650+ return ReturnValue (Pool);
651+ }
652+ }
653+
654+ return UR_RESULT_ERROR_INVALID_VALUE;
655+ }
602656 default :
603657 urPrint (" urUSMGetMemAllocInfo: unsupported ParamName\n " );
604658 return UR_RESULT_ERROR_INVALID_VALUE;
@@ -748,6 +802,7 @@ ur_result_t L0HostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
748802ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
749803 ur_usm_pool_desc_t *PoolDesc) {
750804
805+ this ->Context = Context;
751806 zeroInit = static_cast <uint32_t >(PoolDesc->flags &
752807 UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);
753808
@@ -831,6 +886,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
831886 try {
832887 *Pool = reinterpret_cast <ur_usm_pool_handle_t >(
833888 new ur_usm_pool_handle_t_ (Context, PoolDesc));
889+
890+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
891+ Context->UsmPoolHandles .insert (Context->UsmPoolHandles .cend (), *Pool);
892+
834893 } catch (const UsmAllocationException &Ex) {
835894 return Ex.getError ();
836895 }
@@ -848,6 +907,8 @@ ur_result_t
848907urUSMPoolRelease (ur_usm_pool_handle_t Pool // /< [in] pointer to USM memory pool
849908) {
850909 if (Pool->RefCount .decrementAndTest ()) {
910+ std::shared_lock<ur_shared_mutex> ContextLock (Pool->Context ->Mutex );
911+ Pool->Context ->UsmPoolHandles .remove (Pool);
851912 delete Pool;
852913 }
853914 return UR_RESULT_SUCCESS;
@@ -861,13 +922,19 @@ ur_result_t urUSMPoolGetInfo(
861922 // /< property
862923 size_t *PropSizeRet // /< [out] size in bytes returned in pool property value
863924) {
864- std::ignore = Pool;
865- std::ignore = PropName;
866- std::ignore = PropSize;
867- std::ignore = PropValue;
868- std::ignore = PropSizeRet;
869- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
870- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
925+ UrReturnHelper ReturnValue (PropSize, PropValue, PropSizeRet);
926+
927+ switch (PropName) {
928+ case UR_USM_POOL_INFO_REFERENCE_COUNT: {
929+ return ReturnValue (Pool->RefCount .load ());
930+ }
931+ case UR_USM_POOL_INFO_CONTEXT: {
932+ return ReturnValue (Pool->Context );
933+ }
934+ default : {
935+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
936+ }
937+ }
871938}
872939
873940// If indirect access tracking is not enabled then this functions just performs
0 commit comments