Skip to content

Commit a3356c1

Browse files
authored
feat(conditions): improved error handling (#1365)
1 parent 0c0e579 commit a3356c1

File tree

9 files changed

+135
-13
lines changed

9 files changed

+135
-13
lines changed

R/Learner.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,9 @@ Learner = R6Class("Learner",
572572
#' If the training step fails, the `$model` field of the original learner is `NULL`.
573573
#' The results are reproducible across the different encapsulation methods.
574574
#'
575+
#' Note that for errors of class `Mlr3ErrorConfig`, the function always errs and no fallback learner
576+
#' is trained.
577+
#'
575578
#' Also see the section on error handling the mlr3book:
576579
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling}
577580
#'
@@ -580,15 +583,20 @@ Learner = R6Class("Learner",
580583
#' See the description for details.
581584
#' @param fallback [Learner]\cr
582585
#' The fallback learner for failed predictions.
586+
#' @param when (`function(condition)`)\cr
587+
#' Function that takes in the condition and returns `logical(1)` indicating whether to run the fallback learner.
588+
#' If `NULL` (default), the fallback is always trained, except for errors of class `Mlr3ErrorConfig`.
583589
#'
584590
#' @return `self` (invisibly).
585591
#' @examples
586592
#' learner = lrn("classif.rpart")
587593
#' fallback = lrn("classif.featureless")
588594
#' learner$encapsulate("try", fallback = fallback)
589-
encapsulate = function(method, fallback = NULL) {
595+
encapsulate = function(method, fallback = NULL, when = NULL) {
590596
assert_choice(method, c("none", "try", "evaluate", "callr", "mirai"))
591597

598+
private$.when = assert_function(when, null.ok = TRUE)
599+
592600
if (method != "none") {
593601
assert_learner(fallback, task_type = self$task_type)
594602

@@ -702,7 +710,6 @@ Learner = R6Class("Learner",
702710
private$.use_weights
703711
},
704712

705-
706713
#' @field model (any)\cr
707714
#' The fitted model. Only available after `$train()` has been called.
708715
model = function(rhs) {
@@ -750,7 +757,6 @@ Learner = R6Class("Learner",
750757
get_log_condition(self$state, "error")
751758
},
752759

753-
754760
#' @field hash (`character(1)`)\cr
755761
#' Hash (unique identifier) for this object.
756762
#' The hash is calculated based on the learner id, the parameter settings, the predict type, the fallback hash, the parallel predict setting, the validate setting, and the predict sets.
@@ -780,7 +786,6 @@ Learner = R6Class("Learner",
780786

781787
assert_string(rhs, .var.name = "predict_type")
782788
if (rhs %nin% self$predict_types) {
783-
784789
stopf("Learner '%s' does not support predict type '%s'", self$id, rhs)
785790
}
786791
private$.predict_type = rhs
@@ -840,6 +845,7 @@ Learner = R6Class("Learner",
840845
),
841846

842847
private = list(
848+
.when = NULL,
843849
.use_weights = NULL,
844850
.encapsulation = c(train = "none", predict = "none"),
845851
.fallback = NULL,

R/LearnerClassifDebug.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
8080
iter = p_iter,
8181
early_stopping = p_lgl(default = FALSE, tags = "train"),
8282
count_marshaling = p_lgl(default = FALSE, tags = "train"),
83-
check_pid = p_lgl(default = TRUE, tags = "train")
83+
check_pid = p_lgl(default = TRUE, tags = "train"),
84+
config_error = p_lgl(default = FALSE, tags = "train")
8485
)
8586
super$initialize(
8687
id = "classif.debug",
@@ -161,6 +162,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
161162
.validate = NULL,
162163
.train = function(task) {
163164
pv = self$param_set$get_values(tags = "train")
165+
if (isTRUE(pv$config_error)) {
166+
error_config("You misconfigured the learner")
167+
}
168+
164169
pv$count_marshaling = pv$count_marshaling %??% FALSE
165170
roll = function(name) {
166171
name %chin% names(pv) && pv[[name]] > runif(1L)

R/benchmark_grid.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#' The grid will be generated based on the Cartesian product of learners and pairs.
1616
#'
1717
#' @section Errors and Warnings:
18-
#' * `varying_predict_types`: This warning will be thrown if the learners have different `predict_type`s.
18+
#' * `Mlr3WarningVaryingPredictTypes`: This warning will be thrown if the learners have different `predict_type`s.
1919
#'
2020
#' @param tasks (list of [Task]).
2121
#' @param learners (list of [Learner]).
@@ -79,7 +79,10 @@ benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, pai
7979
assert_param_values(param_values, n_learners = length(learners))
8080
}
8181
if (length(unique(map_chr(unique(learners), "predict_type"))) > 1) {
82-
warningf("Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.", class = "varying_predict_types") # nolint
82+
warning_config(
83+
"Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.",
84+
class = "Mlr3WarningVaryingPredictTypes"
85+
)
8386
}
8487

8588
if (assert_flag(paired)) {

R/worker.R

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,22 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
108108
.compute = getOption("mlr3.mirai_encapsulation", "mlr3_encapsulation")
109109
)
110110

111-
log = append_log(NULL, "train", result$log$class, result$log$msg)
111+
cond = result$log[class == "error", "condition"][[1L]]
112+
113+
cond = if (length(cond)) {
114+
cond = cond[[1L]]
115+
}
116+
117+
when = get_private(learner)$.when
118+
catch_error = (!is.null(cond)) && (!inherits(cond, "Mlr3ErrorConfig")) && (is.null(when) || when(cond))
119+
120+
log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = catch_error)
112121
train_time = result$elapsed
113122

123+
if (!is.null(cond) && !catch_error) {
124+
stop(cond)
125+
}
126+
114127
learner$state = set_class(insert_named(learner$state, list(
115128
model = result$result$model,
116129
log = log,
@@ -495,7 +508,7 @@ process_model_after_predict = function(learner, store_models, is_sequential, unm
495508
}
496509
}
497510

498-
append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character()) {
511+
append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character(), log_error = TRUE) {
499512
if (is.null(log)) {
500513
log = data.table(
501514
stage = factor(levels = c("train", "predict")),
@@ -506,7 +519,7 @@ append_log = function(log = NULL, stage = NA_character_, class = NA_character_,
506519

507520
if (length(msg)) {
508521
pwalk(list(stage, class, msg), function(s, c, m) {
509-
if (c == "error") lg$error("%s: %s", s, m)
522+
if (c == "error" && log_error) lg$error("%s: %s", s, m)
510523
if (c == "warning") lg$warn("%s: %s", s, m)
511524
})
512525
log = rbindlist(list(log, data.table(stage = stage, class = class, msg = msg)), use.names = TRUE)

man/Learner.Rd

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/benchmark_grid.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_classif.debug.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_Learner.R

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,6 +839,86 @@ test_that("Learner printer for encapsulation", {
839839
expect_output(print(lrn("classif.rpart")$encapsulate("none")), "Encapsulation: none \\(fallback: -\\)")
840840
})
841841

842+
test_that("error conditions are working: callr", {
843+
l = lrn("classif.debug",
844+
timeout = c(train = 0.01),
845+
# Sys.sleep does not get interrupted reliably
846+
sleep_train = function() while (TRUE) NULL
847+
)
848+
849+
l$encapsulate(
850+
"callr",
851+
lrn("classif.featureless"),
852+
when = function(cond) {
853+
!inherits(cond, "Mlr3ErrorTimeout")
854+
}
855+
)
856+
857+
expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
858+
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
859+
expect_error(l$train(tsk("iris")), regexp = NA)
860+
})
861+
862+
test_that("error conditions are working: evaluate", {
863+
l = lrn("classif.debug",
864+
timeout = c(train = 0.2),
865+
# Sys.sleep does not get interrupted reliably
866+
sleep_train = function() while (TRUE) NULL
867+
)
868+
869+
l$encapsulate(
870+
"evaluate",
871+
lrn("classif.featureless"),
872+
function(x) {
873+
!inherits(x, "Mlr3ErrorTimeout")
874+
}
875+
)
876+
877+
expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
878+
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
879+
expect_error(l$train(tsk("iris")), regexp = NA)
880+
})
881+
882+
test_that("error conditions are working: try", {
883+
l = lrn("classif.debug",
884+
timeout = c(train = 0.01),
885+
# Sys.sleep does not get interrupted reliably
886+
sleep_train = function() while (TRUE) NULL
887+
)
888+
889+
l$encapsulate(
890+
"try",
891+
lrn("classif.featureless"),
892+
function(x) {
893+
!inherits(x, "Mlr3ErrorTimeout")
894+
}
895+
)
896+
897+
expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
898+
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
899+
expect_error(l$train(tsk("iris")), regexp = NA)
900+
})
901+
902+
test_that("error conditions are working: mirai", {
903+
l = lrn("classif.debug",
904+
timeout = c(train = 0.01),
905+
# Sys.sleep does not get interrupted reliably
906+
sleep_train = function() while (TRUE) NULL
907+
)
908+
909+
l$encapsulate(
910+
"mirai",
911+
lrn("classif.featureless"),
912+
function(x) {
913+
!inherits(x, "Mlr3ErrorTimeout")
914+
}
915+
)
916+
917+
expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
918+
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
919+
expect_error(l$train(tsk("iris")), regexp = NA)
920+
})
921+
842922
test_that("oob_error is available without storing models via $.extract_oob_error()", {
843923
LearnerDummyOOB = R6::R6Class("LearnerDummyOOB", inherit = LearnerClassif,
844924
public = list(
@@ -872,3 +952,11 @@ test_that("oob_error is available without storing models via $.extract_oob_error
872952

873953
expect_equal(rr$aggregate(msr("oob_error")), c(oob_error = 0.123))
874954
})
955+
956+
test_that("config error does not trigger callback", {
957+
l = lrn("classif.debug", config_error = TRUE)
958+
l$encapsulate("evaluate", lrn("classif.featureless"), function(x) TRUE)
959+
expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner")
960+
l$encapsulate("evaluate", lrn("classif.featureless"))
961+
expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner")
962+
})

tests/testthat/test_errorhandling.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,3 @@ test_that("encapsulation / benchmark", {
8888
expect_equal(aggr$warnings, 3L)
8989
expect_equal(aggr$errors, 3L)
9090
})
91-

0 commit comments

Comments
 (0)