Skip to content

Commit 337b480

Browse files
committed
PR feedback
Signed-off-by: Jason T. Brown <[email protected]>
1 parent b33fdce commit 337b480

File tree

12 files changed

+120
-56
lines changed

12 files changed

+120
-56
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala

Lines changed: 21 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ package org.locationtech.rasterframes.expressions.localops
2323

2424
import geotrellis.raster.Tile
2525
import geotrellis.raster.resample._
26-
import geotrellis.raster.resample.{Max RMax, Min RMin}
26+
import geotrellis.raster.resample.{ResampleMethod GTResampleMethod, Max RMax, Min RMin}
2727
import org.apache.spark.sql.Column
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
@@ -34,6 +34,7 @@ import org.apache.spark.sql.functions.lit
3434
import org.apache.spark.sql.rf.TileUDT
3535
import org.apache.spark.sql.types.{DataType, StringType}
3636
import org.apache.spark.unsafe.types.UTF8String
37+
import org.locationtech.rasterframes.util.ResampleMethod
3738
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3839
import org.locationtech.rasterframes.expressions.{fpTile, row}
3940
import org.locationtech.rasterframes.expressions.DynamicExtractors._
@@ -47,40 +48,21 @@ abstract class ResampleBase(left: Expression, right: Expression, method: Express
4748
override def dataType: DataType = left.dataType
4849
override def children: Seq[Expression] = Seq(left, right, method)
4950

50-
def targetFloatIfNeeded(t: Tile, method: ResampleMethod): Tile =
51+
def targetFloatIfNeeded(t: Tile, method: GTResampleMethod): Tile =
5152
method match {
5253
case NearestNeighbor | Mode | RMax | RMin | Sum t
5354
case _ fpTile(t)
5455
}
5556

56-
def stringToMethod(methodName: String): ResampleMethod =
57-
methodName.toLowerCase().trim().replaceAll("_", "") match {
58-
case "nearestneighbor" | "nearest" NearestNeighbor
59-
case "bilinear" Bilinear
60-
case "cubicconvolution" CubicConvolution
61-
case "cubicspline" CubicSpline
62-
case "lanczos" | "lanzos" Lanczos
63-
// aggregates
64-
case "average" Average
65-
case "mode" Mode
66-
case "median" Median
67-
case "max" RMax
68-
case "min" RMin
69-
case "sum" Sum
70-
}
71-
7257
// These methods define the core algorithms to be used.
73-
def op(left: Tile, right: Tile, method: String): Tile = {
74-
val m = stringToMethod(method)
75-
targetFloatIfNeeded(left, m)
76-
.resample(right.cols, right.rows, m)
77-
}
58+
def op(left: Tile, right: Tile, method: GTResampleMethod): Tile =
59+
op(left, right.cols, right.rows, method)
7860

79-
def op(left: Tile, right: Double, method: String): Tile = {
80-
val m = stringToMethod(method)
81-
targetFloatIfNeeded(left, m)
82-
.resample((left.cols * right).toInt, (left.rows * right).toInt, m)
83-
}
61+
def op(left: Tile, right: Double, method: GTResampleMethod): Tile =
62+
op(left, (left.cols * right).toInt, (left.rows * right).toInt, method)
63+
64+
def op(tile: Tile, newCols: Int, newRows: Int, method: GTResampleMethod): Tile =
65+
targetFloatIfNeeded(tile, method).resample(newCols, newRows, method)
8466

8567
override def checkInputDataTypes(): TypeCheckResult = {
8668
// copypasta from BinaryLocalRasterOp
@@ -102,11 +84,16 @@ abstract class ResampleBase(left: Expression, right: Expression, method: Express
10284
val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(input1))
10385
val methodString = input3.asInstanceOf[UTF8String].toString
10486

87+
val resamplingMethod = methodString match {
88+
case ResampleMethod(mm) => mm
89+
case _ => throw new IllegalArgumentException("Unrecognized resampling method specified")
90+
}
91+
10592
val result: Tile = tileOrNumberExtractor(right.dataType)(input2) match {
10693
// in this case we expect the left and right contexts to vary. no warnings raised.
107-
case TileArg(rightTile, _) op(leftTile, rightTile, methodString)
108-
case DoubleArg(d) op(leftTile, d, methodString)
109-
case IntegerArg(i) op(leftTile, i.toDouble, methodString)
94+
case TileArg(rightTile, _) op(leftTile, rightTile, resamplingMethod)
95+
case DoubleArg(d) op(leftTile, d, resamplingMethod)
96+
case IntegerArg(i) op(leftTile, i.toDouble, resamplingMethod)
11097
}
11198

11299
// reassemble the leftTile with its context. Note that this operation does not change Extent and CRS
@@ -177,7 +164,9 @@ object Resample {
177164
> SELECT _FUNC_(tile1, tile2);
178165
...""")
179166
case class ResampleNearest(tile: Expression, target: Expression)
180-
extends ResampleBase(tile, target, Literal("nearest"))
167+
extends ResampleBase(tile, target, Literal("nearest")) {
168+
override val nodeName: String = "rf_resample_nearest"
169+
}
181170
object ResampleNearest {
182171
def apply(tile: Column, target: Column): Column =
183172
new Column(ResampleNearest(tile.expr, target.expr))

core/src/main/scala/org/locationtech/rasterframes/extensions/DataFrameMethods.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.locationtech.rasterframes.extensions
2323

2424
import geotrellis.proj4.CRS
2525
import geotrellis.layer._
26+
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod => GTResampleMethod}
2627
import geotrellis.util.MethodExtensions
2728
import geotrellis.vector.Extent
2829
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -37,6 +38,7 @@ import org.locationtech.rasterframes.util._
3738
import org.locationtech.rasterframes.{MetadataKeys, RasterFrameLayer}
3839
import spray.json.JsonFormat
3940
import org.locationtech.rasterframes.util.JsonCodecs._
41+
4042
import scala.util.Try
4143

4244
/**
@@ -168,7 +170,7 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
168170
* @param resampleMethod string indicating method to use for resampling.
169171
* @return joined dataframe
170172
*/
171-
def rasterJoin(right: DataFrame, resampleMethod: String = "nearest"): DataFrame = RasterJoin(self, right, resampleMethod, None)
173+
def rasterJoin(right: DataFrame, resampleMethod: GTResampleMethod = NearestNeighbor): DataFrame = RasterJoin(self, right, resampleMethod, None)
172174

173175
/**
174176
* Performs a jeft join on the dataframe `right` to this one, reprojecting and merging tiles as necessary.
@@ -187,7 +189,7 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
187189
* @param resampleMethod string indicating method to use for resampling.
188190
* @return joined dataframe
189191
*/
190-
def rasterJoin(right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: String): DataFrame =
192+
def rasterJoin(right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod): DataFrame =
191193
RasterJoin(self, right, leftExtent, leftCRS, rightExtent, rightCRS, resampleMethod, None)
192194

193195
/**
@@ -205,7 +207,7 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
205207
* @param resampleMethod string indicating method to use for resampling.
206208
* @return joined dataframe
207209
*/
208-
def rasterJoin(right: DataFrame, joinExpr: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: String): DataFrame =
210+
def rasterJoin(right: DataFrame, joinExpr: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod): DataFrame =
209211
RasterJoin(self, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS, resampleMethod, None)
210212

211213

core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
package org.locationtech.rasterframes.extensions
2323
import geotrellis.raster.Dimensions
24+
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod => GTResampleMethod}
2425
import org.apache.spark.sql._
2526
import org.apache.spark.sql.functions._
2627
import org.apache.spark.sql.types.DataType
@@ -36,7 +37,7 @@ import scala.util.Random
3637
object RasterJoin {
3738

3839
/** Perform a raster join on dataframes that each have proj_raster columns, or crs and extent explicitly included. */
39-
def apply(left: DataFrame, right: DataFrame, resampleMethod: String, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
40+
def apply(left: DataFrame, right: DataFrame, resampleMethod: GTResampleMethod, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
4041
def usePRT(d: DataFrame) =
4142
d.projRasterColumns.headOption
4243
.map(p => (rf_crs(p), rf_extent(p)))
@@ -53,7 +54,7 @@ object RasterJoin {
5354
apply(ldf, rdf, lextent, lcrs, rextent, rcrs, resampleMethod, fallbackDimensions)
5455
}
5556

56-
def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: String, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
57+
def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
5758
val leftGeom = st_geometry(leftExtent)
5859
val rightGeomReproj = st_reproject(st_geometry(rightExtent), rightCRS, leftCRS)
5960
val joinExpr = new Column(SpatialRelation.Intersects(leftGeom.expr, rightGeomReproj.expr))
@@ -64,7 +65,7 @@ object RasterJoin {
6465
require(extractor.isDefinedAt(col.expr.dataType), s"Expected column ${col} to be of type $description, but was ${col.expr.dataType}.")
6566
}
6667

67-
def apply(left: DataFrame, right: DataFrame, joinExprs: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: String = "nearest", fallbackDimensions: Option[Dimensions[Int]] = None): DataFrame = {
68+
def apply(left: DataFrame, right: DataFrame, joinExprs: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod = NearestNeighbor, fallbackDimensions: Option[Dimensions[Int]] = None): DataFrame = {
6869
// Convert resolved column into a symbolic one.
6970
def unresolved(c: Column): Column = col(c.columnName)
7071

@@ -84,14 +85,13 @@ object RasterJoin {
8485
val rightExtent2 = id + "extent"
8586
// Post aggregation right crs. We create a new name.
8687
val rightCRS2 = id + "crs"
87-
val method = id + "method"
8888

8989
// Gathering up various expressions we'll use to construct the result.
9090
// After joining We will be doing a groupBy the LHS. We have to define the aggregations to perform after the groupBy.
9191
// On the LHS we just want the first thing (subsequent ones should be identical.
9292
val leftAggCols = left.columns.map(s => first(left(s), true) as s)
9393
// On the RHS we collect result as a list.
94-
val rightAggCtx = Seq(collect_list(rightExtent) as rightExtent2, collect_list(rf_crs(rightCRS)) as rightCRS2, lit(resampleMethod) as method)
94+
val rightAggCtx = Seq(collect_list(rightExtent) as rightExtent2, collect_list(rf_crs(rightCRS)) as rightCRS2)
9595
val rightAggTiles = right.tileColumns.map(c => collect_list(ExtractTile(c)) as c.columnName)
9696
val rightAggOther = right.notTileColumns
9797
.filter(n => n.columnName != rightExtent.columnName && n.columnName != rightCRS.columnName)
@@ -110,7 +110,7 @@ object RasterJoin {
110110

111111
val reprojCols = rightAggTiles.map(t => {
112112
reproject_and_merge(
113-
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), destDims, lit(resampleMethod)
113+
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), destDims, lit(ResampleMethod(resampleMethod))
114114
) as t.columnName
115115
})
116116

core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package org.locationtech.rasterframes.extensions
2323

2424
import geotrellis.layer._
25+
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod => GTResampleMethod}
2526
import org.apache.spark.sql._
2627
import org.apache.spark.sql.functions.broadcast
2728
import org.locationtech.rasterframes._
@@ -30,7 +31,7 @@ import org.locationtech.rasterframes.util._
3031

3132
/** Algorithm for projecting an arbitrary RasterFrame into a layer with consistent CRS and gridding. */
3233
object ReprojectToLayer {
33-
def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey]): RasterFrameLayer = {
34+
def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey], resampleMethod: Option[GTResampleMethod] = None): RasterFrameLayer = {
3435
// create a destination dataframe with crs and extend columns
3536
// use RasterJoin to do the rest.
3637
val gb = tlm.tileBounds
@@ -48,7 +49,7 @@ object ReprojectToLayer {
4849
// Create effectively a target RasterFrame, but with no tiles.
4950
val dest = gridItems.toSeq.toDF(SPATIAL_KEY_COLUMN.columnName, EXTENT_COLUMN.columnName, CRS_COLUMN.columnName)
5051

51-
val joined = RasterJoin(broadcast(dest), df, "nearest", Some(tlm.tileLayout.tileDimensions))
52+
val joined = RasterJoin(broadcast(dest), df, resampleMethod.getOrElse(NearestNeighbor), Some(tlm.tileLayout.tileDimensions))
5253

5354
joined.asLayer(SPATIAL_KEY_COLUMN, tlm)
5455
}

core/src/main/scala/org/locationtech/rasterframes/package.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.locationtech
2323
import com.typesafe.config.ConfigFactory
2424
import com.typesafe.scalalogging.Logger
2525
import geotrellis.raster.{Dimensions, Tile, TileFeature, isData}
26+
import geotrellis.raster.resample._
2627
import geotrellis.layer._
2728
import geotrellis.spark.ContextRDD
2829
import org.apache.spark.rdd.RDD
@@ -147,4 +148,7 @@ package object rasterframes extends StandardColumns
147148
else isCellTrue(t.get(col, row))
148149

149150

151+
152+
153+
150154
}

core/src/main/scala/org/locationtech/rasterframes/util/package.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,43 @@ package object util extends DataFrameRenderers {
186186
def apply() = mapping.keys.toSeq
187187
}
188188

189+
object ResampleMethod {
190+
import geotrellis.raster.resample.{ResampleMethod GTResampleMethod, _}
191+
def unapply(name: String): Option[GTResampleMethod] = {
192+
name.toLowerCase().trim().replaceAll("_", "") match {
193+
case "nearestneighbor" | "nearest" Some(NearestNeighbor)
194+
case "bilinear" Some(Bilinear)
195+
case "cubicconvolution" Some(CubicConvolution)
196+
case "cubicspline" Some(CubicSpline)
197+
case "lanczos" | "lanzos" Some(Lanczos)
198+
// aggregates
199+
case "average" Some(Average)
200+
case "mode" Some(Mode)
201+
case "median" Some(Median)
202+
case "max" Some(Max)
203+
case "min" Some(Min)
204+
case "sum" Some(Sum)
205+
case _ => None
206+
}
207+
}
208+
def apply(gtr: GTResampleMethod): String = {
209+
gtr match {
210+
case NearestNeighbor "nearest"
211+
case Bilinear "bilinear"
212+
case CubicConvolution "cubicconvolution"
213+
case CubicSpline "cubicspline"
214+
case Lanczos "lanczos"
215+
case Average "average"
216+
case Mode "mode"
217+
case Median "median"
218+
case Max "max"
219+
case Min "min"
220+
case Sum "sum"
221+
case _ throw new IllegalArgumentException(s"Unrecogized ResampleMethod ${gtr.toString()}")
222+
}
223+
}
224+
}
225+
189226
private[rasterframes]
190227
def toParquetFriendlyColumnName(name: String) = name.replaceAll("[ ,;{}()\n\t=]", "_")
191228

core/src/test/scala/org/locationtech/rasterframes/RasterJoinSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
package org.locationtech.rasterframes
2323

24-
import geotrellis.raster.resample.Bilinear
24+
import geotrellis.raster.resample._
2525
import geotrellis.raster.testkit.RasterMatchers
2626
import geotrellis.raster.{Dimensions, IntConstantNoDataCellType, Raster, Tile}
2727
import org.apache.spark.SparkConf
@@ -175,8 +175,8 @@ class RasterJoinSpec extends TestEnvironment with TestData with RasterMatchers {
175175
it("should honor resampling options") {
176176
// test case. replicate existing test condition and check that resampling option results in different output
177177
val filterExpr = st_intersects(rf_geometry($"tile"), st_point(704940.0, 4251130.0))
178-
val result = b4nativeRf.rasterJoin(b4warpedRf.withColumnRenamed("tile2", "nearest"), "nearest")
179-
.rasterJoin(b4warpedRf.withColumnRenamed("tile2", "CubicSpline"), "cubicSpline")
178+
val result = b4nativeRf.rasterJoin(b4warpedRf.withColumnRenamed("tile2", "nearest"), NearestNeighbor)
179+
.rasterJoin(b4warpedRf.withColumnRenamed("tile2", "CubicSpline"), CubicSpline)
180180
.withColumn("diff", rf_local_subtract($"nearest", $"cubicSpline"))
181181
.agg(rf_agg_stats($"diff") as "stats")
182182
.select($"stats.min" as "min", $"stats.max" as "max")

docs/src/main/paradox/reference.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,14 @@ In __SQL__, three parameters are required for `rf_resample`.:
168168
Tile rf_resample_nearest(Tile tile, Tile shape_tile)
169169

170170

171-
Change the tile dimension by upsampling or downsampling. Passing a numeric `factor` will scale the number of columns and rows in the tile: 1.0 is the same number of columns and row; less than one downsamples the tile; and greater than one upsamples the tile. Passing a tile as the second argument resamples such that the output has the same dimension (number of columns and rows) as `shape_tile`. Resampling methods can be one of: nearest_neighbor, bilinear, cubic_convolution, cubic_spline, lanczos, average, mode, median, max, min, or sum.
171+
Change the tile dimension by upsampling or downsampling. Passing a numeric `factor` will scale the number of columns and rows in the tile: 1.0 is the same number of columns and row; less than one downsamples the tile; and greater than one upsamples the tile. Passing a tile as the second argument resamples such that the output has the same dimension (number of columns and rows) as `shape_tile`.
172172

173-
Note the last six options apply aggregates when downsampling. For example a 0.25 factor and `max` method returns the maximum value in a 4x4 neighborhood.
173+
There are two categories: point resampling methods and aggregating resampling methods.
174+
Resampling method to use can be specified by one of the following strings, possibly in a column.
175+
The point resampling methods are: `"nearest_neighbor"`, `"bilinear"`, `"cubic_convolution"`, `"cubic_spline"`, and `"lanczos"`.
176+
The aggregating resampling methods are: `"average"`, `"mode"`, `"median"`, `"max"`, "`min`", or `"sum"`.
177+
178+
Note the aggregating methods are intended for downsampling. For example a 0.25 factor and `max` method returns the maximum value in a 4x4 neighborhood.
174179

175180
If `tile` has an integer `CellType`, the returned tile will be coerced to a floating point with the following methods: bilinear, cubic_convolution, cubic_spline, lanczos, average, and median.
176181

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
* Added `method_name` parameter to the `rf_resample` method.
88
* __BREAKING__: In SQL, the function `rf_resample` now takes 3 arguments. You can use `rf_resample_nearest` with two arguments or refactor to `rf_resample(t, v, "nearest")`.
9-
* Added resample method parameter to SQL and Python APIs. This will affect the reprojection of right hand side tiles.
9+
* Added resample method parameter to SQL and Python APIs. @ref:[See updated docs](raster-join.md).
1010

1111

1212
### 0.9.0

pyrasterframes/src/main/python/docs/raster-join.pymd

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ The following optional arguments are allowed:
5353
* `right_extent` - the column on the right-hand DataFrame giving the [extent][extent] of the tile columns
5454
* `right_crs` - the column on the right-hand DataFrame giving the [CRS][CRS] of the tile columns
5555
* `join_exprs` - a single column expression as would be used in the [`on` parameter of `join`](https://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.join)
56-
* `resampling_method` - resampling algorithm to use in reprojection of right-hand tile columns. A string that is one of:nearest_neighbor, bilinear, cubic_convolution, cubic_spline, lanczos, average, mode, median, max, min, or sum.
56+
* `resampling_method` - resampling algorithm to use in reprojection of right-hand tile column
57+
5758

5859

5960
Note that the `join_exprs` will override the join behavior described above. By default the expression is equivalent to:
@@ -65,6 +66,10 @@ st_intersects(
6566
)
6667
```
6768

69+
Resampling method to use can be specified by passing one of the following strings into `resampling_method` parameter.
70+
The point resampling methods are: `"nearest_neighbor"`, `"bilinear"`, `"cubic_convolution"`, `"cubic_spline"`, and `"lanczos"`.
71+
The aggregating resampling methods are: `"average"`, `"mode"`, `"median"`, `"max"`, "`min`", or `"sum"`.
72+
Note the aggregating methods are intended for downsampling. For example a 0.25 factor and `max` method returns the maximum value in a 4x4 neighborhood.
6873

6974

7075
[CRS]: concepts.md#coordinate-reference-system--crs

0 commit comments

Comments
 (0)