Skip to content

Commit f4f4e9a

Browse files
committed
Masking improvements and unit tests.
Mask by value of 0 fixed Masking cell type mutations in rf_mask_by_values is resolved Masking by extraction of bits implemented Signed-off-by: Jason T. Brown <[email protected]>
1 parent b2f8d3c commit f4f4e9a

File tree

4 files changed

+290
-147
lines changed

4 files changed

+290
-147
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ trait RasterFunctions {
319319

320320
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
321321
list, replace the value with NODATA. */
322-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = {
322+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Int*): TypedColumn[Any, Tile] = {
323323
import org.apache.spark.sql.functions.array
324324
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
325325
rf_mask_by_values(sourceTile, maskTile, valuesCol)
@@ -338,22 +338,24 @@ trait RasterFunctions {
338338
Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue))
339339

340340
/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
341-
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): Column =
342-
rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(valueToMask))
341+
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): TypedColumn[Any, Tile] =
342+
rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(if (valueToMask) 1 else 0))
343343

344344
/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
345-
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): Column =
346-
rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), valueToMask)
345+
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): TypedColumn[Any, Tile] = {
346+
import org.apache.spark.sql.functions.array
347+
rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), array(valueToMask))
348+
}
347349

348350
/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
349-
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): Column = {
351+
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): TypedColumn[Any, Tile] = {
350352
val bitMask = rf_local_extract_bits(maskTile, startBit, numBits)
351353
rf_mask_by_values(dataTile, bitMask, valuesToMask)
352354
}
353355

354356

355357
/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
356-
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): Column = {
358+
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): TypedColumn[Any, Tile] = {
357359
import org.apache.spark.sql.functions.array
358360
val values = array(valuesToMask.map(lit):_*)
359361
rf_mask_by_bits(dataTile, maskTile, lit(startBit), lit(numBits), values)

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,29 @@ package org.locationtech.rasterframes.expressions.transformers
2424
import com.typesafe.scalalogging.Logger
2525
import geotrellis.raster
2626
import geotrellis.raster.Tile
27-
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask => gtInverseMask, Mask => gtMask}
27+
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask gtInverseMask, Mask gtMask}
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3030
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3131
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression}
3232
import org.apache.spark.sql.rf.TileUDT
33-
import org.apache.spark.sql.types.DataType
33+
import org.apache.spark.sql.types.{DataType, NullType}
3434
import org.apache.spark.sql.{Column, TypedColumn}
3535
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3636
import org.locationtech.rasterframes.expressions.DynamicExtractors._
3737
import org.locationtech.rasterframes.expressions.localops.IsIn
3838
import org.locationtech.rasterframes.expressions.row
3939
import org.slf4j.LoggerFactory
4040

41-
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, inverse: Boolean)
41+
/** Convert cells in the `left` to NoData based on another tile's contents
42+
*
43+
* @param left a tile of data values, with valid nodata cell type
44+
* @param middle a tile indicating locations to set to nodata
45+
* @param right optional, cell values in the `middle` tile indicating locations to set NoData
46+
* @param defined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells
47+
* @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata
48+
*/
49+
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, defined: Boolean, inverse: Boolean)
4250
extends TernaryExpression with CodegenFallback with Serializable {
4351
// aliases.
4452
def targetExp = left
@@ -77,13 +85,13 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp
7785

7886
val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput)
7987

80-
val masking = if (maskValue.value == 0) Defined(maskTile)
81-
else maskTile
88+
val masking = if (defined) Defined(maskTile)
89+
else maskTile.localEqual(maskValue.value)
8290

8391
val result = if (inverse)
84-
gtInverseMask(targetTile, masking, maskValue.value, raster.NODATA)
92+
gtInverseMask(targetTile, masking, 1, raster.NODATA)
8593
else
86-
gtMask(targetTile, masking, maskValue.value, raster.NODATA)
94+
gtMask(targetTile, masking, 1, raster.NODATA)
8795

8896
targetCtx match {
8997
case Some(ctx) => ctx.toProjectRasterTile(result).toInternalRow
@@ -106,7 +114,7 @@ object Mask {
106114
..."""
107115
)
108116
case class MaskByDefined(target: Expression, mask: Expression)
109-
extends Mask(target, mask, Literal(0), false) {
117+
extends Mask(target, mask, Literal(0), true, false) {
110118
override def nodeName: String = "rf_mask"
111119
}
112120
object MaskByDefined {
@@ -126,7 +134,7 @@ object Mask {
126134
..."""
127135
)
128136
case class InverseMaskByDefined(leftTile: Expression, rightTile: Expression)
129-
extends Mask(leftTile, rightTile, Literal(0), true) {
137+
extends Mask(leftTile, rightTile, Literal(0), true, true) {
130138
override def nodeName: String = "rf_inverse_mask"
131139
}
132140
object InverseMaskByDefined {
@@ -146,7 +154,7 @@ object Mask {
146154
..."""
147155
)
148156
case class MaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression)
149-
extends Mask(leftTile, rightTile, maskValue, false) {
157+
extends Mask(leftTile, rightTile, maskValue, false, false) {
150158
override def nodeName: String = "rf_mask_by_value"
151159
}
152160
object MaskByValue {
@@ -168,7 +176,7 @@ object Mask {
168176
..."""
169177
)
170178
case class InverseMaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression)
171-
extends Mask(leftTile, rightTile, maskValue, true) {
179+
extends Mask(leftTile, rightTile, maskValue, false, true) {
172180
override def nodeName: String = "rf_inverse_mask_by_value"
173181
}
174182
object InverseMaskByValue {
@@ -190,7 +198,7 @@ object Mask {
190198
..."""
191199
)
192200
case class MaskByValues(dataTile: Expression, maskTile: Expression)
193-
extends Mask(dataTile, maskTile, Literal(1), inverse = false) {
201+
extends Mask(dataTile, maskTile, Literal(1), false, false) {
194202
def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) =
195203
this(dataTile, IsIn(maskTile, maskValues))
196204
override def nodeName: String = "rf_mask_by_values"

0 commit comments

Comments
 (0)