Skip to content

Commit b063698

Browse files
committed
Reworked rf_tile to accept RasterRef as argument.
Wrote unit tests for rf_tile.
1 parent ba3d30d commit b063698

File tree

7 files changed

+103
-50
lines changed

7 files changed

+103
-50
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.locationtech.rasterframes.tiles.ProjectedRasterTile
3535

3636
private[rasterframes]
3737
object DynamicExtractors {
38-
/** Partial function for pulling a tile and its contesxt from an input row. */
38+
/** Partial function for pulling a tile and its context from an input row. */
3939
lazy val tileExtractor: PartialFunction[DataType, InternalRow => (Tile, Option[TileContext])] = {
4040
case _: TileUDT =>
4141
(row: InternalRow) =>
@@ -47,6 +47,14 @@ object DynamicExtractors {
4747
}
4848
}
4949

50+
lazy val rasterRefExtractor: PartialFunction[DataType, InternalRow => RasterRef] = {
51+
case t if t.conformsTo[RasterRef] =>
52+
(row: InternalRow) => row.to[RasterRef]
53+
}
54+
55+
lazy val tileableExtractor: PartialFunction[DataType, InternalRow => Tile] =
56+
tileExtractor.andThen(_.andThen(_._1)).orElse(rasterRefExtractor.andThen(_.andThen(_.tile)))
57+
5058
lazy val rowTileExtractor: PartialFunction[DataType, Row => (Tile, Option[TileContext])] = {
5159
case _: TileUDT =>
5260
(row: Row) => (row.to[Tile](TileUDT.tileSerializer), None)

core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222
package org.locationtech.rasterframes.expressions.accessors
2323

2424
import geotrellis.raster.Tile
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2527
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
26-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
28+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, UnaryExpression}
2729
import org.apache.spark.sql.rf.TileUDT
2830
import org.apache.spark.sql.types.DataType
2931
import org.apache.spark.sql.{Column, TypedColumn}
3032
import org.locationtech.rasterframes._
3133
import org.locationtech.rasterframes.encoders.CatalystSerializer._
32-
import org.locationtech.rasterframes.expressions.UnaryRasterOp
33-
import org.locationtech.rasterframes.model.TileContext
34+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
35+
import org.locationtech.rasterframes.expressions._
3436

3537
@ExpressionDescription(
3638
usage = "_FUNC_(raster) - Extracts the Tile component of a RasterSource, ProjectedRasterTile (or Tile) and ensures the cells are fully fetched.",
@@ -39,14 +41,22 @@ import org.locationtech.rasterframes.model.TileContext
3941
> SELECT _FUNC_(raster);
4042
....
4143
""")
42-
case class RealizeTile(child: Expression) extends UnaryRasterOp with CodegenFallback {
44+
case class RealizeTile(child: Expression) extends UnaryExpression with CodegenFallback {
4345
override def dataType: DataType = TileType
4446

4547
override def nodeName: String = "rf_tile"
46-
implicit val tileSer = TileUDT.tileSerializer
4748

48-
override protected def eval(tile: Tile, ctx: Option[TileContext]): Any =
49+
override def checkInputDataTypes(): TypeCheckResult = {
50+
if (!tileableExtractor.isDefinedAt(child.dataType)) {
51+
TypeCheckFailure(s"Input type '${child.dataType}' does not conform to a tiled raster type.")
52+
} else TypeCheckSuccess
53+
}
54+
implicit val tileSer = TileUDT.tileSerializer
55+
override protected def nullSafeEval(input: Any): Any = {
56+
val in = row(input)
57+
val tile = tileableExtractor(child.dataType)(in)
4958
(tile.toArrayTile(): Tile).toInternalRow
59+
}
5060
}
5161

5262
object RealizeTile {

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,53 +23,22 @@ package org.locationtech.rasterframes
2323

2424
import java.io.ByteArrayInputStream
2525

26-
import geotrellis.proj4.LatLng
2726
import geotrellis.raster
2827
import geotrellis.raster._
2928
import geotrellis.raster.render.ColorRamps
3029
import geotrellis.raster.testkit.RasterMatchers
31-
import geotrellis.vector.Extent
3230
import javax.imageio.ImageIO
3331
import org.apache.spark.sql.Encoders
3432
import org.apache.spark.sql.functions._
3533
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
3634
import org.locationtech.rasterframes.model.TileDimensions
37-
import org.locationtech.rasterframes.ref.{RasterRef, RasterSource}
3835
import org.locationtech.rasterframes.stats._
3936
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
4037

4138
class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
39+
import TestData._
4240
import spark.implicits._
4341

44-
val extent = Extent(10, 20, 30, 40)
45-
val crs = LatLng
46-
val ct = ByteUserDefinedNoDataCellType(-2)
47-
val cols = 10
48-
val rows = cols
49-
val tileSize = cols * rows
50-
val tileCount = 10
51-
val numND = 4
52-
lazy val zero = TestData.projectedRasterTile(cols, rows, 0, extent, crs, ct)
53-
lazy val one = TestData.projectedRasterTile(cols, rows, 1, extent, crs, ct)
54-
lazy val two = TestData.projectedRasterTile(cols, rows, 2, extent, crs, ct)
55-
lazy val three = TestData.projectedRasterTile(cols, rows, 3, extent, crs, ct)
56-
lazy val six = ProjectedRasterTile(three * two, three.extent, three.crs)
57-
lazy val nd = TestData.projectedRasterTile(cols, rows, -2, extent, crs, ct)
58-
lazy val randPRT = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextInt(), extent, crs, ct)
59-
lazy val randNDPRT: Tile = TestData.injectND(numND)(randPRT)
60-
61-
lazy val randDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextGaussian(), extent, crs, DoubleConstantNoDataCellType)
62-
lazy val randDoubleNDTile = TestData.injectND(numND)(randDoubleTile)
63-
lazy val randPositiveDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextDouble() + 1e-6, extent, crs, DoubleConstantNoDataCellType)
64-
65-
val expectedRandNoData: Long = numND * tileCount.toLong
66-
val expectedRandData: Long = cols * rows * tileCount - expectedRandNoData
67-
lazy val randNDTilesWithNull = Seq.fill[Tile](tileCount)(TestData.injectND(numND)(
68-
TestData.randomTile(cols, rows, UByteConstantNoDataCellType)
69-
)).map(ProjectedRasterTile(_, extent, crs)) :+ null
70-
71-
def lazyPRT = RasterRef(RasterSource(TestData.l8samplePath), 0, None, None).tile
72-
7342
implicit val pairEnc = Encoders.tuple(ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder)
7443
implicit val tripEnc = Encoders.tuple(ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder)
7544

core/src/test/scala/org/locationtech/rasterframes/TestData.scala

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ import org.apache.spark.SparkContext
3838
import org.apache.spark.sql.SparkSession
3939
import org.locationtech.jts.geom.{Coordinate, GeometryFactory}
4040
import org.locationtech.rasterframes.expressions.tilestats.NoDataCells
41+
import org.locationtech.rasterframes.ref.{RasterRef, RasterSource}
4142
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
42-
import spray.json.JsObject
4343

4444
import scala.reflect.ClassTag
4545

@@ -49,8 +49,15 @@ import scala.reflect.ClassTag
4949
* @since 4/3/17
5050
*/
5151
trait TestData {
52+
val extent = Extent(10, 20, 30, 40)
53+
val crs = LatLng
54+
val ct = ByteUserDefinedNoDataCellType(-2)
55+
val cols = 10
56+
val rows = cols
57+
val tileSize = cols * rows
58+
val tileCount = 10
59+
val numND = 4
5260
val instant = ZonedDateTime.now()
53-
val extent = Extent(1, 2, 3, 4)
5461
val sk = SpatialKey(37, 41)
5562
val stk = SpaceTimeKey(sk, instant)
5663
val pe = ProjectedExtent(extent, LatLng)
@@ -153,6 +160,29 @@ trait TestData {
153160
lazy val l8samplePath: URI = getClass.getResource("/L8-B1-Elkton-VA.tiff").toURI
154161
lazy val modisConvertedMrfPath: URI = getClass.getResource("/MCD43A4.A2019111.h30v06.006.2019120033434_01.mrf").toURI
155162

163+
164+
165+
lazy val zero = TestData.projectedRasterTile(cols, rows, 0, extent, crs, ct)
166+
lazy val one = TestData.projectedRasterTile(cols, rows, 1, extent, crs, ct)
167+
lazy val two = TestData.projectedRasterTile(cols, rows, 2, extent, crs, ct)
168+
lazy val three = TestData.projectedRasterTile(cols, rows, 3, extent, crs, ct)
169+
lazy val six = ProjectedRasterTile(three * two, three.extent, three.crs)
170+
lazy val nd = TestData.projectedRasterTile(cols, rows, -2, extent, crs, ct)
171+
lazy val randPRT = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextInt(), extent, crs, ct)
172+
lazy val randNDPRT: Tile = TestData.injectND(numND)(randPRT)
173+
174+
lazy val randDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextGaussian(), extent, crs, DoubleConstantNoDataCellType)
175+
lazy val randDoubleNDTile = TestData.injectND(numND)(randDoubleTile)
176+
lazy val randPositiveDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextDouble() + 1e-6, extent, crs, DoubleConstantNoDataCellType)
177+
178+
val expectedRandNoData: Long = numND * tileCount.toLong
179+
val expectedRandData: Long = cols * rows * tileCount - expectedRandNoData
180+
lazy val randNDTilesWithNull = Seq.fill[Tile](tileCount)(TestData.injectND(numND)(
181+
TestData.randomTile(cols, rows, UByteConstantNoDataCellType)
182+
)).map(ProjectedRasterTile(_, extent, crs)) :+ null
183+
184+
def lazyPRT = RasterRef(RasterSource(TestData.l8samplePath), 0, None, None).tile
185+
156186
object GeomData {
157187
val fact = new GeometryFactory()
158188
val c1 = new Coordinate(1, 2)

core/src/test/scala/org/locationtech/rasterframes/encoders/CatalystSerializerSpec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ import org.locationtech.rasterframes.model.{CellContext, TileContext, TileDataCo
3535
import org.locationtech.rasterframes.ref.{RasterRef, RasterSource}
3636
import org.scalatest.Assertion
3737

38-
class CatalystSerializerSpec extends TestEnvironment with TestData {
38+
class CatalystSerializerSpec extends TestEnvironment {
39+
import TestData._
40+
3941
val dc = TileDataContext(UShortUserDefinedNoDataCellType(3), TileDimensions(12, 23))
4042
val tc = TileContext(Extent(1, 2, 3, 4), WebMercator)
4143
val cc = CellContext(tc, dc, 34, 45)

core/src/test/scala/org/locationtech/rasterframes/ref/RasterRefSpec.scala

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121

2222
package org.locationtech.rasterframes.ref
2323

24-
import org.locationtech.rasterframes._
25-
import org.locationtech.rasterframes.expressions.accessors._
26-
import org.locationtech.rasterframes.expressions.generators._
27-
import RasterRef.RasterRefTile
28-
import geotrellis.raster.Tile
24+
import geotrellis.raster.{ByteConstantNoDataCellType, Tile}
2925
import geotrellis.vector.Extent
3026
import org.apache.spark.sql.Encoders
31-
import org.locationtech.rasterframes.TestEnvironment
27+
import org.locationtech.rasterframes.{TestEnvironment, _}
28+
import org.locationtech.rasterframes.expressions.accessors._
29+
import org.locationtech.rasterframes.expressions.generators._
30+
import org.locationtech.rasterframes.ref.RasterRef.RasterRefTile
31+
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
3232

3333
/**
3434
*
@@ -199,12 +199,45 @@ class RasterRefSpec extends TestEnvironment with TestData {
199199

200200
refs.count() shouldBe > (1L)
201201

202-
203202
val dims = refs.select(rf_dimensions($"proj_raster")).distinct().collect()
204203
forEvery(dims) { r =>
205204
r.cols should be <= NOMINAL_TILE_SIZE
206205
r.rows should be <= NOMINAL_TILE_SIZE
207206
}
208207
}
209208
}
209+
210+
describe("RealizeTile") {
211+
it("should pass through basic Tile") {
212+
val t = TestData.randomTile(5, 5, ByteConstantNoDataCellType)
213+
val result = Seq(t).toDF("tile").select(rf_tile($"tile")).first()
214+
assertEqual(result, t)
215+
}
216+
217+
it("should simplify ProjectedRasterTile") {
218+
val t = TestData.randNDPRT
219+
val result = Seq(t).toDF("tile").select(rf_tile($"tile")).first()
220+
result.isInstanceOf[ProjectedRasterLike] should be (false)
221+
assertEqual(result, t.toArrayTile())
222+
}
223+
224+
it("should resolve a RasterRef") {
225+
new Fixture {
226+
import RasterRef.rrEncoder // This shouldn't be required, but product encoder gets choosen.
227+
val r: RasterRef = subRaster
228+
val result = Seq(r).toDF("ref").select(rf_tile($"ref")).first()
229+
result.isInstanceOf[RasterRefTile] should be(false)
230+
assertEqual(r.tile.toArrayTile(), result)
231+
}
232+
}
233+
234+
it("should resolve a RasterRefTile") {
235+
new Fixture {
236+
val t: ProjectedRasterTile = RasterRefTile(subRaster)
237+
val result = Seq(t).toDF("tile").select(rf_tile($"tile")).first()
238+
result.isInstanceOf[RasterRefTile] should be(false)
239+
assertEqual(t.toArrayTile(), result)
240+
}
241+
}
242+
}
210243
}

datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotrellis/GeoTrellisDataSourceSpec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ import org.scalatest.{BeforeAndAfterAll, Inspectors}
5151
import scala.math.{max, min}
5252

5353
class GeoTrellisDataSourceSpec
54-
extends TestEnvironment with TestData with BeforeAndAfterAll with Inspectors
54+
extends TestEnvironment with BeforeAndAfterAll with Inspectors
5555
with RasterMatchers with DataSourceOptions {
56+
import TestData._
5657

5758
val tileSize = 12
5859
lazy val layer = Layer(new File(outputLocalPath).toURI, LayerId("test-layer", 4))

0 commit comments

Comments
 (0)