@@ -30,6 +30,7 @@ import org.apache.spark.sql.test.SharedSQLContext
30
30
import org .apache .spark .sql .test .SQLTestData ._
31
31
import org .apache .spark .sql .types ._
32
32
import org .apache .spark .storage .StorageLevel ._
33
+ import org .apache .spark .util .Utils
33
34
34
35
class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
35
36
import testImplicits ._
@@ -40,7 +41,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
40
41
data.createOrReplaceTempView(s " testData $dataType" )
41
42
val storageLevel = MEMORY_ONLY
42
43
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
43
- val inMemoryRelation = InMemoryRelation (useCompression = true , 5 , storageLevel, plan, None )
44
+ val inMemoryRelation = InMemoryRelation (useCompression = true , 5 , storageLevel, plan, None ,
45
+ data.logicalPlan.stats)
44
46
45
47
assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel)
46
48
inMemoryRelation.cachedColumnBuffers.collect().head match {
@@ -116,7 +118,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
116
118
117
119
test(" simple columnar query" ) {
118
120
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
119
- val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None )
121
+ val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None ,
122
+ testData.logicalPlan.stats)
120
123
121
124
checkAnswer(scan, testData.collect().toSeq)
122
125
}
@@ -132,8 +135,10 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
132
135
}
133
136
134
137
test(" projection" ) {
135
- val plan = spark.sessionState.executePlan(testData.select(' value , ' key ).logicalPlan).sparkPlan
136
- val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None )
138
+ val logicalPlan = testData.select(' value , ' key ).logicalPlan
139
+ val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan
140
+ val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None ,
141
+ logicalPlan.stats)
137
142
138
143
checkAnswer(scan, testData.collect().map {
139
144
case Row (key : Int , value : String ) => value -> key
@@ -149,7 +154,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
149
154
150
155
test(" SPARK-1436 regression: in-memory columns must be able to be accessed multiple times" ) {
151
156
val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan
152
- val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None )
157
+ val scan = InMemoryRelation (useCompression = true , 5 , MEMORY_ONLY , plan, None ,
158
+ testData.logicalPlan.stats)
153
159
154
160
checkAnswer(scan, testData.collect().toSeq)
155
161
checkAnswer(scan, testData.collect().toSeq)
@@ -323,7 +329,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
323
329
test(" SPARK-17549: cached table size should be correctly calculated" ) {
324
330
val data = spark.sparkContext.parallelize(1 to 10 , 5 ).toDF()
325
331
val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan
326
- val cached = InMemoryRelation (true , 5 , MEMORY_ONLY , plan, None )
332
+ val cached = InMemoryRelation (true , 5 , MEMORY_ONLY , plan, None , data.logicalPlan.stats )
327
333
328
334
// Materialize the data.
329
335
val expectedAnswer = data.collect()
@@ -448,8 +454,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
448
454
449
455
test(" SPARK-22249: buildFilter should not throw exception when In contains an empty list" ) {
450
456
val attribute = AttributeReference (" a" , IntegerType )()
451
- val testRelation = InMemoryRelation ( false , 1 , MEMORY_ONLY ,
452
- LocalTableScanExec ( Seq (attribute), Nil ), None )
457
+ val localTableScanExec = LocalTableScanExec ( Seq (attribute), Nil )
458
+ val testRelation = InMemoryRelation ( false , 1 , MEMORY_ONLY , localTableScanExec, None , null )
453
459
val tableScanExec = InMemoryTableScanExec (Seq (attribute),
454
460
Seq (In (attribute, Nil )), testRelation)
455
461
assert(tableScanExec.partitionFilters.isEmpty)
@@ -479,4 +485,43 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
479
485
}
480
486
}
481
487
}
488
+
489
+ test(" SPARK-22673: InMemoryRelation should utilize existing stats of the plan to be cached" ) {
490
+ withSQLConf(" spark.sql.cbo.enabled" -> " true" ) {
491
+ withTempPath { workDir =>
492
+ withTable(" table1" ) {
493
+ val workDirPath = workDir.getAbsolutePath
494
+ val data = Seq (100 , 200 , 300 , 400 ).toDF(" count" )
495
+ data.write.parquet(workDirPath)
496
+ val dfFromFile = spark.read.parquet(workDirPath).cache()
497
+ val inMemoryRelation = dfFromFile.queryExecution.optimizedPlan.collect {
498
+ case plan : InMemoryRelation => plan
499
+ }.head
500
+ // InMemoryRelation's stats is file size before the underlying RDD is materialized
501
+ assert(inMemoryRelation.computeStats().sizeInBytes === 740 )
502
+
503
+ // InMemoryRelation's stats is updated after materializing RDD
504
+ dfFromFile.collect()
505
+ assert(inMemoryRelation.computeStats().sizeInBytes === 16 )
506
+
507
+ // test of catalog table
508
+ val dfFromTable = spark.catalog.createTable(" table1" , workDirPath).cache()
509
+ val inMemoryRelation2 = dfFromTable.queryExecution.optimizedPlan.
510
+ collect { case plan : InMemoryRelation => plan }.head
511
+
512
+ // Even CBO enabled, InMemoryRelation's stats keeps as the file size before table's stats
513
+ // is calculated
514
+ assert(inMemoryRelation2.computeStats().sizeInBytes === 740 )
515
+
516
+ // InMemoryRelation's stats should be updated after calculating stats of the table
517
+ // clear cache to simulate a fresh environment
518
+ dfFromTable.unpersist(blocking = true )
519
+ spark.sql(" ANALYZE TABLE table1 COMPUTE STATISTICS" )
520
+ val inMemoryRelation3 = spark.read.table(" table1" ).cache().queryExecution.optimizedPlan.
521
+ collect { case plan : InMemoryRelation => plan }.head
522
+ assert(inMemoryRelation3.computeStats().sizeInBytes === 48 )
523
+ }
524
+ }
525
+ }
526
+ }
482
527
}
0 commit comments