Skip to content

Commit 5529d48

Browse files
committed
Add mask bits python api and unit test
Signed-off-by: Jason T. Brown <[email protected]>
1 parent b67049a commit 5529d48

File tree

4 files changed

+97
-2
lines changed

4 files changed

+97
-2
lines changed

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1264,7 +1264,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
12641264
checker("cloud_no", cirrus, false)
12651265
}
12661266

1267-
it("should have SQL equivalent"){
1267+
it("mask bits should have SQL equivalent"){
12681268

12691269
df.createOrReplaceTempView("df_maskbits")
12701270

pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,39 @@ def rf_inverse_mask_by_value(data_tile, mask_tile, mask_value):
495495
return _apply_column_function('rf_inverse_mask_by_value', data_tile, mask_tile, mask_value)
496496

497497

498+
def rf_mask_by_bit(data_tile, mask_tile, bit_position, value_to_mask):
499+
"""Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value."""
500+
if isinstance(bit_position, int):
501+
bit_position = lit(bit_position)
502+
if isinstance(value_to_mask, (int, float, bool)):
503+
value_to_mask = lit(bool(value_to_mask))
504+
return _apply_column_function('rf_mask_by_bit', data_tile, mask_tile, bit_position, value_to_mask)
505+
506+
507+
def rf_mask_by_bits(data_tile, mask_tile, start_bit, num_bits, values_to_mask):
508+
"""Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned."""
509+
if isinstance(start_bit, int):
510+
start_bit = lit(start_bit)
511+
if isinstance(num_bits, int):
512+
num_bits = lit(num_bits)
513+
if isinstance(values_to_mask, (tuple, list)):
514+
from pyspark.sql.functions import array
515+
values_to_mask = array([lit(v) for v in values_to_mask])
516+
517+
return _apply_column_function('rf_mask_by_bits', data_tile, mask_tile, start_bit, num_bits, values_to_mask)
518+
519+
520+
def rf_local_extract_bits(tile, start_bit, num_bits=1):
521+
"""Extract value from specified bits of the cells' underlying binary data.
522+
* `startBit` is the first bit to consider, working from the right. It is zero indexed.
523+
* `numBits` is the number of bits to take moving further to the left. """
524+
if isinstance(start_bit, int):
525+
start_bit = lit(bit_position)
526+
if isinstance(num_bits, int):
527+
num_bits = lit(num_bits)
528+
return _apply_column_function('rf_local_extract_bits', tile, start_bit, num_bits)
529+
530+
498531
def rf_local_less(left_tile_col, right_tile_col):
499532
"""Cellwise less than comparison between two tiles"""
500533
return _apply_column_function('rf_local_less', left_tile_col, right_tile_col)

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from pyrasterframes.rf_types import *
2626
from pyspark.sql import SQLContext
2727
from pyspark.sql.functions import *
28+
from pyspark.sql import Row
29+
2830
from . import TestEnvironment
2931

3032

@@ -139,6 +141,22 @@ def test_tile_udt_serialization(self):
139141
long_trip = df.first()["tile"]
140142
self.assertEqual(long_trip, a_tile)
141143

144+
def test_masked_deser(self):
145+
t = Tile(np.array([[1, 2, 3,], [4, 5, 6], [7, 8, 9]]),
146+
CellType('uint8'))
147+
148+
df = self.spark.createDataFrame([Row(t=t)])
149+
roundtrip = df.select(rf_mask_by_value('t',
150+
rf_local_greater('t', lit(6)),
151+
1)) \
152+
.first()[0]
153+
self.assertEqual(
154+
roundtrip.cells.mask.sum(),
155+
3,
156+
f"Expected {3} nodata values but found Tile"
157+
f"{roundtrip}"
158+
)
159+
142160
def test_udf_on_tile_type_input(self):
143161
import numpy.testing
144162
df = self.spark.read.raster(self.img_uri)
@@ -248,7 +266,6 @@ def less_pi(t):
248266
class TileOps(TestEnvironment):
249267

250268
def setUp(self):
251-
from pyspark.sql import Row
252269
# convenience so we can assert around Tile() == Tile()
253270
self.t1 = Tile(np.array([[1, 2],
254271
[3, 4]]), CellType.int8().with_no_data_value(3))

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,44 @@ def test_mask_by_values(self):
256256
# assert_equal(result0[0].cells, expected_diag_nd)
257257
self.assertTrue(result0[0] == expected_diag_nd)
258258

259+
def test_mask_bits(self):
260+
t = Tile(42 * np.ones((4, 4), 'uint16'), CellType.uint16())
261+
# with a varitey of known values
262+
mask = Tile(np.array([
263+
[1, 1, 2720, 2720],
264+
[1, 6816, 6816, 2756],
265+
[2720, 2720, 6900, 2720],
266+
[2720, 6900, 6816, 1]
267+
]), CellType('uint16raw'))
268+
269+
df = self.spark.createDataFrame([Row(t=t, mask=mask)])
270+
271+
# removes fill value 1
272+
mask_fill_df = df.select(rf_mask_by_bit('t', 'mask', 0, True).alias('mbb'))
273+
mask_fill_tile = mask_fill_df.first()['mbb']
274+
275+
self.assertTrue(mask_fill_tile.cell_type.has_no_data())
276+
277+
self.assertTrue(
278+
mask_fill_df.select(rf_data_cells('mbb')).first()[0],
279+
16 - 4
280+
)
281+
# Unsure why this fails. mask_fill_tile.cells is all 42 unmasked.
282+
# self.assertEqual(mask_fill_tile.cells.mask.sum(), 4,
283+
# f'Expected {16 - 4} data values but got the masked tile:'
284+
# f'{mask_fill_tile}'
285+
# )
286+
#
287+
# mask out 6816, 6900
288+
mask_med_hi_cir = df.withColumn('mask_cir_mh',
289+
rf_mask_by_bits('t', 'mask', 11, 2, [2, 3])) \
290+
.first()['mask_cir_mh'].cells
291+
292+
self.assertEqual(
293+
mask_med_hi_cir.mask.sum(),
294+
5
295+
)
296+
259297
def test_mask(self):
260298
from pyspark.sql import Row
261299
from pyrasterframes.rf_types import Tile, CellType
@@ -282,6 +320,13 @@ def test_mask(self):
282320
nd_result = df.select(rf_no_data_cells('masked_t')).first()[0]
283321
self.assertEqual(nd_result, expected_no_data_values)
284322

323+
# deser of tile is correct
324+
self.assertEqual(
325+
df.select('masked_t').first()[0].cells.compressed().size,
326+
expected_data_values
327+
)
328+
329+
285330
def test_resample(self):
286331
from pyspark.sql.functions import lit
287332
result = self.rf.select(

0 commit comments

Comments
 (0)