@@ -278,10 +278,13 @@ def _generate_kernel_stub_as_string(self):
278278 )
279279
280280 for redvar in self ._redvars :
281+ rtyp = str (self ._typemap [redvar ])
281282 legal_redvar = self ._redvars_dict [redvar ]
282283 gufunc_txt += " "
283284 gufunc_txt += legal_redvar + " = "
284- gufunc_txt += f"{ self ._parfor_reddict [redvar ].init_val } \n "
285+ gufunc_txt += (
286+ f"dpnp.{ rtyp } ({ self ._parfor_reddict [redvar ].init_val } )\n "
287+ )
285288
286289 gufunc_txt += (
287290 " "
@@ -290,32 +293,17 @@ def _generate_kernel_stub_as_string(self):
290293 + f"{ self ._global_size_var_name [0 ]} + j\n "
291294 )
292295
293- for redvar in self ._redvars :
294- rtyp = str (self ._typemap [redvar ])
295- redvar = self ._redvars_dict [redvar ]
296- gufunc_txt += (
297- " "
298- + f"local_sums_{ redvar } = "
299- + f"dpex.local.array(1, dpnp.{ rtyp } )\n "
300- )
301-
302296 gufunc_txt += " " + self ._sentinel_name + " = 0\n "
303297
304- for i , redvar in enumerate (self ._redvars ):
305- legal_redvar = self ._redvars_dict [redvar ]
306- gufunc_txt += (
307- " " + f"local_sums_{ legal_redvar } [0] = { legal_redvar } \n "
308- )
309-
310298 for i , redvar in enumerate (self ._redvars ):
311299 legal_redvar = self ._redvars_dict [redvar ]
312300 redop = self ._parfor_reddict [redvar ].redop
313301 if redop == operator .iadd :
314302 gufunc_txt += f" { self ._final_sum_var_name [i ]} [0] += \
315- local_sums_ { legal_redvar } [0] \n "
303+ { legal_redvar } \n "
316304 elif redop == operator .imul :
317305 gufunc_txt += f" { self ._final_sum_var_name [i ]} [0] *= \
318- local_sums_ { legal_redvar } [0] \n "
306+ { legal_redvar } \n "
319307 else :
320308 raise NotImplementedError
321309
0 commit comments