Skip to content

Commit dbec945

Browse files
authored
Merge pull request #841 from mlr-org/pipeop_quantiles
PipeOpLearnerQuantiles for Quantile Regression
2 parents 750636f + 9e8947d commit dbec945

File tree

86 files changed

+723
-4
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

86 files changed

+723
-4
lines changed

DESCRIPTION

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,11 @@ Authors@R:
4343
family = "Mücke",
4444
role = "ctb",
4545
email = "[email protected]",
46-
comment = c(ORCID = "0009-0000-9432-9795")))
46+
comment = c(ORCID = "0009-0000-9432-9795")),
47+
person(given = "Lona",
48+
family = "Koers",
49+
role = "ctb",
50+
email = "[email protected]"))
4751
Description: Dataflow programming toolkit that enriches 'mlr3' with a diverse
4852
set of pipelining operators ('PipeOps') that can be composed into graphs.
4953
Operations exist for data preprocessing, model fitting, and ensemble
@@ -157,6 +161,7 @@ Collate:
157161
'PipeOpKernelPCA.R'
158162
'PipeOpLearner.R'
159163
'PipeOpLearnerCV.R'
164+
'PipeOpLearnerQuantiles.R'
160165
'PipeOpMissingIndicators.R'
161166
'PipeOpModelMatrix.R'
162167
'PipeOpMultiplicity.R'

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ S3method(marshal_model,Multiplicity)
5353
S3method(marshal_model,graph_learner_model)
5454
S3method(marshal_model,pipeop_impute_learner_state)
5555
S3method(marshal_model,pipeop_learner_cv_state)
56+
S3method(marshal_model,pipeop_learner_quantiles_state)
5657
S3method(po,"NULL")
5758
S3method(po,Filter)
5859
S3method(po,Learner)
@@ -76,6 +77,7 @@ S3method(unmarshal_model,Multiplicity_marshaled)
7677
S3method(unmarshal_model,graph_learner_model_marshaled)
7778
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
7879
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
80+
S3method(unmarshal_model,pipeop_learner_quantiles_state_marshaled)
7981
export("%>>!%")
8082
export("%>>%")
8183
export("%among%")
@@ -125,6 +127,7 @@ export(PipeOpImputeSample)
125127
export(PipeOpKernelPCA)
126128
export(PipeOpLearner)
127129
export(PipeOpLearnerCV)
130+
export(PipeOpLearnerQuantiles)
128131
export(PipeOpMissInd)
129132
export(PipeOpModelMatrix)
130133
export(PipeOpMultiplicityExply)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mlr3pipelines 0.7.0-9000
22

33
* New down-sampling PipeOps for inbalanced data: `PipeOpTomek` / `po("tomek")` and `PipeOpNearmiss` / `po("nearmiss")`
4+
* New PipeOp for Quantile Regression `PipeOpLearnerQuantiles` / `po(learner_quantiles)`
45
* `GraphLearner` has new active bindings/methods as shortcuts for active bindings/methods of the underlying `Graph`:
56
`$pipeops`, `$edges`, `$pipeops_param_set`, and `$pipeops_param_set_values` as well as `$ids()` and `$plot()`.
67

R/PipeOpLearnerQuantiles.R

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
#' @title Wrap a Learner into a PipeOp to to predict multiple Quantiles
2+
#'
3+
#' @usage NULL
4+
#' @name mlr_pipeops_learner_quantiles
5+
#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOp`].
6+
#'
7+
#' @description
8+
#' Wraps a [`LearnerRegr`][mlr3::LearnerRegr] into a [`PipeOp`] to predict multiple quantiles.
9+
#'
10+
#' [`PipeOpLearnerQuantiles`] only supports [`LearnerRegr`][mlr3::LearnerRegr]s that have `quantiles` as a possible `pedict_type`.
11+
#'
12+
#' It produces quantile-based predictions for multiple quantiles in one [`PredictionRegr`][mlr3::Prediction]. This is especially helpful if the [`LearnerRegr`][mlr3::LearnerRegr] can only predict one quantile (like for example [`LearnerRegrGBM`][mlr3extralearners::LearnerRegrGBM])
13+
#'
14+
#' Inherits the `$param_set` (and therefore `$param_set$values`) from the [`Learner`][mlr3::Learner] it is constructed from.
15+
#'
16+
#' @section Construction:
17+
#' ```
18+
#' PipeOpLearnerQuantiles$new(learner, id = NULL, param_vals = list())
19+
#' ```
20+
#'
21+
#' * `learner` :: [`Learner`][mlr3::Learner] | `character(1)`\cr
22+
#' [`Learner`][mlr3::Learner] to wrap, or a string identifying a [`Learner`][mlr3::Learner] in the [`mlr3::mlr_learners`] [`Dictionary`][mlr3misc::Dictionary].
23+
#' The [`Learner`][mlr3::Learner] has to be a [`LearnerRegr`][mlr3::LearnerRegr] with `predict_type` `"quantiles"`.
24+
#' This argument is always cloned; to access the [`Learner`][mlr3::Learner] inside `PipeOpLearnerQuantiles` by-reference, use `$learner`.
25+
#' * `id` :: `character(1)`
26+
#' Identifier of the resulting object, internally defaulting to the `id` of the [`Learner`][mlr3::Learner] being wrapped.
27+
#' * `param_vals` :: named `list`\cr
28+
#' List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction. Default `list()`.
29+
#'
30+
#' @section Input and Output Channels:
31+
#' [`PipeOpLearnerQuantiles`] has one input channel named `"input"`, taking a [`TaskRegr`][mlr3::TaskRegr] specific to the [`Learner`][mlr3::Learner]
32+
#' type given to `learner` during construction; both during training and prediction.
33+
#'
34+
#' [`PipeOpLearnerQuantiles`] has one output channel named `"output"`, producing `NULL` during training and a [`PredictionRegr`][mlr3::Prediction] object
35+
#' during prediction.
36+
#'
37+
#' The output during prediction is a [`PredictionRegr`][mlr3::PredictionRegr] on the prediction input data that aggregates all `result`s produced by the [`Learner`][mlr3::Learner] for each quantile in `quantiles`.
38+
#' trained on the training input data.
39+
#'
40+
#' @section State:
41+
#' The `$state` is set during training. It is a named `list` with the member:
42+
#' * `model_states` :: `list`\cr
43+
#' List of the states of all models created by the [`Learner`][mlr3::Learner]'s `$.train()` function.
44+
#'
45+
#' @section Parameters:
46+
#' The parameters are exactly the parameters of the [`Learner`][mlr3::Learner] wrapped by this object.
47+
#' * `q_vals` :: `numeric`\cr
48+
#' Quantiles to use for training and prediction.
49+
#' Initialized to `c(0.05, 0.5, 0.95)`
50+
#'
51+
#' * `q_response` :: `numeric(1)`\cr
52+
#' Which quantile in `quantiles` to use as a `response` for the [`PredictionRegr`][mlr3::PredictionRegr] during prediction.
53+
#' Initialized to `0.5`.
54+
#'
55+
#' @section Internals:
56+
#' The `$state` is updated during training.
57+
#'
58+
#' @section Fields:
59+
#' Fields inherited from [`PipeOp`], as well as:
60+
#' * `learner` :: [`LearnerRegr`][mlr3::LearnerRegr]\cr
61+
#' [`Learner`][mlr3::Learner] that is being wrapped. Read-only.
62+
#' * `learner_model` :: [`Learner`][mlr3::Learner]\cr
63+
#' If [`PipeOpLearnerQuantiles`] has been trained, this is a `list` containing the [`Learner`][mlr3::Learner]s for each quantile.
64+
#' Otherwise, this contains the [`Learner`][mlr3::Learner] that is being wrapped.
65+
#' Read-only.
66+
#' * `predict_type` :: `character(1)`\cr
67+
#' Predict type of the [`PipeOpLearnerQuantiles`], which is always `"response" "quantiles"`.
68+
#'
69+
#' @section Methods:
70+
#' Methods inherited from [`PipeOp`].
71+
#'
72+
#' @family PipeOps
73+
#' @family Meta PipeOps
74+
#' @template seealso_pipeopslist
75+
#' @include PipeOp.R
76+
#' @export
77+
#' @examples
78+
#' library("mlr3")
79+
#'
80+
#' task = tsk("boston_housing")
81+
#' learner = lrn("regr.debug")
82+
#' po = mlr_pipeops$get("learner_quantiles", learner)
83+
#'
84+
#' po$train(list(task))
85+
#' po$predict(list(task))
86+
PipeOpLearnerQuantiles = R6Class("PipeOpLearnerQuantiles",
87+
inherit = PipeOp,
88+
public = list(
89+
initialize = function(learner, id = NULL, param_vals = list()) {
90+
private$.learner = as_learner(learner, clone = TRUE)
91+
id = id %??% private$.learner$id
92+
type = private$.learner$task_type
93+
94+
if ("regr" != type) {
95+
stop("PipeOpLearnerQuantiles only supports regression.")
96+
}
97+
98+
task_type = mlr_reflections$task_types[type, mult = "first"]$task
99+
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
100+
101+
# paradox requirements 1.0
102+
private$.quantiles_param_set = ps(
103+
q_vals = p_uty(custom_check = crate(function(x) {
104+
checkmate::check_numeric(x, lower = 0L, upper = 1L, any.missing = FALSE, min.len = 1L, sorted = TRUE)
105+
}), tags = c("train", "predict", "required")),
106+
q_response = p_dbl(lower = 0L, upper = 1L, tags = c("train", "predict", "required"))
107+
)
108+
109+
private$.quantiles_param_set$values = list(q_vals = c(0.05, 0.5, 0.95), q_response = 0.5) # default
110+
111+
super$initialize(id, param_set = alist(quantiles = private$.quantiles_param_set, private$.learner$param_set),
112+
param_vals = param_vals,
113+
input = data.table(name = "input", train = task_type, predict = task_type),
114+
output = data.table(name = "output", train = "NULL", predict = out_type),
115+
packages = learner$packages, tags = c("learner", "ensemble")
116+
)
117+
}
118+
),
119+
active = list(
120+
learner = function(val) {
121+
if (!missing(val)) {
122+
if (!identical(val, private$.learner)) {
123+
stop("$learner is read-only.")
124+
}
125+
}
126+
private$.learner
127+
},
128+
129+
learner_model = function(val) {
130+
if (!missing(val)) {
131+
if (!identical(val, private$.learner)) {
132+
stop("$learner_model is read-only.")
133+
}
134+
}
135+
if (is.null(self$state) || is_noop(self$state)) {
136+
private$.learner
137+
} else {
138+
multiplicity_recurse(self$state, function(state) {
139+
map(state$model_states, clone_with_state, learner = private$.learner)
140+
})
141+
}
142+
},
143+
predict_type = function(val) {
144+
if (!missing(val)) {
145+
stop("$predict_type is read-only.")
146+
}
147+
mlr_reflections$learner_predict_types$regr$quantiles # Returns c("response", "quantiles")
148+
}
149+
),
150+
private = list(
151+
.state_class = "pipeop_learner_quantiles_state",
152+
153+
.train = function(inputs) {
154+
task = inputs[[1L]]
155+
pv = private$.quantiles_param_set$values
156+
157+
assert_subset(pv$q_response, pv$q_vals, empty.ok = FALSE)
158+
if ("quantiles" %nin% private$.learner$predict_types) {
159+
stopf("Learner needs to be able to predict quantiles.")
160+
}
161+
private$.learner$predict_type = "quantiles"
162+
163+
# train learner on all quantiles in q_vals
164+
states = map(pv$q_vals, function(quantile) {
165+
on.exit({private$.learner$state = NULL})
166+
private$.learner$quantiles = quantile
167+
private$.learner$train(task)$state
168+
})
169+
170+
# add states of trained models to PipeOp state
171+
self$state = list(model_states = states)
172+
173+
list(NULL)
174+
},
175+
176+
.predict = function(inputs) {
177+
task = inputs[[1L]]
178+
pv = private$.quantiles_param_set$values
179+
180+
prds = pmap(list(self$state$model_states, pv$q_vals), function(state, quantile) {
181+
on.exit({private$.learner$state = NULL})
182+
private$.learner$state = state
183+
private$.learner$quantiles = quantile
184+
as.data.table(private$.learner$predict(task))
185+
})
186+
187+
quantiles = as.matrix(map_dtc(prds, "response"))
188+
unname(quantiles)
189+
attr(quantiles, "probs") = pv$q_vals
190+
attr(quantiles, "response") = pv$q_response
191+
192+
# return quantile PredictionRegr with all requested quantiles
193+
list(as_prediction(as_prediction_data(list(quantiles = quantiles), task = task)))
194+
},
195+
196+
.quantiles_param_set = NULL,
197+
.learner = NULL,
198+
.additional_phash_input = function() private$.learner$phash
199+
)
200+
)
201+
202+
#' @export
203+
marshal_model.pipeop_learner_quantiles_state = function(model, inplace = FALSE, ...) {
204+
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
205+
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
206+
# workhorse function
207+
model$model_states = map(model$model_states, marshal_model, inplace = inplace)
208+
# only wrap this in a marshaled class if the model was actually marshaled above
209+
# (the default marshal method does nothing)
210+
if (some(model$model_states, is_marshaled_model)) {
211+
model = structure(
212+
list(marshaled = model, packages = "mlr3pipelines"),
213+
class = c(paste0(class(model), "_marshaled"), "marshaled")
214+
)
215+
}
216+
model
217+
}
218+
219+
#' @export
220+
unmarshal_model.pipeop_learner_quantiles_state_marshaled = function(model, inplace = FALSE, ...) {
221+
state_marshaled = model$marshaled
222+
state_marshaled$model_states = map(state_marshaled$model_states, unmarshal_model, inplace = inplace)
223+
state_marshaled
224+
}
225+
226+
227+
mlr_pipeops$add("learner_quantiles", PipeOpLearnerQuantiles, list(R6Class("Learner", public = list(id = "learner_quantiles", task_type = "regr", param_set = ps(), packages = "mlr3pipelines"))$new()))

man/PipeOp.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpEnsemble.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpImpute.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpTargetTrafo.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpTaskPreproc.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/PipeOpTaskPreprocSimple.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)