Skip to content

Commit fd044d7

Browse files
authored
Allow to add custom post transform functions that are not supported by the ONNX spec yet (#463)
* add _add_post_transform_node function to allow for custom post transform nodes that are not supported by the ONNX spec yet Signed-off-by: Jan-Benedikt Jagusch <[email protected]> * add 'poisson' support also to WrappedBooster Signed-off-by: Jan-Benedikt Jagusch <[email protected]>
1 parent 9542999 commit fd044d7

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def __init__(self, booster):
2727
if (_model_dict['objective'].startswith('binary') or
2828
_model_dict['objective'].startswith('multiclass')):
2929
self.operator_name = 'LgbmClassifier'
30-
elif _model_dict['objective'].startswith('regression'):
30+
elif (_model_dict['objective'].startswith('regression') or
31+
_model_dict['objective'].startswith('poisson')):
3132
self.operator_name = 'LgbmRegressor'
3233
else:
3334
# Other objectives are not supported.
@@ -170,4 +171,4 @@ def parse_lightgbm(model, initial_types=None, target_opset=None,
170171
for variable in outputs:
171172
raw_model_container.add_output(variable)
172173

173-
return topology
174+
return topology

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
import copy
44
import numbers
55
import numpy as np
6+
import onnx
67
from collections import Counter
78
from ...common._apply_operation import (
89
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
910
from ...common._registration import register_converter
1011
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
1112
from ....proto import onnx_proto
13+
from onnxconverter_common.container import ModelComponentContainer
1214

1315

1416
def _translate_split_criterion(criterion):
@@ -222,6 +224,7 @@ def convert_lightgbm(scope, operator, container):
222224

223225
# Create different attributes for classifier and
224226
# regressor, respectively
227+
post_transform = None
225228
if gbm_text['objective'].startswith('binary'):
226229
n_classes = 1
227230
attrs['post_transform'] = 'LOGISTIC'
@@ -232,6 +235,13 @@ def convert_lightgbm(scope, operator, container):
232235
n_classes = 1 # Regressor has only one output variable
233236
attrs['post_transform'] = 'NONE'
234237
attrs['n_targets'] = n_classes
238+
elif gbm_text['objective'].startswith('poisson'):
239+
n_classes = 1 # Regressor has only one output variable
240+
attrs['n_targets'] = n_classes
241+
# 'Exp' is not a supported post_transform value in the ONNX spec yet,
242+
# so we need to add an 'Exp' post transform node to the model
243+
attrs['post_transform'] = 'NONE'
244+
post_transform = "Exp"
235245
else:
236246
raise RuntimeError(
237247
"LightGBM objective should be cleaned already not '{}'.".format(
@@ -392,6 +402,22 @@ def convert_lightgbm(scope, operator, container):
392402
container.add_node('Identity', output_name,
393403
operator.output_full_names,
394404
name=scope.get_unique_operator_name('Identity'))
405+
if post_transform:
406+
_add_post_transform_node(container, post_transform)
407+
408+
409+
def _add_post_transform_node(container: ModelComponentContainer, op_type: str):
410+
"""
411+
Add a post transform node to a ModelComponentContainer.
412+
413+
Useful for post transform functions that are not supported by the ONNX spec yet (e.g. 'Exp').
414+
"""
415+
assert len(container.outputs) == 1, "Adding a post transform node is only possible for models with 1 output."
416+
original_output_name = container.outputs[0].name
417+
new_output_name = f"{op_type.lower()}_{original_output_name}"
418+
post_transform_node = onnx.helper.make_node(op_type, inputs=[original_output_name], outputs=[new_output_name])
419+
container.nodes.append(post_transform_node)
420+
container.outputs[0].name = new_output_name
395421

396422

397423
def modify_tree_for_rule_in_set(gbm, use_float=False):

0 commit comments

Comments
 (0)