|
| 1 | +#' @title Wrap a Learner into a PipeOp with Cross-validation Plus Confidence Intervals as Predictions |
| 2 | +#' |
| 3 | +#' @usage NULL |
| 4 | +#' @name mlr_pipeops_learner_pi_cvplus |
| 5 | +#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOp`]. |
| 6 | +#' |
| 7 | +#' @description |
| 8 | +#' Wraps an [`mlr3::Learner`] into a [`PipeOp`]. |
| 9 | +#' |
| 10 | +#' Inherits the `$param_set` (and therefore `$param_set$values`) from the [`Learner`][mlr3::Learner] it is constructed from. |
| 11 | +#' |
| 12 | +#' Using [`PipeOpLearnerPICVPlus`], it is possible to embed a [`mlr3::Learner`] into a [`Graph`]. |
| 13 | +#' [`PipeOpLearnerPICVPlus`] can then be used to perform cross validation plus (or jackknife plus). |
| 14 | +#' During training, [`PipeOpLearnerPICVPlus`] performs cross validation on the training data. |
| 15 | +#' During prediction, the models from the training stage are used to construct predictive confidence intervals for the prediction data based on |
| 16 | +#' out-of-fold residuals and out-of-fold predictions. |
| 17 | +#' |
| 18 | +#' @section Construction: |
| 19 | +#' ``` |
| 20 | +#' PipeOpLearnerPICVPlus$new(learner, id = NULL, param_vals = list()) |
| 21 | +#' ``` |
| 22 | +#' |
| 23 | +#' * `learner` :: [`LearnerRegr`][mlr3::LearnerRegr] |
| 24 | +#' [`LearnerRegr`][mlr3::LearnerRegr] to use for the cross validation models in the Cross Validation Plus method. |
| 25 | +#' This argument is always cloned; to access the [`Learner`][mlr3::Learner] inside `PipeOpLearnerPICVPlus` by-reference, use `$learner`.\cr |
| 26 | +#' * `id` :: `character(1)` |
| 27 | +#' Identifier of the resulting object, internally defaulting to the `id` of the [`Learner`][mlr3::Learner] being wrapped. |
| 28 | +#' * `param_vals` :: named `list`\cr |
| 29 | +#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. |
| 30 | +#' Default is `list()`. |
| 31 | +#' |
| 32 | +#' @section Input and Output Channels: |
| 33 | +#' [`PipeOpLearnerPICVPlus`] has one input channel named `"input"`, taking a [`Task`][mlr3::Task] specific to the [`Learner`][mlr3::Learner] |
| 34 | +#' type given to `learner` during construction; both during training and prediction. |
| 35 | +#' |
| 36 | +#' [`PipeOpLearnerPICVPlus`] has one output channel named `"output"`, producing `NULL` during training and a [`PredictionRegr`][mlr3::PredictionRegr] |
| 37 | +#' during prediction. |
| 38 | +#' |
| 39 | +#' The output during prediction is a [`PredictionRegr`][mlr3::PredictionRegr] with `predict_type` `quantiles` on the prediction input data. |
| 40 | +#' The `alpha` and `1 - alpha` quantiles are the `quantiles` of the prediction interval produced by the cross validation plus method. |
| 41 | +#' The `response` is the median of the prediction of all cross validation models on the prediction data. |
| 42 | +#' |
| 43 | +#' @section State: |
| 44 | +#' The `$state` is a named `list` with members: |
| 45 | +#' * `cv_model_states` :: `list`\cr |
| 46 | +#' List of the state of each cross validation model created by the [`Learner`][`mlr3::Learner`]'s `$.train()` function during resampling with method `"cv"`. |
| 47 | +#' * `residuals` :: `data.table`\cr |
| 48 | +#' `data.table` with columns `fold` and `residual`. Lists the Regression residuals for each observation and cross validation fold. |
| 49 | +#' |
| 50 | +#' This state is given the class `"pipeop_learner_cv_state"`. |
| 51 | +#' |
| 52 | +#' @section Parameters: |
| 53 | +#' The parameters of the [`Learner`][mlr3::Learner] wrapped by this object, as well as: |
| 54 | +#' * `folds` :: `numeric(1)`\cr |
| 55 | +#' Number of cross validation folds. Initialized to 3. |
| 56 | +#' * `alpha` :: `numeric(1)`\cr |
| 57 | +#' Quantile to use for the cross validation plus prediction intervals. Initialized to 0.05. |
| 58 | +#' |
| 59 | +#' @section Internals: |
| 60 | +#' The `$state` is updated during training. |
| 61 | +#' |
| 62 | +#' @section Fields: |
| 63 | +#' Fields inherited from [`PipeOp`], as well as: |
| 64 | +#' * `learner` :: [`Learner`][mlr3::Learner]\cr |
| 65 | +#' [`Learner`][mlr3::Learner] that is being wrapped. |
| 66 | +#' Read-only. |
| 67 | +#' * `learner_model` :: [`Learner`][mlr3::Learner] or `list`\cr |
| 68 | +#' If the [`PipeOpLearnerPICVPlus`] has been trained, this is a `list` containing the [`Learner`][mlr3::Learner]s of the cross validation models. |
| 69 | +#' Otherwise, this contains the [`Learner`][mlr3::Learner] that is being wrapped. |
| 70 | +#' Read-only. |
| 71 | +#' * `predict_type`\cr |
| 72 | +#' Predict type of the [`PipeOpLearnerPICVPlus`], which is always `"response" "quantiles"`. |
| 73 | +#' This can be different to the predict type of the [`Learner`][mlr3::Learner] that is being wrapped. |
| 74 | +#' |
| 75 | +#' @section Methods: |
| 76 | +#' Methods inherited from [`PipeOp`]. |
| 77 | +#' |
| 78 | +#' @references |
| 79 | +#' `r format_bib("barber_2021")` |
| 80 | +#' |
| 81 | +#' @family PipeOps |
| 82 | +#' @family Meta PipeOps |
| 83 | +#' @template seealso_pipeopslist |
| 84 | +#' @include PipeOp.R |
| 85 | +#' @export |
| 86 | +#' @examples |
| 87 | +#' \dontshow{ if (requireNamespace("rpart")) \{ } |
| 88 | +#' library("mlr3") |
| 89 | +#' |
| 90 | +#' task = tsk("mtcars") |
| 91 | +#' learner = lrn("regr.rpart") |
| 92 | +#' lrncvplus_po = mlr_pipeops$get("learner_pi_cvplus", learner) |
| 93 | +#' |
| 94 | +#' lrncvplus_po$train(list(task)) |
| 95 | +#' lrncvplus_po$predict(list(task)) |
| 96 | +#' \dontshow{ \} } |
| 97 | +PipeOpLearnerPICVPlus = R6Class("PipeOpLearnerPICVPlus", |
| 98 | + inherit = PipeOp, |
| 99 | + public = list( |
| 100 | + initialize = function(learner, id = NULL, param_vals = list()) { |
| 101 | + private$.learner = as_learner(learner, clone = TRUE) |
| 102 | + id = id %??% private$.learner$id |
| 103 | + type = private$.learner$task_type |
| 104 | + |
| 105 | + if ("regr" != type) { |
| 106 | + stop("PipeOpLearnerPICVPlus only supports regression.") |
| 107 | + } |
| 108 | + |
| 109 | + task_type = mlr_reflections$task_types[type, mult = "first"]$task |
| 110 | + out_type = mlr_reflections$task_types[type, mult = "first"]$prediction |
| 111 | + |
| 112 | + # paradox requirements 1.0 |
| 113 | + private$.picvplus_param_set = ps( |
| 114 | + folds = p_int(lower = 2L, upper = Inf, tags = c("train", "required")), |
| 115 | + alpha = p_dbl(lower = 0L, upper = 1L, tags = c("predict", "required")) |
| 116 | + ) |
| 117 | + |
| 118 | + private$.picvplus_param_set$values = list(folds = 3, alpha = 0.05) # default |
| 119 | + |
| 120 | + super$initialize(id, param_set = alist(picvplus = private$.picvplus_param_set, private$.learner$param_set), |
| 121 | + param_vals = param_vals, |
| 122 | + input = data.table(name = "input", train = task_type, predict = task_type), |
| 123 | + output = data.table(name = "output", train = "NULL", predict = out_type), |
| 124 | + packages = learner$packages, |
| 125 | + tags = c("learner", "ensemble") |
| 126 | + ) |
| 127 | + } |
| 128 | + ), |
| 129 | + active = list( |
| 130 | + learner = function(val) { |
| 131 | + if (!missing(val)) { |
| 132 | + if (!identical(val, private$.learner)) { |
| 133 | + stop("$learner is read-only.") |
| 134 | + } |
| 135 | + } |
| 136 | + private$.learner |
| 137 | + }, |
| 138 | + |
| 139 | + learner_model = function(val) { |
| 140 | + if (!missing(val)) { |
| 141 | + if (!identical(val, private$.learner)) { |
| 142 | + stop("$learner_model is read-only.") |
| 143 | + } |
| 144 | + } |
| 145 | + if (is.null(self$state) || is_noop(self$state)) { |
| 146 | + private$.learner |
| 147 | + } else { |
| 148 | + multiplicity_recurse(self$state, function(state) { |
| 149 | + map(state$cv_model_states, clone_with_state, learner = private$.learner) |
| 150 | + }) |
| 151 | + } |
| 152 | + }, |
| 153 | + predict_type = function(val) { |
| 154 | + if (!missing(val)) { |
| 155 | + stop("$predict_type is read-only.") |
| 156 | + } |
| 157 | + mlr_reflections$learner_predict_types$regr$quantiles # Returns c("response", "quantiles") |
| 158 | + } |
| 159 | + ), |
| 160 | + private = list( |
| 161 | + .state_class = "pipeop_learner_pi_cvplus_state", |
| 162 | + |
| 163 | + .train = function(inputs) { |
| 164 | + task = inputs[[1L]] |
| 165 | + pv = private$.picvplus_param_set$values |
| 166 | + |
| 167 | + # Compute CV Predictions |
| 168 | + rdesc = rsmp("cv", folds = pv$folds) |
| 169 | + rr = resample(task, private$.learner, rdesc, store_models = TRUE) |
| 170 | + |
| 171 | + prds = rbindlist(map(rr$predictions(predict_sets = "test"), as.data.table), idcol = "fold") |
| 172 | + |
| 173 | + # Add states of trained models and residuals to PipeOp state |
| 174 | + self$state = list(cv_model_states = map(rr$learners, "state"), |
| 175 | + residuals = prds[, .(fold, residual = abs(truth - response))]) |
| 176 | + |
| 177 | + list(NULL) |
| 178 | + }, |
| 179 | + |
| 180 | + .predict = function(inputs) { |
| 181 | + task = inputs[[1L]] |
| 182 | + pv = private$.picvplus_param_set$values |
| 183 | + |
| 184 | + mu_hat = map(self$state$cv_model_states, function(state) { |
| 185 | + on.exit({private$.learner$state = NULL}) |
| 186 | + private$.learner$state = state |
| 187 | + as.data.table(private$.learner$predict(task)) |
| 188 | + }) |
| 189 | + |
| 190 | + get_quantiles = function(observation) { |
| 191 | + quantiles = pmap_dtr(self$state$residuals, function(fold, residual) { |
| 192 | + list(lower = mu_hat[[fold]][observation, response] - residual, |
| 193 | + upper = mu_hat[[fold]][observation, response] + residual) |
| 194 | + }) |
| 195 | + list(q_lower = stats::quantile(quantiles$lower, probs = pv$alpha), |
| 196 | + q_upper = stats::quantile(quantiles$upper, probs = 1 - pv$alpha)) |
| 197 | + } |
| 198 | + |
| 199 | + quantiles = as.matrix(map_dtr(seq_len(task$nrow), get_quantiles)) |
| 200 | + quantiles = unname(quantiles) |
| 201 | + attr(quantiles, "probs") = c(pv$alpha, 1 - pv$alpha) |
| 202 | + |
| 203 | + response = map_dbl(seq_len(task$nrow), function(observation) { |
| 204 | + stats::quantile(map_dbl(mu_hat, function(fold) {fold[observation, response]}), probs = 0.5) |
| 205 | + }) |
| 206 | + |
| 207 | + list(PredictionRegr$new( |
| 208 | + row_ids = task$row_ids, truth = task$truth(),response = response, quantiles = quantiles |
| 209 | + )) |
| 210 | + }, |
| 211 | + |
| 212 | + .picvplus_param_set = NULL, |
| 213 | + .learner = NULL, |
| 214 | + .additional_phash_input = function() private$.learner$phash |
| 215 | + ) |
| 216 | +) |
| 217 | + |
| 218 | +#' @export |
| 219 | +marshal_model.pipeop_learner_pi_cvplus_state = function(model, inplace = FALSE, ...) { |
| 220 | + # Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace |
| 221 | + # is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3 |
| 222 | + # workhorse function |
| 223 | + model$cv_model_states = map(model$cv_model_states, marshal_model, inplace = inplace) |
| 224 | + # only wrap this in a marshaled class if the model was actually marshaled above |
| 225 | + # (the default marshal method does nothing) |
| 226 | + if (some(model$cv_model_states, is_marshaled_model)) { |
| 227 | + model = structure( |
| 228 | + list(marshaled = model, packages = "mlr3pipelines"), |
| 229 | + class = c(paste0(class(model), "_marshaled"), "marshaled") |
| 230 | + ) |
| 231 | + } |
| 232 | + model |
| 233 | +} |
| 234 | + |
| 235 | +#' @export |
| 236 | +unmarshal_model.pipeop_learner_pi_cvplus_state_marshaled = function(model, inplace = FALSE, ...) { |
| 237 | + state_marshaled = model$marshaled |
| 238 | + state_marshaled$cv_model_states = map(state_marshaled$cv_model_states, unmarshal_model, inplace = inplace) |
| 239 | + state_marshaled |
| 240 | +} |
| 241 | + |
| 242 | +mlr_pipeops$add("learner_pi_cvplus", PipeOpLearnerPICVPlus, list(R6Class("Learner", public = list(id = "learner_pi_cvplus", task_type = "regr", param_set = ps(), packages = "mlr3pipelines"))$new())) |
| 243 | + |
| 244 | + |
0 commit comments