Skip to content

Commit e83583e

Browse files
jaceklaskowskisrowen
authored andcommitted
[MINOR][SQL] Clean up ObjectProducerExec operators
## What changes were proposed in this pull request? Cleaned up (removed) code duplication in `ObjectProducerExec` operators so they use the trait's methods. ## How was this patch tested? Local build. Waiting for Jenkins. Closes apache#25065 from jaceklaskowski/ObjectProducerExec-operators-cleanup. Authored-by: Jacek Laskowski <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent 8dff711 commit e83583e

File tree

3 files changed

+11
-14
lines changed

3 files changed

+11
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ case class ExternalRDDScanExec[T](
6969

7070
protected override def doExecute(): RDD[InternalRow] = {
7171
val numOutputRows = longMetric("numOutputRows")
72-
val outputDataType = outputObjAttr.dataType
7372
rdd.mapPartitionsInternal { iter =>
74-
val outputObject = ObjectOperator.wrapObjectToRow(outputDataType)
73+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
7574
iter.map { value =>
7675
numOutputRows += 1
7776
outputObject(value)

sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ case class MapPartitionsExec(
191191
override protected def doExecute(): RDD[InternalRow] = {
192192
child.execute().mapPartitionsInternal { iter =>
193193
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
194-
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
194+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
195195
func(iter.map(getObject)).map(outputObject)
196196
}
197197
}
@@ -278,10 +278,10 @@ case class MapElementsExec(
278278
override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
279279
val (funcClass, methodName) = func match {
280280
case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call"
281-
case _ => FunctionUtils.getFunctionOneName(outputObjAttr.dataType, child.output(0).dataType)
281+
case _ => FunctionUtils.getFunctionOneName(outputObjectType, child.output(0).dataType)
282282
}
283283
val funcObj = Literal.create(func, ObjectType(funcClass))
284-
val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
284+
val callFunc = Invoke(funcObj, methodName, outputObjectType, child.output)
285285

286286
val result = BindReferences.bindReference(callFunc, child.output).genCode(ctx)
287287

@@ -296,7 +296,7 @@ case class MapElementsExec(
296296

297297
child.execute().mapPartitionsInternal { iter =>
298298
val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
299-
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
299+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
300300
iter.map(row => outputObject(callFunc(getObject(row))))
301301
}
302302
}
@@ -395,7 +395,7 @@ case class MapGroupsExec(
395395

396396
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
397397
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
398-
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
398+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
399399

400400
grouped.flatMap { case (key, rowIter) =>
401401
val result = func(
@@ -447,8 +447,6 @@ case class FlatMapGroupsInRExec(
447447
outputObjAttr: Attribute,
448448
child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
449449

450-
override def output: Seq[Attribute] = outputObjAttr :: Nil
451-
452450
override def outputPartitioning: Partitioning = child.outputPartitioning
453451

454452
override def requiredChildDistribution: Seq[Distribution] =
@@ -473,7 +471,7 @@ case class FlatMapGroupsInRExec(
473471
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
474472
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
475473
val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
476-
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
474+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
477475
val runner = new RRunner[(Array[Byte], Iterator[Array[Byte]]), Array[Byte]](
478476
func, SerializationFormats.ROW, serializerForR, packageNames, broadcastVars,
479477
isDataFrame = true, colNames = inputSchema.fieldNames,
@@ -606,7 +604,7 @@ case class CoGroupExec(
606604
val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup)
607605
val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr)
608606
val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr)
609-
val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
607+
val outputObject = ObjectOperator.wrapObjectToRow(outputObjectType)
610608

611609
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
612610
case (key, leftResult, rightResult) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
2828
import org.apache.spark.util.CompletionIterator
2929

3030
/**
31-
* Physical operator for executing `FlatMapGroupsWithState.`
31+
* Physical operator for executing `FlatMapGroupsWithState`
3232
*
3333
* @param func function called on each group
3434
* @param keyDeserializer used to extract the key object for each group.
3535
* @param valueDeserializer used to extract the items in the iterator from an input row.
3636
* @param groupingAttributes used to group the data
3737
* @param dataAttributes used to read the data
38-
* @param outputObjAttr used to define the output object
38+
* @param outputObjAttr Defines the output object
3939
* @param stateEncoder used to serialize/deserialize state before calling `func`
4040
* @param outputMode the output mode of `func`
4141
* @param timeoutConf used to timeout groups that have not received data in a while
@@ -154,7 +154,7 @@ case class FlatMapGroupsWithStateExec(
154154
ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
155155
private val getValueObj =
156156
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
157-
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
157+
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjectType)
158158

159159
// Metrics
160160
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")

0 commit comments

Comments
 (0)