4747# ' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
4848# ' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
4949# ' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
50+ # ' * `inner_tuned_values` :: named `list()` or `NULL`\cr
51+ # ' The inner tuned parameter values.
52+ # ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports inner tuning.
53+ # ' * `inner_valid_scores` :: named `list()` or `NULL`\cr
54+ # ' The inner tuned parameter values.
55+ # ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
5056# '
5157# ' @section Internals:
5258# ' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
@@ -110,7 +116,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
110116 )
111117
112118 properties = setdiff(mlr_reflections $ learner_properties [[task_type ]],
113- c(" validation" , " inner_tuning" )[! c(private $ .validate , inner_tuning )])
119+ c(" validation" , " inner_tuning" )[! c(private $ .can_validate , inner_tuning )])
114120
115121 super $ initialize(id = id , task_type = task_type ,
116122 feature_types = mlr_reflections $ task_feature_types ,
@@ -128,9 +134,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
128134 if (! is.null(predict_type )) self $ predict_type = predict_type
129135 },
130136 base_learner = function (recursive = Inf ) {
131- self $ base_pipeop(recursive = recursive )$ learner_model
132- },
133- base_pipeop = function (recursive = Inf ) {
134137 assert(check_numeric(recursive , lower = Inf ), check_int(recursive ))
135138 if (recursive < = 0 ) return (self )
136139 gm = self $ graph_model
@@ -150,30 +153,18 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
150153 if (length(last_pipeop_id ) > 1 ) stop(" Graph has no unique PipeOp containing a Learner" )
151154 if (length(last_pipeop_id ) == 0 ) stop(" No Learner PipeOp found." )
152155 }
153- last_pipeop $ base_pipeop(recursive - 1 )
154-
155- },
156-
157- # ' @description
158- # ' Retrieves the inner validation scores as a named `list()`.
156+ learner_model $ base_learner(recursive - 1 )
157+ }
158+ ),
159+ active = list (
159160 inner_valid_scores = function (rhs ) {
160161 assert_ro_binding(rhs )
161- if (is.null(self $ state )) {
162- stopf(" Learner not trained" )
163- }
164162 self $ state $ inner_valid_scores
165163 },
166- # ' @description
167- # ' Retrieves the inner tuned values as a named `list()`.
168164 inner_tuned_values = function (rhs ) {
169165 assert_ro_binding(rhs )
170- if (is.null(self $ state )) {
171- stopf(" Learner not trained" )
172- }
173166 self $ state $ inner_tuned_values
174- }
175- ),
176- active = list (
167+ },
177168 validate = function (rhs ) {
178169 if (! missing(rhs )) {
179170 if (! private $ .can_validate ) {
@@ -185,11 +176,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
185176
186177 },
187178 hash = function () {
188- digest(list (class(self ), self $ id , self $ graph $ hash , private $ .predict_type ,
179+ digest(list (class(self ), self $ id , self $ graph $ hash , private $ .predict_type , private $ .validate ,
189180 self $ fallback $ hash , self $ parallel_predict ), algo = " xxhash64" )
190181 },
191182 phash = function () {
192- digest(list (class(self ), self $ id , self $ graph $ phash , private $ .predict_type ,
183+ digest(list (class(self ), self $ id , self $ graph $ phash , private $ .predict_type , private $ .validate ,
193184 self $ fallback $ hash , self $ parallel_predict ), algo = " xxhash64" )
194185 },
195186 predict_type = function (rhs ) {
@@ -226,21 +217,34 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
226217 .validate = NULL ,
227218 .can_validate = NULL ,
228219 .extract_inner_tuned_values = function () {
220+ itvs = unlist(map(
221+ learner_wrapping_pipeops(self $ graph_model ), function (po ) {
222+ if (exists(" inner_tuned_values" , po $ learner )) {
223+ po $ learner_model $ inner_tuned_values
224+ }
225+ }
226+ ), recursive = FALSE )
229227
230-
231- warningf(" Implementthis" )
232- list ()
228+ if (is.null(itvs ) || ! length(itvs )) {
229+ return (named_list())
230+ }
231+ itvs
233232
234233 },
235234 .extract_inner_valid_scores = function () {
236- warningf(" Implementthis" )
237- list ()
238- # map(
239- # keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearnerCV") || inherits(po, "PipeOpLearner")),
240- # function(po) {
241- # po$inner_
242- # }
243- # )
235+ ivs = unlist(map(
236+ learner_wrapping_pipeops(self $ graph_model ), function (po ) {
237+ if (exists(" inner_valid_scores" , po $ learner )) {
238+ po $ learner_model $ inner_valid_scores
239+ }
240+ }
241+ ), recursive = FALSE )
242+
243+ if (is.null(ivs ) || ! length(ivs )) {
244+ return (named_list())
245+ }
246+ ivs
247+
244248 },
245249 deep_clone = function (name , value ) {
246250 private $ .param_set = NULL
@@ -256,8 +260,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
256260
257261 .train = function (task ) {
258262 if (! is.null(get0(" validate" , self ))) {
259- some_pipeops_validate = map (
260- filter (self $ graph $ pipeops , function (po ) inherits(po , " PipeOpLearner" ) || inherits(po , " PipeOpLearnerCV" )),
263+ some_pipeops_validate = map_lgl (
264+ keep (self $ graph $ pipeops , function (po ) inherits(po , " PipeOpLearner" ) || inherits(po , " PipeOpLearnerCV" )),
261265 function (po ) ! is.null(get0(" validate" , po $ learner ))
262266 )
263267
@@ -319,30 +323,30 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
319323# ' Configure validation for a graph learner.
320324# '
321325# ' In a [`GraphLearner`], validation can be configured on two levels:
322- # ' 1. On the [`GraphLearner`] level.
323- # ' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`].
324- # '
325- # ' Therefore, enabling validation requires to specify not only how to create the validation set (1), but also which
326- # ' pipeops should actually use it.
327- # ' Only the [`GraphLearner`] can specify **how** to create the validation set.
328- # ' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
329- # ' `"inner_valid"` (enable).
326+ # ' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed.
327+ # ' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`], which specifies
328+ # ' which pipeops actually make use of the validation set.
329+ # ' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
330+ # ' `"inner_valid"` (enable).
330331# '
331332# ' @param learner ([`GraphLearner`])\cr
332333# ' The graph learner to configure.
333334# ' @param validate (`numeric(1)`, `"inner_valid"` or `NULL`)\cr
334335# ' How to set the `$validate` field of the learner.
335- # ' If set to `NULL` all validation is disabled.
336+ # ' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops .
336337# ' @param ids (`NULL` or `character()`)\cr
337338# ' For which pipeops to enable validation.
338339# ' This parameter is ignored when `validate` is set to `NULL`.
339340# ' By default, validation is enabled for the base learner.
340341# ' @param args (named `list()`)\cr
341342# ' Rarely needed.
342- # ' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the pipeops.
343- # ' The names must be a subset of `ids`.
343+ # ' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective pipeops.
344+ # ' @param ... (any)\cr
345+ # ' Currently unused.
346+ # '
344347# ' @export
345348# ' @examples
349+ # ' library(mlr3)
346350# ' # simple
347351# ' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
348352# ' set_validate(glrn, 0.3)
@@ -353,13 +357,14 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
353357# ' glrn$graph$pipeops$classif.debug$learner$validate
354358# '
355359# ' # complex
356- # ' glrn = as_learner(ppl("stacking", lrns(c("classif.debug", "classif.featureless")), lrn("classif.debug", id = "final")))
360+ # ' glrn = as_learner(ppl("stacking", lrns(c("classif.debug", "classif.featureless")),
361+ # ' lrn("classif.debug", id = "final")))
357362# ' set_validate(glrn, 0.2, which = c("classif.debug", "final"))
358363# ' glrn$validate
359364# ' glrn$graph$pipeops$classif.debug$learner$validate
360365# ' glrn$graph$pipeops$final$learner$validate
361- set_validate.GraphLearner = function (learner , validate , ids = NULL , args = list ()) {
362- if (is.null(learner $ validate )) {
366+ set_validate.GraphLearner = function (learner , validate , ids = NULL , args = list (), ... ) {
367+ if (is.null(validate )) {
363368 learner $ validate = NULL
364369 walk(learner_wrapping_pipeops(learner ), function (po ) {
365370 if (exists(" validate" , po $ learner )) {
@@ -370,7 +375,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
370375 }
371376
372377 if (is.null(ids )) {
373- which = learner $ base_pipeop()$ id
378+ ids = base_pipeop(learner )$ id
374379 } else {
375380 assert_subset(ids , ids(keep(learner_wrapping_pipeops(learner ), function (po ) " validation" %in% po $ learner $ properties )))
376381 assert_true(length(ids ) > 0 )
@@ -379,7 +384,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
379384 assert_list(args , types = " list" )
380385 assert_subset(names(args ), ids )
381386
382- prev_validate_pos = discard(map(learner_wrapping_pipeops(learner ), function (po ) get0(" validate" , po $ learner ), is.null ) )
387+ prev_validate_pos = discard(map(learner_wrapping_pipeops(learner ), function (po ) get0(" validate" , po $ learner )) , is.null )
383388 prev_validate = learner $ validate
384389
385390 on.exit({
@@ -391,18 +396,29 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
391396
392397 walk(ids , function (poid ) {
393398 # learner might be another GraphLearner / AutoTuner
394- invoke(set_validate learner = learner $ graph $ pipeops [[poid ]]$ learner , validate = " inner_valid" , .args = args [[poid ]])
399+ withCallingHandlers({
400+ invoke(set_validate , learner = learner $ graph $ pipeops [[poid ]]$ learner , validate = " inner_valid" , .args = args [[poid ]])
401+ }, error = function (e ) {
402+ e $ message = sprintf(" Failed to set validate for PipeOp '%s':\n %s" , poid , e $ message )
403+ stop(e )
404+ }, warning = function (w ) {
405+ w $ message = sprintf(" Failed to set validate for PipeOp '%s':\n %s" , po $ id , w $ message )
406+ warning(w )
407+ invokeRestart(" muffleWarning" )
408+ })
395409 })
396410 on.exit()
397411
398412 invisible (learner )
399413}
400414
401415
402- # ' @title Set Inner Tuning of a GraphLearner
416+ # ' @title Set Inner Tuning for a Graph Learner
403417# ' @description
404418# ' First, all values specified by `...` are
405419# ' All [`PipeOpLearner`] and [`PipeOpLearnerCV`]
420+ # '
421+ # ' @inheritParams mlr3::set_inner_tuning
406422# ' @param validate (`numeric(1)`, `"inner_valid"`, or `NULL`)\cr
407423# ' How to set the `$validate` field of the learner.
408424# ' @param args (named `list()`)\cr
@@ -415,7 +431,7 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
415431# ' @export
416432set_inner_tuning.GraphLearner = function (.learner , .disable = FALSE , validate = NA , args = NULL , ... ) {
417433 if (is.null(args )) {
418- args = set_names(list (list ()), .learner $ base_pipeops( )$ id
434+ args = set_names(list (list ()), base_pipeop( .learner )$ id )
419435 }
420436 all_pipeops = .learner $ graph $ pipeops
421437 lrn_pipeops = learner_wrapping_pipeops(all_pipeops )
@@ -433,10 +449,18 @@ set_inner_tuning.GraphLearner = function(.learner, .disable = FALSE, validate =
433449 }, add = TRUE )
434450
435451 walk(lrn_pipeops [names(args )], function (po ) {
436- browser()
437- invoke(set_inner_tuning , .learner = po $ learner ,
438- .args = insert_named(list (validate = validate , .disable = .disable ), args [[po $ id ]])
439- )
452+ withCallingHandlers({
453+ invoke(set_inner_tuning , .learner = po $ learner ,
454+ .args = insert_named(list (validate = validate , .disable = .disable ), args [[po $ id ]])
455+ )
456+ }, error = function (e ) {
457+ e $ message = sprintf(" Failed to set inner tuning for PipeOp '%s':\n %s" , po $ id , e $ message )
458+ stop(e )
459+ }, warning = function (w ) {
460+ w $ message = sprintf(" Failed to set inner tuning for PipeOp '%s':\n %s" , po $ id , w $ message )
461+ warning(w )
462+ invokeRestart(" muffleWarning" )
463+ })
440464 })
441465
442466 # Now:
0 commit comments