@@ -23,22 +23,37 @@ package org.locationtech.rasterframes.extensions
2323import org .apache .spark .sql ._
2424import org .apache .spark .sql .functions ._
2525import org .locationtech .rasterframes ._
26+ import org .locationtech .rasterframes .expressions .SpatialRelation
27+ import org .locationtech .rasterframes .expressions .accessors .ExtractTile
2628import org .locationtech .rasterframes .functions .reproject_and_merge
2729import org .locationtech .rasterframes .util ._
2830
2931import scala .util .Random
3032
3133object RasterJoin {
3234
35+ /** Perform a raster join on dataframes that each have proj_raster columns, or crs and extent explicitly included. */
3336 def apply (left : DataFrame , right : DataFrame ): DataFrame = {
34- val df = apply(left, right, left(" extent" ), left(" crs" ), right(" extent" ), right(" crs" ))
35- df.drop(right(" extent" )).drop(right(" crs" ))
37+ def usePRT (d : DataFrame ) =
38+ d.projRasterColumns.headOption
39+ .map(p => (rf_crs(p), rf_extent(p)))
40+ .orElse(Some (col(" crs" ), col(" extent" )))
41+ .map { case (crs, extent) =>
42+ val d2 = d.withColumn(" crs" , crs).withColumn(" extent" , extent)
43+ (d2, d2(" crs" ), d2(" extent" ))
44+ }
45+ .get
46+
47+ val (ldf, lcrs, lextent) = usePRT(left)
48+ val (rdf, rcrs, rextent) = usePRT(right)
49+
50+ apply(ldf, rdf, lextent, lcrs, rextent, rcrs)
3651 }
3752
3853 def apply (left : DataFrame , right : DataFrame , leftExtent : Column , leftCRS : Column , rightExtent : Column , rightCRS : Column ): DataFrame = {
3954 val leftGeom = st_geometry(leftExtent)
4055 val rightGeomReproj = st_reproject(st_geometry(rightExtent), rightCRS, leftCRS)
41- val joinExpr = st_intersects( leftGeom, rightGeomReproj)
56+ val joinExpr = new Column ( SpatialRelation . Intersects ( leftGeom.expr , rightGeomReproj.expr) )
4257 apply(left, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS)
4358 }
4459
@@ -65,7 +80,7 @@ object RasterJoin {
6580 val leftAggCols = left.columns.map(s => first(left(s), true ) as s )
6681 // On the RHS we collect result as a list.
6782 val rightAggCtx = Seq (collect_list(rightExtent) as rightExtent2 , collect_list(rightCRS) as rightCRS2 )
68- val rightAggTiles = right.tileColumns.map(c => collect_list(c ) as c .columnName)
83+ val rightAggTiles = right.tileColumns.map(c => collect_list(ExtractTile (c) ) as c .columnName)
6984 val rightAggOther = right.notTileColumns
7085 .filter(n => n.columnName != rightExtent.columnName && n.columnName != rightCRS.columnName)
7186 .map(c => collect_list(c) as (c.columnName + " _agg" ))
0 commit comments