Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/model_loader/detail/xgboost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
#include <treelite/logging.h>
#include <treelite/model_loader.h>

#include <rapidjson/document.h>

#include "./string_utils.h"

namespace treelite::model_loader {

namespace detail::xgboost {
Expand Down Expand Up @@ -55,6 +59,25 @@ double TransformBaseScoreToMargin(std::string const& postprocessor, double base_
}
}

std::vector<float> ParseBaseScore(std::string const& str) {
std::vector<float> parsed_base_score;
if (StringStartsWith(str, "[")) {
// Vector base_score (from XGBoost 3.1+)
rapidjson::Document doc;
doc.Parse<rapidjson::ParseFlag::kParseNanAndInfFlag>(str);
TREELITE_CHECK(doc.IsArray()) << "Expected an array for base_score";
parsed_base_score.clear();
for (auto const& e : doc.GetArray()) {
TREELITE_CHECK(e.IsFloat()) << "Expected a float array for base_score";
parsed_base_score.push_back(e.GetFloat());
}
} else {
// Scalar base_score (from XGBoost <3.1)
parsed_base_score = {std::stof(str)};
}
return parsed_base_score;
}

} // namespace detail::xgboost

std::string DetectXGBoostFormat(std::string const& filename) {
Expand Down
3 changes: 3 additions & 0 deletions src/model_loader/detail/xgboost.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ std::string GetPostProcessor(std::string const& objective_name);
// Transform base score from probability into margin score
double TransformBaseScoreToMargin(std::string const& postprocessor, double base_score);

// Parse base score
std::vector<float> ParseBaseScore(std::string const& str);

enum FeatureType { kNumerical = 0, kCategorical = 1 };

} // namespace treelite::model_loader::detail::xgboost
Expand Down
184 changes: 162 additions & 22 deletions src/model_loader/detail/xgboost_json/delegated_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "./delegated_handler.h"

#include <algorithm>
#include <string>

#include <treelite/logging.h>
Expand Down Expand Up @@ -499,11 +500,10 @@ RegTreeArrayHandler::RegTreeArrayHandler(std::weak_ptr<Delegator> parent_delegat

bool RegTreeArrayHandler::StartObject() {
if (this->should_ignore_upcoming_value()) {
return this->template push_handler<IgnoreHandler>();
return push_handler<IgnoreHandler>();
}
this->output.emplace_back();
return this->template push_handler<RegTreeHandler, ParsedRegTreeParams>(
this->output.back(), model_builder);
output.emplace_back();
return push_handler<RegTreeHandler, ParsedRegTreeParams>(output.back(), model_builder);
}

/******************************************************************************
Expand All @@ -525,7 +525,9 @@ bool GBTreeModelHandler::StartObject() {
if (this->should_ignore_upcoming_value()) {
return push_handler<IgnoreHandler>();
}
return push_key_handler<IgnoreHandler>("gbtree_model_param");
return push_key_handler<IgnoreHandler>("gbtree_model_param")
|| push_key_handler<CategoryContainerHandler, ParsedCategoryContainer>(
"cats", output.category_container);
}

bool GBTreeModelHandler::EndObject() {
Expand All @@ -544,8 +546,124 @@ bool GBTreeModelHandler::EndObject() {
}

bool GBTreeModelHandler::is_recognized_key(std::string const& key) {
return (key == "trees" || key == "tree_info" || key == "gbtree_model_param"
|| key == "iteration_indptr");
return key == "trees" || key == "tree_info" || key == "gbtree_model_param"
|| key == "iteration_indptr" || key == "cats";
}

/******************************************************************************
* CategoryContainerHandler
* ***************************************************************************/

bool CategoryContainerHandler::StartArray() {
if (this->should_ignore_upcoming_value()) {
return push_handler<IgnoreHandler>();
}
return push_key_handler<CategoryInfoArrayHandler>("enc", output.enc)
|| push_key_handler<ArrayHandler<std::int32_t>, std::vector<std::int32_t>>(
"feature_segments", output.feature_segments)
|| push_key_handler<ArrayHandler<std::int32_t>, std::vector<std::int32_t>>(
"sorted_idx", output.sorted_idx);
}

bool CategoryContainerHandler::is_recognized_key(std::string const& key) {
return key == "enc" || key == "feature_segments" || key == "sorted_idx";
}

/******************************************************************************
* CategoryInfoArrayHandler
* ***************************************************************************/
bool CategoryInfoArrayHandler::StartObject() {
if (this->should_ignore_upcoming_value()) {
return push_handler<IgnoreHandler>();
}
output.emplace_back();
return push_handler<CategoryInfoHandler, ParsedCategoryInfo>(output.back());
}

/******************************************************************************
* CategoryInfoHandler
* ***************************************************************************/
bool CategoryInfoHandler::Int64(std::int64_t i) {
if (this->should_ignore_upcoming_value()) {
return push_handler<IgnoreHandler>();
}
bool got_type = check_cur_key("type");
if (got_type) {
output.type = i;
}
return got_type;
}

bool CategoryInfoHandler::Uint64(std::uint64_t u) {
// The "type" field can be int64 or uint64
// Just defer to the int64 handler
return Int64(static_cast<std::int64_t>(u));
}

bool CategoryInfoHandler::StartArray() {
if (this->should_ignore_upcoming_value()) {
return push_handler<IgnoreHandler>();
}

bool got_offsets = check_cur_key("offsets");
if (got_offsets) {
output.offsets = std::vector<std::int32_t>{};
push_handler<ArrayHandler<std::int32_t>, std::vector<std::int32_t>>(output.offsets.value());
}

// Assumption: Either "offsets" or "type" fields have been given before "values" field.
// Only with this assumption can we infer the type of the "values" field.
bool got_values = check_cur_key("values");
if (got_values) {
if (output.offsets.has_value()) {
// String categories
output.values = std::vector<std::int8_t>{};
push_handler<ArrayHandler<std::int8_t>, std::vector<std::int8_t>>(
std::get<std::vector<std::int8_t>>(output.values));
} else if (output.type.has_value()) {
// Numerical categories
switch (static_cast<ValueKind>(output.type.value())) {
case ValueKind::kU8Array:
case ValueKind::kU16Array:
case ValueKind::kU32Array:
case ValueKind::kU64Array: {
output.values = std::vector<std::uint64_t>{};
push_handler<ArrayHandler<std::uint64_t>, std::vector<std::uint64_t>>(
std::get<std::vector<std::uint64_t>>(output.values));
break;
}
case ValueKind::kI8Array:
case ValueKind::kI16Array:
case ValueKind::kI32Array:
case ValueKind::kI64Array: {
output.values = std::vector<std::int64_t>{};
push_handler<ArrayHandler<std::int64_t>, std::vector<std::int64_t>>(
std::get<std::vector<std::int64_t>>(output.values));
break;
}
case ValueKind::kF32Array:
case ValueKind::kF64Array: {
output.values = std::vector<double>{};
push_handler<ArrayHandler<double>, std::vector<double>>(
std::get<std::vector<double>>(output.values));
break;
}
default:
TREELITE_LOG(ERROR) << "Got invalid type for `values` array";
return false;
}
} else {
TREELITE_LOG(ERROR) << "Cannot determine the type of `values` array, since neither"
<< "`type` or `offsets` fields are present";
return false;
}
}

return got_values || got_offsets;
}

bool CategoryInfoHandler::is_recognized_key(std::string const& key) {
return key == "type" || key == "offsets" || key == "values";
}

/******************************************************************************
Expand Down Expand Up @@ -647,14 +765,17 @@ bool LearnerParamHandler::String(std::string const& str) {
if (this->should_ignore_upcoming_value()) {
return true;
}
// For now, XGBoost always outputs a scalar base_score
return (
assign_value("base_score", static_cast<float>(std::stof(str)), output.base_score)
|| assign_value("num_class", std::max(std::stoi(str), 1), output.num_class)
|| assign_value("num_target", static_cast<std::int32_t>(std::stoi(str)), output.num_target)
|| assign_value("num_feature", std::stoi(str), output.num_feature)
|| assign_value(
"boost_from_average", static_cast<bool>(std::stoi(str)), output.boost_from_average));

// Special handling logic for base_score
bool got_base_score = check_cur_key("base_score");
if (got_base_score) {
output.base_score = ParseBaseScore(str);
}
return got_base_score || assign_value("num_class", std::max(std::stoi(str), 1), output.num_class)
|| assign_value("num_target", static_cast<std::int32_t>(std::stoi(str)), output.num_target)
|| assign_value("num_feature", std::stoi(str), output.num_feature)
|| assign_value(
"boost_from_average", static_cast<bool>(std::stoi(str)), output.boost_from_average);
}

bool LearnerParamHandler::is_recognized_key(std::string const& key) {
Expand All @@ -680,6 +801,13 @@ bool LearnerHandler::StartObject() {
}

bool LearnerHandler::EndObject() {
/* Throw an exception if category encoding is required.
* TODO(hcho3): Implement categorical encoding */
TREELITE_CHECK(output.category_container.enc.empty()
&& output.category_container.feature_segments.empty()
&& output.category_container.sorted_idx.empty())
<< "Treelite does not yet support XGBoost models with categorical encoder";

/* Set metadata */
auto const num_tree = output.num_tree;
auto const num_feature = learner_params.num_feature;
Expand Down Expand Up @@ -743,18 +871,30 @@ bool LearnerHandler::EndObject() {
leaf_vector_shape[1] = 1;
}
}
// Set base scores. For now, XGBoost only supports a scalar base score for all targets / classes.
auto base_score = static_cast<double>(learner_params.base_score);
// Set base scores
// Assume: Either num_target or num_class must be 1
TREELITE_CHECK(learner_params.num_target == 1 || learner_params.num_class == 1);
std::vector<double> base_scores(learner_params.num_target * learner_params.num_class);
if (learner_params.base_score.size() == 1) {
// Scalar base_score (XGBoost <3.1)
// Starting from 3.1, the base score is a vector.
std::fill(base_scores.begin(), base_scores.end(),
static_cast<double>(learner_params.base_score.at(0)));
} else {
// Vector base_score (XGBoost 3.1+)
// Assume: If base_score is a vector, then its length should be num_target * num_class.
TREELITE_CHECK(base_scores.size() == learner_params.base_score.size());
std::transform(learner_params.base_score.begin(), learner_params.base_score.end(),
base_scores.begin(), [](float e) { return static_cast<double>(e); });
}

// Before XGBoost 1.0.0, the base score saved in model is a transformed value. After
// 1.0 it's the original value provided by user.
bool const need_transform_to_margin = output.version.empty() || output.version[0] >= 1;
if (need_transform_to_margin) {
base_score = xgboost::TransformBaseScoreToMargin(postprocessor.name, base_score);
std::for_each(base_scores.begin(), base_scores.end(),
[&](auto& e) { e = xgboost::TransformBaseScoreToMargin(postprocessor.name, e); });
}
// For now, XGBoost produces a scalar base_score
// Assume: Either num_target or num_class must be 1
TREELITE_CHECK(learner_params.num_target == 1 || learner_params.num_class == 1);
std::vector<double> base_scores(learner_params.num_target * learner_params.num_class, base_score);

model_builder::Metadata metadata{
num_feature, task_type, average_tree_output, num_target, num_class, leaf_vector_shape};
Expand Down
Loading
Loading