Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Suggests:
rpart,
testthat (>= 3.2.0)
Remotes:
mlr-org/mlr3misc@errors
mlr-org/mlr3misc
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: false
Expand Down
14 changes: 10 additions & 4 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,9 @@ Learner = R6Class("Learner",
#' If the training step fails, the `$model` field of the original learner is `NULL`.
#' The results are reproducible across the different encapsulation methods.
#'
#' Note that for errors of class `Mlr3ErrorConfig`, the function always errs and no fallback learner
#' is trained.
#'
#' Also see the section on error handling the mlr3book:
#' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling}
#'
Expand All @@ -580,15 +583,20 @@ Learner = R6Class("Learner",
#' See the description for details.
#' @param fallback [Learner]\cr
#' The fallback learner for failed predictions.
#' @param when (`function(condition)`)\cr
#' Function that takes in the condition and returns `logical(1)` indicating whether to run the fallback learner.
#' If `NULL` (default), the fallback is always trained, except for errors of class `Mlr3ErrorConfig`.
#'
#' @return `self` (invisibly).
#' @examples
#' learner = lrn("classif.rpart")
#' fallback = lrn("classif.featureless")
#' learner$encapsulate("try", fallback = fallback)
encapsulate = function(method, fallback = NULL) {
encapsulate = function(method, fallback = NULL, when = NULL) {
assert_choice(method, c("none", "try", "evaluate", "callr", "mirai"))

private$.when = assert_function(when, null.ok = TRUE)

if (method != "none") {
assert_learner(fallback, task_type = self$task_type)

Expand Down Expand Up @@ -702,7 +710,6 @@ Learner = R6Class("Learner",
private$.use_weights
},


#' @field model (any)\cr
#' The fitted model. Only available after `$train()` has been called.
model = function(rhs) {
Expand Down Expand Up @@ -750,7 +757,6 @@ Learner = R6Class("Learner",
get_log_condition(self$state, "error")
},


#' @field hash (`character(1)`)\cr
#' Hash (unique identifier) for this object.
#' The hash is calculated based on the learner id, the parameter settings, the predict type, the fallback hash, the parallel predict setting, the validate setting, and the predict sets.
Expand Down Expand Up @@ -780,7 +786,6 @@ Learner = R6Class("Learner",

assert_string(rhs, .var.name = "predict_type")
if (rhs %nin% self$predict_types) {

stopf("Learner '%s' does not support predict type '%s'", self$id, rhs)
}
private$.predict_type = rhs
Expand Down Expand Up @@ -840,6 +845,7 @@ Learner = R6Class("Learner",
),

private = list(
.when = NULL,
.use_weights = NULL,
.encapsulation = c(train = "none", predict = "none"),
.fallback = NULL,
Expand Down
7 changes: 6 additions & 1 deletion R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
iter = p_iter,
early_stopping = p_lgl(default = FALSE, tags = "train"),
count_marshaling = p_lgl(default = FALSE, tags = "train"),
check_pid = p_lgl(default = TRUE, tags = "train")
check_pid = p_lgl(default = TRUE, tags = "train"),
config_error = p_lgl(default = FALSE, tags = "train")
)
super$initialize(
id = "classif.debug",
Expand Down Expand Up @@ -161,6 +162,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
.validate = NULL,
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
if (isTRUE(pv$config_error)) {
error_config("You misconfigured the learner")
}

pv$count_marshaling = pv$count_marshaling %??% FALSE
roll = function(name) {
name %chin% names(pv) && pv[[name]] > runif(1L)
Expand Down
7 changes: 5 additions & 2 deletions R/benchmark_grid.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' The grid will be generated based on the Cartesian product of learners and pairs.
#'
#' @section Errors and Warnings:
#' * `varying_predict_types`: This warning will be thrown if the learners have different `predict_type`s.
#' * `Mlr3WarningVaryingPredictTypes`: This warning will be thrown if the learners have different `predict_type`s.
#'
#' @param tasks (list of [Task]).
#' @param learners (list of [Learner]).
Expand Down Expand Up @@ -79,7 +79,10 @@ benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, pai
assert_param_values(param_values, n_learners = length(learners))
}
if (length(unique(map_chr(unique(learners), "predict_type"))) > 1) {
warningf("Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.", class = "varying_predict_types") # nolint
warning_config(
"Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.",
class = "Mlr3WarningVaryingPredictTypes"
)
}

if (assert_flag(paired)) {
Expand Down
19 changes: 16 additions & 3 deletions R/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,22 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL
.compute = getOption("mlr3.mirai_encapsulation", "mlr3_encapsulation")
)

log = append_log(NULL, "train", result$log$class, result$log$msg)
cond = result$log[class == "error", "condition"][[1L]]

cond = if (length(cond)) {
cond = cond[[1L]]
}

when = get_private(learner)$.when
catch_error = (!is.null(cond)) && (!inherits(cond, "Mlr3ErrorConfig")) && (is.null(when) || when(cond))

log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = catch_error)
train_time = result$elapsed

if (!is.null(cond) && !catch_error) {
stop(cond)
}

learner$state = set_class(insert_named(learner$state, list(
model = result$result$model,
log = log,
Expand Down Expand Up @@ -495,7 +508,7 @@ process_model_after_predict = function(learner, store_models, is_sequential, unm
}
}

append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character()) {
append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character(), log_error = TRUE) {
if (is.null(log)) {
log = data.table(
stage = factor(levels = c("train", "predict")),
Expand All @@ -506,7 +519,7 @@ append_log = function(log = NULL, stage = NA_character_, class = NA_character_,

if (length(msg)) {
pwalk(list(stage, class, msg), function(s, c, m) {
if (c == "error") lg$error("%s: %s", s, m)
if (c == "error" && log_error) lg$error("%s: %s", s, m)
if (c == "warning") lg$warn("%s: %s", s, m)
})
log = rbindlist(list(log, data.table(stage = stage, class = class, msg = msg)), use.names = TRUE)
Expand Down
9 changes: 8 additions & 1 deletion man/Learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/benchmark_grid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_learners_classif.debug.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 88 additions & 0 deletions tests/testthat/test_Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,86 @@ test_that("Learner printer for encapsulation", {
expect_output(print(lrn("classif.rpart")$encapsulate("none")), "Encapsulation: none \\(fallback: -\\)")
})

test_that("error conditions are working: callr", {
l = lrn("classif.debug",
timeout = c(train = 0.01),
# Sys.sleep does not get interrupted reliably
sleep_train = function() while (TRUE) NULL
)

l$encapsulate(
"callr",
lrn("classif.featureless"),
when = function(cond) {
!inherits(cond, "Mlr3ErrorTimeout")
}
)

expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
expect_error(l$train(tsk("iris")), regexp = NA)
})

test_that("error conditions are working: evaluate", {
l = lrn("classif.debug",
timeout = c(train = 0.2),
# Sys.sleep does not get interrupted reliably
sleep_train = function() while (TRUE) NULL
)

l$encapsulate(
"evaluate",
lrn("classif.featureless"),
function(x) {
!inherits(x, "Mlr3ErrorTimeout")
}
)

expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
expect_error(l$train(tsk("iris")), regexp = NA)
})

test_that("error conditions are working: try", {
l = lrn("classif.debug",
timeout = c(train = 0.01),
# Sys.sleep does not get interrupted reliably
sleep_train = function() while (TRUE) NULL
)

l$encapsulate(
"try",
lrn("classif.featureless"),
function(x) {
!inherits(x, "Mlr3ErrorTimeout")
}
)

expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
expect_error(l$train(tsk("iris")), regexp = NA)
})

test_that("error conditions are working: mirai", {
l = lrn("classif.debug",
timeout = c(train = 0.01),
# Sys.sleep does not get interrupted reliably
sleep_train = function() while (TRUE) NULL
)

l$encapsulate(
"mirai",
lrn("classif.featureless"),
function(x) {
!inherits(x, "Mlr3ErrorTimeout")
}
)

expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit")
l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf))
expect_error(l$train(tsk("iris")), regexp = NA)
})

test_that("oob_error is available without storing models via $.extract_oob_error()", {
LearnerDummyOOB = R6::R6Class("LearnerDummyOOB", inherit = LearnerClassif,
public = list(
Expand Down Expand Up @@ -872,3 +952,11 @@ test_that("oob_error is available without storing models via $.extract_oob_error

expect_equal(rr$aggregate(msr("oob_error")), c(oob_error = 0.123))
})

test_that("config error does not trigger callback", {
l = lrn("classif.debug", config_error = TRUE)
l$encapsulate("evaluate", lrn("classif.featureless"), function(x) TRUE)
expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner")
l$encapsulate("evaluate", lrn("classif.featureless"))
expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner")
})
1 change: 0 additions & 1 deletion tests/testthat/test_errorhandling.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,3 @@ test_that("encapsulation / benchmark", {
expect_equal(aggr$warnings, 3L)
expect_equal(aggr$errors, 3L)
})