Skip to content

Commit 570a68a

Browse files
Dmitry Razdoburdinrazdoburdin
authored andcommitted
add more tests for sycl predictor; fixed the found bugs
1 parent 5452227 commit 570a68a

File tree

10 files changed

+126
-392
lines changed

10 files changed

+126
-392
lines changed

include/xgboost/objective.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ class ObjFunction : public Configurable {
132132
* \param name Name of the objective.
133133
*/
134134
static ObjFunction* Create(const std::string& name, Context const* ctx);
135+
136+
/*!
137+
* \brief Return sycl specific implementation name if possible.
138+
* \param name Name of the objective.
139+
*/
140+
static std::string GetSyclImplementationName(const std::string& name);
135141
};
136142

137143
/*!

plugin/updater_oneapi/device_manager_oneapi.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ namespace xgboost {
1010

1111
sycl::device DeviceManagerOneAPI::GetDevice(const DeviceOrd& device_spec) const {
1212
if (!device_spec.IsSycl()) {
13-
LOG(WARNING) << "Sycl kernel is executed with non-sycl context. "
13+
LOG(WARNING) << "Sycl kernel is executed with non-sycl context: "
14+
<< device_spec.Name() << ". "
1415
<< "Default sycl device_selector will be used.";
1516
}
1617

@@ -45,7 +46,8 @@ sycl::device DeviceManagerOneAPI::GetDevice(const DeviceOrd& device_spec) const
4546

4647
sycl::queue DeviceManagerOneAPI::GetQueue(const DeviceOrd& device_spec) const {
4748
if (!device_spec.IsSycl()) {
48-
LOG(WARNING) << "Sycl kernel is executed with non-sycl context. "
49+
LOG(WARNING) << "Sycl kernel is executed with non-sycl context: "
50+
<< device_spec.Name() << ". "
4951
<< "Default sycl device_selector will be used.";
5052
}
5153

plugin/updater_oneapi/predictor_oneapi.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ class DeviceModelOneAPI {
180180
int sum = 0;
181181
tree_segments_[0] = sum;
182182
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
183+
if (model.trees[tree_idx]->HasCategoricalSplit()) {
184+
LOG(FATAL) << "Categorical features are not yet supported by sycl";
185+
}
183186
sum += model.trees[tree_idx]->GetNodes().size();
184187
tree_segments_[tree_idx - tree_begin + 1] = sum;
185188
}

src/learner.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,9 @@ class LearnerConfiguration : public Learner {
528528
auto const& objective_fn = learner_parameters.at("objective");
529529
if (!obj_) {
530530
CHECK_EQ(get<String const>(objective_fn["name"]), tparam_.objective);
531+
if (ctx_.IsSycl()) {
532+
tparam_.objective = ObjFunction::GetSyclImplementationName(tparam_.objective);
533+
}
531534
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
532535
}
533536
obj_->LoadConfig(objective_fn);
@@ -791,6 +794,9 @@ class LearnerConfiguration : public Learner {
791794
// Rename one of them once binary IO is gone.
792795
cfg_["max_delta_step"] = kMaxDeltaStepDefaultValue;
793796
}
797+
if (ctx_.IsSycl()) {
798+
tparam_.objective = ObjFunction::GetSyclImplementationName(tparam_.objective);
799+
}
794800
if (obj_ == nullptr || tparam_.objective != old.objective) {
795801
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
796802
}
@@ -882,6 +888,9 @@ class LearnerIO : public LearnerConfiguration {
882888
auto const& objective_fn = learner.at("objective");
883889

884890
std::string name = get<String>(objective_fn["name"]);
891+
if (ctx_.IsSycl()) {
892+
name = ObjFunction::GetSyclImplementationName(name);
893+
}
885894
tparam_.UpdateAllowUnknown(Args{{"objective", name}});
886895
obj_.reset(ObjFunction::Create(name, &ctx_));
887896
obj_->LoadConfig(objective_fn);
@@ -1009,6 +1018,9 @@ class LearnerIO : public LearnerConfiguration {
10091018
CHECK(fi->Read(&tparam_.objective)) << "BoostLearner: wrong model format";
10101019
CHECK(fi->Read(&tparam_.booster)) << "BoostLearner: wrong model format";
10111020

1021+
if (ctx_.IsSycl()) {
1022+
tparam_.objective = ObjFunction::GetSyclImplementationName(tparam_.objective);
1023+
}
10121024
obj_.reset(ObjFunction::Create(tparam_.objective, &ctx_));
10131025
gbm_.reset(GradientBooster::Create(tparam_.booster, &ctx_, &learner_model_param_));
10141026
gbm_->Load(fi);

src/objective/objective.cc

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,33 @@ DMLC_REGISTRY_ENABLE(::xgboost::ObjFunctionReg);
1818
namespace xgboost {
1919
// implement factory functions
2020
ObjFunction* ObjFunction::Create(const std::string& name, Context const* ctx) {
21-
std::string replaced_name = name;
22-
if (ctx->IsSycl()) {
23-
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name + "_oneapi");
24-
if (e != nullptr) {
25-
replaced_name += "_oneapi";
26-
}
27-
}
28-
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(replaced_name);
21+
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name);
2922
if (e == nullptr) {
3023
std::stringstream ss;
3124
for (const auto& entry : ::dmlc::Registry< ::xgboost::ObjFunctionReg>::List()) {
3225
ss << "Objective candidate: " << entry->name << "\n";
3326
}
34-
LOG(FATAL) << "Unknown objective function: `" << replaced_name << "`\n"
27+
LOG(FATAL) << "Unknown objective function: `" << name << "`\n"
3528
<< ss.str();
3629
}
3730
auto pobj = (e->body)();
3831
pobj->ctx_ = ctx;
3932
return pobj;
4033
}
4134

35+
// Return sycl specific implementation name if possible.
36+
std::string ObjFunction::GetSyclImplementationName(const std::string& name) {
37+
const std::string sycl_postfix = "_oneapi";
38+
auto *e = ::dmlc::Registry< ::xgboost::ObjFunctionReg>::Get()->Find(name + sycl_postfix);
39+
if (e != nullptr) {
40+
// Function has specific sycl implementation
41+
return name + sycl_postfix;
42+
} else {
43+
// Function hasn't specific sycl implementation
44+
return name;
45+
}
46+
}
47+
4248
void ObjFunction::InitEstimation(MetaInfo const&, linalg::Tensor<float, 1>* base_score) const {
4349
CHECK(base_score);
4450
base_score->Reshape(1);

tests/cpp/plugin/test_predictor_oneapi.cc

Lines changed: 0 additions & 116 deletions
This file was deleted.

0 commit comments

Comments
 (0)