Skip to content

Commit abb7add

Browse files
committed
Fix for both masking by def and value; expand code comments; update tests
Signed-off-by: Jason T. Brown <[email protected]>
1 parent f4f4e9a commit abb7add

File tree

3 files changed

+31
-12
lines changed

3 files changed

+31
-12
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ package org.locationtech.rasterframes.expressions.transformers
2424
import com.typesafe.scalalogging.Logger
2525
import geotrellis.raster
2626
import geotrellis.raster.Tile
27-
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask gtInverseMask, Mask gtMask}
27+
import geotrellis.raster.mapalgebra.local.{Undefined, InverseMask gtInverseMask, Mask gtMask}
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3030
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3131
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression}
3232
import org.apache.spark.sql.rf.TileUDT
33-
import org.apache.spark.sql.types.{DataType, NullType}
33+
import org.apache.spark.sql.types.DataType
3434
import org.apache.spark.sql.{Column, TypedColumn}
3535
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3636
import org.locationtech.rasterframes.expressions.DynamicExtractors._
@@ -43,10 +43,10 @@ import org.slf4j.LoggerFactory
4343
* @param left a tile of data values, with valid nodata cell type
4444
* @param middle a tile indicating locations to set to nodata
4545
* @param right optional, cell values in the `middle` tile indicating locations to set NoData
46-
* @param defined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells
46+
* @param undefined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells
4747
* @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata
4848
*/
49-
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, defined: Boolean, inverse: Boolean)
49+
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, undefined: Boolean, inverse: Boolean)
5050
extends TernaryExpression with CodegenFallback with Serializable {
5151
// aliases.
5252
def targetExp = left
@@ -85,9 +85,12 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp
8585

8686
val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput)
8787

88-
val masking = if (defined) Defined(maskTile)
89-
else maskTile.localEqual(maskValue.value)
88+
// Get a tile where values of 1 indicate locations to set to ND in the target tile
89+
// When `undefined` is true, setting targetTile locations to ND for ND locations of the `maskTile`
90+
val masking = if (undefined) Undefined(maskTile)
91+
else maskTile.localEqual(maskValue.value) // Otherwise if `maskTile` locations equal `maskValue`, set location to ND
9092

93+
// apply the `masking` where values are 1 set to ND (possibly inverted!)
9194
val result = if (inverse)
9295
gtInverseMask(targetTile, masking, 1, raster.NODATA)
9396
else

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,17 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
642642
checkDocs("rf_mask")
643643
}
644644

645+
it("should mask with expected results") {
646+
val df = Seq((byteArrayTile, maskingTile)).toDF("tile", "mask")
647+
648+
val withMasked = df.withColumn("masked",
649+
rf_mask($"tile", $"mask"))
650+
651+
val result: Tile = withMasked.select($"masked").as[Tile].first()
652+
653+
result.localUndefined().toArray() should be (maskingTile.localUndefined().toArray())
654+
}
655+
645656
it("should mask without mutating cell type") {
646657
val result = Seq((byteArrayTile, maskingTile))
647658
.toDF("tile", "mask")
@@ -1227,6 +1238,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
12271238
df.createOrReplaceTempView("df_maskbits")
12281239

12291240
val maskedCol = "cloud_conf_med"
1241+
// this is the example in the docs
12301242
val result = spark.sql(
12311243
s"""
12321244
|SELECT rf_mask_by_values(

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,19 +261,23 @@ def test_mask(self):
261261
from pyrasterframes.rf_types import Tile, CellType
262262

263263
np.random.seed(999)
264-
ma = np.ma.array(np.random.randint(0, 10, (5, 5), dtype='int8'), mask=np.random.rand(5, 5) > 0.7)
264+
# importantly exclude 0 from teh range because that's the nodata value for the `data_tile`'s cell type
265+
ma = np.ma.array(np.random.randint(1, 10, (5, 5), dtype='int8'), mask=np.random.rand(5, 5) > 0.7)
265266
expected_data_values = ma.compressed().size
266267
expected_no_data_values = ma.size - expected_data_values
267268
self.assertTrue(expected_data_values > 0, "Make sure random seed is cooperative ")
268269
self.assertTrue(expected_no_data_values > 0, "Make sure random seed is cooperative ")
269270

270-
df = self.spark.createDataFrame([
271-
Row(t=Tile(np.ones(ma.shape, ma.dtype)), m=Tile(ma))
272-
])
271+
data_tile = Tile(np.ones(ma.shape, ma.dtype), CellType.uint8())
272+
273+
df = self.spark.createDataFrame([Row(t=data_tile, m=Tile(ma))]) \
274+
.withColumn('masked_t', rf_mask('t', 'm'))
273275

274-
df = df.withColumn('masked_t', rf_mask('t', 'm'))
275276
result = df.select(rf_data_cells('masked_t')).first()[0]
276-
self.assertEqual(result, expected_data_values)
277+
self.assertEqual(result, expected_data_values,
278+
f"Masked tile should have {expected_data_values} data values but found: {df.select('masked_t').first()[0].cells}."
279+
f"Original data: {data_tile.cells}"
280+
f"Masked by {ma}")
277281

278282
nd_result = df.select(rf_no_data_cells('masked_t')).first()[0]
279283
self.assertEqual(nd_result, expected_no_data_values)

0 commit comments

Comments
 (0)