Skip to content

Commit 6ad3aa3

Browse files
committed
Reworked TestEnvironment to make it easier to override SparkContext settings.
1 parent 9b196f1 commit 6ad3aa3

File tree

11 files changed

+51
-47
lines changed

11 files changed

+51
-47
lines changed

core/src/main/scala/org/locationtech/rasterframes/ref/RasterSource.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,12 @@ object RasterSource extends LazyLogging {
9393
val cacheTimeout: Duration = Duration.fromNanos(rfConfig.getDuration("raster-source-cache-timeout").toNanos)
9494

9595
private val rsCache = Scaffeine()
96+
.recordStats()
9697
.expireAfterAccess(RasterSource.cacheTimeout)
9798
.build[String, RasterSource]
9899

100+
def cacheStats = rsCache.stats()
101+
99102
implicit def rsEncoder: ExpressionEncoder[RasterSource] = {
100103
RasterSourceUDT // Makes sure UDT is registered first
101104
ExpressionEncoder()

core/src/main/scala/org/locationtech/rasterframes/ref/SimpleRasterInfo.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ object SimpleRasterInfo {
7676
)
7777
}
7878

79+
private[rasterframes]
7980
lazy val cache = Scaffeine()
80-
//.recordStats()
81+
.recordStats()
8182
.build[String, SimpleRasterInfo]
83+
84+
def cacheStats = cache.stats()
8285
}

core/src/test/scala/org/locationtech/rasterframes/ExplodeSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import geotrellis.raster.resample.NearestNeighbor
3333
*/
3434
class ExplodeSpec extends TestEnvironment with TestData {
3535
describe("conversion to/from exploded representation of tiles") {
36-
import sqlContext.implicits._
36+
import spark.implicits._
3737

3838
it("should explode tiles") {
3939
val query = sql(

core/src/test/scala/org/locationtech/rasterframes/RasterFrameSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class RasterFrameSpec extends TestEnvironment with MetadataKeys
5656
}
5757
it("should provide Spark initialization methods") {
5858
assert(spark.withRasterFrames.isInstanceOf[SparkSession])
59-
assert(sqlContext.withRasterFrames.isInstanceOf[SQLContext])
59+
assert(spark.sqlContext.withRasterFrames.isInstanceOf[SQLContext])
6060
}
6161
}
6262

core/src/test/scala/org/locationtech/rasterframes/TestEnvironment.scala

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,46 +20,55 @@
2020
*/
2121
package org.locationtech.rasterframes
2222

23-
import java.nio.file.{Files, Paths}
23+
import java.io.File
24+
import java.nio.file.{Files, Path}
2425

2526
import com.typesafe.scalalogging.LazyLogging
26-
import geotrellis.spark.testkit.{TestEnvironment => GeoTrellisTestEnvironment}
27-
import org.apache.spark.SparkContext
27+
import geotrellis.raster.testkit.RasterMatchers
28+
import org.apache.hadoop.fs.FileUtil
2829
import org.apache.spark.sql._
2930
import org.apache.spark.sql.functions.col
3031
import org.apache.spark.sql.types.StructType
32+
import org.apache.spark.{SparkConf, SparkContext}
3133
import org.locationtech.jts.geom.Geometry
34+
import org.locationtech.rasterframes.util._
3235
import org.scalactic.Tolerance
3336
import org.scalatest._
3437
import org.scalatest.matchers.{MatchResult, Matcher}
35-
import org.locationtech.rasterframes.util._
3638

37-
trait TestEnvironment extends FunSpec with GeoTrellisTestEnvironment
38-
with Matchers with Inspectors with Tolerance with LazyLogging {
39+
trait TestEnvironment extends FunSpec
40+
with Matchers with Inspectors with Tolerance with RasterMatchers with LazyLogging {
3941

40-
override def sparkMaster: String = "local[*]"
42+
lazy val scratchDir: Path = {
43+
val outputDir = Files.createTempDirectory("rf-scratch-")
44+
outputDir.toFile.deleteOnExit()
45+
outputDir
46+
}
4147

42-
override implicit def sc: SparkContext = { _sc.setLogLevel("ERROR"); _sc }
48+
def sparkMaster: String = "local[*]"
4349

44-
lazy val sqlContext: SQLContext = {
50+
def additionalConf = new SparkConf(false)
51+
52+
implicit lazy val spark: SparkSession = {
4553
val session = SparkSession.builder
46-
.config(_sc.getConf)
47-
.config("spark.sql.crossJoin.enabled", true)
54+
.master(sparkMaster)
4855
.withKryoSerialization
56+
.config(additionalConf)
4957
.getOrCreate()
50-
session.sqlContext.withRasterFrames
58+
session.withRasterFrames
5159
}
5260

53-
lazy val sql: String DataFrame = sqlContext.sql
54-
implicit lazy val spark: SparkSession = sqlContext.sparkSession
61+
implicit def sc: SparkContext = spark.sparkContext
62+
63+
lazy val sql: String DataFrame = spark.sql
5564

5665
def isCI: Boolean = sys.env.get("CI").contains("true")
5766

5867
/** This is here so we can test writing UDF generated/modified GeoTrellis types to ensure they are Parquet compliant. */
5968
def write(df: Dataset[_]): Boolean = {
6069
val sanitized = df.select(df.columns.map(c col(c).as(toParquetFriendlyColumnName(c))): _*)
6170
val inRows = sanitized.count()
62-
val dest = Files.createTempFile(Paths.get(outputLocalPath), "rf", ".parquet")
71+
val dest = Files.createTempFile("rf", ".parquet")
6372
logger.trace(s"Writing '${sanitized.columns.mkString(", ")}' to '$dest'...")
6473
sanitized.write.mode(SaveMode.Overwrite).parquet(dest.toString)
6574
val in = df.sparkSession.read.parquet(dest.toString)

core/src/test/scala/org/locationtech/rasterframes/TileAssemblerSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.locationtech.rasterframes.ref.{InMemoryRasterSource, RasterSource}
3737
class TileAssemblerSpec extends TestEnvironment {
3838
import TileAssemblerSpec._
3939
describe("TileAssembler") {
40-
import sqlContext.implicits._
40+
import spark.implicits._
4141

4242
it("should reassemble a small scene") {
4343
val raster = TestData.l8Sample(8).projectedRaster

core/src/test/scala/org/locationtech/rasterframes/TileStatsSpec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.locationtech.rasterframes.stats.CellHistogram
3737
*/
3838
class TileStatsSpec extends TestEnvironment with TestData {
3939
import TestData.injectND
40-
import sqlContext.implicits._
40+
import spark.implicits._
4141

4242
describe("computing statistics over tiles") {
4343
//import org.apache.spark.sql.execution.debug._
@@ -97,7 +97,7 @@ class TileStatsSpec extends TestEnvironment with TestData {
9797
}
9898

9999
it("should support local min/max") {
100-
import sqlContext.implicits._
100+
import spark.implicits._
101101
val ds = Seq[Tile](byteArrayTile, byteConstantTile).toDF("tiles")
102102
ds.createOrReplaceTempView("tmp")
103103

@@ -124,7 +124,7 @@ class TileStatsSpec extends TestEnvironment with TestData {
124124
}
125125

126126
it("should compute tile statistics") {
127-
import sqlContext.implicits._
127+
import spark.implicits._
128128
withClue("mean") {
129129

130130
val ds = Seq.fill[Tile](3)(randomTile(5, 5, FloatConstantNoDataCellType)).toDS()
@@ -229,7 +229,7 @@ class TileStatsSpec extends TestEnvironment with TestData {
229229
}
230230

231231
it("should compute aggregate local stats") {
232-
import sqlContext.implicits._
232+
import spark.implicits._
233233
val ave = (nums: Array[Double]) => nums.sum / nums.length
234234

235235
val ds = (Seq

core/src/test/scala/org/locationtech/rasterframes/encoders/EncodingSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import org.locationtech.rasterframes.tiles.ProjectedRasterTile
4343
*/
4444
class EncodingSpec extends TestEnvironment with TestData {
4545

46-
import sqlContext.implicits._
46+
import spark.implicits._
4747

4848
describe("Spark encoding on standard types") {
4949

datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisCatalogSpec.scala

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
*/
2121
package org.locationtech.rasterframes.datasource.geotrellis
2222

23-
import java.io.File
24-
2523
import org.locationtech.rasterframes._
2624
import geotrellis.proj4.LatLng
2725
import geotrellis.spark._
@@ -39,29 +37,27 @@ class GeoTrellisCatalogSpec
3937

4038
lazy val testRdd = TestData.randomSpatioTemporalTileLayerRDD(10, 12, 5, 6)
4139

42-
import sqlContext.implicits._
40+
import spark.implicits._
4341

4442
before {
45-
val outputDir = new File(outputLocalPath)
46-
FileUtil.fullyDelete(outputDir)
47-
outputDir.deleteOnExit()
48-
lazy val writer = LayerWriter(outputDir.toURI)
43+
FileUtil.fullyDelete(scratchDir.toFile)
44+
lazy val writer = LayerWriter(scratchDir.toUri)
4945
val index = ZCurveKeyIndexMethod.byDay()
5046
writer.write(LayerId("layer-1", 0), testRdd, index)
5147
writer.write(LayerId("layer-2", 0), testRdd, index)
5248
}
5349

5450
describe("Catalog reading") {
5551
it("should show two zoom levels") {
56-
val cat = sqlContext.read
57-
.geotrellisCatalog(outputLocal.toUri)
52+
val cat = spark.read
53+
.geotrellisCatalog(scratchDir.toUri)
5854
assert(cat.schema.length > 4)
5955
assert(cat.count() === 2)
6056
}
6157

6258
it("should support loading a layer in a nice way") {
63-
val cat = sqlContext.read
64-
.geotrellisCatalog(outputLocal.toUri)
59+
val cat = spark.read
60+
.geotrellisCatalog(scratchDir.toUri)
6561

6662
// Select two layers.
6763
val layer = cat
@@ -70,7 +66,7 @@ class GeoTrellisCatalogSpec
7066
.collect
7167
assert(layer.length === 2)
7268

73-
val lots = layer.map(sqlContext.read.geotrellis.loadLayer).map(_.toDF).reduce(_ union _)
69+
val lots = layer.map(spark.read.geotrellis.loadLayer).map(_.toDF).reduce(_ union _)
7470
assert(lots.count === 60)
7571
}
7672
}

datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ class GeoTrellisDataSourceSpec
5656
import TestData._
5757

5858
val tileSize = 12
59-
lazy val layer = Layer(new File(outputLocalPath).toURI, LayerId("test-layer", 4))
60-
lazy val tfLayer = Layer(new File(outputLocalPath).toURI, LayerId("test-tf-layer", 4))
61-
lazy val sampleImageLayer = Layer(new File(outputLocalPath).toURI, LayerId("sample", 0))
59+
lazy val layer = Layer(scratchDir.toUri, LayerId("test-layer", 4))
60+
lazy val tfLayer = Layer(scratchDir.toUri, LayerId("test-tf-layer", 4))
61+
lazy val sampleImageLayer = Layer(scratchDir.toUri, LayerId("sample", 0))
6262
val now = ZonedDateTime.now()
6363
val tileCoordRange = 2 to 5
6464

0 commit comments

Comments
 (0)