Skip to content

Commit e3af5c8

Browse files
committed
Fixed regression in handling of nulls in RasterJoin utility UDF.
1 parent 606a977 commit e3af5c8

File tree

2 files changed

+93
-26
lines changed

2 files changed

+93
-26
lines changed

core/src/main/scala/org/locationtech/rasterframes/functions/package.scala

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -99,38 +99,38 @@ package object functions {
9999
private[rasterframes] val tileOnes: (Int, Int, String) => Tile = (cols, rows, cellTypeName) =>
100100
makeConstantTile(1, cols, rows, cellTypeName)
101101

102-
val reproject_and_merge_f: (Row, CRS, Seq[Tile], Seq[Row], Seq[CRS], Row, String) => Tile = (leftExtentEnc: Row, leftCRSEnc: CRS, tiles: Seq[Tile], rightExtentEnc: Seq[Row], rightCRSEnc: Seq[CRS], leftDimsEnc: Row, resampleMethod: String) => {
103-
if (tiles.isEmpty) null
102+
val reproject_and_merge_f: (Row, CRS, Seq[Tile], Seq[Row], Seq[CRS], Row, String) => Option[Tile] = (leftExtentEnc: Row, leftCRS: CRS, tiles: Seq[Tile], rightExtentEnc: Seq[Row], rightCRSs: Seq[CRS], leftDimsEnc: Row, resampleMethod: String) => {
103+
if (tiles.isEmpty) None
104104
else {
105-
require(tiles.length == rightExtentEnc.length && tiles.length == rightCRSEnc.length, "size mismatch")
105+
require(tiles.length == rightExtentEnc.length && tiles.length == rightCRSs.length, "size mismatch")
106106

107-
val leftExtent: Extent = leftExtentEnc.as[Extent]
108-
val leftDims: Dimensions[Int] = leftDimsEnc.as[Dimensions[Int]]
109-
val leftCRS: CRS = leftCRSEnc
110-
lazy val rightExtents: Seq[Extent] = rightExtentEnc.map(_.as[Extent])
111-
lazy val rightCRSs: Seq[CRS] = rightCRSEnc
107+
val leftExtent = Option(leftExtentEnc).map(_.as[Extent])
108+
val leftDims = Option(leftDimsEnc).map(_.as[Dimensions[Int]])
109+
lazy val rightExtents = rightExtentEnc.map(_.as[Extent])
112110
lazy val resample = resampleMethod match {
113111
case ResampleMethod(mm) => mm
114112
case _ => throw new IllegalArgumentException(s"Unable to parse ResampleMethod for ${resampleMethod}.")
115113
}
116-
117-
if (leftExtent == null || leftDims == null || leftCRS == null) null
118-
else {
119-
120-
val cellType = tiles.map(_.cellType).reduceOption(_ union _).getOrElse(tiles.head.cellType)
121-
122-
// TODO: how to allow control over... expression?
123-
val projOpts = Reproject.Options(resample)
124-
val dest: Tile = ArrayTile.empty(cellType, leftDims.cols, leftDims.rows)
125-
//is there a GT function to do all this?
126-
tiles.zip(rightExtents).zip(rightCRSs).map {
127-
case ((tile, extent), crs) =>
128-
tile.reproject(extent, crs, leftCRS, projOpts)
129-
}.foldLeft(dest)((d, t) =>
130-
d.merge(leftExtent, t.extent, t.tile, projOpts.method)
131-
)
132-
}
133-
}
114+
(leftExtent, leftDims, Option(leftCRS))
115+
.zipped
116+
.map((leftExtent, leftDims, leftCRS) => {
117+
val cellType = tiles
118+
.map(_.cellType)
119+
.reduceOption(_ union _)
120+
.getOrElse(tiles.head.cellType)
121+
122+
// TODO: how to allow control over... expression?
123+
val projOpts = Reproject.Options(resample)
124+
val dest: Tile = ArrayTile.empty(cellType, leftDims.cols, leftDims.rows)
125+
//is there a GT function to do all this?
126+
tiles.zip(rightExtents).zip(rightCRSs).map {
127+
case ((tile, extent), crs) =>
128+
tile.reproject(extent, crs, leftCRS, projOpts)
129+
}.foldLeft(dest)((d, t) =>
130+
d.merge(leftExtent, t.extent, t.tile, projOpts.method)
131+
)
132+
})
133+
}.headOption
134134
}
135135

136136
// NB: Don't be tempted to make this a `val`. Spark will barf if `withRasterFrames` hasn't been called first.

core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import geotrellis.proj4.CRS
2525
import geotrellis.raster.resample._
2626
import geotrellis.raster.testkit.RasterMatchers
2727
import geotrellis.raster.{Dimensions, IntConstantNoDataCellType, Raster, Tile}
28+
import geotrellis.vector.Extent
2829
import org.apache.spark.SparkConf
2930
import org.apache.spark.sql.functions._
3031
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate
@@ -195,6 +196,72 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers {
195196
// This just tests that the tiles are not identical
196197
result.getAs[Double]("min") should be > (0.0)
197198
}
199+
200+
// Failed to execute user defined function(package$$$Lambda$4417/0x00000008019e2840: (struct<xmax:double,xmin:double,ymax:double,ymin:double>, string, array<struct<cellType:string,cols:int,rows:int,cells:binary,ref:struct<source:struct<raster_source_kryo:binary>,bandIndex:int,subextent:struct<xmin:double,ymin:double,xmax:double,ymax:double>,subgrid:struct<colMin:int,rowMin:int,colMax:int,rowMax:int>>>>, array<struct<xmax:double,xmin:double,ymax:double,ymin:double>>, array<string>, struct<cols:int,rows:int>, string) => struct<cellType:string,cols:int,rows:int,cells:binary,ref:struct<source:struct<raster_source_kryo:binary>,bandIndex:int,subextent:struct<xmin:double,ymin:double,xmax:double,ymax:double>,subgrid:struct<colMin:int,rowMin:int,colMax:int,rowMax:int>>>)
201+
202+
it("should raster join with null left head") {
203+
// https://github.com/locationtech/rasterframes/issues/462
204+
val prt = TestData.projectedRasterTile(
205+
10, 10, 1,
206+
Extent(0.0, 0.0, 40.0, 40.0),
207+
CRS.fromEpsgCode(32611),
208+
)
209+
210+
val left = Seq(
211+
(1, "a", prt.tile, prt.tile, prt.extent, prt.crs),
212+
(1, "b", null, prt.tile, prt.extent, prt.crs)
213+
).toDF("i", "j", "t", "u", "e", "c")
214+
215+
val right = Seq(
216+
(1, prt.tile, prt.extent, prt.crs)
217+
).toDF("i", "r", "e", "c")
218+
219+
val joined = left.rasterJoin(right,
220+
left("i") === right("i"),
221+
left("e"), left("c"),
222+
right("e"), right("c"),
223+
NearestNeighbor
224+
)
225+
joined.count() should be (2)
226+
227+
// In the case where the head column is null it will be passed thru
228+
val t1 = joined
229+
.select(isnull($"t"))
230+
.filter($"j" === "b")
231+
.first()
232+
233+
t1.getBoolean(0) should be(true)
234+
235+
// The right hand side tile should get dimensions from col `u` however
236+
val collected = joined.select(rf_dimensions($"r")).collect()
237+
collected.headOption should be (Some(Dimensions(10, 10)))
238+
239+
// If there is no non-null tile on the LHS then the RHS is ill defined
240+
val joinedNoLeftTile = left
241+
.drop($"u")
242+
.rasterJoin(right,
243+
left("i") === right("i"),
244+
left("e"), left("c"),
245+
right("e"), right("c"),
246+
NearestNeighbor
247+
)
248+
joinedNoLeftTile.count() should be (2)
249+
250+
// If there is no non-null tile on the LHS then the RHS is ill defined
251+
val t2 = joinedNoLeftTile
252+
.select(isnull($"t"))
253+
.filter($"j" === "b")
254+
.first()
255+
t2.getBoolean(0) should be(true)
256+
257+
// Because no non-null tile col on Left side, the right side is null too
258+
val t3 = joinedNoLeftTile
259+
.select(isnull($"r"))
260+
.filter($"j" === "b")
261+
.first()
262+
t3.getBoolean(0) should be(true)
263+
}
264+
198265
}
199266

200267
override def additionalConf: SparkConf = super.additionalConf.set("spark.sql.codegen.comments", "true")

0 commit comments

Comments
 (0)