Skip to content

Commit edec7b9

Browse files
committed
remove_unused_values: Support sparse matrices
1 parent 49d04b1 commit edec7b9

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

Orange/preprocess/remove.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from collections import namedtuple
2+
23
import numpy as np
34

4-
from .preprocess import Preprocess
55
from Orange.data import Domain, DiscreteVariable, Table
66
from Orange.preprocess.transformation import Lookup
7+
from Orange.statistics.util import nanunique
8+
from .preprocess import Preprocess
79

810
__all__ = ["Remove"]
911

@@ -234,10 +236,7 @@ def remove_unused_values(var, data):
234236
Domain([var]),
235237
data
236238
)
237-
array = column_data.X.ravel()
238-
mask = np.isfinite(array)
239-
unique = np.array(np.unique(array[mask]), dtype=int)
240-
239+
unique = nanunique(column_data.X).astype(int)
241240
if len(unique) == len(var.values):
242241
return var
243242

Orange/tests/test_remove.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import unittest
55

66
import numpy as np
7+
import scipy.sparse as sp
8+
79
from Orange.data import Table
810
from Orange.preprocess import Remove
911
from Orange.tests import test_filename
@@ -133,3 +135,19 @@ def test_remove_unused_values_metas(self):
133135
self.assertEqual(res.domain["b"].values, res.domain["c"].values)
134136
self.assertEqual(res.domain["d"].values, ["1", "2"])
135137
self.assertEqual(res.domain["f"].values, ['1', 'hey'])
138+
139+
def test_remove_unused_values_attr_sparse(self):
140+
data = self.test8
141+
data = data[1:]
142+
data.X = sp.csr_matrix(data.X)
143+
remover = Remove(Remove.RemoveUnusedValues)
144+
new_data = remover(data)
145+
attr_res = remover.attr_results
146+
147+
self.assertEqual((new_data.X != data.X).nnz, 0)
148+
self.assertEqual([a.values for a in new_data.domain.attributes
149+
if a.is_discrete], [['1'], ['4']])
150+
self.assertEqual([c.values for c in new_data.domain.class_vars
151+
if c.is_discrete], [['1', '2', '3'], ['2']])
152+
self.assertDictEqual(attr_res,
153+
{'removed': 0, 'reduced': 1, 'sorted': 0})

0 commit comments

Comments
 (0)