Skip to content

Commit 5d6bcb9

Browse files
committed
[SPARK-50851][ML][CONNECT][FOLLOW-UP] Simplify the message by applying the UDT
### What changes were proposed in this pull request? To simplify the message, directly apply the UDT instead of the underlying schema ### Why are the changes needed? When debugging a test, I notice the message field for datatype is too verbose: ``` struct { struct_type { struct { fields { name: "type" data_type { byte { } } } fields { name: "size" data_type { integer { } } nullable: true } fields { name: "indices" data_type { array { element_type { integer { } } } } nullable: true } fields { name: "values" data_type { array { element_type { double { } } } } nullable: true } } } elements { ... } } ``` after ``` struct { struct_type { udt { type: "udt" jvm_class: "org.apache.spark.ml.linalg.VectorUDT" } } elements { ... } } ``` ### Does this PR introduce _any_ user-facing change? No, internal change ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #49643 from zhengruifeng/ml_simplify_msg. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent e887f05 commit 5d6bcb9

File tree

3 files changed

+37
-32
lines changed

3 files changed

+37
-32
lines changed

python/pyspark/ml/connect/serialize.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
import pyspark.sql.connect.proto as pb2
2020
from pyspark.ml.linalg import (
21-
VectorUDT,
22-
MatrixUDT,
2321
DenseVector,
2422
SparseVector,
2523
DenseMatrix,
@@ -49,13 +47,23 @@ def build_float_list(value: List[float]) -> pb2.Expression.Literal:
4947
return p
5048

5149

50+
def build_proto_udt(jvm_class: str) -> pb2.DataType:
51+
ret = pb2.DataType()
52+
ret.udt.type = "udt"
53+
ret.udt.jvm_class = jvm_class
54+
return ret
55+
56+
57+
proto_vector_udt = build_proto_udt("org.apache.spark.ml.linalg.VectorUDT")
58+
proto_matrix_udt = build_proto_udt("org.apache.spark.ml.linalg.MatrixUDT")
59+
60+
5261
def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.Literal:
53-
from pyspark.sql.connect.types import pyspark_types_to_proto_types
5462
from pyspark.sql.connect.expressions import LiteralExpression
5563

5664
if isinstance(value, SparseVector):
5765
p = pb2.Expression.Literal()
58-
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
66+
p.struct.struct_type.CopyFrom(proto_vector_udt)
5967
# type = 0
6068
p.struct.elements.append(pb2.Expression.Literal(byte=0))
6169
# size
@@ -68,7 +76,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.
6876

6977
elif isinstance(value, DenseVector):
7078
p = pb2.Expression.Literal()
71-
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(VectorUDT.sqlType()))
79+
p.struct.struct_type.CopyFrom(proto_vector_udt)
7280
# type = 1
7381
p.struct.elements.append(pb2.Expression.Literal(byte=1))
7482
# size = null
@@ -81,7 +89,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.
8189

8290
elif isinstance(value, SparseMatrix):
8391
p = pb2.Expression.Literal()
84-
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
92+
p.struct.struct_type.CopyFrom(proto_matrix_udt)
8593
# type = 0
8694
p.struct.elements.append(pb2.Expression.Literal(byte=0))
8795
# numRows
@@ -100,7 +108,7 @@ def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Expression.
100108

101109
elif isinstance(value, DenseMatrix):
102110
p = pb2.Expression.Literal()
103-
p.struct.struct_type.CopyFrom(pyspark_types_to_proto_types(MatrixUDT.sqlType()))
111+
p.struct.struct_type.CopyFrom(proto_matrix_udt)
104112
# type = 1
105113
p.struct.elements.append(pb2.Expression.Literal(byte=1))
106114
# numRows
@@ -134,14 +142,13 @@ def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
134142

135143

136144
def deserialize_param(literal: pb2.Expression.Literal) -> Any:
137-
from pyspark.sql.connect.types import proto_schema_to_pyspark_data_type
138145
from pyspark.sql.connect.expressions import LiteralExpression
139146

140147
if literal.HasField("struct"):
141148
s = literal.struct
142-
schema = proto_schema_to_pyspark_data_type(s.struct_type)
149+
jvm_class = s.struct_type.udt.jvm_class
143150

144-
if schema == VectorUDT.sqlType():
151+
if jvm_class == "org.apache.spark.ml.linalg.VectorUDT":
145152
assert len(s.elements) == 4
146153
tpe = s.elements[0].byte
147154
if tpe == 0:
@@ -155,7 +162,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any:
155162
else:
156163
raise ValueError(f"Unknown Vector type {tpe}")
157164

158-
elif schema == MatrixUDT.sqlType():
165+
elif jvm_class == "org.apache.spark.ml.linalg.MatrixUDT":
159166
assert len(s.elements) == 7
160167
tpe = s.elements[0].byte
161168
if tpe == 0:
@@ -175,7 +182,7 @@ def deserialize_param(literal: pb2.Expression.Literal) -> Any:
175182
else:
176183
raise ValueError(f"Unknown Matrix type {tpe}")
177184
else:
178-
raise ValueError(f"Unsupported parameter struct {schema}")
185+
raise ValueError(f"Unknown UDT {jvm_class}")
179186
else:
180187
return LiteralExpression._to_value(literal)
181188

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.apache.spark.ml.regression._
3838
import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
3939
import org.apache.spark.ml.util.{HasTrainingSummary, Identifiable, MLWritable}
4040
import org.apache.spark.sql.{DataFrame, Dataset}
41-
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
41+
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
4242
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
4343
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
4444
import org.apache.spark.sql.connect.service.SessionHolder
@@ -147,13 +147,11 @@ private[ml] object MLUtils {
147147
val value = literal.getLiteralTypeCase match {
148148
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
149149
val s = literal.getStruct
150-
val schema = DataTypeProtoConverter.toCatalystType(s.getStructType)
151-
if (schema == VectorUDT.sqlType) {
152-
deserializeVector(s)
153-
} else if (schema == MatrixUDT.sqlType) {
154-
deserializeMatrix(s)
155-
} else {
156-
throw MlUnsupportedException(s"Unsupported parameter struct ${schema} for ${name}")
150+
s.getStructType.getUdt.getJvmClass match {
151+
case "org.apache.spark.ml.linalg.VectorUDT" => deserializeVector(s)
152+
case "org.apache.spark.ml.linalg.MatrixUDT" => deserializeMatrix(s)
153+
case _ =>
154+
throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct} for ${name}")
157155
}
158156

159157
case _ =>

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.apache.spark.connect.proto
2121
import org.apache.spark.ml.linalg._
2222
import org.apache.spark.ml.param.Params
2323
import org.apache.spark.sql.Dataset
24-
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter, ProtoDataTypes}
24+
import org.apache.spark.sql.connect.common.{LiteralValueProtoConverter, ProtoDataTypes}
2525
import org.apache.spark.sql.connect.service.SessionHolder
2626

2727
private[ml] object Serializer {
@@ -37,7 +37,7 @@ private[ml] object Serializer {
3737
data match {
3838
case v: SparseVector =>
3939
val builder = proto.Expression.Literal.Struct.newBuilder()
40-
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType))
40+
builder.setStructType(ProtoDataTypes.VectorUDT)
4141
// type = 0
4242
builder.addElements(proto.Expression.Literal.newBuilder().setByte(0))
4343
// size
@@ -50,7 +50,7 @@ private[ml] object Serializer {
5050

5151
case v: DenseVector =>
5252
val builder = proto.Expression.Literal.Struct.newBuilder()
53-
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(VectorUDT.sqlType))
53+
builder.setStructType(ProtoDataTypes.VectorUDT)
5454
// type = 1
5555
builder.addElements(proto.Expression.Literal.newBuilder().setByte(1))
5656
// size = null
@@ -65,7 +65,7 @@ private[ml] object Serializer {
6565

6666
case m: SparseMatrix =>
6767
val builder = proto.Expression.Literal.Struct.newBuilder()
68-
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType))
68+
builder.setStructType(ProtoDataTypes.MatrixUDT)
6969
// type = 0
7070
builder.addElements(proto.Expression.Literal.newBuilder().setByte(0))
7171
// numRows
@@ -84,7 +84,7 @@ private[ml] object Serializer {
8484

8585
case m: DenseMatrix =>
8686
val builder = proto.Expression.Literal.Struct.newBuilder()
87-
builder.setStructType(DataTypeProtoConverter.toConnectProtoType(MatrixUDT.sqlType))
87+
builder.setStructType(ProtoDataTypes.MatrixUDT)
8888
// type = 1
8989
builder.addElements(proto.Expression.Literal.newBuilder().setByte(1))
9090
// numRows
@@ -146,13 +146,13 @@ private[ml] object Serializer {
146146
literal.getLiteralTypeCase match {
147147
case proto.Expression.Literal.LiteralTypeCase.STRUCT =>
148148
val struct = literal.getStruct
149-
val schema = DataTypeProtoConverter.toCatalystType(struct.getStructType)
150-
if (schema == VectorUDT.sqlType) {
151-
(MLUtils.deserializeVector(struct), classOf[Vector])
152-
} else if (schema == MatrixUDT.sqlType) {
153-
(MLUtils.deserializeMatrix(struct), classOf[Matrix])
154-
} else {
155-
throw MlUnsupportedException(s"$schema not supported")
149+
struct.getStructType.getUdt.getJvmClass match {
150+
case "org.apache.spark.ml.linalg.VectorUDT" =>
151+
(MLUtils.deserializeVector(struct), classOf[Vector])
152+
case "org.apache.spark.ml.linalg.MatrixUDT" =>
153+
(MLUtils.deserializeMatrix(struct), classOf[Matrix])
154+
case _ =>
155+
throw MlUnsupportedException(s"Unsupported struct ${literal.getStruct}")
156156
}
157157
case proto.Expression.Literal.LiteralTypeCase.INTEGER =>
158158
(literal.getInteger.asInstanceOf[Object], classOf[Int])

0 commit comments

Comments
 (0)