Skip to content

Commit e5bef4f

Browse files
authored
[backport] Fix threads in DMatrix slice. (dmlc#8667) (dmlc#8679)
1 parent 10bb0a7 commit e5bef4f

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

src/data/simple_dmatrix.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ DMatrix* SimpleDMatrix::Slice(common::Span<int32_t const> ridxs) {
4242
out->Info() = this->Info().Slice(ridxs);
4343
out->Info().num_nonzero_ = h_offset.back();
4444
}
45+
out->ctx_ = this->ctx_;
4546
return out;
4647
}
4748

tests/cpp/data/test_simple_dmatrix.cc

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
1-
// Copyright by Contributors
1+
/**
2+
* Copyright 2016-2023 by XGBoost Contributors
3+
*/
24
#include <xgboost/data.h>
35

4-
#include <array>
6+
#include <array> // std::array
7+
#include <limits> // std::numeric_limits
8+
#include <memory> // std::unique_ptr
59

6-
#include "../../../src/data/adapter.h"
7-
#include "../../../src/data/simple_dmatrix.h"
8-
#include "../filesystem.h" // dmlc::TemporaryDirectory
9-
#include "../helpers.h"
10+
#include "../../../src/data/adapter.h" // ArrayAdapter
11+
#include "../../../src/data/simple_dmatrix.h" // SimpleDMatrix
12+
#include "../filesystem.h" // dmlc::TemporaryDirectory
13+
#include "../helpers.h" // RandomDataGenerator,CreateSimpleTestData
1014
#include "xgboost/base.h"
15+
#include "xgboost/host_device_vector.h" // HostDeviceVector
16+
#include "xgboost/string_view.h" // StringView
1117

1218
using namespace xgboost; // NOLINT
1319

@@ -298,6 +304,17 @@ TEST(SimpleDMatrix, Slice) {
298304
ASSERT_EQ(out->Info().num_col_, out->Info().num_col_);
299305
ASSERT_EQ(out->Info().num_row_, ridxs.size());
300306
ASSERT_EQ(out->Info().num_nonzero_, ridxs.size() * kCols); // dense
307+
308+
{
309+
HostDeviceVector<float> data;
310+
auto arr_str = RandomDataGenerator{kRows, kCols, 0.0}.GenerateArrayInterface(&data);
311+
auto adapter = data::ArrayAdapter{StringView{arr_str}};
312+
auto n_threads = 2;
313+
std::unique_ptr<DMatrix> p_fmat{
314+
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), n_threads, "")};
315+
std::unique_ptr<DMatrix> slice{p_fmat->Slice(ridxs)};
316+
ASSERT_LE(slice->Ctx()->Threads(), n_threads);
317+
}
301318
}
302319

303320
TEST(SimpleDMatrix, SaveLoadBinary) {

0 commit comments

Comments
 (0)