@@ -76,11 +76,12 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
7676 }
7777
7878 auto &Allocation = Allocations[Device];
79+ ur_result_t URes = UR_RESULT_SUCCESS;
7980 if (!Allocation) {
8081 ur_usm_desc_t USMDesc{};
8182 USMDesc.align = getAlignment ();
8283 ur_usm_pool_handle_t Pool{};
83- ur_result_t URes = getContext ()->interceptor ->allocateMemory (
84+ URes = getContext ()->interceptor ->allocateMemory (
8485 Context, Device, &USMDesc, Pool, Size, AllocType::MEM_BUFFER,
8586 ur_cast<void **>(&Allocation));
8687 if (URes != UR_RESULT_SUCCESS) {
@@ -103,9 +104,57 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
103104 }
104105 }
105106
107+ // If the device required to allocate memory is not the previous one, we
108+ // need to do data migration.
109+ if (Device != LastSyncedDevice && LastSyncedDevice != nullptr ) {
110+ auto &HostAllocation = Allocations[nullptr ];
111+ if (!HostAllocation) {
112+ ur_usm_desc_t USMDesc{};
113+ USMDesc.align = getAlignment ();
114+ ur_usm_pool_handle_t Pool{};
115+ URes = getContext ()->interceptor ->allocateMemory (
116+ Context, nullptr , &USMDesc, Pool, Size, AllocType::HOST_USM,
117+ ur_cast<void **>(&HostAllocation));
118+ if (URes != UR_RESULT_SUCCESS) {
119+ getContext ()->logger .error (" Failed to allocate {} bytes host "
120+ " USM for buffer {} migration" ,
121+ Size, this );
122+ return URes;
123+ }
124+ }
125+
126+ // Copy data from last synced device to host
127+ {
128+ ManagedQueue Queue (Context, LastSyncedDevice);
129+ char *Handle;
130+ UR_CALL (getHandle (LastSyncedDevice, Handle));
131+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
132+ Queue, true , HostAllocation, Handle, Size, 0 , nullptr , nullptr );
133+ if (URes != UR_RESULT_SUCCESS) {
134+ getContext ()->logger .error (
135+ " Failed to migrate memory buffer data" );
136+ return URes;
137+ }
138+ }
139+
140+ // Sync data back to device
141+ {
142+ ManagedQueue Queue (Context, Device);
143+ URes = getContext ()->urDdiTable .Enqueue .pfnUSMMemcpy (
144+ Queue, true , Allocation, HostAllocation, Size, 0 , nullptr ,
145+ nullptr );
146+ if (URes != UR_RESULT_SUCCESS) {
147+ getContext ()->logger .error (
148+ " Failed to migrate memory buffer data" );
149+ return URes;
150+ }
151+ }
152+ }
153+
154+ LastSyncedDevice = Device;
106155 Handle = Allocation;
107156
108- return UR_RESULT_SUCCESS ;
157+ return URes ;
109158}
110159
111160ur_result_t MemBuffer::free () {
0 commit comments