5353# ' * `inner_valid_scores` :: named `list()` or `NULL`\cr
5454# ' The inner tuned parameter values.
5555# ' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
56+ # ' * `validate` :: `numeric(1)`, `"inner_valid"`, `"test"` or `NULL`\cr
57+ # ' How to construct the validation data.
5658# '
5759# ' @section Internals:
5860# ' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
@@ -224,12 +226,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
224226 }
225227 }
226228 ), recursive = FALSE )
227-
228- if (is.null(itvs ) || ! length(itvs )) {
229- return (named_list())
230- }
229+ if (is.null(itvs ) || ! length(itvs )) return (named_list())
231230 itvs
232-
233231 },
234232 .extract_inner_valid_scores = function () {
235233 ivs = unlist(map(
@@ -239,12 +237,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
239237 }
240238 }
241239 ), recursive = FALSE )
242-
243- if (is.null(ivs ) || ! length(ivs )) {
244- return (named_list())
245- }
240+ if (is.null(ivs ) || ! length(ivs )) return (named_list())
246241 ivs
247-
248242 },
249243 deep_clone = function (name , value ) {
250244 private $ .param_set = NULL
@@ -323,15 +317,15 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
323317# ' Configure validation for a graph learner.
324318# '
325319# ' In a [`GraphLearner`], validation can be configured on two levels:
326- # ' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed.
320+ # ' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph .
327321# ' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`], which specifies
328322# ' which pipeops actually make use of the validation set.
329- # ' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
330- # ' `"inner_valid"` (enable).
323+ # ' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] should in almost all cases either set it
324+ # ' to `NULL` (disable) or `"inner_valid"` (enable).
331325# '
332326# ' @param learner ([`GraphLearner`])\cr
333327# ' The graph learner to configure.
334- # ' @param validate (`numeric(1)`, `"inner_valid"` or `NULL`)\cr
328+ # ' @param validate (`numeric(1)`, `"inner_valid"`, `"test"`, or `NULL`)\cr
335329# ' How to set the `$validate` field of the learner.
336330# ' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops.
337331# ' @param ids (`NULL` or `character()`)\cr
@@ -340,7 +334,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
340334# ' By default, validation is enabled for the base learner.
341335# ' @param args (named `list()`)\cr
342336# ' Rarely needed.
343- # ' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective pipeops.
337+ # ' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective learners.
338+ # ' Names must be a subset of the `ids`.
344339# ' @param ... (any)\cr
345340# ' Currently unused.
346341# '
@@ -357,9 +352,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
357352# ' glrn$graph$pipeops$classif.debug$learner$validate
358353# '
359354# ' # complex
360- # ' glrn = as_learner(ppl("stacking", lrns(c ("classif.debug", "classif.featureless")),
355+ # ' glrn = as_learner(ppl("stacking", list(lrn ("classif.debug"), lrn( "classif.featureless")),
361356# ' lrn("classif.debug", id = "final")))
362- # ' set_validate(glrn, 0.2, which = c("classif.debug", "final"))
357+ # ' set_validate(glrn, 0.2, ids = c("classif.debug", "final"))
363358# ' glrn$validate
364359# ' glrn$graph$pipeops$classif.debug$learner$validate
365360# ' glrn$graph$pipeops$final$learner$validate
@@ -378,7 +373,6 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
378373 ids = base_pipeop(learner )$ id
379374 } else {
380375 assert_subset(ids , ids(keep(learner_wrapping_pipeops(learner ), function (po ) " validation" %in% po $ learner $ properties )))
381- assert_true(length(ids ) > 0 )
382376 }
383377
384378 assert_list(args , types = " list" )
@@ -388,8 +382,8 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
388382 prev_validate = learner $ validate
389383
390384 on.exit({
391- iwalk(prev_validate_pos , function (val , poid ) learner $ graph $ pipeops [[poid ]] = val )
392- learner $ valiate = prev_validate
385+ iwalk(prev_validate_pos , function (val , poid ) learner $ graph $ pipeops [[poid ]]$ learner $ validate = val )
386+ learner $ validate = prev_validate
393387 }, add = TRUE )
394388
395389 learner $ validate = validate
@@ -415,15 +409,17 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
415409
416410# ' @export
417411disable_inner_tuning.GraphLearner = function (learner , ids , ... ) {
412+ pvs = learner $ param_set $ values
413+ on.exit({learner $ param_set $ values = pvs }, add = TRUE )
418414 if (length(ids )) {
419415 walk(learner_wrapping_pipeops(learner $ graph $ pipeops ), function (po ) {
420416 disable_inner_tuning(
421417 learner $ graph $ pipeops [[po $ id ]]$ learner ,
422- ids = po $ param_set $ ids()[sprintf(" %s.%s" , po $ id , po $ param_set $ ids()) %in% ids ],
423- ...
418+ ids = po $ param_set $ ids()[sprintf(" %s.%s" , po $ id , po $ param_set $ ids()) %in% ids ]
424419 )
425420 })
426421 }
422+ on.exit()
427423 invisible (learner )
428424}
429425
0 commit comments