Skip to content

Commit 2cef64a

Browse files
authored
Merge pull request #495 from s22s/feature/resampling_method
Choose ResampleMethod in rf_resample and RasterJoin
2 parents 143e98f + 5f214c7 commit 2cef64a

File tree

19 files changed

+421
-76
lines changed

19 files changed

+421
-76
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala

Lines changed: 136 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,53 +22,157 @@
2222
package org.locationtech.rasterframes.expressions.localops
2323

2424
import geotrellis.raster.Tile
25-
import geotrellis.raster.resample.NearestNeighbor
25+
import geotrellis.raster.resample._
26+
import geotrellis.raster.resample.{ResampleMethod GTResampleMethod, Max RMax, Min RMin}
2627
import org.apache.spark.sql.Column
2728
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
30+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2831
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
29-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
32+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression}
3033
import org.apache.spark.sql.functions.lit
31-
import org.locationtech.rasterframes.expressions.BinaryLocalRasterOp
32-
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
34+
import org.apache.spark.sql.rf.TileUDT
35+
import org.apache.spark.sql.types.{DataType, StringType}
36+
import org.apache.spark.unsafe.types.UTF8String
37+
import org.locationtech.rasterframes.util.ResampleMethod
38+
import org.locationtech.rasterframes.encoders.CatalystSerializer._
39+
import org.locationtech.rasterframes.expressions.{fpTile, row}
40+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
41+
42+
43+
abstract class ResampleBase(left: Expression, right: Expression, method: Expression)
44+
extends TernaryExpression
45+
with CodegenFallback with Serializable {
3346

34-
@ExpressionDescription(
35-
usage = "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.",
36-
arguments = """
37-
Arguments:
38-
* tile - tile
39-
* rhs - scalar or tile to match dimension""",
40-
examples = """
41-
Examples:
42-
> SELECT _FUNC_(tile, 2.0);
43-
...
44-
> SELECT _FUNC_(tile1, tile2);
45-
..."""
46-
)
47-
case class Resample(left: Expression, right: Expression) extends BinaryLocalRasterOp
48-
with CodegenFallback {
4947
override val nodeName: String = "rf_resample"
50-
override protected def op(left: Tile, right: Tile): Tile = left.resample(right.cols, right.rows, NearestNeighbor)
51-
override protected def op(left: Tile, right: Double): Tile = left.resample((left.cols * right).toInt,
52-
(left.rows * right).toInt, NearestNeighbor)
53-
override protected def op(left: Tile, right: Int): Tile = op(left, right.toDouble)
48+
override def dataType: DataType = left.dataType
49+
override def children: Seq[Expression] = Seq(left, right, method)
50+
51+
def targetFloatIfNeeded(t: Tile, method: GTResampleMethod): Tile =
52+
method match {
53+
case NearestNeighbor | Mode | RMax | RMin | Sum t
54+
case _ fpTile(t)
55+
}
56+
57+
// These methods define the core algorithms to be used.
58+
def op(left: Tile, right: Tile, method: GTResampleMethod): Tile =
59+
op(left, right.cols, right.rows, method)
60+
61+
def op(left: Tile, right: Double, method: GTResampleMethod): Tile =
62+
op(left, (left.cols * right).toInt, (left.rows * right).toInt, method)
63+
64+
def op(tile: Tile, newCols: Int, newRows: Int, method: GTResampleMethod): Tile =
65+
targetFloatIfNeeded(tile, method).resample(newCols, newRows, method)
66+
67+
override def checkInputDataTypes(): TypeCheckResult = {
68+
// copypasta from BinaryLocalRasterOp
69+
if (!tileExtractor.isDefinedAt(left.dataType)) {
70+
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
71+
}
72+
else if (!tileOrNumberExtractor.isDefinedAt(right.dataType)) {
73+
TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a compatible type.")
74+
} else method.dataType match {
75+
case StringType TypeCheckSuccess
76+
case _ TypeCheckFailure(s"Cannot interpret value of type `${method.dataType.simpleString}` for resampling method; please provide a String method name.")
77+
}
78+
}
79+
80+
override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
81+
// more copypasta from BinaryLocalRasterOp
82+
implicit val tileSer = TileUDT.tileSerializer
83+
84+
val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(input1))
85+
val methodString = input3.asInstanceOf[UTF8String].toString
86+
87+
val resamplingMethod = methodString match {
88+
case ResampleMethod(mm) => mm
89+
case _ => throw new IllegalArgumentException("Unrecognized resampling method specified")
90+
}
91+
92+
val result: Tile = tileOrNumberExtractor(right.dataType)(input2) match {
93+
// in this case we expect the left and right contexts to vary. no warnings raised.
94+
case TileArg(rightTile, _) op(leftTile, rightTile, resamplingMethod)
95+
case DoubleArg(d) op(leftTile, d, resamplingMethod)
96+
case IntegerArg(i) op(leftTile, i.toDouble, resamplingMethod)
97+
}
98+
99+
// reassemble the leftTile with its context. Note that this operation does not change Extent and CRS
100+
leftCtx match {
101+
case Some(ctx) ctx.toProjectRasterTile(result).toInternalRow
102+
case None result.toInternalRow
103+
}
104+
}
54105

55106
override def eval(input: InternalRow): Any = {
56107
if(input == null) null
57108
else {
58109
val l = left.eval(input)
59110
val r = right.eval(input)
60-
if (l == null && r == null) null
61-
else if (l == null) r
62-
else if (r == null && tileExtractor.isDefinedAt(right.dataType)) l
63-
else if (r == null) null
64-
else nullSafeEval(l, r)
111+
val m = method.eval(input)
112+
if (m == null) null // no method, return null
113+
else if (l == null) null // no l tile, return null
114+
else if (r == null) l // no target tile or factor, return l without changin it
115+
else nullSafeEval(l, r, m)
65116
}
66117
}
118+
67119
}
68-
object Resample{
69-
def apply(left: Column, right: Column): Column =
70-
new Column(Resample(left.expr, right.expr))
120+
121+
@ExpressionDescription(
122+
usage = "_FUNC_(tile, factor, method_name) - Resample tile to different dimension based on scalar `factor` or a tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses resampling method named in the `method_name`." +
123+
"Methods average, mode, median, max, min, and sum aggregate over cells when downsampling",
124+
arguments = """
125+
Arguments:
126+
* tile - tile
127+
* factor - scalar or tile to match dimension
128+
* method_name - one the following options: nearest_neighbor, bilinear, cubic_convolution, cubic_spline, lanczos, average, mode, median, max, min, sum
129+
This option can be CamelCase as well
130+
""",
131+
examples = """
132+
Examples:
133+
> SELECT _FUNC_(tile, 0.2, median);
134+
...
135+
> SELECT _FUNC_(tile1, tile2, lit("cubic_spline"));
136+
..."""
137+
)
138+
case class Resample(left: Expression, factor: Expression, method: Expression)
139+
extends ResampleBase(left, factor, method)
140+
141+
object Resample {
142+
def apply(left: Column, right: Column, methodName: String): Column =
143+
new Column(Resample(left.expr, right.expr, lit(methodName).expr))
144+
145+
def apply(left: Column, right: Column, method: Column): Column =
146+
new Column(Resample(left.expr, right.expr, method.expr))
147+
148+
def apply[N: Numeric](left: Column, right: N, method: String): Column = new Column(Resample(left.expr, lit(right).expr, lit(method).expr))
149+
150+
def apply[N: Numeric](left: Column, right: N, method: Column): Column = new Column(Resample(left.expr, lit(right).expr, method.expr))
151+
152+
}
153+
154+
@ExpressionDescription(
155+
usage = "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.",
156+
arguments = """
157+
Arguments:
158+
* tile - tile
159+
* rhs - scalar or tile to match dimension""",
160+
examples = """
161+
Examples:
162+
> SELECT _FUNC_(tile, 2.0);
163+
...
164+
> SELECT _FUNC_(tile1, tile2);
165+
...""")
166+
case class ResampleNearest(tile: Expression, target: Expression)
167+
extends ResampleBase(tile, target, Literal("nearest")) {
168+
override val nodeName: String = "rf_resample_nearest"
169+
}
170+
object ResampleNearest {
171+
def apply(tile: Column, target: Column): Column =
172+
new Column(ResampleNearest(tile.expr, target.expr))
71173

72174
def apply[N: Numeric](tile: Column, value: N): Column =
73-
new Column(Resample(tile.expr, lit(value).expr))
175+
new Column(ResampleNearest(tile.expr, lit(value).expr))
74176
}
177+
178+

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ package object expressions {
109109
registry.registerExpression[ExpM1]("rf_expm1")
110110
registry.registerExpression[Sqrt]("rf_sqrt")
111111
registry.registerExpression[Resample]("rf_resample")
112+
registry.registerExpression[ResampleNearest]("rf_resample_nearest")
112113
registry.registerExpression[TileToArrayDouble]("rf_tile_to_array_double")
113114
registry.registerExpression[TileToArrayInt]("rf_tile_to_array_int")
114115
registry.registerExpression[DataCells]("rf_data_cells")

core/src/main/scala/org/locationtech/rasterframes/extensions/DataFrameMethods.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.locationtech.rasterframes.extensions
2323

2424
import geotrellis.proj4.CRS
2525
import geotrellis.layer._
26+
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod => GTResampleMethod}
2627
import geotrellis.util.MethodExtensions
2728
import geotrellis.vector.Extent
2829
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -37,6 +38,7 @@ import org.locationtech.rasterframes.util._
3738
import org.locationtech.rasterframes.{MetadataKeys, RasterFrameLayer}
3839
import spray.json.JsonFormat
3940
import org.locationtech.rasterframes.util.JsonCodecs._
41+
4042
import scala.util.Try
4143

4244
/**
@@ -165,9 +167,10 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
165167
* }}}
166168
*
167169
* @param right Right side of the join.
170+
* @param resampleMethod string indicating method to use for resampling.
168171
* @return joined dataframe
169172
*/
170-
def rasterJoin(right: DataFrame): DataFrame = RasterJoin(self, right, None)
173+
def rasterJoin(right: DataFrame, resampleMethod: GTResampleMethod = NearestNeighbor): DataFrame = RasterJoin(self, right, resampleMethod, None)
171174

172175
/**
173176
* Performs a jeft join on the dataframe `right` to this one, reprojecting and merging tiles as necessary.
@@ -183,10 +186,11 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
183186
* @param leftCRS this (left) datafrasme's CRS column
184187
* @param rightExtent right dataframe's CRS extent
185188
* @param rightCRS right dataframe's CRS column
189+
* @param resampleMethod string indicating method to use for resampling.
186190
* @return joined dataframe
187191
*/
188-
def rasterJoin(right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column): DataFrame =
189-
RasterJoin(self, right, leftExtent, leftCRS, rightExtent, rightCRS, None)
192+
def rasterJoin(right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod): DataFrame =
193+
RasterJoin(self, right, leftExtent, leftCRS, rightExtent, rightCRS, resampleMethod, None)
190194

191195
/**
192196
* Performs a jeft join on the dataframe `right` to this one, reprojecting and merging tiles as necessary.
@@ -200,10 +204,11 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
200204
* @param leftCRS this (left) datafrasme's CRS column
201205
* @param rightExtent right dataframe's CRS extent
202206
* @param rightCRS right dataframe's CRS column
207+
* @param resampleMethod string indicating method to use for resampling.
203208
* @return joined dataframe
204209
*/
205-
def rasterJoin(right: DataFrame, joinExpr: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column): DataFrame =
206-
RasterJoin(self, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS, None)
210+
def rasterJoin(right: DataFrame, joinExpr: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod): DataFrame =
211+
RasterJoin(self, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS, resampleMethod, None)
207212

208213

209214
/** Layout contents of RasterFrame to a layer. Assumes CRS and extent columns exist. */

core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
package org.locationtech.rasterframes.extensions
2323
import geotrellis.raster.Dimensions
24+
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod => GTResampleMethod}
2425
import org.apache.spark.sql._
2526
import org.apache.spark.sql.functions._
2627
import org.apache.spark.sql.types.DataType
@@ -36,7 +37,7 @@ import scala.util.Random
3637
object RasterJoin {
3738

3839
/** Perform a raster join on dataframes that each have proj_raster columns, or crs and extent explicitly included. */
39-
def apply(left: DataFrame, right: DataFrame, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
40+
def apply(left: DataFrame, right: DataFrame, resampleMethod: GTResampleMethod, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
4041
def usePRT(d: DataFrame) =
4142
d.projRasterColumns.headOption
4243
.map(p => (rf_crs(p), rf_extent(p)))
@@ -50,21 +51,21 @@ object RasterJoin {
5051
val (ldf, lcrs, lextent) = usePRT(left)
5152
val (rdf, rcrs, rextent) = usePRT(right)
5253

53-
apply(ldf, rdf, lextent, lcrs, rextent, rcrs, fallbackDimensions)
54+
apply(ldf, rdf, lextent, lcrs, rextent, rcrs, resampleMethod, fallbackDimensions)
5455
}
5556

56-
def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
57+
def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
5758
val leftGeom = st_geometry(leftExtent)
5859
val rightGeomReproj = st_reproject(st_geometry(rightExtent), rightCRS, leftCRS)
5960
val joinExpr = new Column(SpatialRelation.Intersects(leftGeom.expr, rightGeomReproj.expr))
60-
apply(left, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS, fallbackDimensions)
61+
apply(left, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS, resampleMethod, fallbackDimensions)
6162
}
6263

6364
private def checkType[T](col: Column, description: String, extractor: PartialFunction[DataType, Any => T]): Unit = {
6465
require(extractor.isDefinedAt(col.expr.dataType), s"Expected column ${col} to be of type $description, but was ${col.expr.dataType}.")
6566
}
6667

67-
def apply(left: DataFrame, right: DataFrame, joinExprs: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, fallbackDimensions: Option[Dimensions[Int]]): DataFrame = {
68+
def apply(left: DataFrame, right: DataFrame, joinExprs: Column, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column, resampleMethod: GTResampleMethod = NearestNeighbor, fallbackDimensions: Option[Dimensions[Int]] = None): DataFrame = {
6869
// Convert resolved column into a symbolic one.
6970
def unresolved(c: Column): Column = col(c.columnName)
7071

@@ -85,7 +86,6 @@ object RasterJoin {
8586
// Post aggregation right crs. We create a new name.
8687
val rightCRS2 = id + "crs"
8788

88-
8989
// Gathering up various expressions we'll use to construct the result.
9090
// After joining We will be doing a groupBy the LHS. We have to define the aggregations to perform after the groupBy.
9191
// On the LHS we just want the first thing (subsequent ones should be identical.
@@ -110,7 +110,7 @@ object RasterJoin {
110110

111111
val reprojCols = rightAggTiles.map(t => {
112112
reproject_and_merge(
113-
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), destDims
113+
col(leftExtent2), col(leftCRS2), col(t.columnName), col(rightExtent2), col(rightCRS2), destDims, lit(ResampleMethod(resampleMethod))
114114
) as t.columnName
115115
})
116116

core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package org.locationtech.rasterframes.extensions
2323

2424
import geotrellis.layer._
25+
import geotrellis.raster.resample.{NearestNeighbor, ResampleMethod => GTResampleMethod}
2526
import org.apache.spark.sql._
2627
import org.apache.spark.sql.functions.broadcast
2728
import org.locationtech.rasterframes._
@@ -30,7 +31,7 @@ import org.locationtech.rasterframes.util._
3031

3132
/** Algorithm for projecting an arbitrary RasterFrame into a layer with consistent CRS and gridding. */
3233
object ReprojectToLayer {
33-
def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey]): RasterFrameLayer = {
34+
def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey], resampleMethod: Option[GTResampleMethod] = None): RasterFrameLayer = {
3435
// create a destination dataframe with crs and extend columns
3536
// use RasterJoin to do the rest.
3637
val gb = tlm.tileBounds
@@ -48,7 +49,7 @@ object ReprojectToLayer {
4849
// Create effectively a target RasterFrame, but with no tiles.
4950
val dest = gridItems.toSeq.toDF(SPATIAL_KEY_COLUMN.columnName, EXTENT_COLUMN.columnName, CRS_COLUMN.columnName)
5051

51-
val joined = RasterJoin(broadcast(dest), df, Some(tlm.tileLayout.tileDimensions))
52+
val joined = RasterJoin(broadcast(dest), df, resampleMethod.getOrElse(NearestNeighbor), Some(tlm.tileLayout.tileDimensions))
5253

5354
joined.asLayer(SPATIAL_KEY_COLUMN, tlm)
5455
}

core/src/main/scala/org/locationtech/rasterframes/functions/TileFunctions.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ import org.locationtech.rasterframes.expressions.transformers.RenderPNG.{RenderC
3434
import org.locationtech.rasterframes.expressions.transformers._
3535
import org.locationtech.rasterframes.stats._
3636
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
37-
import org.locationtech.rasterframes.util.{withTypedAlias, ColorRampNames, _}
38-
import org.locationtech.rasterframes.{encoders, singlebandTileEncoder, functions => F}
37+
import org.locationtech.rasterframes.util.{ColorRampNames, withTypedAlias, _}
38+
import org.locationtech.rasterframes.{encoders, singlebandTileEncoder, functions F}
3939

4040
/** Functions associated with creating and transforming tiles, including tile-wise statistics and rendering. */
4141
trait TileFunctions {
@@ -104,11 +104,21 @@ trait TileFunctions {
104104

105105
/** Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less
106106
* than one will downsample tile; greater than one will upsample. Uses nearest-neighbor. */
107-
def rf_resample[T: Numeric](tileCol: Column, factorValue: T) = Resample(tileCol, factorValue)
107+
def rf_resample[T: Numeric](tileCol: Column, factorValue: T) = ResampleNearest(tileCol, factorValue)
108108

109109
/** Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less
110-
* than one will downsample tile; greater than one will upsample. Uses nearest-neighbor. */
111-
def rf_resample(tileCol: Column, factorCol: Column) = Resample(tileCol, factorCol)
110+
* than one will downsample tile; greater than one will upsample. Uses nearest-neighbor. */
111+
def rf_resample(tileCol: Column, factorCol: Column) = ResampleNearest(tileCol, factorCol)
112+
113+
/** */
114+
def rf_resample[T: Numeric](tileCol: Column, factorVal: T, methodName: Column) = Resample(tileCol, factorVal, methodName)
115+
116+
def rf_resample[T: Numeric](tileCol: Column, factorVal: T, methodName: String) = Resample(tileCol, factorVal, methodName)
117+
118+
def rf_resample(tileCol: Column, factorCol: Column, methodName: Column) = Resample(tileCol, factorCol, methodName)
119+
120+
def rf_resample(tileCol: Column, factorCol: Column, methodName: String) = Resample(tileCol, factorCol, lit(methodName))
121+
112122

113123
/** Assign a `NoData` value to the tile column. */
114124
def rf_with_no_data(col: Column, nodata: Double): Column = SetNoDataValue(col, nodata)

0 commit comments

Comments
 (0)