Skip to content

Commit c021d2e

Browse files
committed
Unit test build-out.
1 parent 6a2132a commit c021d2e

File tree

5 files changed

+130
-8
lines changed

5 files changed

+130
-8
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,12 @@ trait RasterFunctions {
5959
/** Extracts the bounding box from a RasterSource or ProjectedRasterTile */
6060
def rf_extent(col: Column): TypedColumn[Any, Extent] = GetExtent(col)
6161

62-
/** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource */
62+
/** Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource and its CRS */
6363
def rf_spatial_index(targetExtent: Column, targetCRS: Column) = XZ2Indexer(targetExtent, targetCRS)
6464

65+
/** Constructs a XZ2 index in WGS84 from either a ProjectedRasterTile or RasterSource */
66+
def rf_spatial_index(targetExtent: Column) = XZ2Indexer(targetExtent)
67+
6568
/** Extracts the CRS from a RasterSource or ProjectedRasterTile */
6669
def rf_crs(col: Column): TypedColumn[Any, CRS] = GetCRS(col)
6770

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/XZ2Indexer.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ import org.locationtech.rasterframes.jts.ReprojectionTransformer
4040
import org.locationtech.rasterframes.ref.{RasterRef, RasterSource}
4141
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
4242
import org.apache.spark.sql.rf
43+
import org.locationtech.rasterframes.expressions.accessors.GetCRS
4344

4445
/**
45-
* This expression constructs a XZ2 index for a given JTS geometry.
46+
* Constructs a XZ2 index in WGS84 from either a Geometry, Extent, ProjectedRasterTile, or RasterSource
4647
*
48+
* @param left geometry-like column
49+
* @param right CRS column
4750
* @param indexResolution resolution level of the space filling curve -
4851
* i.e. how many times the space will be recursively quartered
4952
* 1-18 is typical.
@@ -71,7 +74,7 @@ case class XZ2Indexer(left: Expression, right: Expression, indexResolution: Shor
7174

7275
val coords = left.dataType match {
7376
case t if rf.WithTypeConformity(t).conformsTo(JTSTypes.GeometryTypeInstance) =>
74-
JTSTypes.GeometryTypeInstance.deserialize(left)
77+
JTSTypes.GeometryTypeInstance.deserialize(leftInput)
7578
case t if t.conformsTo[Extent] =>
7679
row(leftInput).to[Extent]
7780
case t if t.conformsTo[Envelope] =>
@@ -112,4 +115,6 @@ object XZ2Indexer {
112115
import org.locationtech.rasterframes.encoders.SparkBasicEncoders.longEnc
113116
def apply(targetExtent: Column, targetCRS: Column): TypedColumn[Any, Long] =
114117
new Column(new XZ2Indexer(targetExtent.expr, targetCRS.expr)).as[Long]
118+
def apply(targetExtent: Column): TypedColumn[Any, Long] =
119+
new Column(new XZ2Indexer(targetExtent.expr, GetCRS(targetExtent.expr))).as[Long]
115120
}

core/src/test/scala/org/locationtech/rasterframes/GeometryFunctionsSpec.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,7 @@ class GeometryFunctionsSpec extends TestEnvironment with TestData with StandardC
131131
val wm4 = sql("SELECT st_reproject(ll, '+proj=longlat +ellps=WGS84 +datum=WGS84 +no_defs', 'EPSG:3857') AS wm4 from geom")
132132
.as[Geometry].first()
133133
wm4 should matchGeom(webMercator, 0.00001)
134-
135-
// TODO: See comment in `org.locationtech.rasterframes.expressions.register` for
136-
// TODO: what needs to happen to support this.
137-
//checkDocs("st_reproject")
134+
checkDocs("st_reproject")
138135
}
139136
}
140137

core/src/test/scala/org/locationtech/rasterframes/expressions/ProjectedLayerMetadataAggregateTest.scala renamed to core/src/test/scala/org/locationtech/rasterframes/expressions/ProjectedLayerMetadataAggregateSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.locationtech.rasterframes.encoders.serialized_literal
3030
import org.locationtech.rasterframes.expressions.aggregates.ProjectedLayerMetadataAggregate
3131
import org.locationtech.rasterframes.model.TileDimensions
3232

33-
class ProjectedLayerMetadataAggregateTest extends TestEnvironment {
33+
class ProjectedLayerMetadataAggregateSpec extends TestEnvironment {
3434

3535
import spark.implicits._
3636

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions
23+
import geotrellis.proj4.{CRS, LatLng, WebMercator}
24+
import org.locationtech.rasterframes._
25+
import geotrellis.vector.Extent
26+
import org.locationtech.rasterframes.TestEnvironment
27+
import org.apache.spark.sql.functions.lit
28+
import org.locationtech.rasterframes._
29+
import encoders.serialized_literal
30+
import geotrellis.raster.CellType
31+
import org.apache.spark.sql.Encoders
32+
import org.locationtech.geomesa.curve.XZ2SFC
33+
import org.locationtech.rasterframes.ref.{InMemoryRasterSource, RasterSource}
34+
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
35+
import org.scalatest.Inspectors
36+
37+
class XZ2IndexerSpec extends TestEnvironment with Inspectors {
38+
val testExtents = Seq(
39+
Extent(10, 10, 12, 12),
40+
Extent(9.0, 9.0, 13.0, 13.0),
41+
Extent(-180.0, -90.0, 180.0, 90.0),
42+
Extent(0.0, 0.0, 180.0, 90.0),
43+
Extent(0.0, 0.0, 20.0, 20.0),
44+
Extent(11.0, 11.0, 13.0, 13.0),
45+
Extent(9.0, 9.0, 11.0, 11.0),
46+
Extent(10.5, 10.5, 11.5, 11.5),
47+
Extent(11.0, 11.0, 11.0, 11.0),
48+
Extent(-180.0, -90.0, 8.0, 8.0),
49+
Extent(0.0, 0.0, 8.0, 8.0),
50+
Extent(9.0, 9.0, 9.5, 9.5),
51+
Extent(20.0, 20.0, 180.0, 90.0)
52+
)
53+
val sfc = XZ2SFC(18)
54+
val expected = testExtents.map(e => sfc.index(e.xmin, e.ymin, e.xmax, e.ymax))
55+
56+
def reproject(dst: CRS)(e: Extent): Extent = e.reproject(LatLng, dst)
57+
58+
describe("Spatial index generation") {
59+
import spark.implicits._
60+
it("should be SQL registered with docs") {
61+
checkDocs("rf_spatial_index")
62+
}
63+
it("should create index from Extent") {
64+
val crs: CRS = WebMercator
65+
val df = testExtents.map(reproject(crs)).map(Tuple1.apply).toDF("extent")
66+
val indexes = df.select(rf_spatial_index($"extent", serialized_literal(crs))).collect()
67+
68+
forEvery(indexes.zip(expected)) { case (i, e) =>
69+
i should be (e)
70+
}
71+
}
72+
it("should create index from Geometry") {
73+
val crs: CRS = LatLng
74+
val df = testExtents.map(_.jtsGeom).map(Tuple1.apply).toDF("extent")
75+
val indexes = df.select(rf_spatial_index($"extent", serialized_literal(crs))).collect()
76+
77+
forEvery(indexes.zip(expected)) { case (i, e) =>
78+
i should be (e)
79+
}
80+
}
81+
it("should create index from ProjectedRasterTile") {
82+
val crs: CRS = WebMercator
83+
val tile = TestData.randomTile(2, 2, CellType.fromName("uint8"))
84+
val prts = testExtents.map(reproject(crs)).map(ProjectedRasterTile(tile, _, crs))
85+
86+
implicit val enc = Encoders.tuple(ProjectedRasterTile.prtEncoder, Encoders.scalaInt)
87+
// The `id` here is to deal with Spark auto projecting single columns dataframes and needing to provide an encoder
88+
val df = prts.zipWithIndex.toDF("proj_raster", "id")
89+
val indexes = df.select(rf_spatial_index($"proj_raster")).collect()
90+
91+
forEvery(indexes.zip(expected)) { case (i, e) =>
92+
i should be (e)
93+
}
94+
}
95+
it("should create index from RasterSource") {
96+
val crs: CRS = WebMercator
97+
val tile = TestData.randomTile(2, 2, CellType.fromName("uint8"))
98+
val srcs = testExtents.map(reproject(crs)).map(InMemoryRasterSource(tile, _, crs): RasterSource).toDF("src")
99+
val indexes = srcs.select(rf_spatial_index($"src")).collect()
100+
101+
forEvery(indexes.zip(expected)) { case (i, e) =>
102+
i should be (e)
103+
}
104+
105+
}
106+
it("should work when CRS is LatLng") {
107+
108+
val df = testExtents.map(Tuple1.apply).toDF("extent")
109+
val crs: CRS = LatLng
110+
val indexes = df.select(rf_spatial_index($"extent", serialized_literal(crs))).collect()
111+
112+
forEvery(indexes.zip(expected)) { case (i, e) =>
113+
i should be (e)
114+
}
115+
}
116+
}
117+
}

0 commit comments

Comments
 (0)