@@ -75,12 +75,14 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
7575 return UR_RESULT_SUCCESS;
7676 }
7777
78+ std::scoped_lock<ur_shared_mutex> Guard (Mutex);
7879 auto &Allocation = Allocations[Device];
80+ ur_result_t URes = UR_RESULT_SUCCESS;
7981 if (!Allocation) {
8082 ur_usm_desc_t USMDesc{};
8183 USMDesc.align = getAlignment ();
8284 ur_usm_pool_handle_t Pool{};
83- ur_result_t URes = getContext ()->interceptor ->allocateMemory (
85+ URes = getContext ()->interceptor ->allocateMemory (
8486 Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER,
8587 ur_cast<void **>(&Allocation));
8688 if (URes != UR_RESULT_SUCCESS) {
@@ -105,7 +107,60 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
105107
106108 Handle = Allocation;
107109
108- return UR_RESULT_SUCCESS;
110+ if (!LastSyncedDevice.hDevice ) {
111+ LastSyncedDevice = MemBuffer::Device_t{Device, Handle};
112+ return URes;
113+ }
114+
115+ // If the device required to allocate memory is not the previous one, we
116+ // need to do data migration.
117+ if (Device != LastSyncedDevice.hDevice ) {
118+ auto &HostAllocation = Allocations[nullptr ];
119+ if (!HostAllocation) {
120+ ur_usm_desc_t USMDesc{};
121+ USMDesc.align = getAlignment ();
122+ ur_usm_pool_handle_t Pool{};
123+ URes = getContext ()->interceptor ->allocateMemory (
124+ Context, nullptr , &USMDesc, Pool, Size, AllocType::HOST_USM,
125+ ur_cast<void **>(&HostAllocation));
126+ if (URes != UR_RESULT_SUCCESS) {
127+ getContext ()->logger .error (" Failed to allocate {} bytes host "
128+ " USM for buffer {} migration" ,
129+ Size, this );
130+ return URes;
131+ }
132+ }
133+
134+ // Copy data from last synced device to host
135+ {
136+ ManagedQueue Queue (Context, LastSyncedDevice.hDevice );
137+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
138+ Queue, true , HostAllocation, LastSyncedDevice.MemHandle , Size,
139+ 0 , nullptr , nullptr );
140+ if (URes != UR_RESULT_SUCCESS) {
141+ getContext ()->logger .error (
142+ " Failed to migrate memory buffer data" );
143+ return URes;
144+ }
145+ }
146+
147+ // Sync data back to device
148+ {
149+ ManagedQueue Queue (Context, Device);
150+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
151+ Queue, true , Allocation, HostAllocation, Size, 0 , nullptr ,
152+ nullptr );
153+ if (URes != UR_RESULT_SUCCESS) {
154+ getContext ()->logger .error (
155+ " Failed to migrate memory buffer data" );
156+ return URes;
157+ }
158+ }
159+ }
160+
161+ LastSyncedDevice = MemBuffer::Device_t{Device, Handle};
162+
163+ return URes;
109164}
110165
111166ur_result_t MemBuffer::free () {
0 commit comments