Skip to content

Commit b951efd

Browse files
committed
fix tests, deps
1 parent e210a7b commit b951efd

File tree

7 files changed

+69
-83
lines changed

7 files changed

+69
-83
lines changed

DESCRIPTION

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ Suggests:
8787
methods,
8888
vtreat,
8989
future
90+
Remotes:
91+
mlr-org/paradox@feat/inner_valid,
92+
mlr-org/mlr3@feat/inner_valid
9093
ByteCompile: true
9194
Encoding: UTF-8
9295
Config/testthat/edition: 3

NAMESPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ S3method(as_pipeop,Filter)
1111
S3method(as_pipeop,Learner)
1212
S3method(as_pipeop,PipeOp)
1313
S3method(as_pipeop,default)
14+
S3method(disable_inner_tuning,GraphLearner)
1415
S3method(po,"NULL")
1516
S3method(po,Filter)
1617
S3method(po,Learner)
@@ -23,7 +24,6 @@ S3method(pos,list)
2324
S3method(predict,Graph)
2425
S3method(print,Multiplicity)
2526
S3method(print,Selector)
26-
S3method(set_inner_tuning,GraphLearner)
2727
S3method(set_validate,GraphLearner)
2828
export("%>>!%")
2929
export("%>>%")

R/GraphLearner.R

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
#' * `inner_valid_scores` :: named `list()` or `NULL`\cr
5454
#' The inner tuned parameter values.
5555
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
56+
#' * `validate` :: `numeric(1)`, `"inner_valid"`, `"test"` or `NULL`\cr
57+
#' How to construct the validation data.
5658
#'
5759
#' @section Internals:
5860
#' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
@@ -224,12 +226,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
224226
}
225227
}
226228
), recursive = FALSE)
227-
228-
if (is.null(itvs) || !length(itvs)) {
229-
return(named_list())
230-
}
229+
if (is.null(itvs) || !length(itvs)) return(named_list())
231230
itvs
232-
233231
},
234232
.extract_inner_valid_scores = function() {
235233
ivs = unlist(map(
@@ -239,12 +237,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
239237
}
240238
}
241239
), recursive = FALSE)
242-
243-
if (is.null(ivs) || !length(ivs)) {
244-
return(named_list())
245-
}
240+
if (is.null(ivs) || !length(ivs)) return(named_list())
246241
ivs
247-
248242
},
249243
deep_clone = function(name, value) {
250244
private$.param_set = NULL
@@ -323,15 +317,15 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
323317
#' Configure validation for a graph learner.
324318
#'
325319
#' In a [`GraphLearner`], validation can be configured on two levels:
326-
#' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed.
320+
#' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph.
327321
#' 2. On the level of the [`Learner`]s that are wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`], which specifies
328322
#' which pipeops actually make use of the validation set.
329-
#' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] can only set it to `NULL` (disable) or
330-
#' `"inner_valid"` (enable).
323+
#' All learners wrapped by [`PipeOpLearner`] and [`PipeOpLearnerCV`] should in almost all cases either set it
324+
#' to `NULL` (disable) or `"inner_valid"` (enable).
331325
#'
332326
#' @param learner ([`GraphLearner`])\cr
333327
#' The graph learner to configure.
334-
#' @param validate (`numeric(1)`, `"inner_valid"` or `NULL`)\cr
328+
#' @param validate (`numeric(1)`, `"inner_valid"`, `"test"`, or `NULL`)\cr
335329
#' How to set the `$validate` field of the learner.
336330
#' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops.
337331
#' @param ids (`NULL` or `character()`)\cr
@@ -340,7 +334,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
340334
#' By default, validation is enabled for the base learner.
341335
#' @param args (named `list()`)\cr
342336
#' Rarely needed.
343-
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective pipeops.
337+
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective learners.
338+
#' Names must be a subset of the `ids`.
344339
#' @param ... (any)\cr
345340
#' Currently unused.
346341
#'
@@ -357,9 +352,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
357352
#' glrn$graph$pipeops$classif.debug$learner$validate
358353
#'
359354
#' # complex
360-
#' glrn = as_learner(ppl("stacking", lrns(c("classif.debug", "classif.featureless")),
355+
#' glrn = as_learner(ppl("stacking", list(lrn("classif.debug"), lrn("classif.featureless")),
361356
#' lrn("classif.debug", id = "final")))
362-
#' set_validate(glrn, 0.2, which = c("classif.debug", "final"))
357+
#' set_validate(glrn, 0.2, ids = c("classif.debug", "final"))
363358
#' glrn$validate
364359
#' glrn$graph$pipeops$classif.debug$learner$validate
365360
#' glrn$graph$pipeops$final$learner$validate
@@ -378,7 +373,6 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
378373
ids = base_pipeop(learner)$id
379374
} else {
380375
assert_subset(ids, ids(keep(learner_wrapping_pipeops(learner), function(po) "validation" %in% po$learner$properties)))
381-
assert_true(length(ids) > 0)
382376
}
383377

384378
assert_list(args, types = "list")
@@ -388,8 +382,8 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
388382
prev_validate = learner$validate
389383

390384
on.exit({
391-
iwalk(prev_validate_pos, function(val, poid) learner$graph$pipeops[[poid]] = val)
392-
learner$valiate = prev_validate
385+
iwalk(prev_validate_pos, function(val, poid) learner$graph$pipeops[[poid]]$learner$validate = val)
386+
learner$validate = prev_validate
393387
}, add = TRUE)
394388

395389
learner$validate = validate
@@ -415,15 +409,17 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
415409

416410
#' @export
417411
disable_inner_tuning.GraphLearner = function(learner, ids, ...) {
412+
pvs = learner$param_set$values
413+
on.exit({learner$param_set$values = pvs}, add = TRUE)
418414
if (length(ids)) {
419415
walk(learner_wrapping_pipeops(learner$graph$pipeops), function(po) {
420416
disable_inner_tuning(
421417
learner$graph$pipeops[[po$id]]$learner,
422-
ids = po$param_set$ids()[sprintf("%s.%s", po$id, po$param_set$ids()) %in% ids],
423-
...
418+
ids = po$param_set$ids()[sprintf("%s.%s", po$id, po$param_set$ids()) %in% ids]
424419
)
425420
})
426421
}
422+
on.exit()
427423
invisible(learner)
428424
}
429425

man/mlr_learners_graph.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/set_inner_tuning.GraphLearner.Rd

Lines changed: 0 additions & 33 deletions
This file was deleted.

man/set_validate.GraphLearner.Rd

Lines changed: 8 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_GraphLearner.R

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -569,25 +569,6 @@ test_that("GraphLearner hashes", {
569569

570570
})
571571

572-
573-
test_that("disable_inner_tuning", {
574-
glrn = as_learner(as_pipeop(lrn("classif.debug", validate = 0.2, early_stopping = TRUE, iter = 100)))
575-
disable_inner_tuning(glrn)
576-
577-
set_inner_tuning(glrn, .disable = TRUE)
578-
expect_equal(glrn$base_learner()$validate, NULL)
579-
set_inner_tuning(glrn, validate = 0.21, args = list(classif.debug = list(validate = "inner_valid", iter = 99)))
580-
expect_equal(glrn$validate, 0.21)
581-
expect_equal(glrn$graph$pipeops$classif.debug$learner$validate, "inner_valid")
582-
expect_equal(glrn$graph$pipeops$classif.debug$learner$param_set$values$iter, 99)
583-
584-
glrn2 = as_learner(as_pipeop(lrn("classif.debug")))
585-
expect_error(
586-
set_inner_tuning(glrn2, validate = 0.21),
587-
"for PipeOp 'classif\\.debug'"
588-
)
589-
})
590-
591572
test_that("validation, inner_valid_scores", {
592573
# None of the Learners can do validation -> NULL
593574
glrn1 = as_learner(as_graph(lrn("classif.rpart")))$train(tsk("iris"))
@@ -628,10 +609,46 @@ test_that("inner_tuned_values", {
628609
expect_equal(glrn2$inner_tuned_values, NULL)
629610
glrn2$train(task)
630611
expect_equal(glrn2$inner_tuned_values, named_list())
631-
set_inner_tuning(glrn2, args = list(classif.debug = list(early_stopping = TRUE, iter = 1000)), validate = 0.2)
612+
glrn2$param_set$set_values(classif.debug.early_stopping = TRUE, classif.debug.iter = 1000)
613+
set_validate(glrn2, 0.2)
614+
glrn2$train(task)
632615
expect_equal(names(glrn2$inner_tuned_values), "classif.debug.iter")
633616
})
634617

618+
test_that("disable_inner_tuning", {
619+
glrn = as_learner(as_pipeop(lrn("classif.debug", iter = 100, early_stopping = TRUE)))
620+
disable_inner_tuning(glrn, "classif.debug.iter")
621+
expect_false(glrn$graph$pipeops$classif.debug$param_set$values$early_stopping)
622+
expect_error(disable_inner_tuning(glrn, "classif.debug.abc"), "subset of")
623+
})
624+
635625
test_that("set_validate", {
626+
glrn = as_learner(as_pipeop(lrn("classif.debug", validate = 0.3)))
627+
set_validate(glrn, "test")
628+
expect_equal(glrn$validate, "test")
629+
expect_equal(glrn$graph$pipeops$classif.debug$learner$validate, "inner_valid")
630+
set_validate(glrn, NULL)
631+
expect_equal(glrn$validate, NULL)
632+
expect_equal(glrn$graph$pipeops$classif.debug$learner$validate, NULL)
633+
set_validate(glrn, 0.2, ids = "classif.debug")
634+
expect_equal(glrn$validate, 0.2)
635+
expect_equal(glrn$graph$pipeops$classif.debug$learner$validate, "inner_valid")
636+
637+
638+
glrn = as_learner(ppl("stacking", list(lrn("classif.debug"), lrn("classif.featureless")),
639+
lrn("classif.debug", id = "final")))
640+
set_validate(glrn, 0.3, ids = c("classif.debug", "final"))
641+
expect_equal(glrn$validate, 0.3)
642+
expect_equal(glrn$graph$pipeops$classif.debug$learner$validate, "inner_valid")
643+
expect_equal(glrn$graph$pipeops$final$learner$validate, "inner_valid")
644+
636645

646+
glrn = as_learner(ppl("stacking", list(lrn("classif.debug"), lrn("classif.featureless")),
647+
lrn("classif.debug", id = "final")))
648+
glrn2 = as_learner(po("learner", glrn, id = "polearner"))
649+
set_validate(glrn2, validate = 0.25, ids = "polearner", args = list(polearner = list(ids = "final")))
650+
expect_equal(glrn2$validate, 0.25)
651+
expect_equal(glrn2$graph$pipeops$polearner$learner$validate, "inner_valid")
652+
expect_equal(glrn2$graph$pipeops$polearner$learner$graph$pipeops$final$learner$validate, "inner_valid")
653+
expect_equal(glrn2$graph$pipeops$polearner$learner$graph$pipeops$classif.debug$learner$validate, NULL)
637654
})

0 commit comments

Comments
 (0)