Skip to content

Commit 428f6cb

Browse files
authored
[R] remove default values in internal booster manipulation functions (dmlc#9461)
1 parent d638535 commit 428f6cb

File tree

5 files changed

+32
-13
lines changed

5 files changed

+32
-13
lines changed

R-package/R/callbacks.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,7 @@ cb.cv.predict <- function(save_models = FALSE) {
511511
if (save_models) {
512512
env$basket$models <- lapply(env$bst_folds, function(fd) {
513513
xgb.attr(fd$bst, 'niter') <- env$end_iteration - 1
514-
xgb.Booster.complete(xgb.handleToBooster(fd$bst), saveraw = TRUE)
514+
xgb.Booster.complete(xgb.handleToBooster(handle = fd$bst, raw = NULL), saveraw = TRUE)
515515
})
516516
}
517517
}
@@ -659,7 +659,7 @@ cb.gblinear.history <- function(sparse = FALSE) {
659659
} else { # xgb.cv:
660660
cf <- vector("list", length(env$bst_folds))
661661
for (i in seq_along(env$bst_folds)) {
662-
dmp <- xgb.dump(xgb.handleToBooster(env$bst_folds[[i]]$bst))
662+
dmp <- xgb.dump(xgb.handleToBooster(handle = env$bst_folds[[i]]$bst, raw = NULL))
663663
cf[[i]] <- as.numeric(grep('(booster|bias|weigh)', dmp, invert = TRUE, value = TRUE))
664664
if (sparse) cf[[i]] <- as(cf[[i]], "sparseVector")
665665
}

R-package/R/xgb.Booster.R

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Construct an internal xgboost Booster and return a handle to it.
22
# internal utility function
3-
xgb.Booster.handle <- function(params = list(), cachelist = list(),
4-
modelfile = NULL, handle = NULL) {
3+
xgb.Booster.handle <- function(params, cachelist, modelfile, handle) {
54
if (typeof(cachelist) != "list" ||
65
!all(vapply(cachelist, inherits, logical(1), what = 'xgb.DMatrix'))) {
76
stop("cachelist must be a list of xgb.DMatrix objects")
@@ -44,7 +43,7 @@ xgb.Booster.handle <- function(params = list(), cachelist = list(),
4443

4544
# Convert xgb.Booster.handle to xgb.Booster
4645
# internal utility function
47-
xgb.handleToBooster <- function(handle, raw = NULL) {
46+
xgb.handleToBooster <- function(handle, raw) {
4847
bst <- list(handle = handle, raw = raw)
4948
class(bst) <- "xgb.Booster"
5049
return(bst)
@@ -129,7 +128,12 @@ xgb.Booster.complete <- function(object, saveraw = TRUE) {
129128
stop("argument type must be xgb.Booster")
130129

131130
if (is.null.handle(object$handle)) {
132-
object$handle <- xgb.Booster.handle(modelfile = object$raw, handle = object$handle)
131+
object$handle <- xgb.Booster.handle(
132+
params = list(),
133+
cachelist = list(),
134+
modelfile = object$raw,
135+
handle = object$handle
136+
)
133137
} else {
134138
if (is.null(object$raw) && saveraw) {
135139
object$raw <- xgb.serialize(object$handle)
@@ -475,7 +479,7 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
475479
#' @export
476480
predict.xgb.Booster.handle <- function(object, ...) {
477481

478-
bst <- xgb.handleToBooster(object)
482+
bst <- xgb.handleToBooster(handle = object, raw = NULL)
479483

480484
ret <- predict(bst, ...)
481485
return(ret)

R-package/R/xgb.cv.R

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,12 @@ xgb.cv <- function(params = list(), data, nrounds, nfold, label = NULL, missing
202202
dtrain <- slice(dall, unlist(folds[-k]))
203203
else
204204
dtrain <- slice(dall, train_folds[[k]])
205-
handle <- xgb.Booster.handle(params, list(dtrain, dtest))
205+
handle <- xgb.Booster.handle(
206+
params = params,
207+
cachelist = list(dtrain, dtest),
208+
modelfile = NULL,
209+
handle = NULL
210+
)
206211
list(dtrain = dtrain, bst = handle, watchlist = list(train = dtrain, test = dtest), index = folds[[k]])
207212
})
208213
rm(dall)

R-package/R/xgb.load.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,12 @@ xgb.load <- function(modelfile) {
3535
if (is.null(modelfile))
3636
stop("xgb.load: modelfile cannot be NULL")
3737

38-
handle <- xgb.Booster.handle(modelfile = modelfile)
38+
handle <- xgb.Booster.handle(
39+
params = list(),
40+
cachelist = list(),
41+
modelfile = modelfile,
42+
handle = NULL
43+
)
3944
# re-use modelfile if it is raw so we do not need to serialize
4045
if (typeof(modelfile) == "raw") {
4146
warning(
@@ -45,9 +50,9 @@ xgb.load <- function(modelfile) {
4550
" `xgb.unserialize` instead. "
4651
)
4752
)
48-
bst <- xgb.handleToBooster(handle, modelfile)
53+
bst <- xgb.handleToBooster(handle = handle, raw = modelfile)
4954
} else {
50-
bst <- xgb.handleToBooster(handle, NULL)
55+
bst <- xgb.handleToBooster(handle = handle, raw = NULL)
5156
}
5257
bst <- xgb.Booster.complete(bst, saveraw = TRUE)
5358
return(bst)

R-package/R/xgb.train.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,13 @@ xgb.train <- function(params = list(), data, nrounds, watchlist = list(),
363363
is_update <- NVL(params[['process_type']], '.') == 'update'
364364

365365
# Construct a booster (either a new one or load from xgb_model)
366-
handle <- xgb.Booster.handle(params, append(watchlist, dtrain), xgb_model)
367-
bst <- xgb.handleToBooster(handle)
366+
handle <- xgb.Booster.handle(
367+
params = params,
368+
cachelist = append(watchlist, dtrain),
369+
modelfile = xgb_model,
370+
handle = NULL
371+
)
372+
bst <- xgb.handleToBooster(handle = handle, raw = NULL)
368373

369374
# extract parameters that can affect the relationship b/w #trees and #iterations
370375
num_class <- max(as.numeric(NVL(params[['num_class']], 1)), 1)

0 commit comments

Comments
 (0)