Skip to content

Commit 19bff36

Browse files
authored
Merge pull request #5102 from janezd/impute-constant-all
[ENH] Impute: Allow setting a default value for all numeric and time variables
2 parents 4a9f492 + 8286474 commit 19bff36

File tree

4 files changed

+200
-12
lines changed

4 files changed

+200
-12
lines changed

Orange/preprocess/impute.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .transformation import Transformation, Lookup
88

99
__all__ = ["ReplaceUnknowns", "Average", "DoNotImpute", "DropInstances",
10-
"Model", "AsValue", "Random", "Default"]
10+
"Model", "AsValue", "Random", "Default", "FixedValueByType"]
1111

1212

1313
class ReplaceUnknowns(Transformation):
@@ -113,6 +113,10 @@ def __call__(self, data, variable, value=None):
113113
a.to_sql = ImputeSql(variable, value)
114114
return a
115115

116+
@staticmethod
117+
def supports_variable(variable):
118+
return variable.is_primitive()
119+
116120

117121
class ImputeSql(Reprable):
118122
def __init__(self, var, default):
@@ -124,7 +128,7 @@ def __call__(self):
124128

125129

126130
class Default(BaseImputeMethod):
127-
name = "Value"
131+
name = "Fixed value"
128132
short_name = "value"
129133
description = ""
130134
columns_only = True
@@ -142,6 +146,32 @@ def copy(self):
142146
return Default(self.default)
143147

144148

149+
class FixedValueByType(BaseImputeMethod):
150+
name = "Fixed value"
151+
short_name = "Fixed Value"
152+
format = "{var.name}"
153+
154+
def __init__(self,
155+
default_discrete=np.nan, default_continuous=np.nan,
156+
default_string=None, default_time=np.nan):
157+
# If you change the order of args or in dict, also fix method copy
158+
self.defaults = {
159+
Orange.data.DiscreteVariable: default_discrete,
160+
Orange.data.ContinuousVariable: default_continuous,
161+
Orange.data.StringVariable: default_string,
162+
Orange.data.TimeVariable: default_time
163+
}
164+
165+
def __call__(self, data, variable, *, default=None):
166+
variable = data.domain[variable]
167+
if default is None:
168+
default = self.defaults[type(variable)]
169+
return variable.copy(compute_value=ReplaceUnknowns(variable, default))
170+
171+
def copy(self):
172+
return FixedValueByType(*self.defaults.values())
173+
174+
145175
class ReplaceUnknownsModel(Reprable):
146176
"""
147177
Replace unknown values with predicted values using a `Orange.base.Model`
@@ -272,6 +302,9 @@ def __call__(self, data, variable):
272302
else:
273303
raise TypeError(type(variable))
274304

305+
@staticmethod
306+
def supports_variable(variable):
307+
return variable.is_primitive()
275308

276309
class ReplaceUnknownsRandom(Transformation):
277310
"""
@@ -354,3 +387,7 @@ def __call__(self, data, variable):
354387
dist[1, :] += 1 / dist.shape[1]
355388
return variable.copy(
356389
compute_value=ReplaceUnknownsRandom(variable, dist))
390+
391+
@staticmethod
392+
def supports_variable(variable):
393+
return variable.is_primitive()

Orange/preprocess/tests/test_impute.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import unittest
22

3-
from Orange.data import DiscreteVariable, ContinuousVariable
4-
from Orange.preprocess.impute import ReplaceUnknownsRandom, ReplaceUnknowns
3+
import numpy as np
4+
5+
from Orange.data import \
6+
Domain, Table, \
7+
DiscreteVariable, ContinuousVariable, TimeVariable, StringVariable
8+
from Orange.preprocess.impute import ReplaceUnknownsRandom, ReplaceUnknowns, \
9+
FixedValueByType
510
from Orange.statistics.distribution import Discrete
611

712

@@ -52,5 +57,59 @@ def test_equality(self):
5257
self.assertNotEqual(hash(t1), hash(t1a))
5358

5459

60+
class TestFixedValuesByType(unittest.TestCase):
61+
def setUp(self):
62+
domain = Domain(
63+
[DiscreteVariable("d", values=tuple("abc")),
64+
ContinuousVariable("c"),
65+
TimeVariable("t")],
66+
[],
67+
[StringVariable("s")]
68+
)
69+
n = np.nan
70+
self.data = Table(
71+
domain,
72+
np.array([[1, n, 15], [n, 42, n]]),
73+
np.empty((2, 0)),
74+
np.array([["foo"], [""]]))
75+
76+
def test_none_defined(self):
77+
d, c, t = self.data.domain.attributes
78+
s, = self.data.domain.metas
79+
80+
imputer = FixedValueByType()
81+
for var in (d, c, t):
82+
imp = imputer(self.data, var)
83+
self.assertIsInstance(imp.compute_value, ReplaceUnknowns)
84+
self.assertTrue(np.isnan(imp.compute_value.value))
85+
imp = imputer(self.data, s)
86+
self.assertIsInstance(imp.compute_value, ReplaceUnknowns)
87+
self.assertIsNone(imp.compute_value.value)
88+
89+
def test_all_defined(self):
90+
d, c, t = self.data.domain.attributes
91+
s, = self.data.domain.metas
92+
93+
imputer = FixedValueByType(
94+
default_discrete=1, default_continuous=42,
95+
default_string="foo", default_time=3.14)
96+
97+
self.assertEqual(imputer(self.data, d).compute_value.value, 1)
98+
self.assertEqual(imputer(self.data, c).compute_value.value, 42)
99+
self.assertEqual(imputer(self.data, t).compute_value.value, 3.14)
100+
self.assertEqual(imputer(self.data, s).compute_value.value, "foo")
101+
102+
def test_with_default(self):
103+
s, = self.data.domain.metas
104+
105+
imputer = FixedValueByType(
106+
default_discrete=1, default_continuous=42,
107+
default_string="foo", default_time=3.14)
108+
109+
self.assertEqual(
110+
imputer(self.data, s, default="bar").compute_value.value,
111+
"bar")
112+
113+
55114
if __name__ == "__main__":
56115
unittest.main()

Orange/widgets/data/owimpute.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@
1111
from AnyQt.QtWidgets import (
1212
QGroupBox, QRadioButton, QPushButton, QHBoxLayout, QGridLayout,
1313
QVBoxLayout, QStackedWidget, QComboBox,
14-
QButtonGroup, QStyledItemDelegate, QListView, QDoubleSpinBox
15-
)
16-
from AnyQt.QtCore import Qt, QThread, QModelIndex
14+
QButtonGroup, QStyledItemDelegate, QListView, QDoubleSpinBox, QLabel)
15+
from AnyQt.QtCore import Qt, QThread, QModelIndex, QDateTime, QLocale
1716
from AnyQt.QtCore import pyqtSlot as Slot
17+
from AnyQt.QtGui import QDoubleValidator
18+
1819
from orangewidget.utils.listview import ListViewSearch
1920

2021
import Orange.data
@@ -154,6 +155,8 @@ class Warning(OWWidget.Warning):
154155
_variable_imputation_state = settings.ContextSetting({}) # type: VariableState
155156

156157
autocommit = settings.Setting(True)
158+
default_numeric = settings.Setting("")
159+
default_time = settings.Setting(0)
157160

158161
want_main_area = False
159162
resizing_enabled = False
@@ -171,11 +174,13 @@ def __init__(self):
171174
main_layout.setContentsMargins(10, 10, 10, 10)
172175
self.controlArea.layout().addLayout(main_layout)
173176

174-
box = QGroupBox(title=self.tr("Default Method"), flat=False)
175-
box_layout = QGridLayout(box)
176-
box_layout.setContentsMargins(5, 0, 0, 0)
177+
box = gui.vBox(None, "Default Method")
177178
main_layout.addWidget(box)
178179

180+
box_layout = QGridLayout(box)
181+
box_layout.setSpacing(8)
182+
box.layout().addLayout(box_layout)
183+
179184
button_group = QButtonGroup()
180185
button_group.buttonClicked[int].connect(self.set_default_method)
181186

@@ -186,6 +191,41 @@ def __init__(self):
186191
button_group.addButton(button, method)
187192
box_layout.addWidget(button, i % 3, i // 3)
188193

194+
def set_to_fixed_value():
195+
self.set_default_method(Method.Default)
196+
197+
def set_default_time(datetime):
198+
self.default_time = datetime.toSecsSinceEpoch()
199+
if self.default_method_index != Method.Default:
200+
set_to_fixed_value()
201+
else:
202+
self._invalidate()
203+
204+
hlayout = QHBoxLayout()
205+
box.layout().addLayout(hlayout)
206+
button = QRadioButton("Fixed values; numeric variables:")
207+
button_group.addButton(button, Method.Default)
208+
button.setChecked(Method.Default == self.default_method_index)
209+
hlayout.addWidget(button)
210+
211+
locale = QLocale()
212+
locale.setNumberOptions(locale.NumberOption.RejectGroupSeparator)
213+
validator = QDoubleValidator()
214+
validator.setLocale(locale)
215+
le = gui.lineEdit(
216+
None, self, "default_numeric",
217+
validator=validator, alignment=Qt.AlignRight,
218+
callback=self._invalidate, focusInCallback=set_to_fixed_value)
219+
hlayout.addWidget(le)
220+
221+
hlayout.addWidget(QLabel(", time:"))
222+
223+
self.time_widget = gui.DateTimeEditWCalendarTime(self)
224+
self.time_widget.setContentsMargins(0, 0, 0, 0)
225+
self.default_time = QDateTime.currentDateTime().toSecsSinceEpoch()
226+
self.time_widget.dateTimeChanged.connect(set_default_time)
227+
hlayout.addWidget(self.time_widget)
228+
189229
self.default_button_group = button_group
190230

191231
box = QGroupBox(title=self.tr("Individual Attribute Settings"),
@@ -267,6 +307,17 @@ def create_imputer(self, method, *args):
267307
m = AsDefault()
268308
m.method = default
269309
return m
310+
elif method == Method.Default and not args: # global default values
311+
if self.default_numeric == "":
312+
default_num = np.nan
313+
else:
314+
default_num, ok = QLocale().toDouble(self.default_numeric)
315+
if not ok:
316+
default_num = np.nan
317+
return impute.FixedValueByType(
318+
default_continuous=default_num,
319+
default_time=self.default_time or np.nan
320+
)
270321
else:
271322
return METHODS[method](*args)
272323

@@ -302,6 +353,8 @@ def set_data(self, data):
302353
if data is not None:
303354
self.varmodel[:] = data.domain.variables
304355
self.openContext(data.domain)
356+
self.time_widget.set_datetime(
357+
QDateTime.fromSecsSinceEpoch(self.default_time))
305358
# restore per variable imputation state
306359
self._restore_state(self._variable_imputation_state)
307360

@@ -660,5 +713,19 @@ def storeSpecificSettings(self):
660713
super().storeSpecificSettings()
661714

662715

716+
def __sample_data(): # pragma: no cover
717+
domain = Orange.data.Domain(
718+
[Orange.data.ContinuousVariable(f"c{i}") for i in range(3)]
719+
+ [Orange.data.TimeVariable(f"t{i}") for i in range(3)],
720+
[])
721+
n = np.nan
722+
x = np.array([
723+
[1, 2, n, 1000, n, n],
724+
[2, n, 1, n, 2000, 2000]
725+
])
726+
return Orange.data.Table(domain, x, np.empty((2, 0)))
727+
728+
663729
if __name__ == "__main__": # pragma: no cover
730+
# WidgetPreview(OWImpute).run(__sample_data())
664731
WidgetPreview(OWImpute).run(Orange.data.Table("brown-selected"))

Orange/widgets/data/tests/test_owimpute.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
from unittest.mock import Mock
44
import numpy as np
55

6-
from AnyQt.QtCore import Qt, QItemSelection
6+
from AnyQt.QtCore import Qt, QItemSelection, QLocale
77
from AnyQt.QtTest import QTest
88

9-
from Orange.data import Table, Domain
9+
from Orange.data import Table, Domain, ContinuousVariable, TimeVariable
1010
from Orange.preprocess import impute
1111
from Orange.widgets.data.owimpute import OWImpute, AsDefault, Learner, Method
1212
from Orange.widgets.tests.base import WidgetTest
@@ -119,6 +119,31 @@ def test_select_method(self):
119119
self.assertIsInstance(widget.get_method_for_column(0), AsDefault)
120120
self.assertIsInstance(widget.get_method_for_column(2), AsDefault)
121121

122+
def test_overall_default(self):
123+
domain = Domain(
124+
[ContinuousVariable(f"c{i}") for i in range(3)]
125+
+ [TimeVariable(f"t{i}") for i in range(3)],
126+
[])
127+
n = np.nan
128+
x = np.array([
129+
[1, 2, n, 1000, n, n],
130+
[2, n, 1, n, 2000, 2000]
131+
])
132+
data = Table(domain, x, np.empty((2, 0)))
133+
134+
widget = self.widget
135+
widget.default_numeric = QLocale().toString(3.14)
136+
widget.default_time = 42
137+
widget.default_method_index = Method.Default
138+
139+
self.send_signal(self.widget.Inputs.data, data)
140+
imp_data = self.get_output(self.widget.Outputs.data)
141+
np.testing.assert_almost_equal(
142+
imp_data.X,
143+
[[1, 2, 3.14, 1000, 42, 42],
144+
[2, 3.14, 1, 42, 2000, 2000]]
145+
)
146+
122147
def test_value_edit(self):
123148
data = Table("heart_disease")[::10]
124149
self.send_signal(self.widget.Inputs.data, data)

0 commit comments

Comments
 (0)