Skip to content

Commit 142bdc7

Browse files
authored
[EM] Support SHAP contribution with QDM. (dmlc#10724)
- Add GPU support. - Add external memory support. - Update the GPU tree shap.
1 parent cb54374 commit 142bdc7

File tree

13 files changed

+274
-159
lines changed

13 files changed

+274
-159
lines changed

src/predictor/gpu_predictor.cu

Lines changed: 70 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,9 @@ struct SparsePageLoader {
143143
};
144144

145145
struct EllpackLoader {
146-
EllpackDeviceAccessor const& matrix;
147-
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t,
148-
float)
149-
: matrix{m} {}
146+
EllpackDeviceAccessor matrix;
147+
XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor m, bool, bst_feature_t, bst_idx_t, float)
148+
: matrix{std::move(m)} {}
150149
[[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
151150
auto gidx = matrix.GetBinIndex<false>(ridx, fidx);
152151
if (gidx == -1) {
@@ -162,6 +161,8 @@ struct EllpackLoader {
162161
}
163162
return matrix.gidx_fvalue_map[gidx - 1];
164163
}
164+
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return this->matrix.NumFeatures(); }
165+
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; }
165166
};
166167

167168
template <typename Batch>
@@ -1031,9 +1032,6 @@ class GPUPredictor : public xgboost::Predictor {
10311032
if (tree_weights != nullptr) {
10321033
LOG(FATAL) << "Dart booster feature " << not_implemented;
10331034
}
1034-
if (!p_fmat->PageExists<SparsePage>()) {
1035-
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
1036-
}
10371035
CHECK(!p_fmat->Info().IsColumnSplit())
10381036
<< "Predict contribution support for column-wise data split is not yet implemented.";
10391037
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
@@ -1047,8 +1045,8 @@ class GPUPredictor : public xgboost::Predictor {
10471045
// allocate space for (number of features + bias) times the number of rows
10481046
size_t contributions_columns =
10491047
model.learner_model_param->num_feature + 1; // +1 for bias
1050-
out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
1051-
model.learner_model_param->num_output_group);
1048+
auto dim_size = contributions_columns * model.learner_model_param->num_output_group;
1049+
out_contribs->Resize(p_fmat->Info().num_row_ * dim_size);
10521050
out_contribs->Fill(0.0f);
10531051
auto phis = out_contribs->DeviceSpan();
10541052

@@ -1058,16 +1056,27 @@ class GPUPredictor : public xgboost::Predictor {
10581056
d_model.Init(model, 0, tree_end, ctx_->Device());
10591057
dh::device_vector<uint32_t> categories;
10601058
ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device());
1061-
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
1062-
batch.data.SetDevice(ctx_->Device());
1063-
batch.offset.SetDevice(ctx_->Device());
1064-
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
1065-
model.learner_model_param->num_feature);
1066-
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
1067-
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
1068-
X, device_paths.begin(), device_paths.end(), ngroup, begin,
1069-
dh::tend(phis));
1059+
if (p_fmat->PageExists<SparsePage>()) {
1060+
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
1061+
batch.data.SetDevice(ctx_->Device());
1062+
batch.offset.SetDevice(ctx_->Device());
1063+
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
1064+
model.learner_model_param->num_feature);
1065+
auto begin = dh::tbegin(phis) + batch.base_rowid * dim_size;
1066+
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
1067+
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
1068+
}
1069+
} else {
1070+
for (auto& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
1071+
EllpackDeviceAccessor acc{batch.Impl()->GetDeviceAccessor(ctx_->Device())};
1072+
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
1073+
std::numeric_limits<float>::quiet_NaN()};
1074+
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
1075+
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
1076+
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
1077+
}
10701078
}
1079+
10711080
// Add the base margin term to last column
10721081
p_fmat->Info().base_margin_.SetDevice(ctx_->Device());
10731082
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
@@ -1094,9 +1103,6 @@ class GPUPredictor : public xgboost::Predictor {
10941103
if (tree_weights != nullptr) {
10951104
LOG(FATAL) << "Dart booster feature " << not_implemented;
10961105
}
1097-
if (!p_fmat->PageExists<SparsePage>()) {
1098-
LOG(FATAL) << "SHAP value for QuantileDMatrix is not yet implemented for GPU.";
1099-
}
11001106
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
11011107
out_contribs->SetDevice(ctx_->Device());
11021108
if (tree_end == 0 || tree_end > model.trees.size()) {
@@ -1108,9 +1114,9 @@ class GPUPredictor : public xgboost::Predictor {
11081114
// allocate space for (number of features + bias) times the number of rows
11091115
size_t contributions_columns =
11101116
model.learner_model_param->num_feature + 1; // +1 for bias
1111-
out_contribs->Resize(p_fmat->Info().num_row_ * contributions_columns *
1112-
contributions_columns *
1113-
model.learner_model_param->num_output_group);
1117+
auto dim_size =
1118+
contributions_columns * contributions_columns * model.learner_model_param->num_output_group;
1119+
out_contribs->Resize(p_fmat->Info().num_row_ * dim_size);
11141120
out_contribs->Fill(0.0f);
11151121
auto phis = out_contribs->DeviceSpan();
11161122

@@ -1120,16 +1126,29 @@ class GPUPredictor : public xgboost::Predictor {
11201126
d_model.Init(model, 0, tree_end, ctx_->Device());
11211127
dh::device_vector<uint32_t> categories;
11221128
ExtractPaths(&device_paths, &d_model, &categories, ctx_->Device());
1123-
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
1124-
batch.data.SetDevice(ctx_->Device());
1125-
batch.offset.SetDevice(ctx_->Device());
1126-
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
1127-
model.learner_model_param->num_feature);
1128-
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
1129-
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
1130-
X, device_paths.begin(), device_paths.end(), ngroup, begin,
1131-
dh::tend(phis));
1129+
if (p_fmat->PageExists<SparsePage>()) {
1130+
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
1131+
batch.data.SetDevice(ctx_->Device());
1132+
batch.offset.SetDevice(ctx_->Device());
1133+
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
1134+
model.learner_model_param->num_feature);
1135+
auto begin = dh::tbegin(phis) + batch.base_rowid * dim_size;
1136+
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
1137+
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
1138+
}
1139+
} else {
1140+
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, {})) {
1141+
auto impl = batch.Impl();
1142+
auto acc =
1143+
impl->GetDeviceAccessor(ctx_->Device(), p_fmat->Info().feature_types.ConstDeviceSpan());
1144+
auto begin = dh::tbegin(phis) + batch.BaseRowId() * dim_size;
1145+
auto X = EllpackLoader{acc, true, model.learner_model_param->num_feature, batch.Size(),
1146+
std::numeric_limits<float>::quiet_NaN()};
1147+
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
1148+
X, device_paths.begin(), device_paths.end(), ngroup, begin, dh::tend(phis));
1149+
}
11321150
}
1151+
11331152
// Add the base margin term to last column
11341153
p_fmat->Info().base_margin_.SetDevice(ctx_->Device());
11351154
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
@@ -1180,51 +1199,35 @@ class GPUPredictor : public xgboost::Predictor {
11801199
bool use_shared = shared_memory_bytes != 0;
11811200
bst_feature_t num_features = info.num_col_;
11821201

1202+
auto launch = [&](auto fn, std::uint32_t grid, auto data, bst_idx_t batch_offset) {
1203+
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes}(
1204+
fn, data, d_model.nodes.ConstDeviceSpan(),
1205+
predictions->DeviceSpan().subspan(batch_offset), d_model.tree_segments.ConstDeviceSpan(),
1206+
1207+
d_model.split_types.ConstDeviceSpan(), d_model.categories_tree_segments.ConstDeviceSpan(),
1208+
d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(),
1209+
1210+
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, use_shared,
1211+
std::numeric_limits<float>::quiet_NaN());
1212+
};
1213+
11831214
if (p_fmat->PageExists<SparsePage>()) {
1215+
bst_idx_t batch_offset = 0;
11841216
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
11851217
batch.data.SetDevice(ctx_->Device());
11861218
batch.offset.SetDevice(ctx_->Device());
1187-
bst_idx_t batch_offset = 0;
11881219
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
11891220
model.learner_model_param->num_feature};
1190-
size_t num_rows = batch.Size();
1191-
auto grid =
1192-
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
1193-
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
1194-
PredictLeafKernel<SparsePageLoader, SparsePageView>, data,
1195-
d_model.nodes.ConstDeviceSpan(),
1196-
predictions->DeviceSpan().subspan(batch_offset),
1197-
d_model.tree_segments.ConstDeviceSpan(),
1198-
1199-
d_model.split_types.ConstDeviceSpan(),
1200-
d_model.categories_tree_segments.ConstDeviceSpan(),
1201-
d_model.categories_node_segments.ConstDeviceSpan(),
1202-
d_model.categories.ConstDeviceSpan(),
1203-
1204-
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
1205-
use_shared, std::numeric_limits<float>::quiet_NaN());
1221+
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
1222+
launch(PredictLeafKernel<SparsePageLoader, SparsePageView>, grid, data, batch_offset);
12061223
batch_offset += batch.Size();
12071224
}
12081225
} else {
1226+
bst_idx_t batch_offset = 0;
12091227
for (auto const& batch : p_fmat->GetBatches<EllpackPage>(ctx_, BatchParam{})) {
1210-
bst_idx_t batch_offset = 0;
12111228
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->Device())};
1212-
size_t num_rows = batch.Size();
1213-
auto grid =
1214-
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
1215-
dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes} (
1216-
PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, data,
1217-
d_model.nodes.ConstDeviceSpan(),
1218-
predictions->DeviceSpan().subspan(batch_offset),
1219-
d_model.tree_segments.ConstDeviceSpan(),
1220-
1221-
d_model.split_types.ConstDeviceSpan(),
1222-
d_model.categories_tree_segments.ConstDeviceSpan(),
1223-
d_model.categories_node_segments.ConstDeviceSpan(),
1224-
d_model.categories.ConstDeviceSpan(),
1225-
1226-
d_model.tree_beg_, d_model.tree_end_, num_features, num_rows,
1227-
use_shared, std::numeric_limits<float>::quiet_NaN());
1229+
auto grid = static_cast<std::uint32_t>(common::DivRoundUp(batch.Size(), kBlockThreads));
1230+
launch(PredictLeafKernel<EllpackLoader, EllpackDeviceAccessor>, grid, data, batch_offset);
12281231
batch_offset += batch.Size();
12291232
}
12301233
}

tests/cpp/data/test_simple_dmatrix.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2016-2023 by XGBoost Contributors
2+
* Copyright 2016-2024, XGBoost Contributors
33
*/
44
#include <xgboost/data.h>
55

@@ -434,12 +434,11 @@ namespace {
434434
void VerifyColumnSplit() {
435435
size_t constexpr kRows {16};
436436
size_t constexpr kCols {8};
437-
auto dmat =
438-
RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, false, 1, DataSplitMode::kCol);
437+
auto p_fmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(false, DataSplitMode::kCol);
439438

440-
ASSERT_EQ(dmat->Info().num_col_, kCols * collective::GetWorldSize());
441-
ASSERT_EQ(dmat->Info().num_row_, kRows);
442-
ASSERT_EQ(dmat->Info().data_split_mode, DataSplitMode::kCol);
439+
ASSERT_EQ(p_fmat->Info().num_col_, kCols * collective::GetWorldSize());
440+
ASSERT_EQ(p_fmat->Info().num_row_, kRows);
441+
ASSERT_EQ(p_fmat->Info().data_split_mode, DataSplitMode::kCol);
443442
}
444443
} // anonymous namespace
445444

tests/cpp/gbm/test_gbtree.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/**
2-
* Copyright 2019-2023, XGBoost contributors
2+
* Copyright 2019-2024, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <xgboost/context.h>
@@ -463,7 +463,7 @@ INSTANTIATE_TEST_SUITE_P(PredictorTypes, Dart, testing::Values("CPU"));
463463

464464
std::pair<Json, Json> TestModelSlice(std::string booster) {
465465
size_t constexpr kRows = 1000, kCols = 100, kForest = 2, kClasses = 3;
466-
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true, false, kClasses);
466+
auto m = RandomDataGenerator{kRows, kCols, 0}.Classes(kClasses).GenerateDMatrix(true);
467467

468468
int32_t kIters = 10;
469469
std::unique_ptr<Learner> learner {
@@ -592,7 +592,7 @@ TEST(Dart, Slice) {
592592

593593
TEST(GBTree, FeatureScore) {
594594
size_t n_samples = 1000, n_features = 10, n_classes = 4;
595-
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
595+
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true);
596596

597597
std::unique_ptr<Learner> learner{ Learner::Create({m}) };
598598
learner->SetParam("num_class", std::to_string(n_classes));
@@ -629,7 +629,7 @@ TEST(GBTree, FeatureScore) {
629629

630630
TEST(GBTree, PredictRange) {
631631
size_t n_samples = 1000, n_features = 10, n_classes = 4;
632-
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
632+
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true);
633633

634634
std::unique_ptr<Learner> learner{Learner::Create({m})};
635635
learner->SetParam("num_class", std::to_string(n_classes));
@@ -642,7 +642,7 @@ TEST(GBTree, PredictRange) {
642642
ASSERT_THROW(learner->Predict(m, false, &out_predt, 0, 3), dmlc::Error);
643643

644644
auto m_1 =
645-
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
645+
RandomDataGenerator{n_samples, n_features, 0.5}.Classes(n_classes).GenerateDMatrix(true);
646646
HostDeviceVector<float> out_predt_full;
647647
learner->Predict(m_1, false, &out_predt_full, 0, 0);
648648
ASSERT_TRUE(std::equal(out_predt.HostVector().begin(), out_predt.HostVector().end(),

tests/cpp/helpers.cc

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,33 @@ void RandomDataGenerator::GenerateCSR(
376376
CHECK_EQ(columns->Size(), value->Size());
377377
}
378378

379+
namespace {
380+
void MakeLabels(DeviceOrd device, bst_idx_t n_samples, bst_target_t n_classes,
381+
bst_target_t n_targets, std::shared_ptr<DMatrix> out) {
382+
RandomDataGenerator gen{n_samples, n_targets, 0.0f};
383+
if (n_classes != 0) {
384+
gen.Lower(0).Upper(n_classes).GenerateDense(out->Info().labels.Data());
385+
out->Info().labels.Reshape(n_samples, n_targets);
386+
auto& h_labels = out->Info().labels.Data()->HostVector();
387+
for (auto& v : h_labels) {
388+
v = static_cast<float>(static_cast<uint32_t>(v));
389+
}
390+
} else {
391+
gen.GenerateDense(out->Info().labels.Data());
392+
CHECK_EQ(out->Info().labels.Size(), n_samples * n_targets);
393+
out->Info().labels.Reshape(n_samples, n_targets);
394+
}
395+
if (device.IsCUDA()) {
396+
out->Info().labels.Data()->SetDevice(device);
397+
out->Info().labels.Data()->ConstDevicePointer();
398+
out->Info().feature_types.SetDevice(device);
399+
out->Info().feature_types.ConstDevicePointer();
400+
}
401+
}
402+
} // namespace
403+
379404
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateDMatrix(
380-
bool with_label, bool float_label, size_t classes, DataSplitMode data_split_mode) const {
405+
bool with_label, DataSplitMode data_split_mode) const {
381406
HostDeviceVector<float> data;
382407
HostDeviceVector<std::size_t> rptrs;
383408
HostDeviceVector<bst_feature_t> columns;
@@ -388,19 +413,7 @@ void RandomDataGenerator::GenerateCSR(
388413
DMatrix::Create(&adapter, std::numeric_limits<float>::quiet_NaN(), 1, "", data_split_mode)};
389414

390415
if (with_label) {
391-
RandomDataGenerator gen{rows_, n_targets_, 0.0f};
392-
if (!float_label) {
393-
gen.Lower(0).Upper(classes).GenerateDense(out->Info().labels.Data());
394-
out->Info().labels.Reshape(this->rows_, this->n_targets_);
395-
auto& h_labels = out->Info().labels.Data()->HostVector();
396-
for (auto& v : h_labels) {
397-
v = static_cast<float>(static_cast<uint32_t>(v));
398-
}
399-
} else {
400-
gen.GenerateDense(out->Info().labels.Data());
401-
CHECK_EQ(out->Info().labels.Size(), this->rows_ * this->n_targets_);
402-
out->Info().labels.Reshape(this->rows_, this->n_targets_);
403-
}
416+
MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, out);
404417
}
405418
if (device_.IsCUDA()) {
406419
out->Info().labels.SetDevice(device_);
@@ -435,34 +448,31 @@ void RandomDataGenerator::GenerateCSR(
435448
#endif // defined(XGBOOST_USE_CUDA)
436449
}
437450

438-
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
451+
std::shared_ptr<DMatrix> p_fmat{DMatrix::Create(
439452
static_cast<DataIterHandle>(iter.get()), iter->Proxy(), Reset, Next,
440453
std::numeric_limits<float>::quiet_NaN(), Context{}.Threads(), prefix, on_host_)};
441454

442455
auto row_page_path =
443-
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(dmat.get())) + ".row.page";
456+
data::MakeId(prefix, dynamic_cast<data::SparsePageDMatrix*>(p_fmat.get())) + ".row.page";
444457
EXPECT_TRUE(FileExists(row_page_path)) << row_page_path;
445458

446459
// Loop over the batches and count the number of pages
447460
std::size_t batch_count = 0;
448461
bst_idx_t row_count = 0;
449-
for (const auto& batch : dmat->GetBatches<xgboost::SparsePage>()) {
462+
for (const auto& batch : p_fmat->GetBatches<xgboost::SparsePage>()) {
450463
batch_count++;
451464
row_count += batch.Size();
452465
CHECK_NE(batch.data.Size(), 0);
453466
}
454467

455468
EXPECT_EQ(batch_count, n_batches_);
456-
EXPECT_EQ(dmat->NumBatches(), n_batches_);
457-
EXPECT_EQ(row_count, dmat->Info().num_row_);
469+
EXPECT_EQ(p_fmat->NumBatches(), n_batches_);
470+
EXPECT_EQ(row_count, p_fmat->Info().num_row_);
458471

459472
if (with_label) {
460-
RandomDataGenerator{static_cast<bst_idx_t>(dmat->Info().num_row_), this->n_targets_, 0.0f}.GenerateDense(
461-
dmat->Info().labels.Data());
462-
CHECK_EQ(dmat->Info().labels.Size(), this->rows_ * this->n_targets_);
463-
dmat->Info().labels.Reshape(this->rows_, this->n_targets_);
473+
MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, p_fmat);
464474
}
465-
return dmat;
475+
return p_fmat;
466476
}
467477

468478
[[nodiscard]] std::shared_ptr<DMatrix> RandomDataGenerator::GenerateExtMemQuantileDMatrix(
@@ -492,10 +502,7 @@ void RandomDataGenerator::GenerateCSR(
492502
}
493503

494504
if (with_label) {
495-
RandomDataGenerator{static_cast<bst_idx_t>(p_fmat->Info().num_row_), this->n_targets_, 0.0f}
496-
.GenerateDense(p_fmat->Info().labels.Data());
497-
CHECK_EQ(p_fmat->Info().labels.Size(), this->rows_ * this->n_targets_);
498-
p_fmat->Info().labels.Reshape(this->rows_, this->n_targets_);
505+
MakeLabels(this->device_, this->rows_, this->n_classes_, this->n_targets_, p_fmat);
499506
}
500507
return p_fmat;
501508
}

0 commit comments

Comments
 (0)