Skip to content

Commit 61081b7

Browse files
committed
Fix: Mask operations preserver the target tile cell type
1 parent 9be3cb6 commit 61081b7

File tree

6 files changed

+114
-87
lines changed

6 files changed

+114
-87
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByDefined.scala

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,9 @@ package org.locationtech.rasterframes.expressions.transformers
2323

2424
import geotrellis.raster.{NODATA, Tile, isNoData}
2525
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2726
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2827
import org.apache.spark.sql.{Column, TypedColumn}
2928
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
30-
import org.apache.spark.sql.types.DataType
31-
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
3229
import org.locationtech.rasterframes.expressions.{RasterResult, row}
3330
import org.locationtech.rasterframes.tileEncoder
3431

@@ -45,36 +42,26 @@ import org.locationtech.rasterframes.tileEncoder
4542
..."""
4643
)
4744
case class InverseMaskByDefined(targetTile: Expression, maskTile: Expression)
48-
extends BinaryExpression
45+
extends BinaryExpression with MaskExpression
4946
with CodegenFallback
5047
with RasterResult {
5148
override def nodeName: String = "rf_inverse_mask"
5249

53-
def dataType: DataType = targetTile.dataType
5450
def left: Expression = targetTile
5551
def right: Expression = maskTile
5652

5753
protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
5854
InverseMaskByDefined(newLeft, newRight)
5955

60-
override def checkInputDataTypes(): TypeCheckResult = {
61-
if (!tileExtractor.isDefinedAt(targetTile.dataType)) {
62-
TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.")
63-
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
64-
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
65-
} else TypeCheckSuccess
66-
}
67-
68-
private lazy val targetTileExtractor = tileExtractor(targetTile.dataType)
69-
private lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
56+
override def checkInputDataTypes(): TypeCheckResult = checkTileDataTypes()
7057

7158
override protected def nullSafeEval(targetInput: Any, maskInput: Any): Any = {
7259
val (targetTile, targetCtx) = targetTileExtractor(row(targetInput))
7360
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
74-
75-
val result = targetTile.dualCombine(mask)
76-
{ (v, m) => if (isNoData(m)) v else NODATA }
61+
val result = maskEval(targetTile, mask,
62+
{ (v, m) => if (isNoData(m)) v else NODATA },
7763
{ (v, m) => if (isNoData(m)) v else NODATA }
64+
)
7865
toInternalRow(result, targetCtx)
7966
}
8067
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/InverseMaskByValue.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121

2222
package org.locationtech.rasterframes.expressions.transformers
2323

24-
import geotrellis.raster.{NODATA, Tile, d2i}
24+
import geotrellis.raster.{NODATA, Tile}
2525
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
2727
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2828
import org.apache.spark.sql.{Column, TypedColumn}
2929
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
30-
import org.apache.spark.sql.types.DataType
31-
import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArgExtractor, tileExtractor}
30+
import org.locationtech.rasterframes.expressions.DynamicExtractors.intArgExtractor
3231
import org.locationtech.rasterframes.expressions.{RasterResult, row}
3332
import org.locationtech.rasterframes.tileEncoder
3433

@@ -47,12 +46,11 @@ import org.locationtech.rasterframes.tileEncoder
4746
..."""
4847
)
4948
case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, maskValue: Expression)
50-
extends TernaryExpression
49+
extends TernaryExpression with MaskExpression
5150
with CodegenFallback
5251
with RasterResult {
5352
override def nodeName: String = "rf_inverse_mask_by_value"
5453

55-
def dataType: DataType = targetTile.dataType
5654
def first: Expression = targetTile
5755
def second: Expression = maskTile
5856
def third: Expression = maskValue
@@ -61,27 +59,22 @@ case class InverseMaskByValue(targetTile: Expression, maskTile: Expression, mask
6159
InverseMaskByValue(newFirst, newSecond, newThird)
6260

6361
override def checkInputDataTypes(): TypeCheckResult = {
64-
if (!tileExtractor.isDefinedAt(targetTile.dataType)) {
65-
TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.")
66-
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
67-
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
68-
} else if (!intArgExtractor.isDefinedAt(maskValue.dataType)) {
62+
if (!intArgExtractor.isDefinedAt(maskValue.dataType)) {
6963
TypeCheckFailure(s"Input type '${maskValue.dataType}' isn't an integral type.")
70-
} else TypeCheckSuccess
64+
} else checkTileDataTypes()
7165
}
7266

73-
private lazy val targetTileExtractor = tileExtractor(targetTile.dataType)
74-
private lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
7567
private lazy val maskValueExtractor = intArgExtractor(maskValue.dataType)
7668

7769
override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = {
7870
val (targetTile, targetCtx) = targetTileExtractor(row(targetInput))
7971
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
8072
val maskValue = maskValueExtractor(maskValueInput).value
8173

82-
val result = targetTile.dualCombine(mask)
74+
val result = maskEval(targetTile, mask,
75+
{ (v, m) => if (m != maskValue) NODATA else v },
8376
{ (v, m) => if (m != maskValue) NODATA else v }
84-
{ (v, m) => if (d2i(m) != maskValue) NODATA else v }
77+
)
8578
toInternalRow(result, targetCtx)
8679
}
8780
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByDefined.scala

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,9 @@
2222
package org.locationtech.rasterframes.expressions.transformers
2323
import geotrellis.raster.{NODATA, Tile, isNoData}
2424
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
25-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2625
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2726
import org.apache.spark.sql.{Column, TypedColumn}
2827
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
29-
import org.apache.spark.sql.types.DataType
30-
import org.locationtech.rasterframes.expressions.DynamicExtractors.{tileExtractor}
3128
import org.locationtech.rasterframes.expressions.{RasterResult, row}
3229
import org.locationtech.rasterframes.tileEncoder
3330

@@ -44,36 +41,26 @@ import org.locationtech.rasterframes.tileEncoder
4441
..."""
4542
)
4643
case class MaskByDefined(targetTile: Expression, maskTile: Expression)
47-
extends BinaryExpression
44+
extends BinaryExpression with MaskExpression
4845
with CodegenFallback
4946
with RasterResult {
5047
override def nodeName: String = "rf_mask"
5148

52-
def dataType: DataType = targetTile.dataType
5349
def left: Expression = targetTile
5450
def right: Expression = maskTile
5551

5652
protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
5753
MaskByDefined(newLeft, newRight)
5854

59-
override def checkInputDataTypes(): TypeCheckResult = {
60-
if (!tileExtractor.isDefinedAt(targetTile.dataType)) {
61-
TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.")
62-
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
63-
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
64-
} else TypeCheckSuccess
65-
}
66-
67-
private lazy val targetTileExtractor = tileExtractor(targetTile.dataType)
68-
private lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
55+
override def checkInputDataTypes(): TypeCheckResult = checkTileDataTypes()
6956

7057
override protected def nullSafeEval(targetInput: Any, maskInput: Any): Any = {
7158
val (targetTile, targetCtx) = targetTileExtractor(row(targetInput))
7259
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
73-
74-
val result = targetTile.dualCombine(mask)
75-
{ (v, m) => if (isNoData(m)) NODATA else v }
60+
val result = maskEval(targetTile, mask,
61+
{ (v, m) => if (isNoData(m)) NODATA else v },
7662
{ (v, m) => if (isNoData(m)) NODATA else v }
63+
)
7764
toInternalRow(result, targetCtx)
7865
}
7966
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValue.scala

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121

2222
package org.locationtech.rasterframes.expressions.transformers
2323

24-
import geotrellis.raster.{NODATA, Tile, d2i}
24+
import geotrellis.raster.{NODATA, Tile}
2525
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
2727
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2828
import org.apache.spark.sql.{Column, TypedColumn}
2929
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
30-
import org.apache.spark.sql.types.{DataType}
31-
import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArgExtractor, tileExtractor}
30+
import org.locationtech.rasterframes.expressions.DynamicExtractors.intArgExtractor
3231
import org.locationtech.rasterframes.expressions.{RasterResult, row}
3332
import org.locationtech.rasterframes.tileEncoder
3433

@@ -46,42 +45,36 @@ import org.locationtech.rasterframes.tileEncoder
4645
> SELECT _FUNC_(target, mask, maskValue);
4746
..."""
4847
)
49-
case class MaskByValue(dataTile: Expression, maskTile: Expression, maskValue: Expression)
50-
extends TernaryExpression
48+
case class MaskByValue(targetTile: Expression, maskTile: Expression, maskValue: Expression)
49+
extends TernaryExpression with MaskExpression
5150
with CodegenFallback
5251
with RasterResult {
5352
override def nodeName: String = "rf_mask_by_value"
5453

55-
def dataType: DataType = dataTile.dataType
56-
def first: Expression = dataTile
54+
def first: Expression = targetTile
5755
def second: Expression = maskTile
5856
def third: Expression = maskValue
5957

6058
protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
6159
MaskByValue(newFirst, newSecond, newThird)
6260

6361
override def checkInputDataTypes(): TypeCheckResult = {
64-
if (!tileExtractor.isDefinedAt(dataTile.dataType)) {
65-
TypeCheckFailure(s"Input type '${dataTile.dataType}' does not conform to a raster type.")
66-
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
67-
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
68-
} else if (!intArgExtractor.isDefinedAt(maskValue.dataType)) {
62+
if (!intArgExtractor.isDefinedAt(maskValue.dataType)) {
6963
TypeCheckFailure(s"Input type '${maskValue.dataType}' isn't an integral type.")
70-
} else TypeCheckSuccess
64+
} else checkTileDataTypes()
7165
}
7266

73-
private lazy val dataTileExtractor = tileExtractor(dataTile.dataType)
74-
private lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
7567
private lazy val maskValueExtractor = intArgExtractor(maskValue.dataType)
7668

7769
override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = {
78-
val (targetTile, targetCtx) = dataTileExtractor(row(targetInput))
70+
val (targetTile, targetCtx) = targetTileExtractor(row(targetInput))
7971
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
8072
val maskValue = maskValueExtractor(maskValueInput).value
8173

82-
val result = targetTile.dualCombine(mask)
74+
val result = maskEval(targetTile, mask,
75+
{ (v, m) => if (m == maskValue) NODATA else v },
8376
{ (v, m) => if (m == maskValue) NODATA else v }
84-
{ (v, m) => if (d2i(m) == maskValue) NODATA else v }
77+
)
8578
toInternalRow(result, targetCtx)
8679
}
8780
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/MaskByValues.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121

2222
package org.locationtech.rasterframes.expressions.transformers
2323

24-
import geotrellis.raster.{NODATA, Tile, d2i}
24+
import geotrellis.raster.{NODATA, Tile}
2525
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
2727
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2828
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
2929
import org.apache.spark.sql.catalyst.util.ArrayData
30-
import org.apache.spark.sql.types._
3130
import org.apache.spark.sql.{Column, TypedColumn}
32-
import org.locationtech.rasterframes.expressions.DynamicExtractors.{intArrayExtractor, tileExtractor}
31+
import org.locationtech.rasterframes.expressions.DynamicExtractors.intArrayExtractor
3332
import org.locationtech.rasterframes.expressions.{RasterResult, row}
3433
import org.locationtech.rasterframes.tileEncoder
3534

@@ -48,12 +47,11 @@ import org.locationtech.rasterframes.tileEncoder
4847
..."""
4948
)
5049
case class MaskByValues(targetTile: Expression, maskTile: Expression, maskValues: Expression)
51-
extends TernaryExpression
50+
extends TernaryExpression with MaskExpression
5251
with CodegenFallback
5352
with RasterResult {
5453
override def nodeName: String = "rf_mask_by_values"
5554

56-
def dataType: DataType = targetTile.dataType
5755
def first: Expression = targetTile
5856
def second: Expression = maskTile
5957
def third: Expression = maskValues
@@ -62,26 +60,21 @@ case class MaskByValues(targetTile: Expression, maskTile: Expression, maskValues
6260
MaskByValues(newFirst, newSecond, newThird)
6361

6462
override def checkInputDataTypes(): TypeCheckResult =
65-
if (!tileExtractor.isDefinedAt(targetTile.dataType)) {
66-
TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.")
67-
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
68-
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
69-
} else if (!intArrayExtractor.isDefinedAt(maskValues.dataType)) {
63+
if (!intArrayExtractor.isDefinedAt(maskValues.dataType)) {
7064
TypeCheckFailure(s"Input type '${maskValues.dataType}' does not translate to an array<int>.")
71-
} else TypeCheckSuccess
65+
} else checkTileDataTypes()
7266

73-
private lazy val targetTileExtractor = tileExtractor(targetTile.dataType)
74-
private lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
7567
private lazy val maskValuesExtractor = intArrayExtractor(maskValues.dataType)
7668

7769
override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValuesInput: Any): Any = {
7870
val (targetTile, targetCtx) = targetTileExtractor(row(targetInput))
7971
val (mask, maskCtx) = maskTileExtractor(row(maskInput))
8072
val maskValues: Array[Int] = maskValuesExtractor(maskValuesInput.asInstanceOf[ArrayData])
8173

82-
val result = targetTile.dualCombine(mask)
74+
val result = maskEval(targetTile, mask,
75+
{ (v, m) => if (maskValues.contains(m)) NODATA else v },
8376
{ (v, m) => if (maskValues.contains(m)) NODATA else v }
84-
{ (v, m) => if (maskValues.contains(d2i(m))) NODATA else v }
77+
)
8578

8679
toInternalRow(result, targetCtx)
8780
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.transformers
23+
24+
import geotrellis.raster.Tile
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
27+
import org.apache.spark.sql.catalyst.expressions.Expression
28+
import org.apache.spark.sql.types.DataType
29+
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
30+
31+
import spire.syntax.cfor._
32+
33+
trait MaskExpression { self: Expression =>
34+
35+
def targetTile: Expression
36+
def maskTile: Expression
37+
38+
def dataType: DataType = targetTile.dataType
39+
40+
protected lazy val targetTileExtractor = tileExtractor(targetTile.dataType)
41+
protected lazy val maskTileExtractor = tileExtractor(maskTile.dataType)
42+
43+
def checkTileDataTypes(): TypeCheckResult = {
44+
if (!tileExtractor.isDefinedAt(targetTile.dataType)) {
45+
TypeCheckFailure(s"Input type '${targetTile.dataType}' does not conform to a raster type.")
46+
} else if (!tileExtractor.isDefinedAt(maskTile.dataType)) {
47+
TypeCheckFailure(s"Input type '${maskTile.dataType}' does not conform to a raster type.")
48+
} else TypeCheckSuccess
49+
}
50+
51+
def maskEval(targetTile: Tile, maskTile: Tile, maskInt: (Int, Int) => Int, maskDouble: (Double, Int) => Double): Tile = {
52+
val result = targetTile.mutable
53+
54+
if (targetTile.cellType.isFloatingPoint) {
55+
cfor(0)(_ < targetTile.rows, _ + 1) { row =>
56+
cfor(0)(_ < targetTile.cols, _ + 1) { col =>
57+
val v = targetTile.getDouble(col, row)
58+
val m = maskTile.get(col, row)
59+
result.setDouble(col, row, maskDouble(v, m))
60+
}
61+
}
62+
} else {
63+
cfor(0)(_ < targetTile.rows, _ + 1) { row =>
64+
cfor(0)(_ < targetTile.cols, _ + 1) { col =>
65+
val v = targetTile.get(col, row)
66+
val m = maskTile.get(col, row)
67+
result.set(col, row, maskInt(v, m))
68+
}
69+
}
70+
}
71+
72+
result
73+
}
74+
}

0 commit comments

Comments
 (0)