Skip to content

Commit f295db6

Browse files
authored
add node to container instead of changing output node of container (#478)
* add node to container instead of changing output node of container * remove unused imports Signed-off-by: Jan-Benedikt Jagusch <[email protected]>
1 parent 93e95db commit f295db6

File tree

1 file changed

+8
-18
lines changed
  • onnxmltools/convert/lightgbm/operator_converters

1 file changed

+8
-18
lines changed

onnxmltools/convert/lightgbm/operator_converters/LightGbm.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,12 @@
33
import copy
44
import numbers
55
import numpy as np
6-
import onnx
76
from collections import Counter
87
from ...common._apply_operation import (
98
apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip)
109
from ...common._registration import register_converter
1110
from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs
1211
from ....proto import onnx_proto
13-
from onnxconverter_common.container import ModelComponentContainer
1412

1513

1614
def _translate_split_criterion(criterion):
@@ -399,26 +397,18 @@ def convert_lightgbm(scope, operator, container):
399397

400398
apply_div(scope, [output_name, denominator_name],
401399
operator.output_full_names, container, broadcast=1)
400+
elif post_transform:
401+
container.add_node(
402+
post_transform,
403+
output_name,
404+
operator.output_full_names,
405+
name=scope.get_unique_operator_name(
406+
post_transform),
407+
)
402408
else:
403409
container.add_node('Identity', output_name,
404410
operator.output_full_names,
405411
name=scope.get_unique_operator_name('Identity'))
406-
if post_transform:
407-
_add_post_transform_node(container, post_transform)
408-
409-
410-
def _add_post_transform_node(container: ModelComponentContainer, op_type: str):
411-
"""
412-
Add a post transform node to a ModelComponentContainer.
413-
414-
Useful for post transform functions that are not supported by the ONNX spec yet (e.g. 'Exp').
415-
"""
416-
assert len(container.outputs) == 1, "Adding a post transform node is only possible for models with 1 output."
417-
original_output_name = container.outputs[0].name
418-
new_output_name = f"{op_type.lower()}_{original_output_name}"
419-
post_transform_node = onnx.helper.make_node(op_type, inputs=[original_output_name], outputs=[new_output_name])
420-
container.nodes.append(post_transform_node)
421-
container.outputs[0].name = new_output_name
422412

423413

424414
def modify_tree_for_rule_in_set(gbm, use_float=False):

0 commit comments

Comments
 (0)