Skip to content

Commit 73bbf5c

Browse files
authored
Fix outer result dtype (#582)
1 parent c1391c0 commit 73bbf5c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

cubed/array_api/linalg.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from cubed.array_api.array_object import Array
44

55
# These functions are in both the main and linalg namespaces
6+
from cubed.array_api.data_type_functions import result_type
67
from cubed.array_api.linear_algebra_functions import ( # noqa: F401
78
matmul,
89
matrix_transpose,
@@ -15,7 +16,9 @@
1516

1617

1718
def outer(x1, x2, /):
18-
return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype)
19+
return blockwise(
20+
nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=result_type(x1, x2)
21+
)
1922

2023

2124
class QRResult(NamedTuple):

0 commit comments

Comments
 (0)