@@ -187,8 +187,15 @@ static ur_result_t USMDeviceAllocImpl(void **ResultPtr,
187187 ZeDesc.pNext = &RelaxedDesc;
188188 }
189189
190- ZE2UR_CALL (zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size, Alignment,
191- Device->ZeDevice , ResultPtr));
190+ ze_result_t ZeResult = ZE_CALL_NOCHECK (
191+ zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size, Alignment,
192+ Device->ZeDevice , ResultPtr));
193+ if (ZeResult != ZE_RESULT_SUCCESS) {
194+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
195+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
196+ }
197+ return ze2urResult (ZeResult);
198+ }
192199
193200 UR_ASSERT (Alignment == 0 ||
194201 reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -226,8 +233,15 @@ static ur_result_t USMSharedAllocImpl(void **ResultPtr,
226233 ZeDevDesc.pNext = &RelaxedDesc;
227234 }
228235
229- ZE2UR_CALL (zeMemAllocShared, (Context->ZeContext , &ZeDevDesc, &ZeHostDesc,
230- Size, Alignment, Device->ZeDevice , ResultPtr));
236+ ze_result_t ZeResult = ZE_CALL_NOCHECK (
237+ zeMemAllocShared, (Context->ZeContext , &ZeDevDesc, &ZeHostDesc, Size,
238+ Alignment, Device->ZeDevice , ResultPtr));
239+ if (ZeResult != ZE_RESULT_SUCCESS) {
240+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
241+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
242+ }
243+ return ze2urResult (ZeResult);
244+ }
231245
232246 UR_ASSERT (Alignment == 0 ||
233247 reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -254,8 +268,15 @@ static ur_result_t USMHostAllocImpl(void **ResultPtr,
254268 // TODO: translate PI properties to Level Zero flags
255269 ZeStruct<ze_host_mem_alloc_desc_t > ZeHostDesc;
256270 ZeHostDesc.flags = 0 ;
257- ZE2UR_CALL (zeMemAllocHost,
258- (Context->ZeContext , &ZeHostDesc, Size, Alignment, ResultPtr));
271+ ze_result_t ZeResult =
272+ ZE_CALL_NOCHECK (zeMemAllocHost, (Context->ZeContext , &ZeHostDesc, Size,
273+ Alignment, ResultPtr));
274+ if (ZeResult != ZE_RESULT_SUCCESS) {
275+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
276+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
277+ }
278+ return ze2urResult (ZeResult);
279+ }
259280
260281 UR_ASSERT (Alignment == 0 ||
261282 reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -599,6 +620,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
599620 ZE2UR_CALL (zeMemGetAddressRange, (Context->ZeContext , Ptr, nullptr , &Size));
600621 return ReturnValue (Size);
601622 }
623+ case UR_USM_ALLOC_INFO_POOL: {
624+ auto UMFPool = umfPoolByPtr (Ptr);
625+ if (!UMFPool) {
626+ return UR_RESULT_ERROR_INVALID_VALUE;
627+ }
628+
629+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
630+
631+ auto SearchMatchingPool =
632+ [](std::unordered_map<ur_device_handle_t , umf::pool_unique_handle_t >
633+ &PoolMap,
634+ umf_memory_pool_handle_t UMFPool) {
635+ for (auto &PoolPair : PoolMap) {
636+ if (PoolPair.second .get () == UMFPool) {
637+ return true ;
638+ }
639+ }
640+ return false ;
641+ };
642+
643+ for (auto &Pool : Context->UsmPoolHandles ) {
644+ if (SearchMatchingPool (Pool->DeviceMemPools , UMFPool)) {
645+ return ReturnValue (Pool);
646+ }
647+ if (SearchMatchingPool (Pool->SharedMemPools , UMFPool)) {
648+ return ReturnValue (Pool);
649+ }
650+ if (Pool->HostMemPool .get () == UMFPool) {
651+ return ReturnValue (Pool);
652+ }
653+ }
654+
655+ return UR_RESULT_ERROR_INVALID_VALUE;
656+ }
602657 default :
603658 urPrint (" urUSMGetMemAllocInfo: unsupported ParamName\n " );
604659 return UR_RESULT_ERROR_INVALID_VALUE;
@@ -748,6 +803,7 @@ ur_result_t L0HostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
748803ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
749804 ur_usm_pool_desc_t *PoolDesc) {
750805
806+ this ->Context = Context;
751807 zeroInit = static_cast <uint32_t >(PoolDesc->flags &
752808 UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);
753809
@@ -831,6 +887,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
831887 try {
832888 *Pool = reinterpret_cast <ur_usm_pool_handle_t >(
833889 new ur_usm_pool_handle_t_ (Context, PoolDesc));
890+
891+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
892+ Context->UsmPoolHandles .insert (Context->UsmPoolHandles .cend (), *Pool);
893+
834894 } catch (const UsmAllocationException &Ex) {
835895 return Ex.getError ();
836896 }
@@ -848,6 +908,8 @@ ur_result_t
848908urUSMPoolRelease (ur_usm_pool_handle_t Pool // /< [in] pointer to USM memory pool
849909) {
850910 if (Pool->RefCount .decrementAndTest ()) {
911+ std::shared_lock<ur_shared_mutex> ContextLock (Pool->Context ->Mutex );
912+ Pool->Context ->UsmPoolHandles .remove (Pool);
851913 delete Pool;
852914 }
853915 return UR_RESULT_SUCCESS;
@@ -861,13 +923,19 @@ ur_result_t urUSMPoolGetInfo(
861923 // /< property
862924 size_t *PropSizeRet // /< [out] size in bytes returned in pool property value
863925) {
864- std::ignore = Pool;
865- std::ignore = PropName;
866- std::ignore = PropSize;
867- std::ignore = PropValue;
868- std::ignore = PropSizeRet;
869- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
870- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
926+ UrReturnHelper ReturnValue (PropSize, PropValue, PropSizeRet);
927+
928+ switch (PropName) {
929+ case UR_USM_POOL_INFO_REFERENCE_COUNT: {
930+ return ReturnValue (Pool->RefCount .load ());
931+ }
932+ case UR_USM_POOL_INFO_CONTEXT: {
933+ return ReturnValue (Pool->Context );
934+ }
935+ default : {
936+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
937+ }
938+ }
871939}
872940
873941// If indirect access tracking is not enabled then this functions just performs
0 commit comments