@@ -152,41 +152,21 @@ UR_APIEXPORT ur_result_t UR_APICALL
152152urUSMGetMemAllocInfo (ur_context_handle_t hContext, const void *pMem,
153153 ur_usm_alloc_info_t propName, size_t propValueSize,
154154 void *pPropValue, size_t *pPropValueSizeRet) {
155- ur_result_t Result = UR_RESULT_SUCCESS;
156- hipPointerAttribute_t hipPointerAttributeType;
157-
158155 UrReturnHelper ReturnValue (propValueSize, pPropValue, pPropValueSizeRet);
159156
160157 try {
161158 switch (propName) {
162159 case UR_USM_ALLOC_INFO_TYPE: {
163- // do not throw if hipPointerGetAttribute returns hipErrorInvalidValue
164- hipError_t Ret = hipPointerGetAttributes (&hipPointerAttributeType, pMem);
165- if (Ret == hipErrorInvalidValue) {
166- // pointer not known to the HIP subsystem
167- return ReturnValue (UR_USM_TYPE_UNKNOWN);
168- }
169- // Direct usage of the function, instead of UR_CHECK_ERROR, so we can get
170- // the line offset.
171- checkErrorUR (Ret, __func__, __LINE__ - 5 , __FILE__);
172- // ROCm 6.0.0 introduces hipMemoryTypeUnregistered in the hipMemoryType
173- // enum to mark unregistered allocations (i.e., via system allocators).
174- #if HIP_VERSION_MAJOR >= 6
175- if (hipPointerAttributeType.type == hipMemoryTypeUnregistered) {
160+ auto MaybePointerAttrs = getPointerAttributes (pMem);
161+ if (!MaybePointerAttrs.has_value ()) {
176162 // pointer not known to the HIP subsystem
177163 return ReturnValue (UR_USM_TYPE_UNKNOWN);
178164 }
179- #endif
180- unsigned int Value;
181- #if HIP_VERSION >= 50600000
182- Value = hipPointerAttributeType.type ;
183- #else
184- Value = hipPointerAttributeType.memoryType ;
185- #endif
165+ auto Value = getMemoryType (*MaybePointerAttrs);
186166 UR_ASSERT (Value == hipMemoryTypeDevice || Value == hipMemoryTypeHost ||
187167 Value == hipMemoryTypeManaged,
188168 UR_RESULT_ERROR_INVALID_MEM_OBJECT);
189- if (hipPointerAttributeType. isManaged || Value == hipMemoryTypeManaged) {
169+ if (MaybePointerAttrs-> isManaged || Value == hipMemoryTypeManaged) {
190170 // pointer to managed memory
191171 return ReturnValue (UR_USM_TYPE_SHARED);
192172 }
@@ -202,15 +182,18 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
202182 ur::unreachable ();
203183 }
204184 case UR_USM_ALLOC_INFO_DEVICE: {
205- // get device index associated with this pointer
206- UR_CHECK_ERROR (hipPointerGetAttributes (&hipPointerAttributeType, pMem));
185+ auto MaybePointerAttrs = getPointerAttributes (pMem);
186+ if (!MaybePointerAttrs.has_value ()) {
187+ // pointer not known to the HIP subsystem
188+ return ReturnValue (UR_USM_TYPE_UNKNOWN);
189+ }
207190
208- int DeviceIdx = hipPointerAttributeType. device ;
191+ int DeviceIdx = MaybePointerAttrs-> device ;
209192
210193 // hip backend has only one platform containing all devices
211194 ur_platform_handle_t platform;
212195 ur_adapter_handle_t AdapterHandle = &adapter;
213- Result = urPlatformGet (&AdapterHandle, 1 , 1 , &platform, nullptr );
196+ UR_CHECK_ERROR ( urPlatformGet (&AdapterHandle, 1 , 1 , &platform, nullptr ) );
214197
215198 // get the device from the platform
216199 ur_device_handle_t Device = platform->Devices [DeviceIdx].get ();
@@ -227,20 +210,32 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
227210 }
228211 return ReturnValue (Pool);
229212 }
213+ case UR_USM_ALLOC_INFO_BASE_PTR:
214+ // HIP gives us the ability to query the base pointer for a device
215+ // pointer, so check whether we've got one of those.
216+ if (auto MaybePointerAttrs = getPointerAttributes (pMem)) {
217+ if (getMemoryType (*MaybePointerAttrs) == hipMemoryTypeDevice) {
218+ void *Base = nullptr ;
219+ UR_CHECK_ERROR (hipPointerGetAttribute (
220+ &Base, HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
221+ (hipDeviceptr_t)pMem));
222+ return ReturnValue (Base);
223+ }
224+ }
225+ // If not, we can't be sure.
226+ return UR_RESULT_ERROR_INVALID_VALUE;
230227 case UR_USM_ALLOC_INFO_SIZE: {
231228 size_t RangeSize = 0 ;
232229 UR_CHECK_ERROR (hipMemPtrGetInfo (const_cast <void *>(pMem), &RangeSize));
233230 return ReturnValue (RangeSize);
234231 }
235- case UR_USM_ALLOC_INFO_BASE_PTR:
236- return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
237232 default :
238233 return UR_RESULT_ERROR_INVALID_ENUMERATION;
239234 }
240235 } catch (ur_result_t Error) {
241- Result = Error;
236+ return Error;
242237 }
243- return Result ;
238+ return UR_RESULT_SUCCESS ;
244239}
245240
246241UR_APIEXPORT ur_result_t UR_APICALL urUSMImportExp (ur_context_handle_t Context,
0 commit comments