Skip to content

Commit d6059f4

Browse files
authored
[BP] Test ellpack categorical feature with missing values. (dmlc#10906) (dmlc#10912)
1 parent 40742a9 commit d6059f4

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/data/ellpack_page.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ __global__ void CompressBinEllpackKernel(
7070
// {feature_cuts, ncuts} forms the array of cuts of `feature'.
7171
const float* feature_cuts = &cuts[cut_ptrs[feature]];
7272
int ncuts = cut_ptrs[feature + 1] - cut_ptrs[feature];
73-
bool is_cat = common::IsCat(feature_types, ifeature);
73+
bool is_cat = common::IsCat(feature_types, feature);
7474
// Assigning the bin in current entry.
7575
// S.t.: fvalue < feature_cuts[bin]
7676
if (is_cat) {

tests/cpp/data/test_ellpack_page.cu

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
#include <utility>
77

8-
#include "../../../src/common/categorical.h"
8+
#include "../../../src/common/categorical.h" // for AsCat
9+
#include "../../../src/common/compressed_iterator.h" // for CompressedByteT
910
#include "../../../src/common/hist_util.h"
1011
#include "../../../src/data/ellpack_page.cuh"
1112
#include "../../../src/data/ellpack_page.h"
12-
#include "../../../src/tree/param.h" // TrainParam
13+
#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix
14+
#include "../../../src/tree/param.h" // TrainParam
1315
#include "../helpers.h"
1416
#include "../histogram_helpers.h"
1517
#include "gtest/gtest.h"
@@ -91,7 +93,7 @@ TEST(EllpackPage, FromCategoricalBasic) {
9193
auto& h_ft = m->Info().feature_types.HostVector();
9294
h_ft.resize(kCols, FeatureType::kCategorical);
9395

94-
Context ctx{MakeCUDACtx(0)};
96+
auto ctx = MakeCUDACtx(0);
9597
auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
9698
auto ellpack = EllpackPage(&ctx, m.get(), p);
9799
auto accessor = ellpack.Impl()->GetDeviceAccessor(FstCU());
@@ -122,6 +124,37 @@ TEST(EllpackPage, FromCategoricalBasic) {
122124
}
123125
}
124126

127+
TEST(EllpackPage, FromCategoricalMissing) {
128+
auto ctx = MakeCUDACtx(0);
129+
130+
std::shared_ptr<common::HistogramCuts> cuts;
131+
auto nan = std::numeric_limits<float>::quiet_NaN();
132+
// 2 rows and 3 columns. The second column is nan, row_stride is 2.
133+
std::vector<float> data{{0.1, nan, 1, 0.2, nan, 0}};
134+
auto p_fmat = GetDMatrixFromData(data, 2, 3);
135+
p_fmat->Info().feature_types.HostVector() = {FeatureType::kNumerical, FeatureType::kNumerical,
136+
FeatureType::kCategorical};
137+
p_fmat->Info().feature_types.SetDevice(ctx.Device());
138+
139+
auto p = BatchParam{256, tree::TrainParam::DftSparseThreshold()};
140+
for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p)) {
141+
cuts = std::make_shared<common::HistogramCuts>(page.Cuts());
142+
}
143+
cuts->cut_ptrs_.SetDevice(ctx.Device());
144+
cuts->cut_values_.SetDevice(ctx.Device());
145+
cuts->min_vals_.SetDevice(ctx.Device());
146+
for (auto const& page : p_fmat->GetBatches<EllpackPage>(&ctx, p)) {
147+
std::vector<common::CompressedByteT> h_buffer;
148+
auto h_acc = page.Impl()->GetHostAccessor(p_fmat->Info().feature_types.ConstDeviceSpan());
149+
ASSERT_EQ(h_acc.n_rows, 2);
150+
ASSERT_EQ(h_acc.row_stride, 2);
151+
ASSERT_EQ(h_acc.gidx_iter[0], 0);
152+
ASSERT_EQ(h_acc.gidx_iter[1], 4); // cat 1
153+
ASSERT_EQ(h_acc.gidx_iter[2], 1);
154+
ASSERT_EQ(h_acc.gidx_iter[3], 3); // cat 0
155+
}
156+
}
157+
125158
struct ReadRowFunction {
126159
EllpackDeviceAccessor matrix;
127160
int row;

0 commit comments

Comments
 (0)