Skip to content

Commit ce3a196

Browse files
Statistics.unique: Fix incorrect handling of negative values in sparse matrices
1 parent e4206e2 commit ce3a196

File tree

2 files changed

+44
-38
lines changed

2 files changed

+44
-38
lines changed

Orange/statistics/util.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,15 +428,18 @@ def unique(x, return_counts=False):
428428
r = np.unique(x.data, return_counts=return_counts)
429429
if not implicit_zeros:
430430
return r
431+
431432
if return_counts:
433+
zero_index = np.searchsorted(r[0], 0)
432434
if explicit_zeros:
433435
r[1][r[0] == 0.] += implicit_zeros
434436
return r
435-
return np.insert(r[0], 0, 0), np.insert(r[1], 0, implicit_zeros)
437+
return np.insert(r[0], zero_index, 0), np.insert(r[1], zero_index, implicit_zeros)
436438
else:
437439
if explicit_zeros:
438440
return r
439-
return np.insert(r, 0, 0)
441+
zero_index = np.searchsorted(r, 0)
442+
return np.insert(r, zero_index, 0)
440443

441444

442445
def nanunique(x):

Orange/tests/test_statistics.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import unittest
2-
import warnings
3-
from functools import wraps, partial
4-
from itertools import chain
2+
from functools import partial, wraps
53

64
import numpy as np
7-
import scipy as sp
8-
from scipy.sparse import csr_matrix, issparse, csc_matrix
5+
from itertools import chain
6+
from scipy.sparse import csr_matrix, issparse, lil_matrix, csc_matrix
97

108
from Orange.statistics.util import bincount, countnans, contingency, stats, \
119
nanmin, nanmax, unique, nanunique, mean, nanmean, digitize, var
@@ -128,37 +126,6 @@ def test_nanmin_nanmax(self):
128126
nanmax(X_sparse, axis=axis),
129127
np.nanmax(X, axis=axis))
130128

131-
def test_unique(self):
132-
for X in self.data:
133-
X_sparse = csr_matrix(X)
134-
np.testing.assert_array_equal(
135-
unique(X_sparse, return_counts=False),
136-
np.unique(X, return_counts=False))
137-
138-
for a1, a2 in zip(unique(X_sparse, return_counts=True),
139-
np.unique(X, return_counts=True)):
140-
np.testing.assert_array_equal(a1, a2)
141-
142-
def test_unique_explicit_zeros(self):
143-
x1 = csr_matrix(np.eye(3))
144-
x2 = csr_matrix(np.eye(3))
145-
146-
# set some of-diagonal to explicit zeros
147-
with warnings.catch_warnings():
148-
warnings.filterwarnings("ignore",
149-
category=sp.sparse.SparseEfficiencyWarning)
150-
x2[0, 1] = 0
151-
x2[1, 0] = 0
152-
153-
np.testing.assert_array_equal(
154-
unique(x1, return_counts=False),
155-
unique(x2, return_counts=False),
156-
)
157-
np.testing.assert_array_equal(
158-
unique(x1, return_counts=True),
159-
unique(x2, return_counts=True),
160-
)
161-
162129
def test_nanunique(self):
163130
x = csr_matrix(np.array([0, 1, 1, np.nan]))
164131
np.testing.assert_array_equal(
@@ -412,3 +379,39 @@ def test_weights_with_nans(self, array):
412379

413380
expected = [3, 0, 1, 1]
414381
np.testing.assert_equal(bincount(x, w)[0], expected)
382+
383+
384+
class TestUnique(unittest.TestCase):
385+
@dense_sparse
386+
def test_returns_unique_values(self, array):
387+
# pylint: disable=bad-whitespace
388+
x = array([[-1., 1., 0., 2., 3., np.nan],
389+
[ 0., 0., 0., 3., 5., np.nan],
390+
[-1., 0., 0., 1., 7., 6.]])
391+
expected = [-1, 0, 1, 2, 3, 5, 6, 7, np.nan, np.nan]
392+
393+
np.testing.assert_equal(unique(x, return_counts=False), expected)
394+
395+
@dense_sparse
396+
def test_returns_counts(self, array):
397+
# pylint: disable=bad-whitespace
398+
x = array([[-1., 1., 0., 2., 3., np.nan],
399+
[ 0., 0., 0., 3., 5., np.nan],
400+
[-1., 0., 0., 1., 7., 6.]])
401+
expected = [2, 6, 2, 1, 2, 1, 1, 1, 1, 1]
402+
403+
np.testing.assert_equal(unique(x, return_counts=True)[1], expected)
404+
405+
def test_sparse_explicit_zeros(self):
406+
# Use `lil_matrix` to fix sparse warning for matrix construction
407+
x = lil_matrix(np.eye(3))
408+
x[0, 1] = 0
409+
x[1, 0] = 0
410+
x = x.tocsr()
411+
# Test against identity matrix
412+
y = csr_matrix(np.eye(3))
413+
414+
np.testing.assert_array_equal(
415+
unique(y, return_counts=True),
416+
unique(x, return_counts=True),
417+
)

0 commit comments

Comments
 (0)