Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 111 additions & 20 deletions R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' @description
#' Wraps an [`mlr3::Learner`] into a [`PipeOp`].
#'
#' Returns cross-validated predictions during training as a [`Task`][mlr3::Task] and stores a model of the
#' Returns resampled predictions during training as a [`Task`][mlr3::Task] and stores a model of the
#' [`Learner`][mlr3::Learner] trained on the whole data in `$state`. This is used to create a similar
#' [`Task`][mlr3::Task] during prediction.
#'
Expand All @@ -19,7 +19,7 @@
#' Inherits the `$param_set` (and therefore `$param_set$values`) from the [`Learner`][mlr3::Learner] it is constructed from.
#'
#' [`PipeOpLearnerCV`] can be used to create "stacking" or "super learning" [`Graph`]s that use the output of one [`Learner`][mlr3::Learner]
#' as feature for another [`Learner`][mlr3::Learner]. Because the [`PipeOpLearnerCV`] erases the original input features, it is often
#' as features for another [`Learner`][mlr3::Learner]. Because the [`PipeOpLearnerCV`] erases the original input features, it is often
#' useful to use [`PipeOpFeatureUnion`] to bind the prediction [`Task`][mlr3::Task] to the original input [`Task`][mlr3::Task].
#'
#' @section Construction:
Expand All @@ -28,8 +28,7 @@
#' ```
#'
#' * `learner` :: [`Learner`][mlr3::Learner] \cr
#' [`Learner`][mlr3::Learner] to use for cross validation / prediction, or a string identifying a
#' [`Learner`][mlr3::Learner] in the [`mlr3::mlr_learners`] [`Dictionary`][mlr3misc::Dictionary].
#' [`Learner`][mlr3::Learner] to use for resampling / prediction.
#' * `id` :: `character(1)`
#' Identifier of the resulting object, internally defaulting to the `id` of the [`Learner`][mlr3::Learner] being wrapped.
#' * `param_vals` :: named `list`\cr
Expand All @@ -43,7 +42,7 @@
#' type given to `learner` during construction; both during training and prediction.
#'
#' The output is a task with the same target as the input task, with features replaced by predictions made by the [`Learner`][mlr3::Learner].
#' During training, this prediction is the out-of-sample prediction made by [`resample`][mlr3::resample], during prediction, this is the
#' During training, this prediction is the prediction made by [`resample`][mlr3::resample], during prediction, this is the
#' ordinary prediction made on the data by a [`Learner`][mlr3::Learner] trained on the training phase data.
#'
#' @section State:
Expand All @@ -64,10 +63,24 @@
#' The parameters are the parameters inherited from the [`PipeOpTaskPreproc`], as well as the parameters of the [`Learner`][mlr3::Learner] wrapped by this object.
#' Besides that, parameters introduced are:
#' * `resampling.method` :: `character(1)`\cr
#' Which resampling method do we want to use. Currently only supports `"cv"` and `"insample"`. `"insample"` generates
#' predictions with the model trained on all training data.
#' * `resampling.folds` :: `numeric(1)`\cr
#' Number of cross validation folds. Initialized to 3. Only used for `resampling.method = "cv"`.
#' Which resampling method to use. Supports `"cv"`,`"bootstrap"`, `"holdout"`, `"loo"`, `"repeated_cv"`, `"subsampling"`, `"custom"` and `"insample"`.
#' See [`mlr_resamplings`][mlr3::mlr_resamplings].
#' `"insample"` generates predictions with the model trained on all training data.
#' In the case of the resampling method returing multiple predictions per row id, the predictions are aggregated via their mean
#' (execpt for the `"response"` in the case of a [classification Task][mlr3::TaskClassif] which is aggregated using the mode).
#' In the case of the resampling method not returning predictions for all row ids as given in the input [`Task`][mlr3::Task], these predictions are added as missing.
#' * `resampling.repeats` :: `integer(1)`\cr
#' Number of repetitions. Initialized to 30. Only used for `resampling.method = "bootstrap"`, or `"repeated_cv"`, or `"subsampling"`.
#' * `resampling.folds` :: `integer(1)`\cr
#' Number of cross validation folds. Initialized to 3. Only used for `resampling.method = "cv"`, or `"repeated_cv"`.
#' * `resampling.ratio` :: `numeric(1)`\cr
#' Ratio of observations to put into the training set. Initialized to 2/3. Only used for `resampling.method = "bootstrap"`, or `"holdout"` or `"subsampling"`.
#' * `resampling.custom.train_sets` :: `list()`\cr
#' List with row ids for training, one list element per iteration. Must have the same length as `resampling.custom.test_sets`.
#' Only used for `resampling.method = "custom"`.
#' * `resampling.custom.test_sets` :: `list()`\cr
#' List with row ids for testing, one list element per iteration. Must have the same length as `resampling.custom.train_sets`.
#' Only used for `resampling.method = "custom"`.
#' * `keep_response` :: `logical(1)`\cr
#' Only effective during `"prob"` prediction: Whether to keep response values, if available. Initialized to `FALSE`.
#'
Expand Down Expand Up @@ -121,11 +134,15 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
task_type = mlr_reflections$task_types[get("type") == private$.learner$task_type][order(get("package"))][1L]$task

private$.crossval_param_set = ParamSet$new(params = list(
ParamFct$new("method", levels = c("cv", "insample"), tags = c("train", "required")),
ParamFct$new("method", levels = c("bootstrap", "custom", "cv", "holdout", "insample", "loo", "repeated_cv", "subsampling"), tags = c("train", "required")),
ParamInt$new("repeats", lower = 1L, tags = c("train", "required")),
ParamInt$new("folds", lower = 2L, upper = Inf, tags = c("train", "required")),
ParamLgl$new("keep_response", tags = c("train", "required"))
ParamDbl$new("ratio", lower = 0, upper = 1, tags = c("train", "required")),
ParamLgl$new("keep_response", tags = c("train", "required")),
ParamUty$new("custom.train_sets", tags = "train", custom_check = function(x) check_list(x, types = "atomicvector", any.missing = FALSE)),
ParamUty$new("custom.test_sets", tags = "train", custom_check = function(x) check_list(x, types = "atomicvector", any.missing = FALSE))
))
private$.crossval_param_set$values = list(method = "cv", folds = 3, keep_response = FALSE)
private$.crossval_param_set$values = list(method = "cv", repeats = 30L, folds = 3L, ratio = 2 / 3, keep_response = FALSE)
private$.crossval_param_set$set_id = "resampling"
# Dependencies in paradox have been broken from the start and this is known since at least a year:
# https://github.com/mlr-org/paradox/issues/216
Expand Down Expand Up @@ -169,17 +186,83 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
self$state = private$.learner$train(task)$state
pv = private$.crossval_param_set$values

# Compute CV Predictions
if (pv$method != "insample") {
rdesc = mlr_resamplings$get(pv$method)
if (pv$method == "cv") rdesc$param_set$values = list(folds = pv$folds)
rr = resample(task, private$.learner, rdesc)
prds = as.data.table(rr$prediction(predict_sets = "test"))
if (pv$method == "insample") {
return(private$pred_to_task(as.data.table(private$.learner$predict(task)), task)) # early exit
}

# Compute resampled Predictions
rdesc = mlr_resamplings$get(pv$method)
rdesc$param_set$values = switch(pv$method,
"bootstrap" = list(repeats = pv$repeats, ratio = pv$ratio),
"custom" = list(),
"cv" = list(folds = pv$folds),
"holdout" = list(ratio = pv$ratio),
"loo" = list(),
"repeated_cv" = list(repeats = pv$repeats, folds = pv$folds),
"subsampling" = list(repeats = pv$repeats, ratio = pv$ratio))
if (pv$method == "custom") {
rdesc$instantiate(task, train_sets = private$.crossval_param_set$values$custom.train_sets, test_sets = private$.crossval_param_set$values$custom.test_sets)
}
# FIXME: we may want to instantiate here in general for safety reasons
rr = resample(task, private$.learner, rdesc)
prds = as.data.table(rr$prediction(predict_sets = "test"))
nrows_multiple = length(prds$row_id[duplicated(prds$row_id)])
missing_rows = setdiff(task$row_ids, prds$row_id)
nrows_missing = length(missing_rows)

if (!nrows_multiple && !nrows_missing) {
return(private$pred_to_task(prds, task)) # early exit
}

# Some resamplings will result in rows being sampled multiple times and some being missing
task_type = task$task_type
prds_names = colnames(prds)

prds_corrected = if (nrows_multiple) {
# classif: prob, regr: response, (se)
SDcols_multiple = setdiff(prds_names, if (task_type == "classif") c("row_id", "truth", "response") else c("row_id", "truth"))

# aggregation functions:
# - mean for prob, response (regr), se
# - mode for response (classif)
prds_corrected = prds[, map(.SD, function(x) {
if (length(x) == 1L) return(x) # early exit
mean(x, na.rm = TRUE)
}), by = "row_id", .SDcols = SDcols_multiple]

if (NROW(prds_corrected) == 0L) prds_corrected = unique(prds[, "row_id"])

if (task_type == "classif") {
cbind(prds_corrected, prds[, map(.SD, function(x) {
if (length(x) == 1L) return(as.character(x)) # early exit
tt = table(x)
names(tt[which.max(tt)])
}), by = "row_id", .SDcols = "response"][, "response"])
} else {
prds_corrected
}
} else {
prds = as.data.table(private$.learner$predict(task))
if (task_type == "classif") {
prds[, "response" := as.character(response)]
}
prds[, !"truth"]
}

if (nrows_missing) {
SDcols_missing = setdiff(prds_names, "truth")
# add missings
prds_corrected = prds_corrected[, map(.SD, add_missings, len = nrows_missing), .SDcols = SDcols_missing]
prds_corrected$row_id[is.na(prds_corrected$row_id)] = missing_rows
}

private$pred_to_task(prds, task)
if (task_type == "classif") {
target = task$truth(prds_corrected$row_id)
prds_corrected$response = factor(prds_corrected$response, levels = levels(target), ordered = is.ordered(target))
}

# FIXME: safety cheks?

private$pred_to_task(prds_corrected, task)
},

.predict_task = function(task) {
Expand All @@ -204,4 +287,12 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
)
)

# Helper function to add missings to predictions based on their storage mode
add_missings = function(x, len) {
c(x, switch(typeof(x),
"character" = rep_len(NA_character_, length.out = len),
"double" = rep_len(NA_real_, length.out = len),
"integer" = rep_len(NA_integer_, length.out = len)))
}

mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ParamSet$new()))$new()))
4 changes: 2 additions & 2 deletions man/Graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_pipeops_histbin.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 22 additions & 9 deletions man/mlr_pipeops_learner_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_pipeops_nmf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_pipeops_targetmutate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading