Skip to content

Commit e210a7b

Browse files
committed
...
1 parent 5e7b85e commit e210a7b

File tree

2 files changed

+12
-88
lines changed

2 files changed

+12
-88
lines changed

R/GraphLearner.R

Lines changed: 9 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -413,99 +413,21 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
413413
}
414414

415415

416-
#' @title Set Inner Tuning for a Graph Learner
417-
#' @description
418-
#' First, all values specified by `...` are
419-
#' All [`PipeOpLearner`] and [`PipeOpLearnerCV`]
420-
#'
421-
#' @inheritParams mlr3::set_inner_tuning
422-
#' @param validate (`numeric(1)`, `"inner_valid"`, or `NULL`)\cr
423-
#' How to set the `$validate` field of the learner.
424-
#' @param args (named `list()`)\cr
425-
#' Names are ids of the [`GraphLearner`]'s `PipeOps` and values are lists containing arguments passed to the
426-
#' respective wrapped [`Learner`].
427-
#' By default, the values `.disable` and `validate` are used, but can be overwritten on a per-pipeop basis.
428-
#'
429-
#' When enabling, the inner tuning of the `$base_learner()` is enabled by default.
430-
#' When disabling, all inner tuning is disable by default.
431416
#' @export
432-
set_inner_tuning.GraphLearner = function(.learner, .disable = FALSE, validate = NA, args = NULL, ...) {
433-
if (is.null(args)) {
434-
args = set_names(list(list()), base_pipeop(.learner)$id)
435-
}
436-
all_pipeops = .learner$graph$pipeops
437-
lrn_pipeops = learner_wrapping_pipeops(all_pipeops)
438-
439-
assert_list(args, names = "unique")
440-
assert_subset(names(args), ids(lrn_pipeops))
441-
442-
443-
# clean up when something goes wrong
444-
prev_pvs = .learner$param_set$values
445-
prev_validate = discard(map(lrn_pipeops, function(po) if (exists("validate", po$learner)) po$learner$validate), is.null)
446-
on.exit({
447-
.learner$param_set$set_values(.values = prev_pvs)
448-
iwalk(prev_validate, function(val, poid) .learner$graph$pipeops[[poid]]$learner$validate = val)
449-
}, add = TRUE)
450-
451-
walk(lrn_pipeops[names(args)], function(po) {
452-
withCallingHandlers({
453-
invoke(set_inner_tuning, .learner = po$learner,
454-
.args = insert_named(list(validate = validate, .disable = .disable), args[[po$id]])
417+
disable_inner_tuning.GraphLearner = function(learner, ids, ...) {
418+
if (length(ids)) {
419+
walk(learner_wrapping_pipeops(learner$graph$pipeops), function(po) {
420+
disable_inner_tuning(
421+
learner$graph$pipeops[[po$id]]$learner,
422+
ids = po$param_set$ids()[sprintf("%s.%s", po$id, po$param_set$ids()) %in% ids],
423+
...
455424
)
456-
}, error = function(e) {
457-
e$message = sprintf("Failed to set inner tuning for PipeOp '%s':\n%s", po$id, e$message)
458-
stop(e)
459-
}, warning = function(w) {
460-
w$message = sprintf("Failed to set inner tuning for PipeOp '%s':\n%s", po$id, w$message)
461-
warning(w)
462-
invokeRestart("muffleWarning")
463425
})
464-
})
465-
466-
# Now:
467-
# Set validate for GraphLearner and verify that the configuration is reasonable
468-
469-
if (.disable) {
470-
.learner$validate = if (identical(validate, NA)) NULL else validate
471-
some_pipeops_validate = some(lrn_pipeops, function(po) {
472-
if (!exists("validate", po$learner)) {
473-
return(FALSE)
474-
}
475-
!is.null(po$learner$validate)
476-
})
477-
# if none of the pipeops does validation, we also disable it in the GraphLearner
478-
# (unless a value was explicitly specified)
479-
if (!some_pipeops_validate && identical(validate, NA)) {
480-
.learner$validate = NULL
481-
}
482-
} else {
483-
if (!identical(validate, NA)) {
484-
.learner$validate = validate
485-
}
486-
487-
some_pipeops_validate = some(lrn_pipeops, function(po) {
488-
if (is.null(get0("validate", po$learner))) return(FALSE)
489-
if (is.null(.learner$validate)) {
490-
warningf("PipeOp '%s' from GraphLearner '%s' wants a validation set but GraphLearner does not specify one. This likely not what you want.",
491-
po$id, .learner$id)
492-
}
493-
if (!identical(po$learner$validate, "inner_valid")) {
494-
warningf("PipeOp '%s' from GraphLearner '%s' specifies validation set other than 'inner_valid'. This is likely not what you want.",
495-
po$id, .learner$id)
496-
}
497-
TRUE
498-
})
499-
500-
if (!is.null(.learner$param_set$values$validate) && !some_pipeops_validate) {
501-
warningf("GraphLearner '%s' specifies a validation set, but none of its Learners use it. This is likely not what you want.", .learner$id)
502-
}
503426
}
504-
505-
on.exit()
506-
invisible(.learner)
427+
invisible(learner)
507428
}
508429

430+
509431
#' @export
510432
as_learner.Graph = function(x, clone = FALSE, ...) {
511433
GraphLearner$new(x, clone_graph = clone)

tests/testthat/test_GraphLearner.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,10 @@ test_that("GraphLearner hashes", {
570570
})
571571

572572

573-
test_that("set_inner_tuning", {
573+
test_that("disable_inner_tuning", {
574574
glrn = as_learner(as_pipeop(lrn("classif.debug", validate = 0.2, early_stopping = TRUE, iter = 100)))
575+
disable_inner_tuning(glrn)
576+
575577
set_inner_tuning(glrn, .disable = TRUE)
576578
expect_equal(glrn$base_learner()$validate, NULL)
577579
set_inner_tuning(glrn, validate = 0.21, args = list(classif.debug = list(validate = "inner_valid", iter = 99)))

0 commit comments

Comments
 (0)