Skip to content

Commit 3f6e1e0

Browse files
committed
Expand rf_mask_by_value[s] signatures to match python api, expand unit tests
Signed-off-by: Jason T. Brown <[email protected]>
1 parent bf6326a commit 3f6e1e0

File tree

2 files changed

+55
-3
lines changed

2 files changed

+55
-3
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,15 +304,43 @@ trait RasterFunctions {
304304
if (!inverse) Mask.MaskByValue(sourceTile, maskTile, maskValue)
305305
else Mask.InverseMaskByValue(sourceTile, maskTile, maskValue)
306306

307+
/** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */
308+
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int, inverse: Boolean): TypedColumn[Any, Tile] =
309+
rf_mask_by_value(sourceTile, maskTile, lit(maskValue), inverse)
310+
311+
/** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */
312+
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
313+
rf_mask_by_value(sourceTile, maskTile, maskValue, false)
314+
307315
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
308316
list, replace the value with NODATA.
309317
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
310-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column, inverse: Boolean=false): TypedColumn[Any, Tile] =
318+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column, inverse: Boolean): TypedColumn[Any, Tile] =
311319
if (!inverse)
312320
Mask.MaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(1))
313321
else
314322
Mask.InverseMaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(0))
315323

324+
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
325+
list, replace the value with NODATA. */
326+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] =
327+
rf_mask_by_values(sourceTile, maskTile, maskValues, false)
328+
329+
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
330+
list, replace the value with NODATA.
331+
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
332+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Iterable[Int], inverse: Boolean): TypedColumn[Any, Tile] = {
333+
import org.apache.spark.sql.functions.array
334+
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
335+
rf_mask_by_values(sourceTile, maskTile, valuesCol, inverse)
336+
}
337+
338+
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
339+
list, replace the value with NODATA.
340+
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
341+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Iterable[Int]): TypedColumn[Any, Tile] =
342+
rf_mask_by_values(sourceTile, maskTile, maskValues, false)
343+
316344
/** Where the `maskTile` does **not** contain `NoData`, replace values in the source tile with `NoData` */
317345
def rf_inverse_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] =
318346
Mask.InverseMaskByDefined(sourceTile, maskTile)
@@ -321,6 +349,10 @@ trait RasterFunctions {
321349
def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
322350
Mask.InverseMaskByValue(sourceTile, maskTile, maskValue)
323351

352+
/** Where the `maskTile` does **not** equal `maskValue`, replace values in the source tile with `NoData` */
353+
def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
354+
Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue))
355+
324356
/** Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value. */
325357
def rf_rasterize(geometry: Column, bounds: Column, value: Column, cols: Int, rows: Int): TypedColumn[Any, Tile] =
326358
withTypedAlias("rf_rasterize", geometry)(

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,19 +694,39 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
694694
rf_local_multiply(rf_convert_cell_type(
695695
rf_local_greater($"tile", 50),
696696
"uint8"),
697-
lit(mask_value)
697+
mask_value
698698
)
699699
)
700700

701701
val withMasked = withMask.withColumn("masked",
702-
rf_inverse_mask_by_value($"tile", $"mask", lit(mask_value)))
702+
rf_inverse_mask_by_value($"tile", $"mask", mask_value))
703+
.withColumn("masked2", rf_mask_by_value($"tile", $"mask", lit(mask_value), true))
703704

704705
val result = withMasked.agg(rf_agg_no_data_cells($"tile") < rf_agg_no_data_cells($"masked")).as[Boolean]
705706

706707
result.first() should be(true)
708+
709+
val result2 = withMasked.agg(rf_agg_no_data_cells($"tile") < rf_agg_no_data_cells($"masked2")).as[Boolean]
710+
result2.first() should be(true)
711+
707712
checkDocs("rf_inverse_mask_by_value")
708713
}
709714

715+
it("should mask tile by another identified by specified values") {
716+
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(randPRT.rows), randPRT.extent, randPRT.crs)
717+
val df = Seq((randPRT, squareIncrementingPRT))
718+
.toDF("tile", "mask")
719+
val mask_values = Seq(4, 5, 6, 12)
720+
721+
val withMasked = df.withColumn("masked",
722+
rf_mask_by_values($"tile", $"mask", mask_values))
723+
724+
val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "nd").as[Long]
725+
726+
result.first() should be(mask_values.length)
727+
checkDocs("rf_mask_by_values")
728+
}
729+
710730
it("should render ascii art") {
711731
val df = Seq[Tile](ProjectedRasterTile(TestData.l8Labels)).toDF("tile")
712732
val r1 = df.select(rf_render_ascii($"tile"))

0 commit comments

Comments
 (0)