Skip to content

Commit 160f351

Browse files
committed
Try Aggregator implemtnation
1 parent 7f5e078 commit 160f351

File tree

1 file changed

+21
-41
lines changed

1 file changed

+21
-41
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ import geotrellis.layer._
2525
import geotrellis.proj4.CRS
2626
import geotrellis.raster.reproject.Reproject
2727
import geotrellis.raster.resample.{Bilinear, ResampleMethod}
28-
import geotrellis.raster.{ArrayTile, CellType, Dimensions, MultibandTile, ProjectedRaster, Tile}
28+
import geotrellis.raster.{ArrayTile, CellType, Dimensions, MultibandTile, MutableArrayTile, ProjectedRaster, Tile}
2929
import geotrellis.vector.Extent
30-
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
31-
import org.apache.spark.sql.types.{DataType, StructField, StructType}
32-
import org.apache.spark.sql.{Column, DataFrame, Row, TypedColumn}
30+
import org.apache.spark.sql.expressions.Aggregator
31+
import org.apache.spark.sql.functions.udaf
32+
import org.apache.spark.sql.{Column, DataFrame, Encoder, TypedColumn}
3333
import org.locationtech.rasterframes._
34-
import org.locationtech.rasterframes.encoders.syntax._
34+
import org.locationtech.rasterframes.encoders.StandardEncoders
3535
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate.ProjectedRasterDefinition
36+
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
3637
import org.locationtech.rasterframes.util._
3738
import org.slf4j.LoggerFactory
3839

@@ -41,48 +42,26 @@ import org.slf4j.LoggerFactory
4142
* `Tile`, `CRS` and `Extent` columns.
4243
* @param prd aggregation settings
4344
*/
44-
class TileRasterizerAggregate(prd: ProjectedRasterDefinition) extends UserDefinedAggregateFunction {
45-
45+
class TileRasterizerAggregate(prd: ProjectedRasterDefinition) extends Aggregator[ProjectedRasterTile, Tile, Tile] {
4646
val projOpts = Reproject.Options.DEFAULT.copy(method = prd.sampler)
4747

48-
def deterministic: Boolean = true
49-
50-
def inputSchema: StructType = StructType(Seq(
51-
StructField("crs", crsUDT, false),
52-
StructField("extent", extentEncoder.schema, false),
53-
StructField("tile", tileUDT)
54-
))
55-
56-
def bufferSchema: StructType = StructType(Seq(
57-
StructField("tile_buffer", tileUDT)
58-
))
59-
60-
def dataType: DataType = tileUDT
61-
62-
def initialize(buffer: MutableAggregationBuffer): Unit =
63-
buffer(0) = ArrayTile.empty(prd.destinationCellType, prd.totalCols, prd.totalRows)
64-
65-
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
66-
val crs: CRS = input.getAs[CRS](0)
67-
val extent: Extent = input.getAs[Row](1).as[Extent]
68-
69-
val localExtent = extent.reproject(crs, prd.destinationCRS)
48+
override def zero: MutableArrayTile = ArrayTile.empty(prd.destinationCellType, prd.totalCols, prd.totalRows)
7049

50+
override def reduce(b: Tile, a: ProjectedRasterTile): Tile = {
51+
val localExtent = a.extent.reproject(a.crs, prd.destinationCRS)
7152
if (prd.destinationExtent.intersects(localExtent)) {
72-
val localTile = input.getAs[Tile](2).reproject(extent, crs, prd.destinationCRS, projOpts)
73-
val bt = buffer.getAs[Tile](0)
74-
val merged = bt.merge(prd.destinationExtent, localExtent, localTile.tile, prd.sampler)
75-
buffer(0) = merged
76-
}
53+
val localTile = a.tile.reproject(a.extent, a.crs, prd.destinationCRS, projOpts)
54+
b.merge(prd.destinationExtent, localExtent, localTile.tile, prd.sampler)
55+
} else b
7756
}
7857

79-
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
80-
val leftTile = buffer1.getAs[Tile](0)
81-
val rightTile = buffer2.getAs[Tile](0)
82-
buffer1(0) = leftTile.merge(rightTile)
83-
}
58+
override def merge(b1: Tile, b2: Tile): Tile = b1.merge(b2)
59+
60+
override def finish(reduction: Tile): Tile = reduction
61+
62+
override def bufferEncoder: Encoder[Tile] = StandardEncoders.tileEncoder
8463

85-
def evaluate(buffer: Row): Tile = buffer.getAs[Tile](0)
64+
override def outputEncoder: Encoder[Tile] = StandardEncoders.tileEncoder
8665
}
8766

8867
object TileRasterizerAggregate {
@@ -107,7 +86,8 @@ object TileRasterizerAggregate {
10786
logger.warn(
10887
s"You've asked for the construction of a very large image (${prd.totalCols} x ${prd.totalRows}). Out of memory error likely.")
10988

110-
new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol)
89+
udaf(new TileRasterizerAggregate(prd))
90+
.apply(crsCol, extentCol, tileCol)
11191
.as("rf_agg_overview_raster")
11292
.as[Tile]
11393
}

0 commit comments

Comments
 (0)