1313
1414#include " ../helpers/memory_helpers.hpp"
1515
16- ur_mem_handle_t_::ur_mem_handle_t_ (ur_context_handle_t hContext, size_t size)
17- : hContext(hContext), size(size) {}
16+ static ur_mem_handle_t_::device_access_mode_t
17+ getDeviceAccessMode (ur_mem_flags_t memFlag) {
18+ if (memFlag & UR_MEM_FLAG_READ_WRITE) {
19+ return ur_mem_handle_t_::device_access_mode_t ::read_write;
20+ } else if (memFlag & UR_MEM_FLAG_READ_ONLY) {
21+ return ur_mem_handle_t_::device_access_mode_t ::read_only;
22+ } else if (memFlag & UR_MEM_FLAG_WRITE_ONLY) {
23+ return ur_mem_handle_t_::device_access_mode_t ::write_only;
24+ } else {
25+ return ur_mem_handle_t_::device_access_mode_t ::read_write;
26+ }
27+ }
28+
29+ static bool isAccessCompatible (ur_mem_handle_t_::device_access_mode_t requested,
30+ ur_mem_handle_t_::device_access_mode_t actual) {
31+ return requested == actual ||
32+ actual == ur_mem_handle_t_::device_access_mode_t ::read_write;
33+ }
34+
35+ ur_mem_handle_t_::ur_mem_handle_t_ (ur_context_handle_t hContext, size_t size,
36+ device_access_mode_t accessMode)
37+ : accessMode(accessMode), hContext(hContext), size(size) {}
38+
39+ size_t ur_mem_handle_t_::getSize () const { return size; }
40+
41+ ur_shared_mutex &ur_mem_handle_t_::getMutex () { return Mutex; }
1842
1943ur_usm_handle_t_::ur_usm_handle_t_ (ur_context_handle_t hContext, size_t size,
2044 const void *ptr)
21- : ur_mem_handle_t_(hContext, size), ptr(const_cast <void *>(ptr)) {}
45+ : ur_mem_handle_t_(hContext, size, device_access_mode_t ::read_write),
46+ ptr(const_cast <void *>(ptr)) {}
2247
2348ur_usm_handle_t_::~ur_usm_handle_t_ () {}
2449
2550void *ur_usm_handle_t_::getDevicePtr (
26- ur_device_handle_t hDevice, access_mode_t access, size_t offset,
51+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
2752 size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
2853 std::ignore = hDevice;
2954 std::ignore = access;
@@ -34,9 +59,9 @@ void *ur_usm_handle_t_::getDevicePtr(
3459}
3560
3661void *ur_usm_handle_t_::mapHostPtr (
37- access_mode_t access , size_t offset, size_t size,
62+ ur_map_flags_t flags , size_t offset, size_t size,
3863 std::function<void (void *src, void *dst, size_t )>) {
39- std::ignore = access ;
64+ std::ignore = flags ;
4065 std::ignore = offset;
4166 std::ignore = size;
4267 return ptr;
@@ -50,8 +75,8 @@ void ur_usm_handle_t_::unmapHostPtr(
5075
5176ur_integrated_mem_handle_t ::ur_integrated_mem_handle_t (
5277 ur_context_handle_t hContext, void *hostPtr, size_t size,
53- host_ptr_action_t hostPtrAction)
54- : ur_mem_handle_t_(hContext, size) {
78+ host_ptr_action_t hostPtrAction, device_access_mode_t accessMode )
79+ : ur_mem_handle_t_(hContext, size, accessMode ) {
5580 bool hostPtrImported = false ;
5681 if (hostPtrAction == host_ptr_action_t ::import ) {
5782 hostPtrImported =
@@ -83,8 +108,9 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
83108}
84109
85110ur_integrated_mem_handle_t ::ur_integrated_mem_handle_t (
86- ur_context_handle_t hContext, void *hostPtr, size_t size, bool ownHostPtr)
87- : ur_mem_handle_t_(hContext, size) {
111+ ur_context_handle_t hContext, void *hostPtr, size_t size,
112+ device_access_mode_t accessMode, bool ownHostPtr)
113+ : ur_mem_handle_t_(hContext, size, accessMode) {
88114 this ->ptr = usm_unique_ptr_t (hostPtr, [hContext, ownHostPtr](void *ptr) {
89115 if (!ownHostPtr) {
90116 return ;
@@ -97,7 +123,7 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
97123}
98124
99125void *ur_integrated_mem_handle_t ::getDevicePtr(
100- ur_device_handle_t hDevice, access_mode_t access, size_t offset,
126+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
101127 size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
102128 std::ignore = hDevice;
103129 std::ignore = access;
@@ -108,9 +134,9 @@ void *ur_integrated_mem_handle_t::getDevicePtr(
108134}
109135
110136void *ur_integrated_mem_handle_t ::mapHostPtr(
111- access_mode_t access , size_t offset, size_t size,
137+ ur_map_flags_t flags , size_t offset, size_t size,
112138 std::function<void (void *src, void *dst, size_t )> migrate) {
113- std::ignore = access ;
139+ std::ignore = flags ;
114140 std::ignore = offset;
115141 std::ignore = size;
116142 std::ignore = migrate;
@@ -178,9 +204,10 @@ ur_discrete_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice, void *src,
178204 return UR_RESULT_SUCCESS;
179205}
180206
181- ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (ur_context_handle_t hContext,
182- void *hostPtr, size_t size)
183- : ur_mem_handle_t_(hContext, size),
207+ ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (
208+ ur_context_handle_t hContext, void *hostPtr, size_t size,
209+ device_access_mode_t accessMode)
210+ : ur_mem_handle_t_(hContext, size, accessMode),
184211 deviceAllocations (hContext->getPlatform ()->getNumDevices()),
185212 activeAllocationDevice(nullptr ), hostAllocations() {
186213 if (hostPtr) {
@@ -189,12 +216,11 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(ur_context_handle_t hContext,
189216 }
190217}
191218
192- ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (ur_context_handle_t hContext,
193- ur_device_handle_t hDevice,
194- void *devicePtr, size_t size,
195- void *writeBackMemory,
196- bool ownZePtr)
197- : ur_mem_handle_t_(hContext, size),
219+ ur_discrete_mem_handle_t ::ur_discrete_mem_handle_t (
220+ ur_context_handle_t hContext, ur_device_handle_t hDevice, void *devicePtr,
221+ size_t size, device_access_mode_t accessMode, void *writeBackMemory,
222+ bool ownZePtr)
223+ : ur_mem_handle_t_(hContext, size, accessMode),
198224 deviceAllocations(hContext->getPlatform ()->getNumDevices()),
199225 activeAllocationDevice(hDevice), writeBackPtr(writeBackMemory),
200226 hostAllocations() {
@@ -227,7 +253,7 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
227253}
228254
229255void *ur_discrete_mem_handle_t ::getDevicePtr(
230- ur_device_handle_t hDevice, access_mode_t access, size_t offset,
256+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
231257 size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
232258 TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::getDevicePtr" );
233259
@@ -265,19 +291,18 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
265291}
266292
267293void *ur_discrete_mem_handle_t ::mapHostPtr(
268- access_mode_t access , size_t offset, size_t size,
294+ ur_map_flags_t flags , size_t offset, size_t size,
269295 std::function<void (void *src, void *dst, size_t )> migrate) {
270296 TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::mapHostPtr" );
271-
272297 // TODO: use async alloc?
273298
274299 void *ptr;
275300 UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
276301 hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
277302
278- hostAllocations.emplace_back (ptr, size, offset, access );
303+ hostAllocations.emplace_back (ptr, size, offset, flags );
279304
280- if (activeAllocationDevice && access != access_mode_t ::write_only ) {
305+ if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ) ) {
281306 auto srcPtr =
282307 ur_cast<char *>(
283308 deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
@@ -301,10 +326,11 @@ void ur_discrete_mem_handle_t::unmapHostPtr(
301326 ur_cast<char *>(
302327 deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
303328 hostAllocation.offset ;
304- } else if (hostAllocation.access != access_mode_t ::write_invalidate) {
305- devicePtr = ur_cast<char *>(
306- getDevicePtr (hContext->getDevices ()[0 ], access_mode_t ::read_only,
307- hostAllocation.offset , hostAllocation.size , migrate));
329+ } else if (!(hostAllocation.flags &
330+ UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
331+ devicePtr = ur_cast<char *>(getDevicePtr (
332+ hContext->getDevices ()[0 ], device_access_mode_t ::read_only,
333+ hostAllocation.offset , hostAllocation.size , migrate));
308334 }
309335
310336 if (devicePtr) {
@@ -332,6 +358,46 @@ static bool useHostBuffer(ur_context_handle_t hContext) {
332358 ZE_DEVICE_PROPERTY_FLAG_INTEGRATED;
333359}
334360
361+ namespace ur ::level_zero {
362+ ur_result_t urMemRetain (ur_mem_handle_t hMem);
363+ ur_result_t urMemRelease (ur_mem_handle_t hMem);
364+ } // namespace ur::level_zero
365+
366+ ur_mem_sub_buffer_t ::ur_mem_sub_buffer_t (ur_mem_handle_t hParent, size_t offset,
367+ size_t size,
368+ device_access_mode_t accessMode)
369+ : ur_mem_handle_t_(hParent->getContext (), size, accessMode),
370+ hParent(hParent), offset(offset), size(size) {
371+ ur::level_zero::urMemRetain (hParent);
372+ }
373+
374+ ur_mem_sub_buffer_t ::~ur_mem_sub_buffer_t () {
375+ ur::level_zero::urMemRelease (hParent);
376+ }
377+
378+ void *ur_mem_sub_buffer_t ::getDevicePtr(
379+ ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
380+ size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
381+ return hParent->getDevicePtr (hDevice, access, offset + this ->offset , size,
382+ migrate);
383+ }
384+
385+ void *ur_mem_sub_buffer_t ::mapHostPtr(
386+ ur_map_flags_t flags, size_t offset, size_t size,
387+ std::function<void (void *src, void *dst, size_t )> migrate) {
388+ return hParent->mapHostPtr (flags, offset + this ->offset , size, migrate);
389+ }
390+
391+ void ur_mem_sub_buffer_t::unmapHostPtr (
392+ void *pMappedPtr,
393+ std::function<void (void *src, void *dst, size_t )> migrate) {
394+ return hParent->unmapHostPtr (pMappedPtr, migrate);
395+ }
396+
397+ size_t ur_mem_sub_buffer_t::getSize () const { return size; }
398+
399+ ur_shared_mutex &ur_mem_sub_buffer_t ::getMutex() { return hParent->getMutex (); }
400+
335401namespace ur ::level_zero {
336402ur_result_t urMemBufferCreate (ur_context_handle_t hContext,
337403 ur_mem_flags_t flags, size_t size,
@@ -347,6 +413,7 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
347413 }
348414
349415 void *hostPtr = pProperties ? pProperties->pHost : nullptr ;
416+ auto accessMode = getDeviceAccessMode (flags);
350417
351418 if (useHostBuffer (hContext)) {
352419 // TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
@@ -355,10 +422,11 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
355422 flags & UR_MEM_FLAG_USE_HOST_POINTER
356423 ? ur_integrated_mem_handle_t ::host_ptr_action_t ::import
357424 : ur_integrated_mem_handle_t ::host_ptr_action_t ::copy;
358- *phBuffer =
359- new ur_integrated_mem_handle_t (hContext, hostPtr, size, hostPtrAction);
425+ *phBuffer = new ur_integrated_mem_handle_t (hContext, hostPtr, size,
426+ hostPtrAction, accessMode );
360427 } else {
361- *phBuffer = new ur_discrete_mem_handle_t (hContext, hostPtr, size);
428+ *phBuffer =
429+ new ur_discrete_mem_handle_t (hContext, hostPtr, size, accessMode);
362430 }
363431
364432 return UR_RESULT_SUCCESS;
@@ -368,13 +436,21 @@ ur_result_t urMemBufferPartition(ur_mem_handle_t hBuffer, ur_mem_flags_t flags,
368436 ur_buffer_create_type_t bufferCreateType,
369437 const ur_buffer_region_t *pRegion,
370438 ur_mem_handle_t *phMem) {
371- std::ignore = hBuffer;
372- std::ignore = flags;
373- std::ignore = bufferCreateType;
374- std::ignore = pRegion;
375- std::ignore = phMem;
376- logger::error (" {} function not implemented!" , __FUNCTION__);
377- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
439+ UR_ASSERT (bufferCreateType == UR_BUFFER_CREATE_TYPE_REGION,
440+ UR_RESULT_ERROR_INVALID_ENUMERATION);
441+ UR_ASSERT ((pRegion->origin < hBuffer->getSize () &&
442+ pRegion->size <= hBuffer->getSize ()),
443+ UR_RESULT_ERROR_INVALID_BUFFER_SIZE);
444+
445+ auto accessMode = getDeviceAccessMode (flags);
446+
447+ UR_ASSERT (isAccessCompatible (accessMode, hBuffer->getDeviceAccessMode ()),
448+ UR_RESULT_ERROR_INVALID_VALUE);
449+
450+ *phMem = new ur_mem_sub_buffer_t (hBuffer, pRegion->origin , pRegion->size ,
451+ accessMode);
452+
453+ return UR_RESULT_SUCCESS;
378454}
379455
380456ur_result_t urMemBufferCreateWithNativeHandle (
@@ -407,21 +483,24 @@ ur_result_t urMemBufferCreateWithNativeHandle(
407483 UR_RESULT_ERROR_INVALID_CONTEXT);
408484 }
409485
486+ // assume read-write
487+ auto accessMode = ur_mem_handle_t_::device_access_mode_t ::read_write;
488+
410489 if (useHostBuffer (hContext) && memoryAttrs.type == ZE_MEMORY_TYPE_HOST) {
411- *phMem =
412- new ur_integrated_mem_handle_t (hContext, ptr, size, ownNativeHandle);
490+ *phMem = new ur_integrated_mem_handle_t (hContext, ptr, size, accessMode,
491+ ownNativeHandle);
413492 // if useHostBuffer(hContext) is true but the allocation is on device, we'll
414493 // treat it as discrete memory
415494 } else {
416495 if (memoryAttrs.type == ZE_MEMORY_TYPE_HOST) {
417496 // For host allocation, we need to copy the data to a device buffer
418497 // and then copy it back on release
419498 *phMem = new ur_discrete_mem_handle_t (hContext, hDevice, nullptr , size,
420- ptr, ownNativeHandle);
499+ accessMode, ptr, ownNativeHandle);
421500 } else {
422501 // For device/shared allocation, we can use it directly
423- *phMem = new ur_discrete_mem_handle_t (hContext, hDevice, ptr, size,
424- nullptr , ownNativeHandle);
502+ *phMem = new ur_discrete_mem_handle_t (
503+ hContext, hDevice, ptr, size, accessMode, nullptr , ownNativeHandle);
425504 }
426505 }
427506
@@ -452,12 +531,12 @@ ur_result_t urMemGetInfo(ur_mem_handle_t hMemory, ur_mem_info_t propName,
452531}
453532
454533ur_result_t urMemRetain (ur_mem_handle_t hMem) {
455- hMem->RefCount .increment ();
534+ hMem->getRefCount () .increment ();
456535 return UR_RESULT_SUCCESS;
457536}
458537
459538ur_result_t urMemRelease (ur_mem_handle_t hMem) {
460- if (hMem->RefCount .decrementAndTest ()) {
539+ if (hMem->getRefCount () .decrementAndTest ()) {
461540 delete hMem;
462541 }
463542 return UR_RESULT_SUCCESS;
@@ -468,11 +547,11 @@ ur_result_t urMemGetNativeHandle(ur_mem_handle_t hMem,
468547 ur_native_handle_t *phNativeMem) {
469548 std::ignore = hDevice;
470549
471- std::scoped_lock<ur_shared_mutex> lock (hMem->Mutex );
550+ std::scoped_lock<ur_shared_mutex> lock (hMem->getMutex () );
472551
473- auto ptr =
474- hMem-> getDevicePtr ( nullptr , ur_mem_handle_t_::access_mode_t ::read_write,
475- 0 , hMem->getSize (), nullptr );
552+ auto ptr = hMem-> getDevicePtr (
553+ nullptr , ur_mem_handle_t_::device_access_mode_t ::read_write, 0 ,
554+ hMem->getSize (), nullptr );
476555 *phNativeMem = reinterpret_cast <ur_native_handle_t >(ptr);
477556 return UR_RESULT_SUCCESS;
478557}
0 commit comments