Skip to content

Commit 811338f

Browse files
Statistics.countnans: Fix sparse implementation and add axis support
1 parent 1f97b66 commit 811338f

File tree

2 files changed

+97
-16
lines changed

2 files changed

+97
-16
lines changed

Orange/statistics/util.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212

1313
def _count_nans_per_row_sparse(X, weights):
1414
""" Count the number of nans (undefined) values per row. """
15-
items_per_row = 1 if X.ndim == 1 else X.shape[1]
16-
counts = np.ones(X.shape[0]) * items_per_row
17-
nnz_per_row = np.bincount(X.indices, minlength=len(counts))
18-
counts -= nnz_per_row
15+
counts = np.fromiter((np.isnan(row.data).sum() for row in X), dtype=np.float)
16+
1917
if weights is not None:
2018
counts *= weights
21-
return np.sum(counts)
19+
20+
return counts
2221

2322

2423
def bincount(X, max_val=None, weights=None, minlength=None):
@@ -56,34 +55,52 @@ def bincount(X, max_val=None, weights=None, minlength=None):
5655

5756
def countnans(X, weights=None, axis=None, dtype=None, keepdims=False):
5857
"""
59-
Count the undefined elements in arr along given axis.
58+
Count the undefined elements in an array along given axis.
6059
6160
Parameters
6261
----------
6362
X : array_like
6463
weights : array_like
6564
Weights to weight the nans with, before or after counting (depending
6665
on the weights shape).
66+
axis : Optional[int]
6767
6868
Returns
6969
-------
70-
counts
70+
Union[np.ndarray, float]
71+
7172
"""
7273
if not sp.issparse(X):
7374
X = np.asanyarray(X)
7475
isnan = np.isnan(X)
7576
if weights is not None and weights.shape == X.shape:
7677
isnan = isnan * weights
78+
79+
# In order to keep return types consistent with sparse vectors, we will
80+
# handle `axis=1` given a regular 1d numpy array equivallently as
81+
# `axis=0`. If we didn't do this, this would raise error, whereas the
82+
# sparse counterpart would return the appropriate value.
83+
axis = axis and min(axis, isnan.ndim - 1)
84+
7785
counts = isnan.sum(axis=axis, dtype=dtype, keepdims=keepdims)
7886
if weights is not None and weights.shape != X.shape:
7987
counts = counts * weights
8088
else:
81-
if any(attr is not None for attr in [axis, dtype]) or \
82-
keepdims is not False:
83-
raise ValueError('Arguments axis, dtype and keepdims'
84-
'are not yet supported on sparse data!')
89+
assert axis in [None, 0, 1], 'Only axis 0 and 1 are currently supported'
90+
arr = X if axis == 1 else X.T
91+
92+
if weights is not None:
93+
weights = weights if axis == 1 else weights.T
94+
95+
arr = arr.tocsr()
96+
counts = _count_nans_per_row_sparse(arr, weights)
97+
98+
# We want a scalar value if `axis=None` or if the sparse matrix is
99+
# actually a vector (e.g. [[1 2 3]]), but has `ndim=2` due to scipy
100+
# implementation
101+
if axis is None or X.shape[0] == 1:
102+
counts = counts.sum()
85103

86-
counts = _count_nans_per_row_sparse(X, weights)
87104
return counts
88105

89106

Orange/tests/test_statistics.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ def test_bincount(self):
3131
self.assertEqual(n_nans, 0)
3232
np.testing.assert_equal(hist, [1, 1, 0, 1])
3333

34-
def test_countnans(self):
35-
np.testing.assert_equal(countnans([[1, np.nan],
36-
[2, np.nan]], axis=0), [0, 2])
37-
3834
def test_contingency(self):
3935
x = np.array([0, 1, 0, 2, np.nan])
4036
y = np.array([0, 0, 1, np.nan, 0])
@@ -241,3 +237,71 @@ def test_var(self):
241237
var(csr_matrix(data), axis=axis),
242238
np.var(data, axis=axis)
243239
)
240+
241+
242+
class TestCountnans(unittest.TestCase):
243+
def test_countnans_on_sparse_array(self):
244+
x = csr_matrix([0, 1, 0, 2, 2, np.nan, 1, np.nan, 0, 1])
245+
self.assertEqual(countnans(x), 2)
246+
247+
def test_countnans_on_dense_array(self):
248+
x = np.array([0, 1, 0, 2, 2, np.nan, 1, np.nan, 0, 1])
249+
self.assertEqual(countnans(x), 2)
250+
251+
def test_shape_matches_dense_and_sparse_given_array_and_axis_None(self):
252+
dense = np.array([0, 1, 0, 2, 2, np.nan, 1, np.nan, 0, 1])
253+
sparse = csr_matrix(dense)
254+
np.testing.assert_equal(countnans(dense), countnans(sparse))
255+
self.assertEqual(countnans(dense), 2)
256+
257+
def test_shape_matches_dense_and_sparse_given_array_and_axis_0(self):
258+
dense = np.array([0, 1, 0, 2, 2, np.nan, 1, np.nan, 0, 1])
259+
sparse = csr_matrix(dense)
260+
np.testing.assert_equal(countnans(dense, axis=0), countnans(sparse, axis=0))
261+
self.assertEqual(countnans(dense, axis=0), 2)
262+
263+
def test_shape_matches_dense_and_sparse_given_array_and_axis_1(self):
264+
dense = np.array([0, 1, 0, 2, 2, np.nan, 1, np.nan, 0, 1])
265+
sparse = csr_matrix(dense)
266+
np.testing.assert_equal(countnans(dense, axis=1), countnans(sparse, axis=1))
267+
self.assertEqual(countnans(dense, axis=1), 2)
268+
269+
def test_countnans(self):
270+
x = [[1, np.nan],
271+
[2, np.nan]]
272+
np.testing.assert_equal(
273+
countnans(x), 2, 'Countnans fails on dense data')
274+
np.testing.assert_equal(
275+
countnans(csr_matrix(x)), 2, 'Countnans fails on sparse data.')
276+
277+
def test_countnans_columns(self):
278+
x = [[1, np.nan],
279+
[2, np.nan]]
280+
np.testing.assert_equal(
281+
countnans(x, axis=0), [0, 2],
282+
'Countnans fails on dense data with `axis=0`')
283+
np.testing.assert_equal(
284+
countnans(csr_matrix(x), axis=0), [0, 2],
285+
'Countnans fails on sparse data with `axis=0`')
286+
287+
def test_countnans_rows(self):
288+
x = [[1, np.nan],
289+
[2, np.nan]]
290+
np.testing.assert_equal(
291+
countnans(x, axis=1), [1, 1],
292+
'Countnans fails on dense data with `axis=1`')
293+
np.testing.assert_equal(
294+
countnans(csr_matrix(x), axis=1), [1, 1],
295+
'Countnans fails on sparse data with `axis=1`')
296+
297+
def test_countnans_weights(self):
298+
x = [[1, np.nan],
299+
[2, np.nan]]
300+
w = np.array([[1, 1],
301+
[2, 2]])
302+
np.testing.assert_equal(countnans(x, weights=w, axis=0), [0, 3])
303+
np.testing.assert_equal(countnans(x, weights=w, axis=1), [1, 2])
304+
305+
w = np.array([1, 2])
306+
np.testing.assert_equal(countnans(x, weights=w, axis=0), [0, 4])
307+
np.testing.assert_equal(countnans(x, weights=w, axis=1), [1, 2])

0 commit comments

Comments
 (0)