Skip to content

Commit a3f4aec

Browse files
HyukjinKwonsunchao
authored andcommitted
[SPARK-51316][PYTHON][FOLLOW-UP] Revert unrelated changes and mark mapInPandas/mapInArrow batched in byte size
This PR is a followup of #50096 that reverts unrelated changes and mark mapInPandas/mapInArrow batched in byte size To make the original change self-contained, and mark mapInPandas/mapInArrow batched in byte size to be consistent. No, the main change has not been released out yet. Manually. No. Closes #50111 from HyukjinKwon/SPARK-51316-followup. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit 5b45671) Signed-off-by: Hyukjin Kwon <gurwls223@apache.org> (cherry picked from commit 1df6fc6)
1 parent 592c03f commit a3f4aec

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
3636
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
3737
import org.apache.spark.sql.catalyst.plans.physical._
3838
import org.apache.spark.sql.catalyst.types.DataTypeUtils
39+
import org.apache.spark.sql.execution.python.BatchIterator
3940
import org.apache.spark.sql.execution.r.ArrowRRunner
4041
import org.apache.spark.sql.execution.streaming.GroupStateImpl
4142
import org.apache.spark.sql.internal.SQLConf
@@ -218,13 +219,17 @@ case class MapPartitionsInRWithArrowExec(
218219
child: SparkPlan) extends UnaryExecNode {
219220
override def producedAttributes: AttributeSet = AttributeSet(output)
220221

222+
private val batchSize = conf.arrowMaxRecordsPerBatch
223+
221224
override def outputPartitioning: Partitioning = child.outputPartitioning
222225

223226
override protected def doExecute(): RDD[InternalRow] = {
224227
child.execute().mapPartitionsInternal { inputIter =>
225228
val outputTypes = schema.map(_.dataType)
226229

227-
val batchIter = Iterator(inputIter)
230+
// DO NOT use iter.grouped(). See BatchIterator.
231+
val batchIter =
232+
if (batchSize > 0) new BatchIterator(inputIter, batchSize) else Iterator(inputIter)
228233

229234
val runner = new ArrowRRunner(func, packageNames, broadcastVars, inputSchema,
230235
SQLConf.get.sessionLocalTimeZone, RRunnerModes.DATAFRAME_DAPPLY)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,10 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
150150

151151
private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
152152
self: BasePythonRunner[Iterator[InternalRow], _] =>
153-
154-
private val arrowMaxRecordsPerBatch = SQLConf.get.arrowMaxRecordsPerBatch
153+
private val arrowMaxRecordsPerBatch = {
154+
val v = SQLConf.get.arrowMaxRecordsPerBatch
155+
if (v > 0) v else Int.MaxValue
156+
}
155157
private val maxBytesPerBatch = SQLConf.get.arrowMaxBytesPerBatch
156158

157159
// Marker inside the input iterator to indicate the start of the next batch.
@@ -176,7 +178,7 @@ private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
176178
maxBytesPerBatch == Int.MaxValue || arrowWriter.sizeInBytes() < maxBytesPerBatch
177179

178180
while (nextBatchStart.hasNext &&
179-
(arrowMaxRecordsPerBatch <= 0 || numRowsInBatch < arrowMaxRecordsPerBatch) &&
181+
numRowsInBatch < arrowMaxRecordsPerBatch &&
180182
underBatchSizeLimit) {
181183
arrowWriter.write(nextBatchStart.next())
182184
numRowsInBatch += 1

0 commit comments

Comments
 (0)