Skip to content

Commit 7167922

Browse files
committed
minor fixes
1 parent 38d561a commit 7167922

File tree

3 files changed

+22
-33
lines changed

3 files changed

+22
-33
lines changed

R/GraphLearner.R

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@
5454
#' The inner tuned parameter values.
5555
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
5656
#' * `validate` :: `numeric(1)`, `"inner_valid"`, `"test"` or `NULL`\cr
57-
#' How to construct the validation data.
57+
#' How to construct the validation data. This also has to be configured in the individual learners wrapped by
58+
#' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
59+
#'
5860
#'
5961
#' @section Internals:
6062
#' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
@@ -108,19 +110,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
108110
}
109111
assert_subset(task_type, mlr_reflections$task_types$type)
110112

111-
112-
private$.can_validate = some(
113-
keep(graph$pipeops, function(x) inherits(x, "PipeOpLearner") || inherits(x, "PipeOpLearnerCV")),
114-
function(po) "validation" %in% po$learner$properties
115-
)
116-
117-
inner_tuning = some(
118-
keep(graph$pipeops, function(x) inherits(x, "PipeOpLearner") || inherits(x, "PipeOpLearnerCV")),
119-
function(po) "inner_tuning" %in% po$learner$properties
120-
)
113+
private$.can_validate = some(learner_wrapping_pipeops(graph), function(po) "validation" %in% po$learner$properties)
114+
private$.can_inner_tuning = some(learner_wrapping_pipeops(graph), function(po) "inner_tuning" %in% po$learner$properties)
121115

122116
properties = setdiff(mlr_reflections$learner_properties[[task_type]],
123-
c("validation", "inner_tuning")[!c(private$.can_validate, inner_tuning)])
117+
c("validation", "inner_tuning")[!c(private$.can_validate, private$.can_inner_tuning)])
124118

125119
super$initialize(id = id, task_type = task_type,
126120
feature_types = mlr_reflections$task_feature_types,
@@ -130,8 +124,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
130124
man = "mlr3pipelines::GraphLearner"
131125
)
132126

133-
private$.param_set = NULL
134-
135127
if (length(param_vals)) {
136128
private$.graph$param_set$values = insert_named(private$.graph$param_set$values, param_vals)
137129
}
@@ -220,7 +212,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
220212
.graph = NULL,
221213
.validate = NULL,
222214
.can_validate = NULL,
215+
.can_inner_tuning = NULL,
223216
.extract_inner_tuned_values = function() {
217+
if (!private$.can_validate) return(NULL)
224218
itvs = unlist(map(
225219
learner_wrapping_pipeops(self$graph_model), function(po) {
226220
if (exists("inner_tuned_values", po$learner)) {
@@ -232,6 +226,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
232226
itvs
233227
},
234228
.extract_inner_valid_scores = function() {
229+
if (!private$.can_inner_tuning) return(NULL)
235230
ivs = unlist(map(
236231
learner_wrapping_pipeops(self$graph_model), function(po) {
237232
if (exists("inner_valid_scores", po$learner)) {
@@ -256,11 +251,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
256251

257252
.train = function(task) {
258253
if (!is.null(get0("validate", self))) {
259-
some_pipeops_validate = map_lgl(
260-
keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV")),
261-
function(po) !is.null(get0("validate", po$learner))
262-
)
263-
254+
some_pipeops_validate = some(learner_wrapping_pipeops(self), function(po) !is.null(get0("validate", po$learner)))
264255
if (!some_pipeops_validate) {
265256
lg$warn("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", self$id)
266257
}
@@ -321,7 +312,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
321312
#' In a [`GraphLearner`], validation can be configured on two levels:
322313
#' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph.
323314
#' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`], which specifies
324-
#' which pipeops actually make use of the validation set.
315+
#' which pipeops actually make use of the validation data.
325316
#' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] should in almost all cases either set it
326317
#' to `NULL` (disable) or `"inner_valid"` (enable).
327318
#'
@@ -364,9 +355,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
364355
if (is.null(validate)) {
365356
learner$validate = NULL
366357
walk(learner_wrapping_pipeops(learner), function(po) {
367-
if (exists("validate", po$learner)) {
368-
po$learner$validate = NULL
369-
}
358+
po$learner$validate = NULL
370359
})
371360
return(invisible(learner))
372361
}
@@ -380,7 +369,9 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
380369
assert_list(args, types = "list")
381370
assert_subset(names(args), ids)
382371

383-
prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner)), is.null)
372+
prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner, ifnotfound = NA)),
373+
function(x) identical(x, NA))
374+
384375
prev_validate = learner$validate
385376

386377
on.exit({
@@ -391,9 +382,9 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
391382
learner$validate = validate
392383

393384
walk(ids, function(poid) {
394-
# learner might be another GraphLearner / AutoTuner
385+
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
395386
withCallingHandlers({
396-
invoke(set_validate, learner = learner$graph$pipeops[[poid]]$learner, validate = "inner_valid", .args = args[[poid]])
387+
invoke(set_validate, learner = learner$graph$pipeops[[poid]]$learner, .args = insert_named(list(validate = "inner_valid"), args[[poid]]))
397388
}, error = function(e) {
398389
e$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", poid, e$message)
399390
stop(e)
@@ -414,11 +405,8 @@ disable_inner_tuning.GraphLearner = function(learner, ids, ...) {
414405
pvs = learner$param_set$values
415406
on.exit({learner$param_set$values = pvs}, add = TRUE)
416407
if (length(ids)) {
417-
walk(learner_wrapping_pipeops(learner$graph$pipeops), function(po) {
418-
disable_inner_tuning(
419-
learner$graph$pipeops[[po$id]]$learner,
420-
ids = po$param_set$ids()[sprintf("%s.%s", po$id, po$param_set$ids()) %in% ids]
421-
)
408+
walk(learner_wrapping_pipeops(learner), function(po) {
409+
disable_inner_tuning(po$learner, ids = po$param_set$ids()[sprintf("%s.%s", po$id, po$param_set$ids()) %in% ids])
422410
})
423411
}
424412
on.exit()

man/mlr_learners_graph.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/set_validate.GraphLearner.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)