Skip to content

Commit 8bf6640

Browse files
zeruibaozhengruifeng
authored andcommitted
[SPARK-53638][SS][PYTHON] Limit the byte size of arrow batch for TWS to avoid OOM
### What changes were proposed in this pull request? Limit the byte size of Arrow batch for TWS to avoid OOM. ### Why are the changes needed? On the Python worker side, when using the Pandas execution path, Arrow batches must be converted into Pandas DataFrames in memory. If an Arrow batch is too large, this conversion can lead to OOM errors in the Python worker. To mitigate this risk, we need to enforce a limit on the byte size of each Arrow batch. Similarly, processing the Pandas DataFrame inside `handleInputRows` also occurs entirely in memory, so applying a size limit to the DataFrame itself further helps prevent OOM issues. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #52391 from zeruibao/zeruibao/SPARK-53638-limit-the-byte-size-of-arrow-batch. Lead-authored-by: zeruibao <[email protected]> Co-authored-by: Zerui Bao <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent e0cb512 commit 8bf6640

File tree

9 files changed

+251
-13
lines changed

9 files changed

+251
-13
lines changed

python/pyspark/sql/pandas/serializers.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1672,6 +1672,7 @@ def __init__(
16721672
safecheck,
16731673
assign_cols_by_name,
16741674
arrow_max_records_per_batch,
1675+
arrow_max_bytes_per_batch,
16751676
int_to_decimal_coercion_enabled,
16761677
):
16771678
super(TransformWithStateInPandasSerializer, self).__init__(
@@ -1682,7 +1683,11 @@ def __init__(
16821683
arrow_cast=True,
16831684
)
16841685
self.arrow_max_records_per_batch = arrow_max_records_per_batch
1686+
self.arrow_max_bytes_per_batch = arrow_max_bytes_per_batch
16851687
self.key_offsets = None
1688+
self.average_arrow_row_size = 0
1689+
self.total_bytes = 0
1690+
self.total_rows = 0
16861691

16871692
def load_stream(self, stream):
16881693
"""
@@ -1711,6 +1716,18 @@ def generate_data_batches(batches):
17111716

17121717
def row_stream():
17131718
for batch in batches:
1719+
# Short circuit batch size calculation if the batch size is
1720+
# unlimited as computing batch size is computationally expensive.
1721+
if self.arrow_max_bytes_per_batch != 2**31 - 1 and batch.num_rows > 0:
1722+
batch_bytes = sum(
1723+
buf.size
1724+
for col in batch.columns
1725+
for buf in col.buffers()
1726+
if buf is not None
1727+
)
1728+
self.total_bytes += batch_bytes
1729+
self.total_rows += batch.num_rows
1730+
self.average_arrow_row_size = self.total_bytes / self.total_rows
17141731
data_pandas = [
17151732
self.arrow_to_pandas(c, i)
17161733
for i, c in enumerate(pa.Table.from_batches([batch]).itercolumns())
@@ -1720,8 +1737,17 @@ def row_stream():
17201737
yield (batch_key, row)
17211738

17221739
for batch_key, group_rows in groupby(row_stream(), key=lambda x: x[0]):
1723-
df = pd.DataFrame([row for _, row in group_rows])
1724-
yield (batch_key, df)
1740+
rows = []
1741+
for _, row in group_rows:
1742+
rows.append(row)
1743+
if (
1744+
len(rows) >= self.arrow_max_records_per_batch
1745+
or len(rows) * self.average_arrow_row_size >= self.arrow_max_bytes_per_batch
1746+
):
1747+
yield (batch_key, pd.DataFrame(rows))
1748+
rows = []
1749+
if rows:
1750+
yield (batch_key, pd.DataFrame(rows))
17251751

17261752
_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
17271753
data_batches = generate_data_batches(_batches)
@@ -1766,13 +1792,15 @@ def __init__(
17661792
safecheck,
17671793
assign_cols_by_name,
17681794
arrow_max_records_per_batch,
1795+
arrow_max_bytes_per_batch,
17691796
int_to_decimal_coercion_enabled,
17701797
):
17711798
super(TransformWithStateInPandasInitStateSerializer, self).__init__(
17721799
timezone,
17731800
safecheck,
17741801
assign_cols_by_name,
17751802
arrow_max_records_per_batch,
1803+
arrow_max_bytes_per_batch,
17761804
int_to_decimal_coercion_enabled,
17771805
)
17781806
self.init_key_offsets = None

python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ def row(self):
237237
return RowStatefulProcessorCompositeType()
238238

239239

240+
class ChunkCountProcessorFactory(StatefulProcessorFactory):
241+
def pandas(self):
242+
return PandasChunkCountProcessor()
243+
244+
245+
class ChunkCountProcessorWithInitialStateFactory(StatefulProcessorFactory):
246+
def pandas(self):
247+
return PandasChunkCountWithInitialStateProcessor()
248+
249+
240250
# StatefulProcessor implementations
241251

242252

@@ -1830,3 +1840,36 @@ def handleInputRows(self, key, rows, timerValues) -> Iterator[Row]:
18301840

18311841
def close(self) -> None:
18321842
pass
1843+
1844+
1845+
class PandasChunkCountProcessor(StatefulProcessor):
1846+
def init(self, handle: StatefulProcessorHandle) -> None:
1847+
pass
1848+
1849+
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
1850+
chunk_count = 0
1851+
for _ in rows:
1852+
chunk_count += 1
1853+
yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})
1854+
1855+
def close(self) -> None:
1856+
pass
1857+
1858+
1859+
class PandasChunkCountWithInitialStateProcessor(StatefulProcessor):
1860+
def init(self, handle: StatefulProcessorHandle) -> None:
1861+
state_schema = StructType([StructField("value", IntegerType(), True)])
1862+
self.value_state = handle.getValueState("value_state", state_schema)
1863+
1864+
def handleInputRows(self, key, rows, timerValues) -> Iterator[pd.DataFrame]:
1865+
chunk_count = 0
1866+
for _ in rows:
1867+
chunk_count += 1
1868+
yield pd.DataFrame({"id": [key[0]], "chunkCount": [chunk_count]})
1869+
1870+
def handleInitialState(self, key, initialState, timerValues) -> None:
1871+
init_val = initialState.at[0, "initVal"]
1872+
self.value_state.update((init_val,))
1873+
1874+
def close(self) -> None:
1875+
pass

python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
UpcastProcessorFactory,
7171
MinEventTimeStatefulProcessorFactory,
7272
StatefulProcessorCompositeTypeFactory,
73+
ChunkCountProcessorFactory,
74+
ChunkCountProcessorWithInitialStateFactory,
7375
)
7476

7577

@@ -1864,6 +1866,99 @@ def close(self):
18641866
.collect()
18651867
)
18661868

1869+
def test_transform_with_state_with_bytes_limit(self):
1870+
if not self.use_pandas():
1871+
return
1872+
1873+
def make_check_results(expected_per_batch):
1874+
def check_results(batch_df, batch_id):
1875+
batch_df.collect()
1876+
if batch_id == 0:
1877+
assert set(batch_df.sort("id").collect()) == expected_per_batch[0]
1878+
else:
1879+
assert set(batch_df.sort("id").collect()) == expected_per_batch[1]
1880+
1881+
return check_results
1882+
1883+
result_with_small_limit = [
1884+
{
1885+
Row(id="0", chunkCount=2),
1886+
Row(id="1", chunkCount=2),
1887+
},
1888+
{
1889+
Row(id="0", chunkCount=3),
1890+
Row(id="1", chunkCount=2),
1891+
},
1892+
]
1893+
1894+
result_with_large_limit = [
1895+
{
1896+
Row(id="0", chunkCount=1),
1897+
Row(id="1", chunkCount=1),
1898+
},
1899+
{
1900+
Row(id="0", chunkCount=1),
1901+
Row(id="1", chunkCount=1),
1902+
},
1903+
]
1904+
1905+
data = [("0", 789), ("3", 987)]
1906+
initial_state = self.spark.createDataFrame(data, "id string, initVal int").groupBy("id")
1907+
1908+
with self.sql_conf(
1909+
# Set it to a very small number so that every row would be a separate pandas df
1910+
{"spark.sql.execution.arrow.maxBytesPerBatch": "2"}
1911+
):
1912+
self._test_transform_with_state_basic(
1913+
ChunkCountProcessorFactory(),
1914+
make_check_results(result_with_small_limit),
1915+
output_schema=StructType(
1916+
[
1917+
StructField("id", StringType(), True),
1918+
StructField("chunkCount", IntegerType(), True),
1919+
]
1920+
),
1921+
)
1922+
1923+
self._test_transform_with_state_basic(
1924+
ChunkCountProcessorWithInitialStateFactory(),
1925+
make_check_results(result_with_small_limit),
1926+
initial_state=initial_state,
1927+
output_schema=StructType(
1928+
[
1929+
StructField("id", StringType(), True),
1930+
StructField("chunkCount", IntegerType(), True),
1931+
]
1932+
),
1933+
)
1934+
1935+
with self.sql_conf(
1936+
# Set it to a very large number so that every row would be in the same pandas df
1937+
{"spark.sql.execution.arrow.maxBytesPerBatch": "100000"}
1938+
):
1939+
self._test_transform_with_state_basic(
1940+
ChunkCountProcessorFactory(),
1941+
make_check_results(result_with_large_limit),
1942+
output_schema=StructType(
1943+
[
1944+
StructField("id", StringType(), True),
1945+
StructField("chunkCount", IntegerType(), True),
1946+
]
1947+
),
1948+
)
1949+
1950+
self._test_transform_with_state_basic(
1951+
ChunkCountProcessorWithInitialStateFactory(),
1952+
make_check_results(result_with_large_limit),
1953+
initial_state=initial_state,
1954+
output_schema=StructType(
1955+
[
1956+
StructField("id", StringType(), True),
1957+
StructField("chunkCount", IntegerType(), True),
1958+
]
1959+
),
1960+
)
1961+
18671962

18681963
@unittest.skipIf(
18691964
not have_pyarrow or os.environ.get("PYTHON_GIL", "?") == "0",

python/pyspark/worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2646,11 +2646,17 @@ def read_udfs(pickleSer, infile, eval_type):
26462646
)
26472647
arrow_max_records_per_batch = int(arrow_max_records_per_batch)
26482648

2649+
arrow_max_bytes_per_batch = runner_conf.get(
2650+
"spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
2651+
)
2652+
arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)
2653+
26492654
ser = TransformWithStateInPandasSerializer(
26502655
timezone,
26512656
safecheck,
26522657
_assign_cols_by_name,
26532658
arrow_max_records_per_batch,
2659+
arrow_max_bytes_per_batch,
26542660
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
26552661
)
26562662
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF:
@@ -2659,11 +2665,17 @@ def read_udfs(pickleSer, infile, eval_type):
26592665
)
26602666
arrow_max_records_per_batch = int(arrow_max_records_per_batch)
26612667

2668+
arrow_max_bytes_per_batch = runner_conf.get(
2669+
"spark.sql.execution.arrow.maxBytesPerBatch", 2**31 - 1
2670+
)
2671+
arrow_max_bytes_per_batch = int(arrow_max_bytes_per_batch)
2672+
26622673
ser = TransformWithStateInPandasInitStateSerializer(
26632674
timezone,
26642675
safecheck,
26652676
_assign_cols_by_name,
26662677
arrow_max_records_per_batch,
2678+
arrow_max_bytes_per_batch,
26672679
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
26682680
)
26692681
elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE_PYTHON_ROW_UDF:

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,14 @@ class ApplyInPandasWithStatePythonRunner(
106106
}
107107

108108
private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
109+
private val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
109110

110111
// applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance.
111112
// Configurations are both applied to executor and Python worker, set them to the worker conf
112113
// to let Python worker read the config properly.
113114
override protected val workerConf: Map[String, String] = initialWorkerConf +
114-
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString)
115+
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
116+
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
115117

116118
private val stateRowDeserializer = stateEncoder.createDeserializer()
117119

@@ -142,7 +144,11 @@ class ApplyInPandasWithStatePythonRunner(
142144
dataOut: DataOutputStream,
143145
inputIterator: Iterator[InType]): Boolean = {
144146
if (pandasWriter == null) {
145-
pandasWriter = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch)
147+
pandasWriter = new ApplyInPandasWithStateWriter(
148+
root,
149+
writer,
150+
arrowMaxRecordsPerBatch,
151+
arrowMaxBytesPerBatch)
146152
}
147153
if (inputIterator.hasNext) {
148154
val startData = dataOut.size()

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStateWriter.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,9 @@ import org.apache.spark.unsafe.types.UTF8String
5050
class ApplyInPandasWithStateWriter(
5151
root: VectorSchemaRoot,
5252
writer: ArrowStreamWriter,
53-
arrowMaxRecordsPerBatch: Int)
54-
extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch) {
53+
arrowMaxRecordsPerBatch: Int,
54+
arrowMaxBytesPerBatch: Long)
55+
extends BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch, arrowMaxBytesPerBatch) {
5556

5657
import ApplyInPandasWithStateWriter._
5758

@@ -144,7 +145,7 @@ class ApplyInPandasWithStateWriter(
144145

145146
// If it exceeds the condition of batch (number of records) once the all data is received for
146147
// same group, finalize and construct a new batch.
147-
if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) {
148+
if (isBatchSizeLimitReached) {
148149
finalizeCurrentArrowBatch()
149150
}
150151
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/BaseStreamingArrowWriter.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class BaseStreamingArrowWriter(
3232
root: VectorSchemaRoot,
3333
writer: ArrowStreamWriter,
3434
arrowMaxRecordsPerBatch: Int,
35+
arrowMaxBytesPerBatch: Long,
3536
arrowWriterForTest: ArrowWriter = null) {
3637
protected val arrowWriterForData: ArrowWriter = if (arrowWriterForTest == null) {
3738
ArrowWriter.create(root)
@@ -54,7 +55,7 @@ class BaseStreamingArrowWriter(
5455
// If it exceeds the condition of batch (number of records) and there is more data for the
5556
// same group, finalize and construct a new batch.
5657

57-
val isCurrentBatchFull = totalNumRowsForBatch >= arrowMaxRecordsPerBatch
58+
val isCurrentBatchFull = isBatchSizeLimitReached
5859
if (isCurrentBatchFull) {
5960
finalizeCurrentChunk(isLastChunkForGroup = false)
6061
finalizeCurrentArrowBatch()
@@ -84,4 +85,13 @@ class BaseStreamingArrowWriter(
8485
protected def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = {
8586
numRowsForCurrentChunk = 0
8687
}
88+
89+
protected def isBatchSizeLimitReached: Boolean = {
90+
// If we have either reached the records or bytes limit
91+
totalNumRowsForBatch >= arrowMaxRecordsPerBatch ||
92+
// Short circuit batch size calculation if the batch size is unlimited as computing batch
93+
// size is computationally expensive.
94+
((arrowMaxBytesPerBatch != Int.MaxValue)
95+
&& (arrowWriterForData.sizeInBytes() >= arrowMaxBytesPerBatch))
96+
}
8797
}

sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkPythonRunner.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ class TransformWithStateInPySparkPythonRunner(
7575
dataOut: DataOutputStream,
7676
inputIterator: Iterator[InType]): Boolean = {
7777
if (pandasWriter == null) {
78-
pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch)
78+
pandasWriter = new BaseStreamingArrowWriter(
79+
root,
80+
writer,
81+
arrowMaxRecordsPerBatch,
82+
arrowMaxBytesPerBatch
83+
)
7984
}
8085

8186
// If we don't have data left for the current group, move to the next group.
@@ -145,7 +150,12 @@ class TransformWithStateInPySparkPythonInitialStateRunner(
145150
dataOut: DataOutputStream,
146151
inputIterator: Iterator[GroupedInType]): Boolean = {
147152
if (pandasWriter == null) {
148-
pandasWriter = new BaseStreamingArrowWriter(root, writer, arrowMaxRecordsPerBatch)
153+
pandasWriter = new BaseStreamingArrowWriter(
154+
root,
155+
writer,
156+
arrowMaxRecordsPerBatch,
157+
arrowMaxBytesPerBatch
158+
)
149159
}
150160

151161
if (inputIterator.hasNext) {
@@ -200,9 +210,11 @@ abstract class TransformWithStateInPySparkPythonBaseRunner[I](
200210

201211
protected val sqlConf = SQLConf.get
202212
protected val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
213+
protected val arrowMaxBytesPerBatch = sqlConf.arrowMaxBytesPerBatch
203214

204215
override protected val workerConf: Map[String, String] = initialWorkerConf +
205-
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString)
216+
(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) +
217+
(SQLConf.ARROW_EXECUTION_MAX_BYTES_PER_BATCH.key -> arrowMaxBytesPerBatch.toString)
206218

207219
// Use lazy val to initialize the fields before these are accessed in [[PythonArrowInput]]'s
208220
// constructor.

0 commit comments

Comments
 (0)