diff --git a/DESCRIPTION b/DESCRIPTION index 841d154fa..f5b6c9196 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -74,7 +74,7 @@ Suggests: rpart, testthat (>= 3.2.0) Remotes: - mlr-org/mlr3misc@errors + mlr-org/mlr3misc Encoding: UTF-8 Config/testthat/edition: 3 Config/testthat/parallel: false diff --git a/R/Learner.R b/R/Learner.R index c39c316de..2db4358ff 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -572,6 +572,9 @@ Learner = R6Class("Learner", #' If the training step fails, the `$model` field of the original learner is `NULL`. #' The results are reproducible across the different encapsulation methods. #' + #' Note that for errors of class `Mlr3ErrorConfig`, the function always errs and no fallback learner + #' is trained. + #' #' Also see the section on error handling the mlr3book: #' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling} #' @@ -580,15 +583,20 @@ Learner = R6Class("Learner", #' See the description for details. #' @param fallback [Learner]\cr #' The fallback learner for failed predictions. + #' @param when (`function(condition)`)\cr + #' Function that takes in the condition and returns `logical(1)` indicating whether to run the fallback learner. + #' If `NULL` (default), the fallback is always trained, except for errors of class `Mlr3ErrorConfig`. #' #' @return `self` (invisibly). #' @examples #' learner = lrn("classif.rpart") #' fallback = lrn("classif.featureless") #' learner$encapsulate("try", fallback = fallback) - encapsulate = function(method, fallback = NULL) { + encapsulate = function(method, fallback = NULL, when = NULL) { assert_choice(method, c("none", "try", "evaluate", "callr", "mirai")) + private$.when = assert_function(when, null.ok = TRUE) + if (method != "none") { assert_learner(fallback, task_type = self$task_type) @@ -702,7 +710,6 @@ Learner = R6Class("Learner", private$.use_weights }, - #' @field model (any)\cr #' The fitted model. Only available after `$train()` has been called. model = function(rhs) { @@ -750,7 +757,6 @@ Learner = R6Class("Learner", get_log_condition(self$state, "error") }, - #' @field hash (`character(1)`)\cr #' Hash (unique identifier) for this object. #' The hash is calculated based on the learner id, the parameter settings, the predict type, the fallback hash, the parallel predict setting, the validate setting, and the predict sets. @@ -780,7 +786,6 @@ Learner = R6Class("Learner", assert_string(rhs, .var.name = "predict_type") if (rhs %nin% self$predict_types) { - stopf("Learner '%s' does not support predict type '%s'", self$id, rhs) } private$.predict_type = rhs @@ -840,6 +845,7 @@ Learner = R6Class("Learner", ), private = list( + .when = NULL, .use_weights = NULL, .encapsulation = c(train = "none", predict = "none"), .fallback = NULL, diff --git a/R/LearnerClassifDebug.R b/R/LearnerClassifDebug.R index e9693022a..c91e37d49 100644 --- a/R/LearnerClassifDebug.R +++ b/R/LearnerClassifDebug.R @@ -80,7 +80,8 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, iter = p_iter, early_stopping = p_lgl(default = FALSE, tags = "train"), count_marshaling = p_lgl(default = FALSE, tags = "train"), - check_pid = p_lgl(default = TRUE, tags = "train") + check_pid = p_lgl(default = TRUE, tags = "train"), + config_error = p_lgl(default = FALSE, tags = "train") ) super$initialize( id = "classif.debug", @@ -161,6 +162,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, .validate = NULL, .train = function(task) { pv = self$param_set$get_values(tags = "train") + if (isTRUE(pv$config_error)) { + error_config("You misconfigured the learner") + } + pv$count_marshaling = pv$count_marshaling %??% FALSE roll = function(name) { name %chin% names(pv) && pv[[name]] > runif(1L) diff --git a/R/benchmark_grid.R b/R/benchmark_grid.R index f7777628b..41f3a40f8 100644 --- a/R/benchmark_grid.R +++ b/R/benchmark_grid.R @@ -15,7 +15,7 @@ #' The grid will be generated based on the Cartesian product of learners and pairs. #' #' @section Errors and Warnings: -#' * `varying_predict_types`: This warning will be thrown if the learners have different `predict_type`s. +#' * `Mlr3WarningVaryingPredictTypes`: This warning will be thrown if the learners have different `predict_type`s. #' #' @param tasks (list of [Task]). #' @param learners (list of [Learner]). @@ -79,7 +79,10 @@ benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, pai assert_param_values(param_values, n_learners = length(learners)) } if (length(unique(map_chr(unique(learners), "predict_type"))) > 1) { - warningf("Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.", class = "varying_predict_types") # nolint + warning_config( + "Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.", + class = "Mlr3WarningVaryingPredictTypes" + ) } if (assert_flag(paired)) { diff --git a/R/worker.R b/R/worker.R index 8bf591794..27633f1ee 100644 --- a/R/worker.R +++ b/R/worker.R @@ -108,9 +108,22 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL .compute = getOption("mlr3.mirai_encapsulation", "mlr3_encapsulation") ) - log = append_log(NULL, "train", result$log$class, result$log$msg) + cond = result$log[class == "error", "condition"][[1L]] + + cond = if (length(cond)) { + cond = cond[[1L]] + } + + when = get_private(learner)$.when + catch_error = (!is.null(cond)) && (!inherits(cond, "Mlr3ErrorConfig")) && (is.null(when) || when(cond)) + + log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = catch_error) train_time = result$elapsed + if (!is.null(cond) && !catch_error) { + stop(cond) + } + learner$state = set_class(insert_named(learner$state, list( model = result$result$model, log = log, @@ -495,7 +508,7 @@ process_model_after_predict = function(learner, store_models, is_sequential, unm } } -append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character()) { +append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character(), log_error = TRUE) { if (is.null(log)) { log = data.table( stage = factor(levels = c("train", "predict")), @@ -506,7 +519,7 @@ append_log = function(log = NULL, stage = NA_character_, class = NA_character_, if (length(msg)) { pwalk(list(stage, class, msg), function(s, c, m) { - if (c == "error") lg$error("%s: %s", s, m) + if (c == "error" && log_error) lg$error("%s: %s", s, m) if (c == "warning") lg$warn("%s: %s", s, m) }) log = rbindlist(list(log, data.table(stage = stage, class = class, msg = msg)), use.names = TRUE) diff --git a/man/Learner.Rd b/man/Learner.Rd index 06630cdf7..3794ad0a5 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -744,10 +744,13 @@ Note that the fallback is always trained, as we do not know in advance whether p If the training step fails, the \verb{$model} field of the original learner is \code{NULL}. The results are reproducible across the different encapsulation methods. +Note that for errors of class \code{Mlr3ErrorConfig}, the function always errs and no fallback learner +is trained. + Also see the section on error handling the mlr3book: \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling} \subsection{Usage}{ -\if{html}{\out{