Skip to content

Commit e9f1abc

Browse files
authored
[R] keep row names in predictions (dmlc#10727)
1 parent adf87b2 commit e9f1abc

File tree

6 files changed

+84
-5
lines changed

6 files changed

+84
-5
lines changed

R-package/R/xgb.Booster.R

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,11 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
354354
" Should be passed as argument to 'xgb.DMatrix' constructor."
355355
)
356356
}
357+
if (is_dmatrix) {
358+
rnames <- NULL
359+
} else {
360+
rnames <- row.names(newdata)
361+
}
357362

358363
use_as_df <- FALSE
359364
use_as_dense_matrix <- FALSE
@@ -501,6 +506,19 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
501506
.Call(XGSetArrayDimNamesInplace_R, arr, dim_names)
502507
}
503508

509+
if (NROW(rnames)) {
510+
if (is.null(dim(arr))) {
511+
.Call(XGSetVectorNamesInplace_R, arr, rnames)
512+
} else {
513+
dim_names <- dimnames(arr)
514+
if (is.null(dim_names)) {
515+
dim_names <- vector(mode = "list", length = length(dim(arr)))
516+
}
517+
dim_names[[length(dim_names)]] <- rnames
518+
.Call(XGSetArrayDimNamesInplace_R, arr, dim_names)
519+
}
520+
}
521+
504522
if (!avoid_transpose && is.array(arr)) {
505523
arr <- aperm(arr)
506524
}

R-package/src/init.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ extern SEXP XGBoosterSetParam_R(SEXP, SEXP, SEXP);
4646
extern SEXP XGBoosterUpdateOneIter_R(SEXP, SEXP, SEXP);
4747
extern SEXP XGCheckNullPtr_R(SEXP);
4848
extern SEXP XGSetArrayDimNamesInplace_R(SEXP, SEXP);
49+
extern SEXP XGSetVectorNamesInplace_R(SEXP, SEXP);
4950
extern SEXP XGDMatrixCreateFromCSC_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
5051
extern SEXP XGDMatrixCreateFromCSR_R(SEXP, SEXP, SEXP, SEXP, SEXP, SEXP);
5152
extern SEXP XGDMatrixCreateFromURI_R(SEXP, SEXP, SEXP);
@@ -108,6 +109,7 @@ static const R_CallMethodDef CallEntries[] = {
108109
{"XGBoosterUpdateOneIter_R", (DL_FUNC) &XGBoosterUpdateOneIter_R, 3},
109110
{"XGCheckNullPtr_R", (DL_FUNC) &XGCheckNullPtr_R, 1},
110111
{"XGSetArrayDimNamesInplace_R", (DL_FUNC) &XGSetArrayDimNamesInplace_R, 2},
112+
{"XGSetVectorNamesInplace_R", (DL_FUNC) &XGSetVectorNamesInplace_R, 2},
111113
{"XGDMatrixCreateFromCSC_R", (DL_FUNC) &XGDMatrixCreateFromCSC_R, 6},
112114
{"XGDMatrixCreateFromCSR_R", (DL_FUNC) &XGDMatrixCreateFromCSR_R, 6},
113115
{"XGDMatrixCreateFromURI_R", (DL_FUNC) &XGDMatrixCreateFromURI_R, 3},

R-package/src/xgboost_R.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,11 @@ XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names) {
335335
return R_NilValue;
336336
}
337337

338+
XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names) {
339+
Rf_setAttrib(arr, R_NamesSymbol, names);
340+
return R_NilValue;
341+
}
342+
338343
namespace {
339344
void _DMatrixFinalizer(SEXP ext) {
340345
R_API_BEGIN();

R-package/src/xgboost_R.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ XGB_DLL SEXP XGCheckNullPtr_R(SEXP handle);
3434
*/
3535
XGB_DLL SEXP XGSetArrayDimNamesInplace_R(SEXP arr, SEXP dim_names);
3636

37+
/*!
38+
* \brief set the names of a vector in-place
39+
* \param arr
40+
* \param names names for the dimensions to set
41+
* \return NULL value
42+
*/
43+
XGB_DLL SEXP XGSetVectorNamesInplace_R(SEXP arr, SEXP names);
44+
3745
/*!
3846
* \brief Set global configuration
3947
* \param json_str a JSON string representing the list of key-value pairs

R-package/tests/testthat/test_basic.R

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ test_that("Can predict on data.frame objects", {
678678

679679
pred_mat <- predict(model, xgb.DMatrix(x_mat))
680680
pred_df <- predict(model, x_df)
681-
expect_equal(pred_mat, pred_df)
681+
expect_equal(pred_mat, unname(pred_df))
682682
})
683683

684684
test_that("'base_margin' gives the same result in DMatrix as in inplace_predict", {
@@ -702,7 +702,7 @@ test_that("'base_margin' gives the same result in DMatrix as in inplace_predict"
702702
pred_from_dm <- predict(model, dm_w_base)
703703
pred_from_mat <- predict(model, x, base_margin = base_margin)
704704

705-
expect_equal(pred_from_dm, pred_from_mat)
705+
expect_equal(pred_from_dm, unname(pred_from_mat))
706706
})
707707

708708
test_that("Coefficients from gblinear have the expected shape and names", {
@@ -725,7 +725,7 @@ test_that("Coefficients from gblinear have the expected shape and names", {
725725
expect_equal(names(coefs), c("(Intercept)", colnames(x)))
726726
pred_auto <- predict(model, x)
727727
pred_manual <- as.numeric(mm %*% coefs)
728-
expect_equal(pred_manual, pred_auto, tolerance = 1e-5)
728+
expect_equal(pred_manual, unname(pred_auto), tolerance = 1e-5)
729729

730730
# Multi-column coefficients
731731
data(iris)
@@ -949,3 +949,47 @@ test_that("xgb.cv works for ranking", {
949949
)
950950
expect_equal(length(res$folds), 2L)
951951
})
952+
953+
test_that("Row names are preserved in outputs", {
954+
data(iris)
955+
x <- iris[, -5]
956+
y <- as.numeric(iris$Species) - 1
957+
dm <- xgb.DMatrix(x, label = y, nthread = 1)
958+
model <- xgb.train(
959+
data = dm,
960+
params = list(
961+
objective = "multi:softprob",
962+
num_class = 3,
963+
max_depth = 2,
964+
nthread = 1
965+
),
966+
nrounds = 3
967+
)
968+
row.names(x) <- paste0("r", seq(1, nrow(x)))
969+
pred <- predict(model, x)
970+
expect_equal(row.names(pred), row.names(x))
971+
pred <- predict(model, x, avoid_transpose = TRUE)
972+
expect_equal(colnames(pred), row.names(x))
973+
974+
data(mtcars)
975+
y <- mtcars[, 1]
976+
x <- as.matrix(mtcars[, -1])
977+
dm <- xgb.DMatrix(data = x, label = y)
978+
model <- xgb.train(
979+
data = dm,
980+
params = list(
981+
max_depth = 2,
982+
nthread = 1
983+
),
984+
nrounds = 3
985+
)
986+
row.names(x) <- paste0("r", seq(1, nrow(x)))
987+
pred <- predict(model, x)
988+
expect_equal(names(pred), row.names(x))
989+
pred <- predict(model, x, avoid_transpose = TRUE)
990+
expect_equal(names(pred), row.names(x))
991+
pred <- predict(model, x, predleaf = TRUE)
992+
expect_equal(row.names(pred), row.names(x))
993+
pred <- predict(model, x, predleaf = TRUE, avoid_transpose = TRUE)
994+
expect_equal(colnames(pred), row.names(x))
995+
})

R-package/tests/testthat/test_dmatrix.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
493493
nrounds = 5
494494
)
495495
pred <- predict(model, x)
496+
pred <- unname(pred)
496497

497498
iterator_env <- as.environment(
498499
list(
@@ -538,7 +539,7 @@ test_that("xgb.DMatrix: ExternalDMatrix produces the same results as regular DMa
538539
)
539540

540541
pred_model1_edm <- predict(model, edm)
541-
pred_model2_mat <- predict(model_ext, x)
542+
pred_model2_mat <- predict(model_ext, x) |> unname()
542543
pred_model2_edm <- predict(model_ext, edm)
543544

544545
expect_equal(pred_model1_edm, pred)
@@ -567,6 +568,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
567568
nrounds = 5
568569
)
569570
pred <- predict(model, x)
571+
pred <- unname(pred)
570572

571573
iterator_env <- as.environment(
572574
list(
@@ -616,7 +618,7 @@ test_that("xgb.DMatrix: External QDM produces same results as regular QDM", {
616618
)
617619

618620
pred_model1_qdm <- predict(model, qdm)
619-
pred_model2_mat <- predict(model_ext, x)
621+
pred_model2_mat <- predict(model_ext, x) |> unname()
620622
pred_model2_qdm <- predict(model_ext, qdm)
621623

622624
expect_equal(pred_model1_qdm, pred)

0 commit comments

Comments
 (0)