5151# ' The internal tuned parameter values.
5252# ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
5353# ' * `internal_valid_scores` :: named `list()` or `NULL`\cr
54- # ' The internal tuned parameter values.
54+ # ' The internal validation scores as retrieved from the `PipeOps`.
55+ # ' The names are prefixed with the respective IDs of the `PipeOp`s.
5556# ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
5657# ' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
57- # ' How to construct the validation data. This also has to be configured in the individual learners wrapped by
58+ # ' How to construct the validation data. This also has to be configured in the individual `PipeOp`s such as
5859# ' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
5960# ' For more details on the possible values, see [`mlr3::Learner`].
6061# ' * `marshaled` :: `logical(1)`\cr
@@ -121,7 +122,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
121122 assert_subset(task_type , mlr_reflections $ task_types $ type )
122123
123124 private $ .can_validate = some(graph $ pipeops , function (po ) " validation" %in% po $ properties )
124- private $ .can_validate = some(graph $ pipeops , function (po ) " internal_tuning" %in% po $ properties )
125+ private $ .can_internal_tuning = some(graph $ pipeops , function (po ) " internal_tuning" %in% po $ properties )
125126
126127 properties = setdiff(mlr_reflections $ learner_properties [[task_type ]],
127128 c(" validation" , " internal_tuning" )[! c(private $ .can_validate , private $ .can_internal_tuning )])
@@ -139,11 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
139140 }
140141 if (! is.null(predict_type )) self $ predict_type = predict_type
141142 },
142- base_learner = function (recursive = Inf ) {
143- self $ base_pipeop(recursive = recursive )$ learner_model
144- },
145- base_pipeop = function (recursive = Inf ) {
143+ base_learner = function (recursive = Inf , return_po = FALSE ) {
146144 assert(check_numeric(recursive , lower = Inf ), check_int(recursive ))
145+ assert_flag(return_po )
147146 if (recursive < = 0 ) return (self )
148147 gm = self $ graph_model
149148 gm_output = gm $ output
@@ -162,7 +161,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
162161 if (length(last_pipeop_id ) > 1 ) stop(" Graph has no unique PipeOp containing a Learner" )
163162 if (length(last_pipeop_id ) == 0 ) stop(" No Learner PipeOp found." )
164163 }
165- learner_model $ base_pipeop(recursive - 1 )
164+ if (return_po ) {
165+ last_pipeop
166+ } else {
167+ learner_model $ base_learner(recursive - 1 )
168+ }
166169 },
167170 marshal = function (... ) {
168171 learner_marshal(.learner = self , ... )
@@ -236,13 +239,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
236239 .can_internal_tuning = NULL ,
237240 .extract_internal_tuned_values = function () {
238241 if (! private $ .can_validate ) return (NULL )
239- itvs = unlist(map(pos_with_property(self , " internal_tuning" ), " internal_tuned_values" ), recursive = FALSE )
242+ itvs = unlist(map(pos_with_property(self $ graph_model , " internal_tuning" ), " internal_tuned_values" ), recursive = FALSE )
240243 if (! length(itvs )) return (named_list())
241244 itvs
242245 },
243246 .extract_internal_valid_scores = function () {
244247 if (! private $ .can_internal_tuning ) return (NULL )
245- its = unlist(map(pos_with_property(self , " validation" ), " internal_valid_scores" ), recursive = FALSE )
248+ ivs = unlist(map(pos_with_property(self $ graph_model , " validation" ), " internal_valid_scores" ), recursive = FALSE )
246249 if (is.null(ivs ) || ! length(ivs )) return (named_list())
247250 ivs
248251 },
@@ -367,30 +370,28 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = l
367370 if (is.null(validate )) {
368371 learner $ validate = NULL
369372 walk(pos_with_property(learner $ graph $ pipeops , " validation" ), function (po ) {
370- # disabling needs no extra arguments
371373 invoke(set_validate , po , validate = NULL , args_all = args_all , args = args [[po $ id ]] %??% list ())
372374 })
373375 return (invisible (learner ))
374376 }
375377
376378 if (is.null(ids )) {
377- ids = learner $ base_pipeop (recursive = 1 )$ id
379+ ids = learner $ base_learner (recursive = 1 , return_po = TRUE )$ id
378380 } else {
379381 assert_subset(ids , ids(pos_with_property(learner $ graph $ pipeops , " validation" )))
380382 }
381383
382384 assert_list(args , types = " list" )
383- assert_list(args_all , types = " list " )
385+ assert_list(args_all )
384386 assert_subset(names(args ), ids )
385387
386388 prev_validate_pos = map(pos_with_property(learner $ graph $ pipeops , " validation" ), " validate" )
387389 prev_validate = learner $ validate
388390 on.exit({
389- iwalk(prev_validate_pos , function (val , poid ) {
390- # passing the args here is just a heuristic that can in principle fail, but this should be extremely
391- # rare
392- args = args [[poid ]] %??% list ()
393- set_validate(learner $ graph $ pipeops [[poid ]], validate = val , args = args , args_all = args_all )
391+ iwalk(prev_validate_pos , function (prev_val , poid ) {
392+ # Here we don't call into set_validate() as this also does not ensure that we are able to correctly
393+ # reset the configuration to the previous state (e.g. for AutoTuner) and is less transparent
394+ learner $ graph $ pipeops [[poid ]]$ validate = prev_val
394395 })
395396 learner $ validate = prev_validate
396397 }, add = TRUE )
@@ -400,13 +401,17 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = l
400401 walk(ids , function (poid ) {
401402 # learner might be another GraphLearner / AutoTuner so we call into set_validate() again
402403 withCallingHandlers({
403- args = c( args [[poid ]], args_all ) %??% list ( )
404- set_validate( learner $ graph $ pipeops [[poid ]], .args = insert_named( list ( validate = " predefined " ), args ) )
404+ args = insert_named(c( list ( validate = " predefined " ), args_all ), args [[poid ]])
405+ invoke( set_validate , learner $ graph $ pipeops [[poid ]], .args = args )
405406 }, error = function (e ) {
406- e $ message = sprintf(" Failed to set validate for PipeOp '%s':\n %s" , poid , e $ message )
407+ e $ message = sprintf(paste0(
408+ " Failed to set validate for PipeOp '%s':\n %s\n " ,
409+ " Trying to heuristically reset validation to its previous state, please check the results" ), poid , e $ message )
407410 stop(e )
408411 }, warning = function (w ) {
409- w $ message = sprintf(" Failed to set validate for PipeOp '%s':\n %s" , po $ id , w $ message )
412+ w $ message = sprintf(paste0(
413+ " Failed to set validate for PipeOp '%s':\n %s\n " ,
414+ " Trying to heuristically reset validation to its previous state, please check the results" ), poid , w $ message )
410415 warning(w )
411416 invokeRestart(" muffleWarning" )
412417 })
@@ -487,4 +492,4 @@ infer_task_type = function(graph) {
487492 task_type = get_po_task_type(graph $ pipeops [[graph $ rhs ]])
488493 }
489494 c(task_type , " classif" )[[1 ]] # "classif" as final fallback
490- }
495+ }
0 commit comments