Skip to content

Commit 2a65b1f

Browse files
committed
Attempt fix for array API bug.
1 parent 905f179 commit 2a65b1f

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

sparse/numba_backend/_common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ def tensordot(a, b, axes=2, *, return_type=None):
157157
ndb = b.ndim
158158
equal = True
159159
if nda == 0 or ndb == 0:
160+
if axes_a == [] and axes_b == []:
161+
if nda == 0 and isinstance(a, SparseArray):
162+
a = a.todense()
163+
if ndb == 0 and isinstance(b, SparseArray):
164+
b = b.todense()
165+
return a * b
160166
pos = int(nda != 0)
161167
raise ValueError(f"Input {pos} operand does not have enough dimensions")
162168
if na != nb:

0 commit comments

Comments
 (0)