Skip to content

Commit 67108de

Browse files
committed
Additional fixes for issue #242 that cropped back up.
1 parent 90dea60 commit 67108de

File tree

12 files changed

+60
-24
lines changed

12 files changed

+60
-24
lines changed

bench/src/main/scala/org/locationtech/rasterframes/bench/TileEncodeBench.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class TileEncodeBench extends SparkEnv {
5555
cellTypeName match {
5656
case "rasterRef"
5757
val baseCOG = "https://s3-us-west-2.amazonaws.com/landsat-pds/c1/L8/149/039/LC08_L1TP_149039_20170411_20170415_01_T1/LC08_L1TP_149039_20170411_20170415_01_T1_B1.TIF"
58-
tile = RasterRefTile(RasterRef(RasterSource(URI.create(baseCOG)), 0, Some(Extent(253785.0, 3235185.0, 485115.0, 3471015.0))))
58+
val extent = Extent(253785.0, 3235185.0, 485115.0, 3471015.0)
59+
tile = RasterRefTile(RasterRef(RasterSource(URI.create(baseCOG)), 0, Some(extent), None))
5960
case _
6061
tile = randomTile(tileSize, tileSize, cellTypeName)
6162
}

core/src/main/scala/org/locationtech/rasterframes/encoders/StandardSerializers.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,30 @@ trait StandardSerializers {
6363
t.xmin, t.ymin, t.xmax, t.ymax
6464
)
6565
override def from[R](row: R, io: CatalystIO[R]): Extent = Extent(
66-
io.getDouble(row, 0), io.getDouble(row, 1), io.getDouble(row, 2), io.getDouble(row, 3)
66+
io.getDouble(row, 0),
67+
io.getDouble(row, 1),
68+
io.getDouble(row, 2),
69+
io.getDouble(row, 3)
70+
)
71+
}
72+
73+
implicit val gridBoundsSerializer: CatalystSerializer[GridBounds] = new CatalystSerializer[GridBounds] {
74+
override def schema: StructType = StructType(Seq(
75+
StructField("colMin", IntegerType, false),
76+
StructField("rowlMin", IntegerType, false),
77+
StructField("colMax", IntegerType, false),
78+
StructField("rowMax", IntegerType, false)
79+
))
80+
81+
override protected def to[R](t: GridBounds, io: CatalystIO[R]): R = io.create(
82+
t.colMin, t.rowMin, t.colMax, t.rowMax
83+
)
84+
85+
override protected def from[R](t: R, io: CatalystIO[R]): GridBounds = GridBounds(
86+
io.getInt(t, 0),
87+
io.getInt(t, 1),
88+
io.getInt(t, 2),
89+
io.getInt(t, 3)
6790
)
6891
}
6992

core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package org.locationtech.rasterframes.expressions.generators
2323

2424
import com.typesafe.scalalogging.LazyLogging
25+
import geotrellis.raster.GridBounds
2526
import geotrellis.vector.Extent
2627
import org.apache.spark.sql.catalyst.InternalRow
2728
import org.apache.spark.sql.catalyst.expressions._
@@ -55,18 +56,22 @@ case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[
5556
name <- bandNames(basename, bandIndexes)
5657
} yield StructField(name, schemaOf[RasterRef], true))
5758

58-
private def band2ref(src: RasterSource, e: Option[Extent])(b: Int): RasterRef =
59-
if (b < src.bandCount) RasterRef(src, b, e) else null
59+
private def band2ref(src: RasterSource, e: Option[(GridBounds, Extent)])(b: Int): RasterRef =
60+
if (b < src.bandCount) RasterRef(src, b, e.map(_._2), e.map(_._1)) else null
61+
6062

6163
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
6264
try {
6365
val refs = children.map { child
6466
val src = RasterSourceType.deserialize(child.eval(input))
67+
val srcRE = src.rasterExtent
6568
subtileDims.map(dims => {
66-
val extents = src.layoutExtents(dims)
67-
extents.map(e bandIndexes.map(band2ref(src, Some(e))))
69+
val subGB = src.layoutBounds(dims)
70+
val subs = subGB.map(gb => (gb, srcRE.extentFor(gb, clamp = true)))
71+
72+
subs.map(p => bandIndexes.map(band2ref(src, Some(p))))
6873
})
69-
.getOrElse(Seq(bandIndexes.map(band2ref(src, None))))
74+
.getOrElse(Seq(bandIndexes.map(band2ref(src, None))))
7075
}
7176
refs.transpose.map(ts InternalRow(ts.flatMap(_.map(_.toInternalRow)): _*))
7277
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/SetCellType.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ import org.locationtech.rasterframes.expressions.row
5050
* rhs - a cell type definition""",
5151
examples = """
5252
Examples:
53-
> SELECT _FUNC_(tile, 1.5);
53+
> SELECT _FUNC_(tile, 'int16ud0');
5454
..."""
5555
)
5656
case class SetCellType(tile: Expression, cellType: Expression)

core/src/main/scala/org/locationtech/rasterframes/ref/RasterRef.scala

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ import org.locationtech.rasterframes.tiles.ProjectedRasterTile
3838
*
3939
* @since 8/21/18
4040
*/
41-
case class RasterRef(source: RasterSource, bandIndex: Int, subextent: Option[Extent])
41+
case class RasterRef(source: RasterSource, bandIndex: Int, subextent: Option[Extent], subgrid: Option[GridBounds])
4242
extends ProjectedRasterLike {
4343
def crs: CRS = source.crs
4444
def extent: Extent = subextent.getOrElse(source.extent)
@@ -48,12 +48,12 @@ case class RasterRef(source: RasterSource, bandIndex: Int, subextent: Option[Ext
4848
def cellType: CellType = source.cellType
4949
def tile: ProjectedRasterTile = ProjectedRasterTile(RasterRefTile(this), extent, crs)
5050

51-
protected lazy val grid: GridBounds = source.rasterExtent.gridBoundsFor(extent, true)
52-
protected def srcExtent: Extent = extent
51+
protected lazy val grid: GridBounds =
52+
subgrid.getOrElse(source.rasterExtent.gridBoundsFor(extent, true))
5353

5454
protected lazy val realizedTile: Tile = {
55-
RasterRef.log.trace(s"Fetching $srcExtent from band $bandIndex of $source")
56-
source.read(srcExtent, Seq(bandIndex)).tile.band(0)
55+
RasterRef.log.trace(s"Fetching $extent ($grid) from band $bandIndex of $source")
56+
source.read(grid, Seq(bandIndex)).tile.band(0)
5757
}
5858
}
5959

@@ -79,20 +79,24 @@ object RasterRef extends LazyLogging {
7979
override def schema: StructType = StructType(Seq(
8080
StructField("source", rsType, false),
8181
StructField("bandIndex", IntegerType, false),
82-
StructField("subextent", schemaOf[Extent], true)
82+
StructField("subextent", schemaOf[Extent], true),
83+
StructField("subgrid", schemaOf[GridBounds], true)
8384
))
8485

8586
override def to[R](t: RasterRef, io: CatalystIO[R]): R = io.create(
8687
io.to(t.source)(RasterSourceUDT.rasterSourceSerializer),
8788
t.bandIndex,
88-
t.subextent.map(io.to[Extent]).orNull
89+
t.subextent.map(io.to[Extent]).orNull,
90+
t.subgrid.map(io.to[GridBounds]).orNull
8991
)
9092

9193
override def from[R](row: R, io: CatalystIO[R]): RasterRef = RasterRef(
9294
io.get[RasterSource](row, 0)(RasterSourceUDT.rasterSourceSerializer),
9395
io.getInt(row, 1),
9496
if (io.isNullAt(row, 2)) None
95-
else Option(io.get[Extent](row, 2))
97+
else Option(io.get[Extent](row, 2)),
98+
if (io.isNullAt(row, 3)) None
99+
else Option(io.get[GridBounds](row, 3))
96100
)
97101
}
98102

core/src/main/scala/org/locationtech/rasterframes/ref/RasterSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ trait RasterSource extends ProjectedRasterLike with Serializable {
7878

7979
def layoutExtents(dims: TileDimensions): Seq[Extent] = {
8080
val re = rasterExtent
81-
layoutBounds(dims).map(re.rasterExtentFor).map(_.extent)
81+
layoutBounds(dims).map(re.extentFor(_, clamp = true))
8282
}
8383

8484
def layoutBounds(dims: TileDimensions): Seq[GridBounds] = {

core/src/main/scala/org/locationtech/rasterframes/tiles/ProjectedRasterTile.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ object ProjectedRasterTile {
6060
extends ProjectedRasterTile {
6161
def delegate: Tile = t
6262

63+
// NB: Don't be tempted to move this into the parent trait. Will get stack overflow.
6364
override def convert(cellType: CellType): Tile =
6465
ConcreteProjectedRasterTile(t.convert(cellType), extent, crs)
6566

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
6464
TestData.randomTile(cols, rows, UByteConstantNoDataCellType)
6565
)).map(ProjectedRasterTile(_, extent, crs)) :+ null
6666

67-
def lazyPRT = RasterRef(RasterSource(TestData.l8samplePath), 0, None).tile
67+
def lazyPRT = RasterRef(RasterSource(TestData.l8samplePath), 0, None, None).tile
6868

6969
implicit val pairEnc = Encoders.tuple(ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder)
7070
implicit val tripEnc = Encoders.tuple(ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder)

core/src/test/scala/org/locationtech/rasterframes/encoders/CatalystSerializerSpec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ class CatalystSerializerSpec extends TestEnvironment with TestData {
105105
it("should serialize RasterRef") {
106106
// TODO: Decide if RasterRef should be encoded 'flat', non-'flat', or depends
107107
val src = RasterSource(remoteCOGSingleband1)
108-
val value = RasterRef(src, 0, Some(src.extent.buffer(-3.0)))
108+
val ext = src.extent.buffer(-3.0)
109+
val value = RasterRef(src, 0, Some(ext), Some(src.rasterExtent.gridBoundsFor(ext)))
109110
assertConsistent(value)
110111
assertInvertable(value)
111112
}

core/src/test/scala/org/locationtech/rasterframes/ref/RasterRefSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ class RasterRefSpec extends TestEnvironment with TestData {
4747

4848
trait Fixture {
4949
val src = RasterSource(remoteCOGSingleband1)
50-
val fullRaster = RasterRef(src, 0, None)
50+
val fullRaster = RasterRef(src, 0, None, None)
5151
val subExtent = sub(src.extent)
52-
val subRaster = RasterRef(src, 0, Some(subExtent))
52+
val subRaster = RasterRef(src, 0, Some(subExtent), Some(src.rasterExtent.gridBoundsFor(subExtent)))
5353
}
5454

5555
import spark.implicits._
@@ -171,7 +171,7 @@ class RasterRefSpec extends TestEnvironment with TestData {
171171
val src = RasterSource(remoteMODIS)
172172
val dims = src
173173
.layoutExtents(NOMINAL_TILE_DIMS)
174-
.map(e => RasterRef(src, 0, Some(e)))
174+
.map(e => RasterRef(src, 0, Some(e), None))
175175
.map(_.dimensions)
176176
.distinct
177177

0 commit comments

Comments
 (0)