Skip to content

Commit b5840e1

Browse files
heyihongzhengruifeng
authored andcommitted
[SPARK-52930][CONNECT] Use DataType.Array/Map for Array/Map Literals
### What changes were proposed in this pull request? This PR introduces a transition to use `DataType.Array` and `DataType.Map` for array and map literals throughout the Spark Connect codebase for Array/Map Literals. While the Spark Connect server supports both new and old data type fields, in this change, the new data type fields are set only in ColumnNodeToProtoConverter for the Spark Connect Scala client. All other components (e.g., ML, Python) still use the old data type fields because literal values are used not only in requests but also in responses, making it difficult to maintain compatibility—clients using older versions may not recognize the new fields in the response. Deprecation and the transition to the new fields require a gradual migration. The key changes include: **Protocol Buffer Updates:** - Modified `expressions.proto` to add new `data_type` fields for `Array` and `Map` messages - Deprecated existing `element_type`, `key_type`, and `value_type` fields in favor of the unified `data_type` approach - Updated generated protocol buffer files (`expressions_pb2.py`, `expressions_pb2.pyi`) to reflect these changes **Core Implementation Changes:** - Enhanced `LiteralValueProtoConverter.scala` with new internal method `toLiteralProtoBuilderInternal` that accepts `ToLiteralProtoOptions` - Updated `LiteralExpressionProtoConverter.scala` to support inference of array and map data types - Modified `columnNodeSupport.scala` to use the new `toLiteralProtoBuilderWithOptions` method with `useDeprecatedDataTypeFields` set to `false` ### Why are the changes needed? The changes are needed to improve Spark's data type handling for array and map literals: - **Nullability of Array/Map literals are now included in the DataType.Array/Map**: This ensures that nullability information is properly captured and handled within the data type structure itself. - **Work better with type inference by including all type information in one field**: By consolidating all type information into a single field, it is easier to infer data types for complex data structures. ### Does this PR introduce _any_ user-facing change? Yes. Previously, the nullability of arrays and map values using typedlit was not preserved (which I believe was a bug). It is now preserved. Please see the changes in ClientE2ETestSuite for details. ### How was this patch tested? `build/sbt "connect/testOnly *LiteralExpressionProtoConverterSuite"` `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite -- -z SPARK-52930"` ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Cursor 1.4.5 Closes #51653 from heyihong/SPARK-52930. Authored-by: Yihong He <heyihong.cn@gmail.com> Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
1 parent 79a0ca7 commit b5840e1

File tree

14 files changed

+1102
-570
lines changed

14 files changed

+1102
-570
lines changed

python/pyspark/sql/connect/proto/expressions_pb2.py

Lines changed: 70 additions & 60 deletions
Large diffs are not rendered by default.

python/pyspark/sql/connect/proto/expressions_pb2.pyi

Lines changed: 57 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -474,27 +474,51 @@ class Expression(google.protobuf.message.Message):
474474

475475
ELEMENT_TYPE_FIELD_NUMBER: builtins.int
476476
ELEMENTS_FIELD_NUMBER: builtins.int
477+
DATA_TYPE_FIELD_NUMBER: builtins.int
477478
@property
478-
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
479+
def element_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
480+
"""(Deprecated) The element type of the array.
481+
482+
This field is deprecated since Spark 4.1+ and should only be set
483+
if the data_type field is not set. Use data_type field instead.
484+
"""
479485
@property
480486
def elements(
481487
self,
482488
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
483489
global___Expression.Literal
484-
]: ...
490+
]:
491+
"""The literal values that make up the array elements."""
492+
@property
493+
def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Array:
494+
"""The type of the array.
495+
496+
If the element type can be inferred from the first element of the elements field,
497+
then you don't need to set data_type.element_type to save space. On the other hand,
498+
redundant type information is also acceptable.
499+
"""
485500
def __init__(
486501
self,
487502
*,
488503
element_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
489504
elements: collections.abc.Iterable[global___Expression.Literal] | None = ...,
505+
data_type: pyspark.sql.connect.proto.types_pb2.DataType.Array | None = ...,
490506
) -> None: ...
491507
def HasField(
492-
self, field_name: typing_extensions.Literal["element_type", b"element_type"]
508+
self,
509+
field_name: typing_extensions.Literal[
510+
"data_type", b"data_type", "element_type", b"element_type"
511+
],
493512
) -> builtins.bool: ...
494513
def ClearField(
495514
self,
496515
field_name: typing_extensions.Literal[
497-
"element_type", b"element_type", "elements", b"elements"
516+
"data_type",
517+
b"data_type",
518+
"element_type",
519+
b"element_type",
520+
"elements",
521+
b"elements",
498522
],
499523
) -> None: ...
500524

@@ -505,39 +529,63 @@ class Expression(google.protobuf.message.Message):
505529
VALUE_TYPE_FIELD_NUMBER: builtins.int
506530
KEYS_FIELD_NUMBER: builtins.int
507531
VALUES_FIELD_NUMBER: builtins.int
532+
DATA_TYPE_FIELD_NUMBER: builtins.int
508533
@property
509-
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
534+
def key_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
535+
"""(Deprecated) The key type of the map.
536+
537+
This field is deprecated since Spark 4.1+ and should only be set
538+
if the data_type field is not set. Use data_type field instead.
539+
"""
510540
@property
511-
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
541+
def value_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType:
542+
"""(Deprecated) The value type of the map.
543+
544+
This field is deprecated since Spark 4.1+ and should only be set
545+
if the data_type field is not set. Use data_type field instead.
546+
"""
512547
@property
513548
def keys(
514549
self,
515550
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
516551
global___Expression.Literal
517-
]: ...
552+
]:
553+
"""The literal keys that make up the map."""
518554
@property
519555
def values(
520556
self,
521557
) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
522558
global___Expression.Literal
523-
]: ...
559+
]:
560+
"""The literal values that make up the map."""
561+
@property
562+
def data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType.Map:
563+
"""The type of the map.
564+
565+
If the key/value types can be inferred from the first element of the keys/values fields,
566+
then you don't need to set data_type.key_type/data_type.value_type to save space.
567+
On the other hand, redundant type information is also acceptable.
568+
"""
524569
def __init__(
525570
self,
526571
*,
527572
key_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
528573
value_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
529574
keys: collections.abc.Iterable[global___Expression.Literal] | None = ...,
530575
values: collections.abc.Iterable[global___Expression.Literal] | None = ...,
576+
data_type: pyspark.sql.connect.proto.types_pb2.DataType.Map | None = ...,
531577
) -> None: ...
532578
def HasField(
533579
self,
534580
field_name: typing_extensions.Literal[
535-
"key_type", b"key_type", "value_type", b"value_type"
581+
"data_type", b"data_type", "key_type", b"key_type", "value_type", b"value_type"
536582
],
537583
) -> builtins.bool: ...
538584
def ClearField(
539585
self,
540586
field_name: typing_extensions.Literal[
587+
"data_type",
588+
b"data_type",
541589
"key_type",
542590
b"key_type",
543591
"keys",

sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/ClientE2ETestSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,35 @@ class ClientE2ETestSuite
16871687
assert(df.count() == 100)
16881688
}
16891689
}
1690+
1691+
test("SPARK-52930: the nullability of arrays should be preserved using typedlit") {
1692+
val arrays = Seq(
1693+
(typedlit(Array[Int]()), false),
1694+
(typedlit(Array[Int](1)), false),
1695+
(typedlit(Array[Integer]()), true),
1696+
(typedlit(Array[Integer](1)), true))
1697+
for ((array, containsNull) <- arrays) {
1698+
val df = spark.sql("select 1").select(array)
1699+
df.createOrReplaceTempView("test_array_nullability")
1700+
val schema = spark.sql("select * from test_array_nullability").schema
1701+
assert(schema.fields.head.dataType.asInstanceOf[ArrayType].containsNull === containsNull)
1702+
}
1703+
}
1704+
1705+
test("SPARK-52930: the nullability of map values should be preserved using typedlit") {
1706+
val maps = Seq(
1707+
(typedlit(Map[String, Int]()), false),
1708+
(typedlit(Map[String, Int]("a" -> 1)), false),
1709+
(typedlit(Map[String, Integer]()), true),
1710+
(typedlit(Map[String, Integer]("a" -> 1)), true))
1711+
for ((map, valueContainsNull) <- maps) {
1712+
val df = spark.sql("select 1").select(map)
1713+
df.createOrReplaceTempView("test_map_nullability")
1714+
val schema = spark.sql("select * from test_map_nullability").schema
1715+
assert(
1716+
schema.fields.head.dataType.asInstanceOf[MapType].valueContainsNull === valueContainsNull)
1717+
}
1718+
}
16901719
}
16911720

16921721
private[sql] case class ClassData(a: String, b: Int)

sql/connect/common/src/main/protobuf/spark/connect/expressions.proto

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,15 +215,48 @@ message Expression {
215215
}
216216

217217
message Array {
218-
DataType element_type = 1;
218+
// (Deprecated) The element type of the array.
219+
//
220+
// This field is deprecated since Spark 4.1+ and should only be set
221+
// if the data_type field is not set. Use data_type field instead.
222+
DataType element_type = 1 [deprecated = true];
223+
224+
// The literal values that make up the array elements.
219225
repeated Literal elements = 2;
226+
227+
// The type of the array.
228+
//
229+
// If the element type can be inferred from the first element of the elements field,
230+
// then you don't need to set data_type.element_type to save space. On the other hand,
231+
// redundant type information is also acceptable.
232+
DataType.Array data_type = 3;
220233
}
221234

222235
message Map {
223-
DataType key_type = 1;
224-
DataType value_type = 2;
236+
// (Deprecated) The key type of the map.
237+
//
238+
// This field is deprecated since Spark 4.1+ and should only be set
239+
// if the data_type field is not set. Use data_type field instead.
240+
DataType key_type = 1 [deprecated = true];
241+
242+
// (Deprecated) The value type of the map.
243+
//
244+
// This field is deprecated since Spark 4.1+ and should only be set
245+
// if the data_type field is not set. Use data_type field instead.
246+
DataType value_type = 2 [deprecated = true];
247+
248+
// The literal keys that make up the map.
225249
repeated Literal keys = 3;
250+
251+
// The literal values that make up the map.
226252
repeated Literal values = 4;
253+
254+
// The type of the map.
255+
//
256+
// If the key/value types can be inferred from the first element of the keys/values fields,
257+
// then you don't need to set data_type.key_type/data_type.value_type to save space.
258+
// On the other hand, redundant type information is also acceptable.
259+
DataType.Map data_type = 5;
227260
}
228261

229262
message Struct {

sql/connect/common/src/main/scala/org/apache/spark/sql/connect/columnNodeSupport.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
2929
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
3030
import org.apache.spark.sql.connect.ConnectConversions._
3131
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
32-
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProtoBuilder
32+
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.{toLiteralProtoBuilderWithOptions, ToLiteralProtoOptions}
3333
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
3434
import org.apache.spark.sql.internal.{Alias, CaseWhenOtherwise, Cast, ColumnNode, ColumnNodeLike, InvokeInlineUserDefinedFunction, LambdaFunction, LazyExpression, Literal, SortOrder, SqlExpression, SubqueryExpression, SubqueryType, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedNamedLambdaVariable, UnresolvedRegex, UnresolvedStar, UpdateFields, Window, WindowFrame}
3535

@@ -65,11 +65,12 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) {
6565
val builder = proto.Expression.newBuilder()
6666
val n = additionalTransformation.map(_(node)).getOrElse(node)
6767
n match {
68-
case Literal(value, None, _) =>
69-
builder.setLiteral(toLiteralProtoBuilder(value))
70-
71-
case Literal(value, Some(dataType), _) =>
72-
builder.setLiteral(toLiteralProtoBuilder(value, dataType))
68+
case Literal(value, dataTypeOpt, _) =>
69+
builder.setLiteral(
70+
toLiteralProtoBuilderWithOptions(
71+
value,
72+
dataTypeOpt,
73+
ToLiteralProtoOptions(useDeprecatedDataTypeFields = false)))
7374

7475
case u @ UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) =>
7576
val escapedName = u.sql

0 commit comments

Comments
 (0)