Skip to content

Commit 04c18bf

Browse files
authored
Merge pull request #1219 from IntelPython/feature/async_kernel_submition
Feature/async kernel submition
2 parents 0c3620e + 707412b commit 04c18bf

File tree

11 files changed

+483
-56
lines changed

11 files changed

+483
-56
lines changed

numba_dpex/core/runtime/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ python_add_library(${PROJECT_NAME} MODULE ${SOURCES})
109109

110110
# Add SYCL to target, this must come after python_add_library()
111111
# FIXME: sources incompatible with sycl include?
112-
# add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
112+
add_sycl_to_target(TARGET ${PROJECT_NAME} SOURCES ${KERNEL_SOURCES})
113113

114114
# Link the DPCTLSyclInterface library to target
115115
target_link_libraries(${PROJECT_NAME} PRIVATE DPCTLSyclInterface)

numba_dpex/core/runtime/_dpexrt_python.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "_queuestruct.h"
2525
#include "_usmarraystruct.h"
2626

27+
#include "experimental/nrt_reserve_meminfo.h"
2728
#include "numba/core/runtime/nrt_external.h"
2829

2930
// forward declarations
@@ -1490,6 +1491,8 @@ static PyObject *build_c_helpers_dict(void)
14901491
&DPEXRT_sycl_event_from_python);
14911492
_declpointer("DPEXRT_sycl_event_to_python", &DPEXRT_sycl_event_to_python);
14921493
_declpointer("DPEXRT_sycl_event_init", &DPEXRT_sycl_event_init);
1494+
_declpointer("DPEXRT_nrt_acquire_meminfo_and_schedule_release",
1495+
&DPEXRT_nrt_acquire_meminfo_and_schedule_release);
14931496

14941497
#undef _declpointer
14951498
return dct;
@@ -1557,6 +1560,9 @@ MOD_INIT(_dpexrt_python)
15571560
PyLong_FromVoidPtr(&DPEXRT_MemInfo_alloc));
15581561
PyModule_AddObject(m, "DPEXRT_MemInfo_fill",
15591562
PyLong_FromVoidPtr(&DPEXRT_MemInfo_fill));
1563+
PyModule_AddObject(
1564+
m, "DPEXRT_nrt_acquire_meminfo_and_schedule_release",
1565+
PyLong_FromVoidPtr(&DPEXRT_nrt_acquire_meminfo_and_schedule_release));
15601566
PyModule_AddObject(m, "c_helpers", build_c_helpers_dict());
15611567
return MOD_SUCCESS_VAL(m);
15621568
}

numba_dpex/core/runtime/context.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,41 @@ def submit_ndrange(
433433
)
434434

435435
return ret
436+
437+
def acquire_meminfo_and_schedule_release(
438+
self, builder: llvmir.IRBuilder, args
439+
):
440+
"""Inserts LLVM IR to call nrt_acquire_meminfo_and_schedule_release.
441+
442+
DPCTLSyclEventRef
443+
DPEXRT_nrt_acquire_meminfo_and_schedule_release(
444+
NRT_api_functions *nrt,
445+
DPCTLSyclQueueRef QRef,
446+
NRT_MemInfo **meminfo_array,
447+
size_t meminfo_array_size,
448+
DPCTLSyclEventRef *depERefs,
449+
size_t nDepERefs,
450+
int *status,
451+
);
452+
453+
"""
454+
mod = builder.module
455+
456+
func_ty = llvmir.FunctionType(
457+
cgutils.voidptr_t,
458+
[
459+
cgutils.voidptr_t,
460+
cgutils.voidptr_t,
461+
cgutils.voidptr_t.as_pointer(),
462+
llvmir.IntType(64),
463+
cgutils.voidptr_t.as_pointer(),
464+
llvmir.IntType(64),
465+
llvmir.IntType(64).as_pointer(),
466+
],
467+
)
468+
fn = cgutils.get_or_insert_function(
469+
mod, func_ty, "DPEXRT_nrt_acquire_meminfo_and_schedule_release"
470+
)
471+
ret = builder.call(fn, args)
472+
473+
return ret
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// SPDX-FileCopyrightText: 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include "nrt_reserve_meminfo.h"
6+
7+
#include "_dbg_printer.h"
8+
#include "syclinterface/dpctl_sycl_type_casters.hpp"
9+
#include <CL/sycl.hpp>
10+
11+
extern "C"
12+
{
13+
DPCTLSyclEventRef
14+
DPEXRT_nrt_acquire_meminfo_and_schedule_release(NRT_api_functions *nrt,
15+
DPCTLSyclQueueRef QRef,
16+
NRT_MemInfo **meminfo_array,
17+
size_t meminfo_array_size,
18+
DPCTLSyclEventRef *depERefs,
19+
size_t nDepERefs,
20+
int *status)
21+
{
22+
DPEXRT_DEBUG(drt_debug_print(
23+
"DPEXRT-DEBUG: scheduling nrt meminfo release.\n"););
24+
25+
using dpctl::syclinterface::unwrap;
26+
using dpctl::syclinterface::wrap;
27+
28+
sycl::queue *q = unwrap<sycl::queue>(QRef);
29+
30+
std::vector<NRT_MemInfo *> meminfo_vec(
31+
meminfo_array, meminfo_array + meminfo_array_size);
32+
33+
for (size_t i = 0; i < meminfo_array_size; ++i) {
34+
nrt->acquire(meminfo_vec[i]);
35+
}
36+
37+
DPEXRT_DEBUG(drt_debug_print("DPEXRT-DEBUG: acquired meminfo.\n"););
38+
39+
try {
40+
sycl::event ht_ev = q->submit([&](sycl::handler &cgh) {
41+
for (size_t ev_id = 0; ev_id < nDepERefs; ++ev_id) {
42+
cgh.depends_on(*(unwrap<sycl::event>(depERefs[ev_id])));
43+
}
44+
cgh.host_task([meminfo_array_size, meminfo_vec, nrt]() {
45+
for (size_t i = 0; i < meminfo_array_size; ++i) {
46+
nrt->release(meminfo_vec[i]);
47+
DPEXRT_DEBUG(
48+
drt_debug_print("DPEXRT-DEBUG: released meminfo "
49+
"from host_task.\n"););
50+
}
51+
});
52+
});
53+
54+
constexpr int result_ok = 0;
55+
56+
*status = result_ok;
57+
auto e_ptr = new sycl::event(ht_ev);
58+
return wrap<sycl::event>(e_ptr);
59+
} catch (const std::exception &e) {
60+
constexpr int result_std_exception = 1;
61+
62+
*status = result_std_exception;
63+
return nullptr;
64+
}
65+
66+
constexpr int result_other_abnormal = 2;
67+
68+
*status = result_other_abnormal;
69+
return nullptr;
70+
}
71+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// SPDX-FileCopyrightText: 2023 Intel Corporation
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
//===----------------------------------------------------------------------===//
6+
///
7+
/// \file
8+
/// Defines dpctl style function(s) that interruct with nrt meminfo and sycl.
9+
///
10+
//===----------------------------------------------------------------------===//
11+
12+
#ifndef _EXPERIMENTAL_H_
13+
#define _EXPERIMENTAL_H_
14+
15+
#include "dpctl_capi.h"
16+
#include "numba/core/runtime/nrt_external.h"
17+
18+
#ifdef __cplusplus
19+
extern "C"
20+
{
21+
#endif
22+
23+
/*!
24+
* @brief Acquires meminfos and schedules a host task to release them.
25+
*
26+
* @param nrt NRT public API functions,
27+
* @param QRef Queue reference,
28+
* @param meminfo_array Array of meminfo pointers to perform actions on,
29+
* @param meminfo_array_size Length of meminfo_array,
30+
* @param depERefs Array of dependant events for the host task,
31+
* @param nDepERefs Length of depERefs,
32+
* @param status Variable to write status to. Same style as
33+
* dpctl,
34+
* @return {return} Event reference to the host task.
35+
*/
36+
DPCTLSyclEventRef
37+
DPEXRT_nrt_acquire_meminfo_and_schedule_release(NRT_api_functions *nrt,
38+
DPCTLSyclQueueRef QRef,
39+
NRT_MemInfo **meminfo_array,
40+
size_t meminfo_array_size,
41+
DPCTLSyclEventRef *depERefs,
42+
size_t nDepERefs,
43+
int *status);
44+
#ifdef __cplusplus
45+
}
46+
#endif
47+
48+
#endif /* _EXPERIMENTAL_H_ */

numba_dpex/dpctl_iface/_intrinsic.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import dpctl
66
from llvmlite.ir import IRBuilder
77
from numba import types
8-
from numba.core import cgutils, imputils
98
from numba.core.datamodel import default_manager
10-
from numba.extending import intrinsic, overload, overload_method, type_callable
9+
from numba.extending import intrinsic, overload, overload_method
1110

1211
import numba_dpex.dpctl_iface.libsyclinterface_bindings as sycl
1312
from numba_dpex.core import types as dpex_types
14-
from numba_dpex.core.runtime import context as dpexrt
13+
from numba_dpex.dpctl_iface.wrappers import wrap_event_reference
1514

1615

1716
@intrinsic
@@ -33,23 +32,8 @@ def sycl_event_create(
3332
sig = ty_event(types.void)
3433

3534
def codegen(context, builder: IRBuilder, sig, args: list):
36-
pyapi = context.get_python_api(builder)
37-
38-
event_struct_proxy = cgutils.create_struct_proxy(ty_event)(
39-
context, builder
40-
)
41-
4235
event = sycl.dpctl_event_create(builder)
43-
dpexrtCtx = dpexrt.DpexRTContext(context)
44-
45-
# Ref count after the call is equal to 1.
46-
dpexrtCtx.eventstruct_init(
47-
pyapi, event, event_struct_proxy._getpointer()
48-
)
49-
50-
event_value = event_struct_proxy._getvalue()
51-
52-
return event_value
36+
return wrap_event_reference(context, builder, event)
5337

5438
return sig, codegen
5539

numba_dpex/dpctl_iface/wrappers.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from numba.core import cgutils
6+
7+
from numba_dpex.core.runtime import context as dpexrt
8+
from numba_dpex.core.types import DpctlSyclEvent
9+
10+
11+
def wrap_event_reference(ctx, builder, eref):
12+
"""Wrap dpctl event reference into datamodel so it can be boxed to
13+
Python."""
14+
15+
ty_event = DpctlSyclEvent()
16+
17+
pyapi = ctx.get_python_api(builder)
18+
19+
event_struct_proxy = cgutils.create_struct_proxy(ty_event)(ctx, builder)
20+
21+
# Ref count after the call is equal to 1.
22+
# TODO: get dpex RT from cached property once the PR is merged
23+
# https://github.com/IntelPython/numba-dpex/pull/1027
24+
# ctx.dpexrt.eventstruct_init( # noqa: W0621
25+
dpexrt.DpexRTContext(ctx).eventstruct_init(
26+
pyapi,
27+
eref,
28+
# calling _<method>() is by numba's design
29+
event_struct_proxy._getpointer(), # pylint: disable=W0212
30+
)
31+
32+
# calling _<method>() is by numba's design
33+
event_value = event_struct_proxy._getvalue() # pylint: disable=W0212
34+
35+
return event_value

numba_dpex/experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .decorators import kernel
1212
from .kernel_dispatcher import KernelDispatcher
13-
from .launcher import call_kernel
13+
from .launcher import call_kernel, call_kernel_async
1414
from .models import *
1515
from .types import KernelDispatcherType
1616

@@ -26,4 +26,4 @@ def dpex_dispatcher_const(context):
2626
return context.get_dummy_value()
2727

2828

29-
__all__ = ["kernel", "KernelDispatcher", "call_kernel"]
29+
__all__ = ["kernel", "KernelDispatcher", "call_kernel", "call_kernel_async"]

numba_dpex/experimental/kernel_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def get_overload_device_ir(self, sig):
254254
args, _ = sigutils.normalize_signature(sig)
255255
return self.overloads[tuple(args)].kernel_device_ir_module
256256

257-
def compile(self, sig) -> _KernelCompileResult:
257+
def compile(self, sig) -> any:
258258
disp = self._get_dispatcher_for_current_target()
259259
if disp is not self:
260260
return disp.compile(sig)

0 commit comments

Comments
 (0)