Skip to content

Commit a69c051

Browse files
committed
add some checks for shapes of things in make_csr_matrix, and move some other checks over from constructor
1 parent cf4f219 commit a69c051

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

pytato/array.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,10 +2238,6 @@ class SparseMatrix(_SuppliedAxesAndTagsMixin, _SuppliedShapeAndDtypeMixin, ABC):
22382238
22392239
.. automethod:: __matmul__
22402240
"""
2241-
if __debug__:
2242-
def __post_init__(self) -> None:
2243-
pass
2244-
22452241
def __matmul__(self, other: Array) -> SparseMatmul:
22462242
return sparse_matmul(self, other)
22472243

@@ -2308,17 +2304,6 @@ class CSRMatrix(SparseMatrix):
23082304
elem_col_indices: Array
23092305
row_starts: Array
23102306

2311-
if __debug__:
2312-
@override
2313-
def __post_init__(self) -> None:
2314-
if self.elem_values.ndim != 1:
2315-
raise ValueError("elem_values must be a 1D array.")
2316-
if self.elem_col_indices.ndim != 1:
2317-
raise ValueError("elem_col_indices must be a 1D array.")
2318-
if self.row_starts.ndim != 1:
2319-
raise ValueError("row_starts must be a 1D array.")
2320-
super().__post_init__()
2321-
23222307

23232308
@array_dataclass()
23242309
class CSRMatmul(SparseMatmul):
@@ -2756,7 +2741,7 @@ def make_csr_matrix(shape: ConvertibleToShape,
27562741
axes: AxesT | None = None) -> CSRMatrix:
27572742
"""Make a :class:`CSRMatrix` object.
27582743
2759-
:param shape: the shape of the matrix
2744+
:param shape: the (two-dimensional) shape of the matrix
27602745
:param elem_values: a one-dimensional array containing the values of all of the
27612746
nonzero entries of the matrix, grouped by row.
27622747
:param elem_col_indices: a one-dimensional array containing the column index
@@ -2768,13 +2753,27 @@ def make_csr_matrix(shape: ConvertibleToShape,
27682753
shape = normalize_shape(shape)
27692754
dtype = elem_values.dtype
27702755

2756+
if len(shape) != 2:
2757+
raise ValueError("'shape' must be 2D.")
2758+
27712759
if axes is None:
27722760
axes = _get_default_axes(len(shape))
27732761

27742762
if len(axes) != len(shape):
27752763
raise ValueError("'axes' dimensionality mismatch:"
27762764
f" expected {len(shape)}, got {len(axes)}.")
27772765

2766+
if elem_values.ndim != 1:
2767+
raise ValueError("'elem_values' must be 1D.")
2768+
if elem_col_indices.ndim != 1:
2769+
raise ValueError("'elem_col_indices' must be 1D.")
2770+
if row_starts.ndim != 1:
2771+
raise ValueError("'row_starts' must be 1D.")
2772+
2773+
if len(row_starts) != shape[0] + 1:
2774+
raise ValueError(
2775+
"'row_starts' must have length equal to the number of rows plus one.")
2776+
27782777
return CSRMatrix(
27792778
shape=shape,
27802779
elem_values=elem_values,

0 commit comments

Comments
 (0)