Skip to content

Commit 6b4d703

Browse files
authored
[BP] Fix potential race in feature constraint. (dmlc#10719) (dmlc#10900)
1 parent a4c6cde commit 6b4d703

File tree

3 files changed

+22
-20
lines changed

3 files changed

+22
-20
lines changed

src/common/bitfield.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,11 @@ struct BitFieldContainer {
108108
#if defined(__CUDA_ARCH__)
109109
__device__ BitFieldContainer& operator|=(BitFieldContainer const& rhs) {
110110
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
111-
size_t min_size = min(NumValues(), rhs.NumValues());
111+
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
112112
if (tid < min_size) {
113-
Data()[tid] |= rhs.Data()[tid];
113+
if (this->Check(tid) || rhs.Check(tid)) {
114+
this->Set(tid);
115+
}
114116
}
115117
return *this;
116118
}
@@ -126,16 +128,20 @@ struct BitFieldContainer {
126128

127129
#if defined(__CUDA_ARCH__)
128130
__device__ BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
129-
size_t min_size = min(NumValues(), rhs.NumValues());
130131
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
132+
std::size_t min_size = std::min(this->Capacity(), rhs.Capacity());
131133
if (tid < min_size) {
132-
Data()[tid] &= rhs.Data()[tid];
134+
if (this->Check(tid) && rhs.Check(tid)) {
135+
this->Set(tid);
136+
} else {
137+
this->Clear(tid);
138+
}
133139
}
134140
return *this;
135141
}
136142
#else
137143
BitFieldContainer& operator&=(BitFieldContainer const& rhs) {
138-
size_t min_size = std::min(NumValues(), rhs.NumValues());
144+
std::size_t min_size = std::min(NumValues(), rhs.NumValues());
139145
for (size_t i = 0; i < min_size; ++i) {
140146
Data()[i] &= rhs.Data()[i];
141147
}

src/tree/constraints.cu

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#include <thrust/execution_policy.h>
77
#include <thrust/iterator/counting_iterator.h>
88

9-
#include <algorithm>
109
#include <string>
1110
#include <set>
1211

@@ -279,10 +278,6 @@ __global__ void InteractionConstraintSplitKernel(LBitField64 feature,
279278
}
280279
// enable constraints from feature
281280
node |= feature;
282-
// clear the buffer after use
283-
if (tid < feature.Capacity()) {
284-
feature.Clear(tid);
285-
}
286281

287282
// enable constraints from parent
288283
left |= node;
@@ -304,7 +299,7 @@ void FeatureInteractionConstraintDevice::Split(
304299
<< " Split node: " << node_id << " and its left child: "
305300
<< left_id << " cannot be the same.";
306301
CHECK_NE(node_id, right_id)
307-
<< " Split node: " << node_id << " and its left child: "
302+
<< " Split node: " << node_id << " and its right child: "
308303
<< right_id << " cannot be the same.";
309304
CHECK_LT(right_id, s_node_constraints_.size());
310305
CHECK_NE(s_node_constraints_.size(), 0);
@@ -330,6 +325,8 @@ void FeatureInteractionConstraintDevice::Split(
330325
feature_buffer_,
331326
feature_id,
332327
node, left, right);
333-
}
334328

329+
// clear the buffer after use
330+
thrust::fill_n(thrust::device, feature_buffer_.Data(), feature_buffer_.NumValues(), 0);
331+
}
335332
} // namespace xgboost

tests/cpp/tree/test_constraints.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
/**
2-
* Copyright 2019-2023, XGBoost contributors
2+
* Copyright 2019-2024, XGBoost contributors
33
*/
44
#include <gtest/gtest.h>
55
#include <thrust/copy.h>
66
#include <thrust/device_vector.h>
7-
#include <cinttypes>
8-
#include <string>
9-
#include <bitset>
7+
8+
#include <cstdint>
109
#include <set>
10+
#include <string>
11+
12+
#include "../../../src/common/device_helpers.cuh"
1113
#include "../../../src/tree/constraints.cuh"
1214
#include "../../../src/tree/param.h"
13-
#include "../../../src/common/device_helpers.cuh"
1415

1516
namespace xgboost {
1617
namespace {
@@ -36,9 +37,7 @@ std::string GetConstraintsStr() {
3637
}
3738

3839
tree::TrainParam GetParameter() {
39-
std::vector<std::pair<std::string, std::string>> args{
40-
{"interaction_constraints", GetConstraintsStr()}
41-
};
40+
Args args{{"interaction_constraints", GetConstraintsStr()}};
4241
tree::TrainParam param;
4342
param.Init(args);
4443
return param;

0 commit comments

Comments
 (0)