Skip to content

Commit 53f6509

Browse files
sdravedschult
andauthored
BUG: sparse: fix setdiag for matrices with missing diagonal entries (scipy#21792)
* BUG: sparse: fix setdiag for matrices with missing diagonal entries We need to first set existing diagonal entries before adding the missing ones, as inserting new values invalidates the computed offsets. * add tests for fix of csr.setdiag --------- Co-authored-by: Dan Schult <[email protected]>
1 parent 0346a93 commit 53f6509

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

scipy/sparse/_compressed.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -970,17 +970,17 @@ def _setdiag(self, values, k):
970970
self.data[offsets] = x
971971
return
972972

973-
mask = (offsets <= -1)
973+
mask = (offsets >= 0)
974974
# Boundary between csc and convert to coo
975975
# The value 0.001 is justified in gh-19962#issuecomment-1920499678
976-
if mask.sum() < self.nnz * 0.001:
976+
if self.nnz - mask.sum() < self.nnz * 0.001:
977+
# replace existing entries
978+
self.data[offsets[mask]] = x[mask]
977979
# create new entries
980+
mask = ~mask
978981
i = i[mask]
979982
j = j[mask]
980983
self._insert_many(i, j, x[mask])
981-
# replace existing entries
982-
mask = ~mask
983-
self.data[offsets[mask]] = x[mask]
984984
else:
985985
# convert to coo for _set_diag
986986
coo = self.tocoo()

scipy/sparse/tests/test_base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4272,6 +4272,13 @@ def test_scalar_idx_dtype(self):
42724272
for x in [a, b, c, d, e, f]:
42734273
x + x
42744274

4275+
def test_setdiag_csr(self):
4276+
# see gh-21791 setting mixture of existing and not when new_values < 0.001*nnz
4277+
D = self.dia_container(([np.arange(1002)], [0]), shape=(1002, 1002))
4278+
A = self.spcreator(D)
4279+
A.setdiag(5 * np.ones(A.shape[0]))
4280+
assert A[-1, -1] == 5
4281+
42754282
def test_binop_explicit_zeros(self):
42764283
# Check that binary ops don't introduce spurious explicit zeros.
42774284
# See gh-9619 for context.
@@ -4441,6 +4448,13 @@ def test_scalar_idx_dtype(self):
44414448
for x in [a, b, c, d, e, f]:
44424449
x + x
44434450

4451+
def test_setdiag_csc(self):
4452+
# see gh-21791 setting mixture of existing and not when new_values < 0.001*nnz
4453+
D = self.dia_container(([np.arange(1002)], [0]), shape=(1002, 1002))
4454+
A = self.spcreator(D)
4455+
A.setdiag(5 * np.ones(A.shape[0]))
4456+
assert A[-1, -1] == 5
4457+
44444458

44454459
TestCSC.init_class()
44464460

0 commit comments

Comments
 (0)