@@ -99,7 +99,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
9999 assert_subset(task_type , mlr_reflections $ task_types $ type )
100100
101101
102- private $ .validate = some(
102+ private $ .can_validate = some(
103103 keep(graph $ pipeops , function (x ) inherits(x , " PipeOpLearner" ) || inherits(x , " PipeOpLearnerCV" )),
104104 function (po ) " validation" %in% po $ learner $ properties
105105 )
@@ -110,7 +110,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
110110 )
111111
112112 properties = setdiff(mlr_reflections $ learner_properties [[task_type ]],
113- c(" validation" , " inner_tuning" )[c( ! private $ .validate , ! inner_tuning )])
113+ c(" validation" , " inner_tuning" )[! c( private $ .validate , inner_tuning )])
114114
115115 super $ initialize(id = id , task_type = task_type ,
116116 feature_types = mlr_reflections $ task_feature_types ,
@@ -128,6 +128,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
128128 if (! is.null(predict_type )) self $ predict_type = predict_type
129129 },
130130 base_learner = function (recursive = Inf ) {
131+ self $ base_pipeop(recursive = recursive )$ learner_model
132+ },
133+ base_pipeop = function (recursive = Inf ) {
131134 assert(check_numeric(recursive , lower = Inf ), check_int(recursive ))
132135 if (recursive < = 0 ) return (self )
133136 gm = self $ graph_model
@@ -147,7 +150,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
147150 if (length(last_pipeop_id ) > 1 ) stop(" Graph has no unique PipeOp containing a Learner" )
148151 if (length(last_pipeop_id ) == 0 ) stop(" No Learner PipeOp found." )
149152 }
150- learner_model $ base_learner(recursive - 1 )
153+ last_pipeop $ base_pipeop(recursive - 1 )
154+
151155 },
152156
153157 # ' @description
@@ -170,6 +174,16 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
170174 }
171175 ),
172176 active = list (
177+ validate = function (rhs ) {
178+ if (! missing(rhs )) {
179+ if (! private $ .can_validate ) {
180+ stopf(" None of the Learners wrapped by GraphLearner '%s' support validation." , self $ id )
181+ }
182+ private $ .validate = assert_validate(rhs )
183+ }
184+ private $ .validate
185+
186+ },
173187 hash = function () {
174188 digest(list (class(self ), self $ id , self $ graph $ hash , private $ .predict_type ,
175189 self $ fallback $ hash , self $ parallel_predict ), algo = " xxhash64" )
@@ -188,12 +202,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
188202 if (! missing(rhs ) && ! identical(rhs , self $ graph $ param_set )) {
189203 stop(" param_set is read-only." )
190204 }
191- if (is.null(private $ .param_set )) {
192- private $ .param_set = ParamSetCollection $ new(sets = c(list (self $ graph $ param_set ),
193- if (private $ .validate ) ps(validate = p_uty(default = NULL , tags = " train" , custom_check = check_validate )
194- )))
195- }
196- private $ .param_set
205+ self $ graph $ param_set
197206 },
198207 graph = function (rhs ) {
199208 if (! missing(rhs ) && ! identical(rhs , private $ .graph )) stop(" graph is read-only" )
@@ -215,12 +224,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
215224 private = list (
216225 .graph = NULL ,
217226 .validate = NULL ,
218- .param_set = NULL ,
227+ .can_validate = NULL ,
219228 .extract_inner_tuned_values = function () {
220229
230+
231+ warningf(" Implementthis" )
232+ list ()
233+
221234 },
222235 .extract_inner_valid_scores = function () {
223- .NotYetImplemented()
236+ warningf(" Implementthis" )
237+ list ()
224238 # map(
225239 # keep(self$graph$pipeops, function(po) inherits(po, "PipeOpLearnerCV") || inherits(po, "PipeOpLearner")),
226240 # function(po) {
@@ -241,6 +255,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
241255 },
242256
243257 .train = function (task ) {
258+ if (! is.null(get0(" validate" , self ))) {
259+ some_pipeops_validate = map(
260+ filter(self $ graph $ pipeops , function (po ) inherits(po , " PipeOpLearner" ) || inherits(po , " PipeOpLearnerCV" )),
261+ function (po ) ! is.null(get0(" validate" , po $ learner ))
262+ )
263+
264+ if (! some_pipeops_validate ) {
265+ lg $ warn(" GraphLearner '%s' specifies a validation set, but none of its Learners use it." , self $ id )
266+ }
267+ }
268+
244269 on.exit({self $ graph $ state = NULL })
245270 self $ graph $ train(task )
246271 state = self $ graph $ state
@@ -288,90 +313,173 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
288313 )
289314)
290315
316+ # ' @title Configure Validation for a GraphLearner
317+ # '
318+ # ' @description
319+ # ' Configure validation for a graph learner.
320+ # '
321+ # ' In a [`GraphLearner`], validation can be configured on two levels:
322+ # ' 1. On the [`GraphLearner`] level.
323+ # ' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`].
324+ # '
325+ # ' Therefore, enabling validation requires to specify not only how to create the validation set (1), but also which
326+ # ' pipeops should actually use it.
327+ # ' Only the [`GraphLearner`] can specify **how** to create the validation set.
328+ # ' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
329+ # ' `"inner_valid"` (enable).
330+ # '
331+ # ' @param learner ([`GraphLearner`])\cr
332+ # ' The graph learner to configure.
333+ # ' @param validate (`numeric(1)`, `"inner_valid"` or `NULL`)\cr
334+ # ' How to set the `$validate` field of the learner.
335+ # ' If set to `NULL` all validation is disabled.
336+ # ' @param ids (`NULL` or `character()`)\cr
337+ # ' For which pipeops to enable validation.
338+ # ' This parameter is ignored when `validate` is set to `NULL`.
339+ # ' By default, validation is enabled for the base learner.
340+ # ' @param args (named `list()`)\cr
341+ # ' Rarely needed.
342+ # ' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the pipeops.
343+ # ' The names must be a subset of `ids`.
344+ # ' @export
345+ # ' @examples
346+ # ' # simple
347+ # ' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
348+ # ' set_validate(glrn, 0.3)
349+ # ' glrn$validate
350+ # ' glrn$graph$pipeops$classif.debug$learner$validate
351+ # ' set_validate(glrn, NULL)
352+ # ' glrn$validate
353+ # ' glrn$graph$pipeops$classif.debug$learner$validate
354+ # '
355+ # ' # complex
356+ # ' glrn = as_learner(ppl("stacking", lrns(c("classif.debug", "classif.featureless")), lrn("classif.debug", id = "final")))
357+ # ' set_validate(glrn, 0.2, which = c("classif.debug", "final"))
358+ # ' glrn$validate
359+ # ' glrn$graph$pipeops$classif.debug$learner$validate
360+ # ' glrn$graph$pipeops$final$learner$validate
361+ set_validate.GraphLearner = function (learner , validate , ids = NULL , args = list ()) {
362+ if (is.null(learner $ validate )) {
363+ learner $ validate = NULL
364+ walk(learner_wrapping_pipeops(learner ), function (po ) {
365+ if (exists(" validate" , po $ learner )) {
366+ po $ learner $ validate = NULL
367+ }
368+ })
369+ return (invisible (learner ))
370+ }
291371
292- # ' @param ids (`character(1)`)\cr
293- # ' The ids of the parameters to disable.
372+ if (is.null(ids )) {
373+ which = learner $ base_pipeop()$ id
374+ } else {
375+ assert_subset(ids , ids(keep(learner_wrapping_pipeops(learner ), function (po ) " validation" %in% po $ learner $ properties )))
376+ assert_true(length(ids ) > 0 )
377+ }
378+
379+ assert_list(args , types = " list" )
380+ assert_subset(names(args ), ids )
381+
382+ prev_validate_pos = discard(map(learner_wrapping_pipeops(learner ), function (po ) get0(" validate" , po $ learner ), is.null ))
383+ prev_validate = learner $ validate
384+
385+ on.exit({
386+ iwalk(prev_validate_pos , function (val , poid ) learner $ graph $ pipeops [[poid ]] = val )
387+ learner $ valiate = prev_validate
388+ }, add = TRUE )
389+
390+ learner $ validate = validate
391+
392+ walk(ids , function (poid ) {
393+ # learner might be another GraphLearner / AutoTuner
394+ invoke(set_validate learner = learner $ graph $ pipeops [[poid ]]$ learner , validate = " inner_valid" , .args = args [[poid ]])
395+ })
396+ on.exit()
397+
398+ invisible (learner )
399+ }
400+
401+
402+ # ' @title Set Inner Tuning of a GraphLearner
403+ # ' @description
404+ # ' First, all values specified by `...` are
405+ # ' All [`PipeOpLearner`] and [`PipeOpLearnerCV`]
406+ # ' @param validate (`numeric(1)`, `"inner_valid"`, or `NULL`)\cr
407+ # ' How to set the `$validate` field of the learner.
408+ # ' @param args (named `list()`)\cr
409+ # ' Names are ids of the [`GraphLearner`]'s `PipeOps` and values are lists containing arguments passed to the
410+ # ' respective wrapped [`Learner`].
411+ # ' By default, the values `.disable` and `validate` are used, but can be overwritten on a per-pipeop basis.
412+ # '
294413# ' When enabling, the inner tuning of the `$base_learner()` is enabled by default.
295414# ' When disabling, all inner tuning is disable by default.
296415# ' @export
297- set_inner_tuning.GraphLearner = function (learner , disable = FALSE , ids = NULL , param_vals = list (), ... ) {
298- all_pipeops = learner $ graph $ pipeops
299- lrn_pipeops = all_pipeops [inherits(all_pipeops , " PipeOpLearner" ) | inherits(all_pipeops , " PipeOpLearnerCV" )]
300-
301- if (is.null(ids ) && disable ) {
302- ids = as.character(unlist(imap(lrn_pipeops , function (po , prefix ) {
303- sprintf(" %s.%s" , prefix , names(po $ param_set $ tags [map_lgl(po $ param_set $ tags , function (t ) " inner_tuning" %in% t )]))
304- })))
305- } else if (is.null(ids ) && ! disable ) {
306- lrn_base = learner $ base_learner()
307-
308- # need to find the pipeop that is the base learner. Cannot directly use id, because id of pipeop might
309- # differ from learner id
310- po_baselrn = NULL
311- for (po in lrn_pipeops [inherits(po , " PipeOpLearner" )]) {
312- if (identical(po $ learner , lrn_base )) {
313- po_baselrn = po
314- break
315- }
316- }
317- ids = paste0(
318- po_baselrn $ id , " ." ,
319- names(po_baselrn $ param_set $ tags [map_lgl(po_baselrn $ param_set $ tags , function (tags ) " inner_tuning" %in% tags )])
320- )
416+ set_inner_tuning.GraphLearner = function (.learner , .disable = FALSE , validate = NA , args = NULL , ... ) {
417+ if (is.null(args )) {
418+ args = set_names(list (list ()), .learner $ base_pipeops()$ id
321419 }
322- assert_subset( ids , learner $ param_set $ ids())
323- pv_prev = learner $ param_set $ values
420+ all_pipeops = . learner$ graph $ pipeops
421+ lrn_pipeops = learner_wrapping_pipeops( all_pipeops )
324422
325- # reset to previous pvs if anything goes wrong
326- on.exit({ learner $ param_set $ set_values( .values = pv_prev )}, add = TRUE )
423+ assert_list( args , names = " unique " )
424+ assert_subset(names( args ), ids( lrn_pipeops ) )
327425
328- learner $ param_set $ set_values(.values = param_vals )
329426
427+ # clean up when something goes wrong
428+ prev_pvs = .learner $ param_set $ values
429+ prev_validate = discard(map(lrn_pipeops , function (po ) if (exists(" validate" , po $ learner )) po $ learner $ validate ), is.null )
430+ on.exit({
431+ .learner $ param_set $ set_values(.values = prev_pvs )
432+ iwalk(prev_validate , function (val , poid ) .learner $ graph $ pipeops [[poid ]]$ learner $ validate = val )
433+ }, add = TRUE )
330434
331- # pipeop_ids are those learners that wrap a learner and have a parameter that is containes in ids
332- po_ids = as.character(unlist(discard(map(lrn_pipeops , function (po ) {
333- if (some(names(param_vals ) %in% sprintf(" %s.%s" , po $ id , po $ param_set $ ids()))) po $ id
334- }), is.null )))
335-
336- # now we walk through the learners and call set_inner_tuning() WITHOUT passing the parameters, as we have already
337- # set them above
338- walk(lrn_pipeops [po_ids ], set_inner_tuning , disable = disable )
435+ walk(lrn_pipeops [names(args )], function (po ) {
436+ browser()
437+ invoke(set_inner_tuning , .learner = po $ learner ,
438+ .args = insert_named(list (validate = validate , .disable = .disable ), args [[po $ id ]])
439+ )
440+ })
339441
340- # now put up some extra guardrails because it is not intuitive how to configure validation in the GraphLearner
442+ # Now:
443+ # Set validate for GraphLearner and verify that the configuration is reasonable
341444
342- some_pipeops_validate = FALSE
343- if (disable ) {
344- for (po in lrn_pipeops ) {
345- if (! is.null(po $ param_set $ values $ validate )) {
346- some_pipeops_validate = TRUE
347- break
445+ if (.disable ) {
446+ .learner $ validate = if (identical(validate , NA )) NULL else validate
447+ some_pipeops_validate = some(lrn_pipeops , function (po ) {
448+ if (! exists(" validate" , po $ learner )) {
449+ return (FALSE )
348450 }
349- }
451+ ! is.null(po $ learner $ validate )
452+ })
350453 # if none of the pipeops does validation, we also disable it in the GraphLearner
351- # (unless a value was explicitly passed via param_vals )
352- if (! some_pipeops_validate && is.null( param_vals $ validate )) {
353- learner $ param_set $ set_values( validate = NULL )
454+ # (unless a value was explicitly specified )
455+ if (! some_pipeops_validate && identical( validate , NA )) {
456+ . learner$ validate = NULL
354457 }
355458 } else {
356- for (po in lrn_pipeops ) {
357- if (! is.null(po $ param_set $ values $ validate ) && is.null(learner $ param_set $ values $ validate )) {
459+ if (! identical(validate , NA )) {
460+ .learner $ validate = validate
461+ }
462+
463+ some_pipeops_validate = some(lrn_pipeops , function (po ) {
464+ if (is.null(get0(" validate" , po $ learner ))) return (FALSE )
465+ if (is.null(.learner $ validate )) {
358466 warningf(" PipeOp '%s' from GraphLearner '%s' wants a validation set but GraphLearner does not specify one. This likely not what you want." ,
359- po $ id , learner $ id )
467+ po $ id , . learner$ id )
360468 }
361- if (! is.null(po $ param_set $ values $ validate )) {
362- if (! identical(po $ param_set $ values $ validate , " inner_valid" )) {
363- warningf(" PipeOp '%s' from GraphLearner '%s' specifies validation set other than 'inner_valid'. This is likely not what you want." )
364- }
365- some_pipeops_validate = TRUE
469+ if (! identical(po $ learner $ validate , " inner_valid" )) {
470+ warningf(" PipeOp '%s' from GraphLearner '%s' specifies validation set other than 'inner_valid'. This is likely not what you want." ,
471+ po $ id , .learner $ id )
366472 }
367- }
368- if (! is.null(learner $ param_set $ values $ validate ) && ! some_pipeops_validate ) {
369- warningf(" GraphLearner '%s' specifies a validation set, but none of its Learners use it." , learner $ id )
473+ TRUE
474+ })
475+
476+ if (! is.null(.learner $ param_set $ values $ validate ) && ! some_pipeops_validate ) {
477+ warningf(" GraphLearner '%s' specifies a validation set, but none of its Learners use it. This is likely not what you want." , .learner $ id )
370478 }
371479 }
372480
373481 on.exit()
374- invisible (learner )
482+ invisible (. learner )
375483}
376484
377485# ' @export
0 commit comments