@@ -278,10 +278,13 @@ def _generate_kernel_stub_as_string(self):
278
278
)
279
279
280
280
for redvar in self ._redvars :
281
+ rtyp = str (self ._typemap [redvar ])
281
282
legal_redvar = self ._redvars_dict [redvar ]
282
283
gufunc_txt += " "
283
284
gufunc_txt += legal_redvar + " = "
284
- gufunc_txt += f"{ self ._parfor_reddict [redvar ].init_val } \n "
285
+ gufunc_txt += (
286
+ f"dpnp.{ rtyp } ({ self ._parfor_reddict [redvar ].init_val } )\n "
287
+ )
285
288
286
289
gufunc_txt += (
287
290
" "
@@ -290,32 +293,17 @@ def _generate_kernel_stub_as_string(self):
290
293
+ f"{ self ._global_size_var_name [0 ]} + j\n "
291
294
)
292
295
293
- for redvar in self ._redvars :
294
- rtyp = str (self ._typemap [redvar ])
295
- redvar = self ._redvars_dict [redvar ]
296
- gufunc_txt += (
297
- " "
298
- + f"local_sums_{ redvar } = "
299
- + f"dpex.local.array(1, dpnp.{ rtyp } )\n "
300
- )
301
-
302
296
gufunc_txt += " " + self ._sentinel_name + " = 0\n "
303
297
304
- for i , redvar in enumerate (self ._redvars ):
305
- legal_redvar = self ._redvars_dict [redvar ]
306
- gufunc_txt += (
307
- " " + f"local_sums_{ legal_redvar } [0] = { legal_redvar } \n "
308
- )
309
-
310
298
for i , redvar in enumerate (self ._redvars ):
311
299
legal_redvar = self ._redvars_dict [redvar ]
312
300
redop = self ._parfor_reddict [redvar ].redop
313
301
if redop == operator .iadd :
314
302
gufunc_txt += f" { self ._final_sum_var_name [i ]} [0] += \
315
- local_sums_ { legal_redvar } [0] \n "
303
+ { legal_redvar } \n "
316
304
elif redop == operator .imul :
317
305
gufunc_txt += f" { self ._final_sum_var_name [i ]} [0] *= \
318
- local_sums_ { legal_redvar } [0] \n "
306
+ { legal_redvar } \n "
319
307
else :
320
308
raise NotImplementedError
321
309
0 commit comments