Skip to content

Commit ad2cde1

Browse files
author
Diptorup Deb
authored
Merge pull request #1178 from IntelPython/feature/KernelDispatcher
An experimental kernel dispatcher for numba_dpex.kernel decorator
2 parents ac05c1a + 199183c commit ad2cde1

21 files changed

+1525
-108
lines changed

.github/workflows/pre-commit.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,19 @@ on:
88
jobs:
99
pre-commit:
1010
runs-on: ubuntu-20.04
11+
defaults:
12+
run:
13+
shell: bash -el {0}
1114
steps:
1215
- uses: actions/checkout@v3
13-
- uses: actions/setup-python@v3
16+
- uses: conda-incubator/setup-miniconda@v2
1417
with:
1518
python-version: '3.11'
16-
- uses: pre-commit/[email protected]
19+
activate-environment: "coverage"
20+
channel-priority: "disabled"
21+
environment-file: environment/pre-commit.yml
22+
- uses: actions/cache@v3
23+
with:
24+
path: ~/.cache/pre-commit
25+
key: pre-commit-3|${{ env.pythonLocation }}|${{ hashFiles('.pre-commit-config.yaml') }}
26+
- run: pre-commit run --show-diff-on-failure --color=always --all-files

.pre-commit-config.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,17 @@ repos:
4747
args: ["-i"]
4848
exclude: "numba_dpex/dpnp_iface"
4949
types_or: [c++, c]
50+
- repo: local
51+
hooks:
52+
- id: pylint
53+
name: pylint
54+
entry: pylint
55+
files: ^numba_dpex/experimental
56+
language: system
57+
types: [python]
58+
require_serial: true
59+
args:
60+
[
61+
"-rn", # Only display messages
62+
"-sn", # Don't display the score
63+
]

environment/pre-commit.yml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name: dev
2+
channels:
3+
- dppy/label/dev
4+
- numba
5+
- intel
6+
- conda-forge
7+
- nodefaults
8+
dependencies:
9+
- libffi
10+
- gxx_linux-64
11+
- dpcpp_linux-64
12+
- numba==0.58*
13+
- dpctl
14+
- dpnp
15+
- dpcpp-llvm-spirv
16+
- opencl_rt
17+
- coverage
18+
- pytest
19+
- pytest-cov
20+
- pytest-xdist
21+
- pexpect
22+
- scikit-build>=0.15*
23+
- cmake>=3.26*
24+
- pre-commit
25+
- pylint

numba_dpex/config.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,15 @@ def __getattr__(name):
5959
"NUMBA_DPEX_DEBUGINFO", int, config.DEBUGINFO_DEFAULT
6060
)
6161

62-
# Emit LLVM assembly language format(.ll)
63-
DUMP_KERNEL_LLVM = _readenv(
64-
"NUMBA_DPEX_DUMP_KERNEL_LLVM", int, config.DUMP_OPTIMIZED
65-
)
62+
# Emit LLVM IR generated for kernel decorated function
63+
DUMP_KERNEL_LLVM = _readenv("NUMBA_DPEX_DUMP_KERNEL_LLVM", int, 0)
64+
65+
# Emit LLVM module generated to launch a kernel decorated function
66+
DUMP_KERNEL_LAUNCHER = _readenv("NUMBA_DPEX_DUMP_KERNEL_LAUNCHER", int, 0)
67+
68+
# Enables debug printf messages inside the kernel launcher module generated for
69+
# a kernel decorated function
70+
DEBUG_KERNEL_LAUNCHER = _readenv("NUMBA_DPEX_DEBUG_KERNEL_LAUNCHER", int, 0)
6671

6772
# configs for caching
6873
# To see the debug messages for the caching.

numba_dpex/core/descriptor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,13 @@ def _inherit_if_not_set(flags, options, name, default=targetconfig._NotSet):
3838
class DpexTargetOptions(CPUTargetOptions):
3939
experimental = _option_mapping("experimental")
4040
release_gil = _option_mapping("release_gil")
41+
no_compile = _option_mapping("no_compile")
4142

4243
def finalize(self, flags, options):
4344
super().finalize(flags, options)
4445
_inherit_if_not_set(flags, options, "experimental", False)
4546
_inherit_if_not_set(flags, options, "release_gil", False)
47+
_inherit_if_not_set(flags, options, "no_compile", True)
4648

4749

4850
class DpexKernelTarget(TargetDescriptor):

numba_dpex/core/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,12 @@ def __init__(self, kernel_name, *, usmarray_argnum_list) -> None:
215215
f"usm_ndarray arguments {usmarray_args} were not allocated "
216216
"on the same queue."
217217
)
218+
else:
219+
self.message = (
220+
f'Execution queue for kernel "{kernel_name}" could '
221+
"be deduced using compute follows data programming model. The "
222+
"kernel has no USMNdArray argument."
223+
)
218224
super().__init__(self.message)
219225

220226

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 90 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import copy
6+
from collections import namedtuple
67

78
from llvmlite import ir as llvmir
8-
from numba.core import ir, types
9+
from numba.core import cgutils, ir, types
910
from numba.parfors.parfor import (
1011
find_potential_aliases_parfor,
1112
get_parfor_outputs,
@@ -27,6 +28,12 @@
2728
create_reduction_remainder_kernel_for_parfor,
2829
)
2930

31+
_KernelArgs = namedtuple(
32+
"_KernelArgs",
33+
["num_flattened_args", "arg_vals", "arg_types"],
34+
)
35+
36+
3037
# A global list of kernels to keep the objects alive indefinitely.
3138
keep_alive_kernels = []
3239

@@ -84,21 +91,7 @@ class ParforLowerImpl:
8491
for a parfor and submits it to a queue.
8592
"""
8693

87-
def _get_exec_queue(self, kernel_fn, lowerer):
88-
"""Creates a stack variable storing the sycl queue pointer used to
89-
launch the kernel function.
90-
"""
91-
self.kernel_builder = KernelLaunchIRBuilder(
92-
lowerer.context, lowerer.builder, kernel_fn.kernel.addressof_ref()
93-
)
94-
95-
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
96-
# pointer.
97-
self.curr_queue = self.kernel_builder.get_queue(
98-
exec_queue=kernel_fn.queue
99-
)
100-
101-
def _build_kernel_arglist(self, kernel_fn, lowerer):
94+
def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
10295
"""Creates local variables for all the arguments and the argument types
10396
that are passes to the kernel function.
10497
@@ -110,39 +103,43 @@ def _build_kernel_arglist(self, kernel_fn, lowerer):
110103
AssertionError: If the LLVM IR Value for an argument defined in
111104
Numba IR is not found.
112105
"""
113-
self.num_flattened_args = 0
106+
num_flattened_args = 0
114107

115108
# Compute number of args to be passed to the kernel. Note that the
116109
# actual number of kernel arguments is greater than the count of
117110
# kernel_fn.kernel_args as arrays get flattened.
118111
for arg_type in kernel_fn.kernel_arg_types:
119112
if isinstance(arg_type, DpnpNdArray):
120113
datamodel = dpex_dmm.lookup(arg_type)
121-
self.num_flattened_args += datamodel.flattened_field_count
114+
num_flattened_args += datamodel.flattened_field_count
122115
elif arg_type == types.complex64 or arg_type == types.complex128:
123-
self.num_flattened_args += 2
116+
num_flattened_args += 2
124117
else:
125-
self.num_flattened_args += 1
118+
num_flattened_args += 1
126119

127120
# Create LLVM values for the kernel args list and kernel arg types list
128-
self.args_list = self.kernel_builder.allocate_kernel_arg_array(
129-
self.num_flattened_args
130-
)
131-
self.args_ty_list = self.kernel_builder.allocate_kernel_arg_ty_array(
132-
self.num_flattened_args
121+
args_list = kernel_builder.allocate_kernel_arg_array(num_flattened_args)
122+
args_ty_list = kernel_builder.allocate_kernel_arg_ty_array(
123+
num_flattened_args
133124
)
134125
callargs_ptrs = []
135126
for arg in kernel_fn.kernel_args:
136127
callargs_ptrs.append(_getvar(lowerer, arg))
137128

138-
self.kernel_builder.populate_kernel_args_and_args_ty_arrays(
129+
kernel_builder.populate_kernel_args_and_args_ty_arrays(
139130
kernel_argtys=kernel_fn.kernel_arg_types,
140131
callargs_ptrs=callargs_ptrs,
141-
args_list=self.args_list,
142-
args_ty_list=self.args_ty_list,
132+
args_list=args_list,
133+
args_ty_list=args_ty_list,
143134
datamodel_mgr=dpex_dmm,
144135
)
145136

137+
return _KernelArgs(
138+
num_flattened_args=num_flattened_args,
139+
arg_vals=args_list,
140+
arg_types=args_ty_list,
141+
)
142+
146143
def _submit_parfor_kernel(
147144
self,
148145
lowerer,
@@ -156,9 +153,11 @@ def _submit_parfor_kernel(
156153
# Ensure that the Python arguments are kept alive for the duration of
157154
# the kernel execution
158155
keep_alive_kernels.append(kernel_fn.kernel)
156+
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
157+
158+
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
159+
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
159160

160-
self._get_exec_queue(kernel_fn, lowerer)
161-
self._build_kernel_arglist(kernel_fn, lowerer)
162161
# Create a global range over which to submit the kernel based on the
163162
# loop_ranges of the parfor
164163
global_range = []
@@ -178,18 +177,26 @@ def _submit_parfor_kernel(
178177

179178
local_range = []
180179

180+
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
181+
kernel_ref = lowerer.builder.inttoptr(
182+
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
183+
cgutils.voidptr_t,
184+
)
185+
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
186+
181187
# Submit a synchronous kernel
182-
self.kernel_builder.submit_sync_kernel(
183-
self.curr_queue,
184-
self.num_flattened_args,
185-
self.args_list,
186-
self.args_ty_list,
187-
global_range,
188-
local_range,
188+
kernel_builder.submit_sycl_kernel(
189+
sycl_kernel_ref=kernel_ref,
190+
sycl_queue_ref=curr_queue_ref,
191+
total_kernel_args=args.num_flattened_args,
192+
arg_list=args.arg_vals,
193+
arg_ty_list=args.arg_types,
194+
global_range=global_range,
195+
local_range=local_range,
189196
)
190197

191198
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
192-
self.kernel_builder.free_queue(sycl_queue_val=self.curr_queue)
199+
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
193200

194201
def _submit_reduction_main_parfor_kernel(
195202
self,
@@ -204,9 +211,11 @@ def _submit_reduction_main_parfor_kernel(
204211
# Ensure that the Python arguments are kept alive for the duration of
205212
# the kernel execution
206213
keep_alive_kernels.append(kernel_fn.kernel)
214+
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
215+
216+
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
207217

208-
self._get_exec_queue(kernel_fn, lowerer)
209-
self._build_kernel_arglist(kernel_fn, lowerer)
218+
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
210219
# Create a global range over which to submit the kernel based on the
211220
# loop_ranges of the parfor
212221
global_range = []
@@ -220,16 +229,27 @@ def _submit_reduction_main_parfor_kernel(
220229
_load_range(lowerer, reductionHelper.work_group_size)
221230
)
222231

232+
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
233+
kernel_ref = lowerer.builder.inttoptr(
234+
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
235+
cgutils.voidptr_t,
236+
)
237+
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
238+
223239
# Submit a synchronous kernel
224-
self.kernel_builder.submit_sync_kernel(
225-
self.curr_queue,
226-
self.num_flattened_args,
227-
self.args_list,
228-
self.args_ty_list,
229-
global_range,
230-
local_range,
240+
kernel_builder.submit_sycl_kernel(
241+
sycl_kernel_ref=kernel_ref,
242+
sycl_queue_ref=curr_queue_ref,
243+
total_kernel_args=args.num_flattened_args,
244+
arg_list=args.arg_vals,
245+
arg_ty_list=args.arg_types,
246+
global_range=global_range,
247+
local_range=local_range,
231248
)
232249

250+
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
251+
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
252+
233253
def _submit_reduction_remainder_parfor_kernel(
234254
self,
235255
lowerer,
@@ -243,8 +263,11 @@ def _submit_reduction_remainder_parfor_kernel(
243263
# the kernel execution
244264
keep_alive_kernels.append(kernel_fn.kernel)
245265

246-
self._get_exec_queue(kernel_fn, lowerer)
247-
self._build_kernel_arglist(kernel_fn, lowerer)
266+
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
267+
268+
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
269+
270+
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
248271
# Create a global range over which to submit the kernel based on the
249272
# loop_ranges of the parfor
250273
global_range = []
@@ -255,16 +278,27 @@ def _submit_reduction_remainder_parfor_kernel(
255278

256279
local_range = []
257280

281+
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
282+
kernel_ref = lowerer.builder.inttoptr(
283+
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
284+
cgutils.voidptr_t,
285+
)
286+
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
287+
258288
# Submit a synchronous kernel
259-
self.kernel_builder.submit_sync_kernel(
260-
self.curr_queue,
261-
self.num_flattened_args,
262-
self.args_list,
263-
self.args_ty_list,
264-
global_range,
265-
local_range,
289+
kernel_builder.submit_sycl_kernel(
290+
sycl_kernel_ref=kernel_ref,
291+
sycl_queue_ref=curr_queue_ref,
292+
total_kernel_args=args.num_flattened_args,
293+
arg_list=args.arg_vals,
294+
arg_ty_list=args.arg_types,
295+
global_range=global_range,
296+
local_range=local_range,
266297
)
267298

299+
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
300+
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
301+
268302
def _reduction_codegen(
269303
self,
270304
parfor,

numba_dpex/core/parfors/reduction_helper.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,7 @@ def work_group_size(self):
395395

396396
def copy_final_sum_to_host(self, parfor_kernel):
397397
lowerer = self.lowerer
398-
ir_builder = KernelLaunchIRBuilder(
399-
lowerer.context,
400-
lowerer.builder,
401-
parfor_kernel.kernel.addressof_ref(),
402-
)
398+
ir_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
403399

404400
# Create a local variable storing a pointer to a DPCTLSyclQueueRef
405401
# pointer.
@@ -447,4 +443,4 @@ def copy_final_sum_to_host(self, parfor_kernel):
447443
sycl.dpctl_event_wait(builder, event_ref)
448444
sycl.dpctl_event_delete(builder, event_ref)
449445

450-
ir_builder.free_queue(sycl_queue_val=curr_queue)
446+
ir_builder.free_queue(ptr_to_sycl_queue_ref=curr_queue)

0 commit comments

Comments
 (0)