@@ -102,26 +102,41 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
102102 return UR_RESULT_SUCCESS;
103103}
104104
105- ur_result_t USMFreeImpl (ur_context_handle_t , void *Pointer) {
105+ ur_result_t USMFreeImpl (ur_context_handle_t hContext , void *Pointer) {
106106 ur_result_t Result = UR_RESULT_SUCCESS;
107107 try {
108108 unsigned int IsManaged;
109109 unsigned int Type;
110- void *AttributeValues[2 ] = {&IsManaged, &Type};
111- CUpointer_attribute Attributes[2 ] = {CU_POINTER_ATTRIBUTE_IS_MANAGED,
112- CU_POINTER_ATTRIBUTE_MEMORY_TYPE};
113- UR_CHECK_ERROR (cuPointerGetAttributes (2 , Attributes, AttributeValues,
114- (CUdeviceptr)Pointer));
110+ unsigned int DeviceOrdinal;
111+ const int NumAttributes = 3 ;
112+ void *AttributeValues[NumAttributes] = {&IsManaged, &Type, &DeviceOrdinal};
113+
114+ CUpointer_attribute Attributes[NumAttributes] = {
115+ CU_POINTER_ATTRIBUTE_IS_MANAGED, CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
116+ CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL};
117+ UR_CHECK_ERROR (cuPointerGetAttributes (
118+ NumAttributes, Attributes, AttributeValues, (CUdeviceptr)Pointer));
115119 UR_ASSERT (Type == CU_MEMORYTYPE_DEVICE || Type == CU_MEMORYTYPE_HOST,
116120 UR_RESULT_ERROR_INVALID_MEM_OBJECT);
117- if (IsManaged || Type == CU_MEMORYTYPE_DEVICE) {
118- // Memory allocated with cuMemAlloc and cuMemAllocManaged must be freed
119- // with cuMemFree
120- UR_CHECK_ERROR (cuMemFree ((CUdeviceptr)Pointer));
121+
122+ std::vector<ur_device_handle_t > ContextDevices = hContext->getDevices ();
123+ ur_platform_handle_t Platform = ContextDevices[0 ]->getPlatform ();
124+ unsigned int NumDevices = Platform->Devices .size ();
125+ UR_ASSERT (DeviceOrdinal < NumDevices, UR_RESULT_ERROR_INVALID_DEVICE);
126+
127+ ur_device_handle_t Device = Platform->Devices [DeviceOrdinal].get ();
128+ umf_memory_provider_handle_t MemoryProvider;
129+
130+ if (IsManaged) {
131+ MemoryProvider = Device->MemoryProviderShared ;
132+ } else if (Type == CU_MEMORYTYPE_DEVICE) {
133+ MemoryProvider = Device->MemoryProviderDevice ;
121134 } else {
122- // Memory allocated with cuMemAllocHost must be freed with cuMemFreeHost
123- UR_CHECK_ERROR (cuMemFreeHost (Pointer));
135+ MemoryProvider = hContext->MemoryProviderHost ;
124136 }
137+
138+ UMF_CHECK_ERROR (umfMemoryProviderFree (MemoryProvider, Pointer,
139+ 0 /* size is unknown */ ));
125140 } catch (ur_result_t Err) {
126141 Result = Err;
127142 }
@@ -143,7 +158,8 @@ ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t,
143158 uint32_t Alignment) {
144159 try {
145160 ScopedContext Active (Device);
146- UR_CHECK_ERROR (cuMemAlloc ((CUdeviceptr *)ResultPtr, Size));
161+ UMF_CHECK_ERROR (umfMemoryProviderAlloc (Device->MemoryProviderDevice , Size,
162+ Alignment, ResultPtr));
147163 } catch (ur_result_t Err) {
148164 return Err;
149165 }
@@ -164,8 +180,8 @@ ur_result_t USMSharedAllocImpl(void **ResultPtr, ur_context_handle_t,
164180 uint32_t Alignment) {
165181 try {
166182 ScopedContext Active (Device);
167- UR_CHECK_ERROR ( cuMemAllocManaged ((CUdeviceptr *)ResultPtr , Size,
168- CU_MEM_ATTACH_GLOBAL ));
183+ UMF_CHECK_ERROR ( umfMemoryProviderAlloc (Device-> MemoryProviderShared , Size,
184+ Alignment, ResultPtr ));
169185 } catch (ur_result_t Err) {
170186 return Err;
171187 }
@@ -179,11 +195,12 @@ ur_result_t USMSharedAllocImpl(void **ResultPtr, ur_context_handle_t,
179195 return UR_RESULT_SUCCESS;
180196}
181197
182- ur_result_t USMHostAllocImpl (void **ResultPtr, ur_context_handle_t ,
198+ ur_result_t USMHostAllocImpl (void **ResultPtr, ur_context_handle_t hContext ,
183199 ur_usm_host_mem_flags_t , size_t Size,
184200 uint32_t Alignment) {
185201 try {
186- UR_CHECK_ERROR (cuMemAllocHost (ResultPtr, Size));
202+ UMF_CHECK_ERROR (umfMemoryProviderAlloc (hContext->MemoryProviderHost , Size,
203+ Alignment, ResultPtr));
187204 } catch (ur_result_t Err) {
188205 return Err;
189206 }
0 commit comments