Skip to content

Commit 38b8c2d

Browse files
authored
Merge pull request #236 from s22s/fix/62
Fixed `Add` expression to use GT behavior
2 parents b307809 + 3c1b728 commit 38b8c2d

File tree

8 files changed

+168
-35
lines changed

8 files changed

+168
-35
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/LocalMeanAggregate.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
package org.locationtech.rasterframes.expressions.aggregates
2323

2424
import org.locationtech.rasterframes.expressions.UnaryRasterAggregate
25-
import org.locationtech.rasterframes.expressions.localops.{Add => AddTiles, Divide => DivideTiles}
25+
import org.locationtech.rasterframes.expressions.localops.{BiasedAdd, Divide => DivideTiles}
2626
import org.locationtech.rasterframes.expressions.transformers.SetCellType
2727
import geotrellis.raster.Tile
2828
import geotrellis.raster.mapalgebra.local
@@ -59,16 +59,16 @@ case class LocalMeanAggregate(child: Expression) extends UnaryRasterAggregate {
5959
override lazy val updateExpressions: Seq[Expression] = Seq(
6060
If(IsNull(count),
6161
SetCellType(Defined(child), Literal("int32")),
62-
If(IsNull(child), count, AddTiles(count, Defined(child)))
62+
If(IsNull(child), count, BiasedAdd(count, Defined(child)))
6363
),
6464
If(IsNull(sum),
6565
SetCellType(child, Literal("float64")),
66-
If(IsNull(child), sum, AddTiles(sum, child))
66+
If(IsNull(child), sum, BiasedAdd(sum, child))
6767
)
6868
)
6969
override val mergeExpressions: Seq[Expression] = Seq(
70-
AddTiles(count.left, count.right),
71-
AddTiles(sum.left, sum.right)
70+
BiasedAdd(count.left, count.right),
71+
BiasedAdd(sum.left, sum.right)
7272
)
7373
override lazy val evaluateExpression: Expression = DivideTiles(sum, count)
7474
}

core/src/main/scala/org/locationtech/rasterframes/expressions/localops/Add.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ import org.apache.spark.sql.{Column, TypedColumn}
3030
import org.locationtech.rasterframes._
3131
import org.locationtech.rasterframes.expressions.BinaryLocalRasterOp
3232
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
33-
import org.locationtech.rasterframes.util.DataBiasedOp.BiasedAdd
3433

3534
@ExpressionDescription(
3635
usage = "_FUNC_(tile, rhs) - Performs cell-wise addition between two tiles or a tile and a scalar.",
@@ -48,9 +47,9 @@ import org.locationtech.rasterframes.util.DataBiasedOp.BiasedAdd
4847
case class Add(left: Expression, right: Expression) extends BinaryLocalRasterOp
4948
with CodegenFallback {
5049
override val nodeName: String = "rf_local_add"
51-
override protected def op(left: Tile, right: Tile): Tile = BiasedAdd(left, right)
52-
override protected def op(left: Tile, right: Double): Tile = BiasedAdd(left, right)
53-
override protected def op(left: Tile, right: Int): Tile = BiasedAdd(left, right)
50+
override protected def op(left: Tile, right: Tile): Tile = left.localAdd(right)
51+
override protected def op(left: Tile, right: Double): Tile = left.localAdd(right)
52+
override protected def op(left: Tile, right: Int): Tile = left.localAdd(right)
5453

5554
override def eval(input: InternalRow): Any = {
5655
if(input == null) null
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.localops
23+
import geotrellis.raster.Tile
24+
import org.apache.spark.sql.catalyst.InternalRow
25+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
26+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
27+
import org.apache.spark.sql.functions.lit
28+
import org.apache.spark.sql.{Column, TypedColumn}
29+
import org.locationtech.rasterframes._
30+
import org.locationtech.rasterframes.expressions.BinaryLocalRasterOp
31+
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
32+
import org.locationtech.rasterframes.util.DataBiasedOp
33+
34+
@ExpressionDescription(
35+
usage = "_FUNC_(tile, rhs) - Performs cell-wise addition between two tiles or a tile and a scalar. " +
36+
"Unlike a regular 'add', this considers `<data> + <nodata> = <data>.",
37+
arguments = """
38+
Arguments:
39+
* tile - left-hand-side tile
40+
* rhs - a tile or scalar value to add to each cell""",
41+
examples = """
42+
Examples:
43+
> SELECT _FUNC_(tile, 1.5);
44+
...
45+
> SELECT _FUNC_(tile1, tile2);
46+
..."""
47+
)
48+
case class BiasedAdd(left: Expression, right: Expression) extends BinaryLocalRasterOp
49+
with CodegenFallback {
50+
override val nodeName: String = "rf_local_biased_add"
51+
override protected def op(left: Tile, right: Tile): Tile = DataBiasedOp.BiasedAdd(left, right)
52+
override protected def op(left: Tile, right: Double): Tile = DataBiasedOp.BiasedAdd(left, right)
53+
override protected def op(left: Tile, right: Int): Tile = DataBiasedOp.BiasedAdd(left, right)
54+
55+
override def eval(input: InternalRow): Any = {
56+
if(input == null) null
57+
else {
58+
val l = left.eval(input)
59+
val r = right.eval(input)
60+
if (l == null && r == null) null
61+
else if (l == null) r
62+
else if (r == null && tileExtractor.isDefinedAt(right.dataType)) l
63+
else if (r == null) null
64+
else nullSafeEval(l, r)
65+
}
66+
}
67+
}
68+
object BiasedAdd {
69+
def apply(left: Column, right: Column): TypedColumn[Any, Tile] =
70+
new Column(BiasedAdd(left.expr, right.expr)).as[Tile]
71+
72+
def apply[N: Numeric](tile: Column, value: N): TypedColumn[Any, Tile] =
73+
new Column(BiasedAdd(tile.expr, lit(value).expr)).as[Tile]
74+
}

pyrasterframes/src/main/python/docs/reference.pymd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ If input `tile` had a NoData value already, the behaviour depends on if its cell
238238

239239
[Local map algebra](https://gisgeography.com/map-algebra-global-zonal-focal-local/) raster operations are element-wise operations on a single tile (unary), between a `tile` and a scalar, between two `tile`s, or across many `tile`s.
240240

241+
When these operations encounter a NoData value in either operand, the cell in the resulting `tile` will have a NoData.
242+
241243
The binary local map algebra functions have similar variations in the Python API depending on the left hand side type:
242244

243245
- `rf_local_op`: applies `op` to two columns; the right hand side can be a `tile` or a numeric column.
@@ -536,6 +538,8 @@ Aggregates over all of the rows in DataFrame of `tile` and returns a count of ea
536538

537539
Local statistics compute the element-wise statistics across a DataFrame or group of `tile`s, resulting in a `tile` that has the same dimension.
538540

541+
When these functions encounter a NoData in a cell location, it will be ignored.
542+
539543
### rf_agg_local_max
540544

541545
Tile rf_agg_local_max(Tile tile)

pyrasterframes/src/main/python/setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,15 +144,16 @@ def initialize_options(self):
144144
'setuptools>=0.8',
145145
'ipython==6.2.1',
146146
"ipykernel==4.8.0",
147-
'Pweave==0.30.3'
147+
'Pweave==0.30.3',
148148
],
149149
tests_require=[
150150
'pytest==3.4.2',
151151
'pypandoc',
152152
'numpy>=1.7',
153153
'shapely',
154154
'pandas',
155-
'rasterio'
155+
'rasterio',
156+
'boto3'
156157
],
157158
packages=[
158159
'pyrasterframes',

pyrasterframes/src/main/python/tests/GeoTiffWriterTests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_identity_write(self):
3737
dest = self._tmpfile()
3838
rf.write.geotiff(dest)
3939

40-
rf2 = self.spark.read.geotiff(dest)
40+
rf2 = self.spark.read.geotiff('file://' + dest)
4141
self.assertEqual(rf2.count(), rf.count())
4242

4343
os.remove(dest)
@@ -47,7 +47,7 @@ def test_unstructured_write(self):
4747
dest = self._tmpfile()
4848
rf.write.geotiff(dest, crs='EPSG:32616')
4949

50-
rf2 = self.spark.read.raster(dest)
50+
rf2 = self.spark.read.raster('file://' + dest)
5151
self.assertEqual(rf2.count(), rf.count())
5252

5353
os.remove(dest)

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from . import TestEnvironment
2929

3030

31-
3231
class CellTypeHandling(unittest.TestCase):
3332

3433
def test_is_raw(self):
@@ -237,11 +236,16 @@ def less_pi(t):
237236
class TileOps(TestEnvironment):
238237

239238
def setUp(self):
239+
from pyspark.sql import Row
240240
# convenience so we can assert around Tile() == Tile()
241241
self.t1 = Tile(np.array([[1, 2],
242242
[3, 4]]), CellType.int8().with_no_data_value(3))
243243
self.t2 = Tile(np.array([[1, 2],
244244
[3, 4]]), CellType.int8().with_no_data_value(1))
245+
self.t3 = Tile(np.array([[1, 2],
246+
[-3, 4]]), CellType.int8().with_no_data_value(3))
247+
248+
self.df = self.spark.createDataFrame([Row(t1=self.t1, t2=self.t2, t3=self.t3)])
245249

246250
def test_addition(self):
247251
e1 = np.ma.masked_equal(np.array([[5, 6],
@@ -253,6 +257,9 @@ def test_addition(self):
253257
r2 = (self.t1 + self.t2).cells
254258
self.assertTrue(np.ma.allequal(r2, e2))
255259

260+
col_result = self.df.select(rf_local_add('t1', 't3').alias('sum')).first()
261+
self.assertEqual(col_result.sum, self.t1 + self.t3)
262+
256263
def test_multiplication(self):
257264
e1 = np.ma.masked_equal(np.array([[4, 8],
258265
[12, 16]]), 12)
@@ -263,6 +270,9 @@ def test_multiplication(self):
263270
r2 = (self.t1 * self.t2).cells
264271
self.assertTrue(np.ma.allequal(r2, e2))
265272

273+
r3 = self.df.select(rf_local_multiply('t1', 't3').alias('r3')).first().r3
274+
self.assertEqual(r3, self.t1 * self.t3)
275+
266276
def test_subtraction(self):
267277
t3 = self.t1 * 4
268278
r1 = t3 - self.t1
@@ -541,7 +551,6 @@ def path(scene, band):
541551
self.assertTrue(df2.select('b1_path').distinct().count() == 3)
542552

543553

544-
545554
def suite():
546555
function_tests = unittest.TestSuite()
547556
return function_tests

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,21 @@ def test_agg_mean(self):
9797
mean = self.rf.agg(rf_agg_mean('tile')).first()['rf_agg_mean(tile)']
9898
self.assertTrue(self.rounded_compare(mean, 10160))
9999

100+
def test_agg_local_mean(self):
101+
from pyspark.sql import Row
102+
103+
# this is really testing the nodata propagation in the agg local summation
104+
ct = CellType.int8().with_no_data_value(4)
105+
df = self.spark.createDataFrame([
106+
Row(tile=Tile(np.array([[1, 2, 3, 4, 5, 6]]), ct)),
107+
Row(tile=Tile(np.array([[1, 2, 4, 3, 5, 6]]), ct)),
108+
])
109+
110+
result = df.agg(rf_agg_local_mean('tile').alias('mean')).first().mean
111+
112+
expected = Tile(np.array([[1.0, 2.0, 3.0, 3.0, 5.0, 6.0]]), CellType.float64())
113+
self.assertEqual(result, expected)
114+
100115
def test_aggregations(self):
101116
aggs = self.rf.agg(
102117
rf_agg_data_cells('tile'),
@@ -112,28 +127,59 @@ def test_aggregations(self):
112127
self.assertEqual(row['rf_agg_stats(tile)'].data_cells, row['rf_agg_data_cells(tile)'])
113128

114129
def test_sql(self):
130+
115131
self.rf.createOrReplaceTempView("rf_test_sql")
116132

117-
self.spark.sql("""SELECT tile,
118-
rf_local_add(tile, 1) AS and_one,
119-
rf_local_subtract(tile, 1) AS less_one,
120-
rf_local_multiply(tile, 2) AS times_two,
121-
rf_local_divide(tile, 2) AS over_two
122-
FROM rf_test_sql""").createOrReplaceTempView('rf_test_sql_1')
123-
124-
statsRow = self.spark.sql("""
125-
SELECT rf_tile_mean(tile) as base,
126-
rf_tile_mean(and_one) as plus_one,
127-
rf_tile_mean(less_one) as minus_one,
128-
rf_tile_mean(times_two) as double,
129-
rf_tile_mean(over_two) as half
130-
FROM rf_test_sql_1
131-
""").first()
132-
133-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.plus_one - 1))
134-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.minus_one + 1))
135-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.double / 2))
136-
self.assertTrue(self.rounded_compare(statsRow.base, statsRow.half * 2))
133+
arith = self.spark.sql("""SELECT tile,
134+
rf_local_add(tile, 1) AS add_one,
135+
rf_local_subtract(tile, 1) AS less_one,
136+
rf_local_multiply(tile, 2) AS times_two,
137+
rf_local_divide(
138+
rf_convert_cell_type(tile, "float32"),
139+
2) AS over_two
140+
FROM rf_test_sql""")
141+
142+
arith.createOrReplaceTempView('rf_test_sql_1')
143+
arith.show(truncate=False)
144+
stats = self.spark.sql("""
145+
SELECT rf_tile_mean(tile) as base,
146+
rf_tile_mean(add_one) as plus_one,
147+
rf_tile_mean(less_one) as minus_one,
148+
rf_tile_mean(times_two) as double,
149+
rf_tile_mean(over_two) as half,
150+
rf_no_data_cells(tile) as nd
151+
152+
FROM rf_test_sql_1
153+
ORDER BY rf_no_data_cells(tile)
154+
""")
155+
stats.show(truncate=False)
156+
stats.createOrReplaceTempView('rf_test_sql_stats')
157+
158+
compare = self.spark.sql("""
159+
SELECT
160+
plus_one - 1.0 = base as add,
161+
minus_one + 1.0 = base as subtract,
162+
double / 2.0 = base as multiply,
163+
half * 2.0 = base as divide,
164+
nd
165+
FROM rf_test_sql_stats
166+
""")
167+
168+
expect_row1 = compare.orderBy('nd').first()
169+
170+
self.assertTrue(expect_row1.subtract)
171+
self.assertTrue(expect_row1.multiply)
172+
self.assertTrue(expect_row1.divide)
173+
self.assertEqual(expect_row1.nd, 0)
174+
self.assertTrue(expect_row1.add)
175+
176+
expect_row2 = compare.orderBy('nd', ascending=False).first()
177+
178+
self.assertTrue(expect_row2.subtract)
179+
self.assertTrue(expect_row2.multiply)
180+
self.assertTrue(expect_row2.divide)
181+
self.assertTrue(expect_row2.nd > 0)
182+
self.assertTrue(expect_row2.add) # <-- Would fail in a case where ND + 1 = 1
137183

138184
def test_explode(self):
139185
import pyspark.sql.functions as F

0 commit comments

Comments
 (0)