Skip to content

Commit d386917

Browse files
authored
Bring vecdot implementation in line with the one in array-api-compat (#402)
1 parent b4e94b0 commit d386917

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

cubed/array_api/linear_algebra_functions.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33

44
from cubed.array_api.data_type_functions import result_type
55
from cubed.array_api.dtypes import _numeric_dtypes
6-
from cubed.array_api.manipulation_functions import expand_dims
6+
from cubed.array_api.manipulation_functions import (
7+
broadcast_arrays,
8+
expand_dims,
9+
moveaxis,
10+
)
711
from cubed.backend_array_api import namespace as nxp
812
from cubed.core import blockwise, reduction, squeeze
913

@@ -158,12 +162,21 @@ def _tensordot(a, b, axes):
158162

159163

160164
def vecdot(x1, x2, /, *, axis=-1, use_new_impl=True, split_every=None):
165+
# based on the implementation in array-api-compat
161166
if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
162167
raise TypeError("Only numeric dtypes are allowed in vecdot")
163-
return tensordot(
164-
x1,
165-
x2,
166-
axes=((axis,), (axis,)),
168+
169+
if x1.shape[axis] != x2.shape[axis]:
170+
raise ValueError("x1 and x2 must have the same size along the given axis")
171+
172+
x1_ = moveaxis(x1, axis, -1)
173+
x2_ = moveaxis(x2, axis, -1)
174+
x1_, x2_ = broadcast_arrays(x1_, x2_)
175+
176+
res = matmul(
177+
x1_[..., None, :],
178+
x2_[..., None],
167179
use_new_impl=use_new_impl,
168180
split_every=split_every,
169181
)
182+
return res[..., 0, 0]

0 commit comments

Comments
 (0)