Skip to content

Commit 137d0d1

Browse files
committed
Fix unit tests for rf_local_is_in
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 7571b0f commit 137d0d1

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/localops/IsIn.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ case class IsIn(left: Expression, right: Expression) extends BinaryExpression wi
5252

5353
override def dataType: DataType = left.dataType
5454

55-
@transient private lazy val elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
55+
@transient private lazy val elementType: DataType = right.dataType.asInstanceOf[ArrayType].elementType
5656

5757
override def checkInputDataTypes(): TypeCheckResult =
5858
if(!tileExtractor.isDefinedAt(left.dataType)) {

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,10 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
977977
checkDocs("rf_local_is_in")
978978

979979
// tile is 3 by 3 with values, 1 to 9
980-
val df = Seq((byteArrayTile, lit(1), lit(5), lit(10))).toDF("t", "one", "five", "ten")
980+
val df = Seq(byteArrayTile).toDF("t")
981+
.withColumn("one", lit(1))
982+
.withColumn("five", lit(5))
983+
.withColumn("ten", lit(10))
981984
.withColumn("in_expect_2", rf_local_is_in($"t", array($"one", $"five")))
982985
.withColumn("in_expect_1", rf_local_is_in($"t", array($"ten", $"five")))
983986
.withColumn("in_expect_0", rf_local_is_in($"t", array($"ten")))
@@ -988,7 +991,7 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
988991
val e1Result = df.select(rf_tile_sum($"in_expect_1")).as[Double].first()
989992
e1Result should be (1.0)
990993

991-
val e0Result = df.select($"in_expect_1").as[Tile].first()
994+
val e0Result = df.select($"in_expect_0").as[Tile].first()
992995
e0Result.toArray() should contain only (0)
993996

994997
// lazy val invalid = df.select(rf_local_is_in($"t", lit("foobar"))).as[Tile].first()

0 commit comments

Comments
 (0)