|
69 | 69 |
|
70 | 70 |
|
71 | 71 | # TODO: implement a specific scalar-array kernel |
72 | | -def _call_multiply(a, b, out=None): |
73 | | - """Call multiply function for special cases of scalar-array dots.""" |
| 72 | +def _call_multiply(a, b, out=None, outer_calc=False): |
| 73 | + """ |
| 74 | + Call multiply function for special cases of scalar-array dots. |
| 75 | +
|
| 76 | + if `sc` is an scalar and `a` is an array of type float32, we have |
| 77 | + dpnp.multiply(a, sc).dtype == dpnp.float32 and |
| 78 | + numpy.multiply(a, sc).dtype == dpnp.float32. |
| 79 | +
|
| 80 | + However, for scalar-array dots such as dot function we have |
| 81 | + dpnp.dot(a, sc).dtype == dpnp.float32 while |
| 82 | + numpy.dot(a, sc).dtype == dpnp.float64. |
| 83 | +
|
| 84 | + We need to adjust the behavior of the multiply function when it is |
| 85 | + being used for special cases of scalar-array dots. |
| 86 | +
|
| 87 | + """ |
74 | 88 |
|
75 | 89 | sc, arr = (a, b) if dpnp.isscalar(a) else (b, a) |
76 | 90 | sc_dtype = map_dtype_to_device(type(sc), arr.sycl_device) |
77 | 91 | res_dtype = dpnp.result_type(sc_dtype, arr) |
| 92 | + multiply_func = dpnp.multiply.outer if outer_calc else dpnp.multiply |
78 | 93 | if out is not None and out.dtype == arr.dtype: |
79 | | - res = dpnp.multiply(a, b, out=out) |
| 94 | + res = multiply_func(a, b, out=out) |
80 | 95 | else: |
81 | | - res = dpnp.multiply(a, b, dtype=res_dtype) |
| 96 | + res = multiply_func(a, b, dtype=res_dtype) |
82 | 97 | return dpnp.get_result_array(res, out, casting="no") |
83 | 98 |
|
84 | 99 |
|
@@ -1109,16 +1124,15 @@ def outer(a, b, out=None): |
1109 | 1124 |
|
1110 | 1125 | dpnp.check_supported_arrays_type(a, b, scalar_type=True, all_scalars=False) |
1111 | 1126 | if dpnp.isscalar(a): |
1112 | | - x1 = a |
1113 | 1127 | x2 = dpnp.ravel(b)[None, :] |
| 1128 | + result = _call_multiply(a, x2, out=out, outer_calc=True) |
1114 | 1129 | elif dpnp.isscalar(b): |
1115 | 1130 | x1 = dpnp.ravel(a)[:, None] |
1116 | | - x2 = b |
| 1131 | + result = _call_multiply(x1, b, out=out, outer_calc=True) |
1117 | 1132 | else: |
1118 | | - x1 = dpnp.ravel(a) |
1119 | | - x2 = dpnp.ravel(b) |
| 1133 | + result = dpnp.multiply.outer(dpnp.ravel(a), dpnp.ravel(b), out=out) |
1120 | 1134 |
|
1121 | | - return dpnp.multiply.outer(x1, x2, out=out) |
| 1135 | + return result |
1122 | 1136 |
|
1123 | 1137 |
|
1124 | 1138 | def tensordot(a, b, axes=2): |
@@ -1288,13 +1302,13 @@ def vdot(a, b): |
1288 | 1302 | if b.size != 1: |
1289 | 1303 | raise ValueError("The second array should be of size one.") |
1290 | 1304 | a_conj = numpy.conj(a) |
1291 | | - return _call_multiply(a_conj, b) |
| 1305 | + return dpnp.squeeze(_call_multiply(a_conj, b)) |
1292 | 1306 |
|
1293 | 1307 | if dpnp.isscalar(b): |
1294 | 1308 | if a.size != 1: |
1295 | 1309 | raise ValueError("The first array should be of size one.") |
1296 | 1310 | a_conj = dpnp.conj(a) |
1297 | | - return _call_multiply(a_conj, b) |
| 1311 | + return dpnp.squeeze(_call_multiply(a_conj, b)) |
1298 | 1312 |
|
1299 | 1313 | if a.ndim == 1 and b.ndim == 1: |
1300 | 1314 | return dpnp_dot(a, b, out=None, conjugate=True) |
|
0 commit comments