Skip to content

Commit 7b675da

Browse files
authored
[bp] Fix rng for the column sampler. (dmlc#10998) (dmlc#11004)
1 parent 5973d60 commit 7b675da

File tree

3 files changed

+25
-9
lines changed

3 files changed

+25
-9
lines changed

src/common/random.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ class ColumnSampler {
230230
};
231231

232232
inline auto MakeColumnSampler(Context const* ctx) {
233-
std::uint32_t seed = common::GlobalRandomEngine()();
233+
std::uint32_t seed = common::GlobalRandom()();
234234
auto rc = collective::Broadcast(ctx, linalg::MakeVec(&seed, 1), 0);
235235
collective::SafeColl(rc);
236236
auto cs = std::make_shared<common::ColumnSampler>(seed);

src/tree/updater_gpu_hist.cu

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -867,12 +867,7 @@ class GPUHistMaker : public TreeUpdater {
867867
CHECK_GE(ctx_->Ordinal(), 0) << "Must have at least one device";
868868
info_ = &dmat->Info();
869869

870-
// Synchronise the column sampling seed
871-
uint32_t column_sampling_seed = common::GlobalRandom()();
872-
auto rc = collective::Broadcast(
873-
ctx_, linalg::MakeVec(&column_sampling_seed, sizeof(column_sampling_seed)), 0);
874-
SafeColl(rc);
875-
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
870+
this->column_sampler_ = common::MakeColumnSampler(ctx_);
876871

877872
auto batch_param = BatchParam{param->max_bin, TrainParam::DftSparseThreshold()};
878873
dh::safe_cuda(cudaSetDevice(ctx_->Ordinal()));
@@ -1012,8 +1007,7 @@ class GPUGlobalApproxMaker : public TreeUpdater {
10121007

10131008
monitor_.Start(__func__);
10141009
CHECK(ctx_->IsCUDA()) << error::InvalidCUDAOrdinal();
1015-
uint32_t column_sampling_seed = common::GlobalRandom()();
1016-
this->column_sampler_ = std::make_shared<common::ColumnSampler>(column_sampling_seed);
1010+
this->column_sampler_ = common::MakeColumnSampler(ctx_);
10171011

10181012
p_last_fmat_ = p_fmat;
10191013
initialised_ = true;

tests/python/test_updaters.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,28 @@ def test_exact_sample_by_node_error(self) -> None:
5454
num_boost_round=2,
5555
)
5656

57+
@pytest.mark.parametrize("tree_method", ["approx", "hist"])
58+
def test_colsample_rng(self, tree_method: str) -> None:
59+
"""Test rng has an effect on column sampling."""
60+
X, y, _ = tm.make_regression(128, 16, use_cupy=False)
61+
reg0 = xgb.XGBRegressor(
62+
n_estimators=2,
63+
colsample_bynode=0.5,
64+
random_state=42,
65+
tree_method=tree_method,
66+
)
67+
reg0.fit(X, y)
68+
69+
reg1 = xgb.XGBRegressor(
70+
n_estimators=2,
71+
colsample_bynode=0.5,
72+
random_state=43,
73+
tree_method=tree_method,
74+
)
75+
reg1.fit(X, y)
76+
77+
assert list(reg0.feature_importances_) != list(reg1.feature_importances_)
78+
5779
@given(
5880
exact_parameter_strategy,
5981
hist_parameter_strategy,

0 commit comments

Comments
 (0)