|
37 | 37 |
|
38 | 38 | """
|
39 | 39 |
|
| 40 | +# pylint: disable=no-name-in-module |
40 | 41 | import numpy
|
41 | 42 |
|
42 | 43 | import dpnp
|
43 | 44 |
|
| 45 | +from .dpnp_utils import map_dtype_to_device |
44 | 46 | from .dpnp_utils.dpnp_utils_einsum import dpnp_einsum
|
45 | 47 | from .dpnp_utils.dpnp_utils_linearalgebra import (
|
46 | 48 | dpnp_dot,
|
|
66 | 68 | ]
|
67 | 69 |
|
68 | 70 |
|
| 71 | +# TODO: implement a specific scalar-array kernel |
| 72 | +def _call_multiply(a, b, out=None): |
| 73 | + """Call multiply function for special cases of scalar-array dots.""" |
| 74 | + |
| 75 | + sc, arr = (a, b) if dpnp.isscalar(a) else (b, a) |
| 76 | + sc_dtype = map_dtype_to_device(type(sc), arr.sycl_device) |
| 77 | + res_dtype = dpnp.result_type(sc_dtype, arr) |
| 78 | + if out is not None and out.dtype == arr.dtype: |
| 79 | + res = dpnp.multiply(a, b, out=out) |
| 80 | + else: |
| 81 | + res = dpnp.multiply(a, b, dtype=res_dtype) |
| 82 | + return dpnp.get_result_array(res, out, casting="no") |
| 83 | + |
| 84 | + |
69 | 85 | def dot(a, b, out=None):
|
70 | 86 | """
|
71 | 87 | Dot product of `a` and `b`.
|
@@ -139,8 +155,7 @@ def dot(a, b, out=None):
|
139 | 155 | raise ValueError("Only C-contiguous array is acceptable.")
|
140 | 156 |
|
141 | 157 | if dpnp.isscalar(a) or dpnp.isscalar(b):
|
142 |
| - # TODO: use specific scalar-vector kernel |
143 |
| - return dpnp.multiply(a, b, out=out) |
| 158 | + return _call_multiply(a, b, out=out) |
144 | 159 |
|
145 | 160 | a_ndim = a.ndim
|
146 | 161 | b_ndim = b.ndim
|
@@ -635,8 +650,7 @@ def inner(a, b):
|
635 | 650 | dpnp.check_supported_arrays_type(a, b, scalar_type=True)
|
636 | 651 |
|
637 | 652 | if dpnp.isscalar(a) or dpnp.isscalar(b):
|
638 |
| - # TODO: use specific scalar-vector kernel |
639 |
| - return dpnp.multiply(a, b) |
| 653 | + return _call_multiply(a, b) |
640 | 654 |
|
641 | 655 | if a.ndim == 0 or b.ndim == 0:
|
642 | 656 | # TODO: use specific scalar-vector kernel
|
@@ -714,8 +728,7 @@ def kron(a, b):
|
714 | 728 | dpnp.check_supported_arrays_type(a, b, scalar_type=True)
|
715 | 729 |
|
716 | 730 | if dpnp.isscalar(a) or dpnp.isscalar(b):
|
717 |
| - # TODO: use specific scalar-vector kernel |
718 |
| - return dpnp.multiply(a, b) |
| 731 | + return _call_multiply(a, b) |
719 | 732 |
|
720 | 733 | a_ndim = a.ndim
|
721 | 734 | b_ndim = b.ndim
|
@@ -1199,8 +1212,7 @@ def tensordot(a, b, axes=2):
|
1199 | 1212 | raise ValueError(
|
1200 | 1213 | "One of the inputs is scalar, axes should be zero."
|
1201 | 1214 | )
|
1202 |
| - # TODO: use specific scalar-vector kernel |
1203 |
| - return dpnp.multiply(a, b) |
| 1215 | + return _call_multiply(a, b) |
1204 | 1216 |
|
1205 | 1217 | return dpnp_tensordot(a, b, axes=axes)
|
1206 | 1218 |
|
@@ -1263,13 +1275,13 @@ def vdot(a, b):
|
1263 | 1275 | if b.size != 1:
|
1264 | 1276 | raise ValueError("The second array should be of size one.")
|
1265 | 1277 | a_conj = numpy.conj(a)
|
1266 |
| - return dpnp.multiply(a_conj, b) |
| 1278 | + return _call_multiply(a_conj, b) |
1267 | 1279 |
|
1268 | 1280 | if dpnp.isscalar(b):
|
1269 | 1281 | if a.size != 1:
|
1270 | 1282 | raise ValueError("The first array should be of size one.")
|
1271 | 1283 | a_conj = dpnp.conj(a)
|
1272 |
| - return dpnp.multiply(a_conj, b) |
| 1284 | + return _call_multiply(a_conj, b) |
1273 | 1285 |
|
1274 | 1286 | if a.ndim == 1 and b.ndim == 1:
|
1275 | 1287 | return dpnp_dot(a, b, out=None, conjugate=True)
|
|
0 commit comments