Skip to content

Commit f54cdc1

Browse files
author
Diptorup Deb
authored
Merge pull request #1032 from IntelPython/feature/support_sycl_queue_in_array_ctors
Feature/support sycl queue in array constructor functions
2 parents c359fa4 + 50dcaec commit f54cdc1

20 files changed

+1637
-529
lines changed

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,15 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
3939
bool dest_is_float,
4040
bool value_is_float,
4141
int64_t value,
42-
const char *device);
42+
const DPCTLSyclQueueRef qref);
4343
static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
4444
void *data,
4545
npy_intp nitems,
4646
npy_intp itemsize,
4747
DPCTLSyclQueueRef qref);
48+
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
49+
size_t usm_type,
50+
const DPCTLSyclQueueRef qref);
4851
static void usmndarray_meminfo_dtor(void *ptr, size_t size, void *info);
4952
static PyObject *box_from_arystruct_parent(arystruct_t *arystruct,
5053
int ndim,
@@ -477,17 +480,23 @@ static NRT_MemInfo *NRT_MemInfo_new_from_usmndarray(PyObject *ndarrobj,
477480
* @param size The size of memory (data) owned by the NRT_MemInfo
478481
* object.
479482
* @param usm_type The usm type of the memory.
480-
* @param device The device on which the memory was allocated.
483+
* @param qref The sycl queue on which the memory was allocated. Note
484+
* that the ownership of the qref object is passed to
485+
* the NRT_MemInfo. As such, it is the caller's
486+
* responsibility to ensure the qref is nt owned by any
487+
* other object and is not deallocated. For such cases,
488+
* the caller should copy the DpctlSyclQueueRef and
489+
* pass a copy of the original qref.
481490
* @return {return} A new NRT_MemInfo object, NULL if no NRT_MemInfo
482491
* object could be created.
483492
*/
484-
static NRT_MemInfo *
485-
DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
493+
static NRT_MemInfo *DPEXRT_MemInfo_alloc(npy_intp size,
494+
size_t usm_type,
495+
const DPCTLSyclQueueRef qref)
486496
{
487497
NRT_MemInfo *mi = NULL;
488498
NRT_ExternalAllocator *ext_alloca = NULL;
489499
MemInfoDtorInfo *midtor_info = NULL;
490-
DPCTLSyclQueueRef qref = NULL;
491500

492501
DPEXRT_DEBUG(drt_debug_print(
493502
"DPEXRT-DEBUG: Inside DPEXRT_MemInfo_alloc %s, line %d\n", __FILE__,
@@ -499,15 +508,6 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
499508
goto error;
500509
}
501510

502-
if (!(qref = (DPCTLSyclQueueRef)DPEXRTQueue_CreateFromFilterString(device)))
503-
{
504-
DPEXRT_DEBUG(
505-
drt_debug_print("DPEXRT-ERROR: Could not create a sycl::queue from "
506-
"filter string: %s at %s %d.\n",
507-
device, __FILE__, __LINE__));
508-
goto error;
509-
}
510-
511511
// Allocate a new NRT_ExternalAllocator
512512
if (!(ext_alloca = NRT_ExternalAllocator_new_for_usm(qref, usm_type)))
513513
goto error;
@@ -520,15 +520,22 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
520520
mi->dtor_info = midtor_info;
521521
mi->data = ext_alloca->malloc(size, qref);
522522

523+
DPEXRT_DEBUG(
524+
DPCTLSyclDeviceRef device_ref; device_ref = DPCTLQueue_GetDevice(qref);
525+
drt_debug_print(
526+
"DPEXRT-DEBUG: DPEXRT_MemInfo_alloc, device info in %s at %d:\n%s",
527+
__FILE__, __LINE__, DPCTLDeviceMgr_GetDeviceInfoStr(device_ref));
528+
DPCTLDevice_Delete(device_ref););
529+
523530
if (mi->data == NULL)
524531
goto error;
525532

526533
mi->size = size;
527534
mi->external_allocator = ext_alloca;
528535
DPEXRT_DEBUG(drt_debug_print(
529536
"DPEXRT-DEBUG: DPEXRT_MemInfo_alloc mi=%p "
530-
"external_allocator=%p for usm_type %zu on device %s, %s at %d\n",
531-
mi, ext_alloca, usm_type, device, __FILE__, __LINE__));
537+
"external_allocator=%p for usm_type=%zu on queue=%p, %s at %d\n",
538+
mi, ext_alloca, usm_type, DPCTLQueue_Hash(qref), __FILE__, __LINE__));
532539

533540
return mi;
534541

@@ -551,7 +558,7 @@ DPEXRT_MemInfo_alloc(npy_intp size, size_t usm_type, const char *device)
551558
* @param dest_is_float True if the destination array's dtype is float.
552559
* @param value_is_float True if the value to be filled is float.
553560
* @param value The value to be used to fill an array.
554-
* @param device The device on which the memory was allocated.
561+
* @param qref The queue on which the memory was allocated.
555562
* @return NRT_MemInfo* A new NRT_MemInfo object, NULL if no NRT_MemInfo
556563
* object could be created.
557564
*/
@@ -560,9 +567,8 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
560567
bool dest_is_float,
561568
bool value_is_float,
562569
int64_t value,
563-
const char *device)
570+
const DPCTLSyclQueueRef qref)
564571
{
565-
DPCTLSyclQueueRef qref = NULL;
566572
DPCTLSyclEventRef eref = NULL;
567573
size_t count = 0, size = 0, exp = 0;
568574

@@ -603,9 +609,6 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
603609
goto error;
604610
}
605611

606-
if (!(qref = (DPCTLSyclQueueRef)DPEXRTQueue_CreateFromFilterString(device)))
607-
goto error;
608-
609612
switch (exp) {
610613
case 3:
611614
{
@@ -621,7 +624,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
621624
}
622625
else if (!dest_is_float && value_is_float) {
623626
double *p = (double *)&value;
624-
bc.i64_ = *p;
627+
bc.i64_ = (int64_t)*p;
625628
}
626629
else {
627630
bc.i64_ = value;
@@ -635,7 +638,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
635638
{
636639
if (dest_is_float && value_is_float) {
637640
double *p = (double *)(&value);
638-
bc.f_ = *p;
641+
bc.f_ = (float)*p;
639642
}
640643
else if (dest_is_float && !value_is_float) {
641644
// To stop warning: dereferencing type-punned pointer
@@ -645,7 +648,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
645648
}
646649
else if (!dest_is_float && value_is_float) {
647650
double *p = (double *)&value;
648-
bc.i32_ = *p;
651+
bc.i32_ = (int32_t)*p;
649652
}
650653
else {
651654
bc.i32_ = (int32_t)value;
@@ -662,7 +665,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
662665

663666
if (value_is_float) {
664667
double *p = (double *)&value;
665-
bc.i16_ = *p;
668+
bc.i16_ = (int16_t)*p;
666669
}
667670
else {
668671
bc.i16_ = (int16_t)value;
@@ -679,7 +682,7 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
679682

680683
if (value_is_float) {
681684
double *p = (double *)&value;
682-
bc.i8_ = *p;
685+
bc.i8_ = (int8_t)*p;
683686
}
684687
else {
685688
bc.i8_ = (int8_t)value;
@@ -694,8 +697,6 @@ static NRT_MemInfo *DPEXRT_MemInfo_fill(NRT_MemInfo *mi,
694697
}
695698

696699
DPCTLEvent_Wait(eref);
697-
698-
DPCTLQueue_Delete(qref);
699700
DPCTLEvent_Delete(eref);
700701

701702
return mi;
@@ -1198,6 +1199,14 @@ static int DPEXRT_sycl_queue_from_python(PyObject *obj,
11981199
goto error;
11991200
}
12001201

1202+
DPEXRT_DEBUG(DPCTLSyclDeviceRef device_ref;
1203+
device_ref = DPCTLQueue_GetDevice(queue_ref);
1204+
drt_debug_print("DPEXRT-DEBUG: DPEXRT_sycl_queue_from_python, "
1205+
"device info in %s at %d:\n%s",
1206+
__FILE__, __LINE__,
1207+
DPCTLDeviceMgr_GetDeviceInfoStr(device_ref));
1208+
DPCTLDevice_Delete(device_ref););
1209+
12011210
queue_struct->parent = obj;
12021211
queue_struct->queue_ref = queue_ref;
12031212

numba_dpex/core/runtime/context.py

Lines changed: 76 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,20 @@ def _check_null_result(func):
2020
@functools.wraps(func)
2121
def wrap(self, builder, *args, **kwargs):
2222
memptr = func(self, builder, *args, **kwargs)
23-
msg = "USM allocation failed. Check the usm_type and filter "
24-
"string values."
23+
msg = "USM allocation failed. Check the usm_type and queue."
2524
cgutils.guard_memory_error(self._context, builder, memptr, msg=msg)
2625
return memptr
2726

2827
return wrap
2928

3029
@_check_null_result
31-
def meminfo_alloc(self, builder, size, usm_type, device):
30+
def meminfo_alloc(self, builder, size, usm_type, queue_ref):
3231
"""
3332
Wrapper to call :func:`~context.DpexRTContext.meminfo_alloc_unchecked`
3433
with null checking of the returned value.
3534
"""
36-
return self.meminfo_alloc_unchecked(builder, size, usm_type, device)
35+
36+
return self.meminfo_alloc_unchecked(builder, size, usm_type, queue_ref)
3737

3838
@_check_null_result
3939
def meminfo_fill(
@@ -44,7 +44,7 @@ def meminfo_fill(
4444
dest_is_float,
4545
value_is_float,
4646
value,
47-
device,
47+
queue_ref,
4848
):
4949
"""
5050
Wrapper to call :func:`~context.DpexRTContext.meminfo_fill_unchecked`
@@ -57,28 +57,34 @@ def meminfo_fill(
5757
dest_is_float,
5858
value_is_float,
5959
value,
60-
device,
60+
queue_ref,
6161
)
6262

63-
def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
63+
def meminfo_alloc_unchecked(self, builder, size, usm_type, queue_ref):
6464
"""Allocate a new MemInfo with a data payload of `size` bytes.
6565
6666
The result of the call is checked and if it is NULL, i.e. allocation
6767
failed, then a MemoryError is raised. If the allocation succeeded then
6868
a pointer to the MemInfo is returned.
6969
7070
Args:
71-
builder (_type_): LLVM IR builder
72-
size (_type_): LLVM uint64 Value specifying the size in bytes for
73-
the data payload.
74-
usm_type (_type_): An LLVM Constant Value specifying the type of the
75-
usm allocator. The constant value should match the values in
76-
``dpctl's`` ``libsyclinterface::DPCTLSyclUSMType`` enum.
77-
device (_type_): An LLVM ArrayType storing a const string for a
78-
DPC++ filter selector string.
79-
80-
Returns: A pointer to the MemInfo is returned.
71+
builder (`llvmlite.ir.builder.IRBuilder`): LLVM IR builder.
72+
size (`llvmlite.ir.values.Argument`): LLVM uint64 value specifying
73+
the size in bytes for the data payload, i.e. i64 %"arg.allocsize"
74+
usm_type (`llvmlite.ir.values.Argument`): An LLVM Argument object
75+
specifying the type of the usm allocator. The constant value
76+
should match the values in
77+
``dpctl's`` ``libsyclinterface::DPCTLSyclUSMType`` enum,
78+
i.e. i64 %"arg.usm_type".
79+
queue_ref (`llvmlite.ir.values.Argument`): An LLVM argument value storing
80+
the pointer to the address of the queue object, the object can be
81+
`dpctl.SyclQueue()`, i.e. i8* %"arg.queue".
82+
83+
Returns:
84+
ret (`llvmlite.ir.instructions.CallInstr`): A pointer to the `MemInfo`
85+
is returned from the `DPEXRT_MemInfo_alloc` C function call.
8186
"""
87+
8288
mod = builder.module
8389
u64 = llvmir.IntType(64)
8490
fnty = llvmir.FunctionType(
@@ -87,7 +93,7 @@ def meminfo_alloc_unchecked(self, builder, size, usm_type, device):
8793
fn = cgutils.get_or_insert_function(mod, fnty, "DPEXRT_MemInfo_alloc")
8894
fn.return_value.add_attribute("noalias")
8995

90-
ret = builder.call(fn, [size, usm_type, device])
96+
ret = builder.call(fn, [size, usm_type, queue_ref])
9197

9298
return ret
9399

@@ -99,7 +105,7 @@ def meminfo_fill_unchecked(
99105
dest_is_float,
100106
value_is_float,
101107
value,
102-
device,
108+
queue_ref,
103109
):
104110
"""Fills an allocated `MemInfo` with the value specified.
105111
@@ -108,17 +114,29 @@ def meminfo_fill_unchecked(
108114
is succeeded then a pointer to the `MemInfo` is returned.
109115
110116
Args:
111-
builder (llvmlite.ir.builder.IRBuilder): LLVM IR builder
112-
meminfo (llvmlite.ir.instructions.LoadInstr): LLVM uint64 value
117+
builder (`llvmlite.ir.builder.IRBuilder`): LLVM IR builder.
118+
meminfo (`llvmlite.ir.instructions.LoadInstr`): LLVM uint64 value
113119
specifying the size in bytes for the data payload.
114-
itemsize (llvmlite.ir.values.Constant): An LLVM Constant value
120+
itemsize (`llvmlite.ir.values.Constant`): An LLVM Constant value
115121
specifying the size of the each data item allocated by the
116122
usm allocator.
117-
device (llvmlite.ir.values.FormattedConstant): An LLVM ArrayType
118-
storing a const string for a DPC++ filter selector string.
123+
dest_is_float (`llvmlite.ir.values.Constant`): An LLVM Constant
124+
value specifying if the destination array type is floating
125+
point.
126+
value_is_float (`llvmlite.ir.values.Constant`): An LLVM Constant
127+
value specifying if the input value is a floating point.
128+
value (`llvmlite.ir.values.Constant`): An LLVM Constant value
129+
specifying if the input value that will be used to fill
130+
the array.
131+
queue_ref (`llvmlite.ir.instructions.ExtractValue`): An LLVM ExtractValue
132+
instruction object to extract the pointer to the queue from the
133+
DpctlSyclQueue type, i.e. %".74" = extractvalue {i8*, i8*} %".73", 1.
119134
120-
Returns: A pointer to the `MemInfo` is returned.
135+
Returns:
136+
ret (`llvmlite.ir.instructions.CallInstr`): A pointer to the `MemInfo`
137+
is returned from the `DPEXRT_MemInfo_fill` C function call.
121138
"""
139+
122140
mod = builder.module
123141
u64 = llvmir.IntType(64)
124142
b = llvmir.IntType(1)
@@ -131,7 +149,14 @@ def meminfo_fill_unchecked(
131149

132150
ret = builder.call(
133151
fn,
134-
[meminfo, itemsize, dest_is_float, value_is_float, value, device],
152+
[
153+
meminfo,
154+
itemsize,
155+
dest_is_float,
156+
value_is_float,
157+
value,
158+
queue_ref,
159+
],
135160
)
136161

137162
return ret
@@ -154,7 +179,6 @@ def arraystruct_from_python(self, pyapi, obj, ptr):
154179

155180
def queuestruct_from_python(self, pyapi, obj, ptr):
156181
"""Calls the c function DPEXRT_sycl_queue_from_python"""
157-
158182
fnty = llvmir.FunctionType(
159183
llvmir.IntType(32), [pyapi.pyobj, pyapi.voidptr]
160184
)
@@ -164,7 +188,6 @@ def queuestruct_from_python(self, pyapi, obj, ptr):
164188
fn.args[1].add_attribute("nocapture")
165189

166190
self.error = pyapi.builder.call(fn, (obj, ptr))
167-
168191
return self.error
169192

170193
def queuestruct_to_python(self, pyapi, val):
@@ -258,7 +281,7 @@ def submit_range(
258281
"""Calls DPEXRTQueue_CreateFromFilterString to create a new sycl::queue
259282
from a given filter string.
260283
261-
Returns: A LLVM IR call inst.
284+
Returns: A DPCTLSyclQueueRef pointer.
262285
"""
263286
mod = builder.module
264287
fnty = llvmir.FunctionType(
@@ -353,3 +376,27 @@ def submit_ndrange(
353376
)
354377

355378
return ret
379+
380+
def copy_queue(self, builder, queue_ref):
381+
"""Calls DPCTLQueue_Copy to create a copy of the DpctlSyclQueueRef
382+
pointer passed in to the function.
383+
384+
Args:
385+
builder: The llvmlite.IRBuilder used to generate the LLVM IR for the
386+
call.
387+
queue_ref: An LLVM value for a DpctlSyclQueueRef pointer that will
388+
be passed to the DPCTLQueue_Copy function.
389+
390+
Returns: A DPCTLSyclQueueRef pointer.
391+
"""
392+
mod = builder.module
393+
fnty = llvmir.FunctionType(
394+
cgutils.voidptr_t,
395+
[cgutils.voidptr_t],
396+
)
397+
fn = cgutils.get_or_insert_function(mod, fnty, "DPCTLQueue_Copy")
398+
fn.return_value.add_attribute("noalias")
399+
400+
ret = builder.call(fn, [queue_ref])
401+
402+
return ret

0 commit comments

Comments
 (0)