Skip to content

Commit e748f93

Browse files
committed
Expressions returning primitives can't return null, so changing how
constant tiles are created to keep catalyst optimizer from pruning them.
1 parent e06bdd5 commit e748f93

File tree

4 files changed

+13
-13
lines changed

4 files changed

+13
-13
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,8 @@ trait RasterFunctions {
231231

232232
/** Constructor for tile column with a single cell value. */
233233
def rf_make_constant_tile(value: Number, cols: Int, rows: Int, cellTypeName: String): TypedColumn[Any, Tile] = {
234-
import org.apache.spark.sql.rf.TileUDT.tileSerializer
235-
val constTile = encoders.serialized_literal(F.makeConstantTile(value, cols, rows, cellTypeName))
236-
withTypedAlias(s"rf_make_constant_tile($value, $cols, $rows, $cellTypeName)")(constTile)
234+
val constTile = udf(() => F.makeConstantTile(value, cols, rows, cellTypeName))
235+
withTypedAlias(s"rf_make_constant_tile($value, $cols, $rows, $cellTypeName)")(constTile.apply())
237236
}
238237

239238
/** Create a column constant tiles of zero */

core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/DataCells.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121

2222
package org.locationtech.rasterframes.expressions.tilestats
2323

24+
import org.locationtech.rasterframes.expressions.{NullToValue, UnaryRasterOp}
2425
import geotrellis.raster._
25-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
26+
import org.apache.spark.sql.{Column, TypedColumn}
2627
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
28+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2729
import org.apache.spark.sql.types.{DataType, LongType}
28-
import org.apache.spark.sql.{Column, TypedColumn}
29-
import org.locationtech.rasterframes.expressions.UnaryRasterOp
3030
import org.locationtech.rasterframes.model.TileContext
3131

3232
@ExpressionDescription(
@@ -40,10 +40,11 @@ import org.locationtech.rasterframes.model.TileContext
4040
357"""
4141
)
4242
case class DataCells(child: Expression) extends UnaryRasterOp
43-
with CodegenFallback {
43+
with CodegenFallback with NullToValue {
4444
override def nodeName: String = "rf_data_cells"
4545
override def dataType: DataType = LongType
4646
override protected def eval(tile: Tile, ctx: Option[TileContext]): Any = DataCells.op(tile)
47+
override def na: Any = 0L
4748
}
4849
object DataCells {
4950
import org.locationtech.rasterframes.encoders.StandardEncoders.PrimitiveEncoders.longEnc

core/src/main/scala/org/locationtech/rasterframes/expressions/tilestats/NoDataCells.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121

2222
package org.locationtech.rasterframes.expressions.tilestats
2323

24+
import org.locationtech.rasterframes.expressions.{NullToValue, UnaryRasterOp}
2425
import geotrellis.raster._
25-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
26+
import org.apache.spark.sql.{Column, TypedColumn}
2627
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
28+
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
2729
import org.apache.spark.sql.types.{DataType, LongType}
28-
import org.apache.spark.sql.{Column, TypedColumn}
29-
import org.locationtech.rasterframes.expressions.UnaryRasterOp
3030
import org.locationtech.rasterframes.model.TileContext
3131

3232
@ExpressionDescription(
@@ -40,10 +40,11 @@ import org.locationtech.rasterframes.model.TileContext
4040
12"""
4141
)
4242
case class NoDataCells(child: Expression) extends UnaryRasterOp
43-
with CodegenFallback {
43+
with CodegenFallback with NullToValue {
4444
override def nodeName: String = "rf_no_data_cells"
4545
override def dataType: DataType = LongType
4646
override protected def eval(tile: Tile, ctx: Option[TileContext]): Any = NoDataCells.op(tile)
47+
override def na: Any = 0L
4748
}
4849
object NoDataCells {
4950
import org.locationtech.rasterframes.encoders.StandardEncoders.PrimitiveEncoders.longEnc

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,14 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
374374

375375
checkDocs("rf_no_data_cells")
376376
}
377+
377378
it("should properly count data and nodata cells on constant tiles") {
378379
val rf = Seq(randPRT).toDF("tile")
379380

380381
val df = rf
381382
.withColumn("make", rf_make_constant_tile(99, 3, 4, ByteConstantNoDataCellType))
382383
.withColumn("make2", rf_with_no_data($"make", 99))
383384

384-
df.show(false)
385-
386385
val counts = df.select(
387386
rf_no_data_cells($"make").alias("nodata1"),
388387
rf_data_cells($"make").alias("data1"),

0 commit comments

Comments
 (0)