Skip to content

Commit fa8f6b6

Browse files
committed
Fixes to rf_mask_by_values to compose over IfIn and have SQL docs.
Removed inverse flag from rf_mask_by_values before committing to an approach with it.
1 parent 4e12f2e commit fa8f6b6

File tree

7 files changed

+58
-78
lines changed

7 files changed

+58
-78
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -312,32 +312,19 @@ trait RasterFunctions {
312312
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
313313
rf_mask_by_value(sourceTile, maskTile, maskValue, false)
314314

315-
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
316-
list, replace the value with NODATA.
317-
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
318-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column, inverse: Boolean): TypedColumn[Any, Tile] =
319-
Mask.MaskByValues(sourceTile, maskTile, maskValues, lit(inverse))
320-
321315
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
322316
list, replace the value with NODATA. */
323317
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] =
324-
rf_mask_by_values(sourceTile, maskTile, maskValues, false)
318+
Mask.MaskByValues(sourceTile, maskTile, maskValues)
325319

326320
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
327-
list, replace the value with NODATA.
328-
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
329-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Iterable[Int], inverse: Boolean): TypedColumn[Any, Tile] = {
321+
list, replace the value with NODATA. */
322+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = {
330323
import org.apache.spark.sql.functions.array
331324
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
332-
rf_mask_by_values(sourceTile, maskTile, valuesCol, inverse)
325+
rf_mask_by_values(sourceTile, maskTile, valuesCol)
333326
}
334327

335-
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
336-
list, replace the value with NODATA.
337-
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
338-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Iterable[Int]): TypedColumn[Any, Tile] =
339-
rf_mask_by_values(sourceTile, maskTile, maskValues, false)
340-
341328
/** Where the `maskTile` does **not** contain `NoData`, replace values in the source tile with `NoData` */
342329
def rf_inverse_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] =
343330
Mask.InverseMaskByDefined(sourceTile, maskTile)

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,10 @@ package object expressions {
125125
registry.registerExpression[LocalMeanAggregate]("rf_agg_local_mean")
126126

127127
registry.registerExpression[Mask.MaskByDefined]("rf_mask")
128+
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
128129
registry.registerExpression[Mask.MaskByValue]("rf_mask_by_value")
129-
registry.registerExpression[Mask.MaskByValues]("rf_mask_by_values")
130130
registry.registerExpression[Mask.InverseMaskByValue]("rf_inverse_mask_by_value")
131-
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
131+
registry.registerExpression[Mask.MaskByValues]("rf_mask_by_values")
132132

133133
registry.registerExpression[DebugRender.RenderAscii]("rf_render_ascii")
134134
registry.registerExpression[DebugRender.RenderMatrix]("rf_render_matrix")

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ package org.locationtech.rasterframes.expressions.transformers
2424
import com.typesafe.scalalogging.Logger
2525
import geotrellis.raster
2626
import geotrellis.raster.Tile
27-
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask gtInverseMask, Mask gtMask}
27+
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask => gtInverseMask, Mask => gtMask}
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3030
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
31-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression}
31+
import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ExpressionDescription, Literal, TernaryExpression}
3232
import org.apache.spark.sql.functions.lit
3333
import org.apache.spark.sql.rf.TileUDT
3434
import org.apache.spark.sql.types.DataType
@@ -41,47 +41,51 @@ import org.slf4j.LoggerFactory
4141

4242
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, inverse: Boolean)
4343
extends TernaryExpression with CodegenFallback with Serializable {
44+
def targetExp = left
45+
def maskExp = middle
46+
def maskValueExp = right
4447

4548
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))
4649

47-
4850
override def children: Seq[Expression] = Seq(left, middle, right)
4951

5052
override def checkInputDataTypes(): TypeCheckResult = {
51-
if (!tileExtractor.isDefinedAt(left.dataType)) {
52-
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
53-
} else if (!tileExtractor.isDefinedAt(middle.dataType)) {
54-
TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a raster type.")
55-
} else if (!intArgExtractor.isDefinedAt(right.dataType)) {
56-
TypeCheckFailure(s"Input type '${right.dataType}' isn't an integral type.")
53+
if (!tileExtractor.isDefinedAt(targetExp.dataType)) {
54+
TypeCheckFailure(s"Input type '${targetExp.dataType}' does not conform to a raster type.")
55+
} else if (!tileExtractor.isDefinedAt(maskExp.dataType)) {
56+
TypeCheckFailure(s"Input type '${maskExp.dataType}' does not conform to a raster type.")
57+
} else if (!intArgExtractor.isDefinedAt(maskValueExp.dataType)) {
58+
TypeCheckFailure(s"Input type '${maskValueExp.dataType}' isn't an integral type.")
5759
} else TypeCheckSuccess
5860
}
5961
override def dataType: DataType = left.dataType
6062

61-
override protected def nullSafeEval(leftInput: Any, middleInput: Any, rightInput: Any): Any = {
63+
override def makeCopy(newArgs: Array[AnyRef]): Expression = super.makeCopy(newArgs)
64+
65+
override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = {
6266
implicit val tileSer = TileUDT.tileSerializer
63-
val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(leftInput))
64-
val (rightTile, rightCtx) = tileExtractor(middle.dataType)(row(middleInput))
67+
val (targetTile, targetCtx) = tileExtractor(targetExp.dataType)(row(targetInput))
68+
val (maskTile, maskCtx) = tileExtractor(maskExp.dataType)(row(maskInput))
6569

66-
if (leftCtx.isEmpty && rightCtx.isDefined)
70+
if (targetCtx.isEmpty && maskCtx.isDefined)
6771
logger.warn(
6872
s"Right-hand parameter '${middle}' provided an extent and CRS, but the left-hand parameter " +
6973
s"'${left}' didn't have any. Because the left-hand side defines output type, the right-hand context will be lost.")
7074

71-
if (leftCtx.isDefined && rightCtx.isDefined && leftCtx != rightCtx)
75+
if (targetCtx.isDefined && maskCtx.isDefined && targetCtx != maskCtx)
7276
logger.warn(s"Both '${left}' and '${middle}' provided an extent and CRS, but they are different. Left-hand side will be used.")
7377

74-
val maskValue = intArgExtractor(right.dataType)(rightInput)
78+
val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput)
7579

76-
val masking = if (maskValue.value == 0) Defined(rightTile)
77-
else rightTile
80+
val masking = if (maskValue.value == 0) Defined(maskTile)
81+
else maskTile
7882

7983
val result = if (inverse)
80-
gtInverseMask(leftTile, masking, maskValue.value, raster.NODATA)
84+
gtInverseMask(targetTile, masking, maskValue.value, raster.NODATA)
8185
else
82-
gtMask(leftTile, masking, maskValue.value, raster.NODATA)
86+
gtMask(targetTile, masking, maskValue.value, raster.NODATA)
8387

84-
leftCtx match {
88+
targetCtx match {
8589
case Some(ctx) => ctx.toProjectRasterTile(result).toInternalRow
8690
case None => result.toInternalRow
8791
}
@@ -173,24 +177,26 @@ object Mask {
173177
}
174178

175179
@ExpressionDescription(
176-
usage = "_FUNC_(data, mask, maskValues, inverse) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA. If `inverse` is true, the cells in `mask` that are not in `maskValues` list become NODATA",
177-
arguments =
178-
"""
179-
180+
usage = "_FUNC_(data, mask, maskValues) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA.",
181+
arguments = """
182+
Arguments:
183+
* target - tile to mask
184+
* mask - masking definition
185+
* maskValues - sequence of values to consider as masks candidates
180186
""",
181-
examples =
182-
"""
183-
> SELECT _FUNC_(data, mask, array(1, 2, 3), false)
184-
185-
"""
187+
examples = """
188+
Examples:
189+
> SELECT _FUNC_(data, mask, array(1, 2, 3))
190+
..."""
186191
)
187-
case class MaskByValues(dataTile: Expression, maskTile: Expression, maskValues: Expression, inverse: Boolean)
188-
extends Mask(dataTile, IsIn(maskTile, maskValues), inverse, false) {
192+
case class MaskByValues(dataTile: Expression, maskTile: Expression)
193+
extends Mask(dataTile, maskTile, Literal(1), inverse = false) {
194+
def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) =
195+
this(dataTile, IsIn(maskTile, maskValues))
189196
override def nodeName: String = "rf_mask_by_values"
190197
}
191198
object MaskByValues {
192-
def apply(dataTile: Column, maskTile: Column, maskValues: Column, inverse: Column): TypedColumn[Any, Tile] =
193-
new Column(MaskByValues(dataTile.expr, maskTile.expr, maskValues.expr, inverse.expr)).as[Tile]
199+
def apply(dataTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] =
200+
new Column(MaskByValues(dataTile.expr, IsIn(maskTile, maskValues).expr)).as[Tile]
194201
}
195-
196202
}

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import geotrellis.raster._
2828
import geotrellis.raster.render.ColorRamps
2929
import geotrellis.raster.testkit.RasterMatchers
3030
import javax.imageio.ImageIO
31-
import org.apache.spark.sql.Encoders
31+
import org.apache.spark.sql.{Column, Encoders, TypedColumn}
3232
import org.apache.spark.sql.functions._
3333
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
3434
import org.locationtech.rasterframes.model.TileDimensions
@@ -701,7 +701,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
701701
val withMasked = withMask.withColumn("masked",
702702
rf_inverse_mask_by_value($"tile", $"mask", mask_value))
703703
.withColumn("masked2", rf_mask_by_value($"tile", $"mask", lit(mask_value), true))
704-
704+
withMasked.explain(true)
705705
val result = withMasked.agg(rf_agg_no_data_cells($"tile") < rf_agg_no_data_cells($"masked")).as[Boolean]
706706

707707
result.first() should be(true)
@@ -712,30 +712,28 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
712712
checkDocs("rf_inverse_mask_by_value")
713713
}
714714

715-
it("should mask tile by another identified by specified values") {
716-
checkDocs("rf_mask_by_values")
717-
715+
it("should mask tile by another identified by sequence of specified values") {
718716
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(six.rows), six.extent, six.crs)
719717
val df = Seq((six, squareIncrementingPRT))
720718
.toDF("tile", "mask")
719+
721720
val mask_values = Seq(4, 5, 6, 12)
722721

723722
val withMasked = df.withColumn("masked",
724723
rf_mask_by_values($"tile", $"mask", mask_values))
725724

726-
val expected = squareIncrementingPRT.toArray()
727-
.filter(v mask_values.contains(v))
728-
.length
725+
val expected = squareIncrementingPRT.toArray().count(v mask_values.contains(v))
729726

730727
val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "masked_nd")
731728
.first()
732729

733730
result.getAs[BigInt](0) should be (expected)
734731

735-
val withMaskedSql = df.selectExpr("rf_mask_by_values(tile, mask, array(4, 5, 6, 12), false) AS masked")
732+
val withMaskedSql = df.selectExpr("rf_mask_by_values(tile, mask, array(4, 5, 6, 12)) AS masked")
736733
val resultSql = withMaskedSql.agg(rf_agg_no_data_cells($"masked")).as[Long]
737734
resultSql.first() should be (expected)
738735

736+
checkDocs("rf_mask_by_values")
739737
}
740738

741739
it("should render ascii art") {

docs/src/main/paradox/reference.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,12 +235,10 @@ Generate a `tile` with the values from `data_tile`, with NoData in cells where t
235235

236236
### rf_mask_by_values
237237

238-
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, Array mask_values, bool inverse)
239-
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, seq mask_values, bool inverse)
238+
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, Array mask_values)
239+
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, seq mask_values)
240240

241-
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is in the `mask_values` Array or list. `mask_values` can be a [`pyspark.sql.ArrayType`][Array] or a `list`.
242-
243-
`inverse` is a literal not a Column. If it is True, the `data_tile` cells are set to NoData where the `mask_tile` cells are __not__ in `mask_values`.
241+
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is in the `mask_values` Array or list. `mask_values` can be a [`pyspark.sql.ArrayType`][Array] or a `list`.
244242

245243
### rf_inverse_mask
246244

pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,18 +474,17 @@ def rf_mask_by_value(data_tile, mask_tile, mask_value, inverse=False):
474474
return Column(jfcn(_to_java_column(data_tile), _to_java_column(mask_tile), _to_java_column(mask_value), inverse))
475475

476476

477-
def rf_mask_by_values(data_tile, mask_tile, mask_values, inverse=False):
477+
def rf_mask_by_values(data_tile, mask_tile, mask_values):
478478
"""Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
479479
list, replace the value with NODATA.
480-
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA
481480
"""
482481
from pyspark.sql.functions import array as sql_array
483482
if isinstance(mask_values, list):
484483
mask_values = sql_array([lit(v) for v in mask_values])
485484

486485
jfcn = RFContext.active().lookup('rf_mask_by_values')
487486
col_args = [_to_java_column(c) for c in [data_tile, mask_tile, mask_values]]
488-
return Column(jfcn(*col_args, inverse))
487+
return Column(jfcn(*col_args))
489488

490489

491490
def rf_inverse_mask_by_value(data_tile, mask_tile, mask_value):

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -249,21 +249,13 @@ def test_mask_by_values(self):
249249
tile = Tile(np.random.randint(1, 100, (5, 5)), CellType.uint8())
250250
mask_tile = Tile(np.array(range(1, 26), 'uint8').reshape(5, 5))
251251
expected_diag_nd = Tile(np.ma.masked_array(tile.cells, mask=np.eye(5)))
252-
expected_off_diag_nd = Tile(np.ma.masked_array(tile.cells, mask=1 - np.eye(5)))
253252

254253
df = self.spark.createDataFrame([Row(t=tile, m=mask_tile)]) \
255254
.select(rf_mask_by_values('t', 'm', [0, 6, 12, 18, 24])) # values on the diagonal
256255
result0 = df.first()
257256
# assert_equal(result0[0].cells, expected_diag_nd)
258257
self.assertTrue(result0[0] == expected_diag_nd)
259258

260-
# mask values off the diagonal! (inverse=True)
261-
result1 = self.spark.createDataFrame([Row(t=tile, m=mask_tile)]) \
262-
.select(rf_mask_by_values('t', 'm', [0, 6, 12, 18, 24], True)) \
263-
.first()
264-
# assert_equal(result1[0].cells, expected_off_diag_nd)
265-
self.assertTrue(result1[0] == expected_off_diag_nd)
266-
267259
def test_mask(self):
268260
from pyspark.sql import Row
269261
from pyrasterframes.rf_types import Tile, CellType

0 commit comments

Comments
 (0)