Skip to content

Commit 3de86cd

Browse files
authored
Merge pull request #770 from mlr-org/feat/inner_valid
Feat/inner valid
2 parents d0a5495 + ac09cae commit 3de86cd

22 files changed

+621
-18
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Authors@R:
2828
comment = c(ORCID = "0000-0001-9754-0393")),
2929
person(given = "Sebastian",
3030
family = "Fischer",
31-
role = "ctb",
31+
role = "aut",
3232
email = "[email protected]",
3333
comment = c(ORCID = "0000-0002-9609-3197")),
3434
person(given = "Susanne",

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ S3method(pos,list)
2727
S3method(predict,Graph)
2828
S3method(print,Multiplicity)
2929
S3method(print,Selector)
30+
S3method(set_validate,GraphLearner)
31+
S3method(set_validate,PipeOpLearner)
3032
S3method(unmarshal_model,Multiplicity_marshaled)
3133
S3method(unmarshal_model,graph_learner_model_marshaled)
3234
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
* Compatibility with new `bbotk` release.
44
* Added marshaling support to `GraphLearner`
5+
* Support internal tuning and validation
56

67
# mlr3pipelines 0.5.2
78

R/GraphLearner.R

Lines changed: 169 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@
4747
#' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
4848
#' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
4949
#' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
50+
#' * `internal_tuned_values` :: named `list()` or `NULL`\cr
51+
#' The internal tuned parameter values collected from all `PipeOp`s.
52+
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
53+
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
54+
#' The internal validation scores as retrieved from the `PipeOps`.
55+
#' The names are prefixed with the respective IDs of the `PipeOp`s.
56+
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
57+
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
58+
#' How to construct the validation data. This also has to be configured for the individual `PipeOp`s such as
59+
#' `PipeOpLearner`, see [`set_validate.GraphLearner`].
60+
#' For more details on the possible values, see [`mlr3::Learner`].
5061
#' * `marshaled` :: `logical(1)`\cr
5162
#' Whether the learner is marshaled.
5263
#'
@@ -110,11 +121,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
110121
}
111122
assert_subset(task_type, mlr_reflections$task_types$type)
112123

124+
private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties)
125+
private$.can_internal_tuning = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)
126+
127+
properties = setdiff(mlr_reflections$learner_properties[[task_type]],
128+
c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)])
129+
113130
super$initialize(id = id, task_type = task_type,
114131
feature_types = mlr_reflections$task_feature_types,
115132
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
116133
packages = graph$packages,
117-
properties = mlr_reflections$learner_properties[[task_type]],
134+
properties = properties,
118135
man = "mlr3pipelines::GraphLearner"
119136
)
120137

@@ -123,8 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
123140
}
124141
if (!is.null(predict_type)) self$predict_type = predict_type
125142
},
126-
base_learner = function(recursive = Inf) {
143+
base_learner = function(recursive = Inf, return_po = FALSE) {
127144
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
145+
assert_flag(return_po)
128146
if (recursive <= 0) return(self)
129147
gm = self$graph_model
130148
gm_output = gm$output
@@ -143,7 +161,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
143161
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
144162
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
145163
}
146-
learner_model$base_learner(recursive - 1)
164+
if (return_po) {
165+
last_pipeop
166+
} else {
167+
learner_model$base_learner(recursive - 1)
168+
}
147169
},
148170
marshal = function(...) {
149171
learner_marshal(.learner = self, ...)
@@ -153,15 +175,32 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
153175
}
154176
),
155177
active = list(
178+
internal_valid_scores = function(rhs) {
179+
assert_ro_binding(rhs)
180+
self$state$internal_valid_scores
181+
},
182+
internal_tuned_values = function(rhs) {
183+
assert_ro_binding(rhs)
184+
self$state$internal_tuned_values
185+
},
186+
validate = function(rhs) {
187+
if (!missing(rhs)) {
188+
if (!private$.can_validate) {
189+
stopf("None of the PipeOps in Graph '%s' supports validation.", self$id)
190+
}
191+
private$.validate = assert_validate(rhs)
192+
}
193+
private$.validate
194+
},
156195
marshaled = function() {
157196
learner_marshaled(self)
158197
},
159198
hash = function() {
160-
digest(list(class(self), self$id, self$graph$hash, private$.predict_type,
199+
digest(list(class(self), self$id, self$graph$hash, private$.predict_type, private$.validate,
161200
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
162201
},
163202
phash = function() {
164-
digest(list(class(self), self$id, self$graph$phash, private$.predict_type,
203+
digest(list(class(self), self$id, self$graph$phash, private$.predict_type, private$.validate,
165204
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
166205
},
167206
predict_type = function(rhs) {
@@ -195,6 +234,21 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
195234
),
196235
private = list(
197236
.graph = NULL,
237+
.validate = NULL,
238+
.can_validate = NULL,
239+
.can_internal_tuning = NULL,
240+
.extract_internal_tuned_values = function() {
241+
if (!private$.can_validate) return(NULL)
242+
itvs = unlist(map(pos_with_property(self$graph_model, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
243+
if (!length(itvs)) return(named_list())
244+
itvs
245+
},
246+
.extract_internal_valid_scores = function() {
247+
if (!private$.can_internal_tuning) return(NULL)
248+
ivs = unlist(map(pos_with_property(self$graph_model, "validation"), "internal_valid_scores"), recursive = FALSE)
249+
if (!length(ivs)) return(named_list())
250+
ivs
251+
},
198252
deep_clone = function(name, value) {
199253
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
200254
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
@@ -207,6 +261,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
207261
},
208262

209263
.train = function(task) {
264+
if (!is.null(get0("validate", self))) {
265+
some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate))
266+
if (!some_pipeops_validate) {
267+
lg$warn("GraphLearner '%s' specifies a validation set, but none of its PipeOps use it.", self$id)
268+
}
269+
}
270+
210271
on.exit({self$graph$state = NULL})
211272
self$graph$train(task)
212273
state = self$graph$state
@@ -255,6 +316,109 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
255316
)
256317
)
257318

319+
#' @title Configure Validation for a GraphLearner
320+
#'
321+
#' @description
322+
#' Configure validation for a graph learner.
323+
#'
324+
#' In a [`GraphLearner`], validation can be configured on two levels:
325+
#' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph.
326+
#' 2. On the level of the individual `PipeOp`s (such as `PipeOpLearner`), which specifies
327+
#' which pipeops actually make use of the validation data (set its `$validate` field to `"predefined"`) or not (set it to `NULL`).
328+
#' This can be specified via the argument `ids`.
329+
#'
330+
#' @param learner ([`GraphLearner`])\cr
331+
#' The graph learner to configure.
332+
#' @param validate (`numeric(1)`, `"predefined"`, `"test"`, or `NULL`)\cr
333+
#' How to set the `$validate` field of the learner.
334+
#' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops.
335+
#' @param ids (`NULL` or `character()`)\cr
336+
#' For which pipeops to enable validation.
337+
#' This parameter is ignored when `validate` is set to `NULL`.
338+
#' By default, validation is enabled for the final `PipeOp` in the `Graph`.
339+
#' @param args_all (`list()`)\cr
340+
#' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual
341+
#' `PipeOp`s.
342+
#' @param args (named `list()`)\cr
343+
#' Rarely needed.
344+
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] when calling it on the individual
345+
#' `PipeOp`s.
346+
#' @param ... (any)\cr
347+
#' Currently unused.
348+
#'
349+
#' @export
350+
#' @examples
351+
#' library(mlr3)
352+
#'
353+
#' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
354+
#' set_validate(glrn, 0.3)
355+
#' glrn$validate
356+
#' glrn$graph$pipeops$classif.debug$learner$validate
357+
#'
358+
#' set_validate(glrn, NULL)
359+
#' glrn$validate
360+
#' glrn$graph$pipeops$classif.debug$learner$validate
361+
#'
362+
#' set_validate(glrn, 0.2, ids = "classif.debug")
363+
#' glrn$validate
364+
#' glrn$graph$pipeops$classif.debug$learner$validate
365+
set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = list(), args = list(), ...) {
366+
prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate")
367+
prev_validate = learner$validate
368+
on.exit({
369+
iwalk(prev_validate_pos, function(prev_val, poid) {
370+
# Here we don't call into set_validate() as this also does not ensure that we are able to correctly
371+
# reset the configuration to the previous state, is less transparent and might fail again
372+
# The error message informs the user about this though via the calling handlers below
373+
learner$graph$pipeops[[poid]]$validate = prev_val
374+
})
375+
learner$validate = prev_validate
376+
}, add = TRUE)
377+
378+
if (is.null(validate)) {
379+
learner$validate = NULL
380+
walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) {
381+
invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list())
382+
})
383+
on.exit()
384+
return(invisible(learner))
385+
}
386+
387+
if (is.null(ids)) {
388+
ids = learner$base_learner(return_po = TRUE)$id
389+
} else {
390+
assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation")))
391+
}
392+
393+
assert_list(args, types = "list")
394+
assert_list(args_all)
395+
assert_subset(names(args), ids)
396+
397+
learner$validate = validate
398+
399+
walk(ids, function(poid) {
400+
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
401+
withCallingHandlers({
402+
args = insert_named(insert_named(list(validate = "predefined"), args_all), args[[poid]])
403+
invoke(set_validate, learner$graph$pipeops[[poid]], .args = args)
404+
}, error = function(e) {
405+
e$message = sprintf(paste0(
406+
"Failed to set validate for PipeOp '%s':\n%s\n",
407+
"Trying to heuristically reset validation to its previous state, please check the results"), poid, e$message)
408+
stop(e)
409+
}, warning = function(w) {
410+
w$message = sprintf(paste0(
411+
"Failed to set validate for PipeOp '%s':\n%s\n",
412+
"Trying to heuristically reset validation to its previous state, please check the results"), poid, w$message)
413+
warning(w)
414+
invokeRestart("muffleWarning")
415+
})
416+
})
417+
on.exit()
418+
419+
invisible(learner)
420+
}
421+
258422
#' @export
259423
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
260424
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)

R/PipeOp.R

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,20 @@
135135
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
136136
#' * `man` :: `character(1)`\cr
137137
#' Identifying string of the help page that shows with `help()`.
138+
#' * `properties` :: `character()`\cr
139+
#' The properties of the pipeop.
140+
#' Currently supported values are:
141+
#' * `"validation"`: the `PipeOp` can make use of the `$internal_valid_task` of an [`mlr3::Task`].
142+
#' This is for example used for `PipeOpLearner`s that wrap a `Learner` with this property, see [`mlr3::Learner`].
143+
#' `PipeOp`s that have this property, also have a `$validate` field, which controls whether to use the validation task,
144+
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
145+
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters.
146+
#' This works analogously to the internal tuning implementation for [`mlr3::Learner`].
147+
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values` and have at least one
148+
#' parameter tagged with `"internal_tuning"`.
149+
#' An example for such a `PipeOp` is a `PipeOpLearner` that wraps a `Learner` with the `"internal_tuning"` property.
150+
#'
151+
#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`.
138152
#'
139153
#' @section Methods:
140154
#' * `train(input)`\cr
@@ -235,8 +249,9 @@ PipeOp = R6Class("PipeOp",
235249
output = NULL,
236250
.result = NULL,
237251
tags = NULL,
252+
properties = NULL,
238253

239-
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
254+
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
240255
if (inherits(param_set, "ParamSet")) {
241256
private$.param_set = assert_param_set(param_set)
242257
private$.param_set_source = NULL
@@ -246,6 +261,7 @@ PipeOp = R6Class("PipeOp",
246261
}
247262
self$id = assert_string(id)
248263

264+
self$properties = assert_subset(properties, mlr_reflections$pipeops$properties)
249265
self$param_set$values = insert_named(self$param_set$values, param_vals)
250266
self$input = assert_connection_table(input)
251267
self$output = assert_connection_table(output)
@@ -601,4 +617,3 @@ evaluate_multiplicities = function(self, unpacked, evalcall, instate) {
601617
map(transpose_list(map(result, "output")), as.Multiplicity)
602618
}
603619
}
604-

R/PipeOpImpute.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ PipeOpImpute = R6Class("PipeOpImpute",
195195

196196
self$state$outtasklayout = copy(intask$feature_types)
197197

198+
if (!is.null(intask$internal_valid_task)) {
199+
intask$internal_valid_task = private$.predict(list(intask$internal_valid_task))[[1L]]
200+
}
201+
198202
list(intask)
199203
},
200204

0 commit comments

Comments
 (0)