Skip to content

Commit af49a9f

Browse files
authored
Fix power (#500)
* change power
1 parent 77df63f commit af49a9f

File tree

4 files changed

+61
-14
lines changed

4 files changed

+61
-14
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ cpdef dparray dpnp_maximum(dparray array1, dparray array2)
238238
cpdef dparray dpnp_minimum(dparray array1, dparray array2)
239239
cpdef dparray dpnp_multiply(dparray array1, array2)
240240
cpdef dparray dpnp_negative(dparray array1)
241-
cpdef dparray dpnp_power(dparray array1, dparray array2)
241+
cpdef dparray dpnp_power(dparray array1, array2)
242242
cpdef dparray dpnp_remainder(dparray array1, dparray array2)
243243
cpdef dparray dpnp_sin(dparray array1)
244244
cpdef dparray dpnp_subtract(dparray array1, dparray array2)

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,24 @@ cpdef dparray dpnp_negative(dparray array1):
422422
return result
423423

424424

425-
cpdef dparray dpnp_power(dparray x1, dparray x2):
426-
return call_fptr_2in_1out(DPNP_FN_POWER, x1, x2, x1.shape)
425+
cpdef dparray dpnp_power(dparray x1, x2):
426+
cdef dparray result
427+
if dpnp.isscalar(x2):
428+
x2_ = dpnp.array([x2])
429+
430+
types_map = {
431+
(dpnp.int32, dpnp.float64): dpnp.float64,
432+
(dpnp.int64, dpnp.float64): dpnp.float64,
433+
}
434+
435+
res_type = types_map.get((x1.dtype.type, x2_.dtype.type), x1.dtype)
436+
437+
result = dparray(x1.shape, dtype=res_type)
438+
for i in range(x1.size):
439+
result[i] = x1[i] ** x2
440+
return result
441+
else:
442+
return call_fptr_2in_1out(DPNP_FN_POWER, x1, x2, x1.shape)
427443

428444

429445
cpdef dpnp_prod(dparray x1):

dpnp/dpnp_iface_mathematical.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,17 +1183,19 @@ def power(x1, x2, **kwargs):
11831183
11841184
"""
11851185

1186-
is_x1_dparray = isinstance(x1, dparray)
1187-
is_x2_dparray = isinstance(x2, dparray)
1188-
1189-
if (not use_origin_backend(x1) and is_x1_dparray and is_x2_dparray and not kwargs):
1190-
if (x1.size != x2.size):
1191-
checker_throw_value_error("power", "size", x1.size, x2.size)
1192-
1193-
if (x1.shape != x2.shape):
1194-
checker_throw_value_error("power", "shape", x1.shape, x2.shape)
1195-
1196-
return dpnp_power(x1, x2)
1186+
if not use_origin_backend(x1):
1187+
if kwargs:
1188+
pass
1189+
elif not isinstance(x1, dparray):
1190+
pass
1191+
elif not isinstance(x2, dparray) and not dpnp.isscalar(x2):
1192+
pass
1193+
elif isinstance(x2, dparray) and x1.size != x2.size:
1194+
pass
1195+
elif isinstance(x2, dparray) and x1.shape != x2.shape:
1196+
pass
1197+
else:
1198+
return dpnp_power(x1, x2)
11971199

11981200
return call_origin(numpy.power, x1, x2, **kwargs)
11991201

tests/test_mathematical.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,35 @@ def test_nancumsum(array):
5050
numpy.testing.assert_array_equal(expected, result)
5151

5252

53+
@pytest.mark.parametrize("val_type",
54+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
55+
ids=['numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32'])
56+
@pytest.mark.parametrize("data_type",
57+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
58+
ids=['numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32'])
59+
@pytest.mark.parametrize("val",
60+
[0, 1, 5],
61+
ids=['0', '1', '5'])
62+
@pytest.mark.parametrize("array",
63+
[[[0, 0], [0, 0]],
64+
[[1, 2], [1, 2]],
65+
[[1, 2], [3, 4]],
66+
[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]],
67+
[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]],
68+
ids=['[[0, 0], [0, 0]]',
69+
'[[1, 2], [1, 2]]',
70+
'[[1, 2], [3, 4]]',
71+
'[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]',
72+
'[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]'])
73+
def test_power(array, val, data_type, val_type):
74+
a = numpy.array(array, dtype=data_type)
75+
ia = inp.array(a)
76+
val_ = val_type(val)
77+
result = inp.power(ia, val_)
78+
expected = numpy.power(ia, val_)
79+
numpy.testing.assert_array_equal(expected, result)
80+
81+
5382
class TestEdiff1d:
5483

5584
def test_ediff1d_int(self):

0 commit comments

Comments
 (0)