18
18
19
19
import pyspark .sql .connect .proto as pb2
20
20
from pyspark .ml .linalg import (
21
- VectorUDT ,
22
- MatrixUDT ,
23
21
DenseVector ,
24
22
SparseVector ,
25
23
DenseMatrix ,
@@ -49,13 +47,23 @@ def build_float_list(value: List[float]) -> pb2.Expression.Literal:
49
47
return p
50
48
51
49
50
+ def build_proto_udt (jvm_class : str ) -> pb2 .DataType :
51
+ ret = pb2 .DataType ()
52
+ ret .udt .type = "udt"
53
+ ret .udt .jvm_class = jvm_class
54
+ return ret
55
+
56
+
57
+ proto_vector_udt = build_proto_udt ("org.apache.spark.ml.linalg.VectorUDT" )
58
+ proto_matrix_udt = build_proto_udt ("org.apache.spark.ml.linalg.MatrixUDT" )
59
+
60
+
52
61
def serialize_param (value : Any , client : "SparkConnectClient" ) -> pb2 .Expression .Literal :
53
- from pyspark .sql .connect .types import pyspark_types_to_proto_types
54
62
from pyspark .sql .connect .expressions import LiteralExpression
55
63
56
64
if isinstance (value , SparseVector ):
57
65
p = pb2 .Expression .Literal ()
58
- p .struct .struct_type .CopyFrom (pyspark_types_to_proto_types ( VectorUDT . sqlType ()) )
66
+ p .struct .struct_type .CopyFrom (proto_vector_udt )
59
67
# type = 0
60
68
p .struct .elements .append (pb2 .Expression .Literal (byte = 0 ))
61
69
# size
@@ -68,7 +76,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.
68
76
69
77
elif isinstance (value , DenseVector ):
70
78
p = pb2 .Expression .Literal ()
71
- p .struct .struct_type .CopyFrom (pyspark_types_to_proto_types ( VectorUDT . sqlType ()) )
79
+ p .struct .struct_type .CopyFrom (proto_vector_udt )
72
80
# type = 1
73
81
p .struct .elements .append (pb2 .Expression .Literal (byte = 1 ))
74
82
# size = null
@@ -81,7 +89,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.
81
89
82
90
elif isinstance (value , SparseMatrix ):
83
91
p = pb2 .Expression .Literal ()
84
- p .struct .struct_type .CopyFrom (pyspark_types_to_proto_types ( MatrixUDT . sqlType ()) )
92
+ p .struct .struct_type .CopyFrom (proto_matrix_udt )
85
93
# type = 0
86
94
p .struct .elements .append (pb2 .Expression .Literal (byte = 0 ))
87
95
# numRows
@@ -100,7 +108,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.
100
108
101
109
elif isinstance (value , DenseMatrix ):
102
110
p = pb2 .Expression .Literal ()
103
- p .struct .struct_type .CopyFrom (pyspark_types_to_proto_types ( MatrixUDT . sqlType ()) )
111
+ p .struct .struct_type .CopyFrom (proto_matrix_udt )
104
112
# type = 1
105
113
p .struct .elements .append (pb2 .Expression .Literal (byte = 1 ))
106
114
# numRows
@@ -134,14 +142,13 @@ def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
134
142
135
143
136
144
def deserialize_param (literal : pb2 .Expression .Literal ) -> Any :
137
- from pyspark .sql .connect .types import proto_schema_to_pyspark_data_type
138
145
from pyspark .sql .connect .expressions import LiteralExpression
139
146
140
147
if literal .HasField ("struct" ):
141
148
s = literal .struct
142
- schema = proto_schema_to_pyspark_data_type ( s .struct_type )
149
+ jvm_class = s .struct_type . udt . jvm_class
143
150
144
- if schema == VectorUDT . sqlType () :
151
+ if jvm_class == "org.apache.spark.ml.linalg.VectorUDT" :
145
152
assert len (s .elements ) == 4
146
153
tpe = s .elements [0 ].byte
147
154
if tpe == 0 :
@@ -155,7 +162,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any:
155
162
else :
156
163
raise ValueError (f"Unknown Vector type { tpe } " )
157
164
158
- elif schema == MatrixUDT . sqlType () :
165
+ elif jvm_class == "org.apache.spark.ml.linalg.MatrixUDT" :
159
166
assert len (s .elements ) == 7
160
167
tpe = s .elements [0 ].byte
161
168
if tpe == 0 :
@@ -175,7 +182,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any:
175
182
else :
176
183
raise ValueError (f"Unknown Matrix type { tpe } " )
177
184
else :
178
- raise ValueError (f"Unsupported parameter struct { schema } " )
185
+ raise ValueError (f"Unknown UDT { jvm_class } " )
179
186
else :
180
187
return LiteralExpression ._to_value (literal )
181
188
0 commit comments