Skip to content

Commit ecf5748

Browse files
authored
Merge pull request #404 from s22s/feature/local_is_in
Mask by list of values
2 parents ade36ab + 961aa2d commit ecf5748

File tree

12 files changed

+368
-145
lines changed

12 files changed

+368
-145
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,12 +292,38 @@ trait RasterFunctions {
292292
}
293293

294294
/** Where the rf_mask tile contains NODATA, replace values in the source tile with NODATA */
295-
def rf_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] =
296-
Mask.MaskByDefined(sourceTile, maskTile)
295+
def rf_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] = rf_mask(sourceTile, maskTile, false)
296+
297+
/** Where the rf_mask tile contains NODATA, replace values in the source tile with NODATA */
298+
def rf_mask(sourceTile: Column, maskTile: Column, inverse: Boolean=false): TypedColumn[Any, Tile] =
299+
if(!inverse) Mask.MaskByDefined(sourceTile, maskTile)
300+
else Mask.InverseMaskByDefined(sourceTile, maskTile)
301+
302+
/** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */
303+
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column, inverse: Boolean=false): TypedColumn[Any, Tile] =
304+
if (!inverse) Mask.MaskByValue(sourceTile, maskTile, maskValue)
305+
else Mask.InverseMaskByValue(sourceTile, maskTile, maskValue)
297306

298307
/** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */
299-
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
300-
Mask.MaskByValue(sourceTile, maskTile, maskValue)
308+
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int, inverse: Boolean): TypedColumn[Any, Tile] =
309+
rf_mask_by_value(sourceTile, maskTile, lit(maskValue), inverse)
310+
311+
/** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */
312+
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
313+
rf_mask_by_value(sourceTile, maskTile, maskValue, false)
314+
315+
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
316+
list, replace the value with NODATA. */
317+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] =
318+
Mask.MaskByValues(sourceTile, maskTile, maskValues)
319+
320+
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
321+
list, replace the value with NODATA. */
322+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = {
323+
import org.apache.spark.sql.functions.array
324+
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
325+
rf_mask_by_values(sourceTile, maskTile, valuesCol)
326+
}
301327

302328
/** Where the `maskTile` does **not** contain `NoData`, replace values in the source tile with `NoData` */
303329
def rf_inverse_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] =
@@ -307,6 +333,10 @@ trait RasterFunctions {
307333
def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
308334
Mask.InverseMaskByValue(sourceTile, maskTile, maskValue)
309335

336+
/** Where the `maskTile` does **not** equal `maskValue`, replace values in the source tile with `NoData` */
337+
def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
338+
Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue))
339+
310340
/** Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value. */
311341
def rf_rasterize(geometry: Column, bounds: Column, value: Column, cols: Int, rows: Int): TypedColumn[Any, Tile] =
312342
withTypedAlias("rf_rasterize", geometry)(
@@ -408,6 +438,9 @@ trait RasterFunctions {
408438
/** Test if each cell value is in provided array */
409439
def rf_local_is_in(tileCol: Column, arrayCol: Column) = IsIn(tileCol, arrayCol)
410440

441+
/** Test if each cell value is in provided array */
442+
def rf_local_is_in(tileCol: Column, array: Array[Int]) = IsIn(tileCol, array)
443+
411444
/** Return a tile with ones where the input is NoData, otherwise zero */
412445
def rf_local_no_data(tileCol: Column): Column = Undefined(tileCol)
413446

core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,12 @@ case class IsIn(left: Expression, right: Expression) extends BinaryExpression wi
8585
object IsIn {
8686
def apply(left: Column, right: Column): Column =
8787
new Column(IsIn(left.expr, right.expr))
88+
89+
def apply(left: Column, right: Array[Int]): Column = {
90+
import org.apache.spark.sql.functions.lit
91+
import org.apache.spark.sql.functions.array
92+
val arrayExpr = array(right.map(lit):_*).expr
93+
new Column(IsIn(left.expr, arrayExpr))
94+
}
95+
8896
}

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,10 @@ package object expressions {
125125
registry.registerExpression[LocalMeanAggregate]("rf_agg_local_mean")
126126

127127
registry.registerExpression[Mask.MaskByDefined]("rf_mask")
128+
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
128129
registry.registerExpression[Mask.MaskByValue]("rf_mask_by_value")
129130
registry.registerExpression[Mask.InverseMaskByValue]("rf_inverse_mask_by_value")
130-
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
131+
registry.registerExpression[Mask.MaskByValues]("rf_mask_by_values")
131132

132133
registry.registerExpression[DebugRender.RenderAscii]("rf_render_ascii")
133134
registry.registerExpression[DebugRender.RenderMatrix]("rf_render_matrix")

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -28,58 +28,64 @@ import geotrellis.raster.mapalgebra.local.{Defined, InverseMask => gtInverseMask
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3030
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
31-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression}
31+
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ExpressionDescription, Literal, TernaryExpression}
32+
import org.apache.spark.sql.functions.lit
3233
import org.apache.spark.sql.rf.TileUDT
3334
import org.apache.spark.sql.types.DataType
3435
import org.apache.spark.sql.{Column, TypedColumn}
3536
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3637
import org.locationtech.rasterframes.expressions.DynamicExtractors._
38+
import org.locationtech.rasterframes.expressions.localops.IsIn
3739
import org.locationtech.rasterframes.expressions.row
3840
import org.slf4j.LoggerFactory
3941

4042
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, inverse: Boolean)
4143
extends TernaryExpression with CodegenFallback with Serializable {
44+
def targetExp = left
45+
def maskExp = middle
46+
def maskValueExp = right
4247

4348
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))
4449

45-
4650
override def children: Seq[Expression] = Seq(left, middle, right)
4751

4852
override def checkInputDataTypes(): TypeCheckResult = {
49-
if (!tileExtractor.isDefinedAt(left.dataType)) {
50-
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
51-
} else if (!tileExtractor.isDefinedAt(middle.dataType)) {
52-
TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a raster type.")
53-
} else if (!intArgExtractor.isDefinedAt(right.dataType)) {
54-
TypeCheckFailure(s"Input type '${right.dataType}' isn't an integral type.")
53+
if (!tileExtractor.isDefinedAt(targetExp.dataType)) {
54+
TypeCheckFailure(s"Input type '${targetExp.dataType}' does not conform to a raster type.")
55+
} else if (!tileExtractor.isDefinedAt(maskExp.dataType)) {
56+
TypeCheckFailure(s"Input type '${maskExp.dataType}' does not conform to a raster type.")
57+
} else if (!intArgExtractor.isDefinedAt(maskValueExp.dataType)) {
58+
TypeCheckFailure(s"Input type '${maskValueExp.dataType}' isn't an integral type.")
5559
} else TypeCheckSuccess
5660
}
5761
override def dataType: DataType = left.dataType
5862

59-
override protected def nullSafeEval(leftInput: Any, middleInput: Any, rightInput: Any): Any = {
63+
override def makeCopy(newArgs: Array[AnyRef]): Expression = super.makeCopy(newArgs)
64+
65+
override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = {
6066
implicit val tileSer = TileUDT.tileSerializer
61-
val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(leftInput))
62-
val (rightTile, rightCtx) = tileExtractor(middle.dataType)(row(middleInput))
67+
val (targetTile, targetCtx) = tileExtractor(targetExp.dataType)(row(targetInput))
68+
val (maskTile, maskCtx) = tileExtractor(maskExp.dataType)(row(maskInput))
6369

64-
if (leftCtx.isEmpty && rightCtx.isDefined)
70+
if (targetCtx.isEmpty && maskCtx.isDefined)
6571
logger.warn(
6672
s"Right-hand parameter '${middle}' provided an extent and CRS, but the left-hand parameter " +
6773
s"'${left}' didn't have any. Because the left-hand side defines output type, the right-hand context will be lost.")
6874

69-
if (leftCtx.isDefined && rightCtx.isDefined && leftCtx != rightCtx)
75+
if (targetCtx.isDefined && maskCtx.isDefined && targetCtx != maskCtx)
7076
logger.warn(s"Both '${left}' and '${middle}' provided an extent and CRS, but they are different. Left-hand side will be used.")
7177

72-
val maskValue = intArgExtractor(right.dataType)(rightInput)
78+
val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput)
7379

74-
val masking = if (maskValue.value == 0) Defined(rightTile)
75-
else rightTile
80+
val masking = if (maskValue.value == 0) Defined(maskTile)
81+
else maskTile
7682

7783
val result = if (inverse)
78-
gtInverseMask(leftTile, masking, maskValue.value, raster.NODATA)
84+
gtInverseMask(targetTile, masking, maskValue.value, raster.NODATA)
7985
else
80-
gtMask(leftTile, masking, maskValue.value, raster.NODATA)
86+
gtMask(targetTile, masking, maskValue.value, raster.NODATA)
8187

82-
leftCtx match {
88+
targetCtx match {
8389
case Some(ctx) => ctx.toProjectRasterTile(result).toInternalRow
8490
case None => result.toInternalRow
8591
}
@@ -169,4 +175,28 @@ object Mask {
169175
def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
170176
new Column(InverseMaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile]
171177
}
178+
179+
@ExpressionDescription(
180+
usage = "_FUNC_(data, mask, maskValues) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA.",
181+
arguments = """
182+
Arguments:
183+
* target - tile to mask
184+
* mask - masking definition
185+
* maskValues - sequence of values to consider as masks candidates
186+
""",
187+
examples = """
188+
Examples:
189+
> SELECT _FUNC_(data, mask, array(1, 2, 3))
190+
..."""
191+
)
192+
case class MaskByValues(dataTile: Expression, maskTile: Expression)
193+
extends Mask(dataTile, maskTile, Literal(1), inverse = false) {
194+
def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) =
195+
this(dataTile, IsIn(maskTile, maskValues))
196+
override def nodeName: String = "rf_mask_by_values"
197+
}
198+
object MaskByValues {
199+
def apply(dataTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] =
200+
new Column(MaskByValues(dataTile.expr, IsIn(maskTile, maskValues).expr)).as[Tile]
201+
}
172202
}

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import geotrellis.raster._
2828
import geotrellis.raster.render.ColorRamps
2929
import geotrellis.raster.testkit.RasterMatchers
3030
import javax.imageio.ImageIO
31-
import org.apache.spark.sql.Encoders
31+
import org.apache.spark.sql.{Column, Encoders, TypedColumn}
3232
import org.apache.spark.sql.functions._
3333
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
3434
import org.locationtech.rasterframes.model.TileDimensions
@@ -694,19 +694,48 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
694694
rf_local_multiply(rf_convert_cell_type(
695695
rf_local_greater($"tile", 50),
696696
"uint8"),
697-
lit(mask_value)
697+
mask_value
698698
)
699699
)
700700

701701
val withMasked = withMask.withColumn("masked",
702-
rf_inverse_mask_by_value($"tile", $"mask", lit(mask_value)))
703-
702+
rf_inverse_mask_by_value($"tile", $"mask", mask_value))
703+
.withColumn("masked2", rf_mask_by_value($"tile", $"mask", lit(mask_value), true))
704+
withMasked.explain(true)
704705
val result = withMasked.agg(rf_agg_no_data_cells($"tile") < rf_agg_no_data_cells($"masked")).as[Boolean]
705706

706707
result.first() should be(true)
708+
709+
val result2 = withMasked.agg(rf_agg_no_data_cells($"tile") < rf_agg_no_data_cells($"masked2")).as[Boolean]
710+
result2.first() should be(true)
711+
707712
checkDocs("rf_inverse_mask_by_value")
708713
}
709714

715+
it("should mask tile by another identified by sequence of specified values") {
716+
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(six.rows), six.extent, six.crs)
717+
val df = Seq((six, squareIncrementingPRT))
718+
.toDF("tile", "mask")
719+
720+
val mask_values = Seq(4, 5, 6, 12)
721+
722+
val withMasked = df.withColumn("masked",
723+
rf_mask_by_values($"tile", $"mask", mask_values))
724+
725+
val expected = squareIncrementingPRT.toArray().count(v mask_values.contains(v))
726+
727+
val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "masked_nd")
728+
.first()
729+
730+
result.getAs[BigInt](0) should be (expected)
731+
732+
val withMaskedSql = df.selectExpr("rf_mask_by_values(tile, mask, array(4, 5, 6, 12)) AS masked")
733+
val resultSql = withMaskedSql.agg(rf_agg_no_data_cells($"masked")).as[Long]
734+
resultSql.first() should be (expected)
735+
736+
checkDocs("rf_mask_by_values")
737+
}
738+
710739
it("should render ascii art") {
711740
val df = Seq[Tile](ProjectedRasterTile(TestData.l8Labels)).toDF("tile")
712741
val r1 = df.select(rf_render_ascii($"tile"))
@@ -983,6 +1012,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
9831012
.withColumn("ten", lit(10))
9841013
.withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five")))
9851014
.withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five")))
1015+
.withColumn("in_expect_1a", rf_local_is_in($"t", Array(10, 5)))
9861016
.withColumn("in_expect_0", rf_local_is_in($"t", array($"ten")))
9871017

9881018
val e2Result = df.select(rf_tile_sum($"in_expect_2")).as[Double].first()
@@ -991,9 +1021,11 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
9911021
val e1Result = df.select(rf_tile_sum($"in_expect_1")).as[Double].first()
9921022
e1Result should be (1.0)
9931023

1024+
val e1aResult = df.select(rf_tile_sum($"in_expect_1a")).as[Double].first()
1025+
e1aResult should be (1.0)
1026+
9941027
val e0Result = df.select($"in_expect_0").as[Tile].first()
9951028
e0Result.toArray() should contain only (0)
9961029

997-
// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first()
9981030
}
9991031
}

docs/src/main/paradox/reference.md

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,35 +207,51 @@ SQL implementation does not accept a cell_type argument. It returns a float64 ce
207207

208208
## Masking and NoData
209209

210-
See @ref:[NoData handling](nodata-handling.md) for conceptual discussion of cell types and NoData.
210+
See the @ref:[masking](masking.md) page for conceptual discussion of masking operations.
211211

212212
There are statistical functions of the count of data and NoData values per `tile` and aggregate over a `tile` column: @ref:[`rf_data_cells`](reference.md#rf-data-cells), @ref:[`rf_no_data_cells`](reference.md#rf-no-data-cells), @ref:[`rf_agg_data_cells`](reference.md#rf-agg-data-cells), and @ref:[`rf_agg_no_data_cells`](reference.md#rf-agg-no-data-cells).
213213

214214
Masking is a raster operation that sets specific cells to NoData based on the values in another raster.
215215

216216
### rf_mask
217217

218-
Tile rf_mask(Tile tile, Tile mask)
218+
Tile rf_mask(Tile tile, Tile mask, bool inverse)
219219

220220
Where the `mask` contains NoData, replace values in the `tile` with NoData.
221221

222222
Returned `tile` cell type will be coerced to one supporting NoData if it does not already.
223223

224+
`inverse` is a literal not a Column. If `inverse` is true, return the `tile` with NoData in locations where the `mask` _does not_ contain NoData. Equivalent to @ref:[`rf_inverse_mask`](reference.md#rf-inverse-mask).
225+
224226
See also @ref:[`rf_rasterize`](reference.md#rf-rasterize).
225227

228+
### rf_mask_by_value
229+
230+
Tile rf_mask_by_value(Tile data_tile, Tile mask_tile, Int mask_value, bool inverse)
231+
232+
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is equal to `mask_value`.
233+
234+
`inverse` is a literal not a Column. If `inverse` is true, return the `data_tile` with NoData in locations where the `mask_tile` value is _not equal_ to `mask_value`. Equivalent to @ref:[`rf_inverse_mask_by_value`](reference.md#rf-inverse-mask-by-value).
235+
236+
### rf_mask_by_values
237+
238+
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, Array mask_values)
239+
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, seq mask_values)
240+
241+
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is in the `mask_values` Array or list. `mask_values` can be a [`pyspark.sql.ArrayType`][Array] or a `list`.
226242

227243
### rf_inverse_mask
228244

229245
Tile rf_inverse_mask(Tile tile, Tile mask)
230246

231247
Where the `mask` _does not_ contain NoData, replace values in `tile` with NoData.
232248

233-
### rf_mask_by_value
234249

235-
Tile rf_mask_by_value(Tile data_tile, Tile mask_tile, Int mask_value)
250+
### rf_inverse_mask_by_value
236251

237-
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is equal to `mask_value`.
252+
Tile rf_inverse_mask_by_value(Tile data_tile, Tile mask_tile, Int mask_value)
238253

254+
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is not equal to `mask_value`. In other words, only keep `data_tile` cells in locations where the `mask_tile` is equal to `mask_value`.
239255

240256
### rf_is_no_data_tile
241257

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### 0.8.4
66

77
* Upgraded to Spark 2.4.4
8+
* Add `rf_mask_by_values` and `rf_local_is_in` raster functions; added optional `inverse` argument to `rf_mask` functions. ([#403](https://github.com/locationtech/rasterframes/pull/403), [#384](https://github.com/locationtech/rasterframes/issues/384))
89
* Added forced truncation of WKT types in Markdown/HTML rendering. ([#408](https://github.com/locationtech/rasterframes/pull/408))
910
* Add `rf_local_is_in` raster function. ([#400](https://github.com/locationtech/rasterframes/pull/400))
1011
* Added partitioning to catalogs before processing in RasterSourceDataSource ([#397](https://github.com/locationtech/rasterframes/pull/397))

0 commit comments

Comments
 (0)