Skip to content

Commit 05037ab

Browse files
committed
Fix ScalaUDF cleanup
1 parent c345c11 commit 05037ab

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ trait UnaryRasterAggregate extends DeclarativeAggregate {
4141
def children = Seq(child)
4242

4343
protected def tileOpAsExpression[R: TypeTag](name: String, op: Tile => R): Expression => ScalaUDF =
44-
udfexpr[R, Any, Tile](name, (dataType: DataType) => (a: Any) => if(a == null) null.asInstanceOf[R] else op(UnaryRasterAggregate.extractTileFromAny(dataType, a)))
44+
udfiexpr[R, Any](name, (dataType: DataType) => (a: Any) => if(a == null) null.asInstanceOf[R] else op(UnaryRasterAggregate.extractTileFromAny(dataType, a)))
4545
}
4646

4747
object UnaryRasterAggregate {

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ package object expressions {
5151
private[expressions]
5252
def fpTile(t: Tile) = if (t.cellType.isFloatingPoint) t else t.convert(DoubleConstantNoDataCellType)
5353

54-
/** As opposed to `udf`, this constructs an unwrapped ScalaUDF Expression from a function. */
54+
/**
55+
* As opposed to `udf`, this constructs an unwrapped ScalaUDF Expression from a function.
56+
* This ScalaUDF Expression expects the argument of type A1 to match the return type RT at runtime.
57+
*/
5558
private[expressions]
56-
def udfexpr[RT: TypeTag, A1: TypeTag, A1T: TypeTag](name: String, f: DataType => A1 => RT): Expression => ScalaUDF = (exp: Expression) => {
57-
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
58-
ScalaUDF((row: A1) => f(exp.dataType)(row), dataType, exp :: Nil, Option(ExpressionEncoder[A1T]().resolveAndBind()) :: Nil)
59+
def udfiexpr[RT: TypeTag, A1: TypeTag](name: String, f: DataType => A1 => RT): Expression => ScalaUDF = (child: Expression) => {
60+
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[RT]
61+
ScalaUDF((row: A1) => f(child.dataType)(row), dataType, Seq(child), Seq(Option(ExpressionEncoder[RT]().resolveAndBind())), udfName = Some(name))
5962
}
6063

6164
def register(sqlContext: SQLContext): Unit = {

0 commit comments

Comments
 (0)