@@ -9,10 +9,12 @@ class BaseBoostingAssembler(ModelAssembler):
99
1010 classifier_name = None
1111
12- def __init__ (self , model , trees , base_score = 0 , tree_limit = None ):
12+ def __init__ (self , model , trees , base_score = 0 , tree_limit = None ,
13+ leaves_cutoff_threshold = 3000 ):
1314 super ().__init__ (model )
1415 self .all_trees = trees
1516 self ._base_score = base_score
17+ self ._leaves_cutoff_threshold = leaves_cutoff_threshold
1618
1719 self ._output_size = 1
1820 self ._is_classification = False
@@ -41,10 +43,19 @@ def _assemble_single_output(self, trees, base_score=0):
4143 trees = trees [:self ._tree_limit ]
4244
4345 trees_ast = [self ._assemble_tree (t ) for t in trees ]
46+ to_sum = trees_ast
47+
48+ # In a large tree we need to generate multiple subroutines to avoid
49+ # java limitations https://github.com/BayesWitnesses/m2cgen/issues/103.
50+ trees_num_leaves = [self ._count_leaves (t ) for t in trees ]
51+ if sum (trees_num_leaves ) > self ._leaves_cutoff_threshold :
52+ to_sum = self ._split_into_subroutines (trees_ast , trees_num_leaves )
53+
4454 result_ast = utils .apply_op_to_expressions (
4555 ast .BinNumOpType .ADD ,
4656 ast .NumVal (base_score ),
47- * trees_ast )
57+ * to_sum )
58+
4859 return ast .SubroutineExpr (result_ast )
4960
5061 def _assemble_multi_class_output (self , trees ):
@@ -74,15 +85,47 @@ def _assemble_bin_class_output(self, trees):
7485 proba_expr
7586 ])
7687
88+ def _split_into_subroutines (self , trees_ast , trees_num_leaves ):
89+ result = []
90+ subroutine_trees = []
91+ subroutine_sum_leaves = 0
92+ for tree , num_leaves in zip (trees_ast , trees_num_leaves ):
93+ next_sum = subroutine_sum_leaves + num_leaves
94+ if subroutine_trees and next_sum > self ._leaves_cutoff_threshold :
95+ # Exceeded the max leaves in the current subroutine,
96+ # finalize this one and start a new one.
97+ partial_result = utils .apply_op_to_expressions (
98+ ast .BinNumOpType .ADD ,
99+ * subroutine_trees )
100+
101+ result .append (ast .SubroutineExpr (partial_result ))
102+
103+ subroutine_trees = []
104+ subroutine_sum_leaves = 0
105+
106+ subroutine_sum_leaves += num_leaves
107+ subroutine_trees .append (tree )
108+
109+ if subroutine_trees :
110+ partial_result = utils .apply_op_to_expressions (
111+ ast .BinNumOpType .ADD ,
112+ * subroutine_trees )
113+ result .append (ast .SubroutineExpr (partial_result ))
114+ return result
115+
77116 def _assemble_tree (self , tree ):
78117 raise NotImplementedError
79118
119+ @staticmethod
120+ def _count_leaves (trees ):
121+ raise NotImplementedError
122+
80123
81124class XGBoostModelAssembler (BaseBoostingAssembler ):
82125
83126 classifier_name = "XGBClassifier"
84127
85- def __init__ (self , model ):
128+ def __init__ (self , model , leaves_cutoff_threshold = 3000 ):
86129 feature_names = model .get_booster ().feature_names
87130 self ._feature_name_to_idx = {
88131 name : idx for idx , name in enumerate (feature_names or [])
@@ -96,7 +139,8 @@ def __init__(self, model):
96139 best_ntree_limit = getattr (model , "best_ntree_limit" , None )
97140
98141 super ().__init__ (model , trees , base_score = model .base_score ,
99- tree_limit = best_ntree_limit )
142+ tree_limit = best_ntree_limit ,
143+ leaves_cutoff_threshold = leaves_cutoff_threshold )
100144
101145 def _assemble_tree (self , tree ):
102146 if "leaf" in tree :
@@ -130,16 +174,31 @@ def _assemble_child_tree(self, tree, child_id):
130174 return self ._assemble_tree (child )
131175 assert False , "Unexpected child ID {}" .format (child_id )
132176
177+ @staticmethod
178+ def _count_leaves (tree ):
179+ queue = [tree ]
180+ num_leaves = 0
181+
182+ while queue :
183+ tree = queue .pop ()
184+ if "leaf" in tree :
185+ num_leaves += 1
186+ elif "children" in tree :
187+ for child in tree ["children" ]:
188+ queue .append (child )
189+ return num_leaves
190+
133191
134192class LightGBMModelAssembler (BaseBoostingAssembler ):
135193
136194 classifier_name = "LGBMClassifier"
137195
138- def __init__ (self , model ):
196+ def __init__ (self , model , leaves_cutoff_threshold = 3000 ):
139197 model_dump = model .booster_ .dump_model ()
140198 trees = [m ["tree_structure" ] for m in model_dump ["tree_info" ]]
141199
142- super ().__init__ (model , trees )
200+ super ().__init__ (model , trees ,
201+ leaves_cutoff_threshold = leaves_cutoff_threshold )
143202
144203 def _assemble_tree (self , tree ):
145204 if "leaf_value" in tree :
@@ -151,9 +210,9 @@ def _assemble_tree(self, tree):
151210 op = ast .CompOpType .from_str_op (tree ["decision_type" ])
152211 assert op == ast .CompOpType .LTE , "Unexpected comparison op"
153212
154- # Make sure that if the ' default_left' is true the left tree branch
213+ # Make sure that if the " default_left" is true the left tree branch
155214 # ends up in the "else" branch of the ast.IfExpr.
156- if tree [' default_left' ]:
215+ if tree [" default_left" ]:
157216 op = ast .CompOpType .GT
158217 true_child = tree ["right_child" ]
159218 false_child = tree ["left_child" ]
@@ -166,6 +225,20 @@ def _assemble_tree(self, tree):
166225 self ._assemble_tree (true_child ),
167226 self ._assemble_tree (false_child ))
168227
228+ @staticmethod
229+ def _count_leaves (tree ):
230+ queue = [tree ]
231+ num_leaves = 0
232+
233+ while queue :
234+ tree = queue .pop ()
235+ if "leaf_value" in tree :
236+ num_leaves += 1
237+ else :
238+ queue .append (tree ["left_child" ])
239+ queue .append (tree ["right_child" ])
240+ return num_leaves
241+
169242
170243def _split_trees_by_classes (trees , n_classes ):
171244 # Splits are computed based on a comment
0 commit comments