Skip to content

Commit acab757

Browse files
committed
Fix test for rf_mask_by_values and add SQL test
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 3f6e1e0 commit acab757

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ trait RasterFunctions {
319319
if (!inverse)
320320
Mask.MaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(1))
321321
else
322-
Mask.InverseMaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(0))
322+
Mask.MaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(0))
323323

324324
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
325325
list, replace the value with NODATA. */

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -713,18 +713,27 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
713713
}
714714

715715
it("should mask tile by another identified by specified values") {
716-
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(randPRT.rows), randPRT.extent, randPRT.crs)
717-
val df = Seq((randPRT, squareIncrementingPRT))
716+
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(six.rows), six.extent, six.crs)
717+
val df = Seq((six, squareIncrementingPRT))
718718
.toDF("tile", "mask")
719719
val mask_values = Seq(4, 5, 6, 12)
720720

721721
val withMasked = df.withColumn("masked",
722722
rf_mask_by_values($"tile", $"mask", mask_values))
723723

724-
val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "nd").as[Long]
724+
val expected = squareIncrementingPRT.toArray()
725+
.filter(v mask_values.contains(v))
726+
.length
727+
728+
val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "masked_nd")
729+
.first()
730+
731+
result.getAs[BigInt](0) should be (expected)
732+
733+
val withMaskedSql = df.selectExpr("rf_mask_by_values(tile, mask, array(4, 5, 6, 12), false) AS masked")
734+
val resultSql = withMaskedSql.agg(rf_agg_no_data_cells($"masked")).as[Long]
735+
resultSql.first() should be (expected)
725736

726-
result.first() should be(mask_values.length)
727-
checkDocs("rf_mask_by_values")
728737
}
729738

730739
it("should render ascii art") {
@@ -1018,6 +1027,5 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
10181027
val e0Result = df.select($"in_expect_0").as[Tile].first()
10191028
e0Result.toArray() should contain only (0)
10201029

1021-
// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first()
10221030
}
10231031
}

0 commit comments

Comments
 (0)