@@ -2238,10 +2238,6 @@ class SparseMatrix(_SuppliedAxesAndTagsMixin, _SuppliedShapeAndDtypeMixin, ABC):
22382238
22392239 .. automethod:: __matmul__
22402240 """
2241- if __debug__ :
2242- def __post_init__ (self ) -> None :
2243- pass
2244-
22452241 def __matmul__ (self , other : Array ) -> SparseMatmul :
22462242 return sparse_matmul (self , other )
22472243
@@ -2308,17 +2304,6 @@ class CSRMatrix(SparseMatrix):
23082304 elem_col_indices : Array
23092305 row_starts : Array
23102306
2311- if __debug__ :
2312- @override
2313- def __post_init__ (self ) -> None :
2314- if self .elem_values .ndim != 1 :
2315- raise ValueError ("elem_values must be a 1D array." )
2316- if self .elem_col_indices .ndim != 1 :
2317- raise ValueError ("elem_col_indices must be a 1D array." )
2318- if self .row_starts .ndim != 1 :
2319- raise ValueError ("row_starts must be a 1D array." )
2320- super ().__post_init__ ()
2321-
23222307
23232308@array_dataclass ()
23242309class CSRMatmul (SparseMatmul ):
@@ -2756,7 +2741,7 @@ def make_csr_matrix(shape: ConvertibleToShape,
27562741 axes : AxesT | None = None ) -> CSRMatrix :
27572742 """Make a :class:`CSRMatrix` object.
27582743
2759- :param shape: the shape of the matrix
2744+ :param shape: the (two-dimensional) shape of the matrix
27602745 :param elem_values: a one-dimensional array containing the values of all of the
27612746 nonzero entries of the matrix, grouped by row.
27622747 :param elem_col_indices: a one-dimensional array containing the column index
@@ -2768,13 +2753,27 @@ def make_csr_matrix(shape: ConvertibleToShape,
27682753 shape = normalize_shape (shape )
27692754 dtype = elem_values .dtype
27702755
2756+ if len (shape ) != 2 :
2757+ raise ValueError ("'shape' must be 2D." )
2758+
27712759 if axes is None :
27722760 axes = _get_default_axes (len (shape ))
27732761
27742762 if len (axes ) != len (shape ):
27752763 raise ValueError ("'axes' dimensionality mismatch:"
27762764 f" expected { len (shape )} , got { len (axes )} ." )
27772765
2766+ if elem_values .ndim != 1 :
2767+ raise ValueError ("'elem_values' must be 1D." )
2768+ if elem_col_indices .ndim != 1 :
2769+ raise ValueError ("'elem_col_indices' must be 1D." )
2770+ if row_starts .ndim != 1 :
2771+ raise ValueError ("'row_starts' must be 1D." )
2772+
2773+ if len (row_starts ) != shape [0 ] + 1 :
2774+ raise ValueError (
2775+ "'row_starts' must have length equal to the number of rows plus one." )
2776+
27782777 return CSRMatrix (
27792778 shape = shape ,
27802779 elem_values = elem_values ,
0 commit comments