Skip to content

Commit ac527b5

Browse files
committed
[SPARK-24991][SQL] use InternalRow in DataSourceWriter
## What changes were proposed in this pull request? A follow up of apache#21118 Since we use `InternalRow` in the read API of data source v2, we should do the same thing for the write API. ## How was this patch tested? existing tests. Author: Wenchen Fan <[email protected]> Closes apache#21948 from cloud-fan/row-write.
1 parent 327bb30 commit ac527b5

File tree

17 files changed

+73
-230
lines changed

17 files changed

+73
-230
lines changed

external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage
4242
*/
4343
class KafkaStreamWriter(
4444
topic: Option[String], producerParams: Map[String, String], schema: StructType)
45-
extends StreamWriter with SupportsWriteInternalRow {
45+
extends StreamWriter {
4646

4747
validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)
4848

49-
override def createInternalRowWriterFactory(): KafkaStreamWriterFactory =
49+
override def createWriterFactory(): KafkaStreamWriterFactory =
5050
KafkaStreamWriterFactory(topic, producerParams, schema)
5151

5252
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
package org.apache.spark.sql.sources.v2.writer;
1919

2020
import org.apache.spark.annotation.InterfaceStability;
21-
import org.apache.spark.sql.Row;
2221
import org.apache.spark.sql.SaveMode;
22+
import org.apache.spark.sql.catalyst.InternalRow;
2323
import org.apache.spark.sql.sources.v2.DataSourceOptions;
2424
import org.apache.spark.sql.sources.v2.StreamWriteSupport;
2525
import org.apache.spark.sql.sources.v2.WriteSupport;
@@ -61,7 +61,7 @@ public interface DataSourceWriter {
6161
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
6262
* submitted.
6363
*/
64-
DataWriterFactory<Row> createWriterFactory();
64+
DataWriterFactory<InternalRow> createWriterFactory();
6565

6666
/**
6767
* Returns whether Spark should use the commit coordinator to ensure that at most one task for

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@
5353
* successfully, and have a way to revert committed data writers without the commit message, because
5454
* Spark only accepts the commit message that arrives first and ignore others.
5555
*
56-
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
57-
* source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers
58-
* that mix in {@link SupportsWriteInternalRow}.
56+
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.catalyst.InternalRow}.
5957
*/
6058
@InterfaceStability.Evolving
6159
public interface DataWriter<T> {

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@
3333
public interface DataWriterFactory<T> extends Serializable {
3434

3535
/**
36-
* Returns a data writer to do the actual writing work.
36+
* Returns a data writer to do the actual writing work. Note that, Spark will reuse the same data
37+
* object instance when sending data to the data writer, for better performance. Data writers
38+
* are responsible for defensive copies if necessary, e.g. copy the data before buffer it in a
39+
* list.
3740
*
3841
* If this method fails (by throwing an exception), the action will fail and no Spark job will be
3942
* submitted.

sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java

Lines changed: 0 additions & 41 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,7 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e
5050
override def output: Seq[Attribute] = Nil
5151

5252
override protected def doExecute(): RDD[InternalRow] = {
53-
val writeTask = writer match {
54-
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
55-
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
56-
}
57-
53+
val writeTask = writer.createWriterFactory()
5854
val useCommitCoordinator = writer.useCommitCoordinator
5955
val rdd = query.execute()
6056
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
@@ -155,27 +151,3 @@ object DataWritingSparkTask extends Logging {
155151
})
156152
}
157153
}
158-
159-
class InternalRowDataWriterFactory(
160-
rowWriterFactory: DataWriterFactory[Row],
161-
schema: StructType) extends DataWriterFactory[InternalRow] {
162-
163-
override def createDataWriter(
164-
partitionId: Int,
165-
taskId: Long,
166-
epochId: Long): DataWriter[InternalRow] = {
167-
new InternalRowDataWriter(
168-
rowWriterFactory.createDataWriter(partitionId, taskId, epochId),
169-
RowEncoder.apply(schema).resolveAndBind())
170-
}
171-
}
172-
173-
class InternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row])
174-
extends DataWriter[InternalRow] {
175-
176-
override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record))
177-
178-
override def commit(): WriterCommitMessage = rowWriter.commit()
179-
180-
override def abort(): Unit = rowWriter.abort()
181-
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp,
2828
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
2929
import org.apache.spark.sql.execution.SQLExecution
3030
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
31-
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
31+
import org.apache.spark.sql.execution.streaming.sources.MicroBatchWriter
3232
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport}
3333
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
34-
import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow
3534
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
3635
import org.apache.spark.util.{Clock, Utils}
3736

@@ -498,12 +497,7 @@ class MicroBatchExecution(
498497
newAttributePlan.schema,
499498
outputMode,
500499
new DataSourceOptions(extraOptions.asJava))
501-
if (writer.isInstanceOf[SupportsWriteInternalRow]) {
502-
WriteToDataSourceV2(
503-
new InternalRowMicroBatchWriter(currentBatchId, writer), newAttributePlan)
504-
} else {
505-
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
506-
}
500+
WriteToDataSourceV2(new MicroBatchWriter(currentBatchId, writer), newAttributePlan)
507501
case _ => throw new IllegalArgumentException(s"unknown sink type for $sink")
508502
}
509503

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousWriteRDD.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.streaming.continuous
1919

20-
import java.util.concurrent.atomic.AtomicLong
21-
2220
import org.apache.spark.{Partition, SparkEnv, TaskContext}
2321
import org.apache.spark.rdd.RDD
2422
import org.apache.spark.sql.catalyst.InternalRow
25-
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
26-
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
23+
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory}
2724
import org.apache.spark.util.Utils
2825

2926
/**
@@ -47,7 +44,6 @@ class ContinuousWriteRDD(var prev: RDD[InternalRow], writeTask: DataWriterFactor
4744
SparkEnv.get)
4845
EpochTracker.initializeCurrentEpoch(
4946
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
50-
5147
while (!context.isInterrupted() && !context.isCompleted()) {
5248
var dataWriter: DataWriter[InternalRow] = null
5349
// write the data and commit this writer.

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@ package org.apache.spark.sql.execution.streaming.continuous
1919

2020
import scala.util.control.NonFatal
2121

22-
import org.apache.spark.{SparkEnv, SparkException, TaskContext}
22+
import org.apache.spark.SparkException
2323
import org.apache.spark.internal.Logging
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.Attribute
2727
import org.apache.spark.sql.execution.SparkPlan
28-
import org.apache.spark.sql.execution.datasources.v2.{DataWritingSparkTask, InternalRowDataWriterFactory}
29-
import org.apache.spark.sql.execution.datasources.v2.DataWritingSparkTask.{logError, logInfo}
3028
import org.apache.spark.sql.execution.streaming.StreamExecution
31-
import org.apache.spark.sql.sources.v2.writer._
3229
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
33-
import org.apache.spark.util.Utils
3430

3531
/**
3632
* The physical plan for writing data into a continuous processing [[StreamWriter]].
@@ -41,11 +37,7 @@ case class WriteToContinuousDataSourceExec(writer: StreamWriter, query: SparkPla
4137
override def output: Seq[Attribute] = Nil
4238

4339
override protected def doExecute(): RDD[InternalRow] = {
44-
val writerFactory = writer match {
45-
case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory()
46-
case _ => new InternalRowDataWriterFactory(writer.createWriterFactory(), query.schema)
47-
}
48-
40+
val writerFactory = writer.createWriterFactory()
4941
val rdd = new ContinuousWriteRDD(query.execute(), writerFactory)
5042

5143
logInfo(s"Start processing data source writer: $writer. " +

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriter.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.sql.execution.streaming.sources
1919

20-
import scala.collection.JavaConverters._
21-
2220
import org.apache.spark.internal.Logging
23-
import org.apache.spark.sql.{Row, SparkSession}
21+
import org.apache.spark.sql.{Dataset, SparkSession}
22+
import org.apache.spark.sql.catalyst.InternalRow
23+
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2424
import org.apache.spark.sql.sources.v2.DataSourceOptions
2525
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
2626
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
@@ -39,7 +39,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
3939
assert(SparkSession.getActiveSession.isDefined)
4040
protected val spark = SparkSession.getActiveSession.get
4141

42-
def createWriterFactory(): DataWriterFactory[Row] = PackedRowWriterFactory
42+
def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory
4343

4444
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
4545
// We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
@@ -62,8 +62,7 @@ class ConsoleWriter(schema: StructType, options: DataSourceOptions)
6262
println(printMessage)
6363
println("-------------------------------------------")
6464
// scalastyle:off println
65-
spark
66-
.createDataFrame(rows.toList.asJava, schema)
65+
Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
6766
.show(numRowsToShow, isTruncated)
6867
}
6968

0 commit comments

Comments
 (0)