Skip to content

Commit 103dc6c

Browse files
committed
Update mask functions with inverse arg, add rf_mask_by_values function
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 458c027 commit 103dc6c

File tree

5 files changed

+126
-40
lines changed

5 files changed

+126
-40
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,12 +292,26 @@ trait RasterFunctions {
292292
}
293293

294294
/** Where the rf_mask tile contains NODATA, replace values in the source tile with NODATA */
295-
def rf_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] =
296-
Mask.MaskByDefined(sourceTile, maskTile)
295+
def rf_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] = rf_mask(sourceTile, maskTile, false)
296+
297+
/** Where the rf_mask tile contains NODATA, replace values in the source tile with NODATA */
298+
def rf_mask(sourceTile: Column, maskTile: Column, inverse: Boolean=false): TypedColumn[Any, Tile] =
299+
if(!inverse) Mask.MaskByDefined(sourceTile, maskTile)
300+
else Mask.InverseMaskByDefined(sourceTile, maskTile)
297301

298302
/** Where the `maskTile` equals `maskValue`, replace values in the source tile with `NoData` */
299-
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
300-
Mask.MaskByValue(sourceTile, maskTile, maskValue)
303+
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Column, inverse: Boolean=false): TypedColumn[Any, Tile] =
304+
if (!inverse) Mask.MaskByValue(sourceTile, maskTile, maskValue)
305+
else Mask.InverseMaskByValue(sourceTile, maskTile, maskValue)
306+
307+
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
308+
list, replace the value with NODATA.
309+
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
310+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column, inverse: Boolean=false): TypedColumn[Any, Tile] =
311+
if (!inverse)
312+
Mask.MaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(1))
313+
else
314+
Mask.InverseMaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(0))
301315

302316
/** Where the `maskTile` does **not** contain `NoData`, replace values in the source tile with `NoData` */
303317
def rf_inverse_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] =

docs/src/main/paradox/reference.md

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,27 +215,45 @@ Masking is a raster operation that sets specific cells to NoData based on the va
215215

216216
### rf_mask
217217

218-
Tile rf_mask(Tile tile, Tile mask)
218+
Tile rf_mask(Tile tile, Tile mask, bool inverse)
219219

220220
Where the `mask` contains NoData, replace values in the `tile` with NoData.
221221

222222
Returned `tile` cell type will be coerced to one supporting NoData if it does not already.
223223

224+
`inverse` is a literal not a Column. If `inverse` is true, return the `tile` with NoData in locations where the `mask` _does not_ contain NoData. Equivalent to @ref:[`rf_inverse_mask`](reference.md#rf-inverse-mask).
225+
224226
See also @ref:[`rf_rasterize`](reference.md#rf-rasterize).
225227

228+
### rf_mask_by_value
229+
230+
Tile rf_mask_by_value(Tile data_tile, Tile mask_tile, Int mask_value, bool inverse)
231+
232+
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is equal to `mask_value`.
233+
234+
`inverse` is a literal not a Column. If `inverse` is true, return the `data_tile` with NoData in locations where the `mask_tile` value is _not equal_ to `mask_value`. Equivalent to @ref:[`rf_inverse_mask_by_value`](reference.md#rf-inverse-mask-by-value).
235+
236+
### rf_mask_by_values
237+
238+
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, Array mask_values, bool inverse)
239+
Tile rf_mask_by_values(Tile data_tile, Tile mask_tile, seq mask_values, bool inverse)
240+
241+
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is in the `mask_values` Array or list. `mask_values` can be a [`pyspark.sql.ArrayType`][Array] or a `list`.
242+
243+
`inverse` is a literal not a Column. If it is True, the `data_tile` cells are set to NoData where the `mask_tile` cells are __not__ in `mask_values`.
226244

227245
### rf_inverse_mask
228246

229247
Tile rf_inverse_mask(Tile tile, Tile mask)
230248

231249
Where the `mask` _does not_ contain NoData, replace values in `tile` with NoData.
232250

233-
### rf_mask_by_value
234251

235-
Tile rf_mask_by_value(Tile data_tile, Tile mask_tile, Int mask_value)
252+
### rf_inverse_mask_by_value
236253

237-
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is equal to `mask_value`.
254+
Tile rf_inverse_mask_by_value(Tile data_tile, Tile mask_tile, Int mask_value)
238255

256+
Generate a `tile` with the values from `data_tile`, with NoData in cells where the `mask_tile` is not equal to `mask_value`. In other words, only keep `data_tile` cells in locations where the `mask_tile` is equal to `mask_value`.
239257

240258
### rf_is_no_data_tile
241259

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
* _Breaking_ (potentially): removed `GeoTiffCollectionRelation` due to usage limitation and overlap with `RasterSourceDataSource` functionality.
88
* Upgraded to Spark 2.4.4
9-
* Add `rf_local_is_in` raster function
9+
* Add `rf_mask_by_values` and `rf_local_is_in` raster functions; added optional `inverse` argument to `rf_mask` functions
1010

1111
### 0.8.3
1212

pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"""
2626
from __future__ import absolute_import
2727
from pyspark.sql.column import Column, _to_java_column
28+
from pyspark.sql.functions import lit
2829
from .rf_context import RFContext
2930
from .rf_types import CellType
3031

@@ -137,20 +138,6 @@ def rf_explode_tiles_sample(sample_frac, seed, *tile_cols):
137138
return Column(jfcn(sample_frac, seed, RFContext.active().list_to_seq(jcols)))
138139

139140

140-
def rf_mask_by_value(data_tile, mask_tile, mask_value):
141-
"""Generate a tile with the values from the data tile, but where cells in the masking tile contain the masking
142-
value, replace the data value with NODATA. """
143-
jfcn = RFContext.active().lookup('rf_mask_by_value')
144-
return Column(jfcn(_to_java_column(data_tile), _to_java_column(mask_tile), _to_java_column(mask_value)))
145-
146-
147-
def rf_inverse_mask_by_value(data_tile, mask_tile, mask_value):
148-
"""Generate a tile with the values from the data tile, but where cells in the masking tile do not contain the
149-
masking value, replace the data value with NODATA. """
150-
jfcn = RFContext.active().lookup('rf_inverse_mask_by_value')
151-
return Column(jfcn(_to_java_column(data_tile), _to_java_column(mask_tile), _to_java_column(mask_value)))
152-
153-
154141
def _apply_scalar_to_tile(name, tile_col, scalar):
155142
jfcn = RFContext.active().lookup(name)
156143
return Column(jfcn(_to_java_column(tile_col), scalar))
@@ -270,14 +257,16 @@ def rf_local_data(tile_col):
270257
"""Return a tile with zeros where the input is NoData, otherwise one."""
271258
return _apply_column_function('rf_local_data', tile_col)
272259

260+
273261
def rf_local_is_in(tile_col, array):
274262
"""Return a tile with cell values of 1 where the `tile_col` cell is in the provided array."""
275-
from pyspark.sql.functions import array as sql_array, lit
263+
from pyspark.sql.functions import array as sql_array
276264
if isinstance(array, list):
277265
array = sql_array([lit(v) for v in array])
278266

279267
return _apply_column_function('rf_local_is_in', tile_col, array)
280268

269+
281270
def _apply_column_function(name, *args):
282271
jfcn = RFContext.active().lookup(name)
283272
jcols = [_to_java_column(arg) for arg in args]
@@ -459,16 +448,54 @@ def rf_agg_local_stats(tile_col):
459448
return _apply_column_function('rf_agg_local_stats', tile_col)
460449

461450

462-
def rf_mask(src_tile_col, mask_tile_col):
463-
"""Where the rf_mask (second) tile contains NODATA, replace values in the source (first) tile with NODATA."""
464-
return _apply_column_function('rf_mask', src_tile_col, mask_tile_col)
451+
def rf_mask(src_tile_col, mask_tile_col, inverse=False):
452+
"""Where the rf_mask (second) tile contains NODATA, replace values in the source (first) tile with NODATA.
453+
If `inverse` is true, replaces values in the source tile with NODATA where the mask tile contains valid data.
454+
"""
455+
if not inverse:
456+
return _apply_column_function('rf_mask', src_tile_col, mask_tile_col)
457+
else:
458+
rf_inverse_mask(src_tile_col, mask_tile_col)
465459

466460

467461
def rf_inverse_mask(src_tile_col, mask_tile_col):
468-
"""Where the rf_mask (second) tile DOES NOT contain NODATA, replace values in the source (first) tile with NODATA."""
462+
"""Where the rf_mask (second) tile DOES NOT contain NODATA, replace values in the source
463+
(first) tile with NODATA."""
469464
return _apply_column_function('rf_inverse_mask', src_tile_col, mask_tile_col)
470465

471466

467+
def rf_mask_by_value(data_tile, mask_tile, mask_value, inverse=False):
468+
"""Generate a tile with the values from the data tile, but where cells in the masking tile contain the masking
469+
value, replace the data value with NODATA. """
470+
if isinstance(mask_value, (int, float)):
471+
mask_value = lit(mask_value)
472+
jfcn = RFContext.active().lookup('rf_mask_by_value')
473+
474+
return Column(jfcn(_to_java_column(data_tile), _to_java_column(mask_tile), _to_java_column(mask_value), inverse))
475+
476+
477+
def rf_mask_by_values(data_tile, mask_tile, mask_values, inverse=False):
478+
"""Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
479+
list, replace the value with NODATA.
480+
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA
481+
"""
482+
from pyspark.sql.functions import array as sql_array
483+
if isinstance(mask_values, list):
484+
mask_values = sql_array([lit(v) for v in mask_values])
485+
486+
jfcn = RFContext.active().lookup('rf_mask_by_values')
487+
col_args = [_to_java_column(c) for c in [data_tile, mask_tile, mask_values]]
488+
return Column(jfcn(*col_args, inverse))
489+
490+
491+
def rf_inverse_mask_by_value(data_tile, mask_tile, mask_value):
492+
"""Generate a tile with the values from the data tile, but where cells in the masking tile do not contain the
493+
masking value, replace the data value with NODATA. """
494+
if isinstance(mask_value, (int, float)):
495+
mask_value = lit(mask_value)
496+
return _apply_column_function('rf_inverse_mask_by_value', data_tile, mask_tile, mask_value)
497+
498+
472499
def rf_local_less(left_tile_col, right_tile_col):
473500
"""Cellwise less than comparison between two tiles"""
474501
return _apply_column_function('rf_local_less', left_tile_col, right_tile_col)

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from pyspark import Row
2525
from pyspark.sql.functions import *
2626

27+
import numpy as np
28+
from numpy.testing import assert_equal
2729

2830
from . import TestEnvironment
2931

@@ -103,7 +105,6 @@ def test_agg_mean(self):
103105
def test_agg_local_mean(self):
104106
from pyspark.sql import Row
105107
from pyrasterframes.rf_types import Tile
106-
import numpy as np
107108

108109
# this is really testing the nodata propagation in the agg local summation
109110
ct = CellType.int8().with_no_data_value(4)
@@ -221,20 +222,51 @@ def test_mask_by_value(self):
221222
rf_local_greater_int(self.rf.tile, 25000),
222223
"uint8"),
223224
lit(mask_value)).alias('mask'))
224-
rf2 = rf1.select(rf1.tile, rf_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked'))
225+
rf2 = rf1.select(rf1.tile, rf_mask_by_value(rf1.tile, rf1.mask, lit(mask_value), False).alias('masked'))
225226
result = rf2.agg(rf_agg_no_data_cells(rf2.tile) < rf_agg_no_data_cells(rf2.masked)) \
226227
.collect()[0][0]
227228
self.assertTrue(result)
228229

229-
rf3 = rf1.select(rf1.tile, rf_inverse_mask_by_value(rf1.tile, rf1.mask, lit(mask_value)).alias('masked'))
230-
result = rf3.agg(rf_agg_no_data_cells(rf3.tile) < rf_agg_no_data_cells(rf3.masked)) \
231-
.collect()[0][0]
232-
self.assertTrue(result)
230+
# note supplying a `int` here, not a column to mask value
231+
rf3 = rf1.select(
232+
rf1.tile,
233+
rf_inverse_mask_by_value(rf1.tile, rf1.mask, mask_value).alias('masked'),
234+
rf_mask_by_value(rf1.tile, rf1.mask, mask_value, True).alias('masked2'),
235+
)
236+
result = rf3.agg(
237+
rf_agg_no_data_cells(rf3.tile) < rf_agg_no_data_cells(rf3.masked),
238+
rf_agg_no_data_cells(rf3.tile) < rf_agg_no_data_cells(rf3.masked2),
239+
) \
240+
.first()
241+
self.assertTrue(result[0])
242+
self.assertTrue(result[1]) # inverse mask arg gives equivalent result
243+
244+
result_equiv_tiles = rf3.select(rf_for_all(rf_local_equal(rf3.masked, rf3.masked2))).first()[0]
245+
self.assertTrue(result_equiv_tiles) # inverse fn and inverse arg produce same Tile
246+
247+
def test_mask_by_values(self):
248+
249+
tile = Tile(np.random.randint(1, 100, (5, 5)), CellType.uint8())
250+
mask_tile = Tile(np.array(range(1, 26), 'uint8').reshape(5, 5))
251+
expected_diag_nd = Tile(np.ma.masked_array(tile.cells, mask=np.eye(5)))
252+
expected_off_diag_nd = Tile(np.ma.masked_array(tile.cells, mask=1 - np.eye(5)))
253+
254+
df = self.spark.createDataFrame([Row(t=tile, m=mask_tile)]) \
255+
.select(rf_mask_by_values('t', 'm', [0, 6, 12, 18, 24])) # values on the diagonal
256+
result0 = df.first()
257+
# assert_equal(result0[0].cells, expected_diag_nd)
258+
self.assertTrue(result0[0] == expected_diag_nd)
259+
260+
# mask values off the diagonal! (inverse=True)
261+
result1 = self.spark.createDataFrame([Row(t=tile, m=mask_tile)]) \
262+
.select(rf_mask_by_values('t', 'm', [0, 6, 12, 18, 24], True)) \
263+
.first()
264+
# assert_equal(result1[0].cells, expected_off_diag_nd)
265+
self.assertTrue(result1[0] == expected_off_diag_nd)
233266

234267
def test_mask(self):
235268
from pyspark.sql import Row
236269
from pyrasterframes.rf_types import Tile, CellType
237-
import numpy as np
238270

239271
np.random.seed(999)
240272
ma = np.ma.array(np.random.randint(0, 10, (5, 5), dtype='int8'), mask=np.random.rand(5, 5) > 0.7)
@@ -326,7 +358,6 @@ def test_render_composite(self):
326358
def test_rf_interpret_cell_type_as(self):
327359
from pyspark.sql import Row
328360
from pyrasterframes.rf_types import Tile
329-
import numpy as np
330361

331362
df = self.spark.createDataFrame([
332363
Row(t=Tile(np.array([[1, 3, 4], [5, 0, 3]]), CellType.uint8().with_no_data_value(5)))
@@ -341,8 +372,6 @@ def test_rf_interpret_cell_type_as(self):
341372
def test_rf_local_data_and_no_data(self):
342373
from pyspark.sql import Row
343374
from pyrasterframes.rf_types import Tile
344-
import numpy as np
345-
from numpy.testing import assert_equal
346375

347376
nd = 5
348377
t = Tile(
@@ -363,8 +392,6 @@ def test_rf_local_data_and_no_data(self):
363392
def test_rf_local_is_in(self):
364393
from pyspark.sql.functions import lit, array, col
365394
from pyspark.sql import Row
366-
import numpy as np
367-
from numpy.testing import assert_equal
368395

369396
nd = 5
370397
t = Tile(

0 commit comments

Comments
 (0)