Skip to content

Commit c345c11

Browse files
committed
Cleanup ScalaUDF usage
1 parent 89c6627 commit c345c11

File tree

9 files changed

+12
-29
lines changed

9 files changed

+12
-29
lines changed

core/src/main/scala/org/apache/spark/sql/rf/RasterSourceUDT.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
package org.apache.spark.sql.rf
2323

2424
import org.apache.spark.sql.catalyst.InternalRow
25-
import org.apache.spark.sql.types.{DataType, UDTRegistration, UserDefinedType, _}
26-
import org.locationtech.rasterframes.expressions.transformers.RasterRefToTile
25+
import org.apache.spark.sql.types._
2726
import org.locationtech.rasterframes.ref.RFRasterSource
2827
import org.locationtech.rasterframes.util.KryoSupport
2928

core/src/main/scala/org/locationtech/rasterframes/expressions/UnaryRasterAggregate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ trait UnaryRasterAggregate extends DeclarativeAggregate {
4040

4141
def children = Seq(child)
4242

43-
protected def tileOpAsExpressionNew[R: TypeTag](name: String, op: Tile => R): Expression => ScalaUDF =
44-
udfexprNew[R, Any](name, (dataType: DataType) => (a: Any) => if(a == null) null.asInstanceOf[R] else op(UnaryRasterAggregate.extractTileFromAny(dataType, a)))
43+
protected def tileOpAsExpression[R: TypeTag](name: String, op: Tile => R): Expression => ScalaUDF =
44+
udfexpr[R, Any, Tile](name, (dataType: DataType) => (a: Any) => if(a == null) null.asInstanceOf[R] else op(UnaryRasterAggregate.extractTileFromAny(dataType, a)))
4545
}
4646

4747
object UnaryRasterAggregate {

core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetCRS.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
package org.locationtech.rasterframes.expressions.accessors
2323

24-
import geotrellis.proj4.{CRS, LatLng}
24+
import geotrellis.proj4.CRS
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2727
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellCountAggregate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ abstract class CellCountAggregate(isData: Boolean) extends UnaryRasterAggregate
4343
val initialValues = Seq(Literal(0L))
4444

4545
private def CellTest: Expression => ScalaUDF =
46-
if (isData) tileOpAsExpressionNew("rf_data_cells", DataCells.op)
47-
else tileOpAsExpressionNew("rf_no_data_cells", NoDataCells.op)
46+
if (isData) tileOpAsExpression("rf_data_cells", DataCells.op)
47+
else tileOpAsExpression("rf_no_data_cells", NoDataCells.op)
4848

4949
val updateExpressions = Seq(If(IsNull(child), count, Add(count, CellTest(child))))
5050

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/CellMeanAggregate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ case class CellMeanAggregate(child: Expression) extends UnaryRasterAggregate {
5454
// Cant' figure out why we can't just use the Expression directly
5555
// this is necessary to properly handle null rows. For example,
5656
// if we use `tilestats.Sum` directly, we get an NPE when the stage is executed.
57-
private val DataCellCounts = tileOpAsExpressionNew("rf_data_cells", DataCells.op)
58-
private val SumCells = tileOpAsExpressionNew("sum_cells", Sum.op)
57+
private val DataCellCounts = tileOpAsExpression("rf_data_cells", DataCells.op)
58+
private val SumCells = tileOpAsExpression("sum_cells", Sum.op)
5959

6060
val updateExpressions = Seq(
6161
// TODO: Figure out why this doesn't work. See above.

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/LocalMeanAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ case class LocalMeanAggregate(child: Expression) extends UnaryRasterAggregate {
4848

4949
def aggBufferAttributes: Seq[AttributeReference] = Seq(count, sum)
5050

51-
private lazy val Defined: Expression => ScalaUDF = tileOpAsExpressionNew("defined_cells", local.Defined.apply)
51+
private lazy val Defined: Expression => ScalaUDF = tileOpAsExpression("defined_cells", local.Defined.apply)
5252

5353
lazy val initialValues: Seq[Expression] = Seq(
5454
Literal.create(null, dataType),

core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,9 @@ package object expressions {
5353

5454
/** As opposed to `udf`, this constructs an unwrapped ScalaUDF Expression from a function. */
5555
private[expressions]
56-
def udfexpr[RT: TypeTag, A1: TypeTag](name: String, f: A1 => RT): Expression => ScalaUDF = (child: Expression) => {
56+
def udfexpr[RT: TypeTag, A1: TypeTag, A1T: TypeTag](name: String, f: DataType => A1 => RT): Expression => ScalaUDF = (exp: Expression) => {
5757
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
58-
ScalaUDF(f, dataType, Seq(child), Option(ExpressionEncoder[RT]()) :: Nil, udfName = Some(name))
59-
}
60-
61-
private[expressions]
62-
def udfexprNew[RT: TypeTag, A1: TypeTag](name: String, f: DataType => A1 => RT): Expression => ScalaUDF = (exp: Expression) => {
63-
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
64-
ScalaUDF((row: A1) => f(exp.dataType)(row), dataType, exp :: Nil, Option(ExpressionEncoder[RT]().resolveAndBind()) :: Nil)
65-
}
66-
67-
private[expressions]
68-
def udfexprNewUntyped[RT: TypeTag, A1: TypeTag](name: String, f: DataType => A1 => RT): Expression => ScalaUDF = (exp: Expression) => {
69-
val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT]
70-
ScalaUDF((row: A1) => f(exp.dataType)(row), dataType, exp :: Nil)
58+
ScalaUDF((row: A1) => f(exp.dataType)(row), dataType, exp :: Nil, Option(ExpressionEncoder[A1T]().resolveAndBind()) :: Nil)
7159
}
7260

7361
def register(sqlContext: SQLContext): Unit = {

core/src/main/scala/org/locationtech/rasterframes/functions/package.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ import geotrellis.vector.Extent
2727
import org.apache.spark.sql.functions.udf
2828
import org.apache.spark.sql.{Row, SQLContext}
2929
import org.locationtech.jts.geom.Geometry
30-
import org.locationtech.rasterframes._
3130
import org.locationtech.rasterframes.encoders.syntax._
3231
import org.locationtech.rasterframes.util.ResampleMethod
3332

core/src/main/scala/org/locationtech/rasterframes/ref/RFRasterSource.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,17 @@ package org.locationtech.rasterframes.ref
2424
import java.net.URI
2525
import com.github.blemale.scaffeine.Scaffeine
2626
import com.typesafe.scalalogging.LazyLogging
27-
import frameless.Injection
2827
import geotrellis.proj4.CRS
2928
import geotrellis.raster._
3029
import geotrellis.raster.io.geotiff.Tags
3130
import geotrellis.vector.Extent
3231
import org.apache.hadoop.conf.Configuration
3332
import org.apache.spark.annotation.Experimental
3433
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
35-
import org.apache.spark.sql.rf.{RasterSourceUDT}
34+
import org.apache.spark.sql.rf.RasterSourceUDT
3635
import org.locationtech.rasterframes.model.TileContext
37-
import org.locationtech.rasterframes.util.KryoSupport
3836
import org.locationtech.rasterframes.{NOMINAL_TILE_DIMS, rfConfig}
3937

40-
import java.nio.ByteBuffer
4138
import scala.concurrent.duration.{Duration, FiniteDuration}
4239

4340
/**

0 commit comments

Comments
 (0)