|
37 | 37 |
|
38 | 38 | """ |
39 | 39 |
|
| 40 | +# pylint: disable=no-name-in-module |
40 | 41 | import numpy |
41 | 42 |
|
42 | 43 | import dpnp |
43 | 44 |
|
| 45 | +from .dpnp_utils import map_dtype_to_device |
44 | 46 | from .dpnp_utils.dpnp_utils_einsum import dpnp_einsum |
45 | 47 | from .dpnp_utils.dpnp_utils_linearalgebra import ( |
46 | 48 | dpnp_dot, |
|
64 | 66 | ] |
65 | 67 |
|
66 | 68 |
|
| 69 | +# TODO: implement a specific scalar-array kernel |
| 70 | +def _call_multiply(a, b, out=None): |
| 71 | + """Call multiply function for special cases of scalar-array dots.""" |
| 72 | + |
| 73 | + sc, arr = (a, b) if dpnp.isscalar(a) else (b, a) |
| 74 | + sc_dtype = map_dtype_to_device(type(sc), arr.sycl_device) |
| 75 | + res_dtype = dpnp.result_type(sc_dtype, arr) |
| 76 | + res = dpnp.multiply(a, b, dtype=res_dtype) |
| 77 | + return dpnp.get_result_array(res, out, casting="no") |
| 78 | + |
| 79 | + |
67 | 80 | def dot(a, b, out=None): |
68 | 81 | """ |
69 | 82 | Dot product of `a` and `b`. |
@@ -137,8 +150,7 @@ def dot(a, b, out=None): |
137 | 150 | raise ValueError("Only C-contiguous array is acceptable.") |
138 | 151 |
|
139 | 152 | if dpnp.isscalar(a) or dpnp.isscalar(b): |
140 | | - # TODO: use specific scalar-vector kernel |
141 | | - return dpnp.multiply(a, b, out=out) |
| 153 | + return _call_multiply(a, b, out=out) |
142 | 154 |
|
143 | 155 | a_ndim = a.ndim |
144 | 156 | b_ndim = b.ndim |
@@ -627,8 +639,7 @@ def inner(a, b): |
627 | 639 | dpnp.check_supported_arrays_type(a, b, scalar_type=True) |
628 | 640 |
|
629 | 641 | if dpnp.isscalar(a) or dpnp.isscalar(b): |
630 | | - # TODO: use specific scalar-vector kernel |
631 | | - return dpnp.multiply(a, b) |
| 642 | + return _call_multiply(a, b) |
632 | 643 |
|
633 | 644 | if a.ndim == 0 or b.ndim == 0: |
634 | 645 | # TODO: use specific scalar-vector kernel |
@@ -706,8 +717,7 @@ def kron(a, b): |
706 | 717 | dpnp.check_supported_arrays_type(a, b, scalar_type=True) |
707 | 718 |
|
708 | 719 | if dpnp.isscalar(a) or dpnp.isscalar(b): |
709 | | - # TODO: use specific scalar-vector kernel |
710 | | - return dpnp.multiply(a, b) |
| 720 | + return _call_multiply(a, b) |
711 | 721 |
|
712 | 722 | a_ndim = a.ndim |
713 | 723 | b_ndim = b.ndim |
@@ -1043,8 +1053,7 @@ def tensordot(a, b, axes=2): |
1043 | 1053 | raise ValueError( |
1044 | 1054 | "One of the inputs is scalar, axes should be zero." |
1045 | 1055 | ) |
1046 | | - # TODO: use specific scalar-vector kernel |
1047 | | - return dpnp.multiply(a, b) |
| 1056 | + return _call_multiply(a, b) |
1048 | 1057 |
|
1049 | 1058 | return dpnp_tensordot(a, b, axes=axes) |
1050 | 1059 |
|
@@ -1107,13 +1116,13 @@ def vdot(a, b): |
1107 | 1116 | if b.size != 1: |
1108 | 1117 | raise ValueError("The second array should be of size one.") |
1109 | 1118 | a_conj = numpy.conj(a) |
1110 | | - return dpnp.multiply(a_conj, b) |
| 1119 | + return _call_multiply(a_conj, b) |
1111 | 1120 |
|
1112 | 1121 | if dpnp.isscalar(b): |
1113 | 1122 | if a.size != 1: |
1114 | 1123 | raise ValueError("The first array should be of size one.") |
1115 | 1124 | a_conj = dpnp.conj(a) |
1116 | | - return dpnp.multiply(a_conj, b) |
| 1125 | + return _call_multiply(a_conj, b) |
1117 | 1126 |
|
1118 | 1127 | if a.ndim == 1 and b.ndim == 1: |
1119 | 1128 | return dpnp_dot(a, b, out=None, conjugate=True) |
|
0 commit comments