Skip to content

Commit 1d80036

Browse files
committed
Fix impute.Model for derived domains
The compute_value was missing transformation into the variable space it was working upon.
1 parent ab868dc commit 1d80036

File tree

2 files changed

+28
-5
lines changed

2 files changed

+28
-5
lines changed

Orange/preprocess/impute.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def copy(self):
172172
return FixedValueByType(*self.defaults.values())
173173

174174

175-
class ReplaceUnknownsModel(Reprable):
175+
class ReplaceUnknownsModel(Transformation):
176176
"""
177177
Replace unknown values with predicted values using a `Orange.base.Model`
178178
@@ -185,15 +185,14 @@ class ReplaceUnknownsModel(Reprable):
185185
"""
186186
def __init__(self, variable, model):
187187
assert model.domain.class_var == variable
188-
self.variable = variable
188+
super().__init__(variable)
189189
self.model = model
190190

191191
def __call__(self, data):
192192
if isinstance(data, Orange.data.Instance):
193193
data = Orange.data.Table.from_list(data.domain, [data])
194194
domain = data.domain
195-
column = data.get_column(self.variable, copy=True)
196-
195+
column = data.transform(self._target_domain).get_column(self.variable, copy=True)
197196
mask = np.isnan(column)
198197
if not np.any(mask):
199198
return column
@@ -207,6 +206,9 @@ def __call__(self, data):
207206
column[mask] = predicted
208207
return column
209208

209+
def transform(self, c):
210+
assert False, "abstract in Transformation, never used here"
211+
210212
def __eq__(self, other):
211213
return type(self) is type(other) \
212214
and self.variable == other.variable \

Orange/tests/test_impute.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from Orange import preprocess
1010
from Orange.preprocess import impute, SklImpute
1111
from Orange import data
12-
from Orange.data import Unknown, Table
12+
from Orange.data import Unknown, Table, Domain
1313

1414
from Orange.classification import MajorityLearner, SimpleTreeLearner
1515
from Orange.regression import MeanLearner
@@ -293,6 +293,27 @@ def test_bad_domain(self):
293293
self.assertRaises(ValueError, imputer, data=table,
294294
variable=table.domain[0])
295295

296+
def test_missing_imputed_columns(self):
297+
housing = Table("housing")
298+
299+
learner = SimpleTreeLearner(min_instances=10, max_depth=10)
300+
method = preprocess.impute.Model(learner)
301+
302+
ivar = method(housing, housing.domain.attributes[0])
303+
imputed = housing.transform(
304+
Domain([ivar],
305+
housing.domain.class_var)
306+
)
307+
removed_imputed = imputed.transform(
308+
Domain([], housing.domain.class_var))
309+
310+
r = removed_imputed.transform(imputed.domain)
311+
312+
no_class = removed_imputed.transform(Domain(removed_imputed.domain.attributes, None))
313+
model_prediction_for_unknowns = ivar.compute_value.model(no_class[0])
314+
315+
np.testing.assert_equal(r.X, model_prediction_for_unknowns)
316+
296317

297318
class TestRandom(unittest.TestCase):
298319
def test_replacement(self):

0 commit comments

Comments
 (0)