Skip to content

Commit caabee2

Browse files
authored
[R] remove 'reshape' argument, let shapes be handled by core cpp library (dmlc#10330)
1 parent fd365c1 commit caabee2

File tree

13 files changed

+240
-249
lines changed

13 files changed

+240
-249
lines changed

R-package/R/callbacks.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,7 @@ xgb.cb.cv.predict <- function(save_models = FALSE, outputmargin = FALSE) {
853853
pr <- predict(
854854
fd$bst,
855855
fd$evals[[2L]],
856-
outputmargin = env$outputmargin,
857-
reshape = TRUE
856+
outputmargin = env$outputmargin
858857
)
859858
if (is.null(pred)) {
860859
if (NCOL(pr) > 1L) {

R-package/R/utils.R

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ xgb.iter.update <- function(bst, dtrain, iter, obj) {
199199
bst,
200200
dtrain,
201201
outputmargin = TRUE,
202-
training = TRUE,
203-
reshape = TRUE
202+
training = TRUE
204203
)
205204
gpair <- obj(pred, dtrain)
206205
n_samples <- dim(dtrain)[1]
@@ -246,7 +245,7 @@ xgb.iter.eval <- function(bst, evals, iter, feval) {
246245
res <- sapply(seq_along(evals), function(j) {
247246
w <- evals[[j]]
248247
## predict using all trees
249-
preds <- predict(bst, w, outputmargin = TRUE, reshape = TRUE, iterationrange = "all")
248+
preds <- predict(bst, w, outputmargin = TRUE, iterationrange = "all")
250249
eval_res <- feval(preds, w)
251250
out <- eval_res$value
252251
names(out) <- paste0(evnames[j], "-", eval_res$metric)

R-package/R/xgb.Booster.R

Lines changed: 76 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,6 @@ xgb.get.handle <- function(object) {
112112
#' @param predcontrib Whether to return feature contributions to individual predictions (see Details).
113113
#' @param approxcontrib Whether to use a fast approximation for feature contributions (see Details).
114114
#' @param predinteraction Whether to return contributions of feature interactions to individual predictions (see Details).
115-
#' @param reshape Whether to reshape the vector of predictions to matrix form when there are several
116-
#' prediction outputs per case. No effect if `predleaf`, `predcontrib`,
117-
#' or `predinteraction` is `TRUE`.
118115
#' @param training Whether the prediction result is used for training. For dart booster,
119116
#' training predicting will perform dropout.
120117
#' @param iterationrange Sequence of rounds/iterations from the model to use for prediction, specified by passing
@@ -128,8 +125,24 @@ xgb.get.handle <- function(object) {
128125
#' of the iterations (rounds) otherwise.
129126
#'
130127
#' If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.
131-
#' @param strict_shape Default is `FALSE`. When set to `TRUE`, the output
132-
#' type and shape of predictions are invariant to the model type.
128+
#' @param strict_shape Whether to always return an array with the same dimensions for the given prediction mode
129+
#' regardless of the model type - meaning that, for example, both a multi-class and a binary classification
130+
#' model would generate output arrays with the same number of dimensions, with the 'class' dimension having
131+
#' size equal to '1' for the binary model.
132+
#'
133+
#' If passing `FALSE` (the default), dimensions will be simplified according to the model type, so that a
134+
#' binary classification model for example would not have a redundant dimension for 'class'.
135+
#'
136+
#' See documentation for the return type for the exact shape of the output arrays for each prediction mode.
137+
#' @param avoid_transpose Whether to output the resulting predictions in the same memory layout in which they
138+
#' are generated by the core XGBoost library, without transposing them to match the expected output shape.
139+
#'
140+
#' Internally, XGBoost uses row-major order for the predictions it generates, while R arrays use column-major
141+
#' order, hence the result needs to be transposed in order to have the expected shape when represented as
142+
#' an R array or matrix, which might be a slow operation.
143+
#'
144+
#' If passing `TRUE`, then the result will have dimensions in reverse order - for example, rows
145+
#' will be the last dimensions instead of the first dimension.
133146
#' @param base_margin Base margin used for boosting from existing model.
134147
#'
135148
#' Note that, if `newdata` is an `xgb.DMatrix` object, this argument will
@@ -180,28 +193,46 @@ xgb.get.handle <- function(object) {
180193
#' Note that converting a matrix to [xgb.DMatrix()] uses multiple threads too.
181194
#'
182195
#' @return
183-
#' The return type depends on `strict_shape`. If `FALSE` (default):
184-
#' - For regression or binary classification: A vector of length `nrows(newdata)`.
185-
#' - For multiclass classification: A vector of length `num_class * nrows(newdata)` or
186-
#' a `(nrows(newdata), num_class)` matrix, depending on the `reshape` value.
187-
#' - When `predleaf = TRUE`: A matrix with one column per tree.
188-
#' - When `predcontrib = TRUE`: When not multiclass, a matrix with
189-
#' ` num_features + 1` columns. The last "+ 1" column corresponds to the baseline value.
190-
#' In the multiclass case, a list of `num_class` such matrices.
191-
#' The contribution values are on the scale of untransformed margin
192-
#' (e.g., for binary classification, the values are log-odds deviations from the baseline).
193-
#' - When `predinteraction = TRUE`: When not multiclass, the output is a 3d array of
194-
#' dimension `c(nrow, num_features + 1, num_features + 1)`. The off-diagonal (in the last two dimensions)
195-
#' elements represent different feature interaction contributions. The array is symmetric WRT the last
196-
#' two dimensions. The "+ 1" columns corresponds to the baselines. Summing this array along the last dimension should
197-
#' produce practically the same result as `predcontrib = TRUE`.
198-
#' In the multiclass case, a list of `num_class` such arrays.
199-
#'
200-
#' When `strict_shape = TRUE`, the output is always an array:
201-
#' - For normal predictions, the output has dimension `(num_class, nrow(newdata))`.
202-
#' - For `predcontrib = TRUE`, the dimension is `(ncol(newdata) + 1, num_class, nrow(newdata))`.
203-
#' - For `predinteraction = TRUE`, the dimension is `(ncol(newdata) + 1, ncol(newdata) + 1, num_class, nrow(newdata))`.
204-
#' - For `predleaf = TRUE`, the dimension is `(n_trees_in_forest, num_class, n_iterations, nrow(newdata))`.
196+
#' A numeric vector or array, with corresponding dimensions depending on the prediction mode and on
197+
#' parameter `strict_shape` as follows:
198+
#'
199+
#' If passing `strict_shape=FALSE`:\itemize{
200+
#' \item For regression or binary classification: a vector of length `nrows`.
201+
#' \item For multi-class and multi-target objectives: a matrix of dimensions `[nrows, ngroups]`.
202+
#'
203+
#' Note that objective variant `multi:softmax` defaults towards predicting most likely class (a vector
204+
#' `nrows`) instead of per-class probabilities.
205+
#' \item For `predleaf`: a matrix with one column per tree.
206+
#'
207+
#' For multi-class / multi-target, they will be arranged so that columns in the output will have
208+
#' the leafs from one group followed by leafs of the other group (e.g. order will be `group1:feat1`,
209+
#' `group1:feat2`, ..., `group2:feat1`, `group2:feat2`, ...).
210+
#' \item For `predcontrib`: when not multi-class / multi-target, a matrix with dimensions
211+
#' `[nrows, nfeats+1]`. The last "+ 1" column corresponds to the baseline value.
212+
#'
213+
#' For multi-class and multi-target objectives, will be an array with dimensions `[nrows, ngroups, nfeats+1]`.
214+
#'
215+
#' The contribution values are on the scale of untransformed margin (e.g., for binary classification,
216+
#' the values are log-odds deviations from the baseline).
217+
#' \item For `predinteraction`: when not multi-class / multi-target, the output is a 3D array of
218+
#' dimensions `[nrows, nfeats+1, nfeats+1]`. The off-diagonal (in the last two dimensions)
219+
#' elements represent different feature interaction contributions. The array is symmetric w.r.t. the last
220+
#' two dimensions. The "+ 1" columns corresponds to the baselines. Summing this array along the last
221+
#' dimension should produce practically the same result as `predcontrib = TRUE`.
222+
#'
223+
#' For multi-class and multi-target, will be a 4D array with dimensions `[nrows, ngroups, nfeats+1, nfeats+1]`
224+
#' }
225+
#'
226+
#' If passing `strict_shape=FALSE`, the result is always an array:\itemize{
227+
#' \item For normal predictions, the dimension is `[nrows, ngroups]`.
228+
#' \item For `predcontrib=TRUE`, the dimension is `[nrows, ngroups, nfeats+1]`.
229+
#' \item For `predinteraction=TRUE`, the dimension is `[nrows, ngroups, nfeats+1, nfeats+1]`.
230+
#' \item For `predleaf=TRUE`, the dimension is `[nrows, niter, ngroups, num_parallel_tree]`.
231+
#' }
232+
#'
233+
#' If passing `avoid_transpose=TRUE`, then the dimensions in all cases will be in reverse order - for
234+
#' example, for `predinteraction`, they will be `[nfeats+1, nfeats+1, ngroups, nrows]`
235+
#' instead of `[nrows, ngroups, nfeats+1, nfeats+1]`.
205236
#' @seealso [xgb.train()]
206237
#' @references
207238
#' 1. Scott M. Lundberg, Su-In Lee, "A Unified Approach to Interpreting Model Predictions",
@@ -279,8 +310,6 @@ xgb.get.handle <- function(object) {
279310
#' # predict for softmax returns num_class probability numbers per case:
280311
#' pred <- predict(bst, as.matrix(iris[, -5]))
281312
#' str(pred)
282-
#' # reshape it to a num_class-columns matrix
283-
#' pred <- matrix(pred, ncol = num_class, byrow = TRUE)
284313
#' # convert the probabilities to softmax labels
285314
#' pred_labels <- max.col(pred) - 1
286315
#' # the following should result in the same error as seen in the last iteration
@@ -311,8 +340,11 @@ xgb.get.handle <- function(object) {
311340
#' @export
312341
predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FALSE,
313342
predleaf = FALSE, predcontrib = FALSE, approxcontrib = FALSE, predinteraction = FALSE,
314-
reshape = FALSE, training = FALSE, iterationrange = NULL, strict_shape = FALSE,
343+
training = FALSE, iterationrange = NULL, strict_shape = FALSE, avoid_transpose = FALSE,
315344
validate_features = FALSE, base_margin = NULL, ...) {
345+
if (NROW(list(...))) {
346+
warning("Passed unused prediction arguments: ", paste(names(list(...)), collapse = ", "), ".")
347+
}
316348
if (validate_features) {
317349
newdata <- validate.features(object, newdata)
318350
}
@@ -415,10 +447,9 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
415447
return(val)
416448
}
417449

418-
## We set strict_shape to TRUE then drop the dimensions conditionally
419450
args <- list(
420451
training = box(training),
421-
strict_shape = box(TRUE),
452+
strict_shape = as.logical(strict_shape),
422453
iteration_begin = box(as.integer(iterationrange[1])),
423454
iteration_end = box(as.integer(iterationrange[2])),
424455
type = box(as.integer(0))
@@ -445,96 +476,36 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
445476

446477
json_conf <- jsonlite::toJSON(args, auto_unbox = TRUE)
447478
if (is_dmatrix) {
448-
predts <- .Call(
479+
arr <- .Call(
449480
XGBoosterPredictFromDMatrix_R, xgb.get.handle(object), newdata, json_conf
450481
)
451482
} else if (use_as_dense_matrix) {
452-
predts <- .Call(
483+
arr <- .Call(
453484
XGBoosterPredictFromDense_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
454485
)
455486
} else if (use_as_csr_matrix) {
456-
predts <- .Call(
487+
arr <- .Call(
457488
XGBoosterPredictFromCSR_R, xgb.get.handle(object), csr_data, missing, json_conf, base_margin
458489
)
459490
} else if (use_as_df) {
460-
predts <- .Call(
491+
arr <- .Call(
461492
XGBoosterPredictFromColumnar_R, xgb.get.handle(object), newdata, missing, json_conf, base_margin
462493
)
463494
}
464495

465-
names(predts) <- c("shape", "results")
466-
shape <- predts$shape
467-
arr <- predts$results
468-
469-
n_ret <- length(arr)
470-
if (n_row != shape[1]) {
471-
stop("Incorrect predict shape.")
472-
}
473-
474-
.Call(XGSetArrayDimInplace_R, arr, rev(shape))
475-
476-
cnames <- if (!is.null(colnames(newdata))) c(colnames(newdata), "(Intercept)") else NULL
477-
n_groups <- shape[2]
478-
479496
## Needed regardless of whether strict shape is being used.
480-
if (predcontrib) {
481-
.Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, NULL, NULL))
482-
} else if (predinteraction) {
483-
.Call(XGSetArrayDimNamesInplace_R, arr, list(cnames, cnames, NULL, NULL))
484-
}
485-
if (strict_shape) {
486-
return(arr) # strict shape is calculated by libxgboost uniformly.
497+
if ((predcontrib || predinteraction) && !is.null(colnames(newdata))) {
498+
cnames <- c(colnames(newdata), "(Intercept)")
499+
dim_names <- vector(mode = "list", length = length(dim(arr)))
500+
dim_names[[1L]] <- cnames
501+
if (predinteraction) dim_names[[2L]] <- cnames
502+
.Call(XGSetArrayDimNamesInplace_R, arr, dim_names)
487503
}
488504

489-
if (predleaf) {
490-
## Predict leaf
491-
if (n_ret == n_row) {
492-
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
493-
} else {
494-
arr <- matrix(arr, nrow = n_row, byrow = TRUE)
495-
}
496-
} else if (predcontrib) {
497-
## Predict contribution
498-
arr <- aperm(a = arr, perm = c(2, 3, 1)) # [group, row, col]
499-
if (n_ret == n_row) {
500-
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
501-
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
502-
} else if (n_groups != 1) {
503-
## turns array into list of matrices
504-
arr <- lapply(seq_len(n_groups), function(g) arr[g, , ])
505-
} else {
506-
## remove the first axis (group)
507-
newdim <- dim(arr)[2:3]
508-
newdn <- dimnames(arr)[2:3]
509-
arr <- arr[1, , ]
510-
.Call(XGSetArrayDimInplace_R, arr, newdim)
511-
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
512-
}
513-
} else if (predinteraction) {
514-
## Predict interaction
515-
arr <- aperm(a = arr, perm = c(3, 4, 1, 2)) # [group, row, col, col]
516-
if (n_ret == n_row) {
517-
.Call(XGSetArrayDimInplace_R, arr, c(n_row, 1L))
518-
.Call(XGSetArrayDimNamesInplace_R, arr, list(NULL, cnames))
519-
} else if (n_groups != 1) {
520-
## turns array into list of matrices
521-
arr <- lapply(seq_len(n_groups), function(g) arr[g, , , ])
522-
} else {
523-
## remove the first axis (group)
524-
arr <- arr[1, , , , drop = FALSE]
525-
newdim <- dim(arr)[2:4]
526-
newdn <- dimnames(arr)[2:4]
527-
.Call(XGSetArrayDimInplace_R, arr, newdim)
528-
.Call(XGSetArrayDimNamesInplace_R, arr, newdn)
529-
}
530-
} else {
531-
## Normal prediction
532-
if (reshape && n_groups != 1) {
533-
arr <- matrix(arr, ncol = n_groups, byrow = TRUE)
534-
} else {
535-
.Call(XGSetArrayDimInplace_R, arr, NULL)
536-
}
505+
if (!avoid_transpose && is.array(arr)) {
506+
arr <- aperm(arr)
537507
}
508+
538509
return(arr)
539510
}
540511

R-package/R/xgb.plot.shap.R

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,10 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
294294
if (is.null(features) && (is.null(model) || !inherits(model, "xgb.Booster")))
295295
stop("when features are not provided, one must provide an xgb.Booster model to rank the features")
296296

297+
last_dim <- function(v) dim(v)[length(dim(v))]
298+
297299
if (!is.null(shap_contrib) &&
298-
(!is.matrix(shap_contrib) || nrow(shap_contrib) != nrow(data) || ncol(shap_contrib) != ncol(data) + 1))
300+
(!is.array(shap_contrib) || nrow(shap_contrib) != nrow(data) || last_dim(shap_contrib) != ncol(data) + 1))
299301
stop("shap_contrib is not compatible with the provided data")
300302

301303
if (is.character(features) && is.null(colnames(data)))
@@ -318,19 +320,39 @@ xgb.shap.data <- function(data, shap_contrib = NULL, features = NULL, top_n = 1,
318320
colnames(data) <- paste0("X", seq_len(ncol(data)))
319321
}
320322

321-
if (!is.null(shap_contrib)) {
322-
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
323-
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs))
324-
}
325-
shap_contrib <- shap_contrib[idx, ]
326-
if (is.null(colnames(shap_contrib))) {
327-
colnames(shap_contrib) <- paste0("X", seq_len(ncol(data)))
328-
}
329-
} else {
330-
shap_contrib <- predict(model, newdata = data, predcontrib = TRUE, approxcontrib = approxcontrib)
331-
if (is.list(shap_contrib)) { # multiclass: either choose a class or merge
332-
shap_contrib <- if (!is.null(target_class)) shap_contrib[[target_class + 1]] else Reduce("+", lapply(shap_contrib, abs))
323+
reshape_3d_shap_contrib <- function(shap_contrib, target_class) {
324+
# multiclass: either choose a class or merge
325+
if (is.list(shap_contrib)) {
326+
if (!is.null(target_class)) {
327+
shap_contrib <- shap_contrib[[target_class + 1]]
328+
} else {
329+
shap_contrib <- Reduce("+", lapply(shap_contrib, abs))
330+
}
331+
} else if (length(dim(shap_contrib)) > 2) {
332+
if (!is.null(target_class)) {
333+
orig_shape <- dim(shap_contrib)
334+
shap_contrib <- shap_contrib[, target_class + 1, , drop = TRUE]
335+
if (!is.matrix(shap_contrib)) {
336+
shap_contrib <- matrix(shap_contrib, orig_shape[c(1L, 3L)])
337+
}
338+
} else {
339+
shap_contrib <- apply(abs(shap_contrib), c(1L, 3L), sum)
340+
}
333341
}
342+
return(shap_contrib)
343+
}
344+
345+
if (is.null(shap_contrib)) {
346+
shap_contrib <- predict(
347+
model,
348+
newdata = data,
349+
predcontrib = TRUE,
350+
approxcontrib = approxcontrib
351+
)
352+
}
353+
shap_contrib <- reshape_3d_shap_contrib(shap_contrib, target_class)
354+
if (is.null(colnames(shap_contrib))) {
355+
colnames(shap_contrib) <- paste0("X", seq_len(ncol(data)))
334356
}
335357

336358
if (is.null(features)) {

0 commit comments

Comments
 (0)