Skip to content

Commit 0214fa2

Browse files
committed
fix masking functions
Made them more direct. Good for fixing things and better for performance because these versions don't need to create intermediate mask tiles.
1 parent a3ac4cf commit 0214fa2

File tree

10 files changed

+491
-237
lines changed

10 files changed

+491
-237
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import geotrellis.raster.{CellGrid, Neighborhood, Raster, TargetCell, Tile}
2626
import geotrellis.vector.Extent
2727
import org.apache.spark.sql.Row
2828
import org.apache.spark.sql.catalyst.InternalRow
29+
import org.apache.spark.sql.catalyst.util.ArrayData
2930
import org.apache.spark.sql.jts.JTSTypes
3031
import org.apache.spark.sql.rf.{RasterSourceUDT, TileUDT}
3132
import org.apache.spark.sql.types._
@@ -106,6 +107,24 @@ object DynamicExtractors {
106107
(row: InternalRow) => row.as[ProjectedRasterTile]
107108
}
108109

110+
lazy val intArrayExtractor: PartialFunction[DataType, ArrayData => Array[Int]] = {
111+
case ArrayType(t, true) =>
112+
throw new IllegalArgumentException(s"Can't turn array of $t to array<int>")
113+
case ArrayType(DoubleType, false) =>
114+
unsafe => unsafe.toDoubleArray.map(_.toInt)
115+
case ArrayType(FloatType, false) =>
116+
unsafe => unsafe.toFloatArray.map(_.toInt)
117+
case ArrayType(IntegerType, false) =>
118+
unsafe => unsafe.toIntArray
119+
case ArrayType(ShortType, false) =>
120+
unsafe => unsafe.toShortArray.map(_.toInt)
121+
case ArrayType(ByteType, false) =>
122+
unsafe => unsafe.toByteArray.map(_.toInt)
123+
case ArrayType(BooleanType, false) =>
124+
unsafe => unsafe.toBooleanArray().map(x => if (x) 1 else 0)
125+
126+
}
127+
109128
lazy val crsExtractor: PartialFunction[DataType, Any => CRS] = {
110129
val base: PartialFunction[DataType, Any => CRS] = {
111130
case _: StringType => (v: Any) => LazyCRS(v.asInstanceOf[UTF8String].toString)

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
package org.locationtech.rasterframes
2323

2424
import geotrellis.raster.{DoubleConstantNoDataCellType, Tile}
25-
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase}
25+
import org.apache.spark.sql.catalyst.analysis.FunctionRegistryBase
2626
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2727
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, ScalaUDF}
2828
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, ScalaReflection}
2929
import org.apache.spark.sql.types.DataType
30-
import org.apache.spark.sql.{SQLContext}
30+
import org.apache.spark.sql.SQLContext
3131
import org.locationtech.rasterframes.expressions.accessors._
3232
import org.locationtech.rasterframes.expressions.aggregates.CellCountAggregate.DataCells
3333
import org.locationtech.rasterframes.expressions.aggregates._
@@ -106,23 +106,23 @@ package object expressions {
106106
def register1[T <: Expression : ClassTag](
107107
name: String,
108108
builder: Expression => T
109-
): Unit = registerFunction[T](name, None){ case Seq(a) => builder(a)
109+
): Unit = registerFunction[T](name, None){ args => builder(args(0))
110110
}
111111

112112
def register2[T <: Expression : ClassTag](
113113
name: String,
114114
builder: (Expression, Expression) => T
115-
): Unit = registerFunction[T](name, None){ case Seq(a, b) => builder(a, b) }
115+
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1)) }
116116

117117
def register3[T <: Expression : ClassTag](
118118
name: String,
119119
builder: (Expression, Expression, Expression) => T
120-
): Unit = registerFunction[T](name, None){ case Seq(a, b, c) => builder(a, b, c) }
120+
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2)) }
121121

122122
def register5[T <: Expression : ClassTag](
123123
name: String,
124124
builder: (Expression, Expression, Expression, Expression, Expression) => T
125-
): Unit = registerFunction[T](name, None){ case Seq(a, b, c, d, e) => builder(a, b, c, d, e) }
125+
): Unit = registerFunction[T](name, None){ args => builder(args(0), args(1), args(2), args(3), args(4)) }
126126

127127
register2("rf_local_add", Add(_, _))
128128
register2("rf_local_subtract", Subtract(_, _))
@@ -207,11 +207,11 @@ package object expressions {
207207
register2(Aspect.name, Aspect(_, _))
208208
register5(Hillshade.name, Hillshade(_, _, _, _, _))
209209

210-
register2("rf_mask", Mask.MaskByDefined(_, _))
211-
register2("rf_inverse_mask", Mask.InverseMaskByDefined(_, _))
212-
register3("rf_mask_by_value", Mask.MaskByValue(_, _, _))
213-
register3("rf_inverse_mask_by_value", Mask.InverseMaskByValue(_, _, _))
214-
register2("rf_mask_by_values", Mask.MaskByValues(_, _))
210+
register2("rf_mask", MaskByDefined(_, _))
211+
register2("rf_inverse_mask", InverseMaskByDefined(_, _))
212+
register3("rf_mask_by_value", MaskByValue(_, _, _))
213+
register3("rf_inverse_mask_by_value", InverseMaskByValue(_, _, _))
214+
register3("rf_mask_by_values", MaskByValues(_, _, _))
215215

216216
register1("rf_render_ascii", DebugRender.RenderAscii(_))
217217
register1("rf_render_matrix", DebugRender.RenderMatrix(_))
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.transformers
23+
24+
import geotrellis.raster.{NODATA, Tile, isNoData}
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
27+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
28+
import org.apache.spark.sql.{Column, TypedColumn}
29+
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
30+
import org.apache.spark.sql.types.DataType
31+
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
32+
import org.locationtech.rasterframes.expressions.{RasterResult, row}
33+
import org.locationtech.rasterframes.tileEncoder
34+
35+
36+
@ExpressionDescription(
37+
usage = "_FUNC_(target, mask) - Generate a tile with the values from the data tile, but where cells in the masking tile DO NOT contain NODATA, replace the data value with NODATA",
38+
arguments = """
39+
Arguments:
40+
* target - tile to mask
41+
* mask - masking definition""",
42+
examples = """
43+
Examples:
44+
> SELECT _FUNC_(target, mask);
45+
..."""
46+
)
47+
case class InverseMaskByDefined(targetTile: Expression, maskTile: Expression)
48+
extends BinaryExpression
49+
with CodegenFallback
50+
with RasterResult {
51+
override def nodeName: String = "rf_inverse_mask"
52+
53+
def dataType: DataType = targetTile.dataType
54+
def left: Expression = targetTile
55+
def right: Expression = maskTile
56+
57+
protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
58+
InverseMaskByDefined(newLeft, newRight)
59+
60+
override def checkInputDataTypes(): TypeCheckResult = {
61+
if (!tileExtractor.isDefinedAt(targetTile.dataType)) {
62+
TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.")
63+
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
64+
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
65+
} else TypeCheckSuccess
66+
}
67+
68+
private lazy val targetTileExtractor = tileExtractor(targetTile.dataType)
69+
private lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
70+
71+
override protected def nullSafeEval(targetInput: Any, maskInput: Any): Any = {
72+
val (targetTile, targetCtx) = targetTileExtractor(row(targetInput))
73+
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
74+
75+
val result = targetTile.dualCombine(mask)
76+
{ (v, m) => if (isNoData(m)) v else NODATA }
77+
{ (v, m) => if (isNoData(m)) v else NODATA }
78+
toInternalRow(result, targetCtx)
79+
}
80+
}
81+
82+
object InverseMaskByDefined {
83+
def apply(srcTile: Column, maskingTile: Column): TypedColumn[Any, Tile] =
84+
new Column(InverseMaskByDefined(srcTile.expr, maskingTile.expr)).as[Tile]
85+
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.transformers
23+
24+
import geotrellis.raster.{NODATA, Tile, d2i}
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
27+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
28+
import org.apache.spark.sql.{Column, TypedColumn}
29+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
30+
import org.apache.spark.sql.types.DataType
31+
import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArgExtractor, tileExtractor}
32+
import org.locationtech.rasterframes.expressions.{RasterResult, row}
33+
import org.locationtech.rasterframes.tileEncoder
34+
35+
36+
@ExpressionDescription(
37+
usage = "_FUNC_(target, mask, maskValue) - Generate a tile with the values from the data tile, but where cells in the masking tile DO NOT contain the masking value, replace the data value with NODATA.",
38+
arguments = """
39+
Arguments:
40+
* target - tile to mask
41+
* mask - masking definition
42+
* maskValue - value in the `mask` for which to mark `target` as data cells
43+
""",
44+
examples = """
45+
Examples:
46+
> SELECT _FUNC_(target, mask, maskValue);
47+
..."""
48+
)
49+
case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, maskValue: Expression)
50+
extends TernaryExpression
51+
with CodegenFallback
52+
with RasterResult {
53+
override def nodeName: String = "rf_inverse_mask_by_value"
54+
55+
def dataType: DataType = targetTile.dataType
56+
def first: Expression = targetTile
57+
def second: Expression = maskTile
58+
def third: Expression = maskValue
59+
60+
protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
61+
InverseMaskByValue(newFirst, newSecond, newThird)
62+
63+
override def checkInputDataTypes(): TypeCheckResult = {
64+
if (!tileExtractor.isDefinedAt(targetTile.dataType)) {
65+
TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.")
66+
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
67+
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
68+
} else if (!intArgExtractor.isDefinedAt(maskValue.dataType)) {
69+
TypeCheckFailure(s"Input type '${maskValue.dataType}' isn't an integral type.")
70+
} else TypeCheckSuccess
71+
}
72+
73+
private lazy val targetTileExtractor = tileExtractor(targetTile.dataType)
74+
private lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
75+
private lazy val maskValueExtractor = intArgExtractor(maskValue.dataType)
76+
77+
override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = {
78+
val (targetTile, targetCtx) = targetTileExtractor(row(targetInput))
79+
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
80+
val maskValue = maskValueExtractor(maskValueInput).value
81+
82+
val result = targetTile.dualCombine(mask)
83+
{ (v, m) => if (m != maskValue) NODATA else v }
84+
{ (v, m) => if (d2i(m) != maskValue) NODATA else v }
85+
toInternalRow(result, targetCtx)
86+
}
87+
}
88+
89+
object InverseMaskByValue {
90+
def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
91+
new Column(InverseMaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile]
92+
}

0 commit comments

Comments
 (0)