@@ -280,6 +280,11 @@ hipError_t ihipMalloc(void** ptr, size_t sizeBytes, unsigned int flags)
280280
281281 *ptr = amd::SvmBuffer::malloc (*amdContext, flags, sizeBytes, amdContext->devices ()[0 ]->info ().memBaseAddrAlign_ ,
282282 useHostDevice ? curDevContext->svmDevices ()[0 ] : nullptr );
283+ size_t offset = 0 ; // this is ignored
284+ amd::Memory* memObj = getMemoryObject (*ptr, offset);
285+ // saves the current device id so that it can be accessed later
286+ memObj->getUserData ().deviceId = hip::getCurrentDevice ()->deviceId ();
287+
283288 if (*ptr == nullptr ) {
284289 size_t free = 0 , total =0 ;
285290 hipMemGetInfo (&free, &total);
@@ -649,6 +654,10 @@ hipError_t ihipMallocPitch(void** ptr, size_t* pitch, size_t width, size_t heigh
649654
650655 *ptr = amd::SvmBuffer::malloc (*hip::getCurrentDevice ()->asContext (), 0 , sizeBytes,
651656 device->info ().memBaseAddrAlign_ );
657+ size_t offset = 0 ; // this is ignored
658+ amd::Memory* memObj = getMemoryObject (*ptr, offset);
659+ // saves the current device id so that it can be accessed later
660+ memObj->getUserData ().deviceId = hip::getCurrentDevice ()->deviceId ();
652661
653662 if (*ptr == nullptr ) {
654663 return hipErrorOutOfMemory;
@@ -2473,19 +2482,8 @@ hipError_t hipPointerGetAttributes(hipPointerAttribute_t* attributes, const void
24732482 attributes->allocationFlags = memObj->getMemFlags () >> 16 ;
24742483
24752484 amd::Context* memObjCtx = &memObj->getContext ();
2476- if (hip::host_device->asContext () == memObjCtx) {
2477- attributes->device = ihipGetDevice ();
2478- HIP_RETURN (hipSuccess);
2479- }
2480- for (auto & ctx : g_devices) {
2481- if (ctx->asContext () == memObjCtx) {
2482- attributes->device = device;
2483- HIP_RETURN (hipSuccess);
2484- }
2485- ++device;
2486- }
2487- LogPrintfError (" Cannot find memory object context, memObjCtx: 0x%x \n " , memObjCtx);
2488- HIP_RETURN (hipErrorInvalidDevice);
2485+ attributes->device = memObj->getUserData ().deviceId ;
2486+ HIP_RETURN (hipSuccess);
24892487 }
24902488
24912489 LogPrintfError (" Cannot get amd_mem_obj for ptr: 0x%x \n " , ptr);
0 commit comments