Skip to content

Commit d8c04cf

Browse files
[SPARK-49836][SQL][SS] Fix possibly broken query when window is provided to window/session_window fn
### What changes were proposed in this pull request? This PR fixes the correctness issue about losing operators during analysis - it happens when window is provided to window()/session_window() function. The rule `TimeWindowing` and `SessionWindowing` are responsible to resolve the time window functions. When the window function has `window` as parameter (time column) (in other words, building time window from time window), the rule wraps window with WindowTime function so that the rule ResolveWindowTime will further resolve this. (And TimeWindowing/SessionWindowing will resolve this again against the result of ResolveWindowTime.) The issue is that the rule uses "return" for the above, which intends to have "early return" as the other branch is too long compared to this branch. This unfortunately does not work as intended - the intention is just to go out of current local scope (mostly end of curly brace), but it seems to break the loop of execution in "outer" side. (I haven't debugged further but it's simply clear that it doesn't work as intended.) Quoting from Scala doc: > Nonlocal returns are implemented by throwing and catching scala.runtime.NonLocalReturnException-s. It's not super clear where NonLocalReturnException is caught in the call stack; it might exit the execution for much broader scope (context) than expected. And it's finally deprecated in Scala 3.2 and likely be removed in future. https://dotty.epfl.ch/docs/reference/dropped-features/nonlocal-returns.html Interestingly it does not break every query for chained time window aggregations. Spark already has several tests with DataFrame API and they haven't failed. The reproducer in community report is using SQL statement - where each aggregation is considered as subquery. This PR fixes the rule to NOT use early return and instead have a huge if else. ### Why are the changes needed? Described in above. ### Does this PR introduce _any_ user-facing change? Yes, this fixes the possible query breakage. The impacted workloads may not be very huge as chained time window aggregations is an advanced usage, and it does not break every query for the usage. ### How was this patch tested? New UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48309 from HeartSaVioR/SPARK-49836. Lead-authored-by: Jungtaek Lim <[email protected]> Co-authored-by: Andrzej Zera <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 0c653db commit d8c04cf

File tree

3 files changed

+232
-127
lines changed

3 files changed

+232
-127
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTimeWindows.scala

Lines changed: 128 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -87,85 +87,86 @@ object TimeWindowing extends Rule[LogicalPlan] {
8787

8888
val window = windowExpressions.head
8989

90+
// time window is provided as time column of window function, replace it with WindowTime
9091
if (StructType.acceptsType(window.timeColumn.dataType)) {
91-
return p.transformExpressions {
92+
p.transformExpressions {
9293
case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn))
9394
}
94-
}
95-
96-
val metadata = window.timeColumn match {
97-
case a: Attribute => a.metadata
98-
case _ => Metadata.empty
99-
}
100-
101-
val newMetadata = new MetadataBuilder()
102-
.withMetadata(metadata)
103-
.putBoolean(TimeWindow.marker, true)
104-
.build()
95+
} else {
96+
val metadata = window.timeColumn match {
97+
case a: Attribute => a.metadata
98+
case _ => Metadata.empty
99+
}
105100

106-
def getWindow(i: Int, dataType: DataType): Expression = {
107-
val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
108-
val remainder = (timestamp - window.startTime) % window.slideDuration
109-
val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
110-
remainder + window.slideDuration)), Some(remainder))
111-
val windowStart = lastStart - i * window.slideDuration
112-
val windowEnd = windowStart + window.windowDuration
101+
val newMetadata = new MetadataBuilder()
102+
.withMetadata(metadata)
103+
.putBoolean(TimeWindow.marker, true)
104+
.build()
113105

114-
// We make sure value fields are nullable since the dataType of TimeWindow defines them
115-
// as nullable.
116-
CreateNamedStruct(
117-
Literal(WINDOW_START) ::
118-
PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
119-
Literal(WINDOW_END) ::
120-
PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
121-
Nil)
122-
}
106+
def getWindow(i: Int, dataType: DataType): Expression = {
107+
val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
108+
val remainder = (timestamp - window.startTime) % window.slideDuration
109+
val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
110+
remainder + window.slideDuration)), Some(remainder))
111+
val windowStart = lastStart - i * window.slideDuration
112+
val windowEnd = windowStart + window.windowDuration
113+
114+
// We make sure value fields are nullable since the dataType of TimeWindow defines them
115+
// as nullable.
116+
CreateNamedStruct(
117+
Literal(WINDOW_START) ::
118+
PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
119+
Literal(WINDOW_END) ::
120+
PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
121+
Nil)
122+
}
123123

124-
val windowAttr = AttributeReference(
125-
WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
124+
val windowAttr = AttributeReference(
125+
WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
126126

127-
if (window.windowDuration == window.slideDuration) {
128-
val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
129-
exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
127+
if (window.windowDuration == window.slideDuration) {
128+
val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
129+
exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
130130

131-
val replacedPlan = p transformExpressions {
132-
case t: TimeWindow => windowAttr
133-
}
131+
val replacedPlan = p transformExpressions {
132+
case t: TimeWindow => windowAttr
133+
}
134134

135-
// For backwards compatibility we add a filter to filter out nulls
136-
val filterExpr = IsNotNull(window.timeColumn)
135+
// For backwards compatibility we add a filter to filter out nulls
136+
val filterExpr = IsNotNull(window.timeColumn)
137137

138-
replacedPlan.withNewChildren(
139-
Project(windowStruct +: child.output,
140-
Filter(filterExpr, child)) :: Nil)
141-
} else {
142-
val overlappingWindows =
143-
math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
144-
val windows =
145-
Seq.tabulate(overlappingWindows)(i =>
146-
getWindow(i, window.timeColumn.dataType))
147-
148-
val projections = windows.map(_ +: child.output)
149-
150-
// When the condition windowDuration % slideDuration = 0 is fulfilled,
151-
// the estimation of the number of windows becomes exact one,
152-
// which means all produced windows are valid.
153-
val filterExpr =
154-
if (window.windowDuration % window.slideDuration == 0) {
155-
IsNotNull(window.timeColumn)
138+
replacedPlan.withNewChildren(
139+
Project(windowStruct +: child.output,
140+
Filter(filterExpr, child)) :: Nil)
156141
} else {
157-
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
158-
window.timeColumn < windowAttr.getField(WINDOW_END)
142+
val overlappingWindows =
143+
math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
144+
val windows =
145+
Seq.tabulate(overlappingWindows)(i =>
146+
getWindow(i, window.timeColumn.dataType))
147+
148+
val projections = windows.map(_ +: child.output)
149+
150+
// When the condition windowDuration % slideDuration = 0 is fulfilled,
151+
// the estimation of the number of windows becomes exact one,
152+
// which means all produced windows are valid.
153+
val filterExpr =
154+
if (window.windowDuration % window.slideDuration == 0) {
155+
IsNotNull(window.timeColumn)
156+
} else {
157+
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
158+
window.timeColumn < windowAttr.getField(WINDOW_END)
159+
}
160+
161+
val substitutedPlan = Filter(filterExpr,
162+
Expand(projections, windowAttr +: child.output, child))
163+
164+
val renamedPlan = p transformExpressions {
165+
case t: TimeWindow => windowAttr
166+
}
167+
168+
renamedPlan.withNewChildren(substitutedPlan :: Nil)
159169
}
160-
161-
val substitutedPlan = Filter(filterExpr,
162-
Expand(projections, windowAttr +: child.output, child))
163-
164-
val renamedPlan = p transformExpressions {
165-
case t: TimeWindow => windowAttr
166-
}
167-
168-
renamedPlan.withNewChildren(substitutedPlan :: Nil)
169170
}
170171
} else if (numWindowExpr > 1) {
171172
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
@@ -210,74 +211,74 @@ object SessionWindowing extends Rule[LogicalPlan] {
210211
val session = sessionExpressions.head
211212

212213
if (StructType.acceptsType(session.timeColumn.dataType)) {
213-
return p transformExpressions {
214+
p transformExpressions {
214215
case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn))
215216
}
216-
}
217+
} else {
218+
val metadata = session.timeColumn match {
219+
case a: Attribute => a.metadata
220+
case _ => Metadata.empty
221+
}
217222

218-
val metadata = session.timeColumn match {
219-
case a: Attribute => a.metadata
220-
case _ => Metadata.empty
221-
}
223+
val newMetadata = new MetadataBuilder()
224+
.withMetadata(metadata)
225+
.putBoolean(SessionWindow.marker, true)
226+
.build()
222227

223-
val newMetadata = new MetadataBuilder()
224-
.withMetadata(metadata)
225-
.putBoolean(SessionWindow.marker, true)
226-
.build()
227-
228-
val sessionAttr = AttributeReference(
229-
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
230-
231-
val sessionStart =
232-
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
233-
val gapDuration = session.gapDuration match {
234-
case expr if expr.dataType == CalendarIntervalType =>
235-
expr
236-
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
237-
Cast(expr, CalendarIntervalType)
238-
case other =>
239-
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
240-
}
241-
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
242-
session.timeColumn.dataType, LongType)
243-
244-
// We make sure value fields are nullable since the dataType of SessionWindow defines them
245-
// as nullable.
246-
val literalSessionStruct = CreateNamedStruct(
247-
Literal(SESSION_START) ::
248-
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
249-
.castNullable() ::
250-
Literal(SESSION_END) ::
251-
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
252-
.castNullable() ::
253-
Nil)
254-
255-
val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
256-
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
228+
val sessionAttr = AttributeReference(
229+
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()
230+
231+
val sessionStart =
232+
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
233+
val gapDuration = session.gapDuration match {
234+
case expr if expr.dataType == CalendarIntervalType =>
235+
expr
236+
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
237+
Cast(expr, CalendarIntervalType)
238+
case other =>
239+
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
240+
}
241+
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
242+
session.timeColumn.dataType, LongType)
257243

258-
val replacedPlan = p transformExpressions {
259-
case s: SessionWindow => sessionAttr
260-
}
244+
// We make sure value fields are nullable since the dataType of SessionWindow defines them
245+
// as nullable.
246+
val literalSessionStruct = CreateNamedStruct(
247+
Literal(SESSION_START) ::
248+
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
249+
.castNullable() ::
250+
Literal(SESSION_END) ::
251+
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
252+
.castNullable() ::
253+
Nil)
261254

262-
val filterByTimeRange = if (gapDuration.foldable) {
263-
val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
264-
interval == null || interval.months + interval.days + interval.microseconds <= 0
265-
} else {
266-
true
267-
}
255+
val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
256+
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
268257

269-
// As same as tumbling window, we add a filter to filter out nulls.
270-
// And we also filter out events with negative or zero or invalid gap duration.
271-
val filterExpr = if (filterByTimeRange) {
272-
IsNotNull(session.timeColumn) &&
273-
(sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
274-
} else {
275-
IsNotNull(session.timeColumn)
276-
}
258+
val replacedPlan = p transformExpressions {
259+
case s: SessionWindow => sessionAttr
260+
}
277261

278-
replacedPlan.withNewChildren(
279-
Filter(filterExpr,
280-
Project(sessionStruct +: child.output, child)) :: Nil)
262+
val filterByTimeRange = if (gapDuration.foldable) {
263+
val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
264+
interval == null || interval.months + interval.days + interval.microseconds <= 0
265+
} else {
266+
true
267+
}
268+
269+
// As same as tumbling window, we add a filter to filter out nulls.
270+
// And we also filter out events with negative or zero or invalid gap duration.
271+
val filterExpr = if (filterByTimeRange) {
272+
IsNotNull(session.timeColumn) &&
273+
(sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
274+
} else {
275+
IsNotNull(session.timeColumn)
276+
}
277+
278+
replacedPlan.withNewChildren(
279+
Filter(filterExpr,
280+
Project(sessionStruct +: child.output, child)) :: Nil)
281+
}
281282
} else if (numWindowExpr > 1) {
282283
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
283284
} else {

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSessionWindowingSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,4 +547,55 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
547547
}
548548
}
549549
}
550+
551+
test("SPARK-49836 using window fn with window as parameter should preserve parent operator") {
552+
withTempView("clicks") {
553+
val df = Seq(
554+
// small window: [00:00, 01:00), user1, 2
555+
("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
556+
// small window: [01:00, 02:00), user2, 2
557+
("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
558+
// small window: [03:00, 04:00), user1, 1
559+
("2024-09-30 00:03:30", "user1"),
560+
// small window: [11:00, 12:00), user1, 3
561+
("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
562+
("2024-09-30 00:11:45", "user1")
563+
).toDF("eventTime", "userId")
564+
565+
// session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 /
566+
// (12:00, 12:05), user1, 3
567+
568+
df.createOrReplaceTempView("clicks")
569+
570+
val aggregatedData = spark.sql(
571+
"""
572+
|SELECT
573+
| userId,
574+
| avg(cpu_large.numClicks) AS clicksPerSession
575+
|FROM
576+
|(
577+
| SELECT
578+
| session_window(small_window, '5 minutes') AS session,
579+
| userId,
580+
| sum(numClicks) AS numClicks
581+
| FROM
582+
| (
583+
| SELECT
584+
| window(eventTime, '1 minute') AS small_window,
585+
| userId,
586+
| count(*) AS numClicks
587+
| FROM clicks
588+
| GROUP BY window, userId
589+
| ) cpu_small
590+
| GROUP BY session_window, userId
591+
|) cpu_large
592+
|GROUP BY userId
593+
|""".stripMargin)
594+
595+
checkAnswer(
596+
aggregatedData,
597+
Seq(Row("user1", 3), Row("user2", 2))
598+
)
599+
}
600+
}
550601
}

0 commit comments

Comments
 (0)