Skip to content

Commit 109dd45

Browse files
authored
Merge pull request #277 from s22s/fix/242
Fixes #242.
2 parents cb205cf + cc8b096 commit 109dd45

File tree

9 files changed

+391
-51
lines changed

9 files changed

+391
-51
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
/*
2+
* Copyright 2016 Azavea
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.locationtech.rasterframes.model
17+
18+
19+
import geotrellis.raster._
20+
import geotrellis.vector._
21+
22+
import scala.math.{ceil, max, min}
23+
24+
/**
25+
* This class is a copy of the GeoTrellis 2.x `RasterExtent`,
26+
* with [GT 3.0 fixes](https://github.com/locationtech/geotrellis/pull/2953/files) incorporated into the
27+
* new `GridExtent[T]` class. This class should be removed after RasterFrames is upgraded to GT 3.x.
28+
*/
29+
case class FixedRasterExtent(
30+
override val extent: Extent,
31+
override val cellwidth: Double,
32+
override val cellheight: Double,
33+
cols: Int,
34+
rows: Int
35+
) extends GridExtent(extent, cellwidth, cellheight) with Grid {
36+
import FixedRasterExtent._
37+
38+
if (cols <= 0) throw GeoAttrsError(s"invalid cols: $cols")
39+
if (rows <= 0) throw GeoAttrsError(s"invalid rows: $rows")
40+
41+
/**
42+
* Convert map coordinates (x, y) to grid coordinates (col, row).
43+
*/
44+
final def mapToGrid(x: Double, y: Double): (Int, Int) = {
45+
val col = floorWithTolerance((x - extent.xmin) / cellwidth).toInt
46+
val row = floorWithTolerance((extent.ymax - y) / cellheight).toInt
47+
(col, row)
48+
}
49+
50+
/**
51+
* Convert map coordinate x to grid coordinate column.
52+
*/
53+
final def mapXToGrid(x: Double): Int = floorWithTolerance(mapXToGridDouble(x)).toInt
54+
55+
/**
56+
* Convert map coordinate x to grid coordinate column.
57+
*/
58+
final def mapXToGridDouble(x: Double): Double = (x - extent.xmin) / cellwidth
59+
60+
/**
61+
* Convert map coordinate y to grid coordinate row.
62+
*/
63+
final def mapYToGrid(y: Double): Int = floorWithTolerance(mapYToGridDouble(y)).toInt
64+
65+
/**
66+
* Convert map coordinate y to grid coordinate row.
67+
*/
68+
final def mapYToGridDouble(y: Double): Double = (extent.ymax - y ) / cellheight
69+
70+
/**
71+
* Convert map coordinate tuple (x, y) to grid coordinates (col, row).
72+
*/
73+
final def mapToGrid(mapCoord: (Double, Double)): (Int, Int) = {
74+
val (x, y) = mapCoord
75+
mapToGrid(x, y)
76+
}
77+
78+
/**
79+
* Convert a point to grid coordinates (col, row).
80+
*/
81+
final def mapToGrid(p: Point): (Int, Int) =
82+
mapToGrid(p.x, p.y)
83+
84+
/**
85+
* The map coordinate of a grid cell is the center point.
86+
*/
87+
final def gridToMap(col: Int, row: Int): (Double, Double) = {
88+
val x = col * cellwidth + extent.xmin + (cellwidth / 2)
89+
val y = extent.ymax - (row * cellheight) - (cellheight / 2)
90+
91+
(x, y)
92+
}
93+
94+
/**
95+
* For a give column, find the corresponding x-coordinate in the
96+
* grid of the present [[FixedRasterExtent]].
97+
*/
98+
final def gridColToMap(col: Int): Double = {
99+
col * cellwidth + extent.xmin + (cellwidth / 2)
100+
}
101+
102+
/**
103+
* For a give row, find the corresponding y-coordinate in the grid
104+
* of the present [[FixedRasterExtent]].
105+
*/
106+
final def gridRowToMap(row: Int): Double = {
107+
extent.ymax - (row * cellheight) - (cellheight / 2)
108+
}
109+
110+
/**
111+
* Gets the GridBounds aligned with this FixedRasterExtent that is the
112+
* smallest subgrid of containing all points within the extent. The
113+
* extent is considered inclusive on it's north and west borders,
114+
* exclusive on it's east and south borders. See [[FixedRasterExtent]]
115+
* for a discussion of grid and extent boundary concepts.
116+
*
117+
* The 'clamp' flag determines whether or not to clamp the
118+
* GridBounds to the FixedRasterExtent; defaults to true. If false,
119+
* GridBounds can contain negative values, or values outside of
120+
* this FixedRasterExtent's boundaries.
121+
*
122+
* @param subExtent The extent to get the grid bounds for
123+
* @param clamp A boolean
124+
*/
125+
def gridBoundsFor(subExtent: Extent, clamp: Boolean = true): GridBounds = {
126+
// West and North boundaries are a simple mapToGrid call.
127+
val (colMin, rowMin) = mapToGrid(subExtent.xmin, subExtent.ymax)
128+
129+
// If South East corner is on grid border lines, we want to still only include
130+
// what is to the West and\or North of the point. However if the border point
131+
// is not directly on a grid division, include the whole row and/or column that
132+
// contains the point.
133+
val colMax = {
134+
val colMaxDouble = mapXToGridDouble(subExtent.xmax)
135+
if(math.abs(colMaxDouble - floorWithTolerance(colMaxDouble)) < FixedRasterExtent.epsilon) colMaxDouble.toInt - 1
136+
else colMaxDouble.toInt
137+
}
138+
139+
val rowMax = {
140+
val rowMaxDouble = mapYToGridDouble(subExtent.ymin)
141+
if(math.abs(rowMaxDouble - floorWithTolerance(rowMaxDouble)) < FixedRasterExtent.epsilon) rowMaxDouble.toInt - 1
142+
else rowMaxDouble.toInt
143+
}
144+
145+
if(clamp) {
146+
GridBounds(math.min(math.max(colMin, 0), cols - 1),
147+
math.min(math.max(rowMin, 0), rows - 1),
148+
math.min(math.max(colMax, 0), cols - 1),
149+
math.min(math.max(rowMax, 0), rows - 1))
150+
} else {
151+
GridBounds(colMin, rowMin, colMax, rowMax)
152+
}
153+
}
154+
155+
/**
156+
* Combine two different [[FixedRasterExtent]]s (which must have the
157+
* same cellsizes). The result is a new extent at the same
158+
* resolution.
159+
*/
160+
def combine (that: FixedRasterExtent): FixedRasterExtent = {
161+
if (cellwidth != that.cellwidth)
162+
throw GeoAttrsError(s"illegal cellwidths: $cellwidth and ${that.cellwidth}")
163+
if (cellheight != that.cellheight)
164+
throw GeoAttrsError(s"illegal cellheights: $cellheight and ${that.cellheight}")
165+
166+
val newExtent = extent.combine(that.extent)
167+
val newRows = ceil(newExtent.height / cellheight).toInt
168+
val newCols = ceil(newExtent.width / cellwidth).toInt
169+
170+
FixedRasterExtent(newExtent, cellwidth, cellheight, newCols, newRows)
171+
}
172+
173+
/**
174+
* Returns a [[RasterExtent]] with the same extent, but a modified
175+
* number of columns and rows based on the given cell height and
176+
* width.
177+
*/
178+
def withResolution(targetCellWidth: Double, targetCellHeight: Double): FixedRasterExtent = {
179+
val newCols = math.ceil((extent.xmax - extent.xmin) / targetCellWidth).toInt
180+
val newRows = math.ceil((extent.ymax - extent.ymin) / targetCellHeight).toInt
181+
FixedRasterExtent(extent, targetCellWidth, targetCellHeight, newCols, newRows)
182+
}
183+
184+
/**
185+
* Returns a [[FixedRasterExtent]] with the same extent, but a modified
186+
* number of columns and rows based on the given cell height and
187+
* width.
188+
*/
189+
def withResolution(cellSize: CellSize): FixedRasterExtent =
190+
withResolution(cellSize.width, cellSize.height)
191+
192+
/**
193+
* Returns a [[FixedRasterExtent]] with the same extent and the given
194+
* number of columns and rows.
195+
*/
196+
def withDimensions(targetCols: Int, targetRows: Int): FixedRasterExtent =
197+
FixedRasterExtent(extent, targetCols, targetRows)
198+
199+
/**
200+
* Adjusts a raster extent so that it can encompass the tile
201+
* layout. Will resample the extent, but keep the resolution, and
202+
* preserve north and west borders
203+
*/
204+
def adjustTo(tileLayout: TileLayout): FixedRasterExtent = {
205+
val totalCols = tileLayout.tileCols * tileLayout.layoutCols
206+
val totalRows = tileLayout.tileRows * tileLayout.layoutRows
207+
208+
val resampledExtent = Extent(extent.xmin, extent.ymax - (cellheight*totalRows),
209+
extent.xmin + (cellwidth*totalCols), extent.ymax)
210+
211+
FixedRasterExtent(resampledExtent, cellwidth, cellheight, totalCols, totalRows)
212+
}
213+
214+
/**
215+
* Returns a new [[FixedRasterExtent]] which represents the GridBounds
216+
* in relation to this FixedRasterExtent.
217+
*/
218+
def rasterExtentFor(gridBounds: GridBounds): FixedRasterExtent = {
219+
val (xminCenter, ymaxCenter) = gridToMap(gridBounds.colMin, gridBounds.rowMin)
220+
val (xmaxCenter, yminCenter) = gridToMap(gridBounds.colMax, gridBounds.rowMax)
221+
val (hcw, hch) = (cellwidth / 2, cellheight / 2)
222+
val e = Extent(xminCenter - hcw, yminCenter - hch, xmaxCenter + hcw, ymaxCenter + hch)
223+
FixedRasterExtent(e, cellwidth, cellheight, gridBounds.width, gridBounds.height)
224+
}
225+
}
226+
227+
/**
228+
* The companion object for the [[FixedRasterExtent]] type.
229+
*/
230+
object FixedRasterExtent {
231+
final val epsilon = 0.0000001
232+
233+
/**
234+
* Create a new [[FixedRasterExtent]] from an Extent, a column, and a
235+
* row.
236+
*/
237+
def apply(extent: Extent, cols: Int, rows: Int): FixedRasterExtent = {
238+
val cw = extent.width / cols
239+
val ch = extent.height / rows
240+
FixedRasterExtent(extent, cw, ch, cols, rows)
241+
}
242+
243+
/**
244+
* Create a new [[FixedRasterExtent]] from an Extent and a [[CellSize]].
245+
*/
246+
def apply(extent: Extent, cellSize: CellSize): FixedRasterExtent = {
247+
val cols = (extent.width / cellSize.width).toInt
248+
val rows = (extent.height / cellSize.height).toInt
249+
FixedRasterExtent(extent, cellSize.width, cellSize.height, cols, rows)
250+
}
251+
252+
/**
253+
* Create a new [[FixedRasterExtent]] from a [[CellGrid]] and an Extent.
254+
*/
255+
def apply(tile: CellGrid, extent: Extent): FixedRasterExtent =
256+
apply(extent, tile.cols, tile.rows)
257+
258+
/**
259+
* Create a new [[FixedRasterExtent]] from an Extent and a [[CellGrid]].
260+
*/
261+
def apply(extent: Extent, tile: CellGrid): FixedRasterExtent =
262+
apply(extent, tile.cols, tile.rows)
263+
264+
265+
/**
266+
* The same logic is used in QGIS: https://github.com/qgis/QGIS/blob/607664c5a6b47c559ed39892e736322b64b3faa4/src/analysis/raster/qgsalignraster.cpp#L38
267+
* The search query: https://github.com/qgis/QGIS/search?p=2&q=floor&type=&utf8=%E2%9C%93
268+
*
269+
* GDAL uses smth like that, however it was a bit hard to track it down:
270+
* https://github.com/OSGeo/gdal/blob/7601a637dfd204948d00f4691c08f02eb7584de5/gdal/frmts/vrt/vrtsources.cpp#L215
271+
* */
272+
def floorWithTolerance(value: Double): Double = {
273+
val roundedValue = math.round(value)
274+
if (math.abs(value - roundedValue) < epsilon) roundedValue
275+
else math.floor(value)
276+
}
277+
}
278+

core/src/main/scala/org/locationtech/rasterframes/ref/RasterRef.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ case class RasterRef(source: RasterSource, bandIndex: Int, subextent: Option[Ext
4848
def cellType: CellType = source.cellType
4949
def tile: ProjectedRasterTile = ProjectedRasterTile(RasterRefTile(this), extent, crs)
5050

51-
protected lazy val grid: GridBounds = source.rasterExtent.gridBoundsFor(extent)
51+
protected lazy val grid: GridBounds = source.rasterExtent.gridBoundsFor(extent, true)
5252
protected def srcExtent: Extent = extent
5353

5454
protected lazy val realizedTile: Tile = {

core/src/main/scala/org/locationtech/rasterframes/ref/RasterSource.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import org.apache.hadoop.conf.Configuration
3333
import org.apache.spark.annotation.Experimental
3434
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
3535
import org.apache.spark.sql.rf.RasterSourceUDT
36-
import org.locationtech.rasterframes.model.{TileContext, TileDimensions}
36+
import org.locationtech.rasterframes.model.{FixedRasterExtent, TileContext, TileDimensions}
3737
import org.locationtech.rasterframes.{NOMINAL_TILE_DIMS, rfConfig}
3838

3939
import scala.concurrent.duration.Duration
@@ -68,7 +68,7 @@ trait RasterSource extends ProjectedRasterLike with Serializable {
6868

6969
protected def readBounds(bounds: Traversable[GridBounds], bands: Seq[Int]): Iterator[Raster[MultibandTile]]
7070

71-
def rasterExtent = RasterExtent(extent, cols, rows)
71+
def rasterExtent = FixedRasterExtent(extent, cols, rows)
7272

7373
def cellSize = CellSize(extent, cols, rows)
7474

core/src/test/scala/org/locationtech/rasterframes/ref/RasterRefSpec.scala

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,28 +161,52 @@ class RasterRefSpec extends TestEnvironment with TestData {
161161
val data = buf.toByteArray
162162
val in = new ObjectInputStream(new ByteArrayInputStream(data))
163163
val recovered = in.readObject()
164-
assert(subRaster === recovered)
164+
subRaster should be (recovered)
165165
}
166166
}
167167
}
168168

169-
describe("CreateRasterRefs") {
170-
it("should convert and expand RasterSource") {
171-
new Fixture {
172-
import spark.implicits._
173-
val df = Seq(src).toDF("src")
174-
val refs = df.select(RasterSourceToRasterRefs(Some(NOMINAL_TILE_DIMS), Seq(0), $"src"))
175-
assert(refs.count() > 1)
169+
describe("RasterRef creation") {
170+
it("should realize subiles of proper size") {
171+
val src = RasterSource(remoteMODIS)
172+
val dims = src
173+
.layoutExtents(NOMINAL_TILE_DIMS)
174+
.map(e => RasterRef(src, 0, Some(e)))
175+
.map(_.dimensions)
176+
.distinct
177+
178+
forEvery(dims) { d =>
179+
d._1 should be <= NOMINAL_TILE_SIZE
180+
d._2 should be <= NOMINAL_TILE_SIZE
176181
}
177182
}
183+
}
178184

179-
it("should work with tile realization") {
180-
new Fixture {
181-
import spark.implicits._
182-
val df = Seq(src).toDF("src")
183-
val refs = df.select(RasterSourceToRasterRefs(Some(NOMINAL_TILE_DIMS), Seq(0), $"src"))
184-
assert(refs.count() > 1)
185+
describe("RasterSourceToRasterRefs") {
186+
it("should convert and expand RasterSource") {
187+
val src = RasterSource(remoteMODIS)
188+
import spark.implicits._
189+
val df = Seq(src).toDF("src")
190+
val refs = df.select(RasterSourceToRasterRefs(None, Seq(0), $"src"))
191+
refs.count() should be (1)
192+
}
193+
194+
it("should properly realize subtiles") {
195+
val src = RasterSource(remoteMODIS)
196+
import spark.implicits._
197+
val df = Seq(src).toDF("src")
198+
val refs = df.select(RasterSourceToRasterRefs(Some(NOMINAL_TILE_DIMS), Seq(0), $"src") as "proj_raster")
199+
200+
refs.count() shouldBe > (1L)
201+
202+
203+
val dims = refs.select(rf_dimensions($"proj_raster")).distinct().collect()
204+
forEvery(dims) { r =>
205+
r.cols should be <=NOMINAL_TILE_SIZE
206+
r.rows should be <=NOMINAL_TILE_SIZE
185207
}
208+
209+
dims.foreach(println)
186210
}
187211
}
188212
}

0 commit comments

Comments
 (0)