Skip to content

Commit 4e12f2e

Browse files
committed
Attempt to create Expression for MaskByValues to enable sql api
Signed-off-by: Jason T. Brown <[email protected]>
1 parent acab757 commit 4e12f2e

File tree

4 files changed

+29
-5
lines changed

4 files changed

+29
-5
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,7 @@ trait RasterFunctions {
316316
list, replace the value with NODATA.
317317
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
318318
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column, inverse: Boolean): TypedColumn[Any, Tile] =
319-
if (!inverse)
320-
Mask.MaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(1))
321-
else
322-
Mask.MaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(0))
319+
Mask.MaskByValues(sourceTile, maskTile, maskValues, lit(inverse))
323320

324321
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
325322
list, replace the value with NODATA. */

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ package object expressions {
126126

127127
registry.registerExpression[Mask.MaskByDefined]("rf_mask")
128128
registry.registerExpression[Mask.MaskByValue]("rf_mask_by_value")
129+
registry.registerExpression[Mask.MaskByValues]("rf_mask_by_values")
129130
registry.registerExpression[Mask.InverseMaskByValue]("rf_inverse_mask_by_value")
130131
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
131132

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,18 @@ package org.locationtech.rasterframes.expressions.transformers
2424
import com.typesafe.scalalogging.Logger
2525
import geotrellis.raster
2626
import geotrellis.raster.Tile
27-
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask => gtInverseMask, Mask => gtMask}
27+
import geotrellis.raster.mapalgebra.local.{Defined, InverseMask gtInverseMask, Mask gtMask}
2828
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2929
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
3030
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3131
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, Literal, TernaryExpression}
32+
import org.apache.spark.sql.functions.lit
3233
import org.apache.spark.sql.rf.TileUDT
3334
import org.apache.spark.sql.types.DataType
3435
import org.apache.spark.sql.{Column, TypedColumn}
3536
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3637
import org.locationtech.rasterframes.expressions.DynamicExtractors._
38+
import org.locationtech.rasterframes.expressions.localops.IsIn
3739
import org.locationtech.rasterframes.expressions.row
3840
import org.slf4j.LoggerFactory
3941

@@ -169,4 +171,26 @@ object Mask {
169171
def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
170172
new Column(InverseMaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile]
171173
}
174+
175+
@ExpressionDescription(
176+
usage = "_FUNC_(data, mask, maskValues, inverse) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA. If `inverse` is true, the cells in `mask` that are not in `maskValues` list become NODATA",
177+
arguments =
178+
"""
179+
180+
""",
181+
examples =
182+
"""
183+
> SELECT _FUNC_(data, mask, array(1, 2, 3), false)
184+
185+
"""
186+
)
187+
case class MaskByValues(dataTile: Expression, maskTile: Expression, maskValues: Expression, inverse: Boolean)
188+
extends Mask(dataTile, IsIn(maskTile, maskValues), inverse, false) {
189+
override def nodeName: String = "rf_mask_by_values"
190+
}
191+
object MaskByValues {
192+
def apply(dataTile: Column, maskTile: Column, maskValues: Column, inverse: Column): TypedColumn[Any, Tile] =
193+
new Column(MaskByValues(dataTile.expr, maskTile.expr, maskValues.expr, inverse.expr)).as[Tile]
194+
}
195+
172196
}

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,8 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
713713
}
714714

715715
it("should mask tile by another identified by specified values") {
716+
checkDocs("rf_mask_by_values")
717+
716718
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(six.rows), six.extent, six.crs)
717719
val df = Seq((six, squareIncrementingPRT))
718720
.toDF("tile", "mask")

0 commit comments

Comments
 (0)