@@ -24,11 +24,11 @@ package org.locationtech.rasterframes.expressions.transformers
2424import com .typesafe .scalalogging .Logger
2525import geotrellis .raster
2626import geotrellis .raster .Tile
27- import geotrellis .raster .mapalgebra .local .{Defined , InverseMask ⇒ gtInverseMask , Mask ⇒ gtMask }
27+ import geotrellis .raster .mapalgebra .local .{Defined , InverseMask => gtInverseMask , Mask => gtMask }
2828import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
2929import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .{TypeCheckFailure , TypeCheckSuccess }
3030import org .apache .spark .sql .catalyst .expressions .codegen .CodegenFallback
31- import org .apache .spark .sql .catalyst .expressions .{Expression , ExpressionDescription , Literal , TernaryExpression }
31+ import org .apache .spark .sql .catalyst .expressions .{AttributeSet , Expression , ExpressionDescription , Literal , TernaryExpression }
3232import org .apache .spark .sql .functions .lit
3333import org .apache .spark .sql .rf .TileUDT
3434import org .apache .spark .sql .types .DataType
@@ -41,47 +41,51 @@ import org.slf4j.LoggerFactory
4141
4242abstract class Mask (val left : Expression , val middle : Expression , val right : Expression , inverse : Boolean )
4343 extends TernaryExpression with CodegenFallback with Serializable {
44+ def targetExp = left
45+ def maskExp = middle
46+ def maskValueExp = right
4447
4548 @ transient protected lazy val logger = Logger (LoggerFactory .getLogger(getClass.getName))
4649
47-
4850 override def children : Seq [Expression ] = Seq (left, middle, right)
4951
5052 override def checkInputDataTypes (): TypeCheckResult = {
51- if (! tileExtractor.isDefinedAt(left .dataType)) {
52- TypeCheckFailure (s " Input type ' ${left .dataType}' does not conform to a raster type. " )
53- } else if (! tileExtractor.isDefinedAt(middle .dataType)) {
54- TypeCheckFailure (s " Input type ' ${middle .dataType}' does not conform to a raster type. " )
55- } else if (! intArgExtractor.isDefinedAt(right .dataType)) {
56- TypeCheckFailure (s " Input type ' ${right .dataType}' isn't an integral type. " )
53+ if (! tileExtractor.isDefinedAt(targetExp .dataType)) {
54+ TypeCheckFailure (s " Input type ' ${targetExp .dataType}' does not conform to a raster type. " )
55+ } else if (! tileExtractor.isDefinedAt(maskExp .dataType)) {
56+ TypeCheckFailure (s " Input type ' ${maskExp .dataType}' does not conform to a raster type. " )
57+ } else if (! intArgExtractor.isDefinedAt(maskValueExp .dataType)) {
58+ TypeCheckFailure (s " Input type ' ${maskValueExp .dataType}' isn't an integral type. " )
5759 } else TypeCheckSuccess
5860 }
5961 override def dataType : DataType = left.dataType
6062
61- override protected def nullSafeEval (leftInput : Any , middleInput : Any , rightInput : Any ): Any = {
63+ override def makeCopy (newArgs : Array [AnyRef ]): Expression = super .makeCopy(newArgs)
64+
65+ override protected def nullSafeEval (targetInput : Any , maskInput : Any , maskValueInput : Any ): Any = {
6266 implicit val tileSer = TileUDT .tileSerializer
63- val (leftTile, leftCtx ) = tileExtractor(left .dataType)(row(leftInput ))
64- val (rightTile, rightCtx ) = tileExtractor(middle .dataType)(row(middleInput ))
67+ val (targetTile, targetCtx ) = tileExtractor(targetExp .dataType)(row(targetInput ))
68+ val (maskTile, maskCtx ) = tileExtractor(maskExp .dataType)(row(maskInput ))
6569
66- if (leftCtx .isEmpty && rightCtx .isDefined)
70+ if (targetCtx .isEmpty && maskCtx .isDefined)
6771 logger.warn(
6872 s " Right-hand parameter ' ${middle}' provided an extent and CRS, but the left-hand parameter " +
6973 s " ' ${left}' didn't have any. Because the left-hand side defines output type, the right-hand context will be lost. " )
7074
71- if (leftCtx .isDefined && rightCtx .isDefined && leftCtx != rightCtx )
75+ if (targetCtx .isDefined && maskCtx .isDefined && targetCtx != maskCtx )
7276 logger.warn(s " Both ' ${left}' and ' ${middle}' provided an extent and CRS, but they are different. Left-hand side will be used. " )
7377
74- val maskValue = intArgExtractor(right .dataType)(rightInput )
78+ val maskValue = intArgExtractor(maskValueExp .dataType)(maskValueInput )
7579
76- val masking = if (maskValue.value == 0 ) Defined (rightTile )
77- else rightTile
80+ val masking = if (maskValue.value == 0 ) Defined (maskTile )
81+ else maskTile
7882
7983 val result = if (inverse)
80- gtInverseMask(leftTile , masking, maskValue.value, raster.NODATA )
84+ gtInverseMask(targetTile , masking, maskValue.value, raster.NODATA )
8185 else
82- gtMask(leftTile , masking, maskValue.value, raster.NODATA )
86+ gtMask(targetTile , masking, maskValue.value, raster.NODATA )
8387
84- leftCtx match {
88+ targetCtx match {
8589 case Some (ctx) => ctx.toProjectRasterTile(result).toInternalRow
8690 case None => result.toInternalRow
8791 }
@@ -173,24 +177,26 @@ object Mask {
173177 }
174178
175179 @ ExpressionDescription (
176- usage = " _FUNC_(data, mask, maskValues, inverse) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA. If `inverse` is true, the cells in `mask` that are not in `maskValues` list become NODATA" ,
177- arguments =
178- """
179-
180+ usage = " _FUNC_(data, mask, maskValues) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA." ,
181+ arguments = """
182+ Arguments:
183+ * target - tile to mask
184+ * mask - masking definition
185+ * maskValues - sequence of values to consider as masks candidates
180186 """ ,
181- examples =
182- """
183- > SELECT _FUNC_(data, mask, array(1, 2, 3), false)
184-
185- """
187+ examples = """
188+ Examples:
189+ > SELECT _FUNC_(data, mask, array(1, 2, 3))
190+ ..."""
186191 )
187- case class MaskByValues (dataTile : Expression , maskTile : Expression , maskValues : Expression , inverse : Boolean )
188- extends Mask (dataTile, IsIn (maskTile, maskValues), inverse, false ) {
192+ case class MaskByValues (dataTile : Expression , maskTile : Expression )
193+ extends Mask (dataTile, maskTile, Literal (1 ), inverse = false ) {
194+ def this (dataTile : Expression , maskTile : Expression , maskValues : Expression ) =
195+ this (dataTile, IsIn (maskTile, maskValues))
189196 override def nodeName : String = " rf_mask_by_values"
190197 }
191198 object MaskByValues {
192- def apply (dataTile : Column , maskTile : Column , maskValues : Column , inverse : Column ): TypedColumn [Any , Tile ] =
193- new Column (MaskByValues (dataTile.expr, maskTile.expr , maskValues.expr, inverse .expr)).as[Tile ]
199+ def apply (dataTile : Column , maskTile : Column , maskValues : Column ): TypedColumn [Any , Tile ] =
200+ new Column (MaskByValues (dataTile.expr, IsIn ( maskTile, maskValues) .expr)).as[Tile ]
194201 }
195-
196202}
0 commit comments