Skip to content

Commit d20a738

Browse files
prabhat00155wenbingl
authored andcommitted
Prroy/converter feature union (#183)
Added feature union converter
1 parent c40305a commit d20a738

File tree

6 files changed

+60
-3
lines changed

6 files changed

+60
-3
lines changed

onnxmltools/convert/sklearn/_parse.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,29 @@ def _parse_sklearn_pipeline(scope, model, inputs):
150150
return inputs
151151

152152

153+
def _parse_sklearn_feature_union(scope, model, inputs):
154+
'''
155+
:param scope: Scope object
156+
:param model: A scikit-learn FeatureUnion object
157+
:param inputs: A list of Variable objects
158+
:return: A list of output variables produced by feature union
159+
'''
160+
# Output variable name of each transform. It's a list of string.
161+
transformed_result_names = []
162+
# Encode each transform as our IR object
163+
for name, transform in model.transformer_list:
164+
transformed_result_names.append(_parse_sklearn_simple_model(scope, transform, inputs)[0])
165+
# Create a Concat ONNX node
166+
concat_operator = scope.declare_local_operator('SklearnConcat')
167+
concat_operator.inputs = transformed_result_names
168+
169+
# Declare output name of scikit-learn FeatureUnion
170+
union_name = scope.declare_local_variable('union', FloatTensorType())
171+
concat_operator.outputs.append(union_name)
172+
173+
return concat_operator.outputs
174+
175+
153176
def _parse_sklearn(scope, model, inputs):
154177
'''
155178
This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input
@@ -161,6 +184,8 @@ def _parse_sklearn(scope, model, inputs):
161184
'''
162185
if isinstance(model, pipeline.Pipeline):
163186
return _parse_sklearn_pipeline(scope, model, inputs)
187+
elif isinstance(model, pipeline.FeatureUnion):
188+
return _parse_sklearn_feature_union(scope, model, inputs)
164189
else:
165190
return _parse_sklearn_simple_model(scope, model, inputs)
166191

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
from ...common._registration import register_converter
8+
9+
10+
def convert_sklearn_concat(scope, operator, container):
11+
container.add_node('Concat', [s for s in operator.input_full_names],
12+
operator.outputs[0].full_name, name=scope.get_unique_operator_name('Concat'), axis=1)
13+
14+
15+
register_converter('SklearnConcat', convert_sklearn_concat)

onnxmltools/convert/sklearn/operator_converters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# To register converter for scikit-learn operators, import associated modules here.
88
from . import Binarizer
9+
from . import Concat
910
from . import DictVectorizer
1011
from . import DecisionTree
1112
from . import GradientBoosting
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
7+
from ...common._registration import register_shape_calculator
8+
from ...common.utils import check_input_and_output_numbers
9+
10+
11+
def calculate_sklearn_concat(operator):
12+
check_input_and_output_numbers(operator, output_count_range=1)
13+
14+
15+
register_shape_calculator('SklearnConcat', calculate_sklearn_concat)

onnxmltools/convert/sklearn/shape_calculators/LabelEncoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
1111

1212

13-
def calculate_sklearn_lebel_encoder_output_shapes(operator):
13+
def calculate_sklearn_label_encoder_output_shapes(operator):
1414
'''
1515
This function just copy the input shape to the output because label encoder only alters input features' values, not
1616
their shape.
@@ -22,4 +22,4 @@ def calculate_sklearn_lebel_encoder_output_shapes(operator):
2222
operator.outputs[0].type = Int64TensorType(copy.deepcopy(input_shape))
2323

2424

25-
register_shape_calculator('SklearnLabelEncoder', calculate_sklearn_lebel_encoder_output_shapes)
25+
register_shape_calculator('SklearnLabelEncoder', calculate_sklearn_label_encoder_output_shapes)

onnxmltools/convert/sklearn/shape_calculators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# --------------------------------------------------------------------------
66

77
# To register shape calculators for scikit-learn operators, import associated modules here.
8-
from . import TextVectorizer
8+
from . import Concat
99
from . import DictVectorizer
1010
from . import Imputer
1111
from . import LabelEncoder
@@ -15,3 +15,4 @@
1515
from . import Scaler
1616
from . import SVM
1717
from . import SVD
18+
from . import TextVectorizer

0 commit comments

Comments
 (0)