Skip to content

Commit 838e5fa

Browse files
authored
Merge pull request #5083 from ales-erjavec/fixes/model-impute-no-backmap
[FIX] impute: Remove class vars from input data for ReplaceUnknownsModel
2 parents 6e88010 + d061f53 commit 838e5fa

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

Orange/preprocess/impute.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,20 @@ def __init__(self, variable, model):
160160

161161
def __call__(self, data):
162162
if isinstance(data, Orange.data.Instance):
163-
column = np.array([float(data[self.variable])])
164-
else:
165-
column = np.array(data.get_column_view(self.variable)[0],
166-
copy=True)
163+
data = Orange.data.Table.from_list(data.domain, [data])
164+
domain = data.domain
165+
column = np.array(data.get_column_view(self.variable)[0], copy=True)
167166

168167
mask = np.isnan(column)
169168
if not np.any(mask):
170169
return column
171170

172-
if isinstance(data, Orange.data.Instance):
173-
predicted = self.model(data)
174-
else:
175-
predicted = self.model(data[mask])
171+
if domain.class_vars:
172+
# cannot have class var in domain (due to backmappers in model)
173+
data = data.transform(
174+
Orange.data.Domain(domain.attributes, None, domain.metas)
175+
)
176+
predicted = self.model(data[mask])
176177
column[mask] = predicted
177178
return column
178179

Orange/tests/test_impute.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,12 @@ def test_replacement(self):
229229
domain = data.Domain(
230230
(data.DiscreteVariable("A", values=("0", "1", "2")),
231231
data.ContinuousVariable("B"),
232-
data.ContinuousVariable("C"))
232+
data.ContinuousVariable("C")),
233+
# the class is here to ensure the backmapper in model does not
234+
# run and raise exception
235+
data.DiscreteVariable("Z", values=("P", "M"))
233236
)
234-
table = data.Table.from_numpy(domain, np.array(X))
237+
table = data.Table.from_numpy(domain, np.array(X), [0,] * 3)
235238

236239
v = impute.Model(MajorityLearner())(table, domain[0])
237240
self.assertTrue(np.all(np.isfinite(v.compute_value(table))))

0 commit comments

Comments
 (0)