@@ -32,22 +32,24 @@ ur_usm_handle_t::ur_usm_handle_t(ur_context_handle_t hContext, size_t size,
3232 : ur_mem_buffer_t (hContext, size, device_access_mode_t ::read_write),
3333 ptr (const_cast <void *>(ptr)) {}
3434
35- void *ur_usm_handle_t ::getDevicePtr(
36- ur_device_handle_t /* hDevice*/ , device_access_mode_t /* access*/ ,
37- size_t offset, size_t /* size*/ ,
38- std::function<void (void *src, void *dst, size_t )> /* migrate*/ ) {
35+ void *ur_usm_handle_t ::getDevicePtr(ur_device_handle_t /* hDevice*/ ,
36+ device_access_mode_t /* access*/ ,
37+ size_t offset, size_t /* size*/ ,
38+ ze_command_list_handle_t /* cmdList*/ ,
39+ wait_list_view & /* waitListView*/ ) {
3940 return ur_cast<char *>(ptr) + offset;
4041}
4142
42- void *
43- ur_usm_handle_t ::mapHostPtr( ur_map_flags_t /* flags */ , size_t offset ,
44- size_t /* size */ ,
45- std::function< void ( void *src, void *dst, size_t )> ) {
43+ void *ur_usm_handle_t ::mapHostPtr( ur_map_flags_t /* flags */ , size_t offset,
44+ size_t /* size */ ,
45+ ze_command_list_handle_t /* cmdList */ ,
46+ wait_list_view & /* waitListView */ ) {
4647 return ur_cast<char *>(ptr) + offset;
4748}
4849
49- void ur_usm_handle_t::unmapHostPtr (
50- void * /* pMappedPtr*/ , std::function<void (void *src, void *dst, size_t )>) {
50+ void ur_usm_handle_t::unmapHostPtr (void * /* pMappedPtr*/ ,
51+ ze_command_list_handle_t /* cmdList*/ ,
52+ wait_list_view & /* waitListView*/ ) {
5153 /* nop */
5254}
5355
@@ -106,14 +108,14 @@ ur_integrated_buffer_handle_t::~ur_integrated_buffer_handle_t() {
106108
107109void *ur_integrated_buffer_handle_t ::getDevicePtr(
108110 ur_device_handle_t /* hDevice*/ , device_access_mode_t /* access*/ ,
109- size_t offset, size_t /* size*/ ,
110- std::function< void ( void *src, void *dst, size_t )> /* migrate */ ) {
111+ size_t offset, size_t /* size*/ , ze_command_list_handle_t /* cmdList */ ,
112+ wait_list_view & /* waitListView */ ) {
111113 return ur_cast<char *>(ptr.get ()) + offset;
112114}
113115
114116void *ur_integrated_buffer_handle_t ::mapHostPtr(
115117 ur_map_flags_t /* flags*/ , size_t offset, size_t /* size*/ ,
116- std::function< void ( void *src, void *dst, size_t )> /* migrate */ ) {
118+ ze_command_list_handle_t /* cmdList */ , wait_list_view & /* waitListView */ ) {
117119 // TODO: if writeBackPtr is set, we should map to that pointer
118120 // because that's what SYCL expects, SYCL will attempt to call free
119121 // on the resulting pointer leading to double free with the current
@@ -122,7 +124,8 @@ void *ur_integrated_buffer_handle_t::mapHostPtr(
122124}
123125
124126void ur_integrated_buffer_handle_t::unmapHostPtr (
125- void * /* pMappedPtr*/ , std::function<void (void *src, void *dst, size_t )>) {
127+ void * /* pMappedPtr*/ , ze_command_list_handle_t /* cmdList*/ ,
128+ wait_list_view & /* waitListView*/ ) {
126129 // TODO: if writeBackPtr is set, we should copy the data back
127130 /* nop */
128131}
@@ -250,8 +253,8 @@ void *ur_discrete_buffer_handle_t::getActiveDeviceAlloc(size_t offset) {
250253
251254void *ur_discrete_buffer_handle_t ::getDevicePtr(
252255 ur_device_handle_t hDevice, device_access_mode_t /* access*/ , size_t offset,
253- size_t /* size*/ ,
254- std::function< void ( void *src, void *dst, size_t )> /* migrate */ ) {
256+ size_t /* size*/ , ze_command_list_handle_t /* cmdList */ ,
257+ wait_list_view & /* waitListView */ ) {
255258 TRACK_SCOPE_LATENCY (" ur_discrete_buffer_handle_t::getDevicePtr" );
256259
257260 if (!activeAllocationDevice) {
@@ -283,9 +286,22 @@ void *ur_discrete_buffer_handle_t::getDevicePtr(
283286 return getActiveDeviceAlloc (offset);
284287}
285288
286- void *ur_discrete_buffer_handle_t ::mapHostPtr(
287- ur_map_flags_t flags, size_t offset, size_t size,
288- std::function<void (void *src, void *dst, size_t )> migrate) {
289+ static void migrateMemory (ze_command_list_handle_t cmdList, void *src,
290+ void *dst, size_t size,
291+ wait_list_view &waitListView) {
292+ if (!cmdList) {
293+ throw UR_RESULT_ERROR_INVALID_NULL_HANDLE;
294+ }
295+ ZE2UR_CALL_THROWS (zeCommandListAppendMemoryCopy,
296+ (cmdList, dst, src, size, nullptr , waitListView.num ,
297+ waitListView.handles ));
298+ waitListView.clear ();
299+ }
300+
301+ void *ur_discrete_buffer_handle_t ::mapHostPtr(ur_map_flags_t flags,
302+ size_t offset, size_t size,
303+ ze_command_list_handle_t cmdList,
304+ wait_list_view &waitListView) {
289305 TRACK_SCOPE_LATENCY (" ur_discrete_buffer_handle_t::mapHostPtr" );
290306 // TODO: use async alloc?
291307
@@ -309,15 +325,16 @@ void *ur_discrete_buffer_handle_t::mapHostPtr(
309325
310326 if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
311327 auto srcPtr = getActiveDeviceAlloc (offset);
312- migrate (srcPtr, hostAllocations.back ().ptr .get (), size);
328+ migrateMemory (cmdList, srcPtr, hostAllocations.back ().ptr .get (), size,
329+ waitListView);
313330 }
314331
315332 return hostAllocations.back ().ptr .get ();
316333}
317334
318- void ur_discrete_buffer_handle_t::unmapHostPtr (
319- void *pMappedPtr ,
320- std::function< void ( void *src, void *dst, size_t )> migrate ) {
335+ void ur_discrete_buffer_handle_t::unmapHostPtr (void *pMappedPtr,
336+ ze_command_list_handle_t cmdList ,
337+ wait_list_view &waitListView ) {
321338 TRACK_SCOPE_LATENCY (" ur_discrete_buffer_handle_t::unmapHostPtr" );
322339
323340 auto hostAlloc =
@@ -341,8 +358,9 @@ void ur_discrete_buffer_handle_t::unmapHostPtr(
341358 // UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
342359 // allocation. is this correct?
343360 if (activeAllocationDevice) {
344- migrate (hostAlloc->ptr .get (), getActiveDeviceAlloc (hostAlloc->offset ),
345- hostAlloc->size );
361+ migrateMemory (cmdList, hostAlloc->ptr .get (),
362+ getActiveDeviceAlloc (hostAlloc->offset ), hostAlloc->size ,
363+ waitListView);
346364 }
347365
348366 hostAllocations.erase (hostAlloc);
@@ -361,18 +379,20 @@ ur_shared_buffer_handle_t::ur_shared_buffer_handle_t(
361379
362380void *ur_shared_buffer_handle_t ::getDevicePtr(
363381 ur_device_handle_t , device_access_mode_t , size_t offset, size_t ,
364- std::function< void ( void *src, void *dst, size_t )> ) {
382+ ze_command_list_handle_t /* cmdList */ , wait_list_view & /* waitListView */ ) {
365383 return reinterpret_cast <char *>(ptr.get ()) + offset;
366384}
367385
368- void *ur_shared_buffer_handle_t ::mapHostPtr(
369- ur_map_flags_t , size_t offset, size_t ,
370- std::function<void (void *src, void *dst, size_t )>) {
386+ void *
387+ ur_shared_buffer_handle_t ::mapHostPtr(ur_map_flags_t , size_t offset, size_t ,
388+ ze_command_list_handle_t /* cmdList*/ ,
389+ wait_list_view & /* waitListView*/ ) {
371390 return reinterpret_cast <char *>(ptr.get ()) + offset;
372391}
373392
374393void ur_shared_buffer_handle_t::unmapHostPtr (
375- void *, std::function<void (void *src, void *dst, size_t )>) {
394+ void *, ze_command_list_handle_t /* cmdList*/ ,
395+ wait_list_view & /* waitListView*/ ) {
376396 // nop
377397}
378398
@@ -403,24 +423,27 @@ ur_mem_sub_buffer_t::~ur_mem_sub_buffer_t() {
403423 ur::level_zero::urMemRelease (hParent);
404424}
405425
406- void *ur_mem_sub_buffer_t ::getDevicePtr(
407- ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
408- size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
426+ void *ur_mem_sub_buffer_t ::getDevicePtr(ur_device_handle_t hDevice,
427+ device_access_mode_t access,
428+ size_t offset, size_t size,
429+ ze_command_list_handle_t cmdList,
430+ wait_list_view &waitListView) {
409431 return hParent->getBuffer ()->getDevicePtr (
410- hDevice, access, offset + this ->offset , size, std::move (migrate) );
432+ hDevice, access, offset + this ->offset , size, cmdList, waitListView );
411433}
412434
413- void *ur_mem_sub_buffer_t ::mapHostPtr(
414- ur_map_flags_t flags, size_t offset, size_t size,
415- std::function<void (void *src, void *dst, size_t )> migrate) {
435+ void *ur_mem_sub_buffer_t ::mapHostPtr(ur_map_flags_t flags, size_t offset,
436+ size_t size,
437+ ze_command_list_handle_t cmdList,
438+ wait_list_view &waitListView) {
416439 return hParent->getBuffer ()->mapHostPtr (flags, offset + this ->offset , size,
417- std::move (migrate) );
440+ cmdList, waitListView );
418441}
419442
420- void ur_mem_sub_buffer_t::unmapHostPtr (
421- void *pMappedPtr ,
422- std::function< void ( void *src, void *dst, size_t )> migrate ) {
423- return hParent->getBuffer ()->unmapHostPtr (pMappedPtr, std::move (migrate) );
443+ void ur_mem_sub_buffer_t::unmapHostPtr (void *pMappedPtr,
444+ ze_command_list_handle_t cmdList ,
445+ wait_list_view &waitListView ) {
446+ return hParent->getBuffer ()->unmapHostPtr (pMappedPtr, cmdList, waitListView );
424447}
425448
426449ur_shared_mutex &ur_mem_sub_buffer_t ::getMutex() {
@@ -690,9 +713,10 @@ ur_result_t urMemGetNativeHandle(ur_mem_handle_t hMem,
690713
691714 std::scoped_lock<ur_shared_mutex> lock (hBuffer->getMutex ());
692715
716+ wait_list_view emptyWaitListView (nullptr , 0 );
693717 auto ptr = hBuffer->getDevicePtr (
694718 hDevice, ur_mem_buffer_t ::device_access_mode_t ::read_write, 0 ,
695- hBuffer->getSize (), nullptr );
719+ hBuffer->getSize (), nullptr , emptyWaitListView );
696720 *phNativeMem = reinterpret_cast <ur_native_handle_t >(ptr);
697721 return UR_RESULT_SUCCESS;
698722} catch (...) {
0 commit comments