@@ -25,14 +25,15 @@ import geotrellis.layer._
2525import geotrellis .proj4 .CRS
2626import geotrellis .raster .reproject .Reproject
2727import geotrellis .raster .resample .{Bilinear , ResampleMethod }
28- import geotrellis .raster .{ArrayTile , CellType , Dimensions , MultibandTile , ProjectedRaster , Tile }
28+ import geotrellis .raster .{ArrayTile , CellType , Dimensions , MultibandTile , MutableArrayTile , ProjectedRaster , Tile }
2929import geotrellis .vector .Extent
30- import org .apache .spark .sql .expressions .{ MutableAggregationBuffer , UserDefinedAggregateFunction }
31- import org .apache .spark .sql .types .{ DataType , StructField , StructType }
32- import org .apache .spark .sql .{Column , DataFrame , Row , TypedColumn }
30+ import org .apache .spark .sql .expressions .Aggregator
31+ import org .apache .spark .sql .functions . udaf
32+ import org .apache .spark .sql .{Column , DataFrame , Encoder , TypedColumn }
3333import org .locationtech .rasterframes ._
34- import org .locationtech .rasterframes .encoders .syntax . _
34+ import org .locationtech .rasterframes .encoders .StandardEncoders
3535import org .locationtech .rasterframes .expressions .aggregates .TileRasterizerAggregate .ProjectedRasterDefinition
36+ import org .locationtech .rasterframes .tiles .ProjectedRasterTile
3637import org .locationtech .rasterframes .util ._
3738import org .slf4j .LoggerFactory
3839
@@ -41,48 +42,26 @@ import org.slf4j.LoggerFactory
4142 * `Tile`, `CRS` and `Extent` columns.
4243 * @param prd aggregation settings
4344 */
44- class TileRasterizerAggregate (prd : ProjectedRasterDefinition ) extends UserDefinedAggregateFunction {
45-
45+ class TileRasterizerAggregate (prd : ProjectedRasterDefinition ) extends Aggregator [ProjectedRasterTile , Tile , Tile ] {
4646 val projOpts = Reproject .Options .DEFAULT .copy(method = prd.sampler)
4747
48- def deterministic : Boolean = true
49-
50- def inputSchema : StructType = StructType (Seq (
51- StructField (" crs" , crsUDT, false ),
52- StructField (" extent" , extentEncoder.schema, false ),
53- StructField (" tile" , tileUDT)
54- ))
55-
56- def bufferSchema : StructType = StructType (Seq (
57- StructField (" tile_buffer" , tileUDT)
58- ))
59-
60- def dataType : DataType = tileUDT
61-
62- def initialize (buffer : MutableAggregationBuffer ): Unit =
63- buffer(0 ) = ArrayTile .empty(prd.destinationCellType, prd.totalCols, prd.totalRows)
64-
65- def update (buffer : MutableAggregationBuffer , input : Row ): Unit = {
66- val crs : CRS = input.getAs[CRS ](0 )
67- val extent : Extent = input.getAs[Row ](1 ).as[Extent ]
68-
69- val localExtent = extent.reproject(crs, prd.destinationCRS)
48+ override def zero : MutableArrayTile = ArrayTile .empty(prd.destinationCellType, prd.totalCols, prd.totalRows)
7049
50+ override def reduce (b : Tile , a : ProjectedRasterTile ): Tile = {
51+ val localExtent = a.extent.reproject(a.crs, prd.destinationCRS)
7152 if (prd.destinationExtent.intersects(localExtent)) {
72- val localTile = input.getAs[Tile ](2 ).reproject(extent, crs, prd.destinationCRS, projOpts)
73- val bt = buffer.getAs[Tile ](0 )
74- val merged = bt.merge(prd.destinationExtent, localExtent, localTile.tile, prd.sampler)
75- buffer(0 ) = merged
76- }
53+ val localTile = a.tile.reproject(a.extent, a.crs, prd.destinationCRS, projOpts)
54+ b.merge(prd.destinationExtent, localExtent, localTile.tile, prd.sampler)
55+ } else b
7756 }
7857
79- def merge (buffer1 : MutableAggregationBuffer , buffer2 : Row ): Unit = {
80- val leftTile = buffer1.getAs[ Tile ]( 0 )
81- val rightTile = buffer2.getAs[ Tile ]( 0 )
82- buffer1( 0 ) = leftTile.merge(rightTile)
83- }
58+ override def merge (b1 : Tile , b2 : Tile ): Tile = b1.merge(b2)
59+
60+ override def finish ( reduction : Tile ) : Tile = reduction
61+
62+ override def bufferEncoder : Encoder [ Tile ] = StandardEncoders .tileEncoder
8463
85- def evaluate ( buffer : Row ) : Tile = buffer.getAs[ Tile ]( 0 )
64+ override def outputEncoder : Encoder [ Tile ] = StandardEncoders .tileEncoder
8665}
8766
8867object TileRasterizerAggregate {
@@ -107,7 +86,8 @@ object TileRasterizerAggregate {
10786 logger.warn(
10887 s " You've asked for the construction of a very large image ( ${prd.totalCols} x ${prd.totalRows}). Out of memory error likely. " )
10988
110- new TileRasterizerAggregate (prd)(crsCol, extentCol, tileCol)
89+ udaf(new TileRasterizerAggregate (prd))
90+ .apply(crsCol, extentCol, tileCol)
11191 .as(" rf_agg_overview_raster" )
11292 .as[Tile ]
11393 }
0 commit comments