Skip to content

Commit 46a2bf1

Browse files
authored
feat: support inputCols and outputCols interfaces for converting Spark StringIndexerModel (#568)
Signed-off-by: Jason Wang <[email protected]>
1 parent 622a3ba commit 46a2bf1

File tree

3 files changed

+141
-82
lines changed

3 files changed

+141
-82
lines changed
Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,61 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import copy
4+
from typing import List
5+
6+
from pyspark import SparkContext
7+
from pyspark.ml.feature import StringIndexerModel
8+
49
from ...common._registration import register_converter, register_shape_calculator
10+
from ...common._topology import ModelComponentContainer, Operator, Scope, Variable
511
from ...common.data_types import Int64TensorType, StringTensorType
6-
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
12+
from ...common.utils import check_input_and_output_types
13+
714

15+
def convert_sparkml_string_indexer(scope: Scope, operator: Operator, container: ModelComponentContainer):
16+
op: StringIndexerModel = operator.raw_operator
17+
op_domain = "ai.onnx.ml"
18+
op_version = 2
19+
op_type = "LabelEncoder"
820

9-
def convert_sparkml_string_indexer(scope, operator, container):
10-
op = operator.raw_operator
11-
op_type = 'LabelEncoder'
12-
attrs = {
13-
'name': scope.get_unique_operator_name(op_type),
14-
'classes_strings': [str(c) for c in op.labels]
15-
}
21+
labelsArray: List[List[str]]
1622

17-
if isinstance(operator.inputs[0].type, Int64TensorType):
18-
attrs['default_int64'] = -1
19-
elif isinstance(operator.inputs[0].type, StringTensorType):
20-
attrs['default_string'] = '__unknown__'
23+
if SparkContext._active_spark_context.version.startswith("2."):
24+
labelsArray = [op.labels]
2125
else:
22-
raise RuntimeError('Unsupported input type: %s' % type(operator.inputs[0].type))
26+
labelsArray = op.labelsArray
2327

24-
container.add_node(op_type, operator.input_full_names, operator.output_full_names, op_domain='ai.onnx.ml', **attrs)
28+
for i in range(0, len(labelsArray)):
29+
encoder_attrs = {
30+
"name": scope.get_unique_operator_name("StringIndexer_" + str(i)),
31+
"keys_strings": labelsArray[i],
32+
"values_int64s": list(range(0, len(labelsArray[i]))),
33+
}
2534

35+
container.add_node(
36+
op_type,
37+
[operator.inputs[i].full_name],
38+
[operator.outputs[i].full_name],
39+
op_domain=op_domain,
40+
op_version=op_version,
41+
**encoder_attrs,
42+
)
2643

27-
register_converter('pyspark.ml.feature.StringIndexerModel', convert_sparkml_string_indexer)
2844

45+
register_converter("pyspark.ml.feature.StringIndexerModel", convert_sparkml_string_indexer)
2946

30-
def calculate_sparkml_string_indexer_output_shapes(operator):
31-
'''
47+
48+
def calculate_sparkml_string_indexer_output_shapes(operator: Operator):
49+
"""
3250
This function just copy the input shape to the output because label encoder only alters input features' values, not
3351
their shape.
34-
'''
35-
check_input_and_output_numbers(operator, output_count_range=1)
52+
"""
3653
check_input_and_output_types(operator, good_input_types=[Int64TensorType, StringTensorType])
37-
38-
input_shape = copy.deepcopy(operator.inputs[0].type.shape)
39-
operator.outputs[0].type = Int64TensorType(input_shape)
54+
input: Variable
55+
output: Variable
56+
for (input, output) in zip(operator.inputs, operator.outputs):
57+
input_shape = copy.deepcopy(input.type.shape)
58+
output.type = Int64TensorType(input_shape)
4059

4160

42-
register_shape_calculator('pyspark.ml.feature.StringIndexerModel', calculate_sparkml_string_indexer_output_shapes)
61+
register_shape_calculator("pyspark.ml.feature.StringIndexerModel", calculate_sparkml_string_indexer_output_shapes)
Lines changed: 52 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,157 +1,157 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
'''
3+
"""
44
Mapping and utilities for the names of Params(propeties) that various Spark ML models
55
have for their input and output columns
6-
'''
6+
"""
77
from .ops_names import get_sparkml_operator_name
88

99

1010
def build_io_name_map():
11-
'''
11+
"""
1212
map of spark models to input-output tuples
1313
Each lambda gets the corresponding input or output column name from the model
14-
'''
14+
"""
1515
map = {
1616
"pyspark.ml.feature.BucketedRandomProjectionLSHModel": (
1717
lambda model: [model.getOrDefault("inputCol")],
18-
lambda model: [model.getOrDefault("outputCol")]
18+
lambda model: [model.getOrDefault("outputCol")],
1919
),
2020
"pyspark.ml.regression.AFTSurvivalRegressionModel": (
2121
lambda model: [model.getOrDefault("featuresCol")],
22-
lambda model: [model.getOrDefault("predictionCol")]
22+
lambda model: [model.getOrDefault("predictionCol")],
2323
),
2424
"pyspark.ml.feature.ElementwiseProduct": (
2525
lambda model: [model.getOrDefault("inputCol")],
26-
lambda model: [model.getOrDefault("outputCol")]
26+
lambda model: [model.getOrDefault("outputCol")],
2727
),
2828
"pyspark.ml.feature.MinHashLSHModel": (
2929
lambda model: [model.getOrDefault("inputCol")],
30-
lambda model: [model.getOrDefault("outputCol")]
30+
lambda model: [model.getOrDefault("outputCol")],
3131
),
3232
"pyspark.ml.feature.Word2VecModel": (
3333
lambda model: [model.getOrDefault("inputCol")],
34-
lambda model: [model.getOrDefault("outputCol")]
34+
lambda model: [model.getOrDefault("outputCol")],
3535
),
3636
"pyspark.ml.feature.IndexToString": (
3737
lambda model: [model.getOrDefault("inputCol")],
38-
lambda model: [model.getOrDefault("outputCol")]
38+
lambda model: [model.getOrDefault("outputCol")],
3939
),
4040
"pyspark.ml.feature.ChiSqSelectorModel": (
4141
lambda model: [model.getOrDefault("featuresCol")],
42-
lambda model: [model.getOrDefault("outputCol")]
42+
lambda model: [model.getOrDefault("outputCol")],
4343
),
4444
"pyspark.ml.classification.OneVsRestModel": (
4545
lambda model: [model.getOrDefault("featuresCol")],
46-
lambda model: [model.getOrDefault("predictionCol")]
46+
lambda model: [model.getOrDefault("predictionCol")],
4747
),
4848
"pyspark.ml.regression.GBTRegressionModel": (
4949
lambda model: [model.getOrDefault("featuresCol")],
50-
lambda model: [model.getOrDefault("predictionCol")]
50+
lambda model: [model.getOrDefault("predictionCol")],
5151
),
5252
"pyspark.ml.classification.GBTClassificationModel": (
5353
lambda model: [model.getOrDefault("featuresCol")],
54-
lambda model: [model.getOrDefault("predictionCol"), 'probability']
54+
lambda model: [model.getOrDefault("predictionCol"), "probability"],
5555
),
5656
"pyspark.ml.feature.DCT": (
5757
lambda model: [model.getOrDefault("inputCol")],
58-
lambda model: [model.getOrDefault("outputCol")]
58+
lambda model: [model.getOrDefault("outputCol")],
5959
),
6060
"pyspark.ml.feature.PCAModel": (
6161
lambda model: [model.getOrDefault("inputCol")],
62-
lambda model: [model.getOrDefault("outputCol")]
62+
lambda model: [model.getOrDefault("outputCol")],
6363
),
6464
"pyspark.ml.feature.PolynomialExpansion": (
6565
lambda model: [model.getOrDefault("inputCol")],
66-
lambda model: [model.getOrDefault("outputCol")]
66+
lambda model: [model.getOrDefault("outputCol")],
6767
),
6868
"pyspark.ml.feature.Tokenizer": (
6969
lambda model: [model.getOrDefault("inputCol")],
70-
lambda model: [model.getOrDefault("outputCol")]
70+
lambda model: [model.getOrDefault("outputCol")],
7171
),
7272
"pyspark.ml.classification.NaiveBayesModel": (
7373
lambda model: [model.getOrDefault("featuresCol")],
74-
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")]
74+
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")],
7575
),
7676
"pyspark.ml.feature.VectorSlicer": (
7777
lambda model: [model.getOrDefault("inputCol")],
78-
lambda model: [model.getOrDefault("outputCol")]
78+
lambda model: [model.getOrDefault("outputCol")],
7979
),
8080
"pyspark.ml.feature.StopWordsRemover": (
8181
lambda model: [model.getOrDefault("inputCol")],
82-
lambda model: [model.getOrDefault("outputCol")]
82+
lambda model: [model.getOrDefault("outputCol")],
8383
),
8484
"pyspark.ml.feature.NGram": (
8585
lambda model: [model.getOrDefault("inputCol")],
86-
lambda model: [model.getOrDefault("outputCol")]
86+
lambda model: [model.getOrDefault("outputCol")],
8787
),
8888
"pyspark.ml.feature.Bucketizer": (
8989
lambda model: [model.getOrDefault("inputCol")],
90-
lambda model: [model.getOrDefault("outputCol")]
90+
lambda model: [model.getOrDefault("outputCol")],
9191
),
9292
"pyspark.ml.regression.RandomForestRegressionModel": (
9393
lambda model: [model.getOrDefault("featuresCol")],
94-
lambda model: [model.getOrDefault("predictionCol")]
94+
lambda model: [model.getOrDefault("predictionCol")],
9595
),
9696
"pyspark.ml.classification.RandomForestClassificationModel": (
9797
lambda model: [model.getOrDefault("featuresCol")],
98-
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")]
98+
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")],
9999
),
100100
"pyspark.ml.regression.DecisionTreeRegressionModel": (
101101
lambda model: [model.getOrDefault("featuresCol")],
102-
lambda model: [model.getOrDefault("predictionCol")]
102+
lambda model: [model.getOrDefault("predictionCol")],
103103
),
104104
"pyspark.ml.classification.DecisionTreeClassificationModel": (
105105
lambda model: [model.getOrDefault("featuresCol")],
106-
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")]
106+
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")],
107107
),
108108
"pyspark.ml.feature.VectorIndexerModel": (
109109
lambda model: [model.getOrDefault("inputCol")],
110-
lambda model: [model.getOrDefault("outputCol")]
110+
lambda model: [model.getOrDefault("outputCol")],
111111
),
112112
"pyspark.ml.regression.GeneralizedLinearRegressionModel": (
113113
lambda model: [model.getOrDefault("featuresCol")],
114-
lambda model: [model.getOrDefault("predictionCol")]
114+
lambda model: [model.getOrDefault("predictionCol")],
115115
),
116116
"pyspark.ml.regression.LinearRegressionModel": (
117117
lambda model: [model.getOrDefault("featuresCol")],
118-
lambda model: [model.getOrDefault("predictionCol")]
118+
lambda model: [model.getOrDefault("predictionCol")],
119119
),
120120
"pyspark.ml.feature.ImputerModel": (
121121
lambda model: model.getOrDefault("inputCols"),
122-
lambda model: model.getOrDefault("outputCols")
122+
lambda model: model.getOrDefault("outputCols"),
123123
),
124124
"pyspark.ml.feature.MaxAbsScalerModel": (
125125
lambda model: [model.getOrDefault("inputCol")],
126-
lambda model: [model.getOrDefault("outputCol")]
126+
lambda model: [model.getOrDefault("outputCol")],
127127
),
128128
"pyspark.ml.feature.MinMaxScalerModel": (
129129
lambda model: [model.getOrDefault("inputCol")],
130-
lambda model: [model.getOrDefault("outputCol")]
130+
lambda model: [model.getOrDefault("outputCol")],
131131
),
132132
"pyspark.ml.feature.StandardScalerModel": (
133133
lambda model: [model.getOrDefault("inputCol")],
134-
lambda model: [model.getOrDefault("outputCol")]
134+
lambda model: [model.getOrDefault("outputCol")],
135135
),
136136
"pyspark.ml.feature.Normalizer": (
137137
lambda model: [model.getOrDefault("inputCol")],
138-
lambda model: [model.getOrDefault("outputCol")]
138+
lambda model: [model.getOrDefault("outputCol")],
139139
),
140140
"pyspark.ml.feature.Binarizer": (
141141
lambda model: [model.getOrDefault("inputCol")],
142-
lambda model: [model.getOrDefault("outputCol")]
142+
lambda model: [model.getOrDefault("outputCol")],
143143
),
144144
"pyspark.ml.feature.CountVectorizerModel": (
145145
lambda model: [model.getOrDefault("inputCol")],
146-
lambda model: [model.getOrDefault("outputCol")]
146+
lambda model: [model.getOrDefault("outputCol")],
147147
),
148148
"pyspark.ml.classification.LinearSVCModel": (
149149
lambda model: [model.getOrDefault("featuresCol")],
150-
lambda model: [model.getOrDefault("predictionCol")]
150+
lambda model: [model.getOrDefault("predictionCol")],
151151
),
152152
"pyspark.ml.classification.LogisticRegressionModel": (
153153
lambda model: [model.getOrDefault("featuresCol")],
154-
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")]
154+
lambda model: [model.getOrDefault("predictionCol"), model.getOrDefault("probabilityCol")],
155155
),
156156
"pyspark.ml.feature.OneHotEncoderModel": (
157157
lambda model: model.getOrDefault("inputCols")
@@ -162,17 +162,21 @@ def build_io_name_map():
162162
else [model.getOrDefault("outputCol")],
163163
),
164164
"pyspark.ml.feature.StringIndexerModel": (
165-
lambda model: [model.getOrDefault("inputCol")],
166-
lambda model: [model.getOrDefault("outputCol")]
165+
lambda model: model.getOrDefault("inputCols")
166+
if model.isSet("inputCols")
167+
else [model.getOrDefault("inputCol")],
168+
lambda model: model.getOrDefault("outputCols")
169+
if model.isSet("outputCols")
170+
else [model.getOrDefault("outputCol")],
167171
),
168172
"pyspark.ml.feature.VectorAssembler": (
169173
lambda model: model.getOrDefault("inputCols"),
170-
lambda model: [model.getOrDefault("outputCol")]
174+
lambda model: [model.getOrDefault("outputCol")],
171175
),
172176
"pyspark.ml.clustering.KMeansModel": (
173177
lambda model: [model.getOrDefault("featuresCol")],
174-
lambda model: [model.getOrDefault("predictionCol")]
175-
)
178+
lambda model: [model.getOrDefault("predictionCol")],
179+
),
176180
}
177181
return map
178182

@@ -181,18 +185,18 @@ def build_io_name_map():
181185

182186

183187
def get_input_names(model):
184-
'''
188+
"""
185189
Returns the name(s) of the input(s) for a SparkML operator
186190
:param model: SparkML Model
187191
:return: list of input names
188-
'''
192+
"""
189193
return io_name_map[get_sparkml_operator_name(type(model))][0](model)
190194

191195

192196
def get_output_names(model):
193-
'''
197+
"""
194198
Returns the name(s) of the output(s) for a SparkML operator
195199
:param model: SparkML Model
196200
:return: list of output names
197-
'''
201+
"""
198202
return io_name_map[get_sparkml_operator_name(type(model))][1](model)

0 commit comments

Comments
 (0)