Skip to content

Commit a792cd0

Browse files
authored
[R] Check invalid input for the cv function. (dmlc#11264)
- rename the es parameter. - check qdm input.
1 parent e1ce3f6 commit a792cd0

File tree

5 files changed

+61
-14
lines changed

5 files changed

+61
-14
lines changed

R-package/R/callbacks.R

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -613,17 +613,19 @@ xgb.cb.reset.parameters <- function(new_params) {
613613
#' `metric_name = 'dtest-auc'` or `metric_name = 'dtest_auc'`.
614614
#' All dash '-' characters in metric names are considered equivalent to '_'.
615615
#' @param verbose Whether to print the early stopping information.
616-
#' @param keep_all_iter Whether to keep all of the boosting rounds that were produced
617-
#' in the resulting object. If passing `FALSE`, will only keep the boosting rounds
618-
#' up to the detected best iteration, discarding the ones that come after.
616+
#'
617+
#' @param save_best Whether training should return the best model or the last model. If
618+
#' set to `TRUE`, it will only keep the boosting rounds up to the detected best
619+
#' iteration, discarding the ones that come after. This parameter is not supported by
620+
#' the `xgb.cv` function and the `gblinear` booster yet.
619621
#' @return An `xgb.Callback` object, which can be passed to [xgb.train()] or [xgb.cv()].
620622
#' @export
621623
xgb.cb.early.stop <- function(
622624
stopping_rounds,
623625
maximize = FALSE,
624626
metric_name = NULL,
625627
verbose = TRUE,
626-
keep_all_iter = TRUE
628+
save_best = FALSE
627629
) {
628630
if (!is.null(metric_name)) {
629631
stopifnot(is.character(metric_name))
@@ -639,14 +641,17 @@ xgb.cb.early.stop <- function(
639641
maximize = maximize,
640642
metric_name = metric_name,
641643
verbose = verbose,
642-
keep_all_iter = keep_all_iter,
644+
save_best = save_best,
643645
stopped_by_max_rounds = FALSE
644646
)
645647
),
646648
f_before_training = function(env, model, data, evals, begin_iteration, end_iteration) {
647649
if (inherits(model, "xgb.Booster") && !length(evals)) {
648650
stop("For early stopping, 'evals' must have at least one element")
649651
}
652+
if (!inherits(model, "xgb.Booster") && save_best) {
653+
stop("'save_best' must be set to FALSE when using early stopping in 'xgb.cv'.")
654+
}
650655
env$begin_iteration <- begin_iteration
651656
return(NULL)
652657
},
@@ -731,7 +736,7 @@ xgb.cb.early.stop <- function(
731736
return(FALSE)
732737
},
733738
f_after_training = function(env, model, data, evals, iteration, final_feval, prev_cb_res) {
734-
if (inherits(model, "xgb.Booster") && !env$keep_all_iter && env$best_iteration < iteration) {
739+
if (inherits(model, "xgb.Booster") && env$save_best && env$best_iteration < iteration) {
735740
# Note: it loses the attributes after being sliced,
736741
# so they have to be re-assigned afterwards.
737742
prev_attr <- xgb.attributes(model)

R-package/R/xgb.cv.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,13 @@ xgb.cv <- function(params = xgb.params(), data, nrounds, nfold,
113113
check.deprecation(deprecated_cv_params, match.call(), ...)
114114

115115
stopifnot(inherits(data, "xgb.DMatrix"))
116+
116117
if (inherits(data, "xgb.DMatrix") && .Call(XGCheckNullPtr_R, data)) {
117118
stop("'data' is an invalid 'xgb.DMatrix' object. Must be constructed again.")
118119
}
120+
if (inherits(data, "xgb.QuantileDMatrix")) {
121+
stop("'xgb.QuantileDMatrix' is not supported as input to 'xgb.cv'.")
122+
}
119123

120124
params <- check.booster.params(params)
121125
# TODO: should we deprecate the redundant 'metrics' parameter?
@@ -171,7 +175,8 @@ xgb.cv <- function(params = xgb.params(), data, nrounds, nfold,
171175
xgb.cb.early.stop(
172176
early_stopping_rounds,
173177
maximize = maximize,
174-
verbose = verbose
178+
verbose = verbose,
179+
save_best = FALSE
175180
),
176181
as_first_elt = TRUE
177182
)

R-package/man/xgb.cb.early.stop.Rd

Lines changed: 5 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

R-package/tests/testthat/test_basic.R

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,40 @@ test_that("xgb.cv works", {
402402
expect_false(is.null(cv$call))
403403
})
404404

405+
test_that("xgb.cv invalid inputs", {
406+
data("mtcars")
407+
y <- mtcars$mpg
408+
x_df <- mtcars[, -1]
409+
410+
expect_error(
411+
cv <- xgb.cv(
412+
data = xgb.QuantileDMatrix(x_df, label = y),
413+
nfold = 5,
414+
nrounds = 2,
415+
params = xgb.params(
416+
max_depth = 2,
417+
nthread = n_threads
418+
)
419+
),
420+
regexp = ".*QuantileDMatrix.*"
421+
)
422+
expect_error(
423+
cv <- xgb.cv(
424+
data = xgb.DMatrix(x_df, label = y),
425+
nfold = 5,
426+
nrounds = 2,
427+
params = xgb.params(
428+
max_depth = 2,
429+
nthread = n_threads,
430+
),
431+
callbacks = list(
432+
xgb.cb.early.stop(stopping_rounds = 3, save_best = TRUE)
433+
)
434+
),
435+
regexp = ".*save_best.*"
436+
)
437+
})
438+
405439
test_that("xgb.cv works with stratified folds", {
406440
dtrain <- xgb.DMatrix(train$data, label = train$label, nthread = n_threads)
407441
set.seed(314159)

python-package/xgboost/callback.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,11 @@ class EarlyStopping(TrainingCallback):
325325
maximize :
326326
Whether to maximize evaluation metric. None means auto (discouraged).
327327
save_best :
328-
Whether training should return the best model or the last model. This is only
329-
supported with tree methods. Also, the `cv` function doesn't return a model, the
330-
parameter is not applicable.
328+
Whether training should return the best model or the last model. If set to
329+
`True`, it will only keep the boosting rounds up to the detected best iteration,
330+
discarding the ones that come after. This is only supported with tree methods
331+
(not `gblinear`). Also, the `cv` function doesn't return a model, the parameter
332+
is not applicable.
331333
min_delta :
332334
333335
.. versionadded:: 1.5.0

0 commit comments

Comments
 (0)