|
3 | 3 |
|
4 | 4 | from cubed.array_api.data_type_functions import result_type
|
5 | 5 | from cubed.array_api.dtypes import _numeric_dtypes
|
6 |
| -from cubed.array_api.manipulation_functions import expand_dims |
| 6 | +from cubed.array_api.manipulation_functions import ( |
| 7 | + broadcast_arrays, |
| 8 | + expand_dims, |
| 9 | + moveaxis, |
| 10 | +) |
7 | 11 | from cubed.backend_array_api import namespace as nxp
|
8 | 12 | from cubed.core import blockwise, reduction, squeeze
|
9 | 13 |
|
@@ -158,12 +162,21 @@ def _tensordot(a, b, axes):
|
158 | 162 |
|
159 | 163 |
|
160 | 164 | def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
|
| 165 | + # based on the implementation in array-api-compat |
161 | 166 | if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
|
162 | 167 | raise TypeError("Only numeric dtypes are allowed in vecdot")
|
163 |
| - return tensordot( |
164 |
| - x1, |
165 |
| - x2, |
166 |
| - axes=((axis,), (axis,)), |
| 168 | + |
| 169 | + if x1.shape[axis] != x2.shape[axis]: |
| 170 | + raise ValueError("x1 and x2 must have the same size along the given axis") |
| 171 | + |
| 172 | + x1_ = moveaxis(x1, axis, -1) |
| 173 | + x2_ = moveaxis(x2, axis, -1) |
| 174 | + x1_, x2_ = broadcast_arrays(x1_, x2_) |
| 175 | + |
| 176 | + res = matmul( |
| 177 | + x1_[..., None, :], |
| 178 | + x2_[..., None], |
167 | 179 | use_new_impl=use_new_impl,
|
168 | 180 | split_every=split_every,
|
169 | 181 | )
|
| 182 | + return res[..., 0, 0] |
0 commit comments