Skip to content

Commit ac5fce4

Browse files
zeruibaoHeartSaVioR
authored andcommitted
[SPARK-51920][SS][PYTHON] Fix type handling of namedTuple for transfromWithState
### What changes were proposed in this pull request? Fix type handling of namedTuple for transfromWithState ### Why are the changes needed? We hit the issue when using namedTuple as value of structType like ``` class Person(NamedTuple): age: Integer name: String def handleInputRows( self, key: Any, rows: Iterator[Row], timerValues: TimerValues ) -> Iterator[Row]: person: Person = Person(age = 1, name= "peter") person_list = [] person_list.append(person) self.person_list.update((person_list,)) ``` The `_serialize_to_bytes` cannot construct the namedTuple correctly and hit ``` File "/databricks/spark/python/pyspark/sql/streaming/stateful_processor_api_client.py", line 575, in normalize_value return type(v)(normalize_value(e) for e in v) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: Person.__new__() missing 2 required positional arguments: 'age' and 'name' ``` It's because NamedTuple cannot accept generator as parameter. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #53314 from zeruibao/zeruibao/SPARK-5192-fix-namedtuple-type. Lead-authored-by: zeruibao <[email protected]> Co-authored-by: Zerui Bao <[email protected]> Signed-off-by: Jungtaek Lim <[email protected]>
1 parent ca17514 commit ac5fce4

File tree

2 files changed

+59
-13
lines changed

2 files changed

+59
-13
lines changed

python/pyspark/sql/streaming/stateful_processor_api_client.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,11 @@ def normalize_value(v: Any) -> Any:
501501
# Convert NumPy types to Python primitive types.
502502
if isinstance(v, np.generic):
503503
return v.tolist()
504+
# Named tuples (collections.namedtuple or typing.NamedTuple) and Row both
505+
# require positional arguments and cannot be instantiated
506+
# with a generator expression.
507+
if isinstance(v, Row) or (isinstance(v, tuple) and hasattr(v, "_fields")):
508+
return type(v)(*[normalize_value(e) for e in v])
504509
# List / tuple: recursively normalize each element
505510
if isinstance(v, (list, tuple)):
506511
return type(v)(normalize_value(e) for e in v)

python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717

1818
from abc import abstractmethod
1919
import sys
20-
from typing import Iterator
20+
from typing import (
21+
Iterator,
22+
NamedTuple,
23+
)
2124
import unittest
2225
from pyspark.errors import PySparkRuntimeError
2326
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
@@ -1663,10 +1666,15 @@ def close(self) -> None:
16631666

16641667
# A stateful processor that contains composite python type inside Value, List and Map state variable
16651668
class PandasStatefulProcessorCompositeType(StatefulProcessor):
1669+
class Address(NamedTuple):
1670+
road_id: int
1671+
city: str
1672+
16661673
TAGS = [["dummy1", "dummy2"], ["dummy3"]]
16671674
METADATA = [{"key": "env", "value": "prod"}, {"key": "region", "value": "us-west"}]
16681675
ATTRIBUTES_MAP = {"key1": [1], "key2": [10]}
16691676
CONFS_MAP = {"e1": {"e2": 5, "e3": 10}}
1677+
ADDRESS = [Address(1, "Seattle"), Address(3, "SF")]
16701678

16711679
def init(self, handle: StatefulProcessorHandle) -> None:
16721680
obj_schema = StructType(
@@ -1681,6 +1689,17 @@ def init(self, handle: StatefulProcessorHandle) -> None:
16811689
)
16821690
),
16831691
),
1692+
StructField(
1693+
"address",
1694+
ArrayType(
1695+
StructType(
1696+
[
1697+
StructField("road_id", IntegerType()),
1698+
StructField("city", StringType()),
1699+
]
1700+
)
1701+
),
1702+
),
16841703
]
16851704
)
16861705

@@ -1700,25 +1719,28 @@ def init(self, handle: StatefulProcessorHandle) -> None:
17001719

17011720
def _update_obj_state(self, total_temperature):
17021721
if self.obj_state.exists():
1703-
ids, tags, metadata = self.obj_state.get()
1722+
ids, tags, metadata, address = self.obj_state.get()
17041723
assert tags == self.TAGS, f"Tag mismatch: {tags}"
17051724
assert metadata == [Row(**m) for m in self.METADATA], f"Metadata mismatch: {metadata}"
1725+
assert address == [
1726+
Row(**e._asdict()) for e in self.ADDRESS
1727+
], f"Address mismatch: {address}"
17061728
ids = [int(x + total_temperature) for x in ids]
17071729
else:
17081730
ids = [0]
1709-
self.obj_state.update((ids, self.TAGS, self.METADATA))
1731+
self.obj_state.update((ids, self.TAGS, self.METADATA, self.ADDRESS))
17101732
return ids
17111733

17121734
def _update_list_state(self, total_temperature, initial_obj):
17131735
existing_list = self.list_state.get()
17141736
updated_list = []
1715-
for ids, tags, metadata in existing_list:
1737+
for ids, tags, metadata, address in existing_list:
17161738
ids.append(total_temperature)
1717-
updated_list.append((ids, tags, [row.asDict() for row in metadata]))
1739+
updated_list.append((ids, tags, [row.asDict() for row in metadata], address))
17181740
if not updated_list:
17191741
updated_list.append(initial_obj)
17201742
self.list_state.put(updated_list)
1721-
return [id_val for ids, _, _ in updated_list for id_val in ids]
1743+
return [id_val for ids, _, _, _ in updated_list for id_val in ids]
17221744

17231745
def _update_map_state(self, key, total_temperature):
17241746
if not self.map_state.containsKey(key):
@@ -1736,7 +1758,7 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
17361758

17371759
updated_ids = self._update_obj_state(total_temperature)
17381760
flattened_ids = self._update_list_state(
1739-
total_temperature, (updated_ids, self.TAGS, self.METADATA)
1761+
total_temperature, (updated_ids, self.TAGS, self.METADATA, self.ADDRESS)
17401762
)
17411763
attributes_map, confs_map = self._update_map_state(key, total_temperature)
17421764

@@ -1767,10 +1789,15 @@ def close(self) -> None:
17671789

17681790

17691791
class RowStatefulProcessorCompositeType(StatefulProcessor):
1792+
class Address(NamedTuple):
1793+
road_id: int
1794+
city: str
1795+
17701796
TAGS = [["dummy1", "dummy2"], ["dummy3"]]
17711797
METADATA = [{"key": "env", "value": "prod"}, {"key": "region", "value": "us-west"}]
17721798
ATTRIBUTES_MAP = {"key1": [1], "key2": [10]}
17731799
CONFS_MAP = {"e1": {"e2": 5, "e3": 10}}
1800+
ADDRESS = [Address(1, "Seattle"), Address(3, "SF")]
17741801

17751802
def init(self, handle: StatefulProcessorHandle) -> None:
17761803
obj_schema = StructType(
@@ -1785,6 +1812,17 @@ def init(self, handle: StatefulProcessorHandle) -> None:
17851812
)
17861813
),
17871814
),
1815+
StructField(
1816+
"address",
1817+
ArrayType(
1818+
StructType(
1819+
[
1820+
StructField("road_id", IntegerType()),
1821+
StructField("city", StringType()),
1822+
]
1823+
)
1824+
),
1825+
),
17881826
]
17891827
)
17901828

@@ -1804,25 +1842,28 @@ def init(self, handle: StatefulProcessorHandle) -> None:
18041842

18051843
def _update_obj_state(self, total_temperature):
18061844
if self.obj_state.exists():
1807-
ids, tags, metadata = self.obj_state.get()
1845+
ids, tags, metadata, address = self.obj_state.get()
18081846
assert tags == self.TAGS, f"Tag mismatch: {tags}"
18091847
assert metadata == [Row(**m) for m in self.METADATA], f"Metadata mismatch: {metadata}"
1848+
assert address == [
1849+
Row(**e._asdict()) for e in self.ADDRESS
1850+
], f"Address mismatch: {address}"
18101851
ids = [int(x + total_temperature) for x in ids]
18111852
else:
18121853
ids = [0]
1813-
self.obj_state.update((ids, self.TAGS, self.METADATA))
1854+
self.obj_state.update((ids, self.TAGS, self.METADATA, self.ADDRESS))
18141855
return ids
18151856

18161857
def _update_list_state(self, total_temperature, initial_obj):
18171858
existing_list = self.list_state.get()
18181859
updated_list = []
1819-
for ids, tags, metadata in existing_list:
1860+
for ids, tags, metadata, address in existing_list:
18201861
ids.append(total_temperature)
1821-
updated_list.append((ids, tags, [row.asDict() for row in metadata]))
1862+
updated_list.append((ids, tags, [row.asDict() for row in metadata], address))
18221863
if not updated_list:
18231864
updated_list.append(initial_obj)
18241865
self.list_state.put(updated_list)
1825-
return [id_val for ids, _, _ in updated_list for id_val in ids]
1866+
return [id_val for ids, _, _, _ in updated_list for id_val in ids]
18261867

18271868
def _update_map_state(self, key, total_temperature):
18281869
if not self.map_state.containsKey(key):
@@ -1840,7 +1881,7 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
18401881

18411882
updated_ids = self._update_obj_state(total_temperature)
18421883
flattened_ids = self._update_list_state(
1843-
total_temperature, (updated_ids, self.TAGS, self.METADATA)
1884+
total_temperature, (updated_ids, self.TAGS, self.METADATA, self.ADDRESS)
18441885
)
18451886
attributes_map, confs_map = self._update_map_state(key, total_temperature)
18461887

0 commit comments

Comments
 (0)