@@ -152,7 +152,7 @@ def _get_max_min(dtype):
152152 return f .max , f .min
153153
154154
155- def _get_reduction_res_dt (a , dtype , _out ):
155+ def _get_reduction_res_dt (a , dtype ):
156156 """Get a data type used by dpctl for result array in reduction function."""
157157
158158 if dtype is None :
@@ -1106,11 +1106,10 @@ def cumprod(a, axis=None, dtype=None, out=None):
11061106 usm_a = dpnp .get_usm_ndarray (a )
11071107
11081108 return dpnp_wrap_reduction_call (
1109- a ,
1109+ usm_a ,
11101110 out ,
11111111 dpt .cumulative_prod ,
1112- _get_reduction_res_dt ,
1113- usm_a ,
1112+ _get_reduction_res_dt (a , dtype ),
11141113 axis = axis ,
11151114 dtype = dtype ,
11161115 )
@@ -1196,11 +1195,10 @@ def cumsum(a, axis=None, dtype=None, out=None):
11961195 usm_a = dpnp .get_usm_ndarray (a )
11971196
11981197 return dpnp_wrap_reduction_call (
1199- a ,
1198+ usm_a ,
12001199 out ,
12011200 dpt .cumulative_sum ,
1202- _get_reduction_res_dt ,
1203- usm_a ,
1201+ _get_reduction_res_dt (a , dtype ),
12041202 axis = axis ,
12051203 dtype = dtype ,
12061204 )
@@ -1281,11 +1279,10 @@ def cumulative_prod(
12811279 """
12821280
12831281 return dpnp_wrap_reduction_call (
1284- x ,
1282+ dpnp . get_usm_ndarray ( x ) ,
12851283 out ,
12861284 dpt .cumulative_prod ,
1287- _get_reduction_res_dt ,
1288- dpnp .get_usm_ndarray (x ),
1285+ _get_reduction_res_dt (x , dtype ),
12891286 axis = axis ,
12901287 dtype = dtype ,
12911288 include_initial = include_initial ,
@@ -1373,11 +1370,10 @@ def cumulative_sum(
13731370 """
13741371
13751372 return dpnp_wrap_reduction_call (
1376- x ,
1373+ dpnp . get_usm_ndarray ( x ) ,
13771374 out ,
13781375 dpt .cumulative_sum ,
1379- _get_reduction_res_dt ,
1380- dpnp .get_usm_ndarray (x ),
1376+ _get_reduction_res_dt (x , dtype ),
13811377 axis = axis ,
13821378 dtype = dtype ,
13831379 include_initial = include_initial ,
@@ -3524,11 +3520,10 @@ def prod(
35243520 usm_a = dpnp .get_usm_ndarray (a )
35253521
35263522 return dpnp_wrap_reduction_call (
3527- a ,
3523+ usm_a ,
35283524 out ,
35293525 dpt .prod ,
3530- _get_reduction_res_dt ,
3531- usm_a ,
3526+ _get_reduction_res_dt (a , dtype ),
35323527 axis = axis ,
35333528 dtype = dtype ,
35343529 keepdims = keepdims ,
@@ -4297,11 +4292,10 @@ def sum(
42974292
42984293 usm_a = dpnp .get_usm_ndarray (a )
42994294 return dpnp_wrap_reduction_call (
4300- a ,
4295+ usm_a ,
43014296 out ,
43024297 dpt .sum ,
4303- _get_reduction_res_dt ,
4304- usm_a ,
4298+ _get_reduction_res_dt (a , dtype ),
43054299 axis = axis ,
43064300 dtype = dtype ,
43074301 keepdims = keepdims ,
0 commit comments