Skip to content

Commit 746db3d

Browse files
committed
[SPARK-53743][SS] Remove the usage of fetchWithArrow in ListState.put/appendList
### What changes were proposed in this pull request? This PR proposes to remove the usage of fetchWithArrow in ListState.put/appendList. (We don't remove the fetchWithArrow and its proto, since it does not remove noticeable complexity and removing something from proto may bring some unexpected side effect on compatibility.) ### Why are the changes needed? We have observed the case where Arrow path of sending the list has some issue, while normal path does not have an issue. The case is to have `None` value in IntegerType() in the element of list state - the column is set to nullable=True hence that should be allowed, but the error is raised during the conversion. ``` File "/databricks/spark/python/pyspark/sql/streaming/stateful_processor.py", line 147, in put self._listStateClient.put(self._stateName, newState) File "/databricks/spark/python/pyspark/sql/streaming/list_state_client.py", line 195, in put self._stateful_processor_api_client._send_arrow_state(self.schema, values) File "/spark/python/pyspark/sql/streaming/stateful_processor_api_client.py", line 604, in _send_arrow_state pandas_df = convert_pandas_using_numpy_type( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/spark/python/pyspark/sql/pandas/types.py", line 1599, in convert_pandas_using_numpy_type df[field.name] = df[field.name].astype(np_type) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/python/lib/python3.12/site-packages/pandas/core/generic.py", line 6643, in astype new_data = self._mgr.astype(dtype=dtype, copy=copy, errors=errors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/python/lib/python3.12/site-packages/pandas/core/internals/managers.py", line 430, in astype return self.apply( ^^^^^^^^^^^ File "/python/lib/python3.12/site-packages/pandas/core/internals/managers.py", line 363, in apply applied = getattr(b, f)(**kwargs) ^^^^^^^^^^^^^^^^^^^^^^^ File "/python/lib/python3.12/site-packages/pandas/core/internals/blocks.py", line 758, in astype new_values = astype_array_safe(values, dtype, copy=copy, errors=errors) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/python/lib/python3.12/site-packages/pandas/core/dtypes/astype.py", line 237, in astype_array_safe new_values = astype_array(values, dtype, copy=copy) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/python/lib/python3.12/site-packages/pandas/core/dtypes/astype.py", line 182, in astype_array values = _astype_nansafe(values, dtype, copy=copy) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/python/lib/python3.12/site-packages/pandas/core/dtypes/astype.py", line 133, in _astype_nansafe return arr.astype(dtype, copy=True) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: int() argument must be a string, a bytes-like object or a real number, not 'NoneType' ``` Since we don't know how useful the Arrow based sending list is, it'd be better not to try to fix the issue in the Arrow code path at this point and just remove it. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Updated the existing test to test the observed case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52479 from HeartSaVioR/SPARK-53743. Authored-by: Jungtaek Lim <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent 1692b55 commit 746db3d

File tree

4 files changed

+55
-67
lines changed

4 files changed

+55
-67
lines changed

python/pyspark/sql/streaming/list_state_client.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -130,24 +130,12 @@ def append_value(self, state_name: str, value: Tuple) -> None:
130130
def append_list(self, state_name: str, values: List[Tuple]) -> None:
131131
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
132132

133-
send_data_via_arrow = False
134-
135-
# To workaround mypy type assignment check.
136-
values_as_bytes: Any = []
137-
if len(values) == 100:
138-
# TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
139-
# value backed by various benchmarks.
140-
# Arrow codepath
141-
send_data_via_arrow = True
142-
else:
143-
values_as_bytes = map(
144-
lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
145-
values,
146-
)
147-
148-
append_list_call = stateMessage.AppendList(
149-
value=values_as_bytes, fetchWithArrow=send_data_via_arrow
133+
values_as_bytes = map(
134+
lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
135+
values,
150136
)
137+
138+
append_list_call = stateMessage.AppendList(value=values_as_bytes, fetchWithArrow=False)
151139
list_state_call = stateMessage.ListStateCall(
152140
stateName=state_name, appendList=append_list_call
153141
)
@@ -156,9 +144,6 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:
156144

157145
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
158146

159-
if send_data_via_arrow:
160-
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
161-
162147
response_message = self._stateful_processor_api_client._receive_proto_message()
163148
status = response_message[0]
164149
if status != 0:
@@ -168,32 +153,19 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:
168153
def put(self, state_name: str, values: List[Tuple]) -> None:
169154
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
170155

171-
send_data_via_arrow = False
172-
# To workaround mypy type assignment check.
173-
values_as_bytes: Any = []
174-
if len(values) == 100:
175-
# TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
176-
# value backed by various benchmarks.
177-
send_data_via_arrow = True
178-
else:
179-
values_as_bytes = map(
180-
lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
181-
values,
182-
)
183-
184-
put_call = stateMessage.ListStatePut(
185-
value=values_as_bytes, fetchWithArrow=send_data_via_arrow
156+
values_as_bytes = map(
157+
lambda x: self._stateful_processor_api_client._serialize_to_bytes(self.schema, x),
158+
values,
186159
)
187160

161+
put_call = stateMessage.ListStatePut(value=values_as_bytes, fetchWithArrow=False)
162+
188163
list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call)
189164
state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call)
190165
message = stateMessage.StateRequest(stateVariableRequest=state_variable_request)
191166

192167
self._stateful_processor_api_client._send_proto_message(message.SerializeToString())
193168

194-
if send_data_via_arrow:
195-
self._stateful_processor_api_client._send_arrow_state(self.schema, values)
196-
197169
response_message = self._stateful_processor_api_client._receive_proto_message()
198170
status = response_message[0]
199171
if status != 0:

python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -942,7 +942,12 @@ def close(self) -> None:
942942

943943
class PandasListStateLargeListProcessor(StatefulProcessor):
944944
def init(self, handle: StatefulProcessorHandle) -> None:
945-
list_state_schema = StructType([StructField("value", IntegerType(), True)])
945+
list_state_schema = StructType(
946+
[
947+
StructField("value", IntegerType(), True),
948+
StructField("valueNull", IntegerType(), True),
949+
]
950+
)
946951
value_state_schema = StructType([StructField("size", IntegerType(), True)])
947952
self.list_state = handle.getListState("listState", list_state_schema)
948953
self.list_size_state = handle.getValueState("listSizeState", value_state_schema)
@@ -952,18 +957,15 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
952957
elements = list(elements_iter)
953958

954959
# Use the magic number 100 to test with both inline proto case and Arrow case.
955-
# TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
956-
# value backed by various benchmarks.
957-
# Put 90 elements per batch:
958-
# 1st batch: read 0 element, and write 90 elements, read back 90 elements
959-
# (both use inline proto)
960-
# 2nd batch: read 90 elements, and write 90 elements, read back 180 elements
961-
# (read uses both inline proto and Arrow, write uses Arrow)
960+
# Now the magic number is not actually used, but this is to make this test be a regression
961+
# test of SPARK-53743.
962+
# Explicitly put 100 elements of list which triggered Arrow based list serialization before
963+
# SPARK-53743.
962964

963965
if len(elements) == 0:
964966
# should be the first batch
965967
assert self.list_size_state.get() is None
966-
new_elements = [(i,) for i in range(90)]
968+
new_elements = [(i, None) for i in range(100)]
967969
if key == ("0",):
968970
self.list_state.put(new_elements)
969971
else:
@@ -978,18 +980,20 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
978980
elements
979981
), f"list_size ({list_size}) != len(elements) ({len(elements)})"
980982

981-
expected_elements_in_state = [(i,) for i in range(list_size)]
982-
assert elements == expected_elements_in_state
983+
expected_elements_in_state = [(i, None) for i in range(list_size)]
984+
assert (
985+
elements == expected_elements_in_state
986+
), f"expected {expected_elements_in_state} but got {elements}"
983987

984988
if key == ("0",):
985989
# Use the operation `put`
986-
new_elements = [(i,) for i in range(list_size + 90)]
990+
new_elements = [(i, None) for i in range(list_size + 90)]
987991
self.list_state.put(new_elements)
988992
final_size = len(new_elements)
989993
self.list_size_state.update((final_size,))
990994
else:
991995
# Use the operation `appendList`
992-
new_elements = [(i,) for i in range(list_size, list_size + 90)]
996+
new_elements = [(i, None) for i in range(list_size, list_size + 90)]
993997
self.list_state.appendList(new_elements)
994998
final_size = len(new_elements) + list_size
995999
self.list_size_state.update((final_size,))
@@ -1004,7 +1008,12 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
10041008

10051009
class RowListStateLargeListProcessor(StatefulProcessor):
10061010
def init(self, handle: StatefulProcessorHandle) -> None:
1007-
list_state_schema = StructType([StructField("value", IntegerType(), True)])
1011+
list_state_schema = StructType(
1012+
[
1013+
StructField("value", IntegerType(), True),
1014+
StructField("valueNull", IntegerType(), True),
1015+
]
1016+
)
10081017
value_state_schema = StructType([StructField("size", IntegerType(), True)])
10091018
self.list_state = handle.getListState("listState", list_state_schema)
10101019
self.list_size_state = handle.getValueState("listSizeState", value_state_schema)
@@ -1015,18 +1024,15 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
10151024
elements = list(elements_iter)
10161025

10171026
# Use the magic number 100 to test with both inline proto case and Arrow case.
1018-
# TODO(SPARK-51907): Let's update this to be either flexible or more reasonable default
1019-
# value backed by various benchmarks.
1020-
# Put 90 elements per batch:
1021-
# 1st batch: read 0 element, and write 90 elements, read back 90 elements
1022-
# (both use inline proto)
1023-
# 2nd batch: read 90 elements, and write 90 elements, read back 180 elements
1024-
# (read uses both inline proto and Arrow, write uses Arrow)
1027+
# Now the magic number is not actually used, but this is to make this test be a regression
1028+
# test of SPARK-53743.
1029+
# Explicitly put 100 elements of list which triggered Arrow based list serialization before
1030+
# SPARK-53743.
10251031

10261032
if len(elements) == 0:
10271033
# should be the first batch
10281034
assert self.list_size_state.get() is None
1029-
new_elements = [(i,) for i in range(90)]
1035+
new_elements = [(i, None) for i in range(100)]
10301036
if key == ("0",):
10311037
self.list_state.put(new_elements)
10321038
else:
@@ -1041,18 +1047,20 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
10411047
elements
10421048
), f"list_size ({list_size}) != len(elements) ({len(elements)})"
10431049

1044-
expected_elements_in_state = [(i,) for i in range(list_size)]
1045-
assert elements == expected_elements_in_state
1050+
expected_elements_in_state = [(i, None) for i in range(list_size)]
1051+
assert (
1052+
elements == expected_elements_in_state
1053+
), f"expected {expected_elements_in_state} but got {elements}"
10461054

10471055
if key == ("0",):
10481056
# Use the operation `put`
1049-
new_elements = [(i,) for i in range(list_size + 90)]
1057+
new_elements = [(i, None) for i in range(list_size + 90)]
10501058
self.list_state.put(new_elements)
10511059
final_size = len(new_elements)
10521060
self.list_size_state.update((final_size,))
10531061
else:
10541062
# Use the operation `appendList`
1055-
new_elements = [(i,) for i in range(list_size, list_size + 90)]
1063+
new_elements = [(i, None) for i in range(list_size, list_size + 90)]
10561064
self.list_state.appendList(new_elements)
10571065
final_size = len(new_elements) + list_size
10581066
self.list_size_state.update((final_size,))

python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,11 @@ def check_results(batch_df, batch_id):
312312
batch_df.collect()
313313
if batch_id == 0:
314314
expected_prev_elements = ""
315-
expected_updated_elements = ",".join(map(lambda x: str(x), range(90)))
315+
expected_updated_elements = ",".join(map(lambda x: str(x), range(100)))
316316
else:
317317
# batch_id == 1:
318-
expected_prev_elements = ",".join(map(lambda x: str(x), range(90)))
319-
expected_updated_elements = ",".join(map(lambda x: str(x), range(180)))
318+
expected_prev_elements = ",".join(map(lambda x: str(x), range(100)))
319+
expected_updated_elements = ",".join(map(lambda x: str(x), range(190)))
320320

321321
assert set(batch_df.sort("id").collect()) == {
322322
Row(

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkStateServer.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,10 @@ class TransformWithStateInPySparkStateServer(
490490
sendResponse(2, s"state $stateName doesn't exist")
491491
}
492492
case ListStateCall.MethodCase.LISTSTATEPUT =>
493+
// TODO: Check whether we can safely remove fetchWithArrow without breaking backward
494+
// compatibility (Spark Connect)
495+
// TODO: Also check whether fetchWithArrow has a clear benefit to be retained (in terms
496+
// of performance)
493497
val rows = if (message.getListStatePut.getFetchWithArrow) {
494498
deserializer.readArrowBatches(inputStream)
495499
} else {
@@ -522,6 +526,10 @@ class TransformWithStateInPySparkStateServer(
522526
listStateInfo.listState.appendValue(newRow)
523527
sendResponse(0)
524528
case ListStateCall.MethodCase.APPENDLIST =>
529+
// TODO: Check whether we can safely remove fetchWithArrow without breaking backward
530+
// compatibility (Spark Connect)
531+
// TODO: Also check whether fetchWithArrow has a clear benefit to be retained (in terms
532+
// of performance)
525533
val rows = if (message.getAppendList.getFetchWithArrow) {
526534
deserializer.readArrowBatches(inputStream)
527535
} else {

0 commit comments

Comments
 (0)