Skip to content

Commit ef3465d

Browse files
authored
Merge pull request #6668 from markotoplak/fix-impute-model
[FIX] Fix impute.Model for derived domains
2 parents d5de749 + 0547537 commit ef3465d

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

Orange/preprocess/impute.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import scipy.sparse as sp
33

44
import Orange.data
5+
from Orange.data.table import DomainTransformationError
56
from Orange.statistics import distribution, basic_stats
67
from Orange.util import Reprable
78
from .transformation import Transformation, Lookup
@@ -172,7 +173,7 @@ def copy(self):
172173
return FixedValueByType(*self.defaults.values())
173174

174175

175-
class ReplaceUnknownsModel(Reprable):
176+
class ReplaceUnknownsModel(Transformation):
176177
"""
177178
Replace unknown values with predicted values using a `Orange.base.Model`
178179
@@ -185,15 +186,14 @@ class ReplaceUnknownsModel(Reprable):
185186
"""
186187
def __init__(self, variable, model):
187188
assert model.domain.class_var == variable
188-
self.variable = variable
189+
super().__init__(variable)
189190
self.model = model
190191

191192
def __call__(self, data):
192193
if isinstance(data, Orange.data.Instance):
193194
data = Orange.data.Table.from_list(data.domain, [data])
194195
domain = data.domain
195-
column = data.get_column(self.variable, copy=True)
196-
196+
column = data.transform(self._target_domain).get_column(self.variable, copy=True)
197197
mask = np.isnan(column)
198198
if not np.any(mask):
199199
return column
@@ -203,10 +203,17 @@ def __call__(self, data):
203203
data = data.transform(
204204
Orange.data.Domain(domain.attributes, None, domain.metas)
205205
)
206-
predicted = self.model(data[mask])
207-
column[mask] = predicted
206+
try:
207+
column[mask] = self.model(data[mask])
208+
except DomainTransformationError:
209+
# owpredictions showed error when imputing target using a Model
210+
# based imputer (owpredictions removes the target before predicing)
211+
pass
208212
return column
209213

214+
def transform(self, c):
215+
assert False, "abstract in Transformation, never used here"
216+
210217
def __eq__(self, other):
211218
return type(self) is type(other) \
212219
and self.variable == other.variable \

Orange/tests/test_impute.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from Orange import preprocess
1010
from Orange.preprocess import impute, SklImpute
1111
from Orange import data
12-
from Orange.data import Unknown, Table
12+
from Orange.data import Unknown, Table, Domain
1313

1414
from Orange.classification import MajorityLearner, SimpleTreeLearner
1515
from Orange.regression import MeanLearner
@@ -293,6 +293,27 @@ def test_bad_domain(self):
293293
self.assertRaises(ValueError, imputer, data=table,
294294
variable=table.domain[0])
295295

296+
def test_missing_imputed_columns(self):
297+
housing = Table("housing")
298+
299+
learner = SimpleTreeLearner(min_instances=10, max_depth=10)
300+
method = preprocess.impute.Model(learner)
301+
302+
ivar = method(housing, housing.domain.attributes[0])
303+
imputed = housing.transform(
304+
Domain([ivar],
305+
housing.domain.class_var)
306+
)
307+
removed_imputed = imputed.transform(
308+
Domain([], housing.domain.class_var))
309+
310+
r = removed_imputed.transform(imputed.domain)
311+
312+
no_class = removed_imputed.transform(Domain(removed_imputed.domain.attributes, None))
313+
model_prediction_for_unknowns = ivar.compute_value.model(no_class[0])
314+
315+
np.testing.assert_equal(r.X, model_prediction_for_unknowns)
316+
296317

297318
class TestRandom(unittest.TestCase):
298319
def test_replacement(self):

i18n/si.jaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2685,6 +2685,9 @@ preprocess/impute.py:
26852685
def `__call__`:
26862686
"'{}' has no values": false
26872687
"'{}' has an unknown distribution": false
2688+
class `ReplaceUnknownsModel`:
2689+
def `transform`:
2690+
abstract in Transformation, never used here: false
26882691
preprocess/normalize.py:
26892692
Normalizer: false
26902693
preprocess/preprocess.py:

0 commit comments

Comments
 (0)