Skip to content

Commit d18f0f1

Browse files
Statistics.countnans: Add dtype param support to sparse
1 parent 5c39006 commit d18f0f1

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

Orange/statistics/util.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import bottleneck as bn
1111

1212

13-
def _count_nans_per_row_sparse(X, weights):
13+
def _count_nans_per_row_sparse(X, weights, dtype=None):
1414
""" Count the number of nans (undefined) values per row. """
1515
if weights is not None:
1616
X = X.tocoo(copy=False)
@@ -25,9 +25,9 @@ def _count_nans_per_row_sparse(X, weights):
2525
w = sp.coo_matrix((data_weights, (nan_rows, nan_cols)), shape=X.shape)
2626
w = w.tocsr(copy=False)
2727

28-
return np.fromiter((np.sum(row.data) for row in w), dtype=np.float)
28+
return np.fromiter((np.sum(row.data) for row in w), dtype=dtype)
2929

30-
return np.fromiter((np.isnan(row.data).sum() for row in X), dtype=np.float)
30+
return np.fromiter((np.isnan(row.data).sum() for row in X), dtype=dtype)
3131

3232

3333
def sparse_count_zeros(x):
@@ -116,52 +116,54 @@ def bincount(x, weights=None, max_val=None, minlength=None):
116116
return bc, nans
117117

118118

119-
def countnans(X, weights=None, axis=None, dtype=None, keepdims=False):
119+
def countnans(x, weights=None, axis=None, dtype=None, keepdims=False):
120120
"""
121121
Count the undefined elements in an array along given axis.
122122
123123
Parameters
124124
----------
125-
X : array_like
126-
weights : array_like
125+
x : array_like
126+
weights : array_like, optional
127127
Weights to weight the nans with, before or after counting (depending
128128
on the weights shape).
129-
axis : Optional[int]
129+
axis : int, optional
130+
dtype : dtype, optional
131+
The data type of the returned array.
130132
131133
Returns
132134
-------
133135
Union[np.ndarray, float]
134136
135137
"""
136-
if not sp.issparse(X):
137-
X = np.asanyarray(X)
138-
isnan = np.isnan(X)
139-
if weights is not None and weights.shape == X.shape:
138+
if not sp.issparse(x):
139+
x = np.asanyarray(x)
140+
isnan = np.isnan(x)
141+
if weights is not None and weights.shape == x.shape:
140142
isnan = isnan * weights
141143

142144
counts = isnan.sum(axis=axis, dtype=dtype, keepdims=keepdims)
143-
if weights is not None and weights.shape != X.shape:
145+
if weights is not None and weights.shape != x.shape:
144146
counts = counts * weights
145147
else:
146148
assert axis in [None, 0, 1], 'Only axis 0 and 1 are currently supported'
147149
# To have consistent behaviour with dense matrices, raise error when
148150
# `axis=1` and the array is 1d (e.g. [[1 2 3]])
149-
if X.shape[0] == 1 and axis == 1:
151+
if x.shape[0] == 1 and axis == 1:
150152
raise ValueError('Axis %d is out of bounds' % axis)
151153

152-
arr = X if axis == 1 else X.T
154+
arr = x if axis == 1 else x.T
153155

154156
if weights is not None:
155157
weights = weights if axis == 1 else weights.T
156158

157159
arr = arr.tocsr()
158-
counts = _count_nans_per_row_sparse(arr, weights)
160+
counts = _count_nans_per_row_sparse(arr, weights, dtype=dtype)
159161

160162
# We want a scalar value if `axis=None` or if the sparse matrix is
161163
# actually a vector (e.g. [[1 2 3]]), but has `ndim=2` due to scipy
162164
# implementation
163-
if axis is None or X.shape[0] == 1:
164-
counts = counts.sum()
165+
if axis is None or x.shape[0] == 1:
166+
counts = counts.sum(dtype=dtype)
165167

166168
return counts
167169

Orange/tests/test_statistics.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,16 @@ def test_2d_weights(self, array):
332332
np.testing.assert_equal(countnans(x, w, axis=0), [1, 8, 0, 8])
333333
np.testing.assert_equal(countnans(x, w, axis=1), [3, 14])
334334

335+
@dense_sparse
336+
def test_dtype(self, array):
337+
x = array([0, np.nan, 2, 3])
338+
w = np.array([0, 1.5, 0, 0])
339+
340+
self.assertIsInstance(countnans(x, w, dtype=np.int32), np.int32)
341+
self.assertEqual(countnans(x, w, dtype=np.int32), 1)
342+
self.assertIsInstance(countnans(x, w, dtype=np.float64), np.float64)
343+
self.assertEqual(countnans(x, w, dtype=np.float64), 1.5)
344+
335345

336346
class TestBincount(unittest.TestCase):
337347
@dense_sparse

0 commit comments

Comments
 (0)