Skip to content

Commit eb899de

Browse files
committed
rf_local_is_in python implementation
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 4b63f01 commit eb899de

File tree

6 files changed

+60
-7
lines changed

6 files changed

+60
-7
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import org.locationtech.rasterframes.expressions._
4444
""",
4545
examples = """
4646
Examples:
47-
> SELECT _FUNC_(tile, array);
47+
> SELECT _FUNC_(tile, array(lit(33), lit(66), lit(99)));
4848
..."""
4949
)
5050
case class IsIn(left: Expression, right: Expression) extends BinaryExpression with CodegenFallback {

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
* _Breaking_ (potentially): removed `GeoTiffCollectionRelation` due to usage limitation and overlap with `RasterSourceDataSource` functionality.
88
* Upgraded to Spark 2.4.4
9+
* Add `rf_local_is_in` raster function
910

1011
### 0.8.3
1112

pyrasterframes/src/main/python/docs/reference.pymd

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ Parameters `tile_columns` and `tile_rows` are literals, not column expressions.
183183

184184
Tile rf_array_to_tile(Array arrayCol, Int numCols, Int numRows)
185185

186-
Python only. Create a `tile` from a Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), filling values in row-major order.
186+
Python only. Create a `tile` from a Spark SQL [Array][Array], filling values in row-major order.
187187

188188
### rf_assemble_tile
189189

@@ -374,6 +374,13 @@ Returns a `tile` column containing the element-wise equality of `tile1` and `rhs
374374

375375
Returns a `tile` column containing the element-wise inequality of `tile1` and `rhs`.
376376

377+
### rf_local_is_in
378+
379+
Tile rf_local_is_in(Tile tile, Array array)
380+
Tile rf_local_is_in(Tile tile, list l)
381+
382+
Returns a `tile` column with cell values of 1 where the `tile` cell value is in the provided array or list. The `array` is a Spark SQL [Array][Array]. A python `list` of numeric values can also be passed.
383+
377384
### rf_round
378385

379386
Tile rf_round(Tile tile)
@@ -621,13 +628,13 @@ Python only. As with @ref:[`rf_explode_tiles`](reference.md#rf-explode-tiles), b
621628

622629
Array rf_tile_to_array_int(Tile tile)
623630

624-
Convert Tile column to Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Float cell types will be coerced to integral type by flooring.
631+
Convert Tile column to Spark SQL [Array][Array], in row-major order. Float cell types will be coerced to integral type by flooring.
625632

626633
### rf_tile_to_array_double
627634

628635
Array rf_tile_to_arry_double(Tile tile)
629636

630-
Convert tile column to Spark [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Integral cell types will be coerced to floats.
637+
Convert tile column to Spark [Array][Array], in row-major order. Integral cell types will be coerced to floats.
631638

632639
### rf_render_ascii
633640

@@ -657,3 +664,4 @@ Runs [`rf_rgb_composite`](reference.md#rf-rgb-composite) on the given tile colum
657664

658665
[RasterFunctions]: org.locationtech.rasterframes.RasterFunctions
659666
[scaladoc]: latest/api/index.html
667+
[Array]: http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType

pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,24 @@ def rf_local_unequal_int(tile_col, scalar):
260260
"""Return a Tile with values equal 1 if the cell is not equal to a scalar, otherwise 0"""
261261
return _apply_scalar_to_tile('rf_local_unequal_int', tile_col, scalar)
262262

263+
263264
def rf_local_no_data(tile_col):
264265
"""Return a tile with ones where the input is NoData, otherwise zero."""
265266
return _apply_column_function('rf_local_no_data', tile_col)
266267

268+
267269
def rf_local_data(tile_col):
268270
"""Return a tile with zeros where the input is NoData, otherwise one."""
269271
return _apply_column_function('rf_local_data', tile_col)
270272

273+
def rf_local_is_in(tile_col, array):
274+
"""Return a tile with cell values of 1 where the `tile_col` cell is in the provided array."""
275+
from pyspark.sql.functions import array as sql_array, lit
276+
if isinstance(array, list):
277+
array = sql_array([lit(v) for v in array])
278+
279+
return _apply_column_function('rf_local_is_in', tile_col, array)
280+
271281
def _apply_column_function(name, *args):
272282
jfcn = RFContext.active().lookup(name)
273283
jcols = [_to_java_column(arg) for arg in args]

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_tile_udt_serialization(self):
131131
cells[1][1] = nd
132132
a_tile = Tile(cells, ct.with_no_data_value(nd))
133133
round_trip = udt.fromInternal(udt.toInternal(a_tile))
134-
self.assertEquals(a_tile, round_trip, "round-trip serialization for " + str(ct))
134+
self.assertEqual(a_tile, round_trip, "round-trip serialization for " + str(ct))
135135

136136
schema = StructType([StructField("tile", TileUDT(), False)])
137137
df = self.spark.createDataFrame([{"tile": a_tile}], schema)

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,8 +347,11 @@ def test_rf_local_data_and_no_data(self):
347347
import numpy as np
348348
from numpy.testing import assert_equal
349349

350-
t = Tile(np.array([[1, 3, 4], [5, 0, 3]]), CellType.uint8().with_no_data_value(5))
351-
#note the convert is due to issue #188
350+
nd = 5
351+
t = Tile(
352+
np.array([[1, 3, 4], [nd, 0, 3]]),
353+
CellType.uint8().with_no_data_value(nd))
354+
# note the convert is due to issue #188
352355
df = self.spark.createDataFrame([Row(t=t)])\
353356
.withColumn('lnd', rf_convert_cell_type(rf_local_no_data('t'), 'uint8')) \
354357
.withColumn('ld', rf_convert_cell_type(rf_local_data('t'), 'uint8'))
@@ -359,3 +362,34 @@ def test_rf_local_data_and_no_data(self):
359362

360363
result_d = result['ld']
361364
assert_equal(result_d.cells, np.invert(t.cells.mask))
365+
366+
def test_rf_local_is_in(self):
367+
from pyspark.sql.functions import lit, array, col
368+
from pyspark.sql import Row
369+
import numpy as np
370+
from numpy.testing import assert_equal
371+
372+
nd = 5
373+
t = Tile(
374+
np.array([[1, 3, 4], [nd, 0, 3]]),
375+
CellType.uint8().with_no_data_value(nd))
376+
# note the convert is due to issue #188
377+
df = self.spark.createDataFrame([Row(t=t)]) \
378+
.withColumn('a', array(lit(3), lit(4))) \
379+
.withColumn('in2', rf_convert_cell_type(
380+
rf_local_is_in(col('t'), array(lit(0), lit(4))),
381+
'uint8')) \
382+
.withColumn('in3', rf_convert_cell_type(rf_local_is_in('t', 'a'), 'uint8')) \
383+
.withColumn('in4', rf_convert_cell_type(
384+
rf_local_is_in('t', array(lit(0), lit(4), lit(3))),
385+
'uint8')) \
386+
.withColumn('in_list', rf_convert_cell_type(rf_local_is_in(col('t'), [4, 1]), 'uint8'))
387+
388+
result = df.first()
389+
self.assertEqual(result['in2'].cells.sum(), 2)
390+
assert_equal(result['in2'].cells, np.isin(t.cells, np.array([0, 4])))
391+
self.assertEqual(result['in3'].cells.sum(), 3)
392+
self.assertEqual(result['in4'].cells.sum(), 4)
393+
self.assertEqual(result['in_list'].cells.sum(), 2,
394+
"Tile value {} should contain two 1s as: [[1, 0, 1],[0, 0, 0]]"
395+
.format(result['in_list'].cells))

0 commit comments

Comments
 (0)