Skip to content

Commit 479ae80

Browse files
authored
[R] Add class names to coefficients (dmlc#10745)
1 parent fd0138c commit 479ae80

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

R-package/R/xgb.Booster.R

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,17 +1109,25 @@ coef.xgb.Booster <- function(object, ...) {
11091109
if (n_cols == 1L) {
11101110
out <- c(intercepts, coefs)
11111111
if (add_names) {
1112-
names(out) <- feature_names
1112+
.Call(XGSetVectorNamesInplace_R, out, feature_names)
11131113
}
11141114
} else {
11151115
coefs <- matrix(coefs, nrow = num_feature, byrow = TRUE)
11161116
dim(intercepts) <- c(1L, n_cols)
11171117
out <- rbind(intercepts, coefs)
1118+
out_names <- vector(mode = "list", length = 2)
11181119
if (add_names) {
1119-
row.names(out) <- feature_names
1120+
out_names[[1L]] <- feature_names
11201121
}
1121-
# TODO: if a class names attributes is added,
1122-
# should use those names here.
1122+
if (inherits(object, "xgboost")) {
1123+
metadata <- attributes(object)$metadata
1124+
if (NROW(metadata$y_levels)) {
1125+
out_names[[2L]] <- metadata$y_levels
1126+
} else if (NROW(metadata$y_names)) {
1127+
out_names[[2L]] <- metadata$y_names
1128+
}
1129+
}
1130+
.Call(XGSetArrayDimNamesInplace_R, out, out_names)
11231131
}
11241132
return(out)
11251133
}

R-package/tests/testthat/test_basic.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,19 @@ test_that("Coefficients from gblinear have the expected shape and names", {
750750
pred_auto <- predict(model, x, outputmargin = TRUE)
751751
pred_manual <- unname(mm %*% coefs)
752752
expect_equal(pred_manual, pred_auto, tolerance = 1e-7)
753+
754+
# xgboost() with additional metadata
755+
model <- xgboost(
756+
iris[, -5],
757+
iris$Species,
758+
booster = "gblinear",
759+
objective = "multi:softprob",
760+
nrounds = 3,
761+
nthread = 1
762+
)
763+
coefs <- coef(model)
764+
expect_equal(row.names(coefs), c("(Intercept)", colnames(x)))
765+
expect_equal(colnames(coefs), levels(iris$Species))
753766
})
754767

755768
test_that("Deep copies work as expected", {

0 commit comments

Comments
 (0)