Skip to content

Commit 2cfb6ed

Browse files
authored
PipeOpImpute predict-imputes factors when empty_level_control == "never"
1 parent 41d75be commit 2cfb6ed

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

R/PipeOpImpute.R

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
#' during prediction will *not* be imputed.
3131
#' - If set to `"always"`, an unseen level is added to the feature during training and missing values are imputed as
3232
#' that value during prediction.
33-
#' - Finally, if set to `"param"`, the hyperparameter `create_empty_levels` is added and control over this behavior is
33+
#' - Finally, if set to `"param"`, the hyperparameter `create_empty_level` is added and control over this behavior is
3434
#' left to the user.
3535
#'
3636
#' For implementation details, see Internals below. Default is `"never"`.
@@ -160,6 +160,7 @@ PipeOpImpute = R6Class("PipeOpImpute",
160160
private$.create_empty_level = FALSE
161161
emplvls_control_ps = ps()
162162
} else if (empty_level_control == "param") {
163+
private$.create_empty_level = NULL
163164
# Setting create_empty_level modifies private$.create_empty_field later in train and predict
164165
emplvls_control_ps = ps(create_empty_level = p_lgl(init = FALSE, tags = c("train", "predict")))
165166
}
@@ -206,10 +207,10 @@ PipeOpImpute = R6Class("PipeOpImpute",
206207
intask = inputs[[1]]$clone(deep = TRUE)
207208
pv = self$param_set$get_values(tags = "train")
208209

209-
# If the hyperparameter exists, we overwrite the private field here, and can simply check the private field after
210-
# this without having to check conditions on both the hyperparameter and the private field
210+
# If the hyperparameter exists, then private$.create_empty_level is NULL and will be ignored
211+
create_empty_level = private$.create_empty_level
211212
if (!is.null(pv$create_empty_level)) {
212-
private$.create_empty_level = pv$create_empty_level
213+
create_empty_level = pv$create_empty_level
213214
}
214215

215216
affected_cols = (pv$affect_columns %??% selector_all())(intask)
@@ -227,9 +228,9 @@ PipeOpImpute = R6Class("PipeOpImpute",
227228
}
228229

229230
imputanda = intask$data(cols = affected_cols)
230-
if (private$.create_empty_level) {
231+
if (create_empty_level) {
231232
# Also run impute on all factor/ordered columns that don't have any NAs
232-
imputanda = imputanda[, map_lgl(imputanda, function(x) anyMissing(x) || is.factor(x)), with = FALSE]
233+
imputanda = imputanda[, map_lgl(imputanda, function(x) is.factor(x) || anyMissing(x)), with = FALSE]
233234
} else {
234235
imputanda = imputanda[, map_lgl(imputanda, function(x) anyMissing(x)), with = FALSE]
235236
}
@@ -278,18 +279,21 @@ PipeOpImpute = R6Class("PipeOpImpute",
278279
context_data = intask$data(cols = self$state$context_cols)
279280
}
280281

281-
# If the hyperparameter exists, we overwrite the private field here, and can simply check the private field after
282-
# this without having to check conditions on both the hyperparameter and the private field
282+
# If the hyperparameter exists and is set to FALSE, we do not impute factor cols that had no missings during train.
283+
# If the HP does not exist, then we always call impute, since imputing will either not add a new factor
284+
# (empty_level_control = "never") or the new factor will have been taken care of (empty_level_control = "always")
283285
pv = self$param_set$get_values(tags = "predict")
284286
if (!is.null(pv$create_empty_level)) {
285-
private$.create_empty_level = pv$create_empty_level
287+
predict_all_factors = pv$create_empty_level
288+
} else {
289+
predict_all_factors = TRUE
286290
}
287291

288292
imputanda = intask$data(cols = self$state$affected_cols)
289-
if (!private$.create_empty_level) {
293+
if (!predict_all_factors) {
290294
# Don't run impute for factor/ordered columns that were not imputed during training
291295
imputanda = imputanda[,
292-
colnames(imputanda) %in% self$state$imputed_train | map_lgl(imputanda, function(x) anyMissing(x) && !is.factor(x)),
296+
colnames(imputanda) %in% self$state$imputed_train | map_lgl(imputanda, function(x) !is.factor(x) && anyMissing(x)),
293297
with = FALSE]
294298
} else {
295299
imputanda = imputanda[,

0 commit comments

Comments
 (0)