Skip to content

Commit 23e78a5

Browse files
committed
Remove code duplicates from parfor lowerer
1 parent f4b618f commit 23e78a5

File tree

1 file changed

+42
-87
lines changed

1 file changed

+42
-87
lines changed

numba_dpex/core/parfors/parfor_lowerer.py

Lines changed: 42 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from ..exceptions import UnsupportedParforError
2424
from ..types.dpnp_ndarray_type import DpnpNdArray
25-
from .kernel_builder import create_kernel_for_parfor
25+
from .kernel_builder import ParforKernel, create_kernel_for_parfor
2626
from .reduction_kernel_builder import (
2727
create_reduction_main_kernel_for_parfor,
2828
create_reduction_remainder_kernel_for_parfor,
@@ -91,7 +91,9 @@ class ParforLowerImpl:
9191
for a parfor and submits it to a queue.
9292
"""
9393

94-
def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
94+
def _build_kernel_arglist(
95+
self, kernel_fn, lowerer, kernel_builder: KernelLaunchIRBuilder
96+
):
9597
"""Creates local variables for all the arguments and the argument types
9698
that are passes to the kernel function.
9799
@@ -139,27 +141,15 @@ def _build_kernel_arglist(self, kernel_fn, lowerer, kernel_builder):
139141
arg_types=args_ty_list,
140142
)
141143

142-
def _submit_parfor_kernel(
144+
def _loop_ranges(
143145
self,
144146
lowerer,
145-
kernel_fn,
146147
loop_ranges,
147148
):
148-
"""
149-
Adds a call to submit a kernel function into the function body of the
150-
current Numba JIT compiled function.
151-
"""
152-
# Ensure that the Python arguments are kept alive for the duration of
153-
# the kernel execution
154-
keep_alive_kernels.append(kernel_fn.kernel)
155-
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
156-
157-
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
158-
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
159-
160149
# Create a global range over which to submit the kernel based on the
161150
# loop_ranges of the parfor
162151
global_range = []
152+
163153
# SYCL ranges can have at max 3 dimension. If the parfor is of a higher
164154
# dimension then the indexing for the higher dimensions is done inside
165155
# the kernel.
@@ -176,45 +166,13 @@ def _submit_parfor_kernel(
176166

177167
local_range = []
178168

179-
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
180-
kernel_ref = lowerer.builder.inttoptr(
181-
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
182-
cgutils.voidptr_t,
183-
)
184-
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
185-
186-
# Submit a synchronous kernel
187-
kernel_builder.submit_sycl_kernel(
188-
sycl_kernel_ref=kernel_ref,
189-
sycl_queue_ref=curr_queue_ref,
190-
total_kernel_args=args.num_flattened_args,
191-
arg_list=args.arg_vals,
192-
arg_ty_list=args.arg_types,
193-
global_range=global_range,
194-
local_range=local_range,
195-
)
196-
197-
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
198-
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
169+
return global_range, local_range
199170

200-
def _submit_reduction_main_parfor_kernel(
171+
def _reduction_ranges(
201172
self,
202173
lowerer,
203-
kernel_fn,
204174
reductionHelper=None,
205175
):
206-
"""
207-
Adds a call to submit the main kernel of a parfor reduction into the
208-
function body of the current Numba JIT compiled function.
209-
"""
210-
# Ensure that the Python arguments are kept alive for the duration of
211-
# the kernel execution
212-
keep_alive_kernels.append(kernel_fn.kernel)
213-
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
214-
215-
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
216-
217-
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
218176
# Create a global range over which to submit the kernel based on the
219177
# loop_ranges of the parfor
220178
global_range = []
@@ -228,54 +186,39 @@ def _submit_reduction_main_parfor_kernel(
228186
_load_range(lowerer, reductionHelper.work_group_size)
229187
)
230188

231-
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
232-
kernel_ref = lowerer.builder.inttoptr(
233-
lowerer.context.get_constant(types.uintp, kernel_ref_addr),
234-
cgutils.voidptr_t,
235-
)
236-
curr_queue_ref = lowerer.builder.load(ptr_to_queue_ref)
189+
return global_range, local_range
237190

238-
# Submit a synchronous kernel
239-
kernel_builder.submit_sycl_kernel(
240-
sycl_kernel_ref=kernel_ref,
241-
sycl_queue_ref=curr_queue_ref,
242-
total_kernel_args=args.num_flattened_args,
243-
arg_list=args.arg_vals,
244-
arg_ty_list=args.arg_types,
245-
global_range=global_range,
246-
local_range=local_range,
247-
)
191+
def _remainder_ranges(self, lowerer):
192+
# Create a global range over which to submit the kernel based on the
193+
# loop_ranges of the parfor
194+
global_range = []
248195

249-
# At this point we can free the DPCTLSyclQueueRef (curr_queue)
250-
kernel_builder.free_queue(ptr_to_sycl_queue_ref=ptr_to_queue_ref)
196+
stop = _load_range(lowerer, 1)
197+
198+
global_range.append(stop)
251199

252-
def _submit_reduction_remainder_parfor_kernel(
200+
local_range = []
201+
202+
return global_range, local_range
203+
204+
def _submit_parfor_kernel(
253205
self,
254206
lowerer,
255-
kernel_fn,
207+
kernel_fn: ParforKernel,
208+
global_range,
209+
local_range,
256210
):
257211
"""
258-
Adds a call to submit the remainder kernel of a parfor reduction into
259-
the function body of the current Numba JIT compiled function.
212+
Adds a call to submit a kernel function into the function body of the
213+
current Numba JIT compiled function.
260214
"""
261215
# Ensure that the Python arguments are kept alive for the duration of
262216
# the kernel execution
263217
keep_alive_kernels.append(kernel_fn.kernel)
264-
265218
kernel_builder = KernelLaunchIRBuilder(lowerer.context, lowerer.builder)
266219

267220
ptr_to_queue_ref = kernel_builder.get_queue(exec_queue=kernel_fn.queue)
268-
269221
args = self._build_kernel_arglist(kernel_fn, lowerer, kernel_builder)
270-
# Create a global range over which to submit the kernel based on the
271-
# loop_ranges of the parfor
272-
global_range = []
273-
274-
stop = _load_range(lowerer, 1)
275-
276-
global_range.append(stop)
277-
278-
local_range = []
279222

280223
kernel_ref_addr = kernel_fn.kernel.addressof_ref()
281224
kernel_ref = lowerer.builder.inttoptr(
@@ -360,10 +303,15 @@ def _reduction_codegen(
360303
parfor_reddict,
361304
)
362305

363-
self._submit_reduction_main_parfor_kernel(
306+
global_range, local_range = self._reduction_ranges(
307+
lowerer, reductionHelperList[0]
308+
)
309+
310+
self._submit_parfor_kernel(
364311
lowerer,
365312
parfor_kernel,
366-
reductionHelperList[0],
313+
global_range,
314+
local_range,
367315
)
368316

369317
parfor_kernel = create_reduction_remainder_kernel_for_parfor(
@@ -376,9 +324,13 @@ def _reduction_codegen(
376324
reductionHelperList,
377325
)
378326

379-
self._submit_reduction_remainder_parfor_kernel(
327+
global_range, local_range = self._remainder_ranges(lowerer)
328+
329+
self._submit_parfor_kernel(
380330
lowerer,
381331
parfor_kernel,
332+
global_range,
333+
local_range,
382334
)
383335

384336
reductionKernelVar.copy_final_sum_to_host(parfor_kernel)
@@ -492,11 +444,14 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
492444
# FIXME: Make the exception more informative
493445
raise UnsupportedParforError
494446

447+
global_range, local_range = self._loop_ranges(lowerer, loop_ranges)
448+
495449
# Finally submit the kernel
496450
self._submit_parfor_kernel(
497451
lowerer,
498452
parfor_kernel,
499-
loop_ranges,
453+
global_range,
454+
local_range,
500455
)
501456

502457
# TODO: free the kernel at this point

0 commit comments

Comments
 (0)