5656# ' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
5757# ' How to construct the validation data. This also has to be configured in the individual learners wrapped by
5858# ' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
59+ # ' For more details on the possible values, see [`mlr3::Learner`].
5960# ' * `marshaled` :: `logical(1)`\cr
6061# ' Whether the learner is marshaled.
6162# '
@@ -119,8 +120,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
119120 }
120121 assert_subset(task_type , mlr_reflections $ task_types $ type )
121122
122- private $ .can_validate = some(learner_wrapping_pipeops( graph ) , function (po ) " validation" %in% po $ learner $ properties )
123- private $ .can_internal_tuning = some(learner_wrapping_pipeops( graph ) , function (po ) " internal_tuning" %in% po $ learner $ properties )
123+ private $ .can_validate = some(graph $ pipeops , function (po ) " validation" %in% po $ properties )
124+ private $ .can_validate = some(graph $ pipeops , function (po ) " internal_tuning" %in% po $ properties )
124125
125126 properties = setdiff(mlr_reflections $ learner_properties [[task_type ]],
126127 c(" validation" , " internal_tuning" )[! c(private $ .can_validate , private $ .can_internal_tuning )])
@@ -139,6 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
139140 if (! is.null(predict_type )) self $ predict_type = predict_type
140141 },
141142 base_learner = function (recursive = Inf ) {
143+ self $ base_pipeop(recursive = recursive )$ learner_model
144+ },
145+ base_pipeop = function (recursive = Inf ) {
142146 assert(check_numeric(recursive , lower = Inf ), check_int(recursive ))
143147 if (recursive < = 0 ) return (self )
144148 gm = self $ graph_model
@@ -158,7 +162,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
158162 if (length(last_pipeop_id ) > 1 ) stop(" Graph has no unique PipeOp containing a Learner" )
159163 if (length(last_pipeop_id ) == 0 ) stop(" No Learner PipeOp found." )
160164 }
161- learner_model $ base_learner (recursive - 1 )
165+ learner_model $ base_pipeop (recursive - 1 )
162166 },
163167 marshal = function (... ) {
164168 learner_marshal(.learner = self , ... )
@@ -179,7 +183,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
179183 validate = function (rhs ) {
180184 if (! missing(rhs )) {
181185 if (! private $ .can_validate ) {
182- stopf(" None of the Learners wrapped by GraphLearner '%s' support validation." , self $ id )
186+ stopf(" None of the PipeOps in Graph '%s' supports validation." , self $ id )
183187 }
184188 private $ .validate = assert_validate(rhs )
185189 }
@@ -232,30 +236,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
232236 .can_internal_tuning = NULL ,
233237 .extract_internal_tuned_values = function () {
234238 if (! private $ .can_validate ) return (NULL )
235- itvs = unlist(map(
236- learner_wrapping_pipeops(self $ graph_model ), function (po ) {
237- if (exists(" internal_tuned_values" , po $ learner )) {
238- po $ learner_model $ internal_tuned_values
239- }
240- }
241- ), recursive = FALSE )
242- if (is.null(itvs ) || ! length(itvs )) return (named_list())
239+ itvs = unlist(map(pos_with_property(self , " internal_tuning" ), " internal_tuned_values" ), recursive = FALSE )
240+ if (! length(itvs )) return (named_list())
243241 itvs
244242 },
245243 .extract_internal_valid_scores = function () {
246244 if (! private $ .can_internal_tuning ) return (NULL )
247- ivs = unlist(map(
248- learner_wrapping_pipeops(self $ graph_model ), function (po ) {
249- if (exists(" internal_valid_scores" , po $ learner )) {
250- po $ learner_model $ internal_valid_scores
251- }
252- }
253- ), recursive = FALSE )
245+ its = unlist(map(pos_with_property(self , " validation" ), " internal_valid_scores" ), recursive = FALSE )
254246 if (is.null(ivs ) || ! length(ivs )) return (named_list())
255247 ivs
256248 },
257249 deep_clone = function (name , value ) {
258- private $ .param_set = NULL
259250 # FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
260251 if (is.environment(value ) && ! is.null(value [[" .__enclos_env__" ]])) {
261252 return (value $ clone(deep = TRUE ))
@@ -268,17 +259,10 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
268259
269260 .train = function (task ) {
270261 if (! is.null(get0(" validate" , self ))) {
271- some_pipeops_validate = some(learner_wrapping_pipeops (self ), function (po ) ! is.null(get0( " validate " , po $ learner ) ))
262+ some_pipeops_validate = some(pos_with_property (self , " validation " ), function (po ) ! is.null(po $ validate ))
272263 if (! some_pipeops_validate ) {
273264 lg $ warn(" GraphLearner '%s' specifies a validation set, but none of its Learners use it." , self $ id )
274265 }
275- } else {
276- # otherwise the pipeops will preprocess this unnecessarily
277- if (! is.null(task $ internal_valid_task )) {
278- prev_itv = task $ internal_valid_task
279- on.exit({task $ internal_valid_task = prev_itv }, add = TRUE )
280- task $ internal_valid_task = NULL
281- }
282266 }
283267
284268 on.exit({self $ graph $ state = NULL })
@@ -350,6 +334,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
350334# ' For which pipeops to enable validation.
351335# ' This parameter is ignored when `validate` is set to `NULL`.
352336# ' By default, validation is enabled for the base learner.
337+ # ' @param args_all (`list()`)\cr
338+ # ' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual
339+ # ' `PipeOp`s.
353340# ' @param args (named `list()`)\cr
354341# ' Rarely needed.
355342# ' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective learners.
@@ -376,31 +363,35 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
376363# ' glrn$validate
377364# ' glrn$graph$pipeops$classif.debug$learner$validate
378365# ' glrn$graph$pipeops$final$learner$validate
379- set_validate.GraphLearner = function (learner , validate , ids = NULL , args = list (), ... ) {
366+ set_validate.GraphLearner = function (learner , validate , ids = NULL , args_all = list (), args = list (), ... ) {
380367 if (is.null(validate )) {
381368 learner $ validate = NULL
382- walk(learner_wrapping_pipeops(learner ), function (po ) {
383- po $ learner $ validate = NULL
369+ walk(pos_with_property(learner $ graph $ pipeops , " validation" ), function (po ) {
370+ # disabling needs no extra arguments
371+ invoke(set_validate , po , validate = NULL , args_all = args_all , args = args [[po $ id ]] %??% list ())
384372 })
385373 return (invisible (learner ))
386374 }
387375
388376 if (is.null(ids )) {
389- ids = base_pipeop(learner )$ id
377+ ids = learner $ base_pipeop(recursive = 1 )$ id
390378 } else {
391- assert_subset(ids , ids(keep(learner_wrapping_pipeops( learner ), function ( po ) " validation" %in% po $ learner $ properties )))
379+ assert_subset(ids , ids(pos_with_property( learner $ graph $ pipeops , " validation" )))
392380 }
393381
394382 assert_list(args , types = " list" )
383+ assert_list(args_all , types = " list" )
395384 assert_subset(names(args ), ids )
396385
397- prev_validate_pos = discard(map(learner_wrapping_pipeops(learner ), function (po ) get0(" validate" , po $ learner , ifnotfound = NA )),
398- function (x ) identical(x , NA ))
399-
386+ prev_validate_pos = map(pos_with_property(learner $ graph $ pipeops , " validation" ), " validate" )
400387 prev_validate = learner $ validate
401-
402388 on.exit({
403- iwalk(prev_validate_pos , function (val , poid ) learner $ graph $ pipeops [[poid ]]$ learner $ validate = val )
389+ iwalk(prev_validate_pos , function (val , poid ) {
390+ # passing the args here is just a heuristic that can in principle fail, but this should be extremely
391+ # rare
392+ args = args [[poid ]] %??% list ()
393+ set_validate(learner $ graph $ pipeops [[poid ]], validate = val , args = args , args_all = args_all )
394+ })
404395 learner $ validate = prev_validate
405396 }, add = TRUE )
406397
@@ -409,7 +400,8 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
409400 walk(ids , function (poid ) {
410401 # learner might be another GraphLearner / AutoTuner so we call into set_validate() again
411402 withCallingHandlers({
412- invoke(set_validate , learner = learner $ graph $ pipeops [[poid ]]$ learner , .args = insert_named(list (validate = " predefined" ), args [[poid ]]))
403+ args = c(args [[poid ]], args_all ) %??% list ()
404+ set_validate(learner $ graph $ pipeops [[poid ]], .args = insert_named(list (validate = " predefined" ), args ))
413405 }, error = function (e ) {
414406 e $ message = sprintf(" Failed to set validate for PipeOp '%s':\n %s" , poid , e $ message )
415407 stop(e )
0 commit comments