Skip to content

Commit c2b6b59

Browse files
authored
Added support for remaining boosting types in lightgbm (#212)
1 parent 1eb1054 commit c2b6b59

File tree

1 file changed

+55
-5
lines changed
  • onnxmltools/convert/lightgbm/operator_converters

1 file changed

+55
-5
lines changed

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import numpy as np
1010
from collections import Counter
1111
from lightgbm import LGBMClassifier, LGBMRegressor
12+
from ...common._apply_operation import apply_div, apply_reshape, apply_sub
1213
from ...common._registration import register_converter
1314
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
15+
from ....proto import onnx_proto
1416

1517

1618
def _translate_split_criterion(criterion):
@@ -114,8 +116,6 @@ def _parse_node(tree_id, class_id, node_id, node_id_pool, learning_rate, node, a
114116

115117
def convert_lightgbm(scope, operator, container):
116118
gbm_model = operator.raw_operator
117-
if gbm_model.boosting_type != 'gbdt':
118-
raise ValueError('Only support LightGBM classifier with boosting_type=gbdt')
119119
gbm_text = gbm_model.booster_.dump_model()
120120

121121
attrs = get_default_tree_classifier_attribute_pairs()
@@ -161,8 +161,10 @@ def convert_lightgbm(scope, operator, container):
161161
# Create ONNX object
162162
if isinstance(gbm_model, LGBMClassifier):
163163
# Prepare label information for both of TreeEnsembleClassifier and ZipMap
164+
class_type = onnx_proto.TensorProto.STRING
164165
zipmap_attrs = {'name': scope.get_unique_variable_name('ZipMap')}
165166
if all(isinstance(i, (numbers.Real, bool, np.bool_)) for i in gbm_model.classes_):
167+
class_type = onnx_proto.TensorProto.INT64
166168
class_labels = [int(i) for i in gbm_model.classes_]
167169
attrs['classlabels_int64s'] = class_labels
168170
zipmap_attrs['classlabels_int64s'] = class_labels
@@ -175,23 +177,71 @@ def convert_lightgbm(scope, operator, container):
175177

176178
# Create tree classifier
177179
probability_tensor_name = scope.get_unique_variable_name('probability_tensor')
180+
label_tensor_name = scope.get_unique_variable_name('label_tensor')
181+
178182
container.add_node('TreeEnsembleClassifier', operator.input_full_names,
179-
[operator.outputs[0].full_name, probability_tensor_name],
183+
[label_tensor_name, probability_tensor_name],
180184
op_domain='ai.onnx.ml', **attrs)
185+
prob_tensor = probability_tensor_name
186+
187+
if gbm_model.boosting_type == 'rf':
188+
col_index_name = scope.get_unique_variable_name('col_index')
189+
first_col_name = scope.get_unique_variable_name('first_col')
190+
zeroth_col_name = scope.get_unique_variable_name('zeroth_col')
191+
denominator_name = scope.get_unique_variable_name('denominator')
192+
modified_first_col_name = scope.get_unique_variable_name('modified_first_col')
193+
unit_float_tensor_name = scope.get_unique_variable_name('unit_float_tensor')
194+
merged_prob_name = scope.get_unique_variable_name('merged_prob')
195+
predicted_label_name = scope.get_unique_variable_name('predicted_label')
196+
classes_name = scope.get_unique_variable_name('classes')
197+
final_label_name = scope.get_unique_variable_name('final_label')
198+
199+
container.add_initializer(col_index_name, onnx_proto.TensorProto.INT64, [], [1])
200+
container.add_initializer(unit_float_tensor_name, onnx_proto.TensorProto.FLOAT, [], [1.0])
201+
container.add_initializer(denominator_name, onnx_proto.TensorProto.FLOAT, [], [100.0])
202+
container.add_initializer(classes_name, class_type,
203+
[len(class_labels)], class_labels)
204+
205+
container.add_node('ArrayFeatureExtractor', [probability_tensor_name, col_index_name],
206+
first_col_name, name=scope.get_unique_operator_name('ArrayFeatureExtractor'),
207+
op_domain='ai.onnx.ml')
208+
apply_div(scope, [first_col_name, denominator_name], modified_first_col_name, container, broadcast=1)
209+
apply_sub(scope, [unit_float_tensor_name, modified_first_col_name], zeroth_col_name, container, broadcast=1)
210+
container.add_node('Concat', [zeroth_col_name, modified_first_col_name],
211+
merged_prob_name, name=scope.get_unique_operator_name('Concat'), axis=1)
212+
container.add_node('ArgMax', merged_prob_name,
213+
predicted_label_name, name=scope.get_unique_operator_name('ArgMax'), axis=1)
214+
container.add_node('ArrayFeatureExtractor', [classes_name, predicted_label_name], final_label_name,
215+
name=scope.get_unique_operator_name('ArrayFeatureExtractor'), op_domain='ai.onnx.ml')
216+
apply_reshape(scope, final_label_name, operator.outputs[0].full_name, container, desired_shape=[-1,])
217+
prob_tensor = merged_prob_name
218+
else:
219+
container.add_node('Identity', label_tensor_name, operator.outputs[0].full_name)
181220

182221
# Convert probability tensor to probability map (keys are labels while values are the associated probabilities)
183-
container.add_node('ZipMap', probability_tensor_name, operator.outputs[1].full_name,
222+
container.add_node('ZipMap', prob_tensor, operator.outputs[1].full_name,
184223
op_domain='ai.onnx.ml', **zipmap_attrs)
185224
else:
186225
# Create tree regressor
226+
output_name = scope.get_unique_variable_name('output')
227+
187228
keys_to_be_renamed = list(k for k in attrs.keys() if k.startswith('class_'))
188229
for k in keys_to_be_renamed:
189230
# Rename class_* attribute to target_* because TreeEnsebmleClassifier and TreeEnsembleClassifier have
190231
# different ONNX attributes
191232
attrs['target' + k[5:]] = copy.deepcopy(attrs[k])
192233
del attrs[k]
193234
container.add_node('TreeEnsembleRegressor', operator.input_full_names,
194-
operator.output_full_names, op_domain='ai.onnx.ml', **attrs)
235+
output_name, op_domain='ai.onnx.ml', **attrs)
236+
237+
if gbm_model.boosting_type == 'rf':
238+
denominator_name = scope.get_unique_variable_name('denominator')
239+
240+
container.add_initializer(denominator_name, onnx_proto.TensorProto.FLOAT, [], [100.0])
241+
242+
apply_div(scope, [output_name, denominator_name], operator.output_full_names, container, broadcast=1)
243+
else:
244+
container.add_node('Identity', output_name, operator.output_full_names)
195245

196246

197247
register_converter('LgbmClassifier', convert_lightgbm)

0 commit comments

Comments
 (0)