Skip to content

Commit 40c465b

Browse files
committed
Remove local arrays in reduction remainder kernel
1 parent d5d48c0 commit 40c465b

File tree

1 file changed

+6
-18
lines changed

1 file changed

+6
-18
lines changed

numba_dpex/core/parfors/kernel_templates/reduction_template.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,13 @@ def _generate_kernel_stub_as_string(self):
278278
)
279279

280280
for redvar in self._redvars:
281+
rtyp = str(self._typemap[redvar])
281282
legal_redvar = self._redvars_dict[redvar]
282283
gufunc_txt += " "
283284
gufunc_txt += legal_redvar + " = "
284-
gufunc_txt += f"{self._parfor_reddict[redvar].init_val}\n"
285+
gufunc_txt += (
286+
f"dpnp.{rtyp}({self._parfor_reddict[redvar].init_val})\n"
287+
)
285288

286289
gufunc_txt += (
287290
" "
@@ -290,32 +293,17 @@ def _generate_kernel_stub_as_string(self):
290293
+ f"{self._global_size_var_name[0]} + j\n"
291294
)
292295

293-
for redvar in self._redvars:
294-
rtyp = str(self._typemap[redvar])
295-
redvar = self._redvars_dict[redvar]
296-
gufunc_txt += (
297-
" "
298-
+ f"local_sums_{redvar} = "
299-
+ f"dpex.local.array(1, dpnp.{rtyp})\n"
300-
)
301-
302296
gufunc_txt += " " + self._sentinel_name + " = 0\n"
303297

304-
for i, redvar in enumerate(self._redvars):
305-
legal_redvar = self._redvars_dict[redvar]
306-
gufunc_txt += (
307-
" " + f"local_sums_{legal_redvar}[0] = {legal_redvar}\n"
308-
)
309-
310298
for i, redvar in enumerate(self._redvars):
311299
legal_redvar = self._redvars_dict[redvar]
312300
redop = self._parfor_reddict[redvar].redop
313301
if redop == operator.iadd:
314302
gufunc_txt += f" {self._final_sum_var_name[i]}[0] += \
315-
local_sums_{legal_redvar}[0]\n"
303+
{legal_redvar}\n"
316304
elif redop == operator.imul:
317305
gufunc_txt += f" {self._final_sum_var_name[i]}[0] *= \
318-
local_sums_{legal_redvar}[0]\n"
306+
{legal_redvar}\n"
319307
else:
320308
raise NotImplementedError
321309

0 commit comments

Comments
 (0)