Skip to content

Commit ddf85d7

Browse files
authored
Merge pull request #838 from mlr-org/pipeop_learnercvplus
PipeOpLearnerCVPlus
2 parents dbec945 + f00027e commit ddf85d7

File tree

87 files changed

+757
-3
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+757
-3
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ Collate:
161161
'PipeOpKernelPCA.R'
162162
'PipeOpLearner.R'
163163
'PipeOpLearnerCV.R'
164+
'PipeOpLearnerPICVPlus.R'
164165
'PipeOpLearnerQuantiles.R'
165166
'PipeOpMissingIndicators.R'
166167
'PipeOpModelMatrix.R'

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ S3method(marshal_model,Multiplicity)
5353
S3method(marshal_model,graph_learner_model)
5454
S3method(marshal_model,pipeop_impute_learner_state)
5555
S3method(marshal_model,pipeop_learner_cv_state)
56+
S3method(marshal_model,pipeop_learner_pi_cvplus_state)
5657
S3method(marshal_model,pipeop_learner_quantiles_state)
5758
S3method(po,"NULL")
5859
S3method(po,Filter)
@@ -77,6 +78,7 @@ S3method(unmarshal_model,Multiplicity_marshaled)
7778
S3method(unmarshal_model,graph_learner_model_marshaled)
7879
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
7980
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
81+
S3method(unmarshal_model,pipeop_learner_pi_cvplus_state_marshaled)
8082
S3method(unmarshal_model,pipeop_learner_quantiles_state_marshaled)
8183
export("%>>!%")
8284
export("%>>%")
@@ -127,6 +129,7 @@ export(PipeOpImputeSample)
127129
export(PipeOpKernelPCA)
128130
export(PipeOpLearner)
129131
export(PipeOpLearnerCV)
132+
export(PipeOpLearnerPICVPlus)
130133
export(PipeOpLearnerQuantiles)
131134
export(PipeOpMissInd)
132135
export(PipeOpModelMatrix)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mlr3pipelines 0.7.0-9000
22

33
* New down-sampling PipeOps for inbalanced data: `PipeOpTomek` / `po("tomek")` and `PipeOpNearmiss` / `po("nearmiss")`
4+
* New PipeOp `PipeOpLearnerPICVPlus / po("learner_pi_cvplus")`
45
* New PipeOp for Quantile Regression `PipeOpLearnerQuantiles` / `po(learner_quantiles)`
56
* `GraphLearner` has new active bindings/methods as shortcuts for active bindings/methods of the underlying `Graph`:
67
`$pipeops`, `$edges`, `$pipeops_param_set`, and `$pipeops_param_set_values` as well as `$ids()` and `$plot()`.

R/PipeOpLearnerPICVPlus.R

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
#' @title Wrap a Learner into a PipeOp with Cross-validation Plus Confidence Intervals as Predictions
2+
#'
3+
#' @usage NULL
4+
#' @name mlr_pipeops_learner_pi_cvplus
5+
#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOp`].
6+
#'
7+
#' @description
8+
#' Wraps an [`mlr3::Learner`] into a [`PipeOp`].
9+
#'
10+
#' Inherits the `$param_set` (and therefore `$param_set$values`) from the [`Learner`][mlr3::Learner] it is constructed from.
11+
#'
12+
#' Using [`PipeOpLearnerPICVPlus`], it is possible to embed a [`mlr3::Learner`] into a [`Graph`].
13+
#' [`PipeOpLearnerPICVPlus`] can then be used to perform cross validation plus (or jackknife plus).
14+
#' During training, [`PipeOpLearnerPICVPlus`] performs cross validation on the training data.
15+
#' During prediction, the models from the training stage are used to construct predictive confidence intervals for the prediction data based on
16+
#' out-of-fold residuals and out-of-fold predictions.
17+
#'
18+
#' @section Construction:
19+
#' ```
20+
#' PipeOpLearnerPICVPlus$new(learner, id = NULL, param_vals = list())
21+
#' ```
22+
#'
23+
#' * `learner` :: [`LearnerRegr`][mlr3::LearnerRegr]
24+
#' [`LearnerRegr`][mlr3::LearnerRegr] to use for the cross validation models in the Cross Validation Plus method.
25+
#' This argument is always cloned; to access the [`Learner`][mlr3::Learner] inside `PipeOpLearnerPICVPlus` by-reference, use `$learner`.\cr
26+
#' * `id` :: `character(1)`
27+
#' Identifier of the resulting object, internally defaulting to the `id` of the [`Learner`][mlr3::Learner] being wrapped.
28+
#' * `param_vals` :: named `list`\cr
29+
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.
30+
#' Default is `list()`.
31+
#'
32+
#' @section Input and Output Channels:
33+
#' [`PipeOpLearnerPICVPlus`] has one input channel named `"input"`, taking a [`Task`][mlr3::Task] specific to the [`Learner`][mlr3::Learner]
34+
#' type given to `learner` during construction; both during training and prediction.
35+
#'
36+
#' [`PipeOpLearnerPICVPlus`] has one output channel named `"output"`, producing `NULL` during training and a [`PredictionRegr`][mlr3::PredictionRegr]
37+
#' during prediction.
38+
#'
39+
#' The output during prediction is a [`PredictionRegr`][mlr3::PredictionRegr] with `predict_type` `quantiles` on the prediction input data.
40+
#' The `alpha` and `1 - alpha` quantiles are the `quantiles` of the prediction interval produced by the cross validation plus method.
41+
#' The `response` is the median of the prediction of all cross validation models on the prediction data.
42+
#'
43+
#' @section State:
44+
#' The `$state` is a named `list` with members:
45+
#' * `cv_model_states` :: `list`\cr
46+
#' List of the state of each cross validation model created by the [`Learner`][`mlr3::Learner`]'s `$.train()` function during resampling with method `"cv"`.
47+
#' * `residuals` :: `data.table`\cr
48+
#' `data.table` with columns `fold` and `residual`. Lists the Regression residuals for each observation and cross validation fold.
49+
#'
50+
#' This state is given the class `"pipeop_learner_cv_state"`.
51+
#'
52+
#' @section Parameters:
53+
#' The parameters of the [`Learner`][mlr3::Learner] wrapped by this object, as well as:
54+
#' * `folds` :: `numeric(1)`\cr
55+
#' Number of cross validation folds. Initialized to 3.
56+
#' * `alpha` :: `numeric(1)`\cr
57+
#' Quantile to use for the cross validation plus prediction intervals. Initialized to 0.05.
58+
#'
59+
#' @section Internals:
60+
#' The `$state` is updated during training.
61+
#'
62+
#' @section Fields:
63+
#' Fields inherited from [`PipeOp`], as well as:
64+
#' * `learner` :: [`Learner`][mlr3::Learner]\cr
65+
#' [`Learner`][mlr3::Learner] that is being wrapped.
66+
#' Read-only.
67+
#' * `learner_model` :: [`Learner`][mlr3::Learner] or `list`\cr
68+
#' If the [`PipeOpLearnerPICVPlus`] has been trained, this is a `list` containing the [`Learner`][mlr3::Learner]s of the cross validation models.
69+
#' Otherwise, this contains the [`Learner`][mlr3::Learner] that is being wrapped.
70+
#' Read-only.
71+
#' * `predict_type`\cr
72+
#' Predict type of the [`PipeOpLearnerPICVPlus`], which is always `"response" "quantiles"`.
73+
#' This can be different to the predict type of the [`Learner`][mlr3::Learner] that is being wrapped.
74+
#'
75+
#' @section Methods:
76+
#' Methods inherited from [`PipeOp`].
77+
#'
78+
#' @references
79+
#' `r format_bib("barber_2021")`
80+
#'
81+
#' @family PipeOps
82+
#' @family Meta PipeOps
83+
#' @template seealso_pipeopslist
84+
#' @include PipeOp.R
85+
#' @export
86+
#' @examples
87+
#' \dontshow{ if (requireNamespace("rpart")) \{ }
88+
#' library("mlr3")
89+
#'
90+
#' task = tsk("mtcars")
91+
#' learner = lrn("regr.rpart")
92+
#' lrncvplus_po = mlr_pipeops$get("learner_pi_cvplus", learner)
93+
#'
94+
#' lrncvplus_po$train(list(task))
95+
#' lrncvplus_po$predict(list(task))
96+
#' \dontshow{ \} }
97+
PipeOpLearnerPICVPlus = R6Class("PipeOpLearnerPICVPlus",
98+
inherit = PipeOp,
99+
public = list(
100+
initialize = function(learner, id = NULL, param_vals = list()) {
101+
private$.learner = as_learner(learner, clone = TRUE)
102+
id = id %??% private$.learner$id
103+
type = private$.learner$task_type
104+
105+
if ("regr" != type) {
106+
stop("PipeOpLearnerPICVPlus only supports regression.")
107+
}
108+
109+
task_type = mlr_reflections$task_types[type, mult = "first"]$task
110+
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
111+
112+
# paradox requirements 1.0
113+
private$.picvplus_param_set = ps(
114+
folds = p_int(lower = 2L, upper = Inf, tags = c("train", "required")),
115+
alpha = p_dbl(lower = 0L, upper = 1L, tags = c("predict", "required"))
116+
)
117+
118+
private$.picvplus_param_set$values = list(folds = 3, alpha = 0.05) # default
119+
120+
super$initialize(id, param_set = alist(picvplus = private$.picvplus_param_set, private$.learner$param_set),
121+
param_vals = param_vals,
122+
input = data.table(name = "input", train = task_type, predict = task_type),
123+
output = data.table(name = "output", train = "NULL", predict = out_type),
124+
packages = learner$packages,
125+
tags = c("learner", "ensemble")
126+
)
127+
}
128+
),
129+
active = list(
130+
learner = function(val) {
131+
if (!missing(val)) {
132+
if (!identical(val, private$.learner)) {
133+
stop("$learner is read-only.")
134+
}
135+
}
136+
private$.learner
137+
},
138+
139+
learner_model = function(val) {
140+
if (!missing(val)) {
141+
if (!identical(val, private$.learner)) {
142+
stop("$learner_model is read-only.")
143+
}
144+
}
145+
if (is.null(self$state) || is_noop(self$state)) {
146+
private$.learner
147+
} else {
148+
multiplicity_recurse(self$state, function(state) {
149+
map(state$cv_model_states, clone_with_state, learner = private$.learner)
150+
})
151+
}
152+
},
153+
predict_type = function(val) {
154+
if (!missing(val)) {
155+
stop("$predict_type is read-only.")
156+
}
157+
mlr_reflections$learner_predict_types$regr$quantiles # Returns c("response", "quantiles")
158+
}
159+
),
160+
private = list(
161+
.state_class = "pipeop_learner_pi_cvplus_state",
162+
163+
.train = function(inputs) {
164+
task = inputs[[1L]]
165+
pv = private$.picvplus_param_set$values
166+
167+
# Compute CV Predictions
168+
rdesc = rsmp("cv", folds = pv$folds)
169+
rr = resample(task, private$.learner, rdesc, store_models = TRUE)
170+
171+
prds = rbindlist(map(rr$predictions(predict_sets = "test"), as.data.table), idcol = "fold")
172+
173+
# Add states of trained models and residuals to PipeOp state
174+
self$state = list(cv_model_states = map(rr$learners, "state"),
175+
residuals = prds[, .(fold, residual = abs(truth - response))])
176+
177+
list(NULL)
178+
},
179+
180+
.predict = function(inputs) {
181+
task = inputs[[1L]]
182+
pv = private$.picvplus_param_set$values
183+
184+
mu_hat = map(self$state$cv_model_states, function(state) {
185+
on.exit({private$.learner$state = NULL})
186+
private$.learner$state = state
187+
as.data.table(private$.learner$predict(task))
188+
})
189+
190+
get_quantiles = function(observation) {
191+
quantiles = pmap_dtr(self$state$residuals, function(fold, residual) {
192+
list(lower = mu_hat[[fold]][observation, response] - residual,
193+
upper = mu_hat[[fold]][observation, response] + residual)
194+
})
195+
list(q_lower = stats::quantile(quantiles$lower, probs = pv$alpha),
196+
q_upper = stats::quantile(quantiles$upper, probs = 1 - pv$alpha))
197+
}
198+
199+
quantiles = as.matrix(map_dtr(seq_len(task$nrow), get_quantiles))
200+
quantiles = unname(quantiles)
201+
attr(quantiles, "probs") = c(pv$alpha, 1 - pv$alpha)
202+
203+
response = map_dbl(seq_len(task$nrow), function(observation) {
204+
stats::quantile(map_dbl(mu_hat, function(fold) {fold[observation, response]}), probs = 0.5)
205+
})
206+
207+
list(PredictionRegr$new(
208+
row_ids = task$row_ids, truth = task$truth(),response = response, quantiles = quantiles
209+
))
210+
},
211+
212+
.picvplus_param_set = NULL,
213+
.learner = NULL,
214+
.additional_phash_input = function() private$.learner$phash
215+
)
216+
)
217+
218+
#' @export
219+
marshal_model.pipeop_learner_pi_cvplus_state = function(model, inplace = FALSE, ...) {
220+
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
221+
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
222+
# workhorse function
223+
model$cv_model_states = map(model$cv_model_states, marshal_model, inplace = inplace)
224+
# only wrap this in a marshaled class if the model was actually marshaled above
225+
# (the default marshal method does nothing)
226+
if (some(model$cv_model_states, is_marshaled_model)) {
227+
model = structure(
228+
list(marshaled = model, packages = "mlr3pipelines"),
229+
class = c(paste0(class(model), "_marshaled"), "marshaled")
230+
)
231+
}
232+
model
233+
}
234+
235+
#' @export
236+
unmarshal_model.pipeop_learner_pi_cvplus_state_marshaled = function(model, inplace = FALSE, ...) {
237+
state_marshaled = model$marshaled
238+
state_marshaled$cv_model_states = map(state_marshaled$cv_model_states, unmarshal_model, inplace = inplace)
239+
state_marshaled
240+
}
241+
242+
mlr_pipeops$add("learner_pi_cvplus", PipeOpLearnerPICVPlus, list(R6Class("Learner", public = list(id = "learner_pi_cvplus", task_type = "regr", param_set = ps(), packages = "mlr3pipelines"))$new()))
243+
244+

R/bibentries.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
#' @importFrom utils bibentry
22
bibentries = c(
3+
barber_2021 = bibentry("article",
4+
doi = "10.1214/20-AOS1965",
5+
year = "2021",
6+
month = "02",
7+
volume = "49",
8+
pages = "486--507",
9+
author = "Rina Foygel Barber and Emmanuel J. Candes and Aaditya Ramdasa and Ryan J. Tibshirani",
10+
title = "Predictive inference with the jackknife+",
11+
journal = "Annals of Statistics"
12+
),
13+
314
chawla_2002 = bibentry("article",
415
doi = "10.1613/jair.953",
516
year = "2002",

man/PipeOp.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpEnsemble.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpImpute.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpTargetTrafo.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpTaskPreproc.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)