Skip to content

Commit 1c1c632

Browse files
committed
Merge remote-tracking branch 'lt/develop' into feature/mask_by_bits
2 parents bebf00f + 59cacd5 commit 1c1c632

File tree

21 files changed

+191
-228
lines changed

21 files changed

+191
-228
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -312,35 +312,19 @@ trait RasterFunctions {
312312
def rf_mask_by_value(sourceTile: Column, maskTile: Column, maskValue: Int): TypedColumn[Any, Tile] =
313313
rf_mask_by_value(sourceTile, maskTile, maskValue, false)
314314

315-
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
316-
list, replace the value with NODATA.
317-
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
318-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column, inverse: Boolean): TypedColumn[Any, Tile] =
319-
if (!inverse)
320-
Mask.MaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(1))
321-
else
322-
Mask.InverseMaskByValue(sourceTile, rf_local_is_in(maskTile, maskValues), lit(0))
323-
324315
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
325316
list, replace the value with NODATA. */
326317
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] =
327-
rf_mask_by_values(sourceTile, maskTile, maskValues, false)
318+
Mask.MaskByValues(sourceTile, maskTile, maskValues)
328319

329320
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
330-
list, replace the value with NODATA.
331-
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
332-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Iterable[Int], inverse: Boolean): TypedColumn[Any, Tile] = {
321+
list, replace the value with NODATA. */
322+
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Seq[Int]): TypedColumn[Any, Tile] = {
333323
import org.apache.spark.sql.functions.array
334324
val valuesCol: Column = array(maskValues.map(lit).toSeq: _*)
335-
rf_mask_by_values(sourceTile, maskTile, valuesCol, inverse)
325+
rf_mask_by_values(sourceTile, maskTile, valuesCol)
336326
}
337327

338-
/** Generate a tile with the values from `data_tile`, but where cells in the `mask_tile` are in the `mask_values`
339-
list, replace the value with NODATA.
340-
If `inverse` is True, the cells in `mask_tile` that are not in `mask_values` list become NODATA */
341-
def rf_mask_by_values(sourceTile: Column, maskTile: Column, maskValues: Iterable[Int]): TypedColumn[Any, Tile] =
342-
rf_mask_by_values(sourceTile, maskTile, maskValues, false)
343-
344328
/** Where the `maskTile` does **not** contain `NoData`, replace values in the source tile with `NoData` */
345329
def rf_inverse_mask(sourceTile: Column, maskTile: Column): TypedColumn[Any, Tile] =
346330
Mask.InverseMaskByDefined(sourceTile, maskTile)

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ object TileRasterizerAggregate {
138138
}
139139
}
140140

141-
// Scan table and constuct what the TileLayerMetadata would be in the specified destination CRS.
141+
// Scan table and construct what the TileLayerMetadata would be in the specified destination CRS.
142142
val tlm: TileLayerMetadata[SpatialKey] = df
143143
.select(
144144
ProjectedLayerMetadataAggregate(

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,10 @@ package object expressions {
125125
registry.registerExpression[LocalMeanAggregate]("rf_agg_local_mean")
126126

127127
registry.registerExpression[Mask.MaskByDefined]("rf_mask")
128+
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
128129
registry.registerExpression[Mask.MaskByValue]("rf_mask_by_value")
129130
registry.registerExpression[Mask.InverseMaskByValue]("rf_inverse_mask_by_value")
130-
registry.registerExpression[Mask.InverseMaskByDefined]("rf_inverse_mask")
131+
registry.registerExpression[Mask.MaskByValues]("rf_mask_by_values")
131132

132133
registry.registerExpression[DebugRender.RenderAscii]("rf_render_ascii")
133134
registry.registerExpression[DebugRender.RenderMatrix]("rf_render_matrix")

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/Mask.scala

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,52 +34,58 @@ import org.apache.spark.sql.types.DataType
3434
import org.apache.spark.sql.{Column, TypedColumn}
3535
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3636
import org.locationtech.rasterframes.expressions.DynamicExtractors._
37+
import org.locationtech.rasterframes.expressions.localops.IsIn
3738
import org.locationtech.rasterframes.expressions.row
3839
import org.slf4j.LoggerFactory
3940

4041
abstract class Mask(val left: Expression, val middle: Expression, val right: Expression, inverse: Boolean)
4142
extends TernaryExpression with CodegenFallback with Serializable {
43+
// aliases.
44+
def targetExp = left
45+
def maskExp = middle
46+
def maskValueExp = right
4247

4348
@transient protected lazy val logger = Logger(LoggerFactory.getLogger(getClass.getName))
4449

45-
4650
override def children: Seq[Expression] = Seq(left, middle, right)
4751

4852
override def checkInputDataTypes(): TypeCheckResult = {
49-
if (!tileExtractor.isDefinedAt(left.dataType)) {
50-
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
51-
} else if (!tileExtractor.isDefinedAt(middle.dataType)) {
52-
TypeCheckFailure(s"Input type '${middle.dataType}' does not conform to a raster type.")
53-
} else if (!intArgExtractor.isDefinedAt(right.dataType)) {
54-
TypeCheckFailure(s"Input type '${right.dataType}' isn't an integral type.")
53+
if (!tileExtractor.isDefinedAt(targetExp.dataType)) {
54+
TypeCheckFailure(s"Input type '${targetExp.dataType}' does not conform to a raster type.")
55+
} else if (!tileExtractor.isDefinedAt(maskExp.dataType)) {
56+
TypeCheckFailure(s"Input type '${maskExp.dataType}' does not conform to a raster type.")
57+
} else if (!intArgExtractor.isDefinedAt(maskValueExp.dataType)) {
58+
TypeCheckFailure(s"Input type '${maskValueExp.dataType}' isn't an integral type.")
5559
} else TypeCheckSuccess
5660
}
5761
override def dataType: DataType = left.dataType
5862

59-
override protected def nullSafeEval(leftInput: Any, middleInput: Any, rightInput: Any): Any = {
63+
override def makeCopy(newArgs: Array[AnyRef]): Expression = super.makeCopy(newArgs)
64+
65+
override protected def nullSafeEval(targetInput: Any, maskInput: Any, maskValueInput: Any): Any = {
6066
implicit val tileSer = TileUDT.tileSerializer
61-
val (leftTile, leftCtx) = tileExtractor(left.dataType)(row(leftInput))
62-
val (rightTile, rightCtx) = tileExtractor(middle.dataType)(row(middleInput))
67+
val (targetTile, targetCtx) = tileExtractor(targetExp.dataType)(row(targetInput))
68+
val (maskTile, maskCtx) = tileExtractor(maskExp.dataType)(row(maskInput))
6369

64-
if (leftCtx.isEmpty && rightCtx.isDefined)
70+
if (targetCtx.isEmpty && maskCtx.isDefined)
6571
logger.warn(
6672
s"Right-hand parameter '${middle}' provided an extent and CRS, but the left-hand parameter " +
6773
s"'${left}' didn't have any. Because the left-hand side defines output type, the right-hand context will be lost.")
6874

69-
if (leftCtx.isDefined && rightCtx.isDefined && leftCtx != rightCtx)
75+
if (targetCtx.isDefined && maskCtx.isDefined && targetCtx != maskCtx)
7076
logger.warn(s"Both '${left}' and '${middle}' provided an extent and CRS, but they are different. Left-hand side will be used.")
7177

72-
val maskValue = intArgExtractor(right.dataType)(rightInput)
78+
val maskValue = intArgExtractor(maskValueExp.dataType)(maskValueInput)
7379

74-
val masking = if (maskValue.value == 0) Defined(rightTile)
75-
else rightTile
80+
val masking = if (maskValue.value == 0) Defined(maskTile)
81+
else maskTile
7682

7783
val result = if (inverse)
78-
gtInverseMask(leftTile, masking, maskValue.value, raster.NODATA)
84+
gtInverseMask(targetTile, masking, maskValue.value, raster.NODATA)
7985
else
80-
gtMask(leftTile, masking, maskValue.value, raster.NODATA)
86+
gtMask(targetTile, masking, maskValue.value, raster.NODATA)
8187

82-
leftCtx match {
88+
targetCtx match {
8389
case Some(ctx) => ctx.toProjectRasterTile(result).toInternalRow
8490
case None => result.toInternalRow
8591
}
@@ -169,4 +175,28 @@ object Mask {
169175
def apply(srcTile: Column, maskingTile: Column, maskValue: Column): TypedColumn[Any, Tile] =
170176
new Column(InverseMaskByValue(srcTile.expr, maskingTile.expr, maskValue.expr)).as[Tile]
171177
}
178+
179+
@ExpressionDescription(
180+
usage = "_FUNC_(data, mask, maskValues) - Generate a tile with the values from `data` tile but where cells in the `mask` tile are in the `maskValues` list, replace the value with NODATA.",
181+
arguments = """
182+
Arguments:
183+
* target - tile to mask
184+
* mask - masking definition
185+
* maskValues - sequence of values to consider as masks candidates
186+
""",
187+
examples = """
188+
Examples:
189+
> SELECT _FUNC_(data, mask, array(1, 2, 3))
190+
..."""
191+
)
192+
case class MaskByValues(dataTile: Expression, maskTile: Expression)
193+
extends Mask(dataTile, maskTile, Literal(1), inverse = false) {
194+
def this(dataTile: Expression, maskTile: Expression, maskValues: Expression) =
195+
this(dataTile, IsIn(maskTile, maskValues))
196+
override def nodeName: String = "rf_mask_by_values"
197+
}
198+
object MaskByValues {
199+
def apply(dataTile: Column, maskTile: Column, maskValues: Column): TypedColumn[Any, Tile] =
200+
new Column(MaskByValues(dataTile.expr, IsIn(maskTile, maskValues).expr)).as[Tile]
201+
}
172202
}

core/src/main/scala/org/locationtech/rasterframes/util/DataFrameRenderers.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@ package org.locationtech.rasterframes.util
2424
import geotrellis.raster.render.ColorRamps
2525
import org.apache.spark.sql.Dataset
2626
import org.apache.spark.sql.functions.{base64, concat, concat_ws, length, lit, substring, when}
27+
import org.apache.spark.sql.jts.JTSTypes
2728
import org.apache.spark.sql.types.{StringType, StructField}
2829
import org.locationtech.rasterframes.expressions.DynamicExtractors
2930
import org.locationtech.rasterframes.{rfConfig, rf_render_png, rf_resample}
31+
import org.apache.spark.sql.rf.WithTypeConformity
3032

3133
/**
32-
* DataFrame extensiosn for rendering sample content in a number of ways
34+
* DataFrame extension for rendering sample content in a number of ways
3335
*/
3436
trait DataFrameRenderers {
3537
private val truncateWidth = rfConfig.getInt("max-truncate-row-element-length")
@@ -47,8 +49,9 @@ trait DataFrameRenderers {
4749
lit("\"></img>")
4850
)
4951
else {
52+
val isGeom = WithTypeConformity(c.dataType).conformsTo(JTSTypes.GeometryTypeInstance)
5053
val str = resolved.cast(StringType)
51-
if (truncate)
54+
if (truncate || isGeom)
5255
when(length(str) > lit(truncateWidth),
5356
concat(substring(str, 1, truncateWidth), lit("..."))
5457
)

core/src/test/resources/MCD43A4.A2019111.h30v06.006.2019120033434_01.mrf.aux.xml

Lines changed: 0 additions & 92 deletions
This file was deleted.

core/src/test/scala/org/locationtech/rasterframes/ExtensionMethodSpec.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import scala.xml.parsing.XhtmlParser
3939
class ExtensionMethodSpec extends TestEnvironment with TestData with SubdivideSupport {
4040
lazy val rf = sampleTileLayerRDD.toLayer
4141

42-
describe("DataFrame exention methods") {
42+
describe("DataFrame extension methods") {
4343
it("should maintain original type") {
4444
val df = rf.withPrefixedColumnNames("_foo_")
4545
"val rf2: RasterFrameLayer = df" should compile
@@ -49,7 +49,7 @@ class ExtensionMethodSpec extends TestEnvironment with TestData with SubdivideSu
4949
"val Some(col) = df.spatialKeyColumn" should compile
5050
}
5151
}
52-
describe("RasterFrameLayer exention methods") {
52+
describe("RasterFrameLayer extension methods") {
5353
it("should provide spatial key column") {
5454
noException should be thrownBy {
5555
rf.spatialKeyColumn
@@ -124,6 +124,10 @@ class ExtensionMethodSpec extends TestEnvironment with TestData with SubdivideSu
124124

125125
val md3 = rf.toMarkdown(truncate=true, renderTiles = false)
126126
md3 shouldNot include("<img")
127+
128+
// Should truncate JTS types even when we don't ask for it.
129+
val md4 = rf.withGeometry().select("geometry").toMarkdown(truncate = false)
130+
md4 should include ("...")
127131
}
128132

129133
it("should render HTML") {

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import geotrellis.raster._
2828
import geotrellis.raster.render.ColorRamps
2929
import geotrellis.raster.testkit.RasterMatchers
3030
import javax.imageio.ImageIO
31-
import org.apache.spark.sql.Encoders
31+
import org.apache.spark.sql.{Column, Encoders, TypedColumn}
3232
import org.apache.spark.sql.functions._
3333
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
3434
import org.locationtech.rasterframes.model.TileDimensions
@@ -701,7 +701,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
701701
val withMasked = withMask.withColumn("masked",
702702
rf_inverse_mask_by_value($"tile", $"mask", mask_value))
703703
.withColumn("masked2", rf_mask_by_value($"tile", $"mask", lit(mask_value), true))
704-
704+
withMasked.explain(true)
705705
val result = withMasked.agg(rf_agg_no_data_cells($"tile") < rf_agg_no_data_cells($"masked")).as[Boolean]
706706

707707
result.first() should be(true)
@@ -712,18 +712,27 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
712712
checkDocs("rf_inverse_mask_by_value")
713713
}
714714

715-
it("should mask tile by another identified by specified values") {
716-
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(randPRT.rows), randPRT.extent, randPRT.crs)
717-
val df = Seq((randPRT, squareIncrementingPRT))
715+
it("should mask tile by another identified by sequence of specified values") {
716+
val squareIncrementingPRT = ProjectedRasterTile(squareIncrementingTile(six.rows), six.extent, six.crs)
717+
val df = Seq((six, squareIncrementingPRT))
718718
.toDF("tile", "mask")
719+
719720
val mask_values = Seq(4, 5, 6, 12)
720721

721722
val withMasked = df.withColumn("masked",
722723
rf_mask_by_values($"tile", $"mask", mask_values))
723724

724-
val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "nd").as[Long]
725+
val expected = squareIncrementingPRT.toArray().count(v mask_values.contains(v))
726+
727+
val result = withMasked.agg(rf_agg_no_data_cells($"masked") as "masked_nd")
728+
.first()
729+
730+
result.getAs[BigInt](0) should be (expected)
731+
732+
val withMaskedSql = df.selectExpr("rf_mask_by_values(tile, mask, array(4, 5, 6, 12)) AS masked")
733+
val resultSql = withMaskedSql.agg(rf_agg_no_data_cells($"masked")).as[Long]
734+
resultSql.first() should be (expected)
725735

726-
result.first() should be(mask_values.length)
727736
checkDocs("rf_mask_by_values")
728737
}
729738

@@ -1018,7 +1027,6 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
10181027
val e0Result = df.select($"in_expect_0").as[Tile].first()
10191028
e0Result.toArray() should contain only (0)
10201029

1021-
// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first()
10221030
}
10231031

10241032
it("should unpack QA bits"){

0 commit comments

Comments
 (0)