Skip to content

Commit e50ccc4

Browse files
authored
[R] Fix integer inputs with NA. (dmlc#9522) (dmlc#9534)
1 parent add57f8 commit e50ccc4

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

R-package/src/xgboost_R.cc

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,11 +120,25 @@ XGB_DLL SEXP XGDMatrixCreateFromMat_R(SEXP mat, SEXP missing, SEXP n_threads) {
120120
ctx.nthread = asInteger(n_threads);
121121
std::int32_t threads = ctx.Threads();
122122

123-
xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
124-
for (size_t j = 0; j < ncol; ++j) {
125-
data[i * ncol + j] = is_int ? static_cast<float>(iin[i + nrow * j]) : din[i + nrow * j];
126-
}
127-
});
123+
if (is_int) {
124+
xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
125+
for (size_t j = 0; j < ncol; ++j) {
126+
auto v = iin[i + nrow * j];
127+
if (v == NA_INTEGER) {
128+
data[i * ncol + j] = std::numeric_limits<float>::quiet_NaN();
129+
} else {
130+
data[i * ncol + j] = static_cast<float>(v);
131+
}
132+
}
133+
});
134+
} else {
135+
xgboost::common::ParallelFor(nrow, threads, [&](xgboost::omp_ulong i) {
136+
for (size_t j = 0; j < ncol; ++j) {
137+
data[i * ncol + j] = din[i + nrow * j];
138+
}
139+
});
140+
}
141+
128142
DMatrixHandle handle;
129143
CHECK_CALL(XGDMatrixCreateFromMat_omp(BeginPtr(data), nrow, ncol,
130144
asReal(missing), &handle, threads));

R-package/tests/testthat/test_dmatrix.R

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,42 @@ test_that("xgb.DMatrix: basic construction", {
5656
expect_equal(raw_fd, raw_dgc)
5757
})
5858

59+
test_that("xgb.DMatrix: NA", {
60+
n_samples <- 3
61+
x <- cbind(
62+
x1 = sample(x = 4, size = n_samples, replace = TRUE),
63+
x2 = sample(x = 4, size = n_samples, replace = TRUE)
64+
)
65+
x[1, "x1"] <- NA
66+
67+
m <- xgb.DMatrix(x)
68+
xgb.DMatrix.save(m, "int.dmatrix")
69+
70+
x <- matrix(as.numeric(x), nrow = n_samples, ncol = 2)
71+
colnames(x) <- c("x1", "x2")
72+
m <- xgb.DMatrix(x)
73+
74+
xgb.DMatrix.save(m, "float.dmatrix")
75+
76+
iconn <- file("int.dmatrix", "rb")
77+
fconn <- file("float.dmatrix", "rb")
78+
79+
expect_equal(file.size("int.dmatrix"), file.size("float.dmatrix"))
80+
81+
bytes <- file.size("int.dmatrix")
82+
idmatrix <- readBin(iconn, "raw", n = bytes)
83+
fdmatrix <- readBin(fconn, "raw", n = bytes)
84+
85+
expect_equal(length(idmatrix), length(fdmatrix))
86+
expect_equal(idmatrix, fdmatrix)
87+
88+
close(iconn)
89+
close(fconn)
90+
91+
file.remove("int.dmatrix")
92+
file.remove("float.dmatrix")
93+
})
94+
5995
test_that("xgb.DMatrix: saving, loading", {
6096
# save to a local file
6197
dtest1 <- xgb.DMatrix(test_data, label = test_label)

0 commit comments

Comments
 (0)