Skip to content

Commit ffb7996

Browse files
author
dmitry.razdoburdin
committed
Fixing compile errors. Oneapi tests passed
1 parent 9d4edbe commit ffb7996

File tree

6 files changed

+311
-17
lines changed

6 files changed

+311
-17
lines changed

plugin/updater_oneapi/updater_quantile_hist_oneapi.cc

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,7 @@ void GPUQuantileHistMakerOneAPI::Configure(const Args& args) {
7373
if (param.device_id != GenericParameter::kDefaultId) {
7474
qu_ = sycl::queue(devices[param.device_id]);
7575
} else {
76-
sycl::default_selector selector;
77-
qu_ = sycl::queue(selector);
76+
qu_ = sycl::queue(sycl::default_selector_v);
7877
}
7978

8079
// initialize pruner
@@ -107,9 +106,10 @@ template<typename GradientSumT>
107106
void GPUQuantileHistMakerOneAPI::CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
108107
HostDeviceVector<GradientPair> *gpair,
109108
DMatrix *dmat,
109+
common::Span<HostDeviceVector<bst_node_t>> out_position,
110110
const std::vector<RegTree *> &trees) {
111111
for (auto tree : trees) {
112-
builder->Update(gmat_, gpair, dmat, tree);
112+
builder->Update(gmat_, gpair, dmat, out_position, tree);
113113
}
114114
}
115115
void GPUQuantileHistMakerOneAPI::Update(HostDeviceVector<GradientPair> *gpair,
@@ -136,12 +136,12 @@ void GPUQuantileHistMakerOneAPI::Update(HostDeviceVector<GradientPair> *gpair,
136136
if (!float_builder_) {
137137
SetBuilder(&float_builder_, dmat);
138138
}
139-
CallBuilderUpdate(float_builder_, gpair, dmat, trees);
139+
CallBuilderUpdate(float_builder_, gpair, dmat, out_position, trees);
140140
} else {
141141
if (!double_builder_) {
142142
SetBuilder(&double_builder_, dmat);
143143
}
144-
CallBuilderUpdate(double_builder_, gpair, dmat, trees);
144+
CallBuilderUpdate(double_builder_, gpair, dmat, out_position, trees);
145145
}
146146

147147
param_.learning_rate = lr;
@@ -245,7 +245,9 @@ void GPUQuantileHistMakerOneAPI::Builder<GradientSumT>::ReduceHists(std::vector<
245245
const GradientPairT* psrc = reinterpret_cast<const GradientPairT*>(this_hist.DataConst());
246246
std::copy(psrc, psrc + nbins, reduce_buffer.begin() + i * nbins);
247247
}
248-
collective::Allreduce<collective::Operation::kSum>(reduce_buffer.data(), nbins * sync_ids.size());
248+
collective::Allreduce<collective::Operation::kSum>(
249+
reinterpret_cast<GradientSumT*>(reduce_buffer.data()),
250+
2 * nbins * sync_ids.size());
249251
// histred_.Allreduce(reduce_buffer.data(), nbins * sync_ids.size());
250252
for (size_t i = 0; i < sync_ids.size(); i++) {
251253
auto this_hist = hist_[sync_ids[i]];
@@ -604,6 +606,7 @@ void GPUQuantileHistMakerOneAPI::Builder<GradientSumT>::Update(
604606
const GHistIndexMatrixOneAPI &gmat,
605607
HostDeviceVector<GradientPair> *gpair,
606608
DMatrix *p_fmat,
609+
common::Span<HostDeviceVector<bst_node_t>> out_position,
607610
RegTree *p_tree) {
608611
builder_monitor_.Start("Update");
609612

@@ -626,7 +629,7 @@ void GPUQuantileHistMakerOneAPI::Builder<GradientSumT>::Update(
626629
p_tree->Stat(nid).base_weight = snode_[nid].weight;
627630
p_tree->Stat(nid).sum_hess = static_cast<float>(snode_[nid].stats.GetHess());
628631
}
629-
pruner_->Update(gpair, p_fmat, std::vector<RegTree*>{p_tree});
632+
pruner_->Update(gpair, p_fmat, out_position, std::vector<RegTree*>{p_tree});
630633

631634
builder_monitor_.Stop("Update");
632635
}
@@ -852,15 +855,9 @@ void GPUQuantileHistMakerOneAPI::Builder<GradientSumT>::InitData(const GHistInde
852855
}
853856
// store a pointer to the tree
854857
p_last_tree_ = &tree;
855-
if (data_layout_ == kDenseDataOneBased) {
856-
column_sampler_.Init(info.num_col_, info.feature_weights.ConstHostVector(),
857-
param_.colsample_bynode, param_.colsample_bylevel,
858-
param_.colsample_bytree, true);
859-
} else {
860-
column_sampler_.Init(info.num_col_, info.feature_weights.ConstHostVector(),
861-
param_.colsample_bynode, param_.colsample_bylevel,
862-
param_.colsample_bytree, false);
863-
}
858+
column_sampler_.Init(info.num_col_, info.feature_weights.ConstHostVector(),
859+
param_.colsample_bynode, param_.colsample_bylevel,
860+
param_.colsample_bytree);
864861
if (data_layout_ == kDenseDataZeroBased || data_layout_ == kDenseDataOneBased) {
865862
/* specialized code for dense data:
866863
choose the column that has a least positive number of discrete bins.
@@ -1352,7 +1349,7 @@ void GPUQuantileHistMakerOneAPI::Builder<GradientSumT>::InitNewNode(int nid,
13521349
grad_stat.Add(gpair[*it].GetGrad(), gpair[*it].GetHess());
13531350
}
13541351
}
1355-
collective::Allreduce<collective::Operation::kSum>(&grad_stat, 1);
1352+
collective::Allreduce<collective::Operation::kSum>(reinterpret_cast<GradientSumT*>(&grad_stat), 2);
13561353
// histred_.Allreduce(&grad_stat, 1);
13571354
snode_[nid].stats = GradStatsOneAPI<GradientSumT>(grad_stat.GetGrad(), grad_stat.GetHess());
13581355
} else {

plugin/updater_oneapi/updater_quantile_hist_oneapi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ class GPUQuantileHistMakerOneAPI: public TreeUpdater {
254254
virtual void Update(const GHistIndexMatrixOneAPI& gmat,
255255
HostDeviceVector<GradientPair>* gpair,
256256
DMatrix* p_fmat,
257+
common::Span<HostDeviceVector<bst_node_t>> out_position,
257258
RegTree* p_tree);
258259

259260
inline void BuildHist(const std::vector<GradientPair>& gpair,
@@ -507,6 +508,7 @@ class GPUQuantileHistMakerOneAPI: public TreeUpdater {
507508
void CallBuilderUpdate(const std::unique_ptr<Builder<GradientSumT>>& builder,
508509
HostDeviceVector<GradientPair> *gpair,
509510
DMatrix *dmat,
511+
common::Span<HostDeviceVector<bst_node_t>> out_position,
510512
const std::vector<RegTree *> &trees);
511513

512514
protected:
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import sys
2+
import unittest
3+
import pytest
4+
5+
import numpy as np
6+
import xgboost as xgb
7+
from hypothesis import given, strategies, assume, settings, note
8+
9+
sys.path.append("tests/python")
10+
import testing as tm
11+
12+
rng = np.random.RandomState(1994)
13+
14+
shap_parameter_strategy = strategies.fixed_dictionaries({
15+
'max_depth': strategies.integers(1, 11),
16+
'max_leaves': strategies.integers(0, 256),
17+
'num_parallel_tree': strategies.sampled_from([1, 10]),
18+
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)
19+
20+
21+
class TestOneAPIPredict(unittest.TestCase):
22+
def test_predict(self):
23+
iterations = 10
24+
np.random.seed(1)
25+
test_num_rows = [10, 1000, 5000]
26+
test_num_cols = [10, 50, 500]
27+
for num_rows in test_num_rows:
28+
for num_cols in test_num_cols:
29+
dtrain = xgb.DMatrix(np.random.randn(num_rows, num_cols),
30+
label=[0, 1] * int(num_rows / 2))
31+
dval = xgb.DMatrix(np.random.randn(num_rows, num_cols),
32+
label=[0, 1] * int(num_rows / 2))
33+
dtest = xgb.DMatrix(np.random.randn(num_rows, num_cols),
34+
label=[0, 1] * int(num_rows / 2))
35+
watchlist = [(dtrain, 'train'), (dval, 'validation')]
36+
res = {}
37+
param = {
38+
"objective": "binary:logistic_oneapi",
39+
"predictor": "oneapi_predictor",
40+
'eval_metric': 'logloss',
41+
'tree_method': 'hist',
42+
'updater': 'grow_quantile_histmaker_oneapi',
43+
'max_depth': 1
44+
}
45+
bst = xgb.train(param, dtrain, iterations, evals=watchlist,
46+
evals_result=res)
47+
assert self.non_increasing(res["train"]["logloss"])
48+
oneapi_pred_train = bst.predict(dtrain, output_margin=True)
49+
oneapi_pred_test = bst.predict(dtest, output_margin=True)
50+
oneapi_pred_val = bst.predict(dval, output_margin=True)
51+
52+
param["predictor"] = "cpu_predictor"
53+
bst_cpu = xgb.train(param, dtrain, iterations, evals=watchlist)
54+
cpu_pred_train = bst_cpu.predict(dtrain, output_margin=True)
55+
cpu_pred_test = bst_cpu.predict(dtest, output_margin=True)
56+
cpu_pred_val = bst_cpu.predict(dval, output_margin=True)
57+
58+
np.testing.assert_allclose(cpu_pred_train, oneapi_pred_train,
59+
rtol=1e-6)
60+
np.testing.assert_allclose(cpu_pred_val, oneapi_pred_val,
61+
rtol=1e-6)
62+
np.testing.assert_allclose(cpu_pred_test, oneapi_pred_test,
63+
rtol=1e-6)
64+
65+
def non_increasing(self, L):
66+
return all((y - x) < 0.001 for x, y in zip(L, L[1:]))
67+
68+
@pytest.mark.skipif(**tm.no_sklearn())
69+
def test_multi_predict(self):
70+
from sklearn.datasets import make_regression
71+
from sklearn.model_selection import train_test_split
72+
73+
n = 1000
74+
X, y = make_regression(n, random_state=rng)
75+
X_train, X_test, y_train, y_test = train_test_split(X, y,
76+
random_state=123)
77+
dtrain = xgb.DMatrix(X_train, label=y_train)
78+
dtest = xgb.DMatrix(X_test)
79+
80+
params = {}
81+
params["tree_method"] = "hist"
82+
params["updater"] = "grow_quantile_histmaker_oneapi"
83+
84+
params['predictor'] = "oneapi_predictor"
85+
bst_oneapi_predict = xgb.train(params, dtrain)
86+
87+
params['predictor'] = "cpu_predictor"
88+
bst_cpu_predict = xgb.train(params, dtrain)
89+
90+
predict0 = bst_oneapi_predict.predict(dtest)
91+
predict1 = bst_oneapi_predict.predict(dtest)
92+
cpu_predict = bst_cpu_predict.predict(dtest)
93+
94+
assert np.allclose(predict0, predict1)
95+
assert np.allclose(predict0, cpu_predict)
96+
97+
@pytest.mark.skipif(**tm.no_sklearn())
98+
def test_sklearn(self):
99+
m, n = 15000, 14
100+
tr_size = 2500
101+
X = np.random.rand(m, n)
102+
y = 200 * np.matmul(X, np.arange(-3, -3 + n))
103+
X_train, y_train = X[:tr_size, :], y[:tr_size]
104+
X_test, y_test = X[tr_size:, :], y[tr_size:]
105+
106+
# First with cpu_predictor
107+
params = {'tree_method': 'hist',
108+
'predictor': 'cpu_predictor',
109+
'n_jobs': -1,
110+
'seed': 123}
111+
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
112+
cpu_train_score = m.score(X_train, y_train)
113+
cpu_test_score = m.score(X_test, y_test)
114+
115+
# Now with oneapi_predictor
116+
params['predictor'] = 'oneapi_predictor'
117+
118+
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
119+
oneapi_train_score = m.score(X_train, y_train)
120+
m = xgb.XGBRegressor(**params).fit(X_train, y_train)
121+
oneapi_test_score = m.score(X_test, y_test)
122+
123+
assert np.allclose(cpu_train_score, oneapi_train_score)
124+
assert np.allclose(cpu_test_score, oneapi_test_score)
125+
126+
@given(strategies.integers(1, 10),
127+
tm.dataset_strategy.filter(lambda x: x.name != "empty"), shap_parameter_strategy)
128+
@settings(deadline=None)
129+
def test_shap(self, num_rounds, dataset, param):
130+
param.update({"predictor": "oneapi_predictor"})
131+
param = dataset.set_params(param)
132+
dmat = dataset.get_dmat()
133+
bst = xgb.train(param, dmat, num_rounds)
134+
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
135+
shap = bst.predict(test_dmat, pred_contribs=True)
136+
margin = bst.predict(test_dmat, output_margin=True)
137+
assume(len(dataset.y) > 0)
138+
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3)
139+
140+
@given(strategies.integers(1, 10),
141+
tm.dataset_strategy.filter(lambda x: x.name != "empty"), shap_parameter_strategy)
142+
@settings(deadline=None, max_examples=20)
143+
def test_shap_interactions(self, num_rounds, dataset, param):
144+
param.update({"predictor": "oneapi_predictor"})
145+
param = dataset.set_params(param)
146+
dmat = dataset.get_dmat()
147+
bst = xgb.train(param, dmat, num_rounds)
148+
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
149+
shap = bst.predict(test_dmat, pred_interactions=True)
150+
margin = bst.predict(test_dmat, output_margin=True)
151+
assume(len(dataset.y) > 0)
152+
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin,
153+
1e-3, 1e-3)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import numpy as np
2+
import xgboost as xgb
3+
import json
4+
5+
rng = np.random.RandomState(1994)
6+
7+
8+
class TestOneAPITrainingContinuation:
9+
def run_training_continuation(self, use_json):
10+
kRows = 64
11+
kCols = 32
12+
X = np.random.randn(kRows, kCols)
13+
y = np.random.randn(kRows)
14+
dtrain = xgb.DMatrix(X, y)
15+
params = {'updater': 'grow_quantile_histmaker_oneapi', 'max_depth': '2',
16+
'gamma': '0.1', 'alpha': '0.01',
17+
'enable_experimental_json_serialization': use_json}
18+
bst_0 = xgb.train(params, dtrain, num_boost_round=64)
19+
dump_0 = bst_0.get_dump(dump_format='json')
20+
21+
bst_1 = xgb.train(params, dtrain, num_boost_round=32)
22+
bst_1 = xgb.train(params, dtrain, num_boost_round=32, xgb_model=bst_1)
23+
dump_1 = bst_1.get_dump(dump_format='json')
24+
25+
def recursive_compare(obj_0, obj_1):
26+
if isinstance(obj_0, float):
27+
assert np.isclose(obj_0, obj_1, atol=1e-6)
28+
elif isinstance(obj_0, str):
29+
assert obj_0 == obj_1
30+
elif isinstance(obj_0, int):
31+
assert obj_0 == obj_1
32+
elif isinstance(obj_0, dict):
33+
keys_0 = list(obj_0.keys())
34+
keys_1 = list(obj_1.keys())
35+
values_0 = list(obj_0.values())
36+
values_1 = list(obj_1.values())
37+
for i in range(len(obj_0.items())):
38+
assert keys_0[i] == keys_1[i]
39+
if list(obj_0.keys())[i] != 'missing':
40+
recursive_compare(values_0[i],
41+
values_1[i])
42+
else:
43+
for i in range(len(obj_0)):
44+
recursive_compare(obj_0[i], obj_1[i])
45+
46+
assert len(dump_0) == len(dump_1)
47+
for i in range(len(dump_0)):
48+
obj_0 = json.loads(dump_0[i])
49+
obj_1 = json.loads(dump_1[i])
50+
recursive_compare(obj_0, obj_1)
51+
52+
def test_oneapi_training_continuation_binary(self):
53+
self.run_training_continuation(False)
54+
55+
def test_oneapi_training_continuation_json(self):
56+
self.run_training_continuation(True)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import numpy as np
2+
import gc
3+
import pytest
4+
import xgboost as xgb
5+
from hypothesis import given, strategies, assume, settings, note
6+
7+
import sys
8+
sys.path.append("tests/python")
9+
import testing as tm
10+
11+
parameter_strategy = strategies.fixed_dictionaries({
12+
'max_depth': strategies.integers(0, 11),
13+
'max_leaves': strategies.integers(0, 256),
14+
'max_bin': strategies.integers(2, 1024),
15+
'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']),
16+
'single_precision_histogram': strategies.booleans(),
17+
'min_child_weight': strategies.floats(0.5, 2.0),
18+
'seed': strategies.integers(0, 10),
19+
# We cannot enable subsampling as the training loss can increase
20+
# 'subsample': strategies.floats(0.5, 1.0),
21+
'colsample_bytree': strategies.floats(0.5, 1.0),
22+
'colsample_bylevel': strategies.floats(0.5, 1.0),
23+
}).filter(lambda x: (x['max_depth'] > 0 or x['max_leaves'] > 0) and (
24+
x['max_depth'] > 0 or x['grow_policy'] == 'lossguide'))
25+
26+
27+
def train_result(param, dmat, num_rounds):
28+
result = {}
29+
xgb.train(param, dmat, num_rounds, [(dmat, 'train')], verbose_eval=False,
30+
evals_result=result)
31+
return result
32+
33+
34+
class TestOneAPIUpdaters:
35+
@given(parameter_strategy, strategies.integers(1, 5),
36+
tm.dataset_strategy.filter(lambda x: x.name != "empty"))
37+
@settings(deadline=None)
38+
def test_oneapi_hist(self, param, num_rounds, dataset):
39+
param['updater'] = 'grow_quantile_histmaker_oneapi'
40+
param = dataset.set_params(param)
41+
result = train_result(param, dataset.get_dmat(), num_rounds)
42+
note(result)
43+
assert tm.non_increasing(result['train'][dataset.metric])
44+
45+
@given(tm.dataset_strategy.filter(lambda x: x.name != "empty"), strategies.integers(0, 1))
46+
@settings(deadline=None)
47+
def test_specified_device_id_oneapi_update(self, dataset, device_id):
48+
param = {'updater': 'grow_quantile_histmaker_oneapi', 'device_id': device_id}
49+
param = dataset.set_params(param)
50+
result = train_result(param, dataset.get_dmat(), 10)
51+
assert tm.non_increasing(result['train'][dataset.metric])

0 commit comments

Comments
 (0)