Skip to content

Commit 41ab5dc

Browse files
akhvorovkrinart
authored andcommitted
Fixed missing feature names for XGBoost (#93)
1 parent d788566 commit 41ab5dc

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

m2cgen/assemblers/boosting.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class XGBoostModelAssembler(BaseBoostingAssembler):
8585
def __init__(self, model):
8686
feature_names = model.get_booster().feature_names
8787
self._feature_name_to_idx = {
88-
name: idx for idx, name in enumerate(feature_names)
88+
name: idx for idx, name in enumerate(feature_names or [])
8989
}
9090

9191
model_dump = model.get_booster().get_dump(dump_format="json")
@@ -103,7 +103,8 @@ def _assemble_tree(self, tree):
103103
return ast.NumVal(tree["leaf"])
104104

105105
threshold = ast.NumVal(tree["split_condition"])
106-
feature_idx = self._feature_name_to_idx[tree["split"]]
106+
split = tree["split"]
107+
feature_idx = self._feature_name_to_idx.get(split, split)
107108
feature_ref = ast.FeatureRef(feature_idx)
108109

109110
# Since comparison with NaN (missing) value always returns false we

tests/assemblers/test_xgboost.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import xgboost
22
import numpy as np
3+
import os
34
from tests import utils
45
from m2cgen import assemblers, ast
56

@@ -228,3 +229,42 @@ def test_multi_class_best_ntree_limit():
228229
])
229230

230231
assert utils.cmp_exprs(actual, expected)
232+
233+
234+
def test_regression_saved_without_feature_names():
235+
base_score = 0.6
236+
estimator = xgboost.XGBRegressor(n_estimators=2, random_state=1,
237+
max_depth=1, base_score=base_score)
238+
utils.train_model_regression(estimator)
239+
240+
with utils.tmp_dir() as tmp_dirpath:
241+
filename = os.path.join(tmp_dirpath, "tmp.file")
242+
estimator.save_model(filename)
243+
estimator = xgboost.XGBRegressor(base_score=base_score)
244+
estimator.load_model(filename)
245+
246+
assembler = assemblers.XGBoostModelAssembler(estimator)
247+
actual = assembler.assemble()
248+
249+
expected = ast.SubroutineExpr(
250+
ast.BinNumExpr(
251+
ast.BinNumExpr(
252+
ast.NumVal(base_score),
253+
ast.IfExpr(
254+
ast.CompExpr(
255+
ast.FeatureRef(12),
256+
ast.NumVal(9.72500038),
257+
ast.CompOpType.GTE),
258+
ast.NumVal(1.67318344),
259+
ast.NumVal(2.92757893)),
260+
ast.BinNumOpType.ADD),
261+
ast.IfExpr(
262+
ast.CompExpr(
263+
ast.FeatureRef(5),
264+
ast.NumVal(6.94099998),
265+
ast.CompOpType.GTE),
266+
ast.NumVal(3.3400948),
267+
ast.NumVal(1.72118247)),
268+
ast.BinNumOpType.ADD))
269+
270+
assert utils.cmp_exprs(actual, expected)

0 commit comments

Comments
 (0)