@@ -222,18 +222,17 @@ TEST(CPUHistogram, SyncHist) {
222
222
TestSyncHist (false );
223
223
}
224
224
225
- void TestBuildHistogram (bool is_distributed, bool force_read_by_column, bool is_col_split) {
225
+ void TestBuildHistogram (Context const * ctx, bool is_distributed, bool force_read_by_column, bool is_col_split) {
226
226
size_t constexpr kNRows = 8 , kNCols = 16 ;
227
227
int32_t constexpr kMaxBins = 4 ;
228
- Context ctx;
229
228
auto p_fmat =
230
229
RandomDataGenerator (kNRows , kNCols , 0.8 ).Seed (3 ).GenerateDMatrix ();
231
230
if (is_col_split) {
232
231
p_fmat = std::shared_ptr<DMatrix>{
233
232
p_fmat->SliceCol (collective::GetWorldSize (), collective::GetRank ())};
234
233
}
235
234
auto const &gmat =
236
- *(p_fmat->GetBatches <GHistIndexMatrix>(& ctx, BatchParam{kMaxBins , 0.5 }).begin ());
235
+ *(p_fmat->GetBatches <GHistIndexMatrix>(ctx, BatchParam{kMaxBins , 0.5 }).begin ());
237
236
uint32_t total_bins = gmat.cut .Ptrs ().back ();
238
237
239
238
static double constexpr kEps = 1e-6 ;
@@ -244,7 +243,7 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
244
243
bst_node_t nid = 0 ;
245
244
HistogramBuilder histogram;
246
245
HistMakerTrainParam hist_param;
247
- histogram.Reset (& ctx, total_bins, {kMaxBins , 0.5 }, is_distributed, is_col_split, &hist_param);
246
+ histogram.Reset (ctx, total_bins, {kMaxBins , 0.5 }, is_distributed, is_col_split, &hist_param);
248
247
249
248
RegTree tree;
250
249
@@ -262,11 +261,11 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
262
261
histogram.AddHistRows (&tree, &nodes_to_build, &dummy_sub, false );
263
262
common::BlockedSpace2d space{
264
263
1 , [&](std::size_t nidx_in_set) { return row_set_collection[nidx_in_set].Size (); }, 256 };
265
- for (auto const &gidx : p_fmat->GetBatches <GHistIndexMatrix>(& ctx, {kMaxBins , 0.5 })) {
264
+ for (auto const &gidx : p_fmat->GetBatches <GHistIndexMatrix>(ctx, {kMaxBins , 0.5 })) {
266
265
histogram.BuildHist (0 , space, gidx, row_set_collection, nodes_to_build,
267
- linalg::MakeTensorView (& ctx, gpair, gpair.size ()), force_read_by_column);
266
+ linalg::MakeTensorView (ctx, gpair, gpair.size ()), force_read_by_column);
268
267
}
269
- histogram.SyncHistogram (& ctx, &tree, nodes_to_build, {});
268
+ histogram.SyncHistogram (ctx, &tree, nodes_to_build, {});
270
269
271
270
// Check if number of histogram bins is correct
272
271
ASSERT_EQ (histogram.Histogram ()[nid].size (), gmat.cut .Ptrs ().back ());
@@ -292,16 +291,21 @@ void TestBuildHistogram(bool is_distributed, bool force_read_by_column, bool is_
292
291
}
293
292
294
293
TEST (CPUHistogram, BuildHist) {
295
- TestBuildHistogram (true , false , false );
296
- TestBuildHistogram (false , false , false );
297
- TestBuildHistogram (true , true , false );
298
- TestBuildHistogram (false , true , false );
294
+ Context ctx;
295
+ TestBuildHistogram (&ctx, true , false , false );
296
+ TestBuildHistogram (&ctx, false , false , false );
297
+ TestBuildHistogram (&ctx, true , true , false );
298
+ TestBuildHistogram (&ctx, false , true , false );
299
299
}
300
300
301
- TEST (CPUHistogram, BuildHistColSplit ) {
301
+ TEST (CPUHistogram, BuildHistColumnSplit ) {
302
302
auto constexpr kWorkers = 4 ;
303
- collective::TestDistributedGlobal (kWorkers , [] { TestBuildHistogram (true , true , true ); });
304
- collective::TestDistributedGlobal (kWorkers , [] { TestBuildHistogram (true , false , true ); });
303
+ Context ctx;
304
+ std::int32_t n_total_threads = std::thread::hardware_concurrency ();
305
+ auto n_threads = std::max (n_total_threads / kWorkers , 1 );
306
+ ctx.UpdateAllowUnknown (Args{{" nthread" , std::to_string (n_threads)}});
307
+ collective::TestDistributedGlobal (kWorkers , [&] { TestBuildHistogram (&ctx, true , true , true ); });
308
+ collective::TestDistributedGlobal (kWorkers , [&] { TestBuildHistogram (&ctx, true , false , true ); });
305
309
}
306
310
307
311
namespace {
0 commit comments