Skip to content

Commit 92afa34

Browse files
authored
Merge pull request #330 from s22s/feature/agg-raster-refactor
Incremental step toward easier aggregate raster generation.
2 parents 36be700 + 98cbdfd commit 92afa34

File tree

18 files changed

+292
-125
lines changed

18 files changed

+292
-125
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.ref
23+
24+
import java.net.URI
25+
26+
import geotrellis.proj4.LatLng
27+
import geotrellis.vector.Extent
28+
import org.locationtech.rasterframes._
29+
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate
30+
31+
class RasterRefIT extends TestEnvironment {
32+
describe("practical subregion reads") {
33+
ignore("should construct a natural color composite") {
34+
import spark.implicits._
35+
def scene(idx: Int) = URI.create(s"https://landsat-pds.s3.us-west-2.amazonaws.com" +
36+
s"/c1/L8/176/039/LC08_L1TP_176039_20190703_20190718_01_T1/LC08_L1TP_176039_20190703_20190718_01_T1_B$idx.TIF")
37+
38+
val redScene = RasterSource(scene(4))
39+
// [west, south, east, north]
40+
val area = Extent(31.115, 29.963, 31.148, 29.99).reproject(LatLng, redScene.crs)
41+
42+
val red = RasterRef(redScene, 0, Some(area), None)
43+
val green = RasterRef(RasterSource(scene(3)), 0, Some(area), None)
44+
val blue = RasterRef(RasterSource(scene(2)), 0, Some(area), None)
45+
46+
val rf = Seq((red, green, blue)).toDF("red", "green", "blue")
47+
val df = rf.select(
48+
rf_crs($"red"), rf_extent($"red"), rf_tile($"red"), rf_tile($"green"), rf_tile($"blue"))
49+
.toDF
50+
51+
val raster = TileRasterizerAggregate(df, redScene.crs, None, None)
52+
53+
forEvery(raster.tile.statisticsDouble) { stats =>
54+
stats should be ('defined)
55+
stats.get.dataCells shouldBe > (1000L)
56+
}
57+
58+
//import geotrellis.raster.io.geotiff.{GeoTiffOptions, MultibandGeoTiff, Tiled}
59+
//import geotrellis.raster.io.geotiff.compression.{DeflateCompression, NoCompression}
60+
//import geotrellis.raster.io.geotiff.tags.codes.ColorSpace
61+
//val tiffOptions = GeoTiffOptions(Tiled, DeflateCompression, ColorSpace.RGB)
62+
//MultibandGeoTiff(raster, raster.crs, tiffOptions).write("target/composite.tif")
63+
}
64+
}
65+
}

core/src/main/scala/org/locationtech/rasterframes/encoders/StandardSerializers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ trait StandardSerializers {
7373
implicit val gridBoundsSerializer: CatalystSerializer[GridBounds] = new CatalystSerializer[GridBounds] {
7474
override def schema: StructType = StructType(Seq(
7575
StructField("colMin", IntegerType, false),
76-
StructField("rowlMin", IntegerType, false),
76+
StructField("rowMin", IntegerType, false),
7777
StructField("colMax", IntegerType, false),
7878
StructField("rowMax", IntegerType, false)
7979
))

core/src/main/scala/org/locationtech/rasterframes/expressions/DynamicExtractors.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.locationtech.rasterframes.tiles.ProjectedRasterTile
3535

3636
private[rasterframes]
3737
object DynamicExtractors {
38-
/** Partial function for pulling a tile and its contesxt from an input row. */
38+
/** Partial function for pulling a tile and its context from an input row. */
3939
lazy val tileExtractor: PartialFunction[DataType, InternalRow => (Tile, Option[TileContext])] = {
4040
case _: TileUDT =>
4141
(row: InternalRow) =>
@@ -47,6 +47,14 @@ object DynamicExtractors {
4747
}
4848
}
4949

50+
lazy val rasterRefExtractor: PartialFunction[DataType, InternalRow => RasterRef] = {
51+
case t if t.conformsTo[RasterRef] =>
52+
(row: InternalRow) => row.to[RasterRef]
53+
}
54+
55+
lazy val tileableExtractor: PartialFunction[DataType, InternalRow => Tile] =
56+
tileExtractor.andThen(_.andThen(_._1)).orElse(rasterRefExtractor.andThen(_.andThen(_.tile)))
57+
5058
lazy val rowTileExtractor: PartialFunction[DataType, Row => (Tile, Option[TileContext])] = {
5159
case _: TileUDT =>
5260
(row: Row) => (row.to[Tile](TileUDT.tileSerializer), None)

core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/RealizeTile.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,17 @@
2222
package org.locationtech.rasterframes.expressions.accessors
2323

2424
import geotrellis.raster.Tile
25+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
26+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
2527
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
26-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
28+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, UnaryExpression}
2729
import org.apache.spark.sql.rf.TileUDT
2830
import org.apache.spark.sql.types.DataType
2931
import org.apache.spark.sql.{Column, TypedColumn}
3032
import org.locationtech.rasterframes._
3133
import org.locationtech.rasterframes.encoders.CatalystSerializer._
32-
import org.locationtech.rasterframes.expressions.UnaryRasterOp
33-
import org.locationtech.rasterframes.model.TileContext
34+
import org.locationtech.rasterframes.expressions.DynamicExtractors._
35+
import org.locationtech.rasterframes.expressions._
3436

3537
@ExpressionDescription(
3638
usage = "_FUNC_(raster) - Extracts the Tile component of a RasterSource, ProjectedRasterTile (or Tile) and ensures the cells are fully fetched.",
@@ -39,14 +41,22 @@ import org.locationtech.rasterframes.model.TileContext
3941
> SELECT _FUNC_(raster);
4042
....
4143
""")
42-
case class RealizeTile(child: Expression) extends UnaryRasterOp with CodegenFallback {
44+
case class RealizeTile(child: Expression) extends UnaryExpression with CodegenFallback {
4345
override def dataType: DataType = TileType
4446

4547
override def nodeName: String = "rf_tile"
46-
implicit val tileSer = TileUDT.tileSerializer
4748

48-
override protected def eval(tile: Tile, ctx: Option[TileContext]): Any =
49+
override def checkInputDataTypes(): TypeCheckResult = {
50+
if (!tileableExtractor.isDefinedAt(child.dataType)) {
51+
TypeCheckFailure(s"Input type '${child.dataType}' does not conform to a tiled raster type.")
52+
} else TypeCheckSuccess
53+
}
54+
implicit val tileSer = TileUDT.tileSerializer
55+
override protected def nullSafeEval(input: Any): Any = {
56+
val in = row(input)
57+
val tile = tileableExtractor(child.dataType)(in)
4958
(tile.toArrayTile(): Tile).toInternalRow
59+
}
5060
}
5161

5262
object RealizeTile {

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,18 @@ package org.locationtech.rasterframes.expressions.aggregates
2424
import geotrellis.proj4.CRS
2525
import geotrellis.raster.reproject.Reproject
2626
import geotrellis.raster.resample.ResampleMethod
27-
import geotrellis.raster.{ArrayTile, CellType, Raster, Tile}
28-
import geotrellis.spark.TileLayerMetadata
27+
import geotrellis.raster.{ArrayTile, CellType, MultibandTile, ProjectedRaster, Raster, Tile}
28+
import geotrellis.spark.{SpatialKey, TileLayerMetadata}
2929
import geotrellis.vector.Extent
3030
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
3131
import org.apache.spark.sql.types.{DataType, StructField, StructType}
32-
import org.apache.spark.sql.{Column, Row, TypedColumn}
32+
import org.apache.spark.sql.{Column, DataFrame, Row, TypedColumn}
3333
import org.locationtech.rasterframes._
34+
import org.locationtech.rasterframes.util._
3435
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3536
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate.ProjectedRasterDefinition
37+
import org.locationtech.rasterframes.model.TileDimensions
38+
import org.slf4j.LoggerFactory
3639

3740
/**
3841
* Aggregation function for creating a single `geotrellis.raster.Raster[Tile]` from
@@ -88,7 +91,7 @@ class TileRasterizerAggregate(prd: ProjectedRasterDefinition) extends UserDefine
8891
}
8992

9093
object TileRasterizerAggregate {
91-
val nodeName = "rf_tile_rasterizer_aggregate"
94+
val nodeName = "rf_agg_raster"
9295
/** Convenience grouping of parameters needed for running aggregate. */
9396
case class ProjectedRasterDefinition(totalCols: Int, totalRows: Int, cellType: CellType, crs: CRS, extent: Extent, sampler: ResampleMethod = ResampleMethod.DEFAULT)
9497

@@ -102,8 +105,73 @@ object TileRasterizerAggregate {
102105
val rows = actualSize.height
103106
new ProjectedRasterDefinition(cols, rows, tlm.cellType, tlm.crs, tlm.extent, sampler)
104107
}
105-
}
108+
}
109+
110+
@transient
111+
private lazy val logger = LoggerFactory.getLogger(getClass)
112+
113+
def apply(prd: ProjectedRasterDefinition, crsCol: Column, extentCol: Column, tileCol: Column): TypedColumn[Any, Raster[Tile]] = {
114+
115+
if (prd.totalCols.toDouble * prd.totalRows * 64.0 > Runtime.getRuntime.totalMemory() * 0.5)
116+
logger.warn(
117+
s"You've asked for the construction of a very large image (${prd.totalCols} x ${prd.totalRows}). Out of memory error likely.")
106118

107-
def apply(prd: ProjectedRasterDefinition, crsCol: Column, extentCol: Column, tileCol: Column): TypedColumn[Any, Raster[Tile]] =
108119
new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol).as(nodeName).as[Raster[Tile]]
120+
}
121+
122+
def apply(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
123+
val tileCols = WithDataFrameMethods(df).tileColumns
124+
require(tileCols.nonEmpty, "need at least one tile column")
125+
// Select the anchoring Tile, Extent and CRS columns
126+
val (extCol, crsCol, tileCol) = {
127+
// Favor "ProjectedRaster" columns
128+
val prCols = df.projRasterColumns
129+
if (prCols.nonEmpty) {
130+
(rf_extent(prCols.head), rf_crs(prCols.head), rf_tile(prCols.head))
131+
} else {
132+
// If no "ProjectedRaster" column, look for single Extent and CRS columns.
133+
val crsCols = df.crsColumns
134+
require(crsCols.size == 1, "Exactly one CRS column must be in DataFrame")
135+
val extentCols = df.extentColumns
136+
require(extentCols.size == 1, "Exactly one Extent column must be in DataFrame")
137+
(extentCols.head, crsCols.head, tileCols.head)
138+
}
139+
}
140+
141+
// Scan table and constuct what the TileLayerMetadata would be in the specified destination CRS.
142+
val tlm: TileLayerMetadata[SpatialKey] = df
143+
.select(
144+
ProjectedLayerMetadataAggregate(
145+
destCRS,
146+
extCol,
147+
crsCol,
148+
rf_cell_type(tileCol),
149+
rf_dimensions(tileCol)
150+
))
151+
.first()
152+
logger.debug(s"Collected TileLayerMetadata: ${tlm.toString}")
153+
154+
val c = ProjectedRasterDefinition(tlm)
155+
156+
val config = rasterDims
157+
.map { dims =>
158+
c.copy(totalCols = dims.cols, totalRows = dims.rows)
159+
}
160+
.getOrElse(c)
161+
162+
destExtent.map { ext =>
163+
c.copy(extent = ext)
164+
}
165+
166+
val aggs = tileCols
167+
.map(t => TileRasterizerAggregate(config, crsCol, extCol, rf_tile(t))("tile").as(t.columnName))
168+
169+
val agg = df.select(aggs: _*)
170+
171+
val row = agg.first()
172+
173+
val bands = for (i <- 0 until row.size) yield row.getAs[Tile](i)
174+
175+
ProjectedRaster(MultibandTile(bands), tlm.extent, tlm.crs)
176+
}
109177
}

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,9 @@ case class RGBComposite(red: Expression, green: Expression, blue: Expression) ex
5656
override def nodeName: String = "rf_rgb_composite"
5757

5858
override def dataType: DataType = if(
59-
red.dataType.conformsTo[ProjectedRasterTile] ||
60-
blue.dataType.conformsTo[ProjectedRasterTile] ||
61-
green.dataType.conformsTo[ProjectedRasterTile]
59+
tileExtractor.isDefinedAt(red.dataType) ||
60+
tileExtractor.isDefinedAt(green.dataType) ||
61+
tileExtractor.isDefinedAt(blue.dataType)
6262
) red.dataType
6363
else TileType
6464

core/src/main/scala/org/locationtech/rasterframes/extensions/DataFrameMethods.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package org.locationtech.rasterframes.extensions
2323

2424
import geotrellis.proj4.CRS
25+
import geotrellis.raster.{MultibandTile, ProjectedRaster}
2526
import geotrellis.spark.io._
2627
import geotrellis.spark.{SpaceTimeKey, SpatialComponent, SpatialKey, TemporalKey, TileLayerMetadata}
2728
import geotrellis.util.MethodExtensions
@@ -32,7 +33,9 @@ import org.apache.spark.sql.{Column, DataFrame, TypedColumn}
3233
import org.locationtech.rasterframes.StandardColumns._
3334
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3435
import org.locationtech.rasterframes.encoders.StandardEncoders._
35-
import org.locationtech.rasterframes.expressions.DynamicExtractors
36+
import org.locationtech.rasterframes.expressions.{DynamicExtractors, aggregates}
37+
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate
38+
import org.locationtech.rasterframes.model.TileDimensions
3639
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
3740
import org.locationtech.rasterframes.util._
3841
import org.locationtech.rasterframes.{MetadataKeys, RasterFrameLayer}
@@ -225,7 +228,7 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
225228
*/
226229
@throws[IllegalArgumentException]
227230
def asLayer: RasterFrameLayer = {
228-
val potentialRF = certifyRasterframe(self)
231+
val potentialRF = certifyLayer(self)
229232

230233
require(
231234
potentialRF.findSpatialKeyField.nonEmpty,
@@ -301,5 +304,5 @@ trait DataFrameMethods[DF <: DataFrame] extends MethodExtensions[DF] with Metada
301304

302305
/** Internal method for slapping the RasterFreameLayer seal of approval on a DataFrame.
303306
* Only call if if you are sure it has a spatial key and tile columns and TileLayerMetadata. */
304-
private[rasterframes] def certify = certifyRasterframe(self)
307+
private[rasterframes] def certify = certifyLayer(self)
305308
}

core/src/main/scala/org/locationtech/rasterframes/util/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ package object util extends DataFrameRenderers {
7777
type KeyMethodsProvider[K1, K2] = K1 TilerKeyMethods[K1, K2]
7878

7979
/** Internal method for slapping the RasterFrameLayer seal of approval on a DataFrame. */
80-
private[rasterframes] def certifyRasterframe(df: DataFrame): RasterFrameLayer =
80+
private[rasterframes] def certifyLayer(df: DataFrame): RasterFrameLayer =
8181
shapeless.tag[RasterFrameTag][DataFrame](df)
8282

8383

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,53 +23,22 @@ package org.locationtech.rasterframes
2323

2424
import java.io.ByteArrayInputStream
2525

26-
import geotrellis.proj4.LatLng
2726
import geotrellis.raster
2827
import geotrellis.raster._
2928
import geotrellis.raster.render.ColorRamps
3029
import geotrellis.raster.testkit.RasterMatchers
31-
import geotrellis.vector.Extent
3230
import javax.imageio.ImageIO
3331
import org.apache.spark.sql.Encoders
3432
import org.apache.spark.sql.functions._
3533
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
3634
import org.locationtech.rasterframes.model.TileDimensions
37-
import org.locationtech.rasterframes.ref.{RasterRef, RasterSource}
3835
import org.locationtech.rasterframes.stats._
3936
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
4037

4138
class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
39+
import TestData._
4240
import spark.implicits._
4341

44-
val extent = Extent(10, 20, 30, 40)
45-
val crs = LatLng
46-
val ct = ByteUserDefinedNoDataCellType(-2)
47-
val cols = 10
48-
val rows = cols
49-
val tileSize = cols * rows
50-
val tileCount = 10
51-
val numND = 4
52-
lazy val zero = TestData.projectedRasterTile(cols, rows, 0, extent, crs, ct)
53-
lazy val one = TestData.projectedRasterTile(cols, rows, 1, extent, crs, ct)
54-
lazy val two = TestData.projectedRasterTile(cols, rows, 2, extent, crs, ct)
55-
lazy val three = TestData.projectedRasterTile(cols, rows, 3, extent, crs, ct)
56-
lazy val six = ProjectedRasterTile(three * two, three.extent, three.crs)
57-
lazy val nd = TestData.projectedRasterTile(cols, rows, -2, extent, crs, ct)
58-
lazy val randPRT = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextInt(), extent, crs, ct)
59-
lazy val randNDPRT: Tile = TestData.injectND(numND)(randPRT)
60-
61-
lazy val randDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextGaussian(), extent, crs, DoubleConstantNoDataCellType)
62-
lazy val randDoubleNDTile = TestData.injectND(numND)(randDoubleTile)
63-
lazy val randPositiveDoubleTile = TestData.projectedRasterTile(cols, rows, scala.util.Random.nextDouble() + 1e-6, extent, crs, DoubleConstantNoDataCellType)
64-
65-
val expectedRandNoData: Long = numND * tileCount.toLong
66-
val expectedRandData: Long = cols * rows * tileCount - expectedRandNoData
67-
lazy val randNDTilesWithNull = Seq.fill[Tile](tileCount)(TestData.injectND(numND)(
68-
TestData.randomTile(cols, rows, UByteConstantNoDataCellType)
69-
)).map(ProjectedRasterTile(_, extent, crs)) :+ null
70-
71-
def lazyPRT = RasterRef(RasterSource(TestData.l8samplePath), 0, None, None).tile
72-
7342
implicit val pairEnc = Encoders.tuple(ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder)
7443
implicit val tripEnc = Encoders.tuple(ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder, ProjectedRasterTile.prtEncoder)
7544

0 commit comments

Comments
 (0)