|
3 | 3 | import copy |
4 | 4 | import numbers |
5 | 5 | import numpy as np |
6 | | -import onnx |
7 | 6 | from collections import Counter |
8 | 7 | from ...common._apply_operation import ( |
9 | 8 | apply_div, apply_reshape, apply_sub, apply_cast, apply_identity, apply_clip) |
10 | 9 | from ...common._registration import register_converter |
11 | 10 | from ...common.tree_ensemble import get_default_tree_classifier_attribute_pairs |
12 | 11 | from ....proto import onnx_proto |
13 | | -from onnxconverter_common.container import ModelComponentContainer |
14 | 12 |
|
15 | 13 |
|
16 | 14 | def _translate_split_criterion(criterion): |
@@ -399,26 +397,18 @@ def convert_lightgbm(scope, operator, container): |
399 | 397 |
|
400 | 398 | apply_div(scope, [output_name, denominator_name], |
401 | 399 | operator.output_full_names, container, broadcast=1) |
| 400 | + elif post_transform: |
| 401 | + container.add_node( |
| 402 | + post_transform, |
| 403 | + output_name, |
| 404 | + operator.output_full_names, |
| 405 | + name=scope.get_unique_operator_name( |
| 406 | + post_transform), |
| 407 | + ) |
402 | 408 | else: |
403 | 409 | container.add_node('Identity', output_name, |
404 | 410 | operator.output_full_names, |
405 | 411 | name=scope.get_unique_operator_name('Identity')) |
406 | | - if post_transform: |
407 | | - _add_post_transform_node(container, post_transform) |
408 | | - |
409 | | - |
410 | | -def _add_post_transform_node(container: ModelComponentContainer, op_type: str): |
411 | | - """ |
412 | | - Add a post transform node to a ModelComponentContainer. |
413 | | -
|
414 | | - Useful for post transform functions that are not supported by the ONNX spec yet (e.g. 'Exp'). |
415 | | - """ |
416 | | - assert len(container.outputs) == 1, "Adding a post transform node is only possible for models with 1 output." |
417 | | - original_output_name = container.outputs[0].name |
418 | | - new_output_name = f"{op_type.lower()}_{original_output_name}" |
419 | | - post_transform_node = onnx.helper.make_node(op_type, inputs=[original_output_name], outputs=[new_output_name]) |
420 | | - container.nodes.append(post_transform_node) |
421 | | - container.outputs[0].name = new_output_name |
422 | 412 |
|
423 | 413 |
|
424 | 414 | def modify_tree_for_rule_in_set(gbm, use_float=False): |
|
0 commit comments