Skip to content

Commit 06a8f6b

Browse files
bogao007HeartSaVioR
authored andcommitted
[SPARK-49744][SS][PYTHON] Implement TTL support for ListState in TransformWithStateInPandas
### What changes were proposed in this pull request? Implement TTL support for ListState in TransformWithStateInPandas. ### Why are the changes needed? Allow users to add TTL to specific list state. ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48253 from bogao007/ttl-list-state. Authored-by: bogao007 <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 06c70ba commit 06a8f6b

File tree

5 files changed

+97
-11
lines changed

5 files changed

+97
-11
lines changed

python/pyspark/sql/streaming/stateful_processor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get(self) -> Optional[Tuple]:
5656
"""
5757
return self._value_state_client.get(self._state_name)
5858

59-
def update(self, new_value: Any) -> None:
59+
def update(self, new_value: Tuple) -> None:
6060
"""
6161
Update the value of the state.
6262
"""
@@ -156,7 +156,9 @@ def getValueState(
156156
self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms)
157157
return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema)
158158

159-
def getListState(self, state_name: str, schema: Union[StructType, str]) -> ListState:
159+
def getListState(
160+
self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int] = None
161+
) -> ListState:
160162
"""
161163
Function to create new or return existing single value state variable of given type.
162164
The user must ensure to call this function only within the `init()` method of the
@@ -169,8 +171,13 @@ def getListState(self, state_name: str, schema: Union[StructType, str]) -> ListS
169171
schema : :class:`pyspark.sql.types.DataType` or str
170172
The schema of the state variable. The value can be either a
171173
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
174+
ttlDurationMs: int
175+
Time to live duration of the state in milliseconds. State values will not be returned
176+
past ttlDuration and will be eventually removed from the state store. Any state update
177+
resets the expiration time to current processing time plus ttlDuration.
178+
If ttl is not specified the state will never expire.
172179
"""
173-
self.stateful_processor_api_client.get_list_state(state_name, schema)
180+
self.stateful_processor_api_client.get_list_state(state_name, schema, ttl_duration_ms)
174181
return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema)
175182

176183

python/pyspark/sql/streaming/stateful_processor_api_client.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,9 @@ def get_value_state(
131131
# TODO(SPARK-49233): Classify user facing errors.
132132
raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}")
133133

134-
def get_list_state(self, state_name: str, schema: Union[StructType, str]) -> None:
134+
def get_list_state(
135+
self, state_name: str, schema: Union[StructType, str], ttl_duration_ms: Optional[int]
136+
) -> None:
135137
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
136138

137139
if isinstance(schema, str):
@@ -140,6 +142,8 @@ def get_list_state(self, state_name: str, schema: Union[StructType, str]) -> Non
140142
state_call_command = stateMessage.StateCallCommand()
141143
state_call_command.stateName = state_name
142144
state_call_command.schema = schema.json()
145+
if ttl_duration_ms is not None:
146+
state_call_command.ttl.durationMs = ttl_duration_ms
143147
call = stateMessage.StatefulProcessorCall(getListState=state_call_command)
144148
message = stateMessage.StateRequest(statefulProcessorCall=call)
145149

python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,18 @@ def check_results(batch_df, _):
221221

222222
self._test_transform_with_state_in_pandas_basic(ListStateProcessor(), check_results, True)
223223

224+
# test list state with ttl has the same behavior as list state when state doesn't expire.
225+
def test_transform_with_state_in_pandas_list_state_large_ttl(self):
226+
def check_results(batch_df, _):
227+
assert set(batch_df.sort("id").collect()) == {
228+
Row(id="0", countAsString="2"),
229+
Row(id="1", countAsString="2"),
230+
}
231+
232+
self._test_transform_with_state_in_pandas_basic(
233+
ListStateLargeTTLProcessor(), check_results, True, "processingTime"
234+
)
235+
224236
# test value state with ttl has the same behavior as value state when
225237
# state doesn't expire.
226238
def test_value_state_ttl_basic(self):
@@ -248,8 +260,10 @@ def check_results(batch_df, batch_id):
248260
[
249261
Row(id="ttl-count-0", count=1),
250262
Row(id="count-0", count=1),
263+
Row(id="ttl-list-state-count-0", count=1),
251264
Row(id="ttl-count-1", count=1),
252265
Row(id="count-1", count=1),
266+
Row(id="ttl-list-state-count-1", count=1),
253267
],
254268
)
255269
elif batch_id == 1:
@@ -258,21 +272,29 @@ def check_results(batch_df, batch_id):
258272
[
259273
Row(id="ttl-count-0", count=2),
260274
Row(id="count-0", count=2),
275+
Row(id="ttl-list-state-count-0", count=3),
261276
Row(id="ttl-count-1", count=2),
262277
Row(id="count-1", count=2),
278+
Row(id="ttl-list-state-count-1", count=3),
263279
],
264280
)
265281
elif batch_id == 2:
266282
# ttl-count-0 expire and restart from count 0.
267-
# ttl-count-1 get reset in batch 1 and keep the state
283+
# The TTL for value state ttl_count_state gets reset in batch 1 because of the
284+
# update operation and ttl-count-1 keeps the state.
285+
# ttl-list-state-count-0 expire and restart from count 0.
286+
# The TTL for list state ttl_list_state gets reset in batch 1 because of the
287+
# put operation and ttl-list-state-count-1 keeps the state.
268288
# non-ttl state never expires
269289
assertDataFrameEqual(
270290
batch_df,
271291
[
272292
Row(id="ttl-count-0", count=1),
273293
Row(id="count-0", count=3),
294+
Row(id="ttl-list-state-count-0", count=1),
274295
Row(id="ttl-count-1", count=3),
275296
Row(id="count-1", count=3),
297+
Row(id="ttl-list-state-count-1", count=7),
276298
],
277299
)
278300
if batch_id == 0 or batch_id == 1:
@@ -362,25 +384,38 @@ def init(self, handle: StatefulProcessorHandle) -> None:
362384
state_schema = StructType([StructField("value", IntegerType(), True)])
363385
self.ttl_count_state = handle.getValueState("ttl-state", state_schema, 10000)
364386
self.count_state = handle.getValueState("state", state_schema)
387+
self.ttl_list_state = handle.getListState("ttl-list-state", state_schema, 10000)
365388

366389
def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
367390
count = 0
368391
ttl_count = 0
392+
ttl_list_state_count = 0
369393
id = key[0]
370394
if self.count_state.exists():
371395
count = self.count_state.get()[0]
372396
if self.ttl_count_state.exists():
373397
ttl_count = self.ttl_count_state.get()[0]
398+
if self.ttl_list_state.exists():
399+
iter = self.ttl_list_state.get()
400+
for s in iter:
401+
ttl_list_state_count += s[0]
374402
for pdf in rows:
375403
pdf_count = pdf.count().get("temperature")
376404
count += pdf_count
377405
ttl_count += pdf_count
406+
ttl_list_state_count += pdf_count
378407

379408
self.count_state.update((count,))
380409
# skip updating state for the 2nd batch so that ttl state expire
381410
if not (ttl_count == 2 and id == "0"):
382411
self.ttl_count_state.update((ttl_count,))
383-
yield pd.DataFrame({"id": [f"ttl-count-{id}", f"count-{id}"], "count": [ttl_count, count]})
412+
self.ttl_list_state.put([(ttl_list_state_count,), (ttl_list_state_count,)])
413+
yield pd.DataFrame(
414+
{
415+
"id": [f"ttl-count-{id}", f"count-{id}", f"ttl-list-state-count-{id}"],
416+
"count": [ttl_count, count, ttl_list_state_count],
417+
}
418+
)
384419

385420
def close(self) -> None:
386421
pass
@@ -457,6 +492,15 @@ def close(self) -> None:
457492
pass
458493

459494

495+
# A stateful processor that inherit all behavior of ListStateProcessor except that it use
496+
# ttl state with a large timeout.
497+
class ListStateLargeTTLProcessor(ListStateProcessor):
498+
def init(self, handle: StatefulProcessorHandle) -> None:
499+
state_schema = StructType([StructField("temperature", IntegerType(), True)])
500+
self.list_state1 = handle.getListState("listState1", state_schema, 30000)
501+
self.list_state2 = handle.getListState("listState2", state_schema, 30000)
502+
503+
460504
class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase):
461505
pass
462506

sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,12 @@ class TransformWithStateInPandasStateServer(
189189
case StatefulProcessorCall.MethodCase.GETLISTSTATE =>
190190
val stateName = message.getGetListState.getStateName
191191
val schema = message.getGetListState.getSchema
192-
// TODO(SPARK-49744): Add ttl support for list state.
193-
initializeStateVariable(stateName, schema, StateVariableType.ListState, None)
192+
val ttlDurationMs = if (message.getGetListState.hasTtl) {
193+
Some(message.getGetListState.getTtl.getDurationMs)
194+
} else {
195+
None
196+
}
197+
initializeStateVariable(stateName, schema, StateVariableType.ListState, ttlDurationMs)
194198
case _ =>
195199
throw new IllegalArgumentException("Invalid method call")
196200
}
@@ -372,10 +376,14 @@ class TransformWithStateInPandasStateServer(
372376
sendResponse(1, s"Value state $stateName already exists")
373377
}
374378
case StateVariableType.ListState => if (!listStates.contains(stateName)) {
375-
// TODO(SPARK-49744): Add ttl support for list state.
379+
val state = if (ttlDurationMs.isEmpty) {
380+
statefulProcessorHandle.getListState[Row](stateName, Encoders.row(schema))
381+
} else {
382+
statefulProcessorHandle.getListState(
383+
stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get)))
384+
}
376385
listStates.put(stateName,
377-
ListStateInfo(statefulProcessorHandle.getListState[Row](stateName,
378-
Encoders.row(schema)), schema, expressionEncoder.createDeserializer(),
386+
ListStateInfo(state, schema, expressionEncoder.createDeserializer(),
379387
expressionEncoder.createSerializer()))
380388
sendResponse(0)
381389
} else {

sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,29 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
118118
}
119119
}
120120

121+
Seq(true, false).foreach { useTTL =>
122+
test(s"get list state, useTTL=$useTTL") {
123+
val stateCallCommandBuilder = StateCallCommand.newBuilder()
124+
.setStateName("newName")
125+
.setSchema("StructType(List(StructField(value,IntegerType,true)))")
126+
if (useTTL) {
127+
stateCallCommandBuilder.setTtl(StateMessage.TTLConfig.newBuilder().setDurationMs(1000))
128+
}
129+
val message = StatefulProcessorCall
130+
.newBuilder()
131+
.setGetListState(stateCallCommandBuilder.build())
132+
.build()
133+
stateServer.handleStatefulProcessorCall(message)
134+
if (useTTL) {
135+
verify(statefulProcessorHandle)
136+
.getListState[Row](any[String], any[Encoder[Row]], any[TTLConfig])
137+
} else {
138+
verify(statefulProcessorHandle).getListState[Row](any[String], any[Encoder[Row]])
139+
}
140+
verify(outputStream).writeInt(0)
141+
}
142+
}
143+
121144
test("value state exists") {
122145
val message = ValueStateCall.newBuilder().setStateName(stateName)
123146
.setExists(Exists.newBuilder().build()).build()

0 commit comments

Comments
 (0)