From 592af56c30c42fe2583d43f32403d0bddc8852d3 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 11 Aug 2025 15:13:01 +0200 Subject: [PATCH 01/12] feat(conditions): improved error handling Summary: * Depends on: https://github.com/mlr-org/mlr3misc/pull/141, which makes `encapsulate()` also return the condition objects * Introduce custom error/error condition constructors (#1283) * Allows to specify for which conditions a fallback learner should be triggered (#1335) TODO: * [ ] add condition objects for warnings (varying_predict_types) --- DESCRIPTION | 3 ++ NAMESPACE | 6 +++ R/Learner.R | 26 +++++++------ R/conditions.R | 42 +++++++++++++++++++++ R/worker.R | 19 ++++++++-- man/Learner.Rd | 5 ++- man/mlr_learner_conditions.Rd | 37 +++++++++++++++++++ tests/testthat/test_Learner.R | 57 +++++++++++++++++++++++++++++ tests/testthat/test_errorhandling.R | 1 - 9 files changed, 179 insertions(+), 17 deletions(-) create mode 100644 R/conditions.R create mode 100644 man/mlr_learner_conditions.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 5039e5e5a..6a3d83bf9 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -73,6 +73,8 @@ Suggests: RhpcBLASctl, rpart, testthat (>= 3.2.0) +Remotes: + mlr-org/mlr3misc@errors Encoding: UTF-8 Config/testthat/edition: 3 Config/testthat/parallel: false @@ -186,6 +188,7 @@ Collate: 'benchmark.R' 'benchmark_grid.R' 'bibentries.R' + 'conditions.R' 'default_fallback.R' 'default_measures.R' 'fix_factor_levels.R' diff --git a/NAMESPACE b/NAMESPACE index e1ac0cb16..0bb1146e7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -232,11 +232,17 @@ export(check_prediction_data) export(clbk) export(clbks) export(col_info) +export(condition_learner) +export(condition_learner_predict) +export(condition_learner_train) export(convert_task) export(create_empty_prediction_data) export(data.table) export(default_measures) export(deprecated_binding) +export(error_learner) +export(error_learner_predict) +export(error_learner_train) export(extract_pkgs) export(filter_prediction_data) export(install_pkgs) diff --git a/R/Learner.R b/R/Learner.R index 3eb85150c..590808554 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -377,11 +377,11 @@ Learner = R6Class("Learner", # to it original state model_was_marshaled = is_marshaled_model(self$model) on.exit({ - if (model_was_marshaled) { - self$model = marshal_model(self$model, inplace = TRUE) - } else { - self$model = unmarshal_model(self$model, inplace = TRUE) - } + if (model_was_marshaled) { + self$model = marshal_model(self$model, inplace = TRUE) + } else { + self$model = unmarshal_model(self$model, inplace = TRUE) + } }, add = TRUE) # reset learner predict time; this is only cumulative for multiple predict sets, @@ -464,7 +464,7 @@ Learner = R6Class("Learner", ci = task$col_info[list(keep_cols), ][ get("type") != col_info(newdata)[list(keep_cols), on = "id"]$type] tab2 = do.call(data.table, Map(auto_convert, - value = as.list(newdata$data(rows = newdata$rownames, cols = ci$id)), + value = as.list(newdata$data(rows = newdata$rownames, cols = ci$id)), id = ci$id, type = ci$type, levels = ci$levels)) tab = cbind(tab1, tab2) @@ -554,11 +554,15 @@ Learner = R6Class("Learner", #' See the description for details. #' @param fallback [Learner]\cr #' The fallback learner for failed predictions. + #' @param should_catch (`function(condition)`)\cr + #' Function that takes in the condition and returns `logical(1)` indicating whether to run the fallback learner. #' #' @return `self` (invisibly). - encapsulate = function(method, fallback = NULL) { + encapsulate = function(method, fallback = NULL, should_catch = NULL) { assert_choice(method, c("none", "try", "evaluate", "callr")) + private$.should_catch = assert_function(should_catch, null.ok = TRUE) + if (method != "none") { assert_learner(fallback, task_type = self$task_type) @@ -673,7 +677,6 @@ Learner = R6Class("Learner", #' This is deprecated and will be removed in the future. data_formats = deprecated_binding("Learner$data_formats", "data.table"), - #' @field model (any)\cr #' The fitted model. Only available after `$train()` has been called. model = function(rhs) { @@ -721,7 +724,6 @@ Learner = R6Class("Learner", get_log_condition(self$state, "error") }, - #' @field hash (`character(1)`)\cr #' Hash (unique identifier) for this object. #' The hash is calculated based on the learner id, the parameter settings, the predict type, the fallback hash, the parallel predict setting, the validate setting, and the predict sets. @@ -751,7 +753,6 @@ Learner = R6Class("Learner", assert_string(rhs, .var.name = "predict_type") if (rhs %nin% self$predict_types) { - stopf("Learner '%s' does not support predict type '%s'", self$id, rhs) } private$.predict_type = rhs @@ -811,6 +812,7 @@ Learner = R6Class("Learner", ), private = list( + .should_catch = NULL, .use_weights = NULL, .encapsulation = c(train = "none", predict = "none"), .fallback = NULL, @@ -900,8 +902,8 @@ marshal_model.learner_state = function(model, inplace = FALSE, ...) { } model$model = mm structure(list( - marshaled = model, - packages = "mlr3" + marshaled = model, + packages = "mlr3" ), class = c("learner_state_marshaled", "list_marshaled", "marshaled")) } diff --git a/R/conditions.R b/R/conditions.R new file mode 100644 index 000000000..d608d4806 --- /dev/null +++ b/R/conditions.R @@ -0,0 +1,42 @@ +#' @title Learner Classes +#' @name mlr_learner_conditions +#' @param msg (`character(1)`)\cr +#' Message to be printed. +#' @param ... (`any`)\cr +#' Additional arguments to be passed to [mlr3misc::stopf()]. +#' @param class (`character()`)\cr +#' Additional classes to be added to the error. +#' @export +error_learner = function(msg, ..., class = NULL) { + stop(condition_learner(msg, ..., class = class)) +} + +#' @rdname mlr_learner_conditions +#' @export +condition_learner = function(msg, ..., class = NULL) { + mlr3misc::condition_mlr3(msg, ..., class = c(class, "mlr3ErrorLearner")) +} + +#' @rdname mlr_learner_conditions +#' @export +error_learner_train = function(msg, ..., class = NULL) { + stop(condition_learner_train(msg, ..., class = class)) +} + +#' @rdname mlr_learner_conditions +#' @export +condition_learner_train = function(msg, ..., class = NULL) { + condition_learner(msg, ..., class = c(class, "mlr3ConditionLearnerTrain")) +} + +#' @rdname mlr_learner_conditions +#' @export +error_learner_predict = function(msg, ..., class = NULL) { + stop(condition_learner_predict(msg, ..., class = class)) +} + +#' @rdname mlr_learner_conditions +#' @export +condition_learner_predict = function(msg, ..., class = NULL) { + condition_learner(msg, ..., class = c(class, "mlr3ConditionLearnerPredict")) +} diff --git a/R/worker.R b/R/worker.R index d624f00a3..483c13ebe 100644 --- a/R/worker.R +++ b/R/worker.R @@ -100,10 +100,23 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL .seed = NA_integer_, .timeout = learner$timeout["train"] ) + # select rows where column 'class' is equal to "error" in data.table style + cond = result$log[class == "error", "condition"] + cond = if (nrow(cond)) { + cond = cond[[1L]][[1L]] + } + + should_catch = get_private(learner)$.should_catch + would_catch = is.null(should_catch) || (!is.null(cond) && should_catch(cond)) - log = append_log(NULL, "train", result$log$class, result$log$msg) + log = append_log(NULL, "train", result$log$class, result$log$msg, would_catch = would_catch) train_time = result$elapsed + if (!would_catch) { + stop(cond) + } + + learner$state = set_class(insert_named(learner$state, list( model = result$result$model, log = log, @@ -486,7 +499,7 @@ process_model_after_predict = function(learner, store_models, is_sequential, unm } } -append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character()) { +append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character(), would_catch = TRUE) { if (is.null(log)) { log = data.table( stage = factor(levels = c("train", "predict")), @@ -497,7 +510,7 @@ append_log = function(log = NULL, stage = NA_character_, class = NA_character_, if (length(msg)) { pwalk(list(stage, class, msg), function(s, c, m) { - if (c == "error") lg$error("%s: %s", s, m) + if (c == "error" && would_catch) lg$error("%s: %s", s, m) if (c == "warning") lg$warn("%s: %s", s, m) }) log = rbindlist(list(log, data.table(stage = stage, class = class, msg = msg)), use.names = TRUE) diff --git a/man/Learner.Rd b/man/Learner.Rd index 96027fbd9..2c92b3c08 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -649,7 +649,7 @@ If the training step fails, the \verb{$model} field of the original learner is \ Also see the section on error handling the mlr3book: \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling} \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Learner$encapsulate(method, fallback = NULL)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Learner$encapsulate(method, fallback = NULL, should_catch = NULL)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -661,6 +661,9 @@ See the description for details.} \item{\code{fallback}}{\link{Learner}\cr The fallback learner for failed predictions.} + +\item{\code{should_catch}}{(\verb{function(condition)})\cr +Function that takes in the condition and returns \code{logical(1)} indicating whether to run the fallback learner.} } \if{html}{\out{}} } diff --git a/man/mlr_learner_conditions.Rd b/man/mlr_learner_conditions.Rd new file mode 100644 index 000000000..00b12beb6 --- /dev/null +++ b/man/mlr_learner_conditions.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/conditions.R +\name{mlr_learner_conditions} +\alias{mlr_learner_conditions} +\alias{error_learner} +\alias{condition_learner} +\alias{error_learner_train} +\alias{condition_learner_train} +\alias{error_learner_predict} +\alias{condition_learner_predict} +\title{Learner Classes} +\usage{ +error_learner(msg, ..., class = NULL) + +condition_learner(msg, ..., class = NULL) + +error_learner_train(msg, ..., class = NULL) + +condition_learner_train(msg, ..., class = NULL) + +error_learner_predict(msg, ..., class = NULL) + +condition_learner_predict(msg, ..., class = NULL) +} +\arguments{ +\item{msg}{(\code{character(1)})\cr +Message to be printed.} + +\item{...}{(\code{any})\cr +Additional arguments to be passed to \code{\link[mlr3misc:printf]{mlr3misc::stopf()}}.} + +\item{class}{(\code{character()})\cr +Additional classes to be added to the error.} +} +\description{ +Learner Classes +} diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index b48772568..52d3ac7c1 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -838,3 +838,60 @@ test_that("Learner printer for encapsulation", { expect_output(print(lrn("classif.rpart")$encapsulate("evaluate", lrn("classif.featureless"))), "Encapsulation: evaluate \\(fallback: LearnerClassifFeatureless\\)") expect_output(print(lrn("classif.rpart")$encapsulate("none")), "Encapsulation: none \\(fallback: -\\)") }) + +test_that("error conditions are working: callr", { + l = lrn("classif.debug", + timeout = c(train = 0.01), + sleep_train = function() while (TRUE) NULL + ) + + l$encapsulate( + "callr", + lrn("classif.featureless"), + should_catch = function(cond) { + !inherits(cond, "mlr3ErrorTimeout") + } + ) + + expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit") + l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf)) + expect_error(l$train(tsk("iris")), regexp = NA) +}) + +test_that("error conditions are working: evaluate", { + l = lrn("classif.debug", + timeout = c(train = 0.2), + sleep_train = function() while (TRUE) NULL + ) + + l$encapsulate( + "evaluate", + lrn("classif.featureless"), + function(x) { + !inherits(x, "mlr3ErrorTimeout") + } + ) + + expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit") + l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf)) + expect_error(l$train(tsk("iris")), regexp = NA) +}) + +test_that("error conditions are working: try", { + l = lrn("classif.debug", + timeout = c(train = 0.01), + sleep_train = function() while (TRUE) NULL + ) + + l$encapsulate( + "try", + lrn("classif.featureless"), + function(x) { + !inherits(x, "mlr3ErrorTimeout") + } + ) + + expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit") + l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf)) + expect_error(l$train(tsk("iris")), regexp = NA) +}) diff --git a/tests/testthat/test_errorhandling.R b/tests/testthat/test_errorhandling.R index 970c0f149..268212e7c 100644 --- a/tests/testthat/test_errorhandling.R +++ b/tests/testthat/test_errorhandling.R @@ -88,4 +88,3 @@ test_that("encapsulation / benchmark", { expect_equal(aggr$warnings, 3L) expect_equal(aggr$errors, 3L) }) - From 74afb131bacf1b442d82184a7e773c36624c915e Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 11 Aug 2025 17:07:39 +0200 Subject: [PATCH 02/12] update --- NAMESPACE | 3 --- R/benchmark_grid.R | 7 ++++-- R/conditions.R | 45 ++++++++++++++++------------------- man/mlr_learner_conditions.Rd | 18 +++++--------- tests/testthat/test_Learner.R | 3 +++ 5 files changed, 34 insertions(+), 42 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 0bb1146e7..c3b49e3c7 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -232,9 +232,6 @@ export(check_prediction_data) export(clbk) export(clbks) export(col_info) -export(condition_learner) -export(condition_learner_predict) -export(condition_learner_train) export(convert_task) export(create_empty_prediction_data) export(data.table) diff --git a/R/benchmark_grid.R b/R/benchmark_grid.R index f7777628b..781668e49 100644 --- a/R/benchmark_grid.R +++ b/R/benchmark_grid.R @@ -15,7 +15,7 @@ #' The grid will be generated based on the Cartesian product of learners and pairs. #' #' @section Errors and Warnings: -#' * `varying_predict_types`: This warning will be thrown if the learners have different `predict_type`s. +#' * `mlr3WarningVaryingPredictTypes`: This warning will be thrown if the learners have different `predict_type`s. #' #' @param tasks (list of [Task]). #' @param learners (list of [Learner]). @@ -79,7 +79,10 @@ benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, pai assert_param_values(param_values, n_learners = length(learners)) } if (length(unique(map_chr(unique(learners), "predict_type"))) > 1) { - warningf("Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.", class = "varying_predict_types") # nolint + mlr3_warning( + "Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.", + class = "mlr3WarningVaryingPredictTypes" + ) } if (assert_flag(paired)) { diff --git a/R/conditions.R b/R/conditions.R index d608d4806..da34662ba 100644 --- a/R/conditions.R +++ b/R/conditions.R @@ -6,37 +6,32 @@ #' Additional arguments to be passed to [mlr3misc::stopf()]. #' @param class (`character()`)\cr #' Additional classes to be added to the error. +#' @param silent (`logical(1)`)\cr +#' If `TRUE`, the error is returned as a condition object instead of being stopped. #' @export -error_learner = function(msg, ..., class = NULL) { - stop(condition_learner(msg, ..., class = class)) +error_learner = function(msg, ..., class = NULL, silent = FALSE) { + condition = mlr3_error(msg, ..., class = c(class, "mlr3ErrorLearner")) + if (silent) { + return(condition) + } + stop(condition) } - -#' @rdname mlr_learner_conditions -#' @export -condition_learner = function(msg, ..., class = NULL) { - mlr3misc::condition_mlr3(msg, ..., class = c(class, "mlr3ErrorLearner")) -} - -#' @rdname mlr_learner_conditions -#' @export -error_learner_train = function(msg, ..., class = NULL) { - stop(condition_learner_train(msg, ..., class = class)) -} - -#' @rdname mlr_learner_conditions -#' @export -condition_learner_train = function(msg, ..., class = NULL) { - condition_learner(msg, ..., class = c(class, "mlr3ConditionLearnerTrain")) -} - #' @rdname mlr_learner_conditions #' @export -error_learner_predict = function(msg, ..., class = NULL) { - stop(condition_learner_predict(msg, ..., class = class)) +error_learner_train = function(msg, ..., class = NULL, silent = FALSE) { + condition = error_learner(msg, ..., class = c(class, "mlr3ErrorLearnerTrain"), silent = TRUE) + if (silent) { + return(condition) + } + stop(condition) } #' @rdname mlr_learner_conditions #' @export -condition_learner_predict = function(msg, ..., class = NULL) { - condition_learner(msg, ..., class = c(class, "mlr3ConditionLearnerPredict")) +error_learner_predict = function(msg, ..., class = NULL, silent = FALSE) { + condition = error_learner(msg, ..., class = c(class, "mlr3ErrorLearnerPredict"), silent = TRUE) + if (silent) { + return(condition) + } + stop(condition) } diff --git a/man/mlr_learner_conditions.Rd b/man/mlr_learner_conditions.Rd index 00b12beb6..170905817 100644 --- a/man/mlr_learner_conditions.Rd +++ b/man/mlr_learner_conditions.Rd @@ -3,24 +3,15 @@ \name{mlr_learner_conditions} \alias{mlr_learner_conditions} \alias{error_learner} -\alias{condition_learner} \alias{error_learner_train} -\alias{condition_learner_train} \alias{error_learner_predict} -\alias{condition_learner_predict} \title{Learner Classes} \usage{ -error_learner(msg, ..., class = NULL) +error_learner(msg, ..., class = NULL, silent = FALSE) -condition_learner(msg, ..., class = NULL) +error_learner_train(msg, ..., class = NULL, silent = FALSE) -error_learner_train(msg, ..., class = NULL) - -condition_learner_train(msg, ..., class = NULL) - -error_learner_predict(msg, ..., class = NULL) - -condition_learner_predict(msg, ..., class = NULL) +error_learner_predict(msg, ..., class = NULL, silent = FALSE) } \arguments{ \item{msg}{(\code{character(1)})\cr @@ -31,6 +22,9 @@ Additional arguments to be passed to \code{\link[mlr3misc:printf]{mlr3misc::stop \item{class}{(\code{character()})\cr Additional classes to be added to the error.} + +\item{silent}{(\code{logical(1)})\cr +If \code{TRUE}, the error is returned as a condition object instead of being stopped.} } \description{ Learner Classes diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 52d3ac7c1..7b7c0e92d 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -842,6 +842,7 @@ test_that("Learner printer for encapsulation", { test_that("error conditions are working: callr", { l = lrn("classif.debug", timeout = c(train = 0.01), + # Sys.sleep does not get interrupted reliably sleep_train = function() while (TRUE) NULL ) @@ -861,6 +862,7 @@ test_that("error conditions are working: callr", { test_that("error conditions are working: evaluate", { l = lrn("classif.debug", timeout = c(train = 0.2), + # Sys.sleep does not get interrupted reliably sleep_train = function() while (TRUE) NULL ) @@ -880,6 +882,7 @@ test_that("error conditions are working: evaluate", { test_that("error conditions are working: try", { l = lrn("classif.debug", timeout = c(train = 0.01), + # Sys.sleep does not get interrupted reliably sleep_train = function() while (TRUE) NULL ) From 31bc072f89bad2b669a1b720ca6200ca4dc335af Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 12 Aug 2025 13:21:47 +0200 Subject: [PATCH 03/12] fixes --- R/Learner.R | 8 +++--- R/LearnerClassifDebug.R | 8 +++++- R/benchmark_grid.R | 6 ++--- R/conditions.R | 43 +++++++++++++++---------------- R/worker.R | 15 +++++------ man/Learner.Rd | 4 +-- man/benchmark_grid.Rd | 2 +- man/mlr_learner_conditions.Rd | 26 +++++++++++++------ man/mlr_learners_classif.debug.Rd | 1 + man/mlr_measures_regr.bias.Rd | 4 +-- man/mlr_measures_regr.pbias.Rd | 4 +-- tests/testthat/test_Learner.R | 17 ++++++++---- 12 files changed, 80 insertions(+), 58 deletions(-) diff --git a/R/Learner.R b/R/Learner.R index 077c31915..b03309494 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -576,7 +576,7 @@ Learner = R6Class("Learner", #' See the description for details. #' @param fallback [Learner]\cr #' The fallback learner for failed predictions. - #' @param should_catch (`function(condition)`)\cr + #' @param when (`function(condition)`)\cr #' Function that takes in the condition and returns `logical(1)` indicating whether to run the fallback learner. #' #' @return `self` (invisibly). @@ -584,10 +584,10 @@ Learner = R6Class("Learner", #' learner = lrn("classif.rpart") #' fallback = lrn("classif.featureless") #' learner$encapsulate("try", fallback = fallback) - encapsulate = function(method, fallback = NULL, should_catch = NULL) { + encapsulate = function(method, fallback = NULL, when = NULL) { assert_choice(method, c("none", "try", "evaluate", "callr")) - private$.should_catch = assert_function(should_catch, null.ok = TRUE) + private$.when = assert_function(when, null.ok = TRUE) if (method != "none") { assert_learner(fallback, task_type = self$task_type) @@ -842,7 +842,7 @@ Learner = R6Class("Learner", ), private = list( - .should_catch = NULL, + .when = NULL, .use_weights = NULL, .encapsulation = c(train = "none", predict = "none"), .fallback = NULL, diff --git a/R/LearnerClassifDebug.R b/R/LearnerClassifDebug.R index e9693022a..a034c8b90 100644 --- a/R/LearnerClassifDebug.R +++ b/R/LearnerClassifDebug.R @@ -80,7 +80,8 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, iter = p_iter, early_stopping = p_lgl(default = FALSE, tags = "train"), count_marshaling = p_lgl(default = FALSE, tags = "train"), - check_pid = p_lgl(default = TRUE, tags = "train") + check_pid = p_lgl(default = TRUE, tags = "train"), + config_error = p_lgl(default = FALSE, tags = "train") ) super$initialize( id = "classif.debug", @@ -161,6 +162,11 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, .validate = NULL, .train = function(task) { pv = self$param_set$get_values(tags = "train") + config_error = pv$config_error %??% FALSE + if (isTRUE(config_error)) { + error_config("You misconfigured the learner") + } + pv$count_marshaling = pv$count_marshaling %??% FALSE roll = function(name) { name %chin% names(pv) && pv[[name]] > runif(1L) diff --git a/R/benchmark_grid.R b/R/benchmark_grid.R index 781668e49..41f3a40f8 100644 --- a/R/benchmark_grid.R +++ b/R/benchmark_grid.R @@ -15,7 +15,7 @@ #' The grid will be generated based on the Cartesian product of learners and pairs. #' #' @section Errors and Warnings: -#' * `mlr3WarningVaryingPredictTypes`: This warning will be thrown if the learners have different `predict_type`s. +#' * `Mlr3WarningVaryingPredictTypes`: This warning will be thrown if the learners have different `predict_type`s. #' #' @param tasks (list of [Task]). #' @param learners (list of [Learner]). @@ -79,9 +79,9 @@ benchmark_grid = function(tasks, learners, resamplings, param_values = NULL, pai assert_param_values(param_values, n_learners = length(learners)) } if (length(unique(map_chr(unique(learners), "predict_type"))) > 1) { - mlr3_warning( + warning_config( "Multiple predict types detected, this will mean that you cannot evaluate the same measures on all learners.", - class = "mlr3WarningVaryingPredictTypes" + class = "Mlr3WarningVaryingPredictTypes" ) } diff --git a/R/conditions.R b/R/conditions.R index da34662ba..39a31cb1a 100644 --- a/R/conditions.R +++ b/R/conditions.R @@ -1,37 +1,36 @@ -#' @title Learner Classes +#' @title Learner Conditions #' @name mlr_learner_conditions +#' +#' @description +#' These functions are used to create conditions for errors related to [`Learner`]s. +#' This extends those from [mlr3misc::mlr_conditions]. +#' +#' @section Errors: +#' * `Mlr3ErrorLearner`: Base class for mlr3 errors. +#' * `Mlr3ErrorLearnerTrain`: Errors during training. +#' * `Mlr3ErrorLearnerPredict`: Errors during prediction. +#' #' @param msg (`character(1)`)\cr -#' Message to be printed. +#' Message to be printed. (formatted with `sprintf()`) #' @param ... (`any`)\cr -#' Additional arguments to be passed to [mlr3misc::stopf()]. +#' Additional arguments to be passed to `sprintf()`. #' @param class (`character()`)\cr #' Additional classes to be added to the error. -#' @param silent (`logical(1)`)\cr +#' @param signal (`logical(1)`)\cr #' If `TRUE`, the error is returned as a condition object instead of being stopped. #' @export -error_learner = function(msg, ..., class = NULL, silent = FALSE) { - condition = mlr3_error(msg, ..., class = c(class, "mlr3ErrorLearner")) - if (silent) { - return(condition) - } - stop(condition) +error_learner = function(msg, ..., class = NULL, signal = TRUE) { + error_mlr3(msg, ..., class = c(class, "Mlr3ErrorLearner"), signal = signal) } + #' @rdname mlr_learner_conditions #' @export -error_learner_train = function(msg, ..., class = NULL, silent = FALSE) { - condition = error_learner(msg, ..., class = c(class, "mlr3ErrorLearnerTrain"), silent = TRUE) - if (silent) { - return(condition) - } - stop(condition) +error_learner_train = function(msg, ..., class = NULL, signal = TRUE) { + error_learner(msg, ..., class = c(class, "Mlr3ErrorLearnerTrain"), signal = signal) } #' @rdname mlr_learner_conditions #' @export -error_learner_predict = function(msg, ..., class = NULL, silent = FALSE) { - condition = error_learner(msg, ..., class = c(class, "mlr3ErrorLearnerPredict"), silent = TRUE) - if (silent) { - return(condition) - } - stop(condition) +error_learner_predict = function(msg, ..., class = NULL, signal = TRUE) { + error_learner(msg, ..., class = c(class, "Mlr3ErrorLearnerPredict"), signal = signal) } diff --git a/R/worker.R b/R/worker.R index 73037947f..4b4f40d7a 100644 --- a/R/worker.R +++ b/R/worker.R @@ -109,20 +109,19 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL # select rows where column 'class' is equal to "error" in data.table style cond = result$log[class == "error", "condition"] cond = if (nrow(cond)) { - cond = cond[[1L]][[1L]] + cond[[1L]][[1L]] } - should_catch = get_private(learner)$.should_catch - would_catch = is.null(should_catch) || (!is.null(cond) && should_catch(cond)) + when = get_private(learner)$.when + catch_error = is.null(when) || (!is.null(cond) && (!inherits(cond, "Mlr3ErrorConfig") && when(cond))) - log = append_log(NULL, "train", result$log$class, result$log$msg, would_catch = would_catch) + log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = catch_error) train_time = result$elapsed - if (!would_catch) { + if (!catch_error) { stop(cond) } - learner$state = set_class(insert_named(learner$state, list( model = result$result$model, log = log, @@ -506,7 +505,7 @@ process_model_after_predict = function(learner, store_models, is_sequential, unm } } -append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character(), would_catch = TRUE) { +append_log = function(log = NULL, stage = NA_character_, class = NA_character_, msg = character(), log_error = TRUE) { if (is.null(log)) { log = data.table( stage = factor(levels = c("train", "predict")), @@ -517,7 +516,7 @@ append_log = function(log = NULL, stage = NA_character_, class = NA_character_, if (length(msg)) { pwalk(list(stage, class, msg), function(s, c, m) { - if (c == "error" && would_catch) lg$error("%s: %s", s, m) + if (c == "error" && log_error) lg$error("%s: %s", s, m) if (c == "warning") lg$warn("%s: %s", s, m) }) log = rbindlist(list(log, data.table(stage = stage, class = class, msg = msg)), use.names = TRUE) diff --git a/man/Learner.Rd b/man/Learner.Rd index fc29fe240..79e21008c 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -749,7 +749,7 @@ If the training step fails, the \verb{$model} field of the original learner is \ Also see the section on error handling the mlr3book: \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling} \subsection{Usage}{ -\if{html}{\out{
}}\preformatted{Learner$encapsulate(method, fallback = NULL, should_catch = NULL)}\if{html}{\out{
}} +\if{html}{\out{
}}\preformatted{Learner$encapsulate(method, fallback = NULL, when = NULL)}\if{html}{\out{
}} } \subsection{Arguments}{ @@ -762,7 +762,7 @@ See the description for details.} \item{\code{fallback}}{\link{Learner}\cr The fallback learner for failed predictions.} -\item{\code{should_catch}}{(\verb{function(condition)})\cr +\item{\code{when}}{(\verb{function(condition)})\cr Function that takes in the condition and returns \code{logical(1)} indicating whether to run the fallback learner.} } \if{html}{\out{}} diff --git a/man/benchmark_grid.Rd b/man/benchmark_grid.Rd index db72c5613..19c630508 100644 --- a/man/benchmark_grid.Rd +++ b/man/benchmark_grid.Rd @@ -56,7 +56,7 @@ The grid will be generated based on the Cartesian product of learners and pairs. \section{Errors and Warnings}{ \itemize{ -\item \code{varying_predict_types}: This warning will be thrown if the learners have different \code{predict_type}s. +\item \code{Mlr3WarningVaryingPredictTypes}: This warning will be thrown if the learners have different \code{predict_type}s. } } diff --git a/man/mlr_learner_conditions.Rd b/man/mlr_learner_conditions.Rd index 170905817..f52d87dab 100644 --- a/man/mlr_learner_conditions.Rd +++ b/man/mlr_learner_conditions.Rd @@ -5,27 +5,37 @@ \alias{error_learner} \alias{error_learner_train} \alias{error_learner_predict} -\title{Learner Classes} +\title{Learner Conditions} \usage{ -error_learner(msg, ..., class = NULL, silent = FALSE) +error_learner(msg, ..., class = NULL, signal = TRUE) -error_learner_train(msg, ..., class = NULL, silent = FALSE) +error_learner_train(msg, ..., class = NULL, signal = TRUE) -error_learner_predict(msg, ..., class = NULL, silent = FALSE) +error_learner_predict(msg, ..., class = NULL, signal = TRUE) } \arguments{ \item{msg}{(\code{character(1)})\cr -Message to be printed.} +Message to be printed. (formatted with \code{sprintf()})} \item{...}{(\code{any})\cr -Additional arguments to be passed to \code{\link[mlr3misc:printf]{mlr3misc::stopf()}}.} +Additional arguments to be passed to \code{sprintf()}.} \item{class}{(\code{character()})\cr Additional classes to be added to the error.} -\item{silent}{(\code{logical(1)})\cr +\item{signal}{(\code{logical(1)})\cr If \code{TRUE}, the error is returned as a condition object instead of being stopped.} } \description{ -Learner Classes +These functions are used to create conditions for errors related to \code{\link{Learner}}s. +This extends those from \link[mlr3misc:mlr_conditions]{mlr3misc::mlr_conditions}. } +\section{Errors}{ + +\itemize{ +\item \code{Mlr3ErrorLearner}: Base class for mlr3 errors. +\item \code{Mlr3ErrorLearnerTrain}: Errors during training. +\item \code{Mlr3ErrorLearnerPredict}: Errors during prediction. +} +} + diff --git a/man/mlr_learners_classif.debug.Rd b/man/mlr_learners_classif.debug.Rd index 1d157c024..cc3353671 100644 --- a/man/mlr_learners_classif.debug.Rd +++ b/man/mlr_learners_classif.debug.Rd @@ -72,6 +72,7 @@ lrn("classif.debug") early_stopping \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr count_marshaling \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr check_pid \tab logical \tab TRUE \tab TRUE, FALSE \tab - \cr + config_error \tab logical \tab FALSE \tab TRUE, FALSE \tab - \cr } } diff --git a/man/mlr_measures_regr.bias.Rd b/man/mlr_measures_regr.bias.Rd index d0d7bd160..f5451f5f3 100644 --- a/man/mlr_measures_regr.bias.Rd +++ b/man/mlr_measures_regr.bias.Rd @@ -8,9 +8,9 @@ Measure to compare true observed response with predicted response in regression } \details{ The Bias is defined as \deqn{ - \frac{1}{n} \sum_{i=1}^n w_i \left( r_i - t_i \right), + \frac{1}{n} \sum_{i=1}^n w_i \left( t_i - r_i \right), }{ - weighted.mean(r - t, w), + weighted.mean(t - r, w), } where \eqn{w_i} are normalized sample weights. Good predictions score close to 0. diff --git a/man/mlr_measures_regr.pbias.Rd b/man/mlr_measures_regr.pbias.Rd index 7e659709e..f0d5f5b0f 100644 --- a/man/mlr_measures_regr.pbias.Rd +++ b/man/mlr_measures_regr.pbias.Rd @@ -8,9 +8,9 @@ Measure to compare true observed response with predicted response in regression } \details{ The Percent Bias is defined as \deqn{ - \frac{1}{n} \sum_{i=1}^n w_i \frac{\left( r_i - t_i \right)}{\left| t_i \right|}, + \frac{1}{n} \sum_{i=1}^n w_i \frac{\left( t_i - r_i \right)}{\left| t_i \right|}, }{ - weighted.mean((r - t) / abs(t), w), + weighted.mean((t - r) / abs(t), w), } where \eqn{w_i} are normalized sample weights. Good predictions score close to 0. diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 7fda2308f..022c35f7d 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -849,8 +849,8 @@ test_that("error conditions are working: callr", { l$encapsulate( "callr", lrn("classif.featureless"), - should_catch = function(cond) { - !inherits(cond, "mlr3ErrorTimeout") + when = function(cond) { + !inherits(cond, "Mlr3ErrorTimeout") } ) @@ -870,7 +870,7 @@ test_that("error conditions are working: evaluate", { "evaluate", lrn("classif.featureless"), function(x) { - !inherits(x, "mlr3ErrorTimeout") + !inherits(x, "Mlr3ErrorTimeout") } ) @@ -890,14 +890,15 @@ test_that("error conditions are working: try", { "try", lrn("classif.featureless"), function(x) { - !inherits(x, "mlr3ErrorTimeout") + !inherits(x, "Mlr3ErrorTimeout") } ) expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit") l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf)) expect_error(l$train(tsk("iris")), regexp = NA) -======= +}) + test_that("oob_error is available without storing models via $.extract_oob_error()", { LearnerDummyOOB = R6::R6Class("LearnerDummyOOB", inherit = LearnerClassif, public = list( @@ -931,3 +932,9 @@ test_that("oob_error is available without storing models via $.extract_oob_error expect_equal(rr$aggregate(msr("oob_error")), c(oob_error = 0.123)) }) + +test_that("config error does not trigger callback", { + l = lrn("classif.debug", config_error = TRUE) + l$encapsulate("evaluate", lrn("classif.featureless"), function(x) TRUE) + expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner") +}) From 7cbec788189fb81534cf81ea08fa2d30ae75726d Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 12 Aug 2025 13:27:25 +0200 Subject: [PATCH 04/12] trigger ci --- R/conditions.R | 2 +- man/mlr_learner_conditions.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/conditions.R b/R/conditions.R index 39a31cb1a..e502198c2 100644 --- a/R/conditions.R +++ b/R/conditions.R @@ -17,7 +17,7 @@ #' @param class (`character()`)\cr #' Additional classes to be added to the error. #' @param signal (`logical(1)`)\cr -#' If `TRUE`, the error is returned as a condition object instead of being stopped. +#' If `TRUE`, the error object is returned without stopping the interpreter. #' @export error_learner = function(msg, ..., class = NULL, signal = TRUE) { error_mlr3(msg, ..., class = c(class, "Mlr3ErrorLearner"), signal = signal) diff --git a/man/mlr_learner_conditions.Rd b/man/mlr_learner_conditions.Rd index f52d87dab..d778d6f25 100644 --- a/man/mlr_learner_conditions.Rd +++ b/man/mlr_learner_conditions.Rd @@ -24,7 +24,7 @@ Additional arguments to be passed to \code{sprintf()}.} Additional classes to be added to the error.} \item{signal}{(\code{logical(1)})\cr -If \code{TRUE}, the error is returned as a condition object instead of being stopped.} +If \code{TRUE}, the error object is returned without stopping the interpreter.} } \description{ These functions are used to create conditions for errors related to \code{\link{Learner}}s. From db992b3fa85a83dbc446610c78dc3bcb23ed7560 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 12 Aug 2025 13:59:22 +0200 Subject: [PATCH 05/12] fix pkgdown --- pkgdown/_pkgdown.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index 2dd77cc6a..202da61b6 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -123,6 +123,10 @@ reference: - assert_resample_callbacks - mlr3.model_extractor - mlr3.holdout_task + - title: Conditions + contents: + - starts_with("error_") + - starts_with("warning_") - title: Internal Objects and Functions contents: - marshaling From aa5c338ea402ba774b76453d87b9dffb7646dddb Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 13 Aug 2025 14:50:25 +0200 Subject: [PATCH 06/12] docs --- R/Learner.R | 4 ++++ man/Learner.Rd | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/R/Learner.R b/R/Learner.R index b03309494..c93811078 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -568,6 +568,9 @@ Learner = R6Class("Learner", #' Note that the fallback is always trained, as we do not know in advance whether prediction will fail. #' If the training step fails, the `$model` field of the original learner is `NULL`. #' + #' Note that for errors of class `Mlr3ErrorConfig`, the function always errs and no fallback learner + #' is trained. + #' #' Also see the section on error handling the mlr3book: #' \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling} #' @@ -578,6 +581,7 @@ Learner = R6Class("Learner", #' The fallback learner for failed predictions. #' @param when (`function(condition)`)\cr #' Function that takes in the condition and returns `logical(1)` indicating whether to run the fallback learner. + #' If `NULL` (default), the fallback is always trained, except for errors of class `Mlr3ErrorConfig`. #' #' @return `self` (invisibly). #' @examples diff --git a/man/Learner.Rd b/man/Learner.Rd index 79e21008c..0065145b9 100644 --- a/man/Learner.Rd +++ b/man/Learner.Rd @@ -746,6 +746,9 @@ If the original learner only partially fails during predict step (usually in the Note that the fallback is always trained, as we do not know in advance whether prediction will fail. If the training step fails, the \verb{$model} field of the original learner is \code{NULL}. +Note that for errors of class \code{Mlr3ErrorConfig}, the function always errs and no fallback learner +is trained. + Also see the section on error handling the mlr3book: \url{https://mlr3book.mlr-org.com/chapters/chapter10/advanced_technical_aspects_of_mlr3.html#sec-error-handling} \subsection{Usage}{ @@ -763,7 +766,8 @@ See the description for details.} The fallback learner for failed predictions.} \item{\code{when}}{(\verb{function(condition)})\cr -Function that takes in the condition and returns \code{logical(1)} indicating whether to run the fallback learner.} +Function that takes in the condition and returns \code{logical(1)} indicating whether to run the fallback learner. +If \code{NULL} (default), the fallback is always trained, except for errors of class \code{Mlr3ErrorConfig}.} } \if{html}{\out{}} } From 4859d2b0a11a6e67249d4586f91e89b8faf6715d Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 15 Aug 2025 18:30:54 +0200 Subject: [PATCH 07/12] add mirai test --- tests/testthat/test_Learner.R | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index 022c35f7d..b3811205b 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -899,6 +899,26 @@ test_that("error conditions are working: try", { expect_error(l$train(tsk("iris")), regexp = NA) }) +test_that("error conditions are working: mirai", { + l = lrn("classif.debug", + timeout = c(train = 0.01), + # Sys.sleep does not get interrupted reliably + sleep_train = function() while (TRUE) NULL + ) + + l$encapsulate( + "mirai", + lrn("classif.featureless"), + function(x) { + !inherits(x, "Mlr3ErrorTimeout") + } + ) + + expect_error(l$train(tsk("iris")), regexp = "reached elapsed time limit") + l$configure(error_train = 1, sleep_train = NULL, timeout = c(train = Inf, predict = Inf)) + expect_error(l$train(tsk("iris")), regexp = NA) +}) + test_that("oob_error is available without storing models via $.extract_oob_error()", { LearnerDummyOOB = R6::R6Class("LearnerDummyOOB", inherit = LearnerClassif, public = list( From 87406c447abda741d82f7bea39d55821c482f874 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 15 Aug 2025 18:38:36 +0200 Subject: [PATCH 08/12] depend on mlr3misc main --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 885cf0623..6aa595bcc 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -74,7 +74,7 @@ Suggests: rpart, testthat (>= 3.2.0) Remotes: - mlr-org/mlr3misc@errors + mlr-org/mlr3misc Encoding: UTF-8 Config/testthat/edition: 3 Config/testthat/parallel: false From 86a7b8c33bd779061f344eea6ec0d236592e3f80 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 15 Aug 2025 18:55:32 +0200 Subject: [PATCH 09/12] cleanup --- DESCRIPTION | 1 - NAMESPACE | 3 --- R/Learner.R | 16 +++++++------- R/conditions.R | 36 ------------------------------ man/mlr_learner_conditions.Rd | 41 ----------------------------------- 5 files changed, 8 insertions(+), 89 deletions(-) delete mode 100644 R/conditions.R delete mode 100644 man/mlr_learner_conditions.Rd diff --git a/DESCRIPTION b/DESCRIPTION index 6aa595bcc..f5b6c9196 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -188,7 +188,6 @@ Collate: 'benchmark.R' 'benchmark_grid.R' 'bibentries.R' - 'conditions.R' 'default_fallback.R' 'default_measures.R' 'fix_factor_levels.R' diff --git a/NAMESPACE b/NAMESPACE index e29467c56..12a1e0dc8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -236,9 +236,6 @@ export(data.table) export(default_fallback) export(default_measures) export(deprecated_binding) -export(error_learner) -export(error_learner_predict) -export(error_learner_train) export(extract_pkgs) export(filter_prediction_data) export(install_pkgs) diff --git a/R/Learner.R b/R/Learner.R index 18e9a39e5..2db4358ff 100644 --- a/R/Learner.R +++ b/R/Learner.R @@ -389,11 +389,11 @@ Learner = R6Class("Learner", # to it original state model_was_marshaled = is_marshaled_model(self$model) on.exit({ - if (model_was_marshaled) { - self$model = marshal_model(self$model, inplace = TRUE) - } else { - self$model = unmarshal_model(self$model, inplace = TRUE) - } + if (model_was_marshaled) { + self$model = marshal_model(self$model, inplace = TRUE) + } else { + self$model = unmarshal_model(self$model, inplace = TRUE) + } }, add = TRUE) # reset learner predict time; this is only cumulative for multiple predict sets, @@ -480,7 +480,7 @@ Learner = R6Class("Learner", ci = task$col_info[list(keep_cols), ][ get("type") != col_info(newdata)[list(keep_cols), on = "id"]$type] tab2 = do.call(data.table, Map(auto_convert, - value = as.list(newdata$data(rows = newdata$rownames, cols = ci$id)), + value = as.list(newdata$data(rows = newdata$rownames, cols = ci$id)), id = ci$id, type = ci$type, levels = ci$levels)) tab = cbind(tab1, tab2) @@ -930,8 +930,8 @@ marshal_model.learner_state = function(model, inplace = FALSE, ...) { } model$model = mm structure(list( - marshaled = model, - packages = "mlr3" + marshaled = model, + packages = "mlr3" ), class = c("learner_state_marshaled", "list_marshaled", "marshaled")) } diff --git a/R/conditions.R b/R/conditions.R deleted file mode 100644 index e502198c2..000000000 --- a/R/conditions.R +++ /dev/null @@ -1,36 +0,0 @@ -#' @title Learner Conditions -#' @name mlr_learner_conditions -#' -#' @description -#' These functions are used to create conditions for errors related to [`Learner`]s. -#' This extends those from [mlr3misc::mlr_conditions]. -#' -#' @section Errors: -#' * `Mlr3ErrorLearner`: Base class for mlr3 errors. -#' * `Mlr3ErrorLearnerTrain`: Errors during training. -#' * `Mlr3ErrorLearnerPredict`: Errors during prediction. -#' -#' @param msg (`character(1)`)\cr -#' Message to be printed. (formatted with `sprintf()`) -#' @param ... (`any`)\cr -#' Additional arguments to be passed to `sprintf()`. -#' @param class (`character()`)\cr -#' Additional classes to be added to the error. -#' @param signal (`logical(1)`)\cr -#' If `TRUE`, the error object is returned without stopping the interpreter. -#' @export -error_learner = function(msg, ..., class = NULL, signal = TRUE) { - error_mlr3(msg, ..., class = c(class, "Mlr3ErrorLearner"), signal = signal) -} - -#' @rdname mlr_learner_conditions -#' @export -error_learner_train = function(msg, ..., class = NULL, signal = TRUE) { - error_learner(msg, ..., class = c(class, "Mlr3ErrorLearnerTrain"), signal = signal) -} - -#' @rdname mlr_learner_conditions -#' @export -error_learner_predict = function(msg, ..., class = NULL, signal = TRUE) { - error_learner(msg, ..., class = c(class, "Mlr3ErrorLearnerPredict"), signal = signal) -} diff --git a/man/mlr_learner_conditions.Rd b/man/mlr_learner_conditions.Rd deleted file mode 100644 index d778d6f25..000000000 --- a/man/mlr_learner_conditions.Rd +++ /dev/null @@ -1,41 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/conditions.R -\name{mlr_learner_conditions} -\alias{mlr_learner_conditions} -\alias{error_learner} -\alias{error_learner_train} -\alias{error_learner_predict} -\title{Learner Conditions} -\usage{ -error_learner(msg, ..., class = NULL, signal = TRUE) - -error_learner_train(msg, ..., class = NULL, signal = TRUE) - -error_learner_predict(msg, ..., class = NULL, signal = TRUE) -} -\arguments{ -\item{msg}{(\code{character(1)})\cr -Message to be printed. (formatted with \code{sprintf()})} - -\item{...}{(\code{any})\cr -Additional arguments to be passed to \code{sprintf()}.} - -\item{class}{(\code{character()})\cr -Additional classes to be added to the error.} - -\item{signal}{(\code{logical(1)})\cr -If \code{TRUE}, the error object is returned without stopping the interpreter.} -} -\description{ -These functions are used to create conditions for errors related to \code{\link{Learner}}s. -This extends those from \link[mlr3misc:mlr_conditions]{mlr3misc::mlr_conditions}. -} -\section{Errors}{ - -\itemize{ -\item \code{Mlr3ErrorLearner}: Base class for mlr3 errors. -\item \code{Mlr3ErrorLearnerTrain}: Errors during training. -\item \code{Mlr3ErrorLearnerPredict}: Errors during prediction. -} -} - From 815c51b2d3a1110459d6b5a4b401441926c0a2e2 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 15 Aug 2025 20:01:34 +0200 Subject: [PATCH 10/12] Apply suggestions from code review --- R/LearnerClassifDebug.R | 3 +-- R/worker.R | 1 - pkgdown/_pkgdown.yml | 4 ---- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/R/LearnerClassifDebug.R b/R/LearnerClassifDebug.R index a034c8b90..c91e37d49 100644 --- a/R/LearnerClassifDebug.R +++ b/R/LearnerClassifDebug.R @@ -162,8 +162,7 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif, .validate = NULL, .train = function(task) { pv = self$param_set$get_values(tags = "train") - config_error = pv$config_error %??% FALSE - if (isTRUE(config_error)) { + if (isTRUE(pv$config_error)) { error_config("You misconfigured the learner") } diff --git a/R/worker.R b/R/worker.R index 090b369c4..f68a96196 100644 --- a/R/worker.R +++ b/R/worker.R @@ -107,7 +107,6 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL .timeout = learner$timeout["train"], .compute = getOption("mlr3.mirai_encapsulation", "mlr3_encapsulation") ) - # select rows where column 'class' is equal to "error" in data.table style cond = result$log[class == "error", "condition"] cond = if (nrow(cond)) { cond[[1L]][[1L]] diff --git a/pkgdown/_pkgdown.yml b/pkgdown/_pkgdown.yml index 202da61b6..2dd77cc6a 100644 --- a/pkgdown/_pkgdown.yml +++ b/pkgdown/_pkgdown.yml @@ -123,10 +123,6 @@ reference: - assert_resample_callbacks - mlr3.model_extractor - mlr3.holdout_task - - title: Conditions - contents: - - starts_with("error_") - - starts_with("warning_") - title: Internal Objects and Functions contents: - marshaling From 88e4dc266f1b809ad55a88abad8fbaf92c7b4eac Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 15 Aug 2025 22:26:06 +0200 Subject: [PATCH 11/12] fix issues --- R/worker.R | 14 ++++++++------ tests/testthat/test_Learner.R | 2 ++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/R/worker.R b/R/worker.R index f68a96196..d60323d9b 100644 --- a/R/worker.R +++ b/R/worker.R @@ -107,18 +107,20 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL .timeout = learner$timeout["train"], .compute = getOption("mlr3.mirai_encapsulation", "mlr3_encapsulation") ) - cond = result$log[class == "error", "condition"] - cond = if (nrow(cond)) { - cond[[1L]][[1L]] + + cond = result$log[class == "error", "condition"][[1L]] + + cond = if (length(cond)) { + cond = cond[[1L]] } when = get_private(learner)$.when - catch_error = is.null(when) || (!is.null(cond) && (!inherits(cond, "Mlr3ErrorConfig") && when(cond))) + would_catch_error = (!is.null(cond)) && (!inherits(cond, "Mlr3ErrorConfig")) && (is.null(when) || when(cond)) - log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = catch_error) + log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = would_catch_error) train_time = result$elapsed - if (!catch_error) { + if (!is.null(cond) && !would_catch_error) { stop(cond) } diff --git a/tests/testthat/test_Learner.R b/tests/testthat/test_Learner.R index b3811205b..2b450fcd1 100644 --- a/tests/testthat/test_Learner.R +++ b/tests/testthat/test_Learner.R @@ -957,4 +957,6 @@ test_that("config error does not trigger callback", { l = lrn("classif.debug", config_error = TRUE) l$encapsulate("evaluate", lrn("classif.featureless"), function(x) TRUE) expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner") + l$encapsulate("evaluate", lrn("classif.featureless")) + expect_error(l$train(tsk("iris")), regexp = "You misconfigured the learner") }) From 96627cf643a8d0fca0c371779e1e7326ca4d46a9 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 16 Aug 2025 10:24:07 +0200 Subject: [PATCH 12/12] simplify --- R/worker.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/worker.R b/R/worker.R index d60323d9b..27633f1ee 100644 --- a/R/worker.R +++ b/R/worker.R @@ -115,12 +115,12 @@ learner_train = function(learner, task, train_row_ids = NULL, test_row_ids = NUL } when = get_private(learner)$.when - would_catch_error = (!is.null(cond)) && (!inherits(cond, "Mlr3ErrorConfig")) && (is.null(when) || when(cond)) + catch_error = (!is.null(cond)) && (!inherits(cond, "Mlr3ErrorConfig")) && (is.null(when) || when(cond)) - log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = would_catch_error) + log = append_log(NULL, "train", result$log$class, result$log$msg, log_error = catch_error) train_time = result$elapsed - if (!is.null(cond) && !would_catch_error) { + if (!is.null(cond) && !catch_error) { stop(cond) }