Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions onnxmltools/convert/xgboost/operator_converters/XGBoost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
from onnx import TensorProto
from xgboost import XGBClassifier
from typing import Any, Dict, List, Union
from copy import deepcopy

try:
from xgboost import XGBRFClassifier
Expand All @@ -13,6 +15,9 @@
from ..common import get_xgb_params, get_n_estimators_classifier


Node = Dict[str, Any]
TreeLike = Union[Node, List[Node]]

class XGBConverter:
"""
Base class for converting XGBoost models to ONNX format.
Expand Down Expand Up @@ -58,8 +63,149 @@ def common_members(xgb_node, inputs):
# The json format was available in October 2017.
# XGBoost 0.7 was the first version released with it.
js_tree_list = booster.get_dump(with_stats=True, dump_format="json")
js_trees = [json.loads(s) for s in js_tree_list]
js_trees: TreeLike = [json.loads(s) for s in js_tree_list]
js_trees = XGBConverter._process_categorical_features(js_trees)
return objective, base_score, js_trees, best_ntree_limit


@staticmethod
def _is_bracketed_json_list_string(s: str) -> bool:
s = s.strip()
return len(s) >= 2 and s[0] == '[' and s[-1] == ']'

@staticmethod
def _clone_node_skeleton(source_node: Node) -> Node:
# Copy split-related metadata but not 'children' or 'split_condition'
new_node: Node = {}
for k, v in source_node.items():
if k in ("children", "split_condition"):
continue
new_node[k] = v

new_node["decision_type"] = "BRANCH_EQ"
return new_node

@staticmethod
def _maybe_transform_categorical(node: Node) -> bool:
"""
If node's split_condition is a JSON list string, transform it into a right-leaning
chain of BRANCH_EQ nodes in-place and return True. Otherwise return False.
"""

split_condition = node.get("split_condition")

if not isinstance(split_condition, list):
return False # not categorical

if len(split_condition) == 0:
raise ValueError("split_condition is an empty array. ")

# Validate it's a split node with two children
children = node.get("children")
if not (isinstance(children, list) and len(children) == 2):
raise ValueError("Expected a split node with two children before categorical transform.")

orig_left, orig_right = children

# First category goes on the original node
node["decision_type"] = "BRANCH_EQ"
node["split_condition"] = split_condition[0]

yes_left = orig_left["nodeid"] == node["yes"]

current_node = node
for cat in split_condition[1:]:
new_node = XGBConverter._clone_node_skeleton(current_node)
new_node["split_condition"] = cat

if(yes_left):
current_node["children"] = [deepcopy(orig_left), new_node]
else:
current_node["children"] = [new_node, deepcopy(orig_right)]
current_node = new_node

# Final "no" path goes to the original right subtree
current_node["children"] = [orig_left, orig_right]
return True

@staticmethod
def _process_node(node: Node) -> bool:
# If this is a leaf, nothing to do
if "children" not in node or not isinstance(node["children"], list):
return False

for child in node["children"]:
any_child_node_categorical = XGBConverter._process_node(child)

transformed = XGBConverter._maybe_transform_categorical(node)
if not transformed:
# Non-categorical split node: enforce BRANCH_LT as default
node["decision_type"] = "BRANCH_LT"

return any_child_node_categorical or transformed

@staticmethod
def _update_node_ids(node: Node, node_counter: int) -> None:
node["nodeid"] = node_counter
node_counter += 1
children = node.get("children")

# If this is a leaf, end recursion
if not (isinstance(children, list) and len(children) == 2):
return node_counter

left, right = children
missing_yes = node.get("missing", -1) == node["yes"]

yes_left = left["nodeid"] == node["yes"]

first = "yes" if yes_left else "no"
node[first] = node_counter
if (missing_yes and yes_left) or (not missing_yes and not yes_left):
node["missing"] = node_counter
node_counter = XGBConverter._update_node_ids(left, node_counter)

second = "no" if yes_left else "yes"
node[second] = node_counter
if (not missing_yes and yes_left) or (missing_yes and yes_left):
node["missing"] = node_counter
node_counter = XGBConverter._update_node_ids(right, node_counter)

return node_counter


@staticmethod
def _process_root(root: Node) -> None:
any_categorical = XGBConverter._process_node(root)
if any_categorical:
# If any node was categorical, renumber the tree to ensure unique ids
XGBConverter._update_node_ids(root, node_counter=0)


@staticmethod
def _process_categorical_features(js_tree: TreeLike) -> TreeLike:
"""
Processes the native handling of categorical features to equality checks that
are supported in Onnx.

- If a split node encodes categories via a JSON list string in 'split_condition',
it is expanded into a chain of BRANCH_EQ nodes.
- Otherwise (non-categorical split), the node's 'decision_type' is set to 'BRANCH_LT'.
- If there are categorical features, the nodeids are updated, but depth is ignored
since its not used for the conversion

Returns the processed tree model.
"""
if isinstance(js_tree, list):
for root in js_tree:
if isinstance(root, dict):
XGBConverter._process_root(root)
elif isinstance(js_tree, dict):
XGBConverter._process_root(js_tree)
else:
raise TypeError("js_tree must be a dict (single tree) or list of dicts (forest).")
return js_tree


@staticmethod
def _get_default_tree_attribute_pairs(is_classifier):
Expand Down Expand Up @@ -168,7 +314,9 @@ def _fill_node_attributes(
value=jsnode["split_condition"],
node_id=remap[jsnode["nodeid"]],
feature_id=jsnode["split"],
mode="BRANCH_LT", # 'BRANCH_LEQ' --> is for sklearn
mode=jsnode["decision_type"], # 'BRANCH_LEQ' --> is for sklearn
# 'BRANCH_LT' --> is for xgboost numerical features
# 'BRANCH_EQ' --> is for xgboost categorical features
true_child_id=remap[jsnode["yes"]], # ['children'][0]['nodeid'],
false_child_id=remap[jsnode["no"]], # ['children'][1]['nodeid'],
weights=None,
Expand Down
51 changes: 51 additions & 0 deletions tests/xgboost/data_categorical.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
f0,f1,y
B,0.03260774,1.9839939
A,0.028049896,0.091320634
A,0.028272122,-0.006381493
A,0.055345863,-0.052565098
A,-0.48156285,-0.15852909
C,-0.5834075,3.8429787
C,-0.8621605,3.7296321
A,-1.4881746,-0.33317634
C,0.21630684,4.1076612
C,0.9843764,4.381877
B,-0.54308414,1.906369
A,-0.558615,-0.25187373
C,-0.31648284,3.886167
A,-0.46063974,-0.2746162
C,-1.4362698,3.5367992
B,1.365108,2.4652877
C,0.4389999,4.0895395
A,-0.711695,-0.24534294
C,0.29717177,4.1054583
B,-0.43845728,1.9078286
A,-0.21163744,-0.081306905
A,0.36396384,0.09639855
B,0.9529645,2.3263385
B,1.5195241,2.4685602
A,1.7039094,0.4969085
A,-0.2488587,-0.05893244
B,-0.4997486,1.8538816
C,0.0995975,4.0426106
C,0.12834321,4.1385646
B,-0.7342219,1.7644731
C,-0.6204753,3.7868693
C,0.8132737,4.314664
A,1.641801,0.4572553
C,-0.22650085,4.018044
C,-0.6479652,3.7959006
A,-0.2833712,-0.081319384
C,-0.9951314,3.738142
C,-0.27287176,3.978794
A,0.42244413,0.17658317
B,-0.08134296,1.9654154
A,1.2345777,0.35205185
B,0.15088803,2.06264
A,0.4811195,0.15107182
C,-0.14875753,3.9232593
A,1.3156657,0.41523612
B,-1.2223456,1.6830022
A,-0.30359134,-0.082752064
A,-1.1736887,-0.27390662
B,0.8262735,2.2683973
B,0.8503223,2.247306
68 changes: 68 additions & 0 deletions tests/xgboost/test_xgboost_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Booster,
train as train_xgb,
)
import xgboost
except Exception:
XGBRegressor = None
import sklearn
Expand Down Expand Up @@ -832,6 +833,73 @@ def test_xgb_classifier_13_2(self):
assert_almost_equal(expected[1], got[1])
assert_almost_equal(expected[0], got[0])


@unittest.skipIf(XGBRegressor is None, "xgboost is not available")
@unittest.skipIf(
pv.Version(xgboost.__version__) < pv.Version("1.6.0"),
"xgboost version< 1.6.0 lacks stable categorical support, skipping test.",
)
def test_xgb_regressor_categorical_hist(self):

this = os.path.dirname(__file__)
df = pandas.read_csv(os.path.join(this, "data_categorical.csv"))
df["f0"] = df["f0"].astype("category")
X, y = df.drop("y", axis=1), df["y"]

models =[
XGBRegressor(
objective="reg:squarederror",
n_estimators=30,
learning_rate=0.3,
tree_method="hist",
enable_categorical=True, # turn on native categorical handling
random_state=0,
),
XGBRegressor(
n_estimators=30,
enable_categorical=True, # turn on native categorical handling
max_cat_to_onehot=8, # use native one hot encoding
),
XGBRegressor(
n_estimators=100,
max_depth=10,
enable_categorical=True, # turn on native categorical handling
),
]

for idx, model in enumerate(models):
model.fit(X, y)

# Convert to ONNX.
# Input has 2 columns: cat codes (as numeric) + numeric feature.
onnx_model = convert_xgboost(
model,
initial_types=[("float_input", FloatTensorType([None, 2]))],
target_opset=TARGET_OPSET,
)

# Build the ONNX input:
# - first column: category codes (int codes) cast to float32
# - second column: numeric feature
cat_codes = X["f0"].cat.codes.to_numpy().astype(np.float32).reshape(-1, 1)
num_col = X["f1"].to_numpy().astype(np.float32).reshape(-1, 1)
X_onnx = np.concatenate([cat_codes, num_col], axis=1)

# Compare XGBoost and ONNX results.
expected = model.predict(X).astype(np.float32)
sess = InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
got = sess.run(None, {input_name: X_onnx})[0].ravel().astype(np.float32)
assert_almost_equal(expected, got, decimal=4)

# Test onnx backend
dump_data_and_model(
X_onnx.astype("float32"),
model,
onnx_model,
basename=f"XGBRegressorCategoricalFeatures{idx}",
)


if __name__ == "__main__":
unittest.main(verbosity=2)