Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/model_loader/detail/xgboost_json/delegated_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ bool RegTreeHandler::StartArray() {
return (push_key_handler<ArrayHandler<float>>("loss_changes", loss_changes)
|| push_key_handler<ArrayHandler<float>>("sum_hessian", sum_hessian)
|| push_key_handler<ArrayHandler<float>>("base_weights", base_weights)
|| push_key_handler<ArrayHandler<float>>("leaf_weights", leaf_weights)
|| push_key_handler<ArrayHandler<int>>("categories_segments", categories_segments)
|| push_key_handler<ArrayHandler<int>>("categories_sizes", categories_sizes)
|| push_key_handler<ArrayHandler<int>>("categories_nodes", categories_nodes)
Expand Down Expand Up @@ -390,10 +391,8 @@ bool RegTreeHandler::EndObject() {
if (output.size_leaf_vector == 0) {
output.size_leaf_vector = 1; // In XGBoost, size_leaf_vector=0 indicates a scalar output
}
if (num_nodes * output.size_leaf_vector != base_weights.size()) {
TREELITE_LOG(ERROR) << "Field base_weights has an incorrect dimension. Expected: "
<< (num_nodes * output.size_leaf_vector)
<< ", Actual: " << base_weights.size();
if (output.size_leaf_vector != 1 && leaf_weights.empty()) {
TREELITE_LOG(ERROR) << "Field leaf_weights must be provided for multi-target trees.";
return false;
}
if (static_cast<std::size_t>(num_nodes) != left_children.size()) {
Expand Down Expand Up @@ -440,9 +439,10 @@ bool RegTreeHandler::EndObject() {
if (size_leaf_vector > 1) {
// Vector output
std::vector<float> leafvec(size_leaf_vector);
std::transform(&base_weights[node_id * size_leaf_vector],
&base_weights[(node_id + 1) * size_leaf_vector], leafvec.begin(),
[](float e) { return static_cast<float>(e); });
auto leaf_id = right_children[node_id];
TREELITE_CHECK_NE(leaf_id, -1) << "Expected a leaf node at index " << node_id;
std::copy(&leaf_weights[leaf_id * size_leaf_vector],
&leaf_weights[(leaf_id + 1) * size_leaf_vector], leafvec.begin());
model_builder.LeafVector(leafvec);
} else {
// Scalar leaf output
Expand Down Expand Up @@ -487,7 +487,7 @@ bool RegTreeHandler::is_recognized_key(std::string const& key) {
|| key == "categories" || key == "leaf_child_counts" || key == "left_children"
|| key == "right_children" || key == "parents" || key == "split_indices"
|| key == "split_type" || key == "split_conditions" || key == "default_left"
|| key == "tree_param" || key == "id");
|| key == "tree_param" || key == "id" || key == "leaf_weights");
}

/******************************************************************************
Expand Down
1 change: 1 addition & 0 deletions src/model_loader/detail/xgboost_json/delegated_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ class RegTreeHandler : public OutputHandler<ParsedRegTreeParams> {
std::vector<float> loss_changes;
std::vector<float> sum_hessian;
std::vector<float> base_weights;
std::vector<float> leaf_weights;
std::vector<int> left_children;
std::vector<int> right_children;
std::vector<int> parents;
Expand Down
23 changes: 14 additions & 9 deletions tests/python/test_xgboost_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
from .util import TemporaryDirectory, has_pandas, to_categorical


def _get_model_filename(name, model_format):
if model_format == "ubjson":
model_format = "ubj"
return f"{name}.{model_format}"


def generate_data_for_squared_log_error(n_targets: int = 1):
"""Generate data containing outliers."""
n_rows = 4096
Expand Down Expand Up @@ -108,7 +114,7 @@ def test_xgb_regressor(
num_boost_round=num_boost_round,
)
with TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / f"model.{model_format}"
model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format)
xgb_model.save_model(model_path)
tl_model = treelite.frontend.load_xgboost_model(
model_path, format_choice=model_format
Expand Down Expand Up @@ -179,7 +185,7 @@ def test_xgb_multiclass_classifier(
)

with TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / f"iris.{model_format}"
model_path = pathlib.Path(tmpdir) / _get_model_filename("iris", model_format)
xgb_model.save_model(model_path)
tl_model = treelite.frontend.load_xgboost_model(
model_path, format_choice=model_format
Expand Down Expand Up @@ -262,10 +268,7 @@ def test_xgb_nonlinear_objective(
)

objective_tag = objective.replace(":", "_")
if model_format in ["json", "ubjson"]:
model_name = f"nonlinear_{objective_tag}.{model_format}"
else:
model_name = f"nonlinear_{objective_tag}.deprecated"
model_name = _get_model_filename(f"nonlinear_{objective_tag}", model_format)
with TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / model_name
xgb_model.save_model(model_path)
Expand Down Expand Up @@ -458,7 +461,9 @@ def test_xgb_multi_target_binary_classifier(
tl_model = treelite.frontend.from_xgboost(bst)
else:
with TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / f"multi_target.{model_format}"
model_path = pathlib.Path(tmpdir) / _get_model_filename(
"multi_target", model_format
)
bst.save_model(model_path)
tl_model = treelite.frontend.load_xgboost_model(
model_path, format_choice=model_format
Expand Down Expand Up @@ -533,7 +538,7 @@ def test_xgb_multi_target_regressor(
)

with TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / f"model.{model_format}"
model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format)
xgb_model.save_model(model_path)
tl_model = treelite.frontend.load_xgboost_model(
model_path, format_choice=model_format
Expand Down Expand Up @@ -578,7 +583,7 @@ def test_xgb_detect_format(
expected_pred = xgb_model.predict(xgb.DMatrix(X)).reshape((X.shape[0], 1, -1))

with TemporaryDirectory() as tmpdir:
model_path = pathlib.Path(tmpdir) / f"model.{model_format}"
model_path = pathlib.Path(tmpdir) / _get_model_filename("model", model_format)
xgb_model.save_model(model_path)
detected_format = treelite.frontend._detect_xgboost_format(model_path)
assert detected_format == model_format
Expand Down
Loading