Skip to content

Commit 45ab397

Browse files
authored
Merge pull request #400 from s22s/docs/local_is_in
Add rf_local_is_in function
2 parents 549c308 + 73be68c commit 45ab397

File tree

11 files changed

+206
-51
lines changed

11 files changed

+206
-51
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,9 @@ trait RasterFunctions {
405405
/** Cellwise inequality comparison between a tile and a scalar. */
406406
def rf_local_unequal[T: Numeric](tileCol: Column, value: T): Column = Unequal(tileCol, value)
407407

408+
/** Test if each cell value is in provided array */
409+
def rf_local_is_in(tileCol: Column, arrayCol: Column) = IsIn(tileCol, arrayCol)
410+
408411
/** Return a tile with ones where the input is NoData, otherwise zero */
409412
def rf_local_no_data(tileCol: Column): Column = Undefined(tileCol)
410413

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.localops
23+
24+
import geotrellis.raster.Tile
25+
import geotrellis.raster.mapalgebra.local.IfCell
26+
import org.apache.spark.sql.Column
27+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
28+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
29+
import org.apache.spark.sql.types.{ArrayType, DataType}
30+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
31+
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
32+
import org.apache.spark.sql.catalyst.util.ArrayData
33+
import org.apache.spark.sql.rf.TileUDT
34+
import org.locationtech.rasterframes.encoders.CatalystSerializer._
35+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
36+
import org.locationtech.rasterframes.expressions._
37+
38+
@ExpressionDescription(
39+
usage = "_FUNC_(tile, rhs) - In each cell of `tile`, return true if the value is in rhs.",
40+
arguments = """
41+
Arguments:
42+
* tile - tile column to apply abs
43+
* rhs - array to test against
44+
""",
45+
examples = """
46+
Examples:
47+
> SELECT _FUNC_(tile, array(lit(33), lit(66), lit(99)));
48+
..."""
49+
)
50+
case class IsIn(left: Expression, right: Expression) extends BinaryExpression with CodegenFallback {
51+
override val nodeName: String = "rf_local_is_in"
52+
53+
override def dataType: DataType = left.dataType
54+
55+
@transient private lazy val elementType: DataType = right.dataType.asInstanceOf[ArrayType].elementType
56+
57+
override def checkInputDataTypes(): TypeCheckResult =
58+
if(!tileExtractor.isDefinedAt(left.dataType)) {
59+
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
60+
} else right.dataType match {
61+
case _: ArrayType TypeCheckSuccess
62+
case _ TypeCheckFailure(s"Input type '${right.dataType}' does not conform to ArrayType.")
63+
}
64+
65+
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
66+
implicit val tileSer = TileUDT.tileSerializer
67+
val (childTile, childCtx) = tileExtractor(left.dataType)(row(input1))
68+
69+
val arr = input2.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
70+
71+
childCtx match {
72+
case Some(ctx) => ctx.toProjectRasterTile(op(childTile, arr)).toInternalRow
73+
case None => op(childTile, arr).toInternalRow
74+
}
75+
76+
}
77+
78+
protected def op(left: Tile, right: IndexedSeq[AnyRef]): Tile = {
79+
def fn(i: Int): Boolean = right.contains(i)
80+
IfCell(left, fn(_), 1, 0)
81+
}
82+
83+
}
84+
85+
object IsIn {
86+
def apply(left: Column, right: Column): Column =
87+
new Column(IsIn(left.expr, right.expr))
88+
}

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ package object expressions {
8686
registry.registerExpression[GreaterEqual]("rf_local_greater_equal")
8787
registry.registerExpression[Equal]("rf_local_equal")
8888
registry.registerExpression[Unequal]("rf_local_unequal")
89+
registry.registerExpression[IsIn]("rf_local_is_in")
8990
registry.registerExpression[Undefined]("rf_local_no_data")
9091
registry.registerExpression[Defined]("rf_local_data")
9192
registry.registerExpression[Sum]("rf_tile_sum")

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,4 +972,28 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
972972
val dResult = df.select($"ld").as[Tile].first()
973973
dResult should be (randNDPRT.localDefined())
974974
}
975+
976+
it("should check values isin"){
977+
checkDocs("rf_local_is_in")
978+
979+
// tile is 3 by 3 with values, 1 to 9
980+
val df = Seq(byteArrayTile).toDF("t")
981+
.withColumn("one", lit(1))
982+
.withColumn("five", lit(5))
983+
.withColumn("ten", lit(10))
984+
.withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five")))
985+
.withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five")))
986+
.withColumn("in_expect_0", rf_local_is_in($"t", array($"ten")))
987+
988+
val e2Result = df.select(rf_tile_sum($"in_expect_2")).as[Double].first()
989+
e2Result should be (2.0)
990+
991+
val e1Result = df.select(rf_tile_sum($"in_expect_1")).as[Double].first()
992+
e1Result should be (1.0)
993+
994+
val e0Result = df.select($"in_expect_0").as[Tile].first()
995+
e0Result.toArray() should contain only (0)
996+
997+
// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first()
998+
}
975999
}

pyrasterframes/src/main/python/docs/reference.pymd renamed to docs/src/main/paradox/reference.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ Parameters `tile_columns` and `tile_rows` are literals, not column expressions.
192192

193193
Tile rf_array_to_tile(Array arrayCol, Int numCols, Int numRows)
194194

195-
Python only. Create a `tile` from a Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), filling values in row-major order.
195+
Python only. Create a `tile` from a Spark SQL [Array][Array], filling values in row-major order.
196196

197197
### rf_assemble_tile
198198

@@ -383,6 +383,13 @@ Returns a `tile` column containing the element-wise equality of `tile1` and `rhs
383383

384384
Returns a `tile` column containing the element-wise inequality of `tile1` and `rhs`.
385385

386+
### rf_local_is_in
387+
388+
Tile rf_local_is_in(Tile tile, Array array)
389+
Tile rf_local_is_in(Tile tile, list l)
390+
391+
Returns a `tile` column with cell values of 1 where the `tile` cell value is in the provided array or list. The `array` is a Spark SQL [Array][Array]. A python `list` of numeric values can also be passed.
392+
386393
### rf_round
387394

388395
Tile rf_round(Tile tile)
@@ -630,13 +637,13 @@ Python only. As with @ref:[`rf_explode_tiles`](reference.md#rf-explode-tiles), b
630637

631638
Array rf_tile_to_array_int(Tile tile)
632639

633-
Convert Tile column to Spark SQL [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Float cell types will be coerced to integral type by flooring.
640+
Convert Tile column to Spark SQL [Array][Array], in row-major order. Float cell types will be coerced to integral type by flooring.
634641

635642
### rf_tile_to_array_double
636643

637644
Array rf_tile_to_arry_double(Tile tile)
638645

639-
Convert tile column to Spark [Array](http://spark.apache.org/docs/2.3.2/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType), in row-major order. Integral cell types will be coerced to floats.
646+
Convert tile column to Spark [Array][Array], in row-major order. Integral cell types will be coerced to floats.
640647

641648
### rf_render_ascii
642649

@@ -666,3 +673,4 @@ Runs [`rf_rgb_composite`](reference.md#rf-rgb-composite) on the given tile colum
666673

667674
[RasterFunctions]: org.locationtech.rasterframes.RasterFunctions
668675
[scaladoc]: latest/api/index.html
676+
[Array]: http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.types.ArrayType

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
* _Breaking_ (potentially): removed `GeoTiffCollectionRelation` due to usage limitation and overlap with `RasterSourceDataSource` functionality.
88
* Upgraded to Spark 2.4.4
9+
* Add `rf_local_is_in` raster function
910

1011
### 0.8.3
1112

pyrasterframes/src/main/python/docs/nodata-handling.pymd

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -105,32 +105,23 @@ Drawing on @ref:[local map algebra](local-algebra.md) techniques, we will create
105105
```python, def_mask
106106
from pyspark.sql.functions import lit
107107

108-
mask_part = unmasked.withColumn('nodata', rf_local_equal('scl', lit(0))) \
109-
.withColumn('defect', rf_local_equal('scl', lit(1))) \
110-
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
111-
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
112-
.withColumn('cirrus', rf_local_equal('scl', lit(10)))
113-
114-
one_mask = mask_part.withColumn('mask', rf_local_add('nodata', 'defect')) \
115-
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
116-
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
117-
.withColumn('mask', rf_local_add('mask', 'cirrus'))
118-
119-
cell_types = one_mask.select(rf_cell_type('mask')).distinct()
108+
mask = unmasked.withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10]))
109+
110+
cell_types = mask.select(rf_cell_type('mask')).distinct()
120111
cell_types
121112
```
122113

123114
Because there is not a NoData already defined, we will choose one. In this particular example, the minimum value is greater than zero, so we can use 0 as the NoData value.
124115

125116
```python, pick_nd
126-
blue_min = one_mask.agg(rf_agg_stats('blue').min.alias('blue_min'))
117+
blue_min = mask.agg(rf_agg_stats('blue').min.alias('blue_min'))
127118
blue_min
128119
```
129120

130121
We can now construct the cell type string for our blue band's cell type, designating 0 as NoData.
131122

132123
```python, get_ct_string
133-
blue_ct = one_mask.select(rf_cell_type('blue')).distinct().first()[0][0]
124+
blue_ct = mask.select(rf_cell_type('blue')).distinct().first()[0][0]
134125
masked_blue_ct = CellType(blue_ct).with_no_data_value(0)
135126
masked_blue_ct.cell_type_name
136127
```
@@ -139,9 +130,8 @@ Now we will use the @ref:[`rf_mask_by_value`](reference.md#rf-mask-by-value) to
139130

140131
```python, mask_blu
141132
with_nd = rf_convert_cell_type('blue', masked_blue_ct)
142-
masked = one_mask.withColumn('blue_masked',
143-
rf_mask_by_value(with_nd, 'mask', lit(1))) \
144-
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus', 'blue')
133+
masked = mask.withColumn('blue_masked',
134+
rf_mask_by_value(with_nd, 'mask', lit(1)))
145135
```
146136

147137
We can verify that the number of NoData cells in the resulting `blue_masked` column matches the total of the boolean `mask` _tile_ to ensure our logic is correct.

pyrasterframes/src/main/python/docs/supervised-learning.pymd

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ catalog_df = pd.DataFrame([
3232
{b: uri_base.format(b) for b in cols}
3333
])
3434

35-
df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128)) \
35+
tile_size = 256
36+
df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(tile_size, tile_size)) \
3637
.repartition(100)
3738

3839
df = df.select(
@@ -91,23 +92,12 @@ To filter only for good quality pixels, we follow roughly the same procedure as
9192
```python, make_mask
9293
from pyspark.sql.functions import lit
9394

94-
mask_part = df_labeled \
95-
.withColumn('nodata', rf_local_equal('scl', lit(0))) \
96-
.withColumn('defect', rf_local_equal('scl', lit(1))) \
97-
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
98-
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
99-
.withColumn('cirrus', rf_local_equal('scl', lit(10)))
100-
101-
df_mask_inv = mask_part \
102-
.withColumn('mask', rf_local_add('nodata', 'defect')) \
103-
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
104-
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
105-
.withColumn('mask', rf_local_add('mask', 'cirrus')) \
106-
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
107-
95+
df_labeled = df_labeled \
96+
.withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10]))
97+
10898
# at this point the mask contains 0 for good cells and 1 for defect, etc
10999
# convert cell type and set value 1 to NoData
110-
df_mask = df_mask_inv.withColumn('mask',
100+
df_mask = df_labeled.withColumn('mask',
111101
rf_with_no_data(rf_convert_cell_type('mask', 'uint8'), 1.0)
112102
)
113103

@@ -204,29 +194,35 @@ scored = model.transform(df_mask.drop('label'))
204194
retiled = scored \
205195
.groupBy('extent', 'crs') \
206196
.agg(
207-
rf_assemble_tile('column_index', 'row_index', 'prediction', 128, 128).alias('prediction'),
208-
rf_assemble_tile('column_index', 'row_index', 'B04', 128, 128).alias('red'),
209-
rf_assemble_tile('column_index', 'row_index', 'B03', 128, 128).alias('grn'),
210-
rf_assemble_tile('column_index', 'row_index', 'B02', 128, 128).alias('blu')
197+
rf_assemble_tile('column_index', 'row_index', 'prediction', tile_size, tile_size).alias('prediction'),
198+
rf_assemble_tile('column_index', 'row_index', 'B04', tile_size, tile_size).alias('red'),
199+
rf_assemble_tile('column_index', 'row_index', 'B03', tile_size, tile_size).alias('grn'),
200+
rf_assemble_tile('column_index', 'row_index', 'B02', tile_size, tile_size).alias('blu')
211201
)
212202
retiled.printSchema()
213203
```
214204

215205
Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image.
206+
Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
216207

217208
```python, display_rgb
218209
sample = retiled \
219-
.select('prediction', rf_rgb_composite('red', 'grn', 'blu').alias('rgb')) \
210+
.select('prediction', 'red', 'grn', 'blu') \
220211
.sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \
221212
.first()
222213

223-
sample_rgb = sample['rgb']
224-
mins = np.nanmin(sample_rgb.cells, axis=(0,1))
225-
plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins))
226-
```
214+
sample_rgb = np.concatenate([sample['red'].cells[:, :, None],
215+
sample['grn'].cells[ :, :, None],
216+
sample['blu'].cells[ :, :, None]], axis=2)
217+
# plot scaled RGB
218+
scaling_quantiles = np.nanpercentile(sample_rgb, [3.00, 97.00], axis=(0,1))
219+
scaled = np.clip(sample_rgb, scaling_quantiles[0, :], scaling_quantiles[1, :])
220+
scaled -= scaling_quantiles[0, :]
221+
scaled /= (scaling_quantiles[1, : ] - scaling_quantiles[0, :])
227222

228-
Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
223+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
224+
ax1.imshow(scaled)
229225

230-
```python, display_prediction
231-
display(sample['prediction'])
226+
# display prediction
227+
ax2.imshow(sample['prediction'].cells)
232228
```

pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,14 +260,24 @@ def rf_local_unequal_int(tile_col, scalar):
260260
"""Return a Tile with values equal 1 if the cell is not equal to a scalar, otherwise 0"""
261261
return _apply_scalar_to_tile('rf_local_unequal_int', tile_col, scalar)
262262

263+
263264
def rf_local_no_data(tile_col):
264265
"""Return a tile with ones where the input is NoData, otherwise zero."""
265266
return _apply_column_function('rf_local_no_data', tile_col)
266267

268+
267269
def rf_local_data(tile_col):
268270
"""Return a tile with zeros where the input is NoData, otherwise one."""
269271
return _apply_column_function('rf_local_data', tile_col)
270272

273+
def rf_local_is_in(tile_col, array):
274+
"""Return a tile with cell values of 1 where the `tile_col` cell is in the provided array."""
275+
from pyspark.sql.functions import array as sql_array, lit
276+
if isinstance(array, list):
277+
array = sql_array([lit(v) for v in array])
278+
279+
return _apply_column_function('rf_local_is_in', tile_col, array)
280+
271281
def _apply_column_function(name, *args):
272282
jfcn = RFContext.active().lookup(name)
273283
jcols = [_to_java_column(arg) for arg in args]

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_tile_udt_serialization(self):
131131
cells[1][1] = nd
132132
a_tile = Tile(cells, ct.with_no_data_value(nd))
133133
round_trip = udt.fromInternal(udt.toInternal(a_tile))
134-
self.assertEquals(a_tile, round_trip, "round-trip serialization for " + str(ct))
134+
self.assertEqual(a_tile, round_trip, "round-trip serialization for " + str(ct))
135135

136136
schema = StructType([StructField("tile", TileUDT(), False)])
137137
df = self.spark.createDataFrame([{"tile": a_tile}], schema)

0 commit comments

Comments
 (0)