Skip to content

Commit d206e0f

Browse files
authored
Fix multiply (#502)
* start change
1 parent af49a9f commit d206e0f

File tree

3 files changed

+59
-60
lines changed

3 files changed

+59
-60
lines changed

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -326,58 +326,20 @@ cpdef tuple dpnp_modf(dparray x1):
326326

327327

328328
cpdef dparray dpnp_multiply(dparray x1, x2):
329-
x2_is_scalar = dpnp.isscalar(x2)
330-
331-
x1_dtype_ = x1.dtype
332-
x2_dtype_ = type(x2) if x2_is_scalar else x2.dtype
333-
334-
types_map = {float: dpnp.float64, int: dpnp.int64}
335-
x1_dtype = types_map.get(x1_dtype_, x1_dtype_)
336-
x2_dtype = types_map.get(x2_dtype_, x2_dtype_)
337-
338-
if x1_dtype == dpnp.float64:
339-
if x2_dtype == dpnp.float64:
340-
res_type = dpnp.float64
341-
elif x2_dtype == dpnp.float32:
342-
res_type = dpnp.float64
343-
elif x2_dtype == dpnp.int64:
344-
res_type = dpnp.float64
345-
elif x2_dtype == dpnp.int32:
346-
res_type = dpnp.float64
347-
elif x1_dtype == dpnp.float32:
348-
if x2_dtype == dpnp.float64:
349-
res_type = dpnp.float32
350-
elif x2_dtype == dpnp.float32:
351-
res_type = dpnp.float32
352-
elif x2_dtype == dpnp.int64:
353-
res_type = dpnp.float32
354-
elif x2_dtype == dpnp.int32:
355-
res_type = dpnp.float32
356-
elif x1_dtype == dpnp.int64:
357-
if x2_dtype == dpnp.float64:
358-
res_type = dpnp.float64
359-
elif x2_dtype == dpnp.float32:
360-
res_type = dpnp.float32
361-
elif x2_dtype == dpnp.int64:
362-
res_type = dpnp.int64
363-
elif x2_dtype == dpnp.int32:
364-
res_type = dpnp.int64
365-
elif x1_dtype == dpnp.int32:
366-
if x2_dtype == dpnp.float64:
367-
res_type = dpnp.float64
368-
elif x2_dtype == dpnp.float32:
369-
res_type = dpnp.float32
370-
elif x2_dtype == dpnp.int64:
371-
res_type = dpnp.int32
372-
elif x2_dtype == dpnp.int32:
373-
res_type = dpnp.int32
374-
375-
cdef dparray result = dparray(x1.shape, dtype=res_type)
376-
377-
if x2_is_scalar:
378-
for i in range(result.size):
329+
cdef dparray result
330+
if dpnp.isscalar(x2):
331+
x2_ = dpnp.array([x2])
332+
333+
types_map = {
334+
(dpnp.int32, dpnp.float64): dpnp.float64,
335+
(dpnp.int64, dpnp.float64): dpnp.float64,
336+
}
337+
338+
res_type = types_map.get((x1.dtype.type, x2_.dtype.type), x1.dtype)
339+
result = dparray(x1.shape, dtype=res_type)
340+
for i in range(x1.size):
379341
result[i] = x1[i] * x2
380-
return result
342+
return result.reshape(x1.shape)
381343
else:
382344
return call_fptr_2in_1out(DPNP_FN_MULTIPLY, x1, x2, x1.shape)
383345

dpnp/dpnp_iface_mathematical.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -966,16 +966,24 @@ def multiply(x1, x2, **kwargs):
966966
is_x1_scalar = dpnp.isscalar(x1)
967967
is_x2_scalar = dpnp.isscalar(x2)
968968

969-
if (not use_origin_backend(x1) and (is_x1_dparray or is_x1_scalar)) and \
970-
(not use_origin_backend(x2) and (is_x2_dparray or is_x2_scalar)) and \
971-
not (is_x1_scalar and is_x2_scalar) and not kwargs:
972-
973-
if is_x1_scalar:
974-
return dpnp_multiply(x2, x1)
969+
if not use_origin_backend(x1):
970+
if kwargs:
971+
pass
972+
elif not (is_x1_dparray or is_x1_scalar):
973+
pass
974+
elif not (is_x2_dparray or is_x2_scalar):
975+
pass
976+
elif is_x1_scalar and is_x2_scalar:
977+
pass
978+
elif (is_x1_dparray and is_x2_dparray) and (x1.size != x2.size):
979+
pass
980+
elif (is_x1_dparray and is_x2_dparray) and (x1.shape != x2.shape):
981+
pass
975982
else:
976-
if is_x1_dparray and is_x2_dparray:
977-
if (x1.size == x2.size) and (x1.shape == x2.shape):
978-
return dpnp_multiply(x1, x2)
983+
if is_x1_scalar:
984+
return dpnp_multiply(x2, x1)
985+
else:
986+
return dpnp_multiply(x1, x2)
979987

980988
return call_origin(numpy.multiply, x1, x2, **kwargs)
981989

tests/test_mathematical.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,35 @@ def test_diff(array):
2626
numpy.testing.assert_allclose(expected, result)
2727

2828

29+
@pytest.mark.parametrize("val_type",
30+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
31+
ids=['numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32'])
32+
@pytest.mark.parametrize("data_type",
33+
[numpy.float64, numpy.float32, numpy.int64, numpy.int32],
34+
ids=['numpy.float64', 'numpy.float32', 'numpy.int64', 'numpy.int32'])
35+
@pytest.mark.parametrize("val",
36+
[0, 1, 5],
37+
ids=['0', '1', '5'])
38+
@pytest.mark.parametrize("array",
39+
[[[0, 0], [0, 0]],
40+
[[1, 2], [1, 2]],
41+
[[1, 2], [3, 4]],
42+
[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]],
43+
[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]],
44+
ids=['[[0, 0], [0, 0]]',
45+
'[[1, 2], [1, 2]]',
46+
'[[1, 2], [3, 4]]',
47+
'[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]',
48+
'[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]'])
49+
def test_multiply(array, val, data_type, val_type):
50+
a = numpy.array(array, dtype=data_type)
51+
ia = inp.array(a)
52+
val_ = val_type(val)
53+
result = inp.multiply(ia, val_)
54+
expected = numpy.multiply(ia, val_)
55+
numpy.testing.assert_array_equal(expected, result)
56+
57+
2958
@pytest.mark.parametrize("array", [[1, 2, 3, 4, 5],
3059
[1, 2, numpy.nan, 4, 5],
3160
[[1, 2, numpy.nan], [3, -4, -5]]])

0 commit comments

Comments
 (0)