Skip to content

Commit a5ed5ed

Browse files
authored
Merge pull request #420 from s22s/fix/419
Added the ability to do a raster_join on proj_raster types.
2 parents b396400 + 9385136 commit a5ed5ed

File tree

2 files changed

+28
-6
lines changed

2 files changed

+28
-6
lines changed

core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,37 @@ package org.locationtech.rasterframes.extensions
2323
import org.apache.spark.sql._
2424
import org.apache.spark.sql.functions._
2525
import org.locationtech.rasterframes._
26+
import org.locationtech.rasterframes.expressions.SpatialRelation
27+
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
2628
import org.locationtech.rasterframes.functions.reproject_and_merge
2729
import org.locationtech.rasterframes.util._
2830

2931
import scala.util.Random
3032

3133
object RasterJoin {
3234

35+
/** Perform a raster join on dataframes that each have proj_raster columns, or crs and extent explicitly included. */
3336
def apply(left: DataFrame, right: DataFrame): DataFrame = {
34-
val df = apply(left, right, left("extent"), left("crs"), right("extent"), right("crs"))
35-
df.drop(right("extent")).drop(right("crs"))
37+
def usePRT(d: DataFrame) =
38+
d.projRasterColumns.headOption
39+
.map(p => (rf_crs(p), rf_extent(p)))
40+
.orElse(Some(col("crs"), col("extent")))
41+
.map { case (crs, extent) =>
42+
val d2 = d.withColumn("crs", crs).withColumn("extent", extent)
43+
(d2, d2("crs"), d2("extent"))
44+
}
45+
.get
46+
47+
val (ldf, lcrs, lextent) = usePRT(left)
48+
val (rdf, rcrs, rextent) = usePRT(right)
49+
50+
apply(ldf, rdf, lextent, lcrs, rextent, rcrs)
3651
}
3752

3853
def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column): DataFrame = {
3954
val leftGeom = st_geometry(leftExtent)
4055
val rightGeomReproj = st_reproject(st_geometry(rightExtent), rightCRS, leftCRS)
41-
val joinExpr = st_intersects(leftGeom, rightGeomReproj)
56+
val joinExpr = new Column(SpatialRelation.Intersects(leftGeom.expr, rightGeomReproj.expr))
4257
apply(left, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS)
4358
}
4459

@@ -65,7 +80,7 @@ object RasterJoin {
6580
val leftAggCols = left.columns.map(s => first(left(s), true) as s)
6681
// On the RHS we collect result as a list.
6782
val rightAggCtx = Seq(collect_list(rightExtent) as rightExtent2, collect_list(rightCRS) as rightCRS2)
68-
val rightAggTiles = right.tileColumns.map(c => collect_list(c) as c.columnName)
83+
val rightAggTiles = right.tileColumns.map(c => collect_list(ExtractTile(c)) as c.columnName)
6984
val rightAggOther = right.notTileColumns
7085
.filter(n => n.columnName != rightExtent.columnName && n.columnName != rightCRS.columnName)
7186
.map(c => collect_list(c) as (c.columnName + "_agg"))

core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers {
154154

155155
total18 should be > 0.0
156156
total18 should be < total17
157-
158-
159157
}
160158

161159
it("should pass through ancillary columns") {
@@ -164,5 +162,14 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers {
164162
val joined = left.rasterJoin(right)
165163
joined.columns should contain allElementsOf Seq("left_id", "right_id_agg")
166164
}
165+
166+
it("should handle proj_raster types") {
167+
val df1 = Seq(one).toDF("one")
168+
val df2 = Seq(two).toDF("two")
169+
noException shouldBe thrownBy {
170+
val joined1 = df1.rasterJoin(df2)
171+
val joined2 = df2.rasterJoin(df1)
172+
}
173+
}
167174
}
168175
}

0 commit comments

Comments
 (0)