33import copy
44import numbers
55import numpy as np
6+ import onnx
67from collections import Counter
78from ...common ._apply_operation import (
89 apply_div , apply_reshape , apply_sub , apply_cast , apply_identity , apply_clip )
910from ...common ._registration import register_converter
1011from ...common .tree_ensemble import get_default_tree_classifier_attribute_pairs
1112from ....proto import onnx_proto
13+ from onnxconverter_common .container import ModelComponentContainer
1214
1315
1416def _translate_split_criterion (criterion ):
@@ -222,6 +224,7 @@ def convert_lightgbm(scope, operator, container):
222224
223225 # Create different attributes for classifier and
224226 # regressor, respectively
227+ post_transform = None
225228 if gbm_text ['objective' ].startswith ('binary' ):
226229 n_classes = 1
227230 attrs ['post_transform' ] = 'LOGISTIC'
@@ -232,6 +235,13 @@ def convert_lightgbm(scope, operator, container):
232235 n_classes = 1 # Regressor has only one output variable
233236 attrs ['post_transform' ] = 'NONE'
234237 attrs ['n_targets' ] = n_classes
238+ elif gbm_text ['objective' ].startswith ('poisson' ):
239+ n_classes = 1 # Regressor has only one output variable
240+ attrs ['n_targets' ] = n_classes
241+ # 'Exp' is not a supported post_transform value in the ONNX spec yet,
242+ # so we need to add an 'Exp' post transform node to the model
243+ attrs ['post_transform' ] = 'NONE'
244+ post_transform = "Exp"
235245 else :
236246 raise RuntimeError (
237247 "LightGBM objective should be cleaned already not '{}'." .format (
@@ -392,6 +402,22 @@ def convert_lightgbm(scope, operator, container):
392402 container .add_node ('Identity' , output_name ,
393403 operator .output_full_names ,
394404 name = scope .get_unique_operator_name ('Identity' ))
405+ if post_transform :
406+ _add_post_transform_node (container , post_transform )
407+
408+
409+ def _add_post_transform_node (container : ModelComponentContainer , op_type : str ):
410+ """
411+ Add a post transform node to a ModelComponentContainer.
412+
413+ Useful for post transform functions that are not supported by the ONNX spec yet (e.g. 'Exp').
414+ """
415+ assert len (container .outputs ) == 1 , "Adding a post transform node is only possible for models with 1 output."
416+ original_output_name = container .outputs [0 ].name
417+ new_output_name = f"{ op_type .lower ()} _{ original_output_name } "
418+ post_transform_node = onnx .helper .make_node (op_type , inputs = [original_output_name ], outputs = [new_output_name ])
419+ container .nodes .append (post_transform_node )
420+ container .outputs [0 ].name = new_output_name
395421
396422
397423def modify_tree_for_rule_in_set (gbm , use_float = False ):
0 commit comments