Skip to content
This repository was archived by the owner on Jan 9, 2020. It is now read-only.

Commit 8b08fd0

Browse files
hvanhovellcloud-fan
authored andcommitted
[SPARK-21258][SQL] Fix WindowExec complex object aggregation with spilling
## What changes were proposed in this pull request? `WindowExec` currently improperly stores complex objects (UnsafeRow, UnsafeArrayData, UnsafeMapData, UTF8String) during aggregation by keeping a reference in the buffer used by `GeneratedMutableProjections` to the actual input data. Things go wrong when the input object (or the backing bytes) are reused for other things. This could happen in window functions when it starts spilling to disk. When reading the back the spill files the `UnsafeSorterSpillReader` reuses the buffer to which the `UnsafeRow` points, leading to weird corruption scenario's. Note that this only happens for aggregate functions that preserve (parts of) their input, for example `FIRST`, `LAST`, `MIN` & `MAX`. This was not seen before, because the spilling logic was not doing actual spills as much and actually used an in-memory page. This page was not cleaned up during window processing and made sure unsafe objects point to their own dedicated memory location. This was changed by apache#16909, after this PR Spark spills more eagerly. This PR provides a surgical fix because we are close to releasing Spark 2.2. This change just makes sure that there cannot be any object reuse at the expensive of a little bit of performance. We will follow-up with a more subtle solution at a later point. ## How was this patch tested? Added a regression test to `DataFrameWindowFunctionsSuite`. Author: Herman van Hovell <[email protected]> Closes apache#18470 from hvanhovell/SPARK-21258. (cherry picked from commit e2f32ee) Signed-off-by: Wenchen Fan <[email protected]>
1 parent d16e262 commit 8b08fd0

File tree

2 files changed

+51
-3
lines changed

2 files changed

+51
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,13 @@ private[window] final class AggregateProcessor(
145145

146146
/** Update the buffer. */
147147
def update(input: InternalRow): Unit = {
148-
updateProjection(join(buffer, input))
148+
// TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that
149+
// MutableProjection makes copies of the complex input objects it buffer.
150+
val copy = input.copy()
151+
updateProjection(join(buffer, copy))
149152
var i = 0
150153
while (i < numImperatives) {
151-
imperatives(i).update(buffer, input)
154+
imperatives(i).update(buffer, copy)
152155
i += 1
153156
}
154157
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
2121
import org.apache.spark.sql.functions._
22+
import org.apache.spark.sql.internal.SQLConf
2223
import org.apache.spark.sql.test.SharedSQLContext
23-
import org.apache.spark.sql.types.{DataType, LongType, StructType}
24+
import org.apache.spark.sql.types._
2425

2526
/**
2627
* Window function testing for DataFrame API.
@@ -423,4 +424,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
423424
df.select(selectList: _*).where($"value" < 2),
424425
Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0)))
425426
}
427+
428+
test("SPARK-21258: complex object in combination with spilling") {
429+
// Make sure we trigger the spilling path.
430+
withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") {
431+
val sampleSchema = new StructType().
432+
add("f0", StringType).
433+
add("f1", LongType).
434+
add("f2", ArrayType(new StructType().
435+
add("f20", StringType))).
436+
add("f3", ArrayType(new StructType().
437+
add("f30", StringType)))
438+
439+
val w0 = Window.partitionBy("f0").orderBy("f1")
440+
val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue)
441+
442+
val c0 = first(struct($"f2", $"f3")).over(w0) as "c0"
443+
val c1 = last(struct($"f2", $"f3")).over(w1) as "c1"
444+
445+
val input =
446+
"""{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]}
447+
|{"f1":1497802179638}
448+
|{"f1":1497802189347}
449+
|{"f1":1497802189593}
450+
|{"f1":1497802189597}
451+
|{"f1":1497802189599}
452+
|{"f1":1497802192103}
453+
|{"f1":1497802193414}
454+
|{"f1":1497802193577}
455+
|{"f1":1497802193709}
456+
|{"f1":1497802202883}
457+
|{"f1":1497802203006}
458+
|{"f1":1497802203743}
459+
|{"f1":1497802203834}
460+
|{"f1":1497802203887}
461+
|{"f1":1497802203893}
462+
|{"f1":1497802203976}
463+
|{"f1":1497820168098}
464+
|""".stripMargin.split("\n").toSeq
465+
466+
import testImplicits._
467+
468+
spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () }
469+
}
470+
}
426471
}

0 commit comments

Comments
 (0)