diff --git a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py index aeec4364..cc43cf1e 100644 --- a/onnxmltools/convert/xgboost/operator_converters/XGBoost.py +++ b/onnxmltools/convert/xgboost/operator_converters/XGBoost.py @@ -4,6 +4,8 @@ import numpy as np from onnx import TensorProto from xgboost import XGBClassifier +from typing import Any, Dict, List, Union +from copy import deepcopy try: from xgboost import XGBRFClassifier @@ -13,6 +15,9 @@ from ..common import get_xgb_params, get_n_estimators_classifier +Node = Dict[str, Any] +TreeLike = Union[Node, List[Node]] + class XGBConverter: """ Base class for converting XGBoost models to ONNX format. @@ -58,8 +63,149 @@ def common_members(xgb_node, inputs): # The json format was available in October 2017. # XGBoost 0.7 was the first version released with it. js_tree_list = booster.get_dump(with_stats=True, dump_format="json") - js_trees = [json.loads(s) for s in js_tree_list] + js_trees: TreeLike = [json.loads(s) for s in js_tree_list] + js_trees = XGBConverter._process_categorical_features(js_trees) return objective, base_score, js_trees, best_ntree_limit + + + @staticmethod + def _is_bracketed_json_list_string(s: str) -> bool: + s = s.strip() + return len(s) >= 2 and s[0] == '[' and s[-1] == ']' + + @staticmethod + def _clone_node_skeleton(source_node: Node) -> Node: + # Copy split-related metadata but not 'children' or 'split_condition' + new_node: Node = {} + for k, v in source_node.items(): + if k in ("children", "split_condition"): + continue + new_node[k] = v + + new_node["decision_type"] = "BRANCH_EQ" + return new_node + + @staticmethod + def _maybe_transform_categorical(node: Node) -> bool: + """ + If node's split_condition is a JSON list string, transform it into a right-leaning + chain of BRANCH_EQ nodes in-place and return True. Otherwise return False. + """ + + split_condition = node.get("split_condition") + + if not isinstance(split_condition, list): + return False # not categorical + + if len(split_condition) == 0: + raise ValueError("split_condition is an empty array. ") + + # Validate it's a split node with two children + children = node.get("children") + if not (isinstance(children, list) and len(children) == 2): + raise ValueError("Expected a split node with two children before categorical transform.") + + orig_left, orig_right = children + + # First category goes on the original node + node["decision_type"] = "BRANCH_EQ" + node["split_condition"] = split_condition[0] + + yes_left = orig_left["nodeid"] == node["yes"] + + current_node = node + for cat in split_condition[1:]: + new_node = XGBConverter._clone_node_skeleton(current_node) + new_node["split_condition"] = cat + + if(yes_left): + current_node["children"] = [deepcopy(orig_left), new_node] + else: + current_node["children"] = [new_node, deepcopy(orig_right)] + current_node = new_node + + # Final "no" path goes to the original right subtree + current_node["children"] = [orig_left, orig_right] + return True + + @staticmethod + def _process_node(node: Node) -> bool: + # If this is a leaf, nothing to do + if "children" not in node or not isinstance(node["children"], list): + return False + + for child in node["children"]: + any_child_node_categorical = XGBConverter._process_node(child) + + transformed = XGBConverter._maybe_transform_categorical(node) + if not transformed: + # Non-categorical split node: enforce BRANCH_LT as default + node["decision_type"] = "BRANCH_LT" + + return any_child_node_categorical or transformed + + @staticmethod + def _update_node_ids(node: Node, node_counter: int) -> None: + node["nodeid"] = node_counter + node_counter += 1 + children = node.get("children") + + # If this is a leaf, end recursion + if not (isinstance(children, list) and len(children) == 2): + return node_counter + + left, right = children + missing_yes = node.get("missing", -1) == node["yes"] + + yes_left = left["nodeid"] == node["yes"] + + first = "yes" if yes_left else "no" + node[first] = node_counter + if (missing_yes and yes_left) or (not missing_yes and not yes_left): + node["missing"] = node_counter + node_counter = XGBConverter._update_node_ids(left, node_counter) + + second = "no" if yes_left else "yes" + node[second] = node_counter + if (not missing_yes and yes_left) or (missing_yes and yes_left): + node["missing"] = node_counter + node_counter = XGBConverter._update_node_ids(right, node_counter) + + return node_counter + + + @staticmethod + def _process_root(root: Node) -> None: + any_categorical = XGBConverter._process_node(root) + if any_categorical: + # If any node was categorical, renumber the tree to ensure unique ids + XGBConverter._update_node_ids(root, node_counter=0) + + + @staticmethod + def _process_categorical_features(js_tree: TreeLike) -> TreeLike: + """ + Processes the native handling of categorical features to equality checks that + are supported in Onnx. + + - If a split node encodes categories via a JSON list string in 'split_condition', + it is expanded into a chain of BRANCH_EQ nodes. + - Otherwise (non-categorical split), the node's 'decision_type' is set to 'BRANCH_LT'. + - If there are categorical features, the nodeids are updated, but depth is ignored + since its not used for the conversion + + Returns the processed tree model. + """ + if isinstance(js_tree, list): + for root in js_tree: + if isinstance(root, dict): + XGBConverter._process_root(root) + elif isinstance(js_tree, dict): + XGBConverter._process_root(js_tree) + else: + raise TypeError("js_tree must be a dict (single tree) or list of dicts (forest).") + return js_tree + @staticmethod def _get_default_tree_attribute_pairs(is_classifier): @@ -168,7 +314,9 @@ def _fill_node_attributes( value=jsnode["split_condition"], node_id=remap[jsnode["nodeid"]], feature_id=jsnode["split"], - mode="BRANCH_LT", # 'BRANCH_LEQ' --> is for sklearn + mode=jsnode["decision_type"], # 'BRANCH_LEQ' --> is for sklearn + # 'BRANCH_LT' --> is for xgboost numerical features + # 'BRANCH_EQ' --> is for xgboost categorical features true_child_id=remap[jsnode["yes"]], # ['children'][0]['nodeid'], false_child_id=remap[jsnode["no"]], # ['children'][1]['nodeid'], weights=None, diff --git a/tests/xgboost/data_categorical.csv b/tests/xgboost/data_categorical.csv new file mode 100644 index 00000000..b5d081f4 --- /dev/null +++ b/tests/xgboost/data_categorical.csv @@ -0,0 +1,51 @@ +f0,f1,y +B,0.03260774,1.9839939 +A,0.028049896,0.091320634 +A,0.028272122,-0.006381493 +A,0.055345863,-0.052565098 +A,-0.48156285,-0.15852909 +C,-0.5834075,3.8429787 +C,-0.8621605,3.7296321 +A,-1.4881746,-0.33317634 +C,0.21630684,4.1076612 +C,0.9843764,4.381877 +B,-0.54308414,1.906369 +A,-0.558615,-0.25187373 +C,-0.31648284,3.886167 +A,-0.46063974,-0.2746162 +C,-1.4362698,3.5367992 +B,1.365108,2.4652877 +C,0.4389999,4.0895395 +A,-0.711695,-0.24534294 +C,0.29717177,4.1054583 +B,-0.43845728,1.9078286 +A,-0.21163744,-0.081306905 +A,0.36396384,0.09639855 +B,0.9529645,2.3263385 +B,1.5195241,2.4685602 +A,1.7039094,0.4969085 +A,-0.2488587,-0.05893244 +B,-0.4997486,1.8538816 +C,0.0995975,4.0426106 +C,0.12834321,4.1385646 +B,-0.7342219,1.7644731 +C,-0.6204753,3.7868693 +C,0.8132737,4.314664 +A,1.641801,0.4572553 +C,-0.22650085,4.018044 +C,-0.6479652,3.7959006 +A,-0.2833712,-0.081319384 +C,-0.9951314,3.738142 +C,-0.27287176,3.978794 +A,0.42244413,0.17658317 +B,-0.08134296,1.9654154 +A,1.2345777,0.35205185 +B,0.15088803,2.06264 +A,0.4811195,0.15107182 +C,-0.14875753,3.9232593 +A,1.3156657,0.41523612 +B,-1.2223456,1.6830022 +A,-0.30359134,-0.082752064 +A,-1.1736887,-0.27390662 +B,0.8262735,2.2683973 +B,0.8503223,2.247306 diff --git a/tests/xgboost/test_xgboost_converters.py b/tests/xgboost/test_xgboost_converters.py index c3f7a696..3798cdab 100644 --- a/tests/xgboost/test_xgboost_converters.py +++ b/tests/xgboost/test_xgboost_converters.py @@ -27,6 +27,7 @@ Booster, train as train_xgb, ) + import xgboost except Exception: XGBRegressor = None import sklearn @@ -832,6 +833,73 @@ def test_xgb_classifier_13_2(self): assert_almost_equal(expected[1], got[1]) assert_almost_equal(expected[0], got[0]) + + @unittest.skipIf(XGBRegressor is None, "xgboost is not available") + @unittest.skipIf( + pv.Version(xgboost.__version__) < pv.Version("1.6.0"), + "xgboost version< 1.6.0 lacks stable categorical support, skipping test.", + ) + def test_xgb_regressor_categorical_hist(self): + + this = os.path.dirname(__file__) + df = pandas.read_csv(os.path.join(this, "data_categorical.csv")) + df["f0"] = df["f0"].astype("category") + X, y = df.drop("y", axis=1), df["y"] + + models =[ + XGBRegressor( + objective="reg:squarederror", + n_estimators=30, + learning_rate=0.3, + tree_method="hist", + enable_categorical=True, # turn on native categorical handling + random_state=0, + ), + XGBRegressor( + n_estimators=30, + enable_categorical=True, # turn on native categorical handling + max_cat_to_onehot=8, # use native one hot encoding + ), + XGBRegressor( + n_estimators=100, + max_depth=10, + enable_categorical=True, # turn on native categorical handling + ), + ] + + for idx, model in enumerate(models): + model.fit(X, y) + + # Convert to ONNX. + # Input has 2 columns: cat codes (as numeric) + numeric feature. + onnx_model = convert_xgboost( + model, + initial_types=[("float_input", FloatTensorType([None, 2]))], + target_opset=TARGET_OPSET, + ) + + # Build the ONNX input: + # - first column: category codes (int codes) cast to float32 + # - second column: numeric feature + cat_codes = X["f0"].cat.codes.to_numpy().astype(np.float32).reshape(-1, 1) + num_col = X["f1"].to_numpy().astype(np.float32).reshape(-1, 1) + X_onnx = np.concatenate([cat_codes, num_col], axis=1) + + # Compare XGBoost and ONNX results. + expected = model.predict(X).astype(np.float32) + sess = InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]) + input_name = sess.get_inputs()[0].name + got = sess.run(None, {input_name: X_onnx})[0].ravel().astype(np.float32) + assert_almost_equal(expected, got, decimal=4) + + # Test onnx backend + dump_data_and_model( + X_onnx.astype("float32"), + model, + onnx_model, + basename=f"XGBRegressorCategoricalFeatures{idx}", + ) + if __name__ == "__main__": unittest.main(verbosity=2)