2
2
import operator
3
3
from collections .abc import Iterable
4
4
from functools import reduce
5
+ from typing import Union
5
6
6
7
import numpy as np
7
8
from numpy .lib .mixins import NDArrayOperatorsMixin
@@ -876,11 +877,14 @@ def from_scipy_sparse(cls, x):
876
877
x = x .asformat ("csr" , copy = False )
877
878
return cls ((x .data , x .indices , x .indptr ), shape = x .shape )
878
879
879
- def transpose (self , axes : None = None , copy : bool = False ) -> "CSC" :
880
- if axes is not None :
881
- raise ValueError
880
+ def transpose (self , axes : None = None , copy : bool = False ) -> Union ["CSC" , "CSR" ]:
881
+ axes = normalize_axis (axes , self .ndim )
882
+ if axes not in [(0 , 1 ), (1 , 0 ), None ]:
883
+ raise ValueError (f"Invalid transpose axes: { axes } " )
882
884
if copy :
883
885
self = self .copy ()
886
+ if axes == (0 , 1 ):
887
+ return self
884
888
return CSC ((self .data , self .indices , self .indptr ), self .shape [::- 1 ])
885
889
886
890
@@ -905,9 +909,12 @@ def from_scipy_sparse(cls, x):
905
909
x = x .asformat ("csc" , copy = False )
906
910
return cls ((x .data , x .indices , x .indptr ), shape = x .shape )
907
911
908
- def transpose (self , axes : None = None , copy : bool = False ) -> CSR :
909
- if axes is not None :
910
- raise ValueError
912
+ def transpose (self , axes : None = None , copy : bool = False ) -> Union ["CSC" , "CSR" ]:
913
+ axes = normalize_axis (axes , self .ndim )
914
+ if axes not in [(0 , 1 ), (1 , 0 ), None ]:
915
+ raise ValueError (f"Invalid transpose axes: { axes } " )
911
916
if copy :
912
917
self = self .copy ()
918
+ if axes == (0 , 1 ):
919
+ return self
913
920
return CSR ((self .data , self .indices , self .indptr ), self .shape [::- 1 ])
0 commit comments