Skip to content

Commit f5bb11c

Browse files
neilramaswamyHeartSaVioR
authored andcommitted
[SPARK-50301][SS] Make TransformWithState metrics reflect their intuitive meanings
### What changes were proposed in this pull request? These changes make the following changes to metrics in TWS: - `allUpdatesTimeMs` now captures the time it takes to process all the new data with the user's stateful processor. - `timerProcessingTimeMs` was added to capture the time it takes to process all the user's timers. - `allRemovalsTimeMs` now captures the time it takes to do TTL cleanup at the end of a micro-batch. - `commitTimeMs` now captures _only_ the time it takes to commit the state, not the TTL cleanup. With these metrics, a user can have a fairly clear picture of where time is being spent in a micro-batch that uses TWS: ![image](https://github.com/user-attachments/assets/87a0dc9c-c71b-4d55-8623-8970ad83adf6) ### Why are the changes needed? The metrics today misrepresent what they're actually measuring. ### Does this PR introduce _any_ user-facing change? Yes. Metrics for TWS are changing. However, since TWS is `private[sql]`, this shouldn't impact any real users. ### How was this patch tested? We don't have any way to test these metrics in _any_ stateful operator for streaming today. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48862 from neilramaswamy/spark-50301. Authored-by: Neil Ramaswamy <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent ea222a3 commit f5bb11c

File tree

2 files changed

+97
-12
lines changed

2 files changed

+97
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,20 @@ case class TransformWithStateExec(
343343
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
344344
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
345345
val commitTimeMs = longMetric("commitTimeMs")
346-
val timeoutLatencyMs = longMetric("allRemovalsTimeMs")
346+
val timerProcessingTimeMs = longMetric("timerProcessingTimeMs")
347+
// In TWS, allRemovalsTimeMs is the time taken to remove state due to TTL.
348+
// It does not measure any time taken by explicit calls from the user's state processor
349+
// that clear()s state variables.
350+
//
351+
// allRemovalsTimeMs is not granular enough to distinguish between user-caused removals and
352+
// TTL-caused removals. We could leave this empty and have two custom metrics, but leaving
353+
// this as always 0 will be confusing for users. We could also time every call to clear(), but
354+
// that could have performance penalties. So, we choose to capture TTL-only removals.
355+
val allRemovalsTimeMs = longMetric("allRemovalsTimeMs")
347356

348357
val currentTimeNs = System.nanoTime
349358
val updatesStartTimeNs = currentTimeNs
350-
var timeoutProcessingStartTimeNs = currentTimeNs
359+
var timerProcessingStartTimeNs = currentTimeNs
351360

352361
// If timeout is based on event time, then filter late data based on watermark
353362
val filteredIter = watermarkPredicateForDataForLateEvents match {
@@ -360,9 +369,13 @@ case class TransformWithStateExec(
360369
val newDataProcessorIter =
361370
CompletionIterator[InternalRow, Iterator[InternalRow]](
362371
processNewData(filteredIter), {
363-
// Once the input is processed, mark the start time for timeout processing to measure
372+
// Note: Due to the iterator lazy execution, this metric also captures the time taken
373+
// by the upstream (consumer) operators in addition to the processing in this operator.
374+
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
375+
376+
// Once the input is processed, mark the start time for timer processing to measure
364377
// it separately from the overall processing time.
365-
timeoutProcessingStartTimeNs = System.nanoTime
378+
timerProcessingStartTimeNs = System.nanoTime
366379
processorHandle.setHandleState(StatefulProcessorHandleState.DATA_PROCESSED)
367380
})
368381

@@ -376,9 +389,10 @@ case class TransformWithStateExec(
376389
private def getIterator(): Iterator[InternalRow] =
377390
CompletionIterator[InternalRow, Iterator[InternalRow]](
378391
processTimers(timeMode, processorHandle), {
379-
// Note: `timeoutLatencyMs` also includes the time the parent operator took for
380-
// processing output returned through iterator.
381-
timeoutLatencyMs += NANOSECONDS.toMillis(System.nanoTime - timeoutProcessingStartTimeNs)
392+
// Note: `timerProcessingTimeMs` also includes the time the parent operators take for
393+
// processing output returned from the timers that fire.
394+
timerProcessingTimeMs +=
395+
NANOSECONDS.toMillis(System.nanoTime - timerProcessingStartTimeNs)
382396
processorHandle.setHandleState(StatefulProcessorHandleState.TIMER_PROCESSED)
383397
})
384398
}
@@ -387,13 +401,12 @@ case class TransformWithStateExec(
387401
// Return an iterator of all the rows generated by all the keys, such that when fully
388402
// consumed, all the state updates will be committed by the state store
389403
CompletionIterator[InternalRow, Iterator[InternalRow]](outputIterator, {
390-
// Note: Due to the iterator lazy execution, this metric also captures the time taken
391-
// by the upstream (consumer) operators in addition to the processing in this operator.
392-
allUpdatesTimeMs += NANOSECONDS.toMillis(System.nanoTime - updatesStartTimeNs)
404+
allRemovalsTimeMs += timeTakenMs {
405+
processorHandle.doTtlCleanup()
406+
}
407+
393408
commitTimeMs += timeTakenMs {
394409
if (isStreaming) {
395-
// clean up any expired user state
396-
processorHandle.doTtlCleanup()
397410
store.commit()
398411
} else {
399412
store.abort()
@@ -419,6 +432,8 @@ case class TransformWithStateExec(
419432
StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state variables"),
420433
StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of deleted state variables"),
421434
// metrics around timers
435+
StatefulOperatorCustomSumMetric("timerProcessingTimeMs",
436+
"Number of milliseconds taken to process all timers"),
422437
StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of registered timers"),
423438
StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted timers"),
424439
StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired timers"),

sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,30 @@ class StatefulProcessorWithCompositeTypes extends RunningCountStatefulProcessor
401401
}
402402
}
403403

404+
// For each record, creates a timer to fire in 10 seconds that sleeps for 1 second.
405+
class SleepingTimerProcessor extends StatefulProcessor[String, String, String] {
406+
override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {}
407+
408+
override def handleInputRows(
409+
key: String,
410+
inputRows: Iterator[String],
411+
timerValues: TimerValues): Iterator[String] = {
412+
inputRows.flatMap { _ =>
413+
val currentTime = timerValues.getCurrentProcessingTimeInMs()
414+
getHandle.registerTimer(currentTime + 10000)
415+
None
416+
}
417+
}
418+
419+
override def handleExpiredTimer(
420+
key: String,
421+
timerValues: TimerValues,
422+
expiredTimerInfo: ExpiredTimerInfo): Iterator[String] = {
423+
Thread.sleep(1000)
424+
Iterator.single(key)
425+
}
426+
}
427+
404428
/**
405429
* Class that adds tests for transformWithState stateful streaming operator
406430
*/
@@ -708,6 +732,52 @@ class TransformWithStateSuite extends StateStoreMetricsTest
708732
)
709733
}
710734

735+
test("transformWithState - timer duration should be reflected in metrics") {
736+
val clock = new StreamManualClock
737+
val inputData = MemoryStream[String]
738+
val result = inputData.toDS()
739+
.groupByKey(x => x)
740+
.transformWithState(
741+
new SleepingTimerProcessor, TimeMode.ProcessingTime(), OutputMode.Update())
742+
743+
testStream(result, OutputMode.Update())(
744+
StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock),
745+
AddData(inputData, "a"),
746+
AdvanceManualClock(1 * 1000),
747+
// Side effect: timer scheduled for t = 1 + 10 = 11.
748+
CheckNewAnswer(),
749+
Execute { q =>
750+
val metrics = q.lastProgress.stateOperators(0).customMetrics
751+
assert(metrics.get("numRegisteredTimers") === 1)
752+
assert(metrics.get("timerProcessingTimeMs") < 2000)
753+
},
754+
755+
AddData(inputData, "b"),
756+
AdvanceManualClock(1 * 1000),
757+
// Side effect: timer scheduled for t = 2 + 10 = 12.
758+
CheckNewAnswer(),
759+
Execute { q =>
760+
val metrics = q.lastProgress.stateOperators(0).customMetrics
761+
assert(metrics.get("numRegisteredTimers") === 1)
762+
assert(metrics.get("timerProcessingTimeMs") < 2000)
763+
},
764+
765+
AddData(inputData, "c"),
766+
// Time is currently 2 and we need to advance past 12. So, advance by 11 seconds.
767+
AdvanceManualClock(11 * 1000),
768+
CheckNewAnswer("a", "b"),
769+
Execute { q =>
770+
val metrics = q.lastProgress.stateOperators(0).customMetrics
771+
assert(metrics.get("numRegisteredTimers") === 1)
772+
773+
// Both timers should have fired and taken 1 second each to process.
774+
assert(metrics.get("timerProcessingTimeMs") >= 2000)
775+
},
776+
777+
StopStream
778+
)
779+
}
780+
711781
test("Use statefulProcessor without transformWithState - handle should be absent") {
712782
val processor = new RunningCountStatefulProcessor()
713783
val ex = intercept[Exception] {

0 commit comments

Comments
 (0)