Skip to content

Commit 2545c92

Browse files
committed
[FIX] Tree: Sparse Support
1 parent 5afe1c2 commit 2545c92

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

Orange/classification/tree.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tree inducers: SKL and Orange's own inducer"""
22
import numpy as np
3+
import scipy.sparse as sp
34
import sklearn.tree as skl_tree
45

56
from Orange.base import TreeModel as TreeModelInterface
@@ -76,6 +77,7 @@ def _select_attr(self, data):
7677
"""
7778
# Prevent false warnings by pylint
7879
attr = attr_no = None
80+
col_x = None
7981
REJECT_ATTRIBUTE = 0, None, None, 0
8082

8183
def _score_disc():
@@ -89,8 +91,7 @@ def _score_disc():
8991
if n_values < 2:
9092
return REJECT_ATTRIBUTE
9193

92-
x = data.X[:, attr_no].flatten()
93-
cont = _tree_scorers.contingency(x, len(data.domain.attributes[attr_no].values),
94+
cont = _tree_scorers.contingency(col_x, len(data.domain.attributes[attr_no].values),
9495
data.Y, len(data.domain.class_var.values))
9596
attr_distr = np.sum(cont, axis=0)
9697
null_nodes = attr_distr <= self.min_samples_leaf
@@ -111,7 +112,7 @@ def _score_disc():
111112
cont_entr = np.sum(cont * np.log(cont))
112113
score = (class_entr - attr_entr + cont_entr) / n / np.log(2)
113114
score *= n / len(data) # punishment for missing values
114-
branches = x
115+
branches = col_x
115116
branches[np.isnan(branches)] = -1
116117
if score == 0:
117118
return REJECT_ATTRIBUTE
@@ -135,13 +136,12 @@ def _score_disc_bin():
135136
return REJECT_ATTRIBUTE
136137
best_score *= 1 - np.sum(cont.unknowns) / len(data)
137138
mapping, branches = MappedDiscreteNode.branches_from_mapping(
138-
data.X[:, attr_no], best_mapping, n_values)
139+
col_x, best_mapping, n_values)
139140
node = MappedDiscreteNode(attr, attr_no, mapping, None)
140141
return best_score, node, branches, 2
141142

142143
def _score_cont():
143144
"""Scoring for numeric attributes"""
144-
col_x = data.X[:, attr_no]
145145
nans = np.sum(np.isnan(col_x))
146146
non_nans = len(col_x) - nans
147147
arginds = np.argsort(col_x)[:non_nans]
@@ -159,12 +159,17 @@ def _score_cont():
159159

160160
#######################################
161161
# The real _select_attr starts here
162+
is_sparse = sp.issparse(data.X)
162163
domain = data.domain
163164
class_var = domain.class_var
164165
best_score, *best_res = REJECT_ATTRIBUTE
165166
best_res = [Node(None, None, None)] + best_res[1:]
166167
disc_scorer = _score_disc_bin if self.binarize else _score_disc
167168
for attr_no, attr in enumerate(domain.attributes):
169+
col_x = data.X[:, attr_no]
170+
if is_sparse:
171+
col_x = col_x.toarray()
172+
col_x = col_x.flatten()
168173
sc, *res = disc_scorer() if attr.is_discrete else _score_cont()
169174
if res[0] is not None and sc > best_score:
170175
best_score, best_res = sc, res

Orange/widgets/model/tests/test_tree.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
# pylint: disable=protected-access
2+
import numpy as np
3+
import scipy.sparse as sp
4+
15
from Orange.base import Model
6+
from Orange.data import Table
27
from Orange.widgets.model.owtree import OWTreeLearner
38
from Orange.widgets.tests.base import (
49
DefaultParameterMapping,
@@ -35,3 +40,18 @@ def test_parameters_unchecked(self):
3540
self.parameters = [DefaultParameterMapping(par.name, val)
3641
for par, val in zip(self.parameters, (None, 2, 1))]
3742
self.test_parameters()
43+
44+
def test_sparse_data(self):
45+
"""
46+
Tree can handle sparse data.
47+
GH-2430
48+
"""
49+
table1 = Table("iris")
50+
self.send_signal("Data", table1)
51+
model_dense = self.get_output("Model")
52+
table2 = Table("iris")
53+
table2.X = sp.csr_matrix(table2.X)
54+
model_sparse = self.get_output("Model")
55+
self.assertTrue(np.array_equal(model_dense._code, model_sparse._code))
56+
self.assertTrue(np.array_equal(model_dense._thresholds, model_sparse._thresholds))
57+
self.assertTrue(np.array_equal(model_dense._values, model_sparse._values))

0 commit comments

Comments
 (0)