Skip to content

Commit db71727

Browse files
authored
fix: getTensorTypeFromSpark fails for Spark 3.3.0+ (#607)
Signed-off-by: Jason Wang <[email protected]>
1 parent c29abfd commit db71727

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

onnxmltools/convert/sparkml/utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@ def buildInitialTypesSimple(dataframe):
1414

1515

1616
def getTensorTypeFromSpark(sparktype):
17-
if sparktype == 'StringType':
17+
if sparktype == 'StringType' or sparktype == 'StringType()':
1818
return StringTensorType([1, 1])
19-
elif sparktype == 'DecimalType' \
20-
or sparktype == 'DoubleType' \
21-
or sparktype == 'FloatType' \
22-
or sparktype == 'LongType' \
23-
or sparktype == 'IntegerType' \
24-
or sparktype == 'ShortType' \
25-
or sparktype == 'ByteType' \
26-
or sparktype == 'BooleanType':
19+
elif sparktype == 'DecimalType' or sparktype == 'DecimalType()' \
20+
or sparktype == 'DoubleType' or sparktype == 'DoubleType()' \
21+
or sparktype == 'FloatType' or sparktype == 'FloatType()' \
22+
or sparktype == 'LongType' or sparktype == 'LongType()' \
23+
or sparktype == 'IntegerType' or sparktype == 'IntegerType()' \
24+
or sparktype == 'ShortType' or sparktype == 'ShortType()' \
25+
or sparktype == 'ByteType' or sparktype == 'ByteType()' \
26+
or sparktype == 'BooleanType' or sparktype == 'BooleanType()':
2727
return FloatTensorType([1, 1])
2828
else:
2929
raise TypeError("Cannot map this type to Onnx types: " + sparktype)

0 commit comments

Comments
 (0)