Skip to content

Commit 361a078

Browse files
authored
BUG: Fix CSR/CSC matmul (#660)
1 parent 18c1596 commit 361a078

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

sparse/pydata_backend/_compressed/compressed.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import operator
33
from collections.abc import Iterable
44
from functools import reduce
5+
from typing import Union
56

67
import numpy as np
78
from numpy.lib.mixins import NDArrayOperatorsMixin
@@ -876,11 +877,14 @@ def from_scipy_sparse(cls, x):
876877
x = x.asformat("csr", copy=False)
877878
return cls((x.data, x.indices, x.indptr), shape=x.shape)
878879

879-
def transpose(self, axes: None = None, copy: bool = False) -> "CSC":
880-
if axes is not None:
881-
raise ValueError
880+
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:
881+
axes = normalize_axis(axes, self.ndim)
882+
if axes not in [(0, 1), (1, 0), None]:
883+
raise ValueError(f"Invalid transpose axes: {axes}")
882884
if copy:
883885
self = self.copy()
886+
if axes == (0, 1):
887+
return self
884888
return CSC((self.data, self.indices, self.indptr), self.shape[::-1])
885889

886890

@@ -905,9 +909,12 @@ def from_scipy_sparse(cls, x):
905909
x = x.asformat("csc", copy=False)
906910
return cls((x.data, x.indices, x.indptr), shape=x.shape)
907911

908-
def transpose(self, axes: None = None, copy: bool = False) -> CSR:
909-
if axes is not None:
910-
raise ValueError
912+
def transpose(self, axes: None = None, copy: bool = False) -> Union["CSC", "CSR"]:
913+
axes = normalize_axis(axes, self.ndim)
914+
if axes not in [(0, 1), (1, 0), None]:
915+
raise ValueError(f"Invalid transpose axes: {axes}")
911916
if copy:
912917
self = self.copy()
918+
if axes == (0, 1):
919+
return self
913920
return CSR((self.data, self.indices, self.indptr), self.shape[::-1])

sparse/pydata_backend/tests/test_compressed_2d.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def data_rvs(n):
4141

4242
else:
4343
data_rvs = None
44-
return cls(sparse.random((20, 30, 40), density=0.25, data_rvs=data_rvs).astype(dtype))
44+
return cls(sparse.random((20, 20), density=0.25, data_rvs=data_rvs).astype(dtype))
4545

4646

4747
def test_repr(random_sparse):
@@ -111,7 +111,21 @@ def test_transpose(random_sparse, copy):
111111
assert_eq(random_sparse, tt)
112112
assert type(random_sparse) == type(tt)
113113

114+
assert_eq(random_sparse.transpose(axes=(0, 1)), random_sparse)
115+
assert_eq(random_sparse.transpose(axes=(1, 0)), t)
116+
with pytest.raises(ValueError, match="Invalid transpose axes"):
117+
random_sparse.transpose(axes=0)
118+
114119

115120
def test_transpose_error(random_sparse):
116121
with pytest.raises(ValueError):
117122
random_sparse.transpose(axes=1)
123+
124+
125+
def test_matmul(random_sparse_small):
126+
arr = random_sparse_small.todense()
127+
128+
actual = random_sparse_small @ random_sparse_small
129+
expected = arr @ arr
130+
131+
assert_eq(actual, expected)

0 commit comments

Comments
 (0)