Skip to content

Commit 7571b0f

Browse files
committed
Add rf_local_is_in function
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 9ad38ec commit 7571b0f

File tree

4 files changed

+113
-0
lines changed

4 files changed

+113
-0
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ trait RasterFunctions {
389389
/** Cellwise inequality comparison between a tile and a scalar. */
390390
def rf_local_unequal[T: Numeric](tileCol: Column, value: T): Column = Unequal(tileCol, value)
391391

392+
/** Test if each cell value is in provided array */
393+
def rf_local_is_in(tileCol: Column, arrayCol: Column) = IsIn(tileCol, arrayCol)
394+
392395
/** Return a tile with ones where the input is NoData, otherwise zero */
393396
def rf_local_no_data(tileCol: Column): Column = Undefined(tileCol)
394397

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.localops
23+
24+
import geotrellis.raster.Tile
25+
import geotrellis.raster.mapalgebra.local.IfCell
26+
import org.apache.spark.sql.Column
27+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
28+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
29+
import org.apache.spark.sql.types.{ArrayType, DataType}
30+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
31+
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, ExpressionDescription}
32+
import org.apache.spark.sql.catalyst.util.ArrayData
33+
import org.apache.spark.sql.rf.TileUDT
34+
import org.locationtech.rasterframes.encoders.CatalystSerializer._
35+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
36+
import org.locationtech.rasterframes.expressions._
37+
38+
@ExpressionDescription(
39+
usage = "_FUNC_(tile, rhs) - In each cell of `tile`, return true if the value is in rhs.",
40+
arguments = """
41+
Arguments:
42+
* tile - tile column to apply abs
43+
* rhs - array to test against
44+
""",
45+
examples = """
46+
Examples:
47+
> SELECT _FUNC_(tile, array);
48+
..."""
49+
)
50+
case class IsIn(left: Expression, right: Expression) extends BinaryExpression with CodegenFallback {
51+
override val nodeName: String = "rf_local_is_in"
52+
53+
override def dataType: DataType = left.dataType
54+
55+
@transient private lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
56+
57+
override def checkInputDataTypes(): TypeCheckResult =
58+
if(!tileExtractor.isDefinedAt(left.dataType)) {
59+
TypeCheckFailure(s"Input type '${left.dataType}' does not conform to a raster type.")
60+
} else right.dataType match {
61+
case _: ArrayType TypeCheckSuccess
62+
case _ TypeCheckFailure(s"Input type '${right.dataType}' does not conform to ArrayType.")
63+
}
64+
65+
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
66+
implicit val tileSer = TileUDT.tileSerializer
67+
val (childTile, childCtx) = tileExtractor(left.dataType)(row(input1))
68+
69+
val arr = input2.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
70+
71+
childCtx match {
72+
case Some(ctx) => ctx.toProjectRasterTile(op(childTile, arr)).toInternalRow
73+
case None => op(childTile, arr).toInternalRow
74+
}
75+
76+
}
77+
78+
protected def op(left: Tile, right: IndexedSeq[AnyRef]): Tile = {
79+
def fn(i: Int): Boolean = right.contains(i)
80+
IfCell(left, fn(_), 1, 0)
81+
}
82+
83+
}
84+
85+
object IsIn {
86+
def apply(left: Column, right: Column): Column =
87+
new Column(IsIn(left.expr, right.expr))
88+
}

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ package object expressions {
8787
registry.registerExpression[GreaterEqual]("rf_local_greater_equal")
8888
registry.registerExpression[Equal]("rf_local_equal")
8989
registry.registerExpression[Unequal]("rf_local_unequal")
90+
registry.registerExpression[IsIn]("rf_local_is_in")
9091
registry.registerExpression[Undefined]("rf_local_no_data")
9192
registry.registerExpression[Defined]("rf_local_data")
9293
registry.registerExpression[Sum]("rf_tile_sum")

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,4 +972,25 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
972972
val dResult = df.select($"ld").as[Tile].first()
973973
dResult should be (randNDPRT.localDefined())
974974
}
975+
976+
it("should check values isin"){
977+
checkDocs("rf_local_is_in")
978+
979+
// tile is 3 by 3 with values, 1 to 9
980+
val df = Seq((byteArrayTile, lit(1), lit(5), lit(10))).toDF("t", "one", "five", "ten")
981+
.withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five")))
982+
.withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five")))
983+
.withColumn("in_expect_0", rf_local_is_in($"t", array($"ten")))
984+
985+
val e2Result = df.select(rf_tile_sum($"in_expect_2")).as[Double].first()
986+
e2Result should be (2.0)
987+
988+
val e1Result = df.select(rf_tile_sum($"in_expect_1")).as[Double].first()
989+
e1Result should be (1.0)
990+
991+
val e0Result = df.select($"in_expect_1").as[Tile].first()
992+
e0Result.toArray() should contain only (0)
993+
994+
// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first()
995+
}
975996
}

0 commit comments

Comments
 (0)