|
| 1 | +#' @title Wrap a Learner into a PipeOp to to predict multiple Quantiles |
| 2 | +#' |
| 3 | +#' @usage NULL |
| 4 | +#' @name mlr_pipeops_learner_quantiles |
| 5 | +#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOp`]. |
| 6 | +#' |
| 7 | +#' @description |
| 8 | +#' Wraps a [`LearnerRegr`][mlr3::LearnerRegr] into a [`PipeOp`] to predict multiple quantiles. |
| 9 | +#' |
| 10 | +#' [`PipeOpLearnerQuantiles`] only supports [`LearnerRegr`][mlr3::LearnerRegr]s that have `quantiles` as a possible `pedict_type`. |
| 11 | +#' |
| 12 | +#' It produces quantile-based predictions for multiple quantiles in one [`PredictionRegr`][mlr3::Prediction]. This is especially helpful if the [`LearnerRegr`][mlr3::LearnerRegr] can only predict one quantile (like for example [`LearnerRegrGBM`][mlr3extralearners::LearnerRegrGBM]) |
| 13 | +#' |
| 14 | +#' Inherits the `$param_set` (and therefore `$param_set$values`) from the [`Learner`][mlr3::Learner] it is constructed from. |
| 15 | +#' |
| 16 | +#' @section Construction: |
| 17 | +#' ``` |
| 18 | +#' PipeOpLearnerQuantiles$new(learner, id = NULL, param_vals = list()) |
| 19 | +#' ``` |
| 20 | +#' |
| 21 | +#' * `learner` :: [`Learner`][mlr3::Learner] | `character(1)`\cr |
| 22 | +#' [`Learner`][mlr3::Learner] to wrap, or a string identifying a [`Learner`][mlr3::Learner] in the [`mlr3::mlr_learners`] [`Dictionary`][mlr3misc::Dictionary]. |
| 23 | +#' The [`Learner`][mlr3::Learner] has to be a [`LearnerRegr`][mlr3::LearnerRegr] with `predict_type` `"quantiles"`. |
| 24 | +#' This argument is always cloned; to access the [`Learner`][mlr3::Learner] inside `PipeOpLearnerQuantiles` by-reference, use `$learner`. |
| 25 | +#' * `id` :: `character(1)` |
| 26 | +#' Identifier of the resulting object, internally defaulting to the `id` of the [`Learner`][mlr3::Learner] being wrapped. |
| 27 | +#' * `param_vals` :: named `list`\cr |
| 28 | +#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`. |
| 29 | +#' |
| 30 | +#' @section Input and Output Channels: |
| 31 | +#' [`PipeOpLearnerQuantiles`] has one input channel named `"input"`, taking a [`TaskRegr`][mlr3::TaskRegr] specific to the [`Learner`][mlr3::Learner] |
| 32 | +#' type given to `learner` during construction; both during training and prediction. |
| 33 | +#' |
| 34 | +#' [`PipeOpLearnerQuantiles`] has one output channel named `"output"`, producing `NULL` during training and a [`PredictionRegr`][mlr3::Prediction] object |
| 35 | +#' during prediction. |
| 36 | +#' |
| 37 | +#' The output during prediction is a [`PredictionRegr`][mlr3::PredictionRegr] on the prediction input data that aggregates all `result`s produced by the [`Learner`][mlr3::Learner] for each quantile in `quantiles`. |
| 38 | +#' trained on the training input data. |
| 39 | +#' |
| 40 | +#' @section State: |
| 41 | +#' The `$state` is set during training. It is a named `list` with the member: |
| 42 | +#' * `model_states` :: `list`\cr |
| 43 | +#' List of the states of all models created by the [`Learner`][mlr3::Learner]'s `$.train()` function. |
| 44 | +#' |
| 45 | +#' @section Parameters: |
| 46 | +#' The parameters are exactly the parameters of the [`Learner`][mlr3::Learner] wrapped by this object. |
| 47 | +#' * `q_vals` :: `numeric`\cr |
| 48 | +#' Quantiles to use for training and prediction. |
| 49 | +#' Initialized to `c(0.05, 0.5, 0.95)` |
| 50 | +#' |
| 51 | +#' * `q_response` :: `numeric(1)`\cr |
| 52 | +#' Which quantile in `quantiles` to use as a `response` for the [`PredictionRegr`][mlr3::PredictionRegr] during prediction. |
| 53 | +#' Initialized to `0.5`. |
| 54 | +#' |
| 55 | +#' @section Internals: |
| 56 | +#' The `$state` is updated during training. |
| 57 | +#' |
| 58 | +#' @section Fields: |
| 59 | +#' Fields inherited from [`PipeOp`], as well as: |
| 60 | +#' * `learner` :: [`LearnerRegr`][mlr3::LearnerRegr]\cr |
| 61 | +#' [`Learner`][mlr3::Learner] that is being wrapped. Read-only. |
| 62 | +#' * `learner_model` :: [`Learner`][mlr3::Learner]\cr |
| 63 | +#' If [`PipeOpLearnerQuantiles`] has been trained, this is a `list` containing the [`Learner`][mlr3::Learner]s for each quantile. |
| 64 | +#' Otherwise, this contains the [`Learner`][mlr3::Learner] that is being wrapped. |
| 65 | +#' Read-only. |
| 66 | +#' * `predict_type` :: `character(1)`\cr |
| 67 | +#' Predict type of the [`PipeOpLearnerQuantiles`], which is always `"response" "quantiles"`. |
| 68 | +#' |
| 69 | +#' @section Methods: |
| 70 | +#' Methods inherited from [`PipeOp`]. |
| 71 | +#' |
| 72 | +#' @family PipeOps |
| 73 | +#' @family Meta PipeOps |
| 74 | +#' @template seealso_pipeopslist |
| 75 | +#' @include PipeOp.R |
| 76 | +#' @export |
| 77 | +#' @examples |
| 78 | +#' library("mlr3") |
| 79 | +#' |
| 80 | +#' task = tsk("boston_housing") |
| 81 | +#' learner = lrn("regr.debug") |
| 82 | +#' po = mlr_pipeops$get("learner_quantiles", learner) |
| 83 | +#' |
| 84 | +#' po$train(list(task)) |
| 85 | +#' po$predict(list(task)) |
| 86 | +PipeOpLearnerQuantiles = R6Class("PipeOpLearnerQuantiles", |
| 87 | + inherit = PipeOp, |
| 88 | + public = list( |
| 89 | + initialize = function(learner, id = NULL, param_vals = list()) { |
| 90 | + private$.learner = as_learner(learner, clone = TRUE) |
| 91 | + id = id %??% private$.learner$id |
| 92 | + type = private$.learner$task_type |
| 93 | + |
| 94 | + if ("regr" != type) { |
| 95 | + stop("PipeOpLearnerQuantiles only supports regression.") |
| 96 | + } |
| 97 | + |
| 98 | + task_type = mlr_reflections$task_types[type, mult = "first"]$task |
| 99 | + out_type = mlr_reflections$task_types[type, mult = "first"]$prediction |
| 100 | + |
| 101 | + # paradox requirements 1.0 |
| 102 | + private$.quantiles_param_set = ps( |
| 103 | + q_vals = p_uty(custom_check = crate(function(x) { |
| 104 | + checkmate::check_numeric(x, lower = 0L, upper = 1L, any.missing = FALSE, min.len = 1L, sorted = TRUE) |
| 105 | + }), tags = c("train", "predict", "required")), |
| 106 | + q_response = p_dbl(lower = 0L, upper = 1L, tags = c("train", "predict", "required")) |
| 107 | + ) |
| 108 | + |
| 109 | + private$.quantiles_param_set$values = list(q_vals = c(0.05, 0.5, 0.95), q_response = 0.5) # default |
| 110 | + |
| 111 | + super$initialize(id, param_set = alist(quantiles = private$.quantiles_param_set, private$.learner$param_set), |
| 112 | + param_vals = param_vals, |
| 113 | + input = data.table(name = "input", train = task_type, predict = task_type), |
| 114 | + output = data.table(name = "output", train = "NULL", predict = out_type), |
| 115 | + packages = learner$packages, tags = c("learner", "ensemble") |
| 116 | + ) |
| 117 | + } |
| 118 | + ), |
| 119 | + active = list( |
| 120 | + learner = function(val) { |
| 121 | + if (!missing(val)) { |
| 122 | + if (!identical(val, private$.learner)) { |
| 123 | + stop("$learner is read-only.") |
| 124 | + } |
| 125 | + } |
| 126 | + private$.learner |
| 127 | + }, |
| 128 | + |
| 129 | + learner_model = function(val) { |
| 130 | + if (!missing(val)) { |
| 131 | + if (!identical(val, private$.learner)) { |
| 132 | + stop("$learner_model is read-only.") |
| 133 | + } |
| 134 | + } |
| 135 | + if (is.null(self$state) || is_noop(self$state)) { |
| 136 | + private$.learner |
| 137 | + } else { |
| 138 | + multiplicity_recurse(self$state, function(state) { |
| 139 | + map(state$model_states, clone_with_state, learner = private$.learner) |
| 140 | + }) |
| 141 | + } |
| 142 | + }, |
| 143 | + predict_type = function(val) { |
| 144 | + if (!missing(val)) { |
| 145 | + stop("$predict_type is read-only.") |
| 146 | + } |
| 147 | + mlr_reflections$learner_predict_types$regr$quantiles # Returns c("response", "quantiles") |
| 148 | + } |
| 149 | + ), |
| 150 | + private = list( |
| 151 | + .state_class = "pipeop_learner_quantiles_state", |
| 152 | + |
| 153 | + .train = function(inputs) { |
| 154 | + task = inputs[[1L]] |
| 155 | + pv = private$.quantiles_param_set$values |
| 156 | + |
| 157 | + assert_subset(pv$q_response, pv$q_vals, empty.ok = FALSE) |
| 158 | + if ("quantiles" %nin% private$.learner$predict_types) { |
| 159 | + stopf("Learner needs to be able to predict quantiles.") |
| 160 | + } |
| 161 | + private$.learner$predict_type = "quantiles" |
| 162 | + |
| 163 | + # train learner on all quantiles in q_vals |
| 164 | + states = map(pv$q_vals, function(quantile) { |
| 165 | + on.exit({private$.learner$state = NULL}) |
| 166 | + private$.learner$quantiles = quantile |
| 167 | + private$.learner$train(task)$state |
| 168 | + }) |
| 169 | + |
| 170 | + # add states of trained models to PipeOp state |
| 171 | + self$state = list(model_states = states) |
| 172 | + |
| 173 | + list(NULL) |
| 174 | + }, |
| 175 | + |
| 176 | + .predict = function(inputs) { |
| 177 | + task = inputs[[1L]] |
| 178 | + pv = private$.quantiles_param_set$values |
| 179 | + |
| 180 | + prds = pmap(list(self$state$model_states, pv$q_vals), function(state, quantile) { |
| 181 | + on.exit({private$.learner$state = NULL}) |
| 182 | + private$.learner$state = state |
| 183 | + private$.learner$quantiles = quantile |
| 184 | + as.data.table(private$.learner$predict(task)) |
| 185 | + }) |
| 186 | + |
| 187 | + quantiles = as.matrix(map_dtc(prds, "response")) |
| 188 | + unname(quantiles) |
| 189 | + attr(quantiles, "probs") = pv$q_vals |
| 190 | + attr(quantiles, "response") = pv$q_response |
| 191 | + |
| 192 | + # return quantile PredictionRegr with all requested quantiles |
| 193 | + list(as_prediction(as_prediction_data(list(quantiles = quantiles), task = task))) |
| 194 | + }, |
| 195 | + |
| 196 | + .quantiles_param_set = NULL, |
| 197 | + .learner = NULL, |
| 198 | + .additional_phash_input = function() private$.learner$phash |
| 199 | + ) |
| 200 | +) |
| 201 | + |
| 202 | +#' @export |
| 203 | +marshal_model.pipeop_learner_quantiles_state = function(model, inplace = FALSE, ...) { |
| 204 | + # Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace |
| 205 | + # is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3 |
| 206 | + # workhorse function |
| 207 | + model$model_states = map(model$model_states, marshal_model, inplace = inplace) |
| 208 | + # only wrap this in a marshaled class if the model was actually marshaled above |
| 209 | + # (the default marshal method does nothing) |
| 210 | + if (some(model$model_states, is_marshaled_model)) { |
| 211 | + model = structure( |
| 212 | + list(marshaled = model, packages = "mlr3pipelines"), |
| 213 | + class = c(paste0(class(model), "_marshaled"), "marshaled") |
| 214 | + ) |
| 215 | + } |
| 216 | + model |
| 217 | +} |
| 218 | + |
| 219 | +#' @export |
| 220 | +unmarshal_model.pipeop_learner_quantiles_state_marshaled = function(model, inplace = FALSE, ...) { |
| 221 | + state_marshaled = model$marshaled |
| 222 | + state_marshaled$model_states = map(state_marshaled$model_states, unmarshal_model, inplace = inplace) |
| 223 | + state_marshaled |
| 224 | +} |
| 225 | + |
| 226 | + |
| 227 | +mlr_pipeops$add("learner_quantiles", PipeOpLearnerQuantiles, list(R6Class("Learner", public = list(id = "learner_quantiles", task_type = "regr", param_set = ps(), packages = "mlr3pipelines"))$new())) |
0 commit comments