Skip to content

Commit 3c1b728

Browse files
committed
more docs and unit tests around local algebra and nodata
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 438cfa1 commit 3c1b728

File tree

5 files changed

+84
-25
lines changed

5 files changed

+84
-25
lines changed

pyrasterframes/src/main/python/docs/reference.pymd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ If input `tile` had a NoData value already, the behaviour depends on if its cell
238238

239239
[Local map algebra](https://gisgeography.com/map-algebra-global-zonal-focal-local/) raster operations are element-wise operations on a single tile (unary), between a `tile` and a scalar, between two `tile`s, or across many `tile`s.
240240

241+
When these operations encounter a NoData value in either operand, the cell in the resulting `tile` will have a NoData.
242+
241243
The binary local map algebra functions have similar variations in the Python API depending on the left hand side type:
242244

243245
- `rf_local_op`: applies `op` to two columns; the right hand side can be a `tile` or a numeric column.
@@ -536,6 +538,8 @@ Aggregates over all of the rows in DataFrame of `tile` and returns a count of ea
536538

537539
Local statistics compute the element-wise statistics across a DataFrame or group of `tile`s, resulting in a `tile` that has the same dimension.
538540

541+
When these functions encounter a NoData in a cell location, it will be ignored.
542+
539543
### rf_agg_local_max
540544

541545
Tile rf_agg_local_max(Tile tile)

pyrasterframes/src/main/python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def initialize_options(self):
144144
'setuptools>=0.8',
145145
'ipython==6.2.1',
146146
"ipykernel==4.8.0",
147-
'Pweave==0.30.3'
147+
'Pweave==0.30.3',
148148
],
149149
tests_require=[
150150
'pytest==3.4.2',

pyrasterframes/src/main/python/tests/GeoTiffWriterTests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_identity_write(self):
3737
dest = self._tmpfile()
3838
rf.write.geotiff(dest)
3939

40-
rf2 = self.spark.read.geotiff(dest)
40+
rf2 = self.spark.read.geotiff('file://' + dest)
4141
self.assertEqual(rf2.count(), rf.count())
4242

4343
os.remove(dest)
@@ -47,7 +47,7 @@ def test_unstructured_write(self):
4747
dest = self._tmpfile()
4848
rf.write.geotiff(dest, crs='EPSG:32616')
4949

50-
rf2 = self.spark.read.raster(dest)
50+
rf2 = self.spark.read.raster('file://' + dest)
5151
self.assertEqual(rf2.count(), rf.count())
5252

5353
os.remove(dest)

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from . import TestEnvironment
2929

3030

31-
3231
class CellTypeHandling(unittest.TestCase):
3332

3433
def test_is_raw(self):
@@ -237,11 +236,16 @@ def less_pi(t):
237236
class TileOps(TestEnvironment):
238237

239238
def setUp(self):
239+
from pyspark.sql import Row
240240
# convenience so we can assert around Tile() == Tile()
241241
self.t1 = Tile(np.array([[1, 2],
242242
[3, 4]]), CellType.int8().with_no_data_value(3))
243243
self.t2 = Tile(np.array([[1, 2],
244244
[3, 4]]), CellType.int8().with_no_data_value(1))
245+
self.t3 = Tile(np.array([[1, 2],
246+
[-3, 4]]), CellType.int8().with_no_data_value(3))
247+
248+
self.df = self.spark.createDataFrame([Row(t1=self.t1, t2=self.t2, t3=self.t3)])
245249

246250
def test_addition(self):
247251
e1 = np.ma.masked_equal(np.array([[5, 6],
@@ -253,6 +257,9 @@ def test_addition(self):
253257
r2 = (self.t1 + self.t2).cells
254258
self.assertTrue(np.ma.allequal(r2, e2))
255259

260+
col_result = self.df.select(rf_local_add('t1', 't3').alias('sum')).first()
261+
self.assertEqual(col_result.sum, self.t1 + self.t3)
262+
256263
def test_multiplication(self):
257264
e1 = np.ma.masked_equal(np.array([[4, 8],
258265
[12, 16]]), 12)
@@ -263,6 +270,9 @@ def test_multiplication(self):
263270
r2 = (self.t1 * self.t2).cells
264271
self.assertTrue(np.ma.allequal(r2, e2))
265272

273+
r3 = self.df.select(rf_local_multiply('t1', 't3').alias('r3')).first().r3
274+
self.assertEqual(r3, self.t1 * self.t3)
275+
266276
def test_subtraction(self):
267277
t3 = self.t1 * 4
268278
r1 = t3 - self.t1
@@ -541,7 +551,6 @@ def path(scene, band):
541551
self.assertTrue(df2.select('b1_path').distinct().count() == 3)
542552

543553

544-
545554
def suite():
546555
function_tests = unittest.TestSuite()
547556
return function_tests

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,21 @@ def test_agg_mean(self):
9797
mean = self.rf.agg(rf_agg_mean('tile')).first()['rf_agg_mean(tile)']
9898
self.assertTrue(self.rounded_compare(mean, 10160))
9999

100+
def test_agg_local_mean(self):
101+
from pyspark.sql import Row
102+
103+
# this is really testing the nodata propagation in the agg local summation
104+
ct = CellType.int8().with_no_data_value(4)
105+
df = self.spark.createDataFrame([
106+
Row(tile=Tile(np.array([[1, 2, 3, 4, 5, 6]]), ct)),
107+
Row(tile=Tile(np.array([[1, 2, 4, 3, 5, 6]]), ct)),
108+
])
109+
110+
result = df.agg(rf_agg_local_mean('tile').alias('mean')).first().mean
111+
112+
expected = Tile(np.array([[1.0, 2.0, 3.0, 3.0, 5.0, 6.0]]), CellType.float64())
113+
self.assertEqual(result, expected)
114+
100115
def test_aggregations(self):
101116
aggs = self.rf.agg(
102117
rf_agg_data_cells('tile'),
@@ -112,28 +127,59 @@ def test_aggregations(self):
112127
self.assertEqual(row['rf_agg_stats(tile)'].data_cells, row['rf_agg_data_cells(tile)'])
113128

114129
def test_sql(self):
130+
115131
self.rf.createOrReplaceTempView("rf_test_sql")
116132

117-
self.spark.sql("""SELECT tile,
118-
rf_local_add(tile, 1) AS and_one,
119-
rf_local_subtract(tile, 1) AS less_one,
120-
rf_local_multiply(tile, 2) AS times_two,
121-
rf_local_divide(tile, 2) AS over_two
122-
FROM rf_test_sql""").createOrReplaceTempView('rf_test_sql_1')
123-
124-
statsRow = self.spark.sql("""
125-
SELECT rf_tile_mean(tile) as base,
126-
rf_tile_mean(and_one) as plus_one,
127-
rf_tile_mean(less_one) as minus_one,
128-
rf_tile_mean(times_two) as double,
129-
rf_tile_mean(over_two) as half
130-
FROM rf_test_sql_1
131-
""").first()
132-
133-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.plus_one - 1))
134-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.minus_one + 1))
135-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.double / 2))
136-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.half * 2))
133+
arith = self.spark.sql("""SELECT tile,
134+
rf_local_add(tile, 1) AS add_one,
135+
rf_local_subtract(tile, 1) AS less_one,
136+
rf_local_multiply(tile, 2) AS times_two,
137+
rf_local_divide(
138+
rf_convert_cell_type(tile, "float32"),
139+
2) AS over_two
140+
FROM rf_test_sql""")
141+
142+
arith.createOrReplaceTempView('rf_test_sql_1')
143+
arith.show(truncate=False)
144+
stats = self.spark.sql("""
145+
SELECT rf_tile_mean(tile) as base,
146+
rf_tile_mean(add_one) as plus_one,
147+
rf_tile_mean(less_one) as minus_one,
148+
rf_tile_mean(times_two) as double,
149+
rf_tile_mean(over_two) as half,
150+
rf_no_data_cells(tile) as nd
151+
152+
FROM rf_test_sql_1
153+
ORDER BY rf_no_data_cells(tile)
154+
""")
155+
stats.show(truncate=False)
156+
stats.createOrReplaceTempView('rf_test_sql_stats')
157+
158+
compare = self.spark.sql("""
159+
SELECT
160+
plus_one - 1.0 = base as add,
161+
minus_one + 1.0 = base as subtract,
162+
double / 2.0 = base as multiply,
163+
half * 2.0 = base as divide,
164+
nd
165+
FROM rf_test_sql_stats
166+
""")
167+
168+
expect_row1 = compare.orderBy('nd').first()
169+
170+
self.assertTrue(expect_row1.subtract)
171+
self.assertTrue(expect_row1.multiply)
172+
self.assertTrue(expect_row1.divide)
173+
self.assertEqual(expect_row1.nd, 0)
174+
self.assertTrue(expect_row1.add)
175+
176+
expect_row2 = compare.orderBy('nd', ascending=False).first()
177+
178+
self.assertTrue(expect_row2.subtract)
179+
self.assertTrue(expect_row2.multiply)
180+
self.assertTrue(expect_row2.divide)
181+
self.assertTrue(expect_row2.nd > 0)
182+
self.assertTrue(expect_row2.add) # <-- Would fail in a case where ND + 1 = 1
137183

138184
def test_explode(self):
139185
import pyspark.sql.functions as F

0 commit comments

Comments
 (0)