|
5 | 5 |
|
6 | 6 | #include <utility>
|
7 | 7 |
|
8 |
| -#include "../../../src/common/categorical.h" |
| 8 | +#include "../../../src/common/categorical.h" // for AsCat |
| 9 | +#include "../../../src/common/compressed_iterator.h" // for CompressedByteT |
9 | 10 | #include "../../../src/common/hist_util.h"
|
10 | 11 | #include "../../../src/data/ellpack_page.cuh"
|
11 | 12 | #include "../../../src/data/ellpack_page.h"
|
12 |
| -#include "../../../src/tree/param.h" // TrainParam |
| 13 | +#include "../../../src/data/gradient_index.h" // for GHistIndexMatrix |
| 14 | +#include "../../../src/tree/param.h" // TrainParam |
13 | 15 | #include "../helpers.h"
|
14 | 16 | #include "../histogram_helpers.h"
|
15 | 17 | #include "gtest/gtest.h"
|
@@ -91,7 +93,7 @@ TEST(EllpackPage, FromCategoricalBasic) {
|
91 | 93 | auto& h_ft = m->Info().feature_types.HostVector();
|
92 | 94 | h_ft.resize(kCols, FeatureType::kCategorical);
|
93 | 95 |
|
94 |
| - Context ctx{MakeCUDACtx(0)}; |
| 96 | + auto ctx = MakeCUDACtx(0); |
95 | 97 | auto p = BatchParam{max_bins, tree::TrainParam::DftSparseThreshold()};
|
96 | 98 | auto ellpack = EllpackPage(&ctx, m.get(), p);
|
97 | 99 | auto accessor = ellpack.Impl()->GetDeviceAccessor(FstCU());
|
@@ -122,6 +124,37 @@ TEST(EllpackPage, FromCategoricalBasic) {
|
122 | 124 | }
|
123 | 125 | }
|
124 | 126 |
|
| 127 | +TEST(EllpackPage, FromCategoricalMissing) { |
| 128 | + auto ctx = MakeCUDACtx(0); |
| 129 | + |
| 130 | + std::shared_ptr<common::HistogramCuts> cuts; |
| 131 | + auto nan = std::numeric_limits<float>::quiet_NaN(); |
| 132 | + // 2 rows and 3 columns. The second column is nan, row_stride is 2. |
| 133 | + std::vector<float> data{{0.1, nan, 1, 0.2, nan, 0}}; |
| 134 | + auto p_fmat = GetDMatrixFromData(data, 2, 3); |
| 135 | + p_fmat->Info().feature_types.HostVector() = {FeatureType::kNumerical, FeatureType::kNumerical, |
| 136 | + FeatureType::kCategorical}; |
| 137 | + p_fmat->Info().feature_types.SetDevice(ctx.Device()); |
| 138 | + |
| 139 | + auto p = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; |
| 140 | + for (auto const& page : p_fmat->GetBatches<GHistIndexMatrix>(&ctx, p)) { |
| 141 | + cuts = std::make_shared<common::HistogramCuts>(page.Cuts()); |
| 142 | + } |
| 143 | + cuts->cut_ptrs_.SetDevice(ctx.Device()); |
| 144 | + cuts->cut_values_.SetDevice(ctx.Device()); |
| 145 | + cuts->min_vals_.SetDevice(ctx.Device()); |
| 146 | + for (auto const& page : p_fmat->GetBatches<EllpackPage>(&ctx, p)) { |
| 147 | + std::vector<common::CompressedByteT> h_buffer; |
| 148 | + auto h_acc = page.Impl()->GetHostAccessor(p_fmat->Info().feature_types.ConstDeviceSpan()); |
| 149 | + ASSERT_EQ(h_acc.n_rows, 2); |
| 150 | + ASSERT_EQ(h_acc.row_stride, 2); |
| 151 | + ASSERT_EQ(h_acc.gidx_iter[0], 0); |
| 152 | + ASSERT_EQ(h_acc.gidx_iter[1], 4); // cat 1 |
| 153 | + ASSERT_EQ(h_acc.gidx_iter[2], 1); |
| 154 | + ASSERT_EQ(h_acc.gidx_iter[3], 3); // cat 0 |
| 155 | + } |
| 156 | +} |
| 157 | + |
125 | 158 | struct ReadRowFunction {
|
126 | 159 | EllpackDeviceAccessor matrix;
|
127 | 160 | int row;
|
|
0 commit comments