@@ -78,6 +78,12 @@ __all__ = [
78
78
79
79
include " _sycl_usm_array_interface_utils.pxi"
80
80
81
+ cdef extern from " _opaque_smart_ptr.hpp" :
82
+ void * OpaqueSmartPtr_Make(void * , DPCTLSyclQueueRef) nogil
83
+ void * OpaqueSmartPtr_Copy(void * ) nogil
84
+ void OpaqueSmartPtr_Delete(void * ) nogil
85
+ void * OpaqueSmartPtr_Get(void * ) nogil
86
+
81
87
class USMAllocationError (Exception ):
82
88
"""
83
89
An exception raised when Universal Shared Memory (USM) allocation
@@ -152,7 +158,8 @@ cdef class _Memory:
152
158
MemoryUSMShared, MemoryUSMDevice, MemoryUSMHost
153
159
"""
154
160
cdef _cinit_empty(self ):
155
- self .memory_ptr = NULL
161
+ self ._memory_ptr = NULL
162
+ self ._opaque_ptr = NULL
156
163
self .nbytes = 0
157
164
self .queue = None
158
165
self .refobj = None
@@ -198,7 +205,8 @@ cdef class _Memory:
198
205
)
199
206
200
207
if (p):
201
- self .memory_ptr = p
208
+ self ._memory_ptr = p
209
+ self ._opaque_ptr = OpaqueSmartPtr_Make(p, QRef)
202
210
self .nbytes = nbytes
203
211
self .queue = queue
204
212
else :
@@ -214,18 +222,22 @@ cdef class _Memory:
214
222
cdef _Memory other_mem
215
223
if isinstance (other, _Memory):
216
224
other_mem = < _Memory> other
217
- self .memory_ptr = other_mem.memory_ptr
218
225
self .nbytes = other_mem.nbytes
219
226
self .queue = other_mem.queue
220
- if other_mem.refobj is None :
221
- self .refobj = other
227
+ if other_mem._opaque_ptr is NULL :
228
+ self ._memory_ptr = other_mem._memory_ptr
229
+ self ._opaque_ptr = NULL
230
+ self .refobj = other.reference_obj
222
231
else :
223
- self .refobj = other_mem.refobj
232
+ self ._memory_ptr = other_mem._memory_ptr
233
+ self ._opaque_ptr = OpaqueSmartPtr_Copy(other_mem._opaque_ptr)
234
+ self .refobj = None
224
235
elif hasattr (other, ' __sycl_usm_array_interface__' ):
225
236
other_iface = other.__sycl_usm_array_interface__
226
237
if isinstance (other_iface, dict ):
227
238
other_buf = _USMBufferData.from_sycl_usm_ary_iface(other_iface)
228
- self .memory_ptr = other_buf.p
239
+ self ._opaque_ptr = NULL
240
+ self ._memory_ptr = < DPCTLSyclUSMRef> other_buf.p
229
241
self .nbytes = other_buf.nbytes
230
242
self .queue = other_buf.queue
231
243
self .refobj = other
@@ -241,23 +253,25 @@ cdef class _Memory:
241
253
)
242
254
243
255
def __dealloc__ (self ):
244
- if (self .refobj is None ):
245
- if self .memory_ptr:
246
- if (type (self .queue) is SyclQueue):
247
- DPCTLfree_with_queue(
248
- self .memory_ptr, self .queue.get_queue_ref()
249
- )
256
+ if not (self ._opaque_ptr is NULL ):
257
+ OpaqueSmartPtr_Delete(self ._opaque_ptr)
250
258
self ._cinit_empty()
251
259
260
+ cdef DPCTLSyclUSMRef get_data_ptr(self ):
261
+ return self ._memory_ptr
262
+
263
+ cdef void * get_opaque_ptr(self ):
264
+ return self ._opaque_ptr
265
+
252
266
cdef _getbuffer(self , Py_buffer * buffer , int flags):
253
267
# memory_ptr is Ref which is pointer to SYCL type. For USM it is void*.
254
268
cdef SyclContext ctx = self ._context
255
269
cdef _usm_type UsmTy = DPCTLUSM_GetPointerType(
256
- self .memory_ptr , ctx.get_context_ref()
270
+ self ._memory_ptr , ctx.get_context_ref()
257
271
)
258
272
if UsmTy == _usm_type._USM_DEVICE:
259
273
raise ValueError (" USM Device memory is not host accessible" )
260
- buffer .buf = < char * > self .memory_ptr
274
+ buffer .buf = < void * > self ._memory_ptr
261
275
buffer .format = ' B' # byte
262
276
buffer .internal = NULL # see References
263
277
buffer .itemsize = 1
@@ -285,7 +299,7 @@ cdef class _Memory:
285
299
represented as Python integer.
286
300
"""
287
301
def __get__ (self ):
288
- return < size_t> (self .memory_ptr )
302
+ return < size_t> (self ._memory_ptr )
289
303
290
304
property _context :
291
305
""" :class:`dpctl.SyclContext` the USM pointer is bound to. """
@@ -333,7 +347,7 @@ cdef class _Memory:
333
347
.format(
334
348
self .get_usm_type(),
335
349
self .nbytes,
336
- hex (< object > (< size_t> self .memory_ptr ))
350
+ hex (< object > (< size_t> self ._memory_ptr ))
337
351
)
338
352
)
339
353
@@ -377,7 +391,7 @@ cdef class _Memory:
377
391
"""
378
392
def __get__ (self ):
379
393
cdef dict iface = {
380
- " data" : (< size_t> (< void * > self .memory_ptr ),
394
+ " data" : (< size_t> (< void * > self ._memory_ptr ),
381
395
True ), # bool(self.writable)),
382
396
" shape" : (self .nbytes,),
383
397
" strides" : None ,
@@ -402,18 +416,18 @@ cdef class _Memory:
402
416
if syclobj is None :
403
417
ctx = self ._context
404
418
return _Memory.get_pointer_type(
405
- self .memory_ptr , ctx
419
+ self ._memory_ptr , ctx
406
420
).decode(" UTF-8" )
407
421
elif isinstance (syclobj, SyclContext):
408
422
ctx = < SyclContext> (syclobj)
409
423
return _Memory.get_pointer_type(
410
- self .memory_ptr , ctx
424
+ self ._memory_ptr , ctx
411
425
).decode(" UTF-8" )
412
426
elif isinstance (syclobj, SyclQueue):
413
427
q = < SyclQueue> (syclobj)
414
428
ctx = q.get_sycl_context()
415
429
return _Memory.get_pointer_type(
416
- self .memory_ptr , ctx
430
+ self ._memory_ptr , ctx
417
431
).decode(" UTF-8" )
418
432
raise TypeError (
419
433
" syclobj keyword can be either None, or an instance of "
@@ -435,18 +449,18 @@ cdef class _Memory:
435
449
if syclobj is None :
436
450
ctx = self ._context
437
451
return _Memory.get_pointer_type_enum(
438
- self .memory_ptr , ctx
452
+ self ._memory_ptr , ctx
439
453
)
440
454
elif isinstance (syclobj, SyclContext):
441
455
ctx = < SyclContext> (syclobj)
442
456
return _Memory.get_pointer_type_enum(
443
- self .memory_ptr , ctx
457
+ self ._memory_ptr , ctx
444
458
)
445
459
elif isinstance (syclobj, SyclQueue):
446
460
q = < SyclQueue> (syclobj)
447
461
ctx = q.get_sycl_context()
448
462
return _Memory.get_pointer_type_enum(
449
- self .memory_ptr , ctx
463
+ self ._memory_ptr , ctx
450
464
)
451
465
raise TypeError (
452
466
" syclobj keyword can be either None, or an instance of "
@@ -475,8 +489,8 @@ cdef class _Memory:
475
489
# call kernel to copy from
476
490
ERef = DPCTLQueue_Memcpy(
477
491
self .queue.get_queue_ref(),
478
- < void * > & host_buf[0 ], # destination
479
- < void * > self .memory_ptr , # source
492
+ < void * > & host_buf[0 ], # destination
493
+ < void * > self ._memory_ptr , # source
480
494
< size_t> self .nbytes
481
495
)
482
496
with nogil: DPCTLEvent_Wait(ERef)
@@ -500,8 +514,8 @@ cdef class _Memory:
500
514
# call kernel to copy from
501
515
ERef = DPCTLQueue_Memcpy(
502
516
self .queue.get_queue_ref(),
503
- < void * > self .memory_ptr , # destination
504
- < void * > & host_buf[0 ], # source
517
+ < void * > self ._memory_ptr , # destination
518
+ < void * > & host_buf[0 ], # source
505
519
< size_t> buf_len
506
520
)
507
521
with nogil: DPCTLEvent_Wait(ERef)
@@ -542,16 +556,16 @@ cdef class _Memory:
542
556
if (same_contexts):
543
557
ERef = DPCTLQueue_Memcpy(
544
558
this_queue.get_queue_ref(),
545
- < void * > self .memory_ptr ,
559
+ < void * > self ._memory_ptr ,
546
560
< void * > src_buf.p,
547
561
< size_t> src_buf.nbytes
548
562
)
549
563
with nogil: DPCTLEvent_Wait(ERef)
550
564
DPCTLEvent_Delete(ERef)
551
565
else :
552
566
copy_via_host(
553
- < void * > self .memory_ptr , this_queue, # dest
554
- < void * > src_buf.p, src_queue, # src
567
+ < void * > self ._memory_ptr , this_queue, # dest
568
+ < void * > src_buf.p, src_queue, # src
555
569
< size_t> src_buf.nbytes
556
570
)
557
571
else :
@@ -565,7 +579,7 @@ cdef class _Memory:
565
579
566
580
ERef = DPCTLQueue_Memset(
567
581
self .queue.get_queue_ref(),
568
- < void * > self .memory_ptr , # destination
582
+ < void * > self ._memory_ptr , # destination
569
583
< int > val,
570
584
self .nbytes)
571
585
@@ -703,20 +717,29 @@ cdef class _Memory:
703
717
res = _Memory.__new__ (_Memory)
704
718
_mem = < _Memory> res
705
719
_mem._cinit_empty()
706
- _mem.memory_ptr = USMRef
707
720
_mem.nbytes = nbytes
708
721
QRef_copy = DPCTLQueue_Copy(QRef)
709
722
if QRef_copy is NULL :
710
723
raise ValueError (" Referenced queue could not be copied." )
711
724
try :
712
- _mem.queue = SyclQueue._create(QRef_copy) # consumes the copy
725
+ # _create steals ownership of QRef_copy
726
+ _mem.queue = SyclQueue._create(QRef_copy)
713
727
except dpctl.SyclQueueCreationError as sqce:
714
728
raise ValueError (
715
729
" SyclQueue object could not be created from "
716
730
" copy of referenced queue"
717
731
) from sqce
718
- _mem.refobj = memory_owner
719
- return mem_ty(res)
732
+ if memory_owner is None :
733
+ _mem._memory_ptr = USMRef
734
+ # assume ownership of USM allocation via smart pointer
735
+ _mem._opaque_ptr = OpaqueSmartPtr_Make(< void * > USMRef, QRef)
736
+ _mem.refobj = None
737
+ else :
738
+ _mem._memory_ptr = USMRef
739
+ _mem._opaque_ptr = NULL
740
+ _mem.refobj = memory_owner
741
+ _out = mem_ty(< object > _mem)
742
+ return _out
720
743
721
744
722
745
cdef class MemoryUSMShared(_Memory):
@@ -908,10 +931,13 @@ def as_usm_memory(obj):
908
931
format(obj)
909
932
)
910
933
934
+ cdef api void * Memory_GetOpaquePointer(_Memory obj):
935
+ " Opaque pointer value"
936
+ return obj.get_opaque_ptr()
911
937
912
938
cdef api DPCTLSyclUSMRef Memory_GetUsmPointer(_Memory obj):
913
939
" Pointer of USM allocation"
914
- return obj.memory_ptr
940
+ return obj.get_data_ptr()
915
941
916
942
cdef api DPCTLSyclContextRef Memory_GetContextRef(_Memory obj):
917
943
" Context reference to which USM allocation is bound"
0 commit comments