99import numpy as np
1010from collections import Counter
1111from lightgbm import LGBMClassifier , LGBMRegressor
12+ from ...common ._apply_operation import apply_div , apply_reshape , apply_sub
1213from ...common ._registration import register_converter
1314from ...common .tree_ensemble import get_default_tree_classifier_attribute_pairs
15+ from ....proto import onnx_proto
1416
1517
1618def _translate_split_criterion (criterion ):
@@ -114,8 +116,6 @@ def _parse_node(tree_id, class_id, node_id, node_id_pool, learning_rate, node, a
114116
115117def convert_lightgbm (scope , operator , container ):
116118 gbm_model = operator .raw_operator
117- if gbm_model .boosting_type != 'gbdt' :
118- raise ValueError ('Only support LightGBM classifier with boosting_type=gbdt' )
119119 gbm_text = gbm_model .booster_ .dump_model ()
120120
121121 attrs = get_default_tree_classifier_attribute_pairs ()
@@ -161,8 +161,10 @@ def convert_lightgbm(scope, operator, container):
161161 # Create ONNX object
162162 if isinstance (gbm_model , LGBMClassifier ):
163163 # Prepare label information for both of TreeEnsembleClassifier and ZipMap
164+ class_type = onnx_proto .TensorProto .STRING
164165 zipmap_attrs = {'name' : scope .get_unique_variable_name ('ZipMap' )}
165166 if all (isinstance (i , (numbers .Real , bool , np .bool_ )) for i in gbm_model .classes_ ):
167+ class_type = onnx_proto .TensorProto .INT64
166168 class_labels = [int (i ) for i in gbm_model .classes_ ]
167169 attrs ['classlabels_int64s' ] = class_labels
168170 zipmap_attrs ['classlabels_int64s' ] = class_labels
@@ -175,23 +177,71 @@ def convert_lightgbm(scope, operator, container):
175177
176178 # Create tree classifier
177179 probability_tensor_name = scope .get_unique_variable_name ('probability_tensor' )
180+ label_tensor_name = scope .get_unique_variable_name ('label_tensor' )
181+
178182 container .add_node ('TreeEnsembleClassifier' , operator .input_full_names ,
179- [operator . outputs [ 0 ]. full_name , probability_tensor_name ],
183+ [label_tensor_name , probability_tensor_name ],
180184 op_domain = 'ai.onnx.ml' , ** attrs )
185+ prob_tensor = probability_tensor_name
186+
187+ if gbm_model .boosting_type == 'rf' :
188+ col_index_name = scope .get_unique_variable_name ('col_index' )
189+ first_col_name = scope .get_unique_variable_name ('first_col' )
190+ zeroth_col_name = scope .get_unique_variable_name ('zeroth_col' )
191+ denominator_name = scope .get_unique_variable_name ('denominator' )
192+ modified_first_col_name = scope .get_unique_variable_name ('modified_first_col' )
193+ unit_float_tensor_name = scope .get_unique_variable_name ('unit_float_tensor' )
194+ merged_prob_name = scope .get_unique_variable_name ('merged_prob' )
195+ predicted_label_name = scope .get_unique_variable_name ('predicted_label' )
196+ classes_name = scope .get_unique_variable_name ('classes' )
197+ final_label_name = scope .get_unique_variable_name ('final_label' )
198+
199+ container .add_initializer (col_index_name , onnx_proto .TensorProto .INT64 , [], [1 ])
200+ container .add_initializer (unit_float_tensor_name , onnx_proto .TensorProto .FLOAT , [], [1.0 ])
201+ container .add_initializer (denominator_name , onnx_proto .TensorProto .FLOAT , [], [100.0 ])
202+ container .add_initializer (classes_name , class_type ,
203+ [len (class_labels )], class_labels )
204+
205+ container .add_node ('ArrayFeatureExtractor' , [probability_tensor_name , col_index_name ],
206+ first_col_name , name = scope .get_unique_operator_name ('ArrayFeatureExtractor' ),
207+ op_domain = 'ai.onnx.ml' )
208+ apply_div (scope , [first_col_name , denominator_name ], modified_first_col_name , container , broadcast = 1 )
209+ apply_sub (scope , [unit_float_tensor_name , modified_first_col_name ], zeroth_col_name , container , broadcast = 1 )
210+ container .add_node ('Concat' , [zeroth_col_name , modified_first_col_name ],
211+ merged_prob_name , name = scope .get_unique_operator_name ('Concat' ), axis = 1 )
212+ container .add_node ('ArgMax' , merged_prob_name ,
213+ predicted_label_name , name = scope .get_unique_operator_name ('ArgMax' ), axis = 1 )
214+ container .add_node ('ArrayFeatureExtractor' , [classes_name , predicted_label_name ], final_label_name ,
215+ name = scope .get_unique_operator_name ('ArrayFeatureExtractor' ), op_domain = 'ai.onnx.ml' )
216+ apply_reshape (scope , final_label_name , operator .outputs [0 ].full_name , container , desired_shape = [- 1 ,])
217+ prob_tensor = merged_prob_name
218+ else :
219+ container .add_node ('Identity' , label_tensor_name , operator .outputs [0 ].full_name )
181220
182221 # Convert probability tensor to probability map (keys are labels while values are the associated probabilities)
183- container .add_node ('ZipMap' , probability_tensor_name , operator .outputs [1 ].full_name ,
222+ container .add_node ('ZipMap' , prob_tensor , operator .outputs [1 ].full_name ,
184223 op_domain = 'ai.onnx.ml' , ** zipmap_attrs )
185224 else :
186225 # Create tree regressor
226+ output_name = scope .get_unique_variable_name ('output' )
227+
187228 keys_to_be_renamed = list (k for k in attrs .keys () if k .startswith ('class_' ))
188229 for k in keys_to_be_renamed :
189230 # Rename class_* attribute to target_* because TreeEnsebmleClassifier and TreeEnsembleClassifier have
190231 # different ONNX attributes
191232 attrs ['target' + k [5 :]] = copy .deepcopy (attrs [k ])
192233 del attrs [k ]
193234 container .add_node ('TreeEnsembleRegressor' , operator .input_full_names ,
194- operator .output_full_names , op_domain = 'ai.onnx.ml' , ** attrs )
235+ output_name , op_domain = 'ai.onnx.ml' , ** attrs )
236+
237+ if gbm_model .boosting_type == 'rf' :
238+ denominator_name = scope .get_unique_variable_name ('denominator' )
239+
240+ container .add_initializer (denominator_name , onnx_proto .TensorProto .FLOAT , [], [100.0 ])
241+
242+ apply_div (scope , [output_name , denominator_name ], operator .output_full_names , container , broadcast = 1 )
243+ else :
244+ container .add_node ('Identity' , output_name , operator .output_full_names )
195245
196246
197247register_converter ('LgbmClassifier' , convert_lightgbm )
0 commit comments