Skip to content

Commit d6d1b19

Browse files
authored
API: Add kwargs to sparse.einsum (#620)
1 parent 925112f commit d6d1b19

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

sparse/_common.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1393,7 +1393,7 @@ def _einsum_single(lhs, rhs, operand):
13931393
return to_output_format(COO(new_coords, new_data, shape=new_shape, has_duplicates=True))
13941394

13951395

1396-
def einsum(*operands):
1396+
def einsum(*operands, **kwargs):
13971397
"""
13981398
Perform the equivalent of :obj:`numpy.einsum`.
13991399
@@ -1406,6 +1406,11 @@ def einsum(*operands):
14061406
included as well as subscript labels of the precise output form.
14071407
operands : sequence of SparseArray
14081408
These are the arrays for the operation.
1409+
dtype : data-type, optional
1410+
If provided, forces the calculation to use the data type specified.
1411+
Default is ``None``.
1412+
**kwargs : dict, optional
1413+
Any additional arguments to pass to the function.
14091414
14101415
Returns
14111416
-------
@@ -1417,6 +1422,9 @@ def einsum(*operands):
14171422

14181423
check_zero_fill_value(*operands)
14191424

1425+
if "dtype" in kwargs and kwargs["dtype"] is not None:
1426+
operands = [o.astype(kwargs["dtype"]) for o in operands]
1427+
14201428
if len(operands) == 1:
14211429
return _einsum_single(lhs, rhs, operands[0])
14221430

sparse/tests/test_einsum.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,15 @@ def test_einsum_shape_check():
190190
y = sparse.random((2, 3, 4), density=0.5)
191191
with pytest.raises(ValueError):
192192
sparse.einsum("abc,acb", x, y)
193+
194+
195+
@pytest.mark.parametrize("dtype", [np.int64, np.complex128])
196+
def test_einsum_dtype(dtype):
197+
x = sparse.random((3, 3), density=0.5) * 10.0
198+
x = x.astype(np.float64)
199+
200+
y = sparse.COO.from_numpy(np.ones((3, 1), dtype=np.float64))
201+
202+
result = sparse.einsum("ij,i->j", x, y, dtype=dtype)
203+
204+
assert result.dtype == dtype

0 commit comments

Comments
 (0)