Skip to content

Commit 25d117c

Browse files
authored
Merge pull request #363 from s22s/feature/rf_mask_unit_test
Add rf_mask unit test in python, import geomesa types with module import
2 parents b7f58e9 + cd48803 commit 25d117c

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

pyrasterframes/src/main/python/docs/vector-data.pymd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ RasterFrames provides a variety of ways to work with spatial vector data (points
1111
```python, setup, echo=False
1212
import pyrasterframes
1313
import pyrasterframes.rf_ipython
14-
import geomesa_pyspark.types
1514
import geopandas
1615
import folium
1716
spark = pyrasterframes.get_spark_session('local[2]')

pyrasterframes/src/main/python/pyrasterframes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .rf_context import RFContext
3434
from .version import __version__
3535
from .rf_types import RasterFrameLayer, TileExploder, TileUDT, RasterSourceUDT
36+
import geomesa_pyspark.types # enable vector integrations
3637

3738
__all__ = ['RasterFrameLayer', 'TileExploder']
3839

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,29 @@ def test_mask_by_value(self):
231231
.collect()[0][0]
232232
self.assertTrue(result)
233233

234+
def test_mask(self):
235+
from pyspark.sql import Row
236+
from pyrasterframes.rf_types import Tile, CellType
237+
import numpy as np
238+
239+
np.random.seed(999)
240+
ma = np.ma.array(np.random.randint(0, 10, (5, 5), dtype='int8'), mask=np.random.rand(5, 5) > 0.7)
241+
expected_data_values = ma.compressed().size
242+
expected_no_data_values = ma.size - expected_data_values
243+
self.assertTrue(expected_data_values > 0, "Make sure random seed is cooperative ")
244+
self.assertTrue(expected_no_data_values > 0, "Make sure random seed is cooperative ")
245+
246+
df = self.spark.createDataFrame([
247+
Row(t=Tile(np.ones(ma.shape, ma.dtype)), m=Tile(ma))
248+
])
249+
250+
df = df.withColumn('masked_t', rf_mask('t', 'm'))
251+
result = df.select(rf_data_cells('masked_t')).first()[0]
252+
self.assertEqual(result, expected_data_values)
253+
254+
nd_result = df.select(rf_no_data_cells('masked_t')).first()[0]
255+
self.assertEqual(nd_result, expected_no_data_values)
256+
234257
def test_resample(self):
235258
from pyspark.sql.functions import lit
236259
result = self.rf.select(

0 commit comments

Comments
 (0)