Skip to content

Commit 23703a6

Browse files
authored
Merge pull request #3867 from pavlin-policar/sparse-target
Densify sparse targets when fitting models
2 parents 0dd8df8 + a18f380 commit 23703a6

File tree

5 files changed

+32
-6
lines changed

5 files changed

+32
-6
lines changed

Orange/data/table.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def Y(self, value):
204204
value = value[:, None]
205205
if sp.issparse(value) and len(self) != value.shape[0]:
206206
value = value.T
207+
if sp.issparse(value):
208+
value = value.toarray()
207209
self._Y = value
208210

209211
def __new__(cls, *args, **kwargs):

Orange/tests/test_sparse_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_Y_setter_1d(self):
6767
assert iris.Y.shape == (150,)
6868
iris.Y = csr_matrix(iris.Y)
6969
# We expect the Y shape to match the X shape, which is (150, 4) in iris
70-
self.assertEqual(iris.Y.shape, (150, 1))
70+
self.assertEqual(iris.Y.shape, (150,))
7171

7272
def test_Y_setter_2d(self):
7373
iris = Table('iris')

Orange/tests/test_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2810,7 +2810,7 @@ def test_sparse_dense_transformation(self):
28102810

28112811
iris_sparse = iris.to_sparse(sparse_attributes=True, sparse_class=True)
28122812
self.assertTrue(sp.issparse(iris_sparse.X))
2813-
self.assertTrue(sp.issparse(iris_sparse.Y))
2813+
self.assertFalse(sp.issparse(iris_sparse.Y))
28142814
self.assertFalse(sp.issparse(iris_sparse.metas))
28152815

28162816
dense_iris = iris_sparse.to_dense()

Orange/widgets/utils/owlearnerwidget.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from copy import deepcopy
22

3-
import numpy as np
4-
53
from AnyQt.QtCore import QTimer, Qt
64

75
from Orange.data import Table
86
from Orange.modelling import Fitter, Learner, Model
97
from Orange.preprocess.preprocess import Preprocess
8+
from Orange.statistics import util as ut
109
from Orange.widgets import gui
1110
from Orange.widgets.settings import Setting
1211
from Orange.widgets.utils import getmembers
@@ -133,6 +132,7 @@ def set_data(self, data):
133132
"""Set the input train dataset."""
134133
self.Error.data_error.clear()
135134
self.data = data
135+
136136
if data is not None and data.domain.class_var is None:
137137
self.Error.data_error("Data has no target variable.")
138138
self.data = None
@@ -181,7 +181,7 @@ def check_data(self):
181181
self.Error.data_error(self.learner.learner_adequacy_err_msg)
182182
elif not len(self.data):
183183
self.Error.data_error("Dataset is empty.")
184-
elif len(np.unique(self.data.Y)) < 2:
184+
elif len(ut.unique(self.data.Y)) < 2:
185185
self.Error.data_error("Data contains a single target value.")
186186
elif self.data.X.size == 0:
187187
self.Error.data_error("Data has no features to learn from.")

Orange/widgets/utils/tests/test_owlearnerwidget.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import scipy.sparse as sp
2+
13
# pylint: disable=missing-docstring
24
from Orange.base import Learner, Model
35
from Orange.classification import KNNLearner
46
from Orange.data import Table, Domain
57
from Orange.modelling import TreeLearner
6-
from Orange.regression import MeanLearner
8+
from Orange.preprocess import continuize
9+
from Orange.regression import MeanLearner, LinearRegressionLearner
710
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
811
from Orange.widgets.tests.base import WidgetTest
912
from Orange.widgets.utils.signals import Output
@@ -108,3 +111,24 @@ class WidgetA(OWBaseLearner):
108111
settings = w1.settingsHandler.pack_data(w1)
109112
w2 = self.create_widget(WidgetA, settings)
110113
self.assertEqual(w2.learner_name, w1.learner_name)
114+
115+
def test_converts_sparse_targets_to_dense(self):
116+
class WidgetLR(OWBaseLearner):
117+
name = "lr"
118+
LEARNER = LinearRegressionLearner
119+
120+
w = self.create_widget(WidgetLR)
121+
122+
# Orange will want do do one-hot encoding when continuizing discrete variable
123+
pp = continuize.DomainContinuizer(
124+
multinomial_treatment=continuize.Continuize.AsOrdinal,
125+
transform_class=True,
126+
)
127+
data = self.iris.transform(pp(self.iris))
128+
data.Y = sp.csr_matrix(data.Y)
129+
130+
self.send_signal(w.Inputs.data, data, widget=w)
131+
self.assertFalse(any(w.Error.active))
132+
133+
model = self.get_output(w.Outputs.model, widget=w)
134+
self.assertIsNotNone(model)

0 commit comments

Comments
 (0)