Skip to content

Commit 8dc93f7

Browse files
committed
Merge commit 'b39640034767c16a9b5f422d0fbdf9b0300707e5' into fix/357
* commit 'b39640034767c16a9b5f422d0fbdf9b0300707e5': break out commented assert into skipped unit test around masking and deserialization register rf_local_extract_bit with SQL functions Add landsat masking section to masking docs page Add mask bits python api and unit test Extract bits should throw on non-integral cell types Fix for both masking by def and value; expand code comments; update tests Masking improvements and unit tests. Add failing unit test for mask by value on 0 Update docs and scala function api WIP: mask bits Initial implementation of extracting bit values, e.g. for a quality band
2 parents f4d3074 + b396400 commit 8dc93f7

File tree

11 files changed

+736
-102
lines changed

11 files changed

+736
-102
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ trait RasterFunctions {
319319

320320
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
321321
list, replace the value with NODATA. */
322-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = {
322+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Int*): TypedColumn[Any, Tile] = {
323323
import org.apache.spark.sql.functions.array
324324
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
325325
rf_mask_by_values(sourceTile, maskTile, valuesCol)
@@ -337,6 +337,52 @@ trait RasterFunctions {
337337
def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
338338
Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue))
339339

340+
/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
341+
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): TypedColumn[Any, Tile] =
342+
rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(if (valueToMask) 1 else 0))
343+
344+
/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
345+
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): TypedColumn[Any, Tile] = {
346+
import org.apache.spark.sql.functions.array
347+
rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), array(valueToMask))
348+
}
349+
350+
/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
351+
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): TypedColumn[Any, Tile] = {
352+
val bitMask = rf_local_extract_bits(maskTile, startBit, numBits)
353+
rf_mask_by_values(dataTile, bitMask, valuesToMask)
354+
}
355+
356+
357+
/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
358+
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): TypedColumn[Any, Tile] = {
359+
import org.apache.spark.sql.functions.array
360+
val values = array(valuesToMask.map(lit):_*)
361+
rf_mask_by_bits(dataTile, maskTile, lit(startBit), lit(numBits), values)
362+
}
363+
364+
/** Extract value from specified bits of the cells' underlying binary data.
365+
* `startBit` is the first bit to consider, working from the right. It is zero indexed.
366+
* `numBits` is the number of bits to take moving further to the left. */
367+
def rf_local_extract_bits(tile: Column, startBit: Column, numBits: Column): Column =
368+
ExtractBits(tile, startBit, numBits)
369+
370+
/** Extract value from specified bits of the cells' underlying binary data.
371+
* `bitPosition` is bit to consider, working from the right. It is zero indexed. */
372+
def rf_local_extract_bits(tile: Column, bitPosition: Column): Column =
373+
rf_local_extract_bits(tile, bitPosition, lit(1))
374+
375+
/** Extract value from specified bits of the cells' underlying binary data.
376+
* `startBit` is the first bit to consider, working from the right. It is zero indexed.
377+
* `numBits` is the number of bits to take, moving further to the left. */
378+
def rf_local_extract_bits(tile: Column, startBit: Int, numBits: Int): Column =
379+
rf_local_extract_bits(tile, lit(startBit), lit(numBits))
380+
381+
/** Extract value from specified bits of the cells' underlying binary data.
382+
* `bitPosition` is bit to consider, working from the right. It is zero indexed. */
383+
def rf_local_extract_bits(tile: Column, bitPosition: Int): Column =
384+
rf_local_extract_bits(tile, lit(bitPosition))
385+
340386
/** Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value. */
341387
def rf_rasterize(geometry: Column, bounds: Column, value: Column, cols: Int, rows: Int): TypedColumn[Any, Tile] =
342388
withTypedAlias("rf_rasterize", geometry)(

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,5 +138,8 @@ package object expressions {
138138
registry.registerExpression[XZ2Indexer]("rf_spatial_index")
139139

140140
registry.registerExpression[transformers.ReprojectGeometry]("st_reproject")
141+
142+
registry.registerExpression[ExtractBits]("rf_local_extract_bits")
143+
registry.registerExpression[ExtractBits]("rf_local_extract_bit")
141144
}
142145
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.transformers
23+
24+
import geotrellis.raster.Tile
25+
import org.apache.spark.sql.{Column, TypedColumn}
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
27+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
28+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
29+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
30+
import org.apache.spark.sql.rf.TileUDT
31+
import org.apache.spark.sql.types.DataType
32+
import org.locationtech.rasterframes.encoders.CatalystSerializer._
33+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
34+
import org.locationtech.rasterframes.expressions._
35+
36+
@ExpressionDescription(
37+
usage = "_FUNC_(tile, start_bit, num_bits) - In each cell of `tile`, extract `num_bits` from the cell value, starting at `start_bit` from the left.",
38+
arguments = """
39+
Arguments:
40+
* tile - tile column to extract values
41+
* start_bit -
42+
* num_bits -
43+
""",
44+
examples = """
45+
Examples:
46+
> SELECT _FUNC_(tile, lit(4), lit(2))
47+
..."""
48+
)
49+
case class ExtractBits(child1: Expression, child2: Expression, child3: Expression) extends TernaryExpression with CodegenFallback with Serializable {
50+
override val nodeName: String = "rf_local_extract_bits"
51+
52+
override def children: Seq[Expression] = Seq(child1, child2, child3)
53+
54+
override def dataType: DataType = child1.dataType
55+
56+
override def checkInputDataTypes(): TypeCheckResult =
57+
if(!tileExtractor.isDefinedAt(child1.dataType)) {
58+
TypeCheckFailure(s"Input type '${child1.dataType}' does not conform to a raster type.")
59+
} else if (!intArgExtractor.isDefinedAt(child2.dataType)) {
60+
TypeCheckFailure(s"Input type '${child2.dataType}' isn't an integral type.")
61+
} else if (!intArgExtractor.isDefinedAt(child3.dataType)) {
62+
TypeCheckFailure(s"Input type '${child3.dataType}' isn't an integral type.")
63+
} else TypeCheckSuccess
64+
65+
66+
override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
67+
implicit val tileSer = TileUDT.tileSerializer
68+
val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1))
69+
70+
val startBits = intArgExtractor(child2.dataType)(input2).value
71+
72+
val numBits = intArgExtractor(child2.dataType)(input3).value
73+
74+
childCtx match {
75+
case Some(ctx) => ctx.toProjectRasterTile(op(childTile, startBits, numBits)).toInternalRow
76+
case None => op(childTile, startBits, numBits).toInternalRow
77+
}
78+
}
79+
80+
protected def op(tile: Tile, startBit: Int, numBits: Int): Tile = ExtractBits(tile, startBit, numBits)
81+
82+
}
83+
84+
object ExtractBits{
85+
def apply(tile: Column, startBit: Column, numBits: Column): Column =
86+
new Column(ExtractBits(tile.expr, startBit.expr, numBits.expr))
87+
88+
def apply(tile: Tile, startBit: Int, numBits: Int): Tile = {
89+
assert(!tile.cellType.isFloatingPoint, "ExtractBits operation requires integral CellType")
90+
// this is the last `numBits` positions of "111111111111111"
91+
val widthMask = Int.MaxValue >> (63 - numBits)
92+
// map preserving the nodata structure
93+
tile.mapIfSet(x x >> startBit & widthMask)
94+
}
95+
96+
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ package org.locationtech.rasterframes.expressions.transformers
2424
import com.typesafe.scalalogging.Logger
2525
import geotrellis.raster
2626
import geotrellis.raster.Tile
27-
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask => gtInverseMask, Mask => gtMask}
27+
import geotrellis.raster.mapalgebra.local.{Undefined, InverseMask gtInverseMask, Mask gtMask}
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3030
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
@@ -38,7 +38,15 @@ import org.locationtech.rasterframes.expressions.localops.IsIn
3838
import org.locationtech.rasterframes.expressions.row
3939
import org.slf4j.LoggerFactory
4040

41-
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, inverse: Boolean)
41+
/** Convert cells in the `left` to NoData based on another tile's contents
42+
*
43+
* @param left a tile of data values, with valid nodata cell type
44+
* @param middle a tile indicating locations to set to nodata
45+
* @param right optional, cell values in the `middle` tile indicating locations to set NoData
46+
* @param undefined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells
47+
* @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata
48+
*/
49+
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, undefined: Boolean, inverse: Boolean)
4250
extends TernaryExpression with CodegenFallback with Serializable {
4351
// aliases.
4452
def targetExp = left
@@ -77,13 +85,16 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp
7785

7886
val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput)
7987

80-
val masking = if (maskValue.value == 0) Defined(maskTile)
81-
else maskTile
88+
// Get a tile where values of 1 indicate locations to set to ND in the target tile
89+
// When `undefined` is true, setting targetTile locations to ND for ND locations of the `maskTile`
90+
val masking = if (undefined) Undefined(maskTile)
91+
else maskTile.localEqual(maskValue.value) // Otherwise if `maskTile` locations equal `maskValue`, set location to ND
8292

93+
// apply the `masking` where values are 1 set to ND (possibly inverted!)
8394
val result = if (inverse)
84-
gtInverseMask(targetTile, masking, maskValue.value, raster.NODATA)
95+
gtInverseMask(targetTile, masking, 1, raster.NODATA)
8596
else
86-
gtMask(targetTile, masking, maskValue.value, raster.NODATA)
97+
gtMask(targetTile, masking, 1, raster.NODATA)
8798

8899
targetCtx match {
89100
case Some(ctx) => ctx.toProjectRasterTile(result).toInternalRow
@@ -106,7 +117,7 @@ object Mask {
106117
..."""
107118
)
108119
case class MaskByDefined(target: Expression, mask: Expression)
109-
extends Mask(target, mask, Literal(0), false) {
120+
extends Mask(target, mask, Literal(0), true, false) {
110121
override def nodeName: String = "rf_mask"
111122
}
112123
object MaskByDefined {
@@ -126,7 +137,7 @@ object Mask {
126137
..."""
127138
)
128139
case class InverseMaskByDefined(leftTile: Expression, rightTile: Expression)
129-
extends Mask(leftTile, rightTile, Literal(0), true) {
140+
extends Mask(leftTile, rightTile, Literal(0), true, true) {
130141
override def nodeName: String = "rf_inverse_mask"
131142
}
132143
object InverseMaskByDefined {
@@ -146,7 +157,7 @@ object Mask {
146157
..."""
147158
)
148159
case class MaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression)
149-
extends Mask(leftTile, rightTile, maskValue, false) {
160+
extends Mask(leftTile, rightTile, maskValue, false, false) {
150161
override def nodeName: String = "rf_mask_by_value"
151162
}
152163
object MaskByValue {
@@ -168,7 +179,7 @@ object Mask {
168179
..."""
169180
)
170181
case class InverseMaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression)
171-
extends Mask(leftTile, rightTile, maskValue, true) {
182+
extends Mask(leftTile, rightTile, maskValue, false, true) {
172183
override def nodeName: String = "rf_inverse_mask_by_value"
173184
}
174185
object InverseMaskByValue {
@@ -190,7 +201,7 @@ object Mask {
190201
..."""
191202
)
192203
case class MaskByValues(dataTile: Expression, maskTile: Expression)
193-
extends Mask(dataTile, maskTile, Literal(1), inverse = false) {
204+
extends Mask(dataTile, maskTile, Literal(1), false, false) {
194205
def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) =
195206
this(dataTile, IsIn(maskTile, maskValues))
196207
override def nodeName: String = "rf_mask_by_values"

0 commit comments

Comments
 (0)