Skip to content

Commit a4b79f8

Browse files
authored
TST: sparse: more tests for has_canonical_format flag (scipy#22936)
1 parent 7bed15d commit a4b79f8

File tree

1 file changed

+152
-11
lines changed

1 file changed

+152
-11
lines changed

scipy/sparse/tests/test_base.py

Lines changed: 152 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4356,31 +4356,70 @@ def test_has_sorted_indices(self):
43564356
def test_has_canonical_format(self):
43574357
"Ensure has_canonical_format memoizes state for sum_duplicates"
43584358

4359-
M = self.csr_container((np.array([2]), np.array([0]), np.array([0, 1])))
4360-
assert_equal(True, M.has_canonical_format)
4359+
info_no_dups = (np.array([2]), np.array([0]), np.array([0, 1]))
4360+
info_with_dups = (np.array([1, 1]), np.array([0, 0]), np.array([0, 2]))
43614361

4362-
indices = np.array([0, 0]) # contains duplicate
4363-
data = np.array([1, 1])
4364-
indptr = np.array([0, 2])
4362+
M = self.csr_container(info_no_dups)
4363+
assert_equal(True, M.has_canonical_format)
43654364

4366-
M = self.csr_container((data, indices, indptr)).copy()
4365+
M = self.csr_container(info_with_dups).copy()
43674366
assert_equal(False, M.has_canonical_format)
43684367
assert isinstance(M.has_canonical_format, bool)
43694368

4370-
# set by deduplicating
4369+
# set flag by deduplicating
43714370
M.sum_duplicates()
43724371
assert_equal(True, M.has_canonical_format)
43734372
assert_equal(1, len(M.indices))
43744373

4375-
M = self.csr_container((data, indices, indptr)).copy()
4376-
# set manually (although underlyingly duplicated)
4374+
# manually set flag True (although underlyingly duplicated)
4375+
M = self.csr_container(info_with_dups).copy()
43774376
M.has_canonical_format = True
43784377
assert_equal(True, M.has_canonical_format)
43794378
assert_equal(2, len(M.indices)) # unaffected content
4380-
43814379
# ensure deduplication bypassed when has_canonical_format == True
43824380
M.sum_duplicates()
4383-
assert_equal(2, len(M.indices)) # unaffected content
4381+
assert_equal(2, len(M.indices)) # still has duplicates!!!!
4382+
# ensure deduplication reenabled when has_canonical_format == False
4383+
M.has_canonical_format = False
4384+
M.sum_duplicates()
4385+
assert_equal(1, len(M.indices))
4386+
assert_equal(True, M.has_canonical_format)
4387+
4388+
# manually set flag False (although underlyingly canonical)
4389+
M.has_canonical_format = False
4390+
assert_equal(False, M.has_canonical_format)
4391+
Mcheck = self.csr_container((M.data, M.indices, M.indptr))
4392+
assert_equal(True, Mcheck.has_canonical_format)
4393+
# sum_duplicates does not complain when no work to do
4394+
M.sum_duplicates()
4395+
assert_equal(True, M.has_canonical_format)
4396+
4397+
# check assignments maintain canonical format
4398+
M = self.csr_container((np.array([2]), np.array([2]), np.array([0, 1, 1, 1])))
4399+
assert_equal(M.shape, (3, 3))
4400+
with suppress_warnings() as sup:
4401+
sup.filter(SparseEfficiencyWarning, "Changing the sparsity structure")
4402+
M[0, 1] = 2
4403+
M[1, :] *= 5
4404+
M[0, 2] = 3
4405+
assert_equal(True, M.has_canonical_format)
4406+
Mcheck = self.csr_container((M.data, M.indices, M.indptr))
4407+
assert_equal(True, Mcheck.has_canonical_format)
4408+
4409+
# resetting index arrays before accessing M.has_canonical_format is OK
4410+
M = self.csr_container(info_no_dups)
4411+
M.data, M.indices, M.indptr = info_with_dups
4412+
assert_equal(False, M.has_canonical_format)
4413+
assert_equal(2, len(M.indices)) # dups and has_canonical_format is False
4414+
4415+
# but reset after accessing M.has_canonical_format can break flag
4416+
M = self.csr_container(info_no_dups)
4417+
M.has_canonical_format # underlying attr is set here
4418+
M.data, M.indices, M.indptr = info_with_dups
4419+
assert_equal(True, M.has_canonical_format)
4420+
assert_equal(2, len(M.indices)) # dups but has_canonical_format is True
4421+
M.sum_duplicates()
4422+
assert_equal(2, len(M.indices)) # still has duplicates!!!!
43844423

43854424
def test_scalar_idx_dtype(self):
43864425
# Check that index dtype takes into account all parameters
@@ -4938,6 +4977,51 @@ def test_tocompressed_duplicates(self):
49384977
csc = coo.tocsc()
49394978
assert_equal(csc.nnz + 2, coo.nnz)
49404979

4980+
def test_has_canonical_format(self):
4981+
"Ensure has_canonical_format memoizes state for sum_duplicates"
4982+
4983+
A = self.coo_container((2, 3))
4984+
assert_equal(A.has_canonical_format, True)
4985+
4986+
A_array = np.array([[0, 2, 0]])
4987+
A_coords_form = (np.array([2]), (np.array([0]), np.array([1])))
4988+
A_coords_dups = (np.array([1, 1]), (np.array([0, 0]), np.array([1, 1])))
4989+
4990+
A = self.coo_container(A_array)
4991+
assert A.has_canonical_format is True
4992+
A = self.coo_container(A_coords_form)
4993+
assert A.has_canonical_format is False
4994+
A.sum_duplicates()
4995+
assert A.has_canonical_format is True
4996+
4997+
A = self.coo_container(A, copy=True)
4998+
assert A.has_canonical_format is True
4999+
A = self.coo_container(A, copy=False)
5000+
assert A.has_canonical_format is False
5001+
A.sum_duplicates()
5002+
assert A.has_canonical_format is True
5003+
5004+
A = self.coo_container(A_coords_dups)
5005+
assert A.has_canonical_format is False
5006+
assert_equal(A.nnz, 2) # duplicates
5007+
A.sum_duplicates()
5008+
assert A.has_canonical_format is True
5009+
assert_equal(A.nnz, 1)
5010+
5011+
# manually set
5012+
A.has_canonical_format = False
5013+
assert_equal(A.has_canonical_format, False)
5014+
assert_equal(A.nnz, 1) # incorrectly False
5015+
A.sum_duplicates() # check flag updated
5016+
assert_equal(A.has_canonical_format, True)
5017+
5018+
A = self.coo_container(A_coords_dups)
5019+
A.has_canonical_format = True
5020+
assert_equal(A.has_canonical_format, True)
5021+
assert_equal(A.nnz, 2) # incorrectly True
5022+
A.sum_duplicates() # check dups not removed due to flag
5023+
assert_equal(A.nnz, 2) # still has duplicates!!!!
5024+
49415025
def test_eliminate_zeros(self):
49425026
data = array([1, 0, 0, 0, 2, 0, 3, 0])
49435027
row = array([0, 0, 0, 1, 1, 1, 1, 1])
@@ -5347,6 +5431,63 @@ def test_eliminate_zeros_all_zero(self):
53475431
assert_array_equal(m.data.shape, (0, 2, 3))
53485432
assert_array_equal(m.toarray(), np.zeros((12, 12)))
53495433

5434+
def test_has_canonical_format(self):
5435+
"Ensure has_canonical_format memoizes state for sum_duplicates"
5436+
5437+
A = np.array([[2, 3, 2], [0, 2, 1], [-4, 0, 2]])
5438+
M = self.bsr_container(A)
5439+
assert_equal(True, M.has_canonical_format)
5440+
5441+
indices = np.array([0, 0]) # contains duplicate
5442+
data = np.array([A, A*0])
5443+
indptr = np.array([0, 2])
5444+
5445+
M = self.bsr_container((data, indices, indptr)).copy()
5446+
assert_equal(False, M.has_canonical_format)
5447+
assert isinstance(M.has_canonical_format, bool)
5448+
# set flag by deduplicating
5449+
M.sum_duplicates()
5450+
assert_equal(True, M.has_canonical_format)
5451+
assert_equal(1, len(M.indices))
5452+
5453+
# manually set flag True (although underlyingly duplicated)
5454+
M = self.bsr_container((data, indices, indptr)).copy()
5455+
M.has_canonical_format = True
5456+
assert_equal(True, M.has_canonical_format)
5457+
assert_equal(2, len(M.indices)) # unaffected content
5458+
# ensure deduplication bypassed when has_canonical_format == True
5459+
M.sum_duplicates()
5460+
assert_equal(2, len(M.indices)) # still has duplicates!!!!
5461+
# ensure deduplication reenabled when has_canonical_format == False
5462+
M.has_canonical_format = False
5463+
M.sum_duplicates()
5464+
assert_equal(1, len(M.indices))
5465+
assert_equal(True, M.has_canonical_format)
5466+
5467+
# manually set flag False (although underlyingly canonical)
5468+
M = self.bsr_container(A)
5469+
M.has_canonical_format = False
5470+
assert_equal(False, M.has_canonical_format)
5471+
assert_equal(1, len(M.indices))
5472+
# sum_duplicates does not complain when no work to do
5473+
M.sum_duplicates()
5474+
assert_equal(True, M.has_canonical_format)
5475+
5476+
# manually reset index arrays before accessing M.has_canonical_format is OK
5477+
M = self.bsr_container(A)
5478+
M.data, M.indices, M.indptr = data, indices, indptr
5479+
assert_equal(False, M.has_canonical_format)
5480+
assert_equal(2, len(M.indices)) # dups and has_canonical_format is False
5481+
5482+
# but reset after accessing M.has_canonical_format can break flag
5483+
M = self.bsr_container(A)
5484+
M.has_canonical_format # underlying attr is set here
5485+
M.data, M.indices, M.indptr = data, indices, indptr
5486+
assert_equal(True, M.has_canonical_format)
5487+
assert_equal(2, len(M.indices)) # dups but has_canonical_format is True
5488+
M.sum_duplicates()
5489+
assert_equal(2, len(M.indices)) # still has duplicates!!!!
5490+
53505491
def test_bsr_matvec(self):
53515492
A = self.bsr_container(arange(2*3*4*5).reshape(2*4,3*5), blocksize=(4,5))
53525493
x = arange(A.shape[1]).reshape(-1,1)

0 commit comments

Comments
 (0)