Skip to content

Commit 592c03f

Browse files
HyukjinKwonsunchao
authored andcommitted
[SPARK-51316][PYTHON] Allow Arrow batches in bytes instead of number of rows
This PR allows Arrow batches in bytes instead of number of rows We enabled `spark.sql.execution.pythonUDF.arrow.enabled` by default, and we should make sure users won't hit OOM. Yes. Now we will make the Arrow batches in bytes 256MB by default, and users can configure this Tested with changing default value to 1KB, and added a unittest. Also manually tested as below: ```python from pyspark.sql.functions import pandas_udf import pandas as pd pandas_udf("long") def func(s: pd.Series) -> pd.Series: return s a = spark.range(100000).select(func("id")).collect() ``` No. Closes #50080 from HyukjinKwon/bytes-arrow. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit 53fc763)
1 parent 076bf80 commit 592c03f

File tree

8 files changed

+141
-24
lines changed

8 files changed

+141
-24
lines changed

python/pyspark/sql/tests/test_arrow_map.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,35 @@ def test_self_join(self):
146146
expected = df1.join(df1).collect()
147147
self.assertEqual(sorted(actual), sorted(expected))
148148

149+
def test_map_in_arrow_with_barrier_mode(self):
150+
df = self.spark.range(10)
151+
152+
def func1(iterator):
153+
from pyspark import TaskContext, BarrierTaskContext
154+
155+
tc = TaskContext.get()
156+
assert tc is not None
157+
assert not isinstance(tc, BarrierTaskContext)
158+
for batch in iterator:
159+
yield batch
160+
161+
df.mapInArrow(func1, "id long", False).collect()
162+
163+
def func2(iterator):
164+
from pyspark import TaskContext, BarrierTaskContext
165+
166+
tc = TaskContext.get()
167+
assert tc is not None
168+
assert isinstance(tc, BarrierTaskContext)
169+
for batch in iterator:
170+
yield batch
171+
172+
df.mapInArrow(func2, "id long", True).collect()
173+
174+
def test_negative_and_zero_batch_size(self):
175+
for batch_size in [0, -1]:
176+
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": batch_size}):
177+
MapInArrowTests.test_map_in_arrow(self)
149178

150179
class MapInArrowTests(MapInArrowTestsMixin, ReusedSQLTestCase):
151180
@classmethod
@@ -170,6 +199,15 @@ def tearDownClass(cls):
170199
ReusedSQLTestCase.tearDownClass()
171200

172201

202+
class MapInArrowWithArrowBatchSlicingTestsAndReducedBatchSizeTests(MapInArrowTests):
203+
@classmethod
204+
def setUpClass(cls):
205+
MapInArrowTests.setUpClass()
206+
# Set it to a small odd value to exercise batching logic for all test cases
207+
cls.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "3")
208+
cls.spark.conf.set("spark.sql.execution.arrow.maxBytesPerBatch", "10")
209+
210+
173211
if __name__ == "__main__":
174212
from pyspark.sql.tests.test_arrow_map import * # noqa: F401
175213

sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) {
103103
count += 1
104104
}
105105

106+
def sizeInBytes(): Int = {
107+
var i = 0
108+
var bytes = 0
109+
while (i < fields.size) {
110+
bytes += fields(i).getSizeInBytes()
111+
i += 1
112+
}
113+
bytes
114+
}
115+
106116
def finish(): Unit = {
107117
root.setRowCount(count)
108118
fields.foreach(_.finish())
@@ -141,6 +151,13 @@ private[arrow] abstract class ArrowFieldWriter {
141151
valueVector.setValueCount(count)
142152
}
143153

154+
def getSizeInBytes(): Int = {
155+
valueVector.setValueCount(count)
156+
// Before calling getBufferSizeFor, we need to call
157+
// `setValueCount`, see https://github.com/apache/arrow/pull/9187#issuecomment-763362710
158+
valueVector.getBufferSizeFor(count)
159+
}
160+
144161
def reset(): Unit = {
145162
valueVector.reset()
146163
count = 0

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2908,11 +2908,31 @@ object SQLConf {
29082908
.doc("When using Apache Arrow, limit the maximum number of records that can be written " +
29092909
"to a single ArrowRecordBatch in memory. This configuration is not effective for the " +
29102910
"grouping API such as DataFrame(.cogroup).groupby.applyInPandas because each group " +
2911-
"becomes each ArrowRecordBatch. If set to zero or negative there is no limit.")
2911+
"becomes each ArrowRecordBatch. If set to zero or negative there is no limit. " +
2912+
"See also spark.sql.execution.arrow.maxBytesPerBatch. If both are set, each batch " +
2913+
"is created when any condition of both is met.")
29122914
.version("2.3.0")
29132915
.intConf
29142916
.createWithDefault(10000)
29152917

2918+
val ARROW_EXECUTION_MAX_BYTES_PER_BATCH =
2919+
buildConf("spark.sql.execution.arrow.maxBytesPerBatch")
2920+
.internal()
2921+
.doc("When using Apache Arrow, limit the maximum bytes in each batch that can be written " +
2922+
"to a single ArrowRecordBatch in memory. This configuration is not effective for the " +
2923+
"grouping API such as DataFrame(.cogroup).groupby.applyInPandas because each group " +
2924+
"becomes each ArrowRecordBatch. Unlike 'spark.sql.execution.arrow.maxRecordsPerBatch', " +
2925+
"this configuration does not work for createDataFrame/toPandas with Arrow/pandas " +
2926+
"instances. " +
2927+
"See also spark.sql.execution.arrow.maxRecordsPerBatch. If both are set, each batch " +
2928+
"is created when any condition of both is met.")
2929+
.version("4.0.0")
2930+
.bytesConf(ByteUnit.BYTE)
2931+
.checkValue(x => x > 0 && x <= Int.MaxValue,
2932+
errorMsg = "The value of " +
2933+
"spark.sql.execution.arrow.maxBytesPerBatch should be greater " +
2934+
"than zero and less than INT_MAX.")
2935+
.createWithDefaultString("256MB")
29162936
val ARROW_EXECUTION_USE_LARGE_VAR_TYPES =
29172937
buildConf("spark.sql.execution.arrow.useLargeVarTypes")
29182938
.doc("When using Apache Arrow, use large variable width vectors for string and binary " +
@@ -5073,6 +5093,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
50735093

50745094
def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
50755095

5096+
def arrowMaxBytesPerBatch: Long = getConf(ARROW_EXECUTION_MAX_BYTES_PER_BATCH)
50765097
def arrowUseLargeVarTypes: Boolean = getConf(ARROW_EXECUTION_USE_LARGE_VAR_TYPES)
50775098

50785099
def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE)

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
3636
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
3737
import org.apache.spark.sql.catalyst.plans.physical._
3838
import org.apache.spark.sql.catalyst.types.DataTypeUtils
39-
import org.apache.spark.sql.execution.python.BatchIterator
4039
import org.apache.spark.sql.execution.r.ArrowRRunner
4140
import org.apache.spark.sql.execution.streaming.GroupStateImpl
4241
import org.apache.spark.sql.internal.SQLConf
@@ -219,17 +218,13 @@ case class MapPartitionsInRWithArrowExec(
219218
child: SparkPlan) extends UnaryExecNode {
220219
override def producedAttributes: AttributeSet = AttributeSet(output)
221220

222-
private val batchSize = conf.arrowMaxRecordsPerBatch
223-
224221
override def outputPartitioning: Partitioning = child.outputPartitioning
225222

226223
override protected def doExecute(): RDD[InternalRow] = {
227224
child.execute().mapPartitionsInternal { inputIter =>
228225
val outputTypes = schema.map(_.dataType)
229226

230-
// DO NOT use iter.grouped(). See BatchIterator.
231-
val batchIter =
232-
if (batchSize > 0) new BatchIterator(inputIter, batchSize) else Iterator(inputIter)
227+
val batchIter = Iterator(inputIter)
233228

234229
val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
235230
SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_DAPPLY)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
6262
evalType: Int)
6363
extends EvalPythonExec with PythonSQLMetrics {
6464

65-
private val batchSize = conf.arrowMaxRecordsPerBatch
6665
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
6766
private val largeVarTypes = conf.arrowUseLargeVarTypes
6867
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
@@ -77,10 +76,9 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
7776

7877
val outputTypes = output.drop(child.output.length).map(_.dataType)
7978

80-
// DO NOT use iter.grouped(). See BatchIterator.
81-
val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter)
79+
val batchIter = Iterator(iter)
8280

83-
val columnarBatchIter = new ArrowPythonRunner(
81+
val pyRunner = new ArrowPythonRunner(
8482
funcs,
8583
evalType,
8684
argOffsets,
@@ -89,7 +87,8 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], resultAttrs: Seq[Attribute]
8987
largeVarTypes,
9088
pythonRunnerConf,
9189
pythonMetrics,
92-
jobArtifactUUID).compute(batchIter, context.partitionId(), context)
90+
jobArtifactUUID) with BatchedPythonArrowInput
91+
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context)
9392

9493
columnarBatchIter.flatMap { batch =>
9594
val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType())

sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class ArrowPythonUDTFRunner(
4242
jobArtifactUUID: Option[String])
4343
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
4444
Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType, Array(offsets), jobArtifactUUID)
45-
with BasicPythonArrowInput
45+
with BatchedPythonArrowInput
4646
with BasicPythonArrowOutput {
4747

4848
override protected def writeUDF(

sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,9 @@ class MapInBatchEvaluatorFactory(
5858
// as a DataFrame.
5959
val wrappedIter = contextAwareIterator.map(InternalRow(_))
6060

61-
// DO NOT use iter.grouped(). See BatchIterator.
62-
val batchIter =
63-
if (batchSize > 0) new BatchIterator(wrappedIter, batchSize) else Iterator(wrappedIter)
61+
val batchIter = Iterator(wrappedIter)
6462

65-
val columnarBatchIter = new ArrowPythonRunner(
63+
val pyRunner = new ArrowPythonRunner(
6664
chainedFunc,
6765
pythonEvalType,
6866
argOffsets,
@@ -71,7 +69,8 @@ class MapInBatchEvaluatorFactory(
7169
largeVarTypes,
7270
pythonRunnerConf,
7371
pythonMetrics,
74-
jobArtifactUUID).compute(batchIter, context.partitionId(), context)
72+
jobArtifactUUID) with BatchedPythonArrowInput
73+
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context)
7574

7675
val unsafeProj = UnsafeProjection.create(output, output)
7776

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, Py
2727
import org.apache.spark.sql.catalyst.InternalRow
2828
import org.apache.spark.sql.execution.arrow.ArrowWriter
2929
import org.apache.spark.sql.execution.metric.SQLMetric
30+
import org.apache.spark.sql.internal.SQLConf
3031
import org.apache.spark.sql.types.StructType
3132
import org.apache.spark.sql.util.ArrowUtils
3233
import org.apache.spark.util.Utils
@@ -93,12 +94,14 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
9394
val writer = new ArrowStreamWriter(root, null, dataOut)
9495
writer.start()
9596

96-
writeIteratorToArrowStream(root, writer, dataOut, inputIterator)
97-
98-
// end writes footer to the output stream and doesn't clean any resources.
99-
// It could throw exception if the output stream is closed, so it should be
100-
// in the try block.
101-
writer.end()
97+
Utils.tryWithSafeFinally {
98+
writeIteratorToArrowStream(root, writer, dataOut, inputIterator)
99+
} {
100+
// end writes footer to the output stream and doesn't clean any resources.
101+
// It could throw exception if the output stream is closed, so it should be
102+
// in the try block.
103+
writer.end()
104+
}
102105
} {
103106
// If we close root and allocator in TaskCompletionListener, there could be a race
104107
// condition where the writer thread keeps writing to the VectorSchemaRoot while
@@ -144,3 +147,48 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
144147
}
145148
}
146149
}
150+
151+
private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
152+
self: BasePythonRunner[Iterator[InternalRow], _] =>
153+
154+
private val arrowMaxRecordsPerBatch = SQLConf.get.arrowMaxRecordsPerBatch
155+
private val maxBytesPerBatch = SQLConf.get.arrowMaxBytesPerBatch
156+
157+
// Marker inside the input iterator to indicate the start of the next batch.
158+
private var nextBatchStart: Iterator[InternalRow] = Iterator.empty
159+
160+
override protected def writeIteratorToArrowStream(
161+
root: VectorSchemaRoot,
162+
writer: ArrowStreamWriter,
163+
dataOut: DataOutputStream,
164+
inputIterator: Iterator[Iterator[InternalRow]]): Unit = {
165+
val arrowWriter = ArrowWriter.create(root)
166+
167+
while (nextBatchStart.hasNext || inputIterator.hasNext) {
168+
if (!nextBatchStart.hasNext) {
169+
nextBatchStart = inputIterator.next()
170+
}
171+
172+
val startData = dataOut.size()
173+
var numRowsInBatch = 0
174+
175+
def underBatchSizeLimit: Boolean =
176+
maxBytesPerBatch == Int.MaxValue || arrowWriter.sizeInBytes() < maxBytesPerBatch
177+
178+
while (nextBatchStart.hasNext &&
179+
(arrowMaxRecordsPerBatch <= 0 || numRowsInBatch < arrowMaxRecordsPerBatch) &&
180+
underBatchSizeLimit) {
181+
arrowWriter.write(nextBatchStart.next())
182+
numRowsInBatch += 1
183+
}
184+
185+
if (numRowsInBatch > 0) {
186+
arrowWriter.finish()
187+
writer.writeBatch()
188+
arrowWriter.reset()
189+
val deltaData = dataOut.size() - startData
190+
pythonMetrics("pythonDataSent") += deltaData
191+
}
192+
}
193+
}
194+
}

0 commit comments

Comments
 (0)