Skip to content

Commit 2a345ee

Browse files
authored
ENH: sparse: enhance dtype checking in constructors (scipy#22113)
* tests of invalid dtype checking * complete dtype checking in constructor * fix linting
1 parent 1600a37 commit 2a345ee

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

scipy/sparse/_bsr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False,
137137
self._shape = check_shape(shape)
138138

139139
if dtype is not None:
140-
self.data = self.data.astype(dtype, copy=False)
140+
self.data = self.data.astype(getdtype(dtype, self.data), copy=False)
141141

142142
self.check_format(full_check=False)
143143

scipy/sparse/_coo.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False, *, maxprint=None):
6565
if issparse(arg1):
6666
if arg1.format == self.format and copy:
6767
self.coords = tuple(idx.copy() for idx in arg1.coords)
68-
self.data = arg1.data.copy()
68+
self.data = arg1.data.astype(getdtype(dtype, arg1)) # copy=True
6969
self._shape = check_shape(arg1.shape, allow_nd=self._allow_nd)
7070
self.has_canonical_format = arg1.has_canonical_format
7171
else:
7272
coo = arg1.tocoo()
7373
self.coords = tuple(coo.coords)
74-
self.data = coo.data
74+
self.data = coo.data.astype(getdtype(dtype, coo), copy=False)
7575
self._shape = check_shape(coo.shape, allow_nd=self._allow_nd)
7676
self.has_canonical_format = False
7777
else:
@@ -92,16 +92,12 @@ def __init__(self, arg1, shape=None, dtype=None, copy=False, *, maxprint=None):
9292
coords = M.nonzero()
9393
self.coords = tuple(idx.astype(index_dtype, copy=False)
9494
for idx in coords)
95-
self.data = M[coords]
95+
self.data = getdata(M[coords], copy=copy, dtype=dtype)
9696
self.has_canonical_format = True
9797

9898
if len(self._shape) > 2:
9999
self.coords = tuple(idx.astype(np.int64, copy=False) for idx in self.coords)
100100

101-
if dtype is not None:
102-
newdtype = getdtype(dtype)
103-
self.data = self.data.astype(newdtype, copy=False)
104-
105101
self._check()
106102

107103
@property

scipy/sparse/tests/test_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,15 @@ def test_size_zero_conversions(self):
21832183
assert_array_equal(spm.todok().toarray(), m)
21842184
assert_array_equal(spm.tobsr().toarray(), m)
21852185

2186+
def test_dtype_check(self):
2187+
a = np.array([[3.5, 0, 1.1], [0, 0, 0]], dtype=np.float16)
2188+
with assert_raises(ValueError, match="does not support dtype"):
2189+
self.spcreator(a)
2190+
2191+
A32 = self.spcreator(a.astype(np.float32))
2192+
with assert_raises(ValueError, match="does not support dtype"):
2193+
self.spcreator(A32, dtype=np.float16)
2194+
21862195
def test_pickle(self):
21872196
import pickle
21882197
sup = suppress_warnings()

0 commit comments

Comments
 (0)