Skip to content

Commit d7b3a67

Browse files
authored
Merge branch 'develop' into feature/raster-spatial-index
2 parents 047a63d + a5ed5ed commit d7b3a67

File tree

13 files changed

+761
-108
lines changed

13 files changed

+761
-108
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ trait RasterFunctions {
339339

340340
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
341341
list, replace the value with NODATA. */
342-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = {
342+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Int*): TypedColumn[Any, Tile] = {
343343
import org.apache.spark.sql.functions.array
344344
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
345345
rf_mask_by_values(sourceTile, maskTile, valuesCol)
@@ -357,6 +357,52 @@ trait RasterFunctions {
357357
def rf_inverse_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
358358
Mask.InverseMaskByValue(sourceTile, maskTile, lit(maskValue))
359359

360+
/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
361+
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Int, valueToMask: Boolean): TypedColumn[Any, Tile] =
362+
rf_mask_by_bit(dataTile, maskTile, lit(bitPosition), lit(if (valueToMask) 1 else 0))
363+
364+
/** Applies a mask using bit values in the `mask_tile`. Working from the right, extract the bit at `bitPosition` from the `maskTile`. In all locations where these are equal to the `valueToMask`, the returned tile is set to NoData, else the original `dataTile` cell value. */
365+
def rf_mask_by_bit(dataTile: Column, maskTile: Column, bitPosition: Column, valueToMask: Column): TypedColumn[Any, Tile] = {
366+
import org.apache.spark.sql.functions.array
367+
rf_mask_by_bits(dataTile, maskTile, bitPosition, lit(1), array(valueToMask))
368+
}
369+
370+
/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
371+
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Column, numBits: Column, valuesToMask: Column): TypedColumn[Any, Tile] = {
372+
val bitMask = rf_local_extract_bits(maskTile, startBit, numBits)
373+
rf_mask_by_values(dataTile, bitMask, valuesToMask)
374+
}
375+
376+
377+
/** Applies a mask from blacklisted bit values in the `mask_tile`. Working from the right, the bits from `start_bit` to `start_bit + num_bits` are @ref:[extracted](reference.md#rf_local_extract_bits) from cell values of the `mask_tile`. In all locations where these are in the `mask_values`, the returned tile is set to NoData; otherwise the original `tile` cell value is returned. */
378+
def rf_mask_by_bits(dataTile: Column, maskTile: Column, startBit: Int, numBits: Int, valuesToMask: Int*): TypedColumn[Any, Tile] = {
379+
import org.apache.spark.sql.functions.array
380+
val values = array(valuesToMask.map(lit):_*)
381+
rf_mask_by_bits(dataTile, maskTile, lit(startBit), lit(numBits), values)
382+
}
383+
384+
/** Extract value from specified bits of the cells' underlying binary data.
385+
* `startBit` is the first bit to consider, working from the right. It is zero indexed.
386+
* `numBits` is the number of bits to take moving further to the left. */
387+
def rf_local_extract_bits(tile: Column, startBit: Column, numBits: Column): Column =
388+
ExtractBits(tile, startBit, numBits)
389+
390+
/** Extract value from specified bits of the cells' underlying binary data.
391+
* `bitPosition` is bit to consider, working from the right. It is zero indexed. */
392+
def rf_local_extract_bits(tile: Column, bitPosition: Column): Column =
393+
rf_local_extract_bits(tile, bitPosition, lit(1))
394+
395+
/** Extract value from specified bits of the cells' underlying binary data.
396+
* `startBit` is the first bit to consider, working from the right. It is zero indexed.
397+
* `numBits` is the number of bits to take, moving further to the left. */
398+
def rf_local_extract_bits(tile: Column, startBit: Int, numBits: Int): Column =
399+
rf_local_extract_bits(tile, lit(startBit), lit(numBits))
400+
401+
/** Extract value from specified bits of the cells' underlying binary data.
402+
* `bitPosition` is bit to consider, working from the right. It is zero indexed. */
403+
def rf_local_extract_bits(tile: Column, bitPosition: Int): Column =
404+
rf_local_extract_bits(tile, lit(bitPosition))
405+
360406
/** Create a tile where cells in the grid defined by cols, rows, and bounds are filled with the given value. */
361407
def rf_rasterize(geometry: Column, bounds: Column, value: Column, cols: Int, rows: Int): TypedColumn[Any, Tile] =
362408
withTypedAlias("rf_rasterize", geometry)(

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,5 +139,8 @@ package object expressions {
139139
registry.registerExpression[Z2Indexer]("rf_z2_index")
140140

141141
registry.registerExpression[transformers.ReprojectGeometry]("st_reproject")
142+
143+
registry.registerExpression[ExtractBits]("rf_local_extract_bits")
144+
registry.registerExpression[ExtractBits]("rf_local_extract_bit")
142145
}
143146
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.transformers
23+
24+
import geotrellis.raster.Tile
25+
import org.apache.spark.sql.{Column, TypedColumn}
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
27+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
28+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
29+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
30+
import org.apache.spark.sql.rf.TileUDT
31+
import org.apache.spark.sql.types.DataType
32+
import org.locationtech.rasterframes.encoders.CatalystSerializer._
33+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
34+
import org.locationtech.rasterframes.expressions._
35+
36+
@ExpressionDescription(
37+
usage = "_FUNC_(tile, start_bit, num_bits) - In each cell of `tile`, extract `num_bits` from the cell value, starting at `start_bit` from the left.",
38+
arguments = """
39+
Arguments:
40+
* tile - tile column to extract values
41+
* start_bit -
42+
* num_bits -
43+
""",
44+
examples = """
45+
Examples:
46+
> SELECT _FUNC_(tile, lit(4), lit(2))
47+
..."""
48+
)
49+
case class ExtractBits(child1: Expression, child2: Expression, child3: Expression) extends TernaryExpression with CodegenFallback with Serializable {
50+
override val nodeName: String = "rf_local_extract_bits"
51+
52+
override def children: Seq[Expression] = Seq(child1, child2, child3)
53+
54+
override def dataType: DataType = child1.dataType
55+
56+
override def checkInputDataTypes(): TypeCheckResult =
57+
if(!tileExtractor.isDefinedAt(child1.dataType)) {
58+
TypeCheckFailure(s"Input type '${child1.dataType}' does not conform to a raster type.")
59+
} else if (!intArgExtractor.isDefinedAt(child2.dataType)) {
60+
TypeCheckFailure(s"Input type '${child2.dataType}' isn't an integral type.")
61+
} else if (!intArgExtractor.isDefinedAt(child3.dataType)) {
62+
TypeCheckFailure(s"Input type '${child3.dataType}' isn't an integral type.")
63+
} else TypeCheckSuccess
64+
65+
66+
override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
67+
implicit val tileSer = TileUDT.tileSerializer
68+
val (childTile, childCtx) = tileExtractor(child1.dataType)(row(input1))
69+
70+
val startBits = intArgExtractor(child2.dataType)(input2).value
71+
72+
val numBits = intArgExtractor(child2.dataType)(input3).value
73+
74+
childCtx match {
75+
case Some(ctx) => ctx.toProjectRasterTile(op(childTile, startBits, numBits)).toInternalRow
76+
case None => op(childTile, startBits, numBits).toInternalRow
77+
}
78+
}
79+
80+
protected def op(tile: Tile, startBit: Int, numBits: Int): Tile = ExtractBits(tile, startBit, numBits)
81+
82+
}
83+
84+
object ExtractBits{
85+
def apply(tile: Column, startBit: Column, numBits: Column): Column =
86+
new Column(ExtractBits(tile.expr, startBit.expr, numBits.expr))
87+
88+
def apply(tile: Tile, startBit: Int, numBits: Int): Tile = {
89+
assert(!tile.cellType.isFloatingPoint, "ExtractBits operation requires integral CellType")
90+
// this is the last `numBits` positions of "111111111111111"
91+
val widthMask = Int.MaxValue >> (63 - numBits)
92+
// map preserving the nodata structure
93+
tile.mapIfSet(x x >> startBit & widthMask)
94+
}
95+
96+
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ package org.locationtech.rasterframes.expressions.transformers
2424
import com.typesafe.scalalogging.Logger
2525
import geotrellis.raster
2626
import geotrellis.raster.Tile
27-
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask => gtInverseMask, Mask => gtMask}
27+
import geotrellis.raster.mapalgebra.local.{Undefined, InverseMask gtInverseMask, Mask gtMask}
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3030
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
@@ -38,7 +38,15 @@ import org.locationtech.rasterframes.expressions.localops.IsIn
3838
import org.locationtech.rasterframes.expressions.row
3939
import org.slf4j.LoggerFactory
4040

41-
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, inverse: Boolean)
41+
/** Convert cells in the `left` to NoData based on another tile's contents
42+
*
43+
* @param left a tile of data values, with valid nodata cell type
44+
* @param middle a tile indicating locations to set to nodata
45+
* @param right optional, cell values in the `middle` tile indicating locations to set NoData
46+
* @param undefined if true, consider NoData in the `middle` as the locations to mask; else use `right` valued cells
47+
* @param inverse if true, and defined is true, set `left` to NoData where `middle` is NOT nodata
48+
*/
49+
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, undefined: Boolean, inverse: Boolean)
4250
extends TernaryExpression with CodegenFallback with Serializable {
4351
// aliases.
4452
def targetExp = left
@@ -77,13 +85,16 @@ abstract class Mask(val left: Expression, val middle: Expression, val right: Exp
7785

7886
val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput)
7987

80-
val masking = if (maskValue.value == 0) Defined(maskTile)
81-
else maskTile
88+
// Get a tile where values of 1 indicate locations to set to ND in the target tile
89+
// When `undefined` is true, setting targetTile locations to ND for ND locations of the `maskTile`
90+
val masking = if (undefined) Undefined(maskTile)
91+
else maskTile.localEqual(maskValue.value) // Otherwise if `maskTile` locations equal `maskValue`, set location to ND
8292

93+
// apply the `masking` where values are 1 set to ND (possibly inverted!)
8394
val result = if (inverse)
84-
gtInverseMask(targetTile, masking, maskValue.value, raster.NODATA)
95+
gtInverseMask(targetTile, masking, 1, raster.NODATA)
8596
else
86-
gtMask(targetTile, masking, maskValue.value, raster.NODATA)
97+
gtMask(targetTile, masking, 1, raster.NODATA)
8798

8899
targetCtx match {
89100
case Some(ctx) => ctx.toProjectRasterTile(result).toInternalRow
@@ -106,7 +117,7 @@ object Mask {
106117
..."""
107118
)
108119
case class MaskByDefined(target: Expression, mask: Expression)
109-
extends Mask(target, mask, Literal(0), false) {
120+
extends Mask(target, mask, Literal(0), true, false) {
110121
override def nodeName: String = "rf_mask"
111122
}
112123
object MaskByDefined {
@@ -126,7 +137,7 @@ object Mask {
126137
..."""
127138
)
128139
case class InverseMaskByDefined(leftTile: Expression, rightTile: Expression)
129-
extends Mask(leftTile, rightTile, Literal(0), true) {
140+
extends Mask(leftTile, rightTile, Literal(0), true, true) {
130141
override def nodeName: String = "rf_inverse_mask"
131142
}
132143
object InverseMaskByDefined {
@@ -146,7 +157,7 @@ object Mask {
146157
..."""
147158
)
148159
case class MaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression)
149-
extends Mask(leftTile, rightTile, maskValue, false) {
160+
extends Mask(leftTile, rightTile, maskValue, false, false) {
150161
override def nodeName: String = "rf_mask_by_value"
151162
}
152163
object MaskByValue {
@@ -168,7 +179,7 @@ object Mask {
168179
..."""
169180
)
170181
case class InverseMaskByValue(leftTile: Expression, rightTile: Expression, maskValue: Expression)
171-
extends Mask(leftTile, rightTile, maskValue, true) {
182+
extends Mask(leftTile, rightTile, maskValue, false, true) {
172183
override def nodeName: String = "rf_inverse_mask_by_value"
173184
}
174185
object InverseMaskByValue {
@@ -190,7 +201,7 @@ object Mask {
190201
..."""
191202
)
192203
case class MaskByValues(dataTile: Expression, maskTile: Expression)
193-
extends Mask(dataTile, maskTile, Literal(1), inverse = false) {
204+
extends Mask(dataTile, maskTile, Literal(1), false, false) {
194205
def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) =
195206
this(dataTile, IsIn(maskTile, maskValues))
196207
override def nodeName: String = "rf_mask_by_values"

core/src/main/scala/org/locationtech/rasterframes/extensions/RasterJoin.scala

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,37 @@ package org.locationtech.rasterframes.extensions
2323
import org.apache.spark.sql._
2424
import org.apache.spark.sql.functions._
2525
import org.locationtech.rasterframes._
26+
import org.locationtech.rasterframes.expressions.SpatialRelation
27+
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
2628
import org.locationtech.rasterframes.functions.reproject_and_merge
2729
import org.locationtech.rasterframes.util._
2830

2931
import scala.util.Random
3032

3133
object RasterJoin {
3234

35+
/** Perform a raster join on dataframes that each have proj_raster columns, or crs and extent explicitly included. */
3336
def apply(left: DataFrame, right: DataFrame): DataFrame = {
34-
val df = apply(left, right, left("extent"), left("crs"), right("extent"), right("crs"))
35-
df.drop(right("extent")).drop(right("crs"))
37+
def usePRT(d: DataFrame) =
38+
d.projRasterColumns.headOption
39+
.map(p => (rf_crs(p), rf_extent(p)))
40+
.orElse(Some(col("crs"), col("extent")))
41+
.map { case (crs, extent) =>
42+
val d2 = d.withColumn("crs", crs).withColumn("extent", extent)
43+
(d2, d2("crs"), d2("extent"))
44+
}
45+
.get
46+
47+
val (ldf, lcrs, lextent) = usePRT(left)
48+
val (rdf, rcrs, rextent) = usePRT(right)
49+
50+
apply(ldf, rdf, lextent, lcrs, rextent, rcrs)
3651
}
3752

3853
def apply(left: DataFrame, right: DataFrame, leftExtent: Column, leftCRS: Column, rightExtent: Column, rightCRS: Column): DataFrame = {
3954
val leftGeom = st_geometry(leftExtent)
4055
val rightGeomReproj = st_reproject(st_geometry(rightExtent), rightCRS, leftCRS)
41-
val joinExpr = st_intersects(leftGeom, rightGeomReproj)
56+
val joinExpr = new Column(SpatialRelation.Intersects(leftGeom.expr, rightGeomReproj.expr))
4257
apply(left, right, joinExpr, leftExtent, leftCRS, rightExtent, rightCRS)
4358
}
4459

@@ -65,7 +80,7 @@ object RasterJoin {
6580
val leftAggCols = left.columns.map(s => first(left(s), true) as s)
6681
// On the RHS we collect result as a list.
6782
val rightAggCtx = Seq(collect_list(rightExtent) as rightExtent2, collect_list(rightCRS) as rightCRS2)
68-
val rightAggTiles = right.tileColumns.map(c => collect_list(c) as c.columnName)
83+
val rightAggTiles = right.tileColumns.map(c => collect_list(ExtractTile(c)) as c.columnName)
6984
val rightAggOther = right.notTileColumns
7085
.filter(n => n.columnName != rightExtent.columnName && n.columnName != rightCRS.columnName)
7186
.map(c => collect_list(c) as (c.columnName + "_agg"))

0 commit comments

Comments
 (0)