Skip to content

Commit a3ac4cf

Browse files
committed
Fix Resample and ResampleNearest
Also untangle the super weird inheritance relationship between the two
1 parent aab5486 commit a3ac4cf

File tree

2 files changed

+146
-109
lines changed

2 files changed

+146
-109
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Resample.scala

Lines changed: 62 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -22,96 +22,20 @@
2222
package org.locationtech.rasterframes.expressions.localops
2323

2424
import geotrellis.raster.Tile
25-
import geotrellis.raster.resample._
26-
import geotrellis.raster.resample.{Max => RMax, Min => RMin, ResampleMethod => GTResampleMethod}
25+
import geotrellis.raster.resample.{Mode, NearestNeighbor, Sum, Max => RMax, Min => RMin, ResampleMethod => GTResampleMethod}
2726
import org.apache.spark.sql.Column
28-
import org.apache.spark.sql.catalyst.InternalRow
2927
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
3028
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3129
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
32-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression}
30+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
3331
import org.apache.spark.sql.functions.lit
3432
import org.apache.spark.sql.types.{DataType, StringType}
3533
import org.apache.spark.unsafe.types.UTF8String
36-
import org.locationtech.rasterframes.util.ResampleMethod
37-
import org.locationtech.rasterframes.expressions.{RasterResult, fpTile, row}
3834
import org.locationtech.rasterframes.expressions.DynamicExtractors._
35+
import org.locationtech.rasterframes.expressions.{RasterResult, fpTile, row}
36+
import org.locationtech.rasterframes.util.ResampleMethod
3937

4038

41-
abstract class ResampleBase(left: Expression, right: Expression, method: Expression) extends TernaryExpression with RasterResult with CodegenFallback with Serializable {
42-
43-
override val nodeName: String = "rf_resample"
44-
def first: Expression = left
45-
def second: Expression = right
46-
def third: Expression = method
47-
def dataType: DataType = left.dataType
48-
49-
def targetFloatIfNeeded(t: Tile, method: GTResampleMethod): Tile =
50-
method match {
51-
case NearestNeighbor | Mode | RMax | RMin | Sum => t
52-
case _ => fpTile(t)
53-
}
54-
55-
// These methods define the core algorithms to be used.
56-
def op(left: Tile, right: Tile, method: GTResampleMethod): Tile =
57-
op(left, right.cols, right.rows, method)
58-
59-
def op(left: Tile, right: Double, method: GTResampleMethod): Tile =
60-
op(left, (left.cols * right).toInt, (left.rows * right).toInt, method)
61-
62-
def op(tile: Tile, newCols: Int, newRows: Int, method: GTResampleMethod): Tile =
63-
targetFloatIfNeeded(tile, method).resample(newCols, newRows, method)
64-
65-
override def checkInputDataTypes(): TypeCheckResult = {
66-
// copypasta from BinaryLocalRasterOp
67-
if (!tileExtractor.isDefinedAt(left.dataType)) {
68-
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
69-
}
70-
else if (!tileOrNumberExtractor.isDefinedAt(right.dataType)) {
71-
TypeCheckFailure(s"Input type '${right.dataType}' does not conform to a compatible type.")
72-
} else method.dataType match {
73-
case StringType => TypeCheckSuccess
74-
case _ => TypeCheckFailure(s"Cannot interpret value of type `${method.dataType.simpleString}` for resampling method; please provide a String method name.")
75-
}
76-
}
77-
78-
override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
79-
// more copypasta from BinaryLocalRasterOp
80-
81-
val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(input1))
82-
val methodString = input3.asInstanceOf[UTF8String].toString
83-
84-
val resamplingMethod = methodString match {
85-
case ResampleMethod(mm) => mm
86-
case _ => throw new IllegalArgumentException("Unrecognized resampling method specified")
87-
}
88-
89-
val result: Tile = tileOrNumberExtractor(right.dataType)(input2) match {
90-
// in this case we expect the left and right contexts to vary. no warnings raised.
91-
case TileArg(rightTile, _) => op(leftTile, rightTile, resamplingMethod)
92-
case DoubleArg(d) => op(leftTile, d, resamplingMethod)
93-
case IntegerArg(i) => op(leftTile, i.toDouble, resamplingMethod)
94-
}
95-
96-
// reassemble the leftTile with its context. Note that this operation does not change Extent and CRS
97-
toInternalRow(result, leftCtx)
98-
}
99-
100-
override def eval(input: InternalRow): Any = {
101-
if(input == null) null
102-
else {
103-
val l = left.eval(input)
104-
val r = right.eval(input)
105-
val m = method.eval(input)
106-
if (m == null) null // no method, return null
107-
else if (l == null) null // no l tile, return null
108-
else if (r == null) l // no target tile or factor, return l without changin it
109-
else nullSafeEval(l, r, m)
110-
}
111-
}
112-
113-
}
114-
11539
@ExpressionDescription(
11640
usage = "_FUNC_(tile, factor, method_name) - Resample tile to different dimension based on scalar `factor` or a tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses resampling method named in the `method_name`." +
11741
"Methods average, mode, median, max, min, and sum aggregate over cells when downsampling",
@@ -129,45 +53,74 @@ Examples:
12953
> SELECT _FUNC_(tile1, tile2, lit("cubic_spline"));
13054
..."""
13155
)
132-
case class Resample(left: Expression, factor: Expression, method: Expression) extends ResampleBase(left, factor, method) {
56+
case class Resample(tile: Expression, factor: Expression, method: Expression) extends TernaryExpression with RasterResult with CodegenFallback {
57+
override val nodeName: String = "rf_resample"
58+
def dataType: DataType = tile.dataType
59+
def first: Expression = tile
60+
def second: Expression = factor
61+
def third: Expression = method
62+
63+
override def checkInputDataTypes(): TypeCheckResult = {
64+
if (!tileExtractor.isDefinedAt(tile.dataType)) {
65+
TypeCheckFailure(s"Input type '${tile.dataType}' does not conform to a raster type.")
66+
} else if (!tileOrNumberExtractor.isDefinedAt(factor.dataType)) {
67+
TypeCheckFailure(s"Input type '${factor.dataType}' does not conform to a compatible type.")
68+
} else
69+
method.dataType match {
70+
case StringType => TypeCheckSuccess
71+
case _ =>
72+
TypeCheckFailure(
73+
s"Cannot interpret value of type `${method.dataType.simpleString}` for resampling method; please provide a String method name."
74+
)
75+
}
76+
}
77+
override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
78+
val (leftTile, leftCtx) = tileExtractor(tile.dataType)(row(input1))
79+
val ton = tileOrNumberExtractor(factor.dataType)(input2)
80+
val methodString = input3.asInstanceOf[UTF8String].toString
81+
val resamplingMethod = methodString match {
82+
case ResampleMethod(mm) => mm
83+
case _ => throw new IllegalArgumentException("Unrecognized resampling method specified")
84+
}
85+
86+
val result: Tile = Resample.op(leftTile, ton, resamplingMethod)
87+
toInternalRow(result, leftCtx)
88+
}
89+
13390
override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = copy(newFirst, newSecond, newThird)
13491
}
13592

13693
object Resample {
137-
def apply(left: Column, right: Column, methodName: String): Column =
138-
new Column(Resample(left.expr, right.expr, lit(methodName).expr))
94+
def op(tile: Tile, target: TileOrNumberArg, method: GTResampleMethod): Tile = {
95+
val sourceTile = method match {
96+
case NearestNeighbor | Mode | RMax | RMin | Sum => tile
97+
case _ => fpTile(tile)
98+
}
99+
target match {
100+
case TileArg(targetTile, _) =>
101+
sourceTile.resample(targetTile.cols, targetTile.rows, method)
102+
case DoubleArg(d) =>
103+
sourceTile.resample((tile.cols * d).toInt, (tile.rows * d).toInt, method)
104+
case IntegerArg(i) =>
105+
sourceTile.resample(tile.cols * i,tile.rows * i, method)
106+
}
107+
}
139108

140-
def apply(left: Column, right: Column, method: Column): Column =
141-
new Column(Resample(left.expr, right.expr, method.expr))
109+
def apply(tile: Column, factor: Column, methodName: String): Column =
110+
new Column(Resample(tile.expr, factor.expr, lit(methodName).expr))
142111

143-
def apply[N: Numeric](left: Column, right: N, method: String): Column = new Column(Resample(left.expr, lit(right).expr, lit(method).expr))
112+
def apply(tile: Column, factor: Column, method: Column): Column =
113+
new Column(Resample(tile.expr, factor.expr, method.expr))
144114

145-
def apply[N: Numeric](left: Column, right: N, method: Column): Column = new Column(Resample(left.expr, lit(right).expr, method.expr))
115+
def apply[N: Numeric](tile: Column, factor: N, method: String): Column =
116+
new Column(Resample(tile.expr, lit(factor).expr, lit(method).expr))
146117

118+
def apply[N: Numeric](tile: Column, factor: N, method: Column): Column =
119+
new Column(Resample(tile.expr, lit(factor).expr, method.expr))
147120
}
148121

149-
@ExpressionDescription(
150-
usage = "_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.",
151-
arguments = """
152-
Arguments:
153-
* tile - tile
154-
* rhs - scalar or tile to match dimension""",
155-
examples = """
156-
Examples:
157-
> SELECT _FUNC_(tile, 2.0);
158-
...
159-
> SELECT _FUNC_(tile1, tile2);
160-
...""")
161-
case class ResampleNearest(tile: Expression, target: Expression) extends ResampleBase(tile, target, Literal("nearest")) {
162-
override val nodeName: String = "rf_resample_nearest"
163-
164-
override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
165-
ResampleNearest(tile, target)
166-
}
167-
object ResampleNearest {
168-
def apply(tile: Column, target: Column): Column = new Column(ResampleNearest(tile.expr, target.expr))
169122

170-
def apply[N: Numeric](tile: Column, value: N): Column = new Column(ResampleNearest(tile.expr, lit(value).expr))
171-
}
123+
124+
172125

173126

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.localops
23+
24+
import geotrellis.raster.Tile
25+
import geotrellis.raster.resample._
26+
import org.apache.spark.sql.Column
27+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
28+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
29+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
30+
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
31+
import org.apache.spark.sql.functions.lit
32+
import org.apache.spark.sql.types.DataType
33+
import org.locationtech.rasterframes.expressions.{RasterResult, row}
34+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
35+
36+
37+
@ExpressionDescription(
38+
usage =
39+
"_FUNC_(tile, factor) - Resample tile to different size based on scalar factor or tile whose dimension to match. Scalar less than one will downsample tile; greater than one will upsample. Uses nearest-neighbor value.",
40+
arguments = """
41+
Arguments:
42+
* tile - tile
43+
* rhs - scalar or tile to match dimension""",
44+
examples = """
45+
Examples:
46+
> SELECT _FUNC_(tile, 2.0);
47+
...
48+
> SELECT _FUNC_(tile1, tile2);
49+
..."""
50+
)
51+
case class ResampleNearest(tile: Expression, factor: Expression) extends BinaryExpression with RasterResult with CodegenFallback {
52+
override val nodeName: String = "rf_resample_nearest"
53+
def dataType: DataType = tile.dataType
54+
def left: Expression = tile
55+
def right: Expression = factor
56+
57+
override def checkInputDataTypes(): TypeCheckResult = {
58+
if (!tileExtractor.isDefinedAt(tile.dataType))
59+
TypeCheckFailure(s"Input type '${tile.dataType}' does not conform to a raster type.")
60+
else if (!tileOrNumberExtractor.isDefinedAt(factor.dataType))
61+
TypeCheckFailure(s"Input type '${factor.dataType}' does not conform to a compatible type.")
62+
else
63+
TypeCheckSuccess
64+
}
65+
66+
override def nullSafeEval(input1: Any, input2: Any): Any = {
67+
val (leftTile, leftCtx) = tileExtractor(tile.dataType)(row(input1))
68+
val ton = tileOrNumberExtractor(factor.dataType)(input2)
69+
70+
val result: Tile = Resample.op(leftTile, ton, NearestNeighbor)
71+
toInternalRow(result, leftCtx)
72+
}
73+
74+
override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
75+
ResampleNearest(newLeft, newRight)
76+
}
77+
78+
object ResampleNearest {
79+
def apply(tile: Column, target: Column): Column =
80+
new Column(ResampleNearest(tile.expr, target.expr))
81+
82+
def apply[N: Numeric](tile: Column, value: N): Column =
83+
new Column(ResampleNearest(tile.expr, lit(value).expr))
84+
}

0 commit comments

Comments
 (0)