Skip to content

Commit d0130f2

Browse files
authored
feat: add support for SparkML CountVectorizer conversion (#560)
Signed-off-by: Jason Wang <[email protected]>
1 parent f0fdf12 commit d0130f2

File tree

4 files changed

+177
-1
lines changed

4 files changed

+177
-1
lines changed

onnxmltools/convert/sparkml/operator_converters/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,4 @@
3333
from . import onehot_encoder
3434
from . import vector_assembler
3535
from . import k_means
36-
36+
from . import count_vectorizer
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from ...common._registration import register_converter, register_shape_calculator
4+
from ...common.data_types import StringTensorType, FloatTensorType
5+
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types
6+
from ...common._topology import Operator, Scope, ModelComponentContainer
7+
from pyspark.ml.feature import CountVectorizerModel
8+
9+
10+
def convert_count_vectorizer(scope: Scope, operator: Operator, container: ModelComponentContainer):
11+
op: CountVectorizerModel = operator.raw_operator
12+
vocab, minTF, binary = op.vocabulary, op.getOrDefault("minTF"), op.getOrDefault("binary")
13+
14+
if minTF < 1.0:
15+
raise NotImplementedError("Converting to ONNX for CountVectorizerModel is not supported when minTF < 1.0")
16+
17+
min_opset = 9
18+
if not binary:
19+
# If binary is False, then we need the ThresholdedRelu operator which is only available since opset 10.
20+
min_opset = 10
21+
22+
if container.target_opset < min_opset:
23+
raise NotImplementedError(
24+
f"Converting to ONNX for CountVectorizerModel is not supported in opset < {min_opset}"
25+
)
26+
27+
# Create a TfIdfVectorizer node with gram length set to 1 and mode set to "TF".
28+
vectorizer_output_variable_name = scope.get_unique_variable_name("vectorizer_output")
29+
tfIdfVectorizer_attrs = {
30+
"name": scope.get_unique_operator_name("tfIdfVectorizer"),
31+
"min_gram_length": 1,
32+
"max_gram_length": 1,
33+
"max_skip_count": 0,
34+
"mode": "TF",
35+
"ngram_counts": [0],
36+
"ngram_indexes": [*range(len(vocab))],
37+
"pool_strings": vocab,
38+
}
39+
40+
container.add_node(
41+
op_type="TfIdfVectorizer",
42+
inputs=[operator.inputs[0].full_name],
43+
outputs=[vectorizer_output_variable_name],
44+
op_version=9,
45+
**tfIdfVectorizer_attrs,
46+
)
47+
48+
# In Spark's CountVectorizerModel, the comparison with minTF is inclusive,
49+
# but in ThresholdedRelu (or Binarizer) node, the comparison with `alpha` (or `threshold`) is exclusive.
50+
# So, we need to subtract epsilon from minTF to make the comparison with `alpha` (or `threshold`) effectively inclusive.
51+
epsilon = 1e-6
52+
if binary:
53+
# Create a Binarizer node with threshold set to minTF - epsilon.
54+
container.add_node(
55+
op_type="Binarizer",
56+
inputs=[vectorizer_output_variable_name],
57+
outputs=[operator.outputs[0].full_name],
58+
op_version=1,
59+
op_domain="ai.onnx.ml",
60+
threshold=minTF - epsilon,
61+
)
62+
else:
63+
# Create a ThresholdedRelu node with alpha set to minTF - epsilon
64+
container.add_node(
65+
op_type="ThresholdedRelu",
66+
inputs=[vectorizer_output_variable_name],
67+
outputs=[operator.outputs[0].full_name],
68+
op_version=10,
69+
alpha=minTF - epsilon,
70+
)
71+
72+
73+
register_converter("pyspark.ml.feature.CountVectorizerModel", convert_count_vectorizer)
74+
75+
76+
def calculate_count_vectorizer_output_shapes(operator):
77+
check_input_and_output_numbers(operator, output_count_range=1)
78+
check_input_and_output_types(operator, good_input_types=[StringTensorType])
79+
80+
N = operator.inputs[0].type.shape[0]
81+
C = len(operator.raw_operator.vocabulary)
82+
operator.outputs[0].type = FloatTensorType([N, C])
83+
84+
85+
register_shape_calculator("pyspark.ml.feature.CountVectorizerModel", calculate_count_vectorizer_output_shapes)

onnxmltools/convert/sparkml/ops_input_output.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,10 @@ def build_io_name_map():
141141
lambda model: [model.getOrDefault("inputCol")],
142142
lambda model: [model.getOrDefault("outputCol")]
143143
),
144+
"pyspark.ml.feature.CountVectorizerModel": (
145+
lambda model: [model.getOrDefault("inputCol")],
146+
lambda model: [model.getOrDefault("outputCol")]
147+
),
144148
"pyspark.ml.classification.LinearSVCModel": (
145149
lambda model: [model.getOrDefault("featuresCol")],
146150
lambda model: [model.getOrDefault("predictionCol")]
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import sys
4+
import unittest
5+
import numpy
6+
import pandas
7+
from pyspark.ml.feature import CountVectorizer, CountVectorizerModel
8+
from onnx.defs import onnx_opset_version
9+
from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER
10+
from onnxmltools import convert_sparkml
11+
from onnxmltools.convert.common.data_types import StringTensorType
12+
from tests.sparkml.sparkml_test_utils import save_data_models, run_onnx_model, compare_results
13+
from tests.sparkml import SparkMlTestCase
14+
15+
TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version())
16+
17+
class TestSparkmlCountVectorizer(SparkMlTestCase):
18+
19+
@unittest.skipIf(sys.version_info < (3, 8),
20+
reason="pickle fails on python 3.7")
21+
def test_count_vectorizer_default(self):
22+
data = self.spark.createDataFrame([
23+
("A B C".split(" "), ),
24+
("A B B C A".split(" "), ),
25+
], ["text"])
26+
count_vec = CountVectorizer(inputCol="text", outputCol="result", minTF=1.0, binary=False)
27+
model: CountVectorizerModel = count_vec.fit(data)
28+
result = model.transform(data)
29+
30+
model_onnx = convert_sparkml(model, 'Sparkml CountVectorizer', [('text', StringTensorType([None, None]))], target_opset=TARGET_OPSET)
31+
self.assertTrue(model_onnx is not None)
32+
33+
data_pd = data.toPandas()
34+
data_np = {
35+
"text": data_pd.text.apply(lambda x: pandas.Series(x)).values.astype(str),
36+
}
37+
38+
expected = {
39+
"prediction_result": numpy.asarray(result.toPandas().result.apply(lambda x: pandas.Series(x.toArray())).values.astype(numpy.float32)),
40+
}
41+
42+
paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlCountVectorizerModel_Default")
43+
onnx_model_path = paths[-1]
44+
45+
output_names = ['result']
46+
output, output_shapes = run_onnx_model(output_names, data_np, onnx_model_path)
47+
actual_output = dict(zip(output_names, output))
48+
49+
assert output_shapes[0] == [None, 3]
50+
compare_results(expected["prediction_result"], actual_output["result"], decimal=5)
51+
52+
@unittest.skipIf(sys.version_info < (3, 8),
53+
reason="pickle fails on python 3.7")
54+
def test_count_vectorizer_binary(self):
55+
data = self.spark.createDataFrame([
56+
("A B C".split(" "), ),
57+
("A B B C A".split(" "), ),
58+
("B B B D".split(" "), ),
59+
], ["text"])
60+
count_vec = CountVectorizer(inputCol="text", outputCol="result", minTF=2.0, binary=True)
61+
model: CountVectorizerModel = count_vec.fit(data)
62+
result = model.transform(data)
63+
64+
model_onnx = convert_sparkml(model, 'Sparkml CountVectorizer', [('text', StringTensorType([None, None]))], target_opset=TARGET_OPSET)
65+
self.assertTrue(model_onnx is not None)
66+
67+
data_pd = data.toPandas()
68+
data_np = {
69+
"text": data_pd.text.apply(lambda x: pandas.Series(x)).values.astype(str),
70+
}
71+
72+
expected = {
73+
"prediction_result": numpy.asarray(result.toPandas().result.apply(lambda x: pandas.Series(x.toArray())).values.astype(numpy.float32)),
74+
}
75+
76+
paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlCountVectorizerModel_Binary")
77+
onnx_model_path = paths[-1]
78+
79+
output_names = ['result']
80+
output, output_shapes = run_onnx_model(output_names, data_np, onnx_model_path)
81+
actual_output = dict(zip(output_names, output))
82+
83+
assert output_shapes[0] == [None, 4]
84+
compare_results(expected["prediction_result"], actual_output["result"], decimal=5)
85+
86+
if __name__ == "__main__":
87+
unittest.main()

0 commit comments

Comments
 (0)