Skip to content

Commit 6339c8c

Browse files
viiryacloud-fan
authored andcommitted
[SPARK-24762][SQL] Enable Option of Product encoders
## What changes were proposed in this pull request? SparkSQL doesn't support to encode `Option[Product]` as a top-level row now, because in SparkSQL entire top-level row can't be null. However for use cases like Aggregator, it is reasonable to use `Option[Product]` as buffer and output column types. Due to above limitation, we don't do it for now. This patch proposes to encode `Option[Product]` at top-level as single struct column. So we can work around the issue that entire top-level row can't be null. To summarize encoding of `Product` and `Option[Product]`. For `Product`, 1. at root level, the schema is all fields are flatten it into multiple columns. The `Product ` can't be null, otherwise it throws an exception. ```scala val df = Seq((1 -> "a"), (2 -> "b")).toDF() df.printSchema() root |-- _1: integer (nullable = false) |-- _2: string (nullable = true) ``` 2. At non-root level, `Product` is a struct type column. ```scala val df = Seq((1, (1 -> "a")), (2, (2 -> "b")), (3, null)).toDF() df.printSchema() root |-- _1: integer (nullable = false) |-- _2: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` For `Option[Product]`, 1. it was not supported at root level. After this change, it is a struct type column. ```scala val df = Seq(Some(1 -> "a"), Some(2 -> "b"), None).toDF() df.printSchema root |-- value: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` 2. At non-root level, it is also a struct type column. ```scala val df = Seq((1, Some(1 -> "a")), (2, Some(2 -> "b")), (3, None)).toDF() df.printSchema root |-- _1: integer (nullable = false) |-- _2: struct (nullable = true) | |-- _1: integer (nullable = false) | |-- _2: string (nullable = true) ``` 3. For use case like Aggregator, it was not supported too. After this change, we support to use `Option[Product]` as buffer/output column type. ```scala val df = Seq( OptionBooleanIntData("bob", Some((true, 1))), OptionBooleanIntData("bob", Some((false, 2))), OptionBooleanIntData("bob", None)).toDF() val group = df .groupBy("name") .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood")) group.printSchema root |-- name: string (nullable = true) |-- isGood: struct (nullable = true) | |-- _1: boolean (nullable = false) | |-- _2: integer (nullable = false) ``` The buffer and output type of `OptionBooleanIntAggregator` is both `Option[(Boolean, Int)`. ## How was this patch tested? Added test. Closes apache#21732 from viirya/SPARK-24762. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 9414578 commit 6339c8c

File tree

6 files changed

+163
-40
lines changed

6 files changed

+163
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,6 @@ object ExpressionEncoder {
4949
val mirror = ScalaReflection.mirror
5050
val tpe = typeTag[T].in(mirror).tpe
5151

52-
if (ScalaReflection.optionOfProductType(tpe)) {
53-
throw new UnsupportedOperationException(
54-
"Cannot create encoder for Option of Product type, because Product type is represented " +
55-
"as a row, and the entire row can not be null in Spark SQL like normal databases. " +
56-
"You can wrap your type with Tuple1 if you do want top level null Product objects, " +
57-
"e.g. instead of creating `Dataset[Option[MyClass]]`, you can do something like " +
58-
"`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)), Tuple1(null)).toDS`")
59-
}
60-
6152
val cls = mirror.runtimeClass(tpe)
6253
val serializer = ScalaReflection.serializerForType(tpe)
6354
val deserializer = ScalaReflection.deserializerForType(tpe)
@@ -198,7 +189,7 @@ case class ExpressionEncoder[T](
198189
val serializer: Seq[NamedExpression] = {
199190
val clsName = Utils.getSimpleName(clsTag.runtimeClass)
200191

201-
if (isSerializedAsStruct) {
192+
if (isSerializedAsStructForTopLevel) {
202193
val nullSafeSerializer = objSerializer.transformUp {
203194
case r: BoundReference =>
204195
// For input object of Product type, we can't encode it to row if it's null, as Spark SQL
@@ -213,6 +204,9 @@ case class ExpressionEncoder[T](
213204
} else {
214205
// For other input objects like primitive, array, map, etc., we construct a struct to wrap
215206
// the serializer which is a column of an row.
207+
//
208+
// Note: Because Spark SQL doesn't allow top-level row to be null, to encode
209+
// top-level Option[Product] type, we make it as a top-level struct column.
216210
CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
217211
}
218212
}.flatten
@@ -226,7 +220,7 @@ case class ExpressionEncoder[T](
226220
* `GetColumnByOrdinal` with corresponding ordinal.
227221
*/
228222
val deserializer: Expression = {
229-
if (isSerializedAsStruct) {
223+
if (isSerializedAsStructForTopLevel) {
230224
// We serialized this kind of objects to root-level row. The input of general deserializer
231225
// is a `GetColumnByOrdinal(0)` expression to extract first column of a row. We need to
232226
// transform attributes accessors.
@@ -253,10 +247,24 @@ case class ExpressionEncoder[T](
253247
})
254248

255249
/**
256-
* Returns true if the type `T` is serialized as a struct.
250+
* Returns true if the type `T` is serialized as a struct by `objSerializer`.
257251
*/
258252
def isSerializedAsStruct: Boolean = objSerializer.dataType.isInstanceOf[StructType]
259253

254+
/**
255+
* Returns true if the type `T` is an `Option` type.
256+
*/
257+
def isOptionType: Boolean = classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)
258+
259+
/**
260+
* If the type `T` is serialized as a struct, when it is encoded to a Spark SQL row, fields in
261+
* the struct are naturally mapped to top-level columns in a row. In other words, the serialized
262+
* struct is flattened to row. But in case of the `T` is also an `Option` type, it can't be
263+
* flattened to top-level row, because in Spark SQL top-level row can't be null. This method
264+
* returns true if `T` is serialized as struct and is not `Option` type.
265+
*/
266+
def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct && !isOptionType
267+
260268
// serializer expressions are used to encode an object to a row, while the object is usually an
261269
// intermediate value produced inside an operator, not from the output of the child operator. This
262270
// is quite different from normal expressions, and `AttributeReference` doesn't work here

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,7 @@ class Dataset[T] private[sql](
10931093
// Note that we do this before joining them, to enable the join operator to return null for one
10941094
// side, in cases like outer-join.
10951095
val left = {
1096-
val combined = if (!this.exprEnc.isSerializedAsStruct) {
1096+
val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) {
10971097
assert(joined.left.output.length == 1)
10981098
Alias(joined.left.output.head, "_1")()
10991099
} else {
@@ -1103,7 +1103,7 @@ class Dataset[T] private[sql](
11031103
}
11041104

11051105
val right = {
1106-
val combined = if (!other.exprEnc.isSerializedAsStruct) {
1106+
val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) {
11071107
assert(joined.right.output.length == 1)
11081108
Alias(joined.right.output.head, "_2")()
11091109
} else {
@@ -1116,14 +1116,14 @@ class Dataset[T] private[sql](
11161116
// combine the outputs of each join side.
11171117
val conditionExpr = joined.condition.get transformUp {
11181118
case a: Attribute if joined.left.outputSet.contains(a) =>
1119-
if (!this.exprEnc.isSerializedAsStruct) {
1119+
if (!this.exprEnc.isSerializedAsStructForTopLevel) {
11201120
left.output.head
11211121
} else {
11221122
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
11231123
GetStructField(left.output.head, index)
11241124
}
11251125
case a: Attribute if joined.right.outputSet.contains(a) =>
1126-
if (!other.exprEnc.isSerializedAsStruct) {
1126+
if (!other.exprEnc.isSerializedAsStructForTopLevel) {
11271127
right.output.head
11281128
} else {
11291129
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@@ -1396,7 +1396,7 @@ class Dataset[T] private[sql](
13961396
implicit val encoder = c1.encoder
13971397
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan)
13981398

1399-
if (!encoder.isSerializedAsStruct) {
1399+
if (!encoder.isSerializedAsStructForTopLevel) {
14001400
new Dataset[U1](sparkSession, project, encoder)
14011401
} else {
14021402
// Flattens inner fields of U1

sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
458458
val encoders = columns.map(_.encoder)
459459
val namedColumns =
460460
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
461-
val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
461+
val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) {
462462
assert(groupingAttributes.length == 1)
463463
if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
464464
groupingAttributes.head

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ object TypedAggregateExpression {
4040
val outputEncoder = encoderFor[OUT]
4141
val outputType = outputEncoder.objSerializer.dataType
4242

43-
// Checks if the buffer object is simple, i.e. the buffer encoder is flat and the serializer
44-
// expression is an alias of `BoundReference`, which means the buffer object doesn't need
45-
// serialization.
43+
// Checks if the buffer object is simple, i.e. the `BUF` type is not serialized as struct
44+
// and the serializer expression is an alias of `BoundReference`, which means the buffer
45+
// object doesn't need serialization.
4646
val isSimpleBuffer = {
4747
bufferSerializer.head match {
4848
case Alias(_: BoundReference, _) if !bufferEncoder.isSerializedAsStruct => true
@@ -76,7 +76,7 @@ object TypedAggregateExpression {
7676
None,
7777
bufferSerializer,
7878
bufferEncoder.resolveAndBind().deserializer,
79-
outputEncoder.serializer,
79+
outputEncoder.objSerializer,
8080
outputType,
8181
outputEncoder.objSerializer.nullable)
8282
}
@@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression(
213213
inputSchema: Option[StructType],
214214
bufferSerializer: Seq[NamedExpression],
215215
bufferDeserializer: Expression,
216-
outputSerializer: Seq[Expression],
216+
outputSerializer: Expression,
217217
dataType: DataType,
218218
nullable: Boolean,
219219
mutableAggBufferOffset: Int = 0,
@@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression(
245245
aggregator.merge(buffer, input)
246246
}
247247

248-
private lazy val resultObjToRow = dataType match {
249-
case _: StructType =>
250-
UnsafeProjection.create(CreateStruct(outputSerializer))
251-
case _ =>
252-
assert(outputSerializer.length == 1)
253-
UnsafeProjection.create(outputSerializer.head)
254-
}
248+
private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer)
255249

256250
override def eval(buffer: Any): Any = {
257251
val resultObj = aggregator.finish(buffer)

sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator
2222
import org.apache.spark.sql.expressions.scalalang.typed
2323
import org.apache.spark.sql.functions._
2424
import org.apache.spark.sql.test.SharedSQLContext
25-
import org.apache.spark.sql.types.StringType
25+
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType}
2626

2727

2828
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] {
@@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String, ComplexAggData] {
149149

150150

151151
case class OptionBooleanData(name: String, isGood: Option[Boolean])
152+
case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)])
152153

153154
case class OptionBooleanAggregator(colName: String)
154155
extends Aggregator[Row, Option[Boolean], Option[Boolean]] {
@@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String)
183184
def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder()
184185
}
185186

187+
case class OptionBooleanIntAggregator(colName: String)
188+
extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] {
189+
190+
override def zero: Option[(Boolean, Int)] = None
191+
192+
override def reduce(buffer: Option[(Boolean, Int)], row: Row): Option[(Boolean, Int)] = {
193+
val index = row.fieldIndex(colName)
194+
val value = if (row.isNullAt(index)) {
195+
Option.empty[(Boolean, Int)]
196+
} else {
197+
val nestedRow = row.getStruct(index)
198+
Some((nestedRow.getBoolean(0), nestedRow.getInt(1)))
199+
}
200+
merge(buffer, value)
201+
}
202+
203+
override def merge(
204+
b1: Option[(Boolean, Int)],
205+
b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = {
206+
if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) {
207+
val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0)
208+
Some((true, newInt))
209+
} else if (b1.isDefined) {
210+
b1
211+
} else {
212+
b2
213+
}
214+
}
215+
216+
override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean, Int)] = reduction
217+
218+
override def bufferEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder
219+
override def outputEncoder: Encoder[Option[(Boolean, Int)]] = OptionalBoolIntEncoder
220+
221+
def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] = ExpressionEncoder()
222+
}
223+
186224
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
187225
import testImplicits._
188226

@@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
393431
assert(grouped.schema == df.schema)
394432
checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob", Some(true)))
395433
}
434+
435+
test("SPARK-24762: Aggregator should be able to use Option of Product encoder") {
436+
val df = Seq(
437+
OptionBooleanIntData("bob", Some((true, 1))),
438+
OptionBooleanIntData("bob", Some((false, 2))),
439+
OptionBooleanIntData("bob", None)).toDF()
440+
441+
val group = df
442+
.groupBy("name")
443+
.agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood"))
444+
445+
val expectedSchema = new StructType()
446+
.add("name", StringType, nullable = true)
447+
.add("isGood",
448+
new StructType()
449+
.add("_1", BooleanType, nullable = false)
450+
.add("_2", IntegerType, nullable = false),
451+
nullable = true)
452+
453+
assert(df.schema == expectedSchema)
454+
assert(group.schema == expectedSchema)
455+
checkAnswer(group, Row("bob", Row(true, 3)) :: Nil)
456+
checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob", Some((true, 3))))
457+
}
396458
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,15 +1312,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
13121312
checkDataset(dsString, arrayString)
13131313
}
13141314

1315-
test("SPARK-18251: the type of Dataset can't be Option of Product type") {
1316-
checkDataset(Seq(Some(1), None).toDS(), Some(1), None)
1317-
1318-
val e = intercept[UnsupportedOperationException] {
1319-
Seq(Some(1 -> "a"), None).toDS()
1320-
}
1321-
assert(e.getMessage.contains("Cannot create encoder for Option of Product type"))
1322-
}
1323-
13241315
test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a negative number") {
13251316
// Since the sizeInBytes in Statistics could exceed the limit of an Int, we should use BigInt
13261317
// instead of Int for avoiding possible overflow.
@@ -1558,6 +1549,74 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
15581549
Seq(Row("Amsterdam")))
15591550
}
15601551

1552+
test("SPARK-24762: Enable top-level Option of Product encoders") {
1553+
val data = Seq(Some((1, "a")), Some((2, "b")), None)
1554+
val ds = data.toDS()
1555+
1556+
checkDataset(
1557+
ds,
1558+
data: _*)
1559+
1560+
val schema = new StructType().add(
1561+
"value",
1562+
new StructType()
1563+
.add("_1", IntegerType, nullable = false)
1564+
.add("_2", StringType, nullable = true),
1565+
nullable = true)
1566+
1567+
assert(ds.schema == schema)
1568+
1569+
val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")), 3.0)))
1570+
val nestedDs = nestedOptData.toDS()
1571+
1572+
checkDataset(
1573+
nestedDs,
1574+
nestedOptData: _*)
1575+
1576+
val nestedSchema = StructType(Seq(
1577+
StructField("value", StructType(Seq(
1578+
StructField("_1", StructType(Seq(
1579+
StructField("_1", IntegerType, nullable = false),
1580+
StructField("_2", StringType, nullable = true)))),
1581+
StructField("_2", DoubleType, nullable = false)
1582+
)), nullable = true)
1583+
))
1584+
assert(nestedDs.schema == nestedSchema)
1585+
}
1586+
1587+
test("SPARK-24762: Resolving Option[Product] field") {
1588+
val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS()
1589+
.as[(Int, Option[(String, Double)])]
1590+
checkDataset(ds,
1591+
(1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None))
1592+
}
1593+
1594+
test("SPARK-24762: select Option[Product] field") {
1595+
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
1596+
val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]])
1597+
checkDataset(ds1,
1598+
Some((1, 2)), Some((2, 3)), Some((3, 4)))
1599+
1600+
val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1), null)").as[Option[(Int, Int)]])
1601+
checkDataset(ds2,
1602+
None, None, Some((3, 4)))
1603+
}
1604+
1605+
test("SPARK-24762: joinWith on Option[Product]") {
1606+
val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a")
1607+
val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b")
1608+
val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
1609+
checkDataset(joined, (Some((2, 3)), Some((1, 2))))
1610+
}
1611+
1612+
test("SPARK-24762: typed agg on Option[Product] type") {
1613+
val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS()
1614+
assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1)))
1615+
1616+
assert(ds.groupByKey(x => x).count().collect() ===
1617+
Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1)))
1618+
}
1619+
15611620
test("SPARK-25942: typed aggregation on primitive type") {
15621621
val ds = Seq(1, 2, 3).toDS()
15631622

0 commit comments

Comments
 (0)