Skip to content

Commit 622a3ba

Browse files
authored
fix: showing better error message when a model of wrong type is being passed. (#569)
Signed-off-by: Jason Wang <[email protected]>
1 parent 24d1d99 commit 622a3ba

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

onnxmltools/convert/sparkml/ops_names.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Mapping and utility functions for Name to Spark ML operators
55
'''
66

7+
from pyspark.ml import Transformer, Estimator
78
from pyspark.ml.feature import Binarizer
89
from pyspark.ml.feature import BucketedRandomProjectionLSHModel
910
from pyspark.ml.feature import Bucketizer
@@ -86,6 +87,12 @@ def get_sparkml_operator_name(model_type):
8687
:param model_type: A spark-ml object (LinearRegression, StringIndexer, ...)
8788
:return: A string which stands for the type of the input model in our conversion framework
8889
'''
90+
if not issubclass(model_type, Transformer):
91+
if issubclass(model_type, Estimator):
92+
raise ValueError("Estimator must be fitted before being converted to ONNX")
93+
else:
94+
raise ValueError("Unknown model type: {}".format(model_type))
95+
8996
if model_type not in sparkml_operator_name_map:
9097
raise ValueError("No proper operator name found for '%s'" % model_type)
9198
return sparkml_operator_name_map[model_type]

0 commit comments

Comments
 (0)