Skip to content

Commit 65c07eb

Browse files
authored
Limit the number of leaves in each subroutine for gradient boosted trees (#123)
Fixes #103
1 parent 5ae584d commit 65c07eb

File tree

5 files changed

+269
-11
lines changed

5 files changed

+269
-11
lines changed

m2cgen/assemblers/boosting.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ class BaseBoostingAssembler(ModelAssembler):
99

1010
classifier_name = None
1111

12-
def __init__(self, model, trees, base_score=0, tree_limit=None):
12+
def __init__(self, model, trees, base_score=0, tree_limit=None,
13+
leaves_cutoff_threshold=3000):
1314
super().__init__(model)
1415
self.all_trees = trees
1516
self._base_score = base_score
17+
self._leaves_cutoff_threshold = leaves_cutoff_threshold
1618

1719
self._output_size = 1
1820
self._is_classification = False
@@ -41,10 +43,19 @@ def _assemble_single_output(self, trees, base_score=0):
4143
trees = trees[:self._tree_limit]
4244

4345
trees_ast = [self._assemble_tree(t) for t in trees]
46+
to_sum = trees_ast
47+
48+
# In a large tree we need to generate multiple subroutines to avoid
49+
# java limitations https://github.com/BayesWitnesses/m2cgen/issues/103.
50+
trees_num_leaves = [self._count_leaves(t) for t in trees]
51+
if sum(trees_num_leaves) > self._leaves_cutoff_threshold:
52+
to_sum = self._split_into_subroutines(trees_ast, trees_num_leaves)
53+
4454
result_ast = utils.apply_op_to_expressions(
4555
ast.BinNumOpType.ADD,
4656
ast.NumVal(base_score),
47-
*trees_ast)
57+
*to_sum)
58+
4859
return ast.SubroutineExpr(result_ast)
4960

5061
def _assemble_multi_class_output(self, trees):
@@ -74,15 +85,47 @@ def _assemble_bin_class_output(self, trees):
7485
proba_expr
7586
])
7687

88+
def _split_into_subroutines(self, trees_ast, trees_num_leaves):
89+
result = []
90+
subroutine_trees = []
91+
subroutine_sum_leaves = 0
92+
for tree, num_leaves in zip(trees_ast, trees_num_leaves):
93+
next_sum = subroutine_sum_leaves + num_leaves
94+
if subroutine_trees and next_sum > self._leaves_cutoff_threshold:
95+
# Exceeded the max leaves in the current subroutine,
96+
# finalize this one and start a new one.
97+
partial_result = utils.apply_op_to_expressions(
98+
ast.BinNumOpType.ADD,
99+
*subroutine_trees)
100+
101+
result.append(ast.SubroutineExpr(partial_result))
102+
103+
subroutine_trees = []
104+
subroutine_sum_leaves = 0
105+
106+
subroutine_sum_leaves += num_leaves
107+
subroutine_trees.append(tree)
108+
109+
if subroutine_trees:
110+
partial_result = utils.apply_op_to_expressions(
111+
ast.BinNumOpType.ADD,
112+
*subroutine_trees)
113+
result.append(ast.SubroutineExpr(partial_result))
114+
return result
115+
77116
def _assemble_tree(self, tree):
78117
raise NotImplementedError
79118

119+
@staticmethod
120+
def _count_leaves(trees):
121+
raise NotImplementedError
122+
80123

81124
class XGBoostModelAssembler(BaseBoostingAssembler):
82125

83126
classifier_name = "XGBClassifier"
84127

85-
def __init__(self, model):
128+
def __init__(self, model, leaves_cutoff_threshold=3000):
86129
feature_names = model.get_booster().feature_names
87130
self._feature_name_to_idx = {
88131
name: idx for idx, name in enumerate(feature_names or [])
@@ -96,7 +139,8 @@ def __init__(self, model):
96139
best_ntree_limit = getattr(model, "best_ntree_limit", None)
97140

98141
super().__init__(model, trees, base_score=model.base_score,
99-
tree_limit=best_ntree_limit)
142+
tree_limit=best_ntree_limit,
143+
leaves_cutoff_threshold=leaves_cutoff_threshold)
100144

101145
def _assemble_tree(self, tree):
102146
if "leaf" in tree:
@@ -130,16 +174,31 @@ def _assemble_child_tree(self, tree, child_id):
130174
return self._assemble_tree(child)
131175
assert False, "Unexpected child ID {}".format(child_id)
132176

177+
@staticmethod
178+
def _count_leaves(tree):
179+
queue = [tree]
180+
num_leaves = 0
181+
182+
while queue:
183+
tree = queue.pop()
184+
if "leaf" in tree:
185+
num_leaves += 1
186+
elif "children" in tree:
187+
for child in tree["children"]:
188+
queue.append(child)
189+
return num_leaves
190+
133191

134192
class LightGBMModelAssembler(BaseBoostingAssembler):
135193

136194
classifier_name = "LGBMClassifier"
137195

138-
def __init__(self, model):
196+
def __init__(self, model, leaves_cutoff_threshold=3000):
139197
model_dump = model.booster_.dump_model()
140198
trees = [m["tree_structure"] for m in model_dump["tree_info"]]
141199

142-
super().__init__(model, trees)
200+
super().__init__(model, trees,
201+
leaves_cutoff_threshold=leaves_cutoff_threshold)
143202

144203
def _assemble_tree(self, tree):
145204
if "leaf_value" in tree:
@@ -151,9 +210,9 @@ def _assemble_tree(self, tree):
151210
op = ast.CompOpType.from_str_op(tree["decision_type"])
152211
assert op == ast.CompOpType.LTE, "Unexpected comparison op"
153212

154-
# Make sure that if the 'default_left' is true the left tree branch
213+
# Make sure that if the "default_left" is true the left tree branch
155214
# ends up in the "else" branch of the ast.IfExpr.
156-
if tree['default_left']:
215+
if tree["default_left"]:
157216
op = ast.CompOpType.GT
158217
true_child = tree["right_child"]
159218
false_child = tree["left_child"]
@@ -166,6 +225,20 @@ def _assemble_tree(self, tree):
166225
self._assemble_tree(true_child),
167226
self._assemble_tree(false_child))
168227

228+
@staticmethod
229+
def _count_leaves(tree):
230+
queue = [tree]
231+
num_leaves = 0
232+
233+
while queue:
234+
tree = queue.pop()
235+
if "leaf_value" in tree:
236+
num_leaves += 1
237+
else:
238+
queue.append(tree["left_child"])
239+
queue.append(tree["right_child"])
240+
return num_leaves
241+
169242

170243
def _split_trees_by_classes(trees, n_classes):
171244
# Splits are computed based on a comment

tests/assemblers/test_lightgbm.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,53 @@ def test_regression():
110110
ast.BinNumOpType.ADD))
111111

112112
assert utils.cmp_exprs(actual, expected)
113+
114+
115+
def test_leaves_cutoff_threshold():
116+
estimator = lightgbm.LGBMClassifier(n_estimators=2, random_state=1,
117+
max_depth=1)
118+
utils.train_model_classification_binary(estimator)
119+
120+
assembler = assemblers.LightGBMModelAssembler(estimator,
121+
leaves_cutoff_threshold=1)
122+
actual = assembler.assemble()
123+
124+
sigmoid = ast.BinNumExpr(
125+
ast.NumVal(1),
126+
ast.BinNumExpr(
127+
ast.NumVal(1),
128+
ast.ExpExpr(
129+
ast.BinNumExpr(
130+
ast.NumVal(0),
131+
ast.SubroutineExpr(
132+
ast.BinNumExpr(
133+
ast.BinNumExpr(
134+
ast.NumVal(0),
135+
ast.SubroutineExpr(
136+
ast.IfExpr(
137+
ast.CompExpr(
138+
ast.FeatureRef(23),
139+
ast.NumVal(868.2000000000002),
140+
ast.CompOpType.GT),
141+
ast.NumVal(0.2762557140263451),
142+
ast.NumVal(0.6399134166614473))),
143+
ast.BinNumOpType.ADD),
144+
ast.SubroutineExpr(
145+
ast.IfExpr(
146+
ast.CompExpr(
147+
ast.FeatureRef(27),
148+
ast.NumVal(0.14205000000000004),
149+
ast.CompOpType.GT),
150+
ast.NumVal(-0.2139321843285849),
151+
ast.NumVal(0.1151466338793227))),
152+
ast.BinNumOpType.ADD)),
153+
ast.BinNumOpType.SUB)),
154+
ast.BinNumOpType.ADD),
155+
ast.BinNumOpType.DIV,
156+
to_reuse=True)
157+
158+
expected = ast.VectorVal([
159+
ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB),
160+
sigmoid])
161+
162+
assert utils.cmp_exprs(actual, expected)

tests/assemblers/test_xgboost.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,53 @@ def test_regression_saved_without_feature_names():
268268
ast.BinNumOpType.ADD))
269269

270270
assert utils.cmp_exprs(actual, expected)
271+
272+
273+
def test_leaves_cutoff_threshold():
274+
estimator = xgboost.XGBClassifier(n_estimators=2, random_state=1,
275+
max_depth=1)
276+
utils.train_model_classification_binary(estimator)
277+
278+
assembler = assemblers.XGBoostModelAssembler(estimator,
279+
leaves_cutoff_threshold=1)
280+
actual = assembler.assemble()
281+
282+
sigmoid = ast.BinNumExpr(
283+
ast.NumVal(1),
284+
ast.BinNumExpr(
285+
ast.NumVal(1),
286+
ast.ExpExpr(
287+
ast.BinNumExpr(
288+
ast.NumVal(0),
289+
ast.SubroutineExpr(
290+
ast.BinNumExpr(
291+
ast.BinNumExpr(
292+
ast.NumVal(-0.0),
293+
ast.SubroutineExpr(
294+
ast.IfExpr(
295+
ast.CompExpr(
296+
ast.FeatureRef(20),
297+
ast.NumVal(16.7950001),
298+
ast.CompOpType.GTE),
299+
ast.NumVal(-0.17062147),
300+
ast.NumVal(0.1638484))),
301+
ast.BinNumOpType.ADD),
302+
ast.SubroutineExpr(
303+
ast.IfExpr(
304+
ast.CompExpr(
305+
ast.FeatureRef(27),
306+
ast.NumVal(0.142349988),
307+
ast.CompOpType.GTE),
308+
ast.NumVal(-0.16087772),
309+
ast.NumVal(0.149866998))),
310+
ast.BinNumOpType.ADD)),
311+
ast.BinNumOpType.SUB)),
312+
ast.BinNumOpType.ADD),
313+
ast.BinNumOpType.DIV,
314+
to_reuse=True)
315+
316+
expected = ast.VectorVal([
317+
ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB),
318+
sigmoid])
319+
320+
assert utils.cmp_exprs(actual, expected)

tests/e2e/test_e2e.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,30 @@ def classification_binary(model):
5151
)
5252

5353

54+
def regression_random(model):
55+
return (
56+
model,
57+
utils.train_model_regression_random_data,
58+
REGRESSION,
59+
)
60+
61+
62+
def classification_random(model):
63+
return (
64+
model,
65+
utils.train_model_classification_random_data,
66+
CLASSIFICATION,
67+
)
68+
69+
70+
def classification_binary_random(model):
71+
return (
72+
model,
73+
utils.train_model_classification_binary_random_data,
74+
CLASSIFICATION,
75+
)
76+
77+
5478
# Absolute tolerance. Used in np.isclose to compare 2 values.
5579
# We compare 6 decimal digits.
5680
ATOL = 1.e-6
@@ -63,6 +87,11 @@ def classification_binary(model):
6387
LIGHT_GBM_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED)
6488
SVC_PARAMS = dict(random_state=RANDOM_SEED, decision_function_shape="ovo")
6589

90+
XGBOOST_PARAMS_LARGE = dict(base_score=0.6, n_estimators=100, max_depth=12,
91+
random_state=RANDOM_SEED)
92+
LIGHT_GBM_PARAMS_LARGE = dict(n_estimators=100, num_leaves=100, max_depth=64,
93+
random_state=RANDOM_SEED)
94+
6695

6796
@utils.cartesian_e2e_params(
6897
# These are the languages which support all models specified in the
@@ -85,11 +114,27 @@ def classification_binary(model):
85114
classification(lightgbm.LGBMClassifier(**LIGHT_GBM_PARAMS)),
86115
classification_binary(lightgbm.LGBMClassifier(**LIGHT_GBM_PARAMS)),
87116
117+
# LightGBM (Large Trees)
118+
regression_random(
119+
lightgbm.LGBMRegressor(**LIGHT_GBM_PARAMS_LARGE)),
120+
classification_random(
121+
lightgbm.LGBMClassifier(**LIGHT_GBM_PARAMS_LARGE)),
122+
classification_binary_random(
123+
lightgbm.LGBMClassifier(**LIGHT_GBM_PARAMS_LARGE)),
124+
88125
# XGBoost
89126
regression(xgboost.XGBRegressor(**XGBOOST_PARAMS)),
90127
classification(xgboost.XGBClassifier(**XGBOOST_PARAMS)),
91128
classification_binary(xgboost.XGBClassifier(**XGBOOST_PARAMS)),
92129
130+
# XGBoost (Large Trees)
131+
regression_random(
132+
xgboost.XGBRegressor(**XGBOOST_PARAMS_LARGE)),
133+
classification_random(
134+
xgboost.XGBClassifier(**XGBOOST_PARAMS_LARGE)),
135+
classification_binary_random(
136+
xgboost.XGBClassifier(**XGBOOST_PARAMS_LARGE)),
137+
93138
# Linear SVM
94139
regression(svm.LinearSVR(random_state=RANDOM_SEED)),
95140
classification(svm.LinearSVC(random_state=RANDOM_SEED)),

0 commit comments

Comments
 (0)