5454# ' The inner tuned parameter values.
5555# ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
5656# ' * `validate` :: `numeric(1)`, `"inner_valid"`, `"test"` or `NULL`\cr
57- # ' How to construct the validation data.
57+ # ' How to construct the validation data. This also has to be configured in the individual learners wrapped by
58+ # ' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
59+ # '
5860# '
5961# ' @section Internals:
6062# ' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
@@ -108,19 +110,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
108110 }
109111 assert_subset(task_type , mlr_reflections $ task_types $ type )
110112
111-
112- private $ .can_validate = some(
113- keep(graph $ pipeops , function (x ) inherits(x , " PipeOpLearner" ) || inherits(x , " PipeOpLearnerCV" )),
114- function (po ) " validation" %in% po $ learner $ properties
115- )
116-
117- inner_tuning = some(
118- keep(graph $ pipeops , function (x ) inherits(x , " PipeOpLearner" ) || inherits(x , " PipeOpLearnerCV" )),
119- function (po ) " inner_tuning" %in% po $ learner $ properties
120- )
113+ private $ .can_validate = some(learner_wrapping_pipeops(graph ), function (po ) " validation" %in% po $ learner $ properties )
114+ private $ .can_inner_tuning = some(learner_wrapping_pipeops(graph ), function (po ) " inner_tuning" %in% po $ learner $ properties )
121115
122116 properties = setdiff(mlr_reflections $ learner_properties [[task_type ]],
123- c(" validation" , " inner_tuning" )[! c(private $ .can_validate , inner_tuning )])
117+ c(" validation" , " inner_tuning" )[! c(private $ .can_validate , private $ .can_inner_tuning )])
124118
125119 super $ initialize(id = id , task_type = task_type ,
126120 feature_types = mlr_reflections $ task_feature_types ,
@@ -130,8 +124,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
130124 man = " mlr3pipelines::GraphLearner"
131125 )
132126
133- private $ .param_set = NULL
134-
135127 if (length(param_vals )) {
136128 private $ .graph $ param_set $ values = insert_named(private $ .graph $ param_set $ values , param_vals )
137129 }
@@ -220,7 +212,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
220212 .graph = NULL ,
221213 .validate = NULL ,
222214 .can_validate = NULL ,
215+ .can_inner_tuning = NULL ,
223216 .extract_inner_tuned_values = function () {
217+ if (! private $ .can_validate ) return (NULL )
224218 itvs = unlist(map(
225219 learner_wrapping_pipeops(self $ graph_model ), function (po ) {
226220 if (exists(" inner_tuned_values" , po $ learner )) {
@@ -232,6 +226,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
232226 itvs
233227 },
234228 .extract_inner_valid_scores = function () {
229+ if (! private $ .can_inner_tuning ) return (NULL )
235230 ivs = unlist(map(
236231 learner_wrapping_pipeops(self $ graph_model ), function (po ) {
237232 if (exists(" inner_valid_scores" , po $ learner )) {
@@ -256,11 +251,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
256251
257252 .train = function (task ) {
258253 if (! is.null(get0(" validate" , self ))) {
259- some_pipeops_validate = map_lgl(
260- keep(self $ graph $ pipeops , function (po ) inherits(po , " PipeOpLearner" ) || inherits(po , " PipeOpLearnerCV" )),
261- function (po ) ! is.null(get0(" validate" , po $ learner ))
262- )
263-
254+ some_pipeops_validate = some(learner_wrapping_pipeops(self ), function (po ) ! is.null(get0(" validate" , po $ learner )))
264255 if (! some_pipeops_validate ) {
265256 lg $ warn(" GraphLearner '%s' specifies a validation set, but none of its Learners use it." , self $ id )
266257 }
@@ -321,7 +312,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
321312# ' In a [`GraphLearner`], validation can be configured on two levels:
322313# ' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph.
323314# ' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`], which specifies
324- # ' which pipeops actually make use of the validation set .
315+ # ' which pipeops actually make use of the validation data .
325316# ' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] should in almost all cases either set it
326317# ' to `NULL` (disable) or `"inner_valid"` (enable).
327318# '
@@ -364,9 +355,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
364355 if (is.null(validate )) {
365356 learner $ validate = NULL
366357 walk(learner_wrapping_pipeops(learner ), function (po ) {
367- if (exists(" validate" , po $ learner )) {
368- po $ learner $ validate = NULL
369- }
358+ po $ learner $ validate = NULL
370359 })
371360 return (invisible (learner ))
372361 }
@@ -380,7 +369,9 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
380369 assert_list(args , types = " list" )
381370 assert_subset(names(args ), ids )
382371
383- prev_validate_pos = discard(map(learner_wrapping_pipeops(learner ), function (po ) get0(" validate" , po $ learner )), is.null )
372+ prev_validate_pos = discard(map(learner_wrapping_pipeops(learner ), function (po ) get0(" validate" , po $ learner , ifnotfound = NA )),
373+ function (x ) identical(x , NA ))
374+
384375 prev_validate = learner $ validate
385376
386377 on.exit({
@@ -391,9 +382,9 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
391382 learner $ validate = validate
392383
393384 walk(ids , function (poid ) {
394- # learner might be another GraphLearner / AutoTuner
385+ # learner might be another GraphLearner / AutoTuner so we call into set_validate() again
395386 withCallingHandlers({
396- invoke(set_validate , learner = learner $ graph $ pipeops [[poid ]]$ learner , validate = " inner_valid" , . args = args [[poid ]])
387+ invoke(set_validate , learner = learner $ graph $ pipeops [[poid ]]$ learner , .args = insert_named( list ( validate = " inner_valid" ), args [[poid ]]) )
397388 }, error = function (e ) {
398389 e $ message = sprintf(" Failed to set validate for PipeOp '%s':\n %s" , poid , e $ message )
399390 stop(e )
@@ -414,11 +405,8 @@ disable_inner_tuning.GraphLearner = function(learner, ids, ...) {
414405 pvs = learner $ param_set $ values
415406 on.exit({learner $ param_set $ values = pvs }, add = TRUE )
416407 if (length(ids )) {
417- walk(learner_wrapping_pipeops(learner $ graph $ pipeops ), function (po ) {
418- disable_inner_tuning(
419- learner $ graph $ pipeops [[po $ id ]]$ learner ,
420- ids = po $ param_set $ ids()[sprintf(" %s.%s" , po $ id , po $ param_set $ ids()) %in% ids ]
421- )
408+ walk(learner_wrapping_pipeops(learner ), function (po ) {
409+ disable_inner_tuning(po $ learner , ids = po $ param_set $ ids()[sprintf(" %s.%s" , po $ id , po $ param_set $ ids()) %in% ids ])
422410 })
423411 }
424412 on.exit()
0 commit comments