4747# ' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
4848# ' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
4949# ' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
50+ # ' * `internal_tuned_values` :: named `list()` or `NULL`\cr
51+ # ' The internal tuned parameter values collected from all `PipeOp`s.
52+ # ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
53+ # ' * `internal_valid_scores` :: named `list()` or `NULL`\cr
54+ # ' The internal validation scores as retrieved from the `PipeOps`.
55+ # ' The names are prefixed with the respective IDs of the `PipeOp`s.
56+ # ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
57+ # ' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
58+ # ' How to construct the validation data. This also has to be configured for the individual `PipeOp`s such as
59+ # ' `PipeOpLearner`, see [`set_validate.GraphLearner`].
60+ # ' For more details on the possible values, see [`mlr3::Learner`].
5061# ' * `marshaled` :: `logical(1)`\cr
5162# ' Whether the learner is marshaled.
5263# '
@@ -110,11 +121,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
110121 }
111122 assert_subset(task_type , mlr_reflections $ task_types $ type )
112123
124+ private $ .can_validate = some(graph $ pipeops , function (po ) " validation" %in% po $ properties )
125+ private $ .can_internal_tuning = some(graph $ pipeops , function (po ) " internal_tuning" %in% po $ properties )
126+
127+ properties = setdiff(mlr_reflections $ learner_properties [[task_type ]],
128+ c(" validation" , " internal_tuning" )[! c(private $ .can_validate , private $ .can_internal_tuning )])
129+
113130 super $ initialize(id = id , task_type = task_type ,
114131 feature_types = mlr_reflections $ task_feature_types ,
115132 predict_types = names(mlr_reflections $ learner_predict_types [[task_type ]]),
116133 packages = graph $ packages ,
117- properties = mlr_reflections $ learner_properties [[ task_type ]] ,
134+ properties = properties ,
118135 man = " mlr3pipelines::GraphLearner"
119136 )
120137
@@ -123,8 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
123140 }
124141 if (! is.null(predict_type )) self $ predict_type = predict_type
125142 },
126- base_learner = function (recursive = Inf ) {
143+ base_learner = function (recursive = Inf , return_po = FALSE ) {
127144 assert(check_numeric(recursive , lower = Inf ), check_int(recursive ))
145+ assert_flag(return_po )
128146 if (recursive < = 0 ) return (self )
129147 gm = self $ graph_model
130148 gm_output = gm $ output
@@ -143,7 +161,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
143161 if (length(last_pipeop_id ) > 1 ) stop(" Graph has no unique PipeOp containing a Learner" )
144162 if (length(last_pipeop_id ) == 0 ) stop(" No Learner PipeOp found." )
145163 }
146- learner_model $ base_learner(recursive - 1 )
164+ if (return_po ) {
165+ last_pipeop
166+ } else {
167+ learner_model $ base_learner(recursive - 1 )
168+ }
147169 },
148170 marshal = function (... ) {
149171 learner_marshal(.learner = self , ... )
@@ -153,15 +175,32 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
153175 }
154176 ),
155177 active = list (
178+ internal_valid_scores = function (rhs ) {
179+ assert_ro_binding(rhs )
180+ self $ state $ internal_valid_scores
181+ },
182+ internal_tuned_values = function (rhs ) {
183+ assert_ro_binding(rhs )
184+ self $ state $ internal_tuned_values
185+ },
186+ validate = function (rhs ) {
187+ if (! missing(rhs )) {
188+ if (! private $ .can_validate ) {
189+ stopf(" None of the PipeOps in Graph '%s' supports validation." , self $ id )
190+ }
191+ private $ .validate = assert_validate(rhs )
192+ }
193+ private $ .validate
194+ },
156195 marshaled = function () {
157196 learner_marshaled(self )
158197 },
159198 hash = function () {
160- digest(list (class(self ), self $ id , self $ graph $ hash , private $ .predict_type ,
199+ digest(list (class(self ), self $ id , self $ graph $ hash , private $ .predict_type , private $ .validate ,
161200 self $ fallback $ hash , self $ parallel_predict ), algo = " xxhash64" )
162201 },
163202 phash = function () {
164- digest(list (class(self ), self $ id , self $ graph $ phash , private $ .predict_type ,
203+ digest(list (class(self ), self $ id , self $ graph $ phash , private $ .predict_type , private $ .validate ,
165204 self $ fallback $ hash , self $ parallel_predict ), algo = " xxhash64" )
166205 },
167206 predict_type = function (rhs ) {
@@ -195,6 +234,21 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
195234 ),
196235 private = list (
197236 .graph = NULL ,
237+ .validate = NULL ,
238+ .can_validate = NULL ,
239+ .can_internal_tuning = NULL ,
240+ .extract_internal_tuned_values = function () {
241+ if (! private $ .can_validate ) return (NULL )
242+ itvs = unlist(map(pos_with_property(self $ graph_model , " internal_tuning" ), " internal_tuned_values" ), recursive = FALSE )
243+ if (! length(itvs )) return (named_list())
244+ itvs
245+ },
246+ .extract_internal_valid_scores = function () {
247+ if (! private $ .can_internal_tuning ) return (NULL )
248+ ivs = unlist(map(pos_with_property(self $ graph_model , " validation" ), " internal_valid_scores" ), recursive = FALSE )
249+ if (! length(ivs )) return (named_list())
250+ ivs
251+ },
198252 deep_clone = function (name , value ) {
199253 # FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
200254 if (is.environment(value ) && ! is.null(value [[" .__enclos_env__" ]])) {
@@ -207,6 +261,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
207261 },
208262
209263 .train = function (task ) {
264+ if (! is.null(get0(" validate" , self ))) {
265+ some_pipeops_validate = some(pos_with_property(self , " validation" ), function (po ) ! is.null(po $ validate ))
266+ if (! some_pipeops_validate ) {
267+ lg $ warn(" GraphLearner '%s' specifies a validation set, but none of its PipeOps use it." , self $ id )
268+ }
269+ }
270+
210271 on.exit({self $ graph $ state = NULL })
211272 self $ graph $ train(task )
212273 state = self $ graph $ state
@@ -255,6 +316,109 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
255316 )
256317)
257318
319+ # ' @title Configure Validation for a GraphLearner
320+ # '
321+ # ' @description
322+ # ' Configure validation for a graph learner.
323+ # '
324+ # ' In a [`GraphLearner`], validation can be configured on two levels:
325+ # ' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph.
326+ # ' 2. On the level of the individual `PipeOp`s (such as `PipeOpLearner`), which specifies
327+ # ' which pipeops actually make use of the validation data (set its `$validate` field to `"predefined"`) or not (set it to `NULL`).
328+ # ' This can be specified via the argument `ids`.
329+ # '
330+ # ' @param learner ([`GraphLearner`])\cr
331+ # ' The graph learner to configure.
332+ # ' @param validate (`numeric(1)`, `"predefined"`, `"test"`, or `NULL`)\cr
333+ # ' How to set the `$validate` field of the learner.
334+ # ' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops.
335+ # ' @param ids (`NULL` or `character()`)\cr
336+ # ' For which pipeops to enable validation.
337+ # ' This parameter is ignored when `validate` is set to `NULL`.
338+ # ' By default, validation is enabled for the final `PipeOp` in the `Graph`.
339+ # ' @param args_all (`list()`)\cr
340+ # ' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual
341+ # ' `PipeOp`s.
342+ # ' @param args (named `list()`)\cr
343+ # ' Rarely needed.
344+ # ' A named list of lists, specifying additional argments to be passed to [`set_validate()`] when calling it on the individual
345+ # ' `PipeOp`s.
346+ # ' @param ... (any)\cr
347+ # ' Currently unused.
348+ # '
349+ # ' @export
350+ # ' @examples
351+ # ' library(mlr3)
352+ # '
353+ # ' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
354+ # ' set_validate(glrn, 0.3)
355+ # ' glrn$validate
356+ # ' glrn$graph$pipeops$classif.debug$learner$validate
357+ # '
358+ # ' set_validate(glrn, NULL)
359+ # ' glrn$validate
360+ # ' glrn$graph$pipeops$classif.debug$learner$validate
361+ # '
362+ # ' set_validate(glrn, 0.2, ids = "classif.debug")
363+ # ' glrn$validate
364+ # ' glrn$graph$pipeops$classif.debug$learner$validate
365+ set_validate.GraphLearner = function (learner , validate , ids = NULL , args_all = list (), args = list (), ... ) {
366+ prev_validate_pos = map(pos_with_property(learner $ graph $ pipeops , " validation" ), " validate" )
367+ prev_validate = learner $ validate
368+ on.exit({
369+ iwalk(prev_validate_pos , function (prev_val , poid ) {
370+ # Here we don't call into set_validate() as this also does not ensure that we are able to correctly
371+ # reset the configuration to the previous state, is less transparent and might fail again
372+ # The error message informs the user about this though via the calling handlers below
373+ learner $ graph $ pipeops [[poid ]]$ validate = prev_val
374+ })
375+ learner $ validate = prev_validate
376+ }, add = TRUE )
377+
378+ if (is.null(validate )) {
379+ learner $ validate = NULL
380+ walk(pos_with_property(learner $ graph $ pipeops , " validation" ), function (po ) {
381+ invoke(set_validate , po , validate = NULL , args_all = args_all , args = args [[po $ id ]] %??% list ())
382+ })
383+ on.exit()
384+ return (invisible (learner ))
385+ }
386+
387+ if (is.null(ids )) {
388+ ids = learner $ base_learner(return_po = TRUE )$ id
389+ } else {
390+ assert_subset(ids , ids(pos_with_property(learner $ graph $ pipeops , " validation" )))
391+ }
392+
393+ assert_list(args , types = " list" )
394+ assert_list(args_all )
395+ assert_subset(names(args ), ids )
396+
397+ learner $ validate = validate
398+
399+ walk(ids , function (poid ) {
400+ # learner might be another GraphLearner / AutoTuner so we call into set_validate() again
401+ withCallingHandlers({
402+ args = insert_named(insert_named(list (validate = " predefined" ), args_all ), args [[poid ]])
403+ invoke(set_validate , learner $ graph $ pipeops [[poid ]], .args = args )
404+ }, error = function (e ) {
405+ e $ message = sprintf(paste0(
406+ " Failed to set validate for PipeOp '%s':\n %s\n " ,
407+ " Trying to heuristically reset validation to its previous state, please check the results" ), poid , e $ message )
408+ stop(e )
409+ }, warning = function (w ) {
410+ w $ message = sprintf(paste0(
411+ " Failed to set validate for PipeOp '%s':\n %s\n " ,
412+ " Trying to heuristically reset validation to its previous state, please check the results" ), poid , w $ message )
413+ warning(w )
414+ invokeRestart(" muffleWarning" )
415+ })
416+ })
417+ on.exit()
418+
419+ invisible (learner )
420+ }
421+
258422# ' @export
259423marshal_model.graph_learner_model = function (model , inplace = FALSE , ... ) {
260424 xm = map(.x = model , .f = marshal_model , inplace = inplace , ... )
0 commit comments