@@ -25,17 +25,19 @@ import scala.collection.JavaConverters._
25
25
26
26
import org .apache .hadoop .fs .Path
27
27
28
+ import org .apache .spark .SparkConf
28
29
import org .apache .spark .scheduler .{SparkListener , SparkListenerTaskEnd }
29
30
import org .apache .spark .sql .{AnalysisException , DataFrame }
30
31
import org .apache .spark .sql .execution .DataSourceScanExec
31
32
import org .apache .spark .sql .execution .datasources ._
33
+ import org .apache .spark .sql .execution .datasources .v2 .{BatchScanExec , DataSourceV2Relation , FileScan , FileTable }
32
34
import org .apache .spark .sql .execution .streaming ._
33
35
import org .apache .spark .sql .functions ._
34
36
import org .apache .spark .sql .internal .SQLConf
35
37
import org .apache .spark .sql .types .{IntegerType , StructField , StructType }
36
38
import org .apache .spark .util .Utils
37
39
38
- class FileStreamSinkSuite extends StreamTest {
40
+ abstract class FileStreamSinkSuite extends StreamTest {
39
41
import testImplicits ._
40
42
41
43
override def beforeAll (): Unit = {
@@ -51,6 +53,8 @@ class FileStreamSinkSuite extends StreamTest {
51
53
}
52
54
}
53
55
56
+ protected def checkQueryExecution (df : DataFrame ): Unit
57
+
54
58
test(" unpartitioned writing and batch reading" ) {
55
59
val inputData = MemoryStream [Int ]
56
60
val df = inputData.toDF()
@@ -121,78 +125,36 @@ class FileStreamSinkSuite extends StreamTest {
121
125
122
126
var query : StreamingQuery = null
123
127
124
- // TODO: test file source V2 as well.
125
- withSQLConf(SQLConf .USE_V1_SOURCE_READER_LIST .key -> " parquet" ) {
126
- try {
127
- query =
128
- ds.map(i => (i, i * 1000 ))
129
- .toDF(" id" , " value" )
130
- .writeStream
131
- .partitionBy(" id" )
132
- .option(" checkpointLocation" , checkpointDir)
133
- .format(" parquet" )
134
- .start(outputDir)
135
-
136
- inputData.addData(1 , 2 , 3 )
137
- failAfter(streamingTimeout) {
138
- query.processAllAvailable()
139
- }
128
+ try {
129
+ query =
130
+ ds.map(i => (i, i * 1000 ))
131
+ .toDF(" id" , " value" )
132
+ .writeStream
133
+ .partitionBy(" id" )
134
+ .option(" checkpointLocation" , checkpointDir)
135
+ .format(" parquet" )
136
+ .start(outputDir)
140
137
141
- val outputDf = spark.read.parquet(outputDir)
142
- val expectedSchema = new StructType ()
143
- .add(StructField (" value" , IntegerType , nullable = false ))
144
- .add(StructField (" id" , IntegerType ))
145
- assert(outputDf.schema === expectedSchema)
146
-
147
- // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
148
- // been inferred
149
- val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect {
150
- case LogicalRelation (baseRelation : HadoopFsRelation , _, _, _) => baseRelation
151
- }
152
- assert(hadoopdFsRelations.size === 1 )
153
- assert(hadoopdFsRelations.head.location.isInstanceOf [MetadataLogFileIndex ])
154
- assert(hadoopdFsRelations.head.partitionSchema.exists(_.name == " id" ))
155
- assert(hadoopdFsRelations.head.dataSchema.exists(_.name == " value" ))
156
-
157
- // Verify the data is correctly read
158
- checkDatasetUnorderly(
159
- outputDf.as[(Int , Int )],
160
- (1000 , 1 ), (2000 , 2 ), (3000 , 3 ))
161
-
162
- /** Check some condition on the partitions of the FileScanRDD generated by a DF */
163
- def checkFileScanPartitions (df : DataFrame )(func : Seq [FilePartition ] => Unit ): Unit = {
164
- val getFileScanRDD = df.queryExecution.executedPlan.collect {
165
- case scan : DataSourceScanExec if scan.inputRDDs().head.isInstanceOf [FileScanRDD ] =>
166
- scan.inputRDDs().head.asInstanceOf [FileScanRDD ]
167
- }.headOption.getOrElse {
168
- fail(s " No FileScan in query \n ${df.queryExecution}" )
169
- }
170
- func(getFileScanRDD.filePartitions)
171
- }
138
+ inputData.addData(1 , 2 , 3 )
139
+ failAfter(streamingTimeout) {
140
+ query.processAllAvailable()
141
+ }
172
142
173
- // Read without pruning
174
- checkFileScanPartitions(outputDf) { partitions =>
175
- // There should be as many distinct partition values as there are distinct ids
176
- assert(partitions.flatMap(_.files.map(_.partitionValues)).distinct.size === 3 )
177
- }
143
+ val outputDf = spark.read.parquet(outputDir)
144
+ val expectedSchema = new StructType ()
145
+ .add( StructField ( " value " , IntegerType , nullable = false ))
146
+ .add( StructField ( " id " , IntegerType ) )
147
+ assert(outputDf.schema === expectedSchema)
178
148
179
- // Read with pruning, should read only files in partition dir id=1
180
- checkFileScanPartitions(outputDf.filter(" id = 1" )) { partitions =>
181
- val filesToBeRead = partitions.flatMap(_.files)
182
- assert(filesToBeRead.map(_.filePath).forall(_.contains(" /id=1/" )))
183
- assert(filesToBeRead.map(_.partitionValues).distinct.size === 1 )
184
- }
149
+ // Verify the data is correctly read
150
+ checkDatasetUnorderly(
151
+ outputDf.as[(Int , Int )],
152
+ (1000 , 1 ), (2000 , 2 ), (3000 , 3 ))
185
153
186
- // Read with pruning, should read only files in partition dir id=1 and id=2
187
- checkFileScanPartitions(outputDf.filter(" id in (1,2)" )) { partitions =>
188
- val filesToBeRead = partitions.flatMap(_.files)
189
- assert(! filesToBeRead.map(_.filePath).exists(_.contains(" /id=3/" )))
190
- assert(filesToBeRead.map(_.partitionValues).distinct.size === 2 )
191
- }
192
- } finally {
193
- if (query != null ) {
194
- query.stop()
195
- }
154
+ checkQueryExecution(outputDf)
155
+ } finally {
156
+ if (query != null ) {
157
+ query.stop()
196
158
}
197
159
}
198
160
}
@@ -512,3 +474,92 @@ class FileStreamSinkSuite extends StreamTest {
512
474
}
513
475
}
514
476
}
477
+
478
+ class FileStreamSinkV1Suite extends FileStreamSinkSuite {
479
+ override protected def sparkConf : SparkConf =
480
+ super
481
+ .sparkConf
482
+ .set(SQLConf .USE_V1_SOURCE_READER_LIST , " csv,json,orc,text,parquet" )
483
+ .set(SQLConf .USE_V1_SOURCE_WRITER_LIST , " csv,json,orc,text,parquet" )
484
+
485
+ override def checkQueryExecution (df : DataFrame ): Unit = {
486
+ // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
487
+ // been inferred
488
+ val hadoopdFsRelations = df.queryExecution.analyzed.collect {
489
+ case LogicalRelation (baseRelation : HadoopFsRelation , _, _, _) => baseRelation
490
+ }
491
+ assert(hadoopdFsRelations.size === 1 )
492
+ assert(hadoopdFsRelations.head.location.isInstanceOf [MetadataLogFileIndex ])
493
+ assert(hadoopdFsRelations.head.partitionSchema.exists(_.name == " id" ))
494
+ assert(hadoopdFsRelations.head.dataSchema.exists(_.name == " value" ))
495
+
496
+ /** Check some condition on the partitions of the FileScanRDD generated by a DF */
497
+ def checkFileScanPartitions (df : DataFrame )(func : Seq [FilePartition ] => Unit ): Unit = {
498
+ val getFileScanRDD = df.queryExecution.executedPlan.collect {
499
+ case scan : DataSourceScanExec if scan.inputRDDs().head.isInstanceOf [FileScanRDD ] =>
500
+ scan.inputRDDs().head.asInstanceOf [FileScanRDD ]
501
+ }.headOption.getOrElse {
502
+ fail(s " No FileScan in query \n ${df.queryExecution}" )
503
+ }
504
+ func(getFileScanRDD.filePartitions)
505
+ }
506
+
507
+ // Read without pruning
508
+ checkFileScanPartitions(df) { partitions =>
509
+ // There should be as many distinct partition values as there are distinct ids
510
+ assert(partitions.flatMap(_.files.map(_.partitionValues)).distinct.size === 3 )
511
+ }
512
+
513
+ // Read with pruning, should read only files in partition dir id=1
514
+ checkFileScanPartitions(df.filter(" id = 1" )) { partitions =>
515
+ val filesToBeRead = partitions.flatMap(_.files)
516
+ assert(filesToBeRead.map(_.filePath).forall(_.contains(" /id=1/" )))
517
+ assert(filesToBeRead.map(_.partitionValues).distinct.size === 1 )
518
+ }
519
+
520
+ // Read with pruning, should read only files in partition dir id=1 and id=2
521
+ checkFileScanPartitions(df.filter(" id in (1,2)" )) { partitions =>
522
+ val filesToBeRead = partitions.flatMap(_.files)
523
+ assert(! filesToBeRead.map(_.filePath).exists(_.contains(" /id=3/" )))
524
+ assert(filesToBeRead.map(_.partitionValues).distinct.size === 2 )
525
+ }
526
+ }
527
+ }
528
+
529
+ class FileStreamSinkV2Suite extends FileStreamSinkSuite {
530
+ override protected def sparkConf : SparkConf =
531
+ super
532
+ .sparkConf
533
+ .set(SQLConf .USE_V1_SOURCE_READER_LIST , " " )
534
+ .set(SQLConf .USE_V1_SOURCE_WRITER_LIST , " " )
535
+
536
+ override def checkQueryExecution (df : DataFrame ): Unit = {
537
+ // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
538
+ // been inferred
539
+ val table = df.queryExecution.analyzed.collect {
540
+ case DataSourceV2Relation (table : FileTable , _, _) => table
541
+ }
542
+ assert(table.size === 1 )
543
+ assert(table.head.fileIndex.isInstanceOf [MetadataLogFileIndex ])
544
+ assert(table.head.fileIndex.partitionSchema.exists(_.name == " id" ))
545
+ assert(table.head.dataSchema.exists(_.name == " value" ))
546
+
547
+ /** Check some condition on the partitions of the FileScanRDD generated by a DF */
548
+ def checkFileScanPartitions (df : DataFrame )(func : Seq [FilePartition ] => Unit ): Unit = {
549
+ val fileScan = df.queryExecution.executedPlan.collect {
550
+ case batch : BatchScanExec if batch.scan.isInstanceOf [FileScan ] =>
551
+ batch.scan.asInstanceOf [FileScan ]
552
+ }.headOption.getOrElse {
553
+ fail(s " No FileScan in query \n ${df.queryExecution}" )
554
+ }
555
+ func(fileScan.planInputPartitions().map(_.asInstanceOf [FilePartition ]))
556
+ }
557
+
558
+ // Read without pruning
559
+ checkFileScanPartitions(df) { partitions =>
560
+ // There should be as many distinct partition values as there are distinct ids
561
+ assert(partitions.flatMap(_.files.map(_.partitionValues)).distinct.size === 3 )
562
+ }
563
+ // TODO: test partition pruning when file source V2 supports it.
564
+ }
565
+ }
0 commit comments