Skip to content

Commit e313542

Browse files
authored
Merge pull request #362 from s22s/fix/360
Fixed handling of aggregate extent and image size on geotiff write.
2 parents 8bac84a + 1b7d35f commit e313542

File tree

8 files changed

+100
-52
lines changed

8 files changed

+100
-52
lines changed

core/src/it/scala/org/locationtech/rasterframes/ref/RasterRefIT.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class RasterRefIT extends TestEnvironment {
4848
rf_crs($"red"), rf_extent($"red"), rf_tile($"red"), rf_tile($"green"), rf_tile($"blue"))
4949
.toDF
5050

51-
val raster = TileRasterizerAggregate(df, redScene.crs, None, None)
51+
val raster = TileRasterizerAggregate.collect(df, redScene.crs, None, None)
5252

5353
forEvery(raster.tile.statisticsDouble) { stats =>
5454
stats should be ('defined)

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/ProjectedLayerMetadataAggregate.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ class ProjectedLayerMetadataAggregate(destCRS: CRS, destDims: TileDimensions) ex
7777
import org.locationtech.rasterframes.encoders.CatalystSerializer._
7878
val buf = buffer.to[BufferRecord]
7979

80+
if (buf.isEmpty) {
81+
throw new IllegalArgumentException("Can not collect metadata from empty data frame.")
82+
}
83+
8084
val re = RasterExtent(buf.extent, buf.cellSize)
8185
val layout = LayoutDefinition(re, destDims.cols, destDims.rows)
8286

@@ -152,6 +156,8 @@ object ProjectedLayerMetadataAggregate {
152156
buffer(i) = encoded(i)
153157
}
154158
}
159+
160+
def isEmpty: Boolean = extent == null || cellType == null || cellSize == null
155161
}
156162

157163
private[expressions]

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ object TileRasterizerAggregate {
119119
new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol).as(nodeName).as[Raster[Tile]]
120120
}
121121

122-
def apply(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
122+
def collect(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
123123
val tileCols = WithDataFrameMethods(df).tileColumns
124124
require(tileCols.nonEmpty, "need at least one tile column")
125125
// Select the anchoring Tile, Extent and CRS columns

core/src/main/scala/org/locationtech/rasterframes/extensions/RFSpatialColumnMethods.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.functions.{asc, udf => sparkUdf}
3434
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3535
import org.locationtech.geomesa.curve.Z2SFC
3636
import org.locationtech.rasterframes.StandardColumns
37+
import org.locationtech.rasterframes.encoders.serialized_literal
3738

3839
/**
3940
* RasterFrameLayer extension methods associated with adding spatially descriptive columns.
@@ -71,6 +72,15 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta
7172
val key2Extent = sparkUdf(keyCol2Extent)
7273
self.withColumn(colName, key2Extent(self.spatialKeyColumn)).certify
7374
}
75+
/**
76+
* Append a column containing the CRS of the layer.
77+
*
78+
* @param colName name of column to append. Defaults to "crs"
79+
* @return updated RasterFrameLayer
80+
*/
81+
def withCRS(colName: String = CRS_COLUMN.columnName): RasterFrameLayer = {
82+
self.withColumn(colName, serialized_literal(self.crs)).certify
83+
}
7484

7585
/**
7686
* Append a column containing the bounds of the row's spatial key.

datasource/src/main/scala/org/locationtech/rasterframes/datasource/geotiff/GeoTiffDataSource.scala

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,21 @@ class GeoTiffDataSource
6767

6868
require(tileCols.nonEmpty, "Could not find any tile columns.")
6969

70-
val raster = if (df.isAlreadyLayer) {
71-
val layer = df.certify
72-
val tlm = layer.tileLayerMetadata.merge
73-
74-
// If no desired image size is given, write at full size.
75-
val TileDimensions(cols, rows) = parameters.rasterDimensions
76-
.getOrElse {
77-
val actualSize = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent)
78-
TileDimensions(actualSize.width, actualSize.height)
79-
}
8070

81-
// Should we really play traffic cop here?
82-
if (cols.toDouble * rows * 64.0 > Runtime.getRuntime.totalMemory() * 0.5)
83-
logger.warn(
84-
s"You've asked for the construction of a very large image ($cols x $rows), destined for ${path}. Out of memory error likely.")
8571

86-
layer.toMultibandRaster(tileCols, cols.toInt, rows.toInt)
87-
} else {
88-
require(parameters.crs.nonEmpty, "A destination CRS must be provided")
89-
TileRasterizerAggregate(df, parameters.crs.get, None, parameters.rasterDimensions)
90-
}
72+
val destCRS = parameters.crs.orElse(df.asLayerSafely.map(_.crs)).getOrElse(
73+
throw new IllegalArgumentException("A destination CRS must be provided")
74+
)
75+
76+
val input = df.asLayerSafely.map(layer =>
77+
(layer.crsColumns.isEmpty, layer.extentColumns.isEmpty) match {
78+
case (true, true) => layer.withExtent().withCRS()
79+
case (true, false) => layer.withCRS()
80+
case (false, true) => layer.withExtent()
81+
case _ => layer
82+
}).getOrElse(df)
83+
84+
val raster = TileRasterizerAggregate.collect(input, destCRS, None, parameters.rasterDimensions)
9185

9286
val tags = Tags(
9387
RFBuildInfo.toMap.filter(_._1.toLowerCase() == "version").mapValues(_.toString),

datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotiff/GeoTiffDataSourceSpec.scala

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@
2020
*/
2121
package org.locationtech.rasterframes.datasource.geotiff
2222

23-
import java.nio.file.Paths
23+
import java.nio.file.{Path, Paths}
2424

2525
import geotrellis.proj4._
26+
import geotrellis.raster.CellType
2627
import geotrellis.raster.io.geotiff.{MultibandGeoTiff, SinglebandGeoTiff}
2728
import geotrellis.vector.Extent
2829
import org.locationtech.rasterframes._
2930
import org.apache.spark.sql.functions._
3031
import org.locationtech.rasterframes.TestEnvironment
32+
import org.locationtech.rasterframes.datasource.raster._
3133

3234
/**
3335
* @since 1/14/18
@@ -89,6 +91,15 @@ class GeoTiffDataSourceSpec
8991

9092
describe("GeoTiff writing") {
9193

94+
def checkTiff(file: Path, cols: Int, rows: Int, extent: Extent, cellType: Option[CellType] = None) = {
95+
val outputTif = SinglebandGeoTiff(file.toString)
96+
outputTif.tile.dimensions should be ((cols, rows))
97+
outputTif.extent should be (extent)
98+
cellType.foreach(ct =>
99+
outputTif.cellType should be (ct)
100+
)
101+
}
102+
92103
it("should write GeoTIFF RF to parquet") {
93104
val rf = spark.read.format("geotiff").load(cogPath.toASCIIString).asLayer
94105
assert(write(rf))
@@ -104,6 +115,9 @@ class GeoTiffDataSourceSpec
104115
noException shouldBe thrownBy {
105116
rf.write.format("geotiff").save(out.toString)
106117
}
118+
val extent = rf.tileLayerMetadata.merge.extent
119+
120+
checkTiff(out, 1028, 989, extent)
107121
}
108122

109123
it("should write unstructured raster") {
@@ -116,10 +130,10 @@ class GeoTiffDataSourceSpec
116130

117131
val crs = df.select(rf_crs($"proj_raster")).first()
118132

119-
val out = Paths.get("target", "unstructured.tif").toString
133+
val out = Paths.get("target", "unstructured.tif")
120134

121135
noException shouldBe thrownBy {
122-
df.write.geotiff.withCRS(crs).save(out)
136+
df.write.geotiff.withCRS(crs).save(out.toString)
123137
}
124138

125139
val (inCols, inRows) = {
@@ -129,11 +143,7 @@ class GeoTiffDataSourceSpec
129143
inCols should be (774)
130144
inRows should be (500) //from gdalinfo
131145

132-
val outputTif = SinglebandGeoTiff(out)
133-
outputTif.imageData.cols should be (inCols)
134-
outputTif.imageData.rows should be (inRows)
135-
136-
// TODO check datatype, extent.
146+
checkTiff(out, inCols, inRows, Extent(431902.5, 4313647.5, 443512.5, 4321147.5))
137147
}
138148

139149
it("should round trip unstructured raster from COG"){
@@ -163,30 +173,26 @@ class GeoTiffDataSourceSpec
163173

164174
dfExtent shouldBe resourceExtent
165175

166-
val out = Paths.get("target", "unstructured_cog.tif").toString
176+
val out = Paths.get("target", "unstructured_cog.tif")
167177

168178
noException shouldBe thrownBy {
169-
df.write.geotiff.withCRS(crs).save(out)
179+
df.write.geotiff.withCRS(crs).save(out.toString)
170180
}
171181

172182
val (inCols, inRows, inExtent, inCellType) = {
173183
val tif = readSingleband("LC08_B7_Memphis_COG.tiff")
174184
val id = tif.imageData
175185
(id.cols, id.rows, tif.extent, tif.cellType)
176186
}
177-
inCols should be (963)
178-
inRows should be (754) //from gdalinfo
187+
inCols should be (resourceCols)
188+
inRows should be (resourceRows) //from gdalinfo
179189
inExtent should be (resourceExtent)
180190

181-
val outputTif = SinglebandGeoTiff(out)
182-
outputTif.imageData.cols should be (inCols)
183-
outputTif.imageData.rows should be (inRows)
184-
outputTif.extent should be (resourceExtent)
185-
outputTif.cellType should be (inCellType)
191+
checkTiff(out, inCols, inRows, resourceExtent, Some(inCellType))
186192
}
187193

188194
it("should write GeoTIFF without layer") {
189-
import org.locationtech.rasterframes.datasource.raster._
195+
190196
val pr = col("proj_raster_b0")
191197
val rf = spark.read.raster.withBandIndexes(0, 1, 2).load(rgbCogSamplePath.toASCIIString)
192198

@@ -217,6 +223,42 @@ class GeoTiffDataSourceSpec
217223
.save(out.toString)
218224
}
219225
}
226+
227+
checkTiff(out, 128, 128,
228+
Extent(-76.52586750038186, 36.85907177863949, -76.17461216980891, 37.1303690755922))
229+
}
230+
231+
it("should produce the correct subregion from layer") {
232+
import spark.implicits._
233+
val rf = SinglebandGeoTiff(TestData.singlebandCogPath.getPath)
234+
.projectedRaster.toLayer(128, 128).withExtent()
235+
236+
val out = Paths.get("target", "example3-geotiff.tif")
237+
logger.info(s"Writing to $out")
238+
239+
val bitOfLayer = rf.filter($"spatial_key.col" === 0 && $"spatial_key.row" === 0)
240+
val expectedExtent = bitOfLayer.select($"extent".as[Extent]).first()
241+
bitOfLayer.write.geotiff.save(out.toString)
242+
243+
checkTiff(out, 128, 128, expectedExtent)
244+
}
245+
246+
it("should produce the correct subregion without layer") {
247+
import spark.implicits._
248+
249+
val rf = spark.read.raster
250+
.withTileDimensions(128, 128)
251+
.load(TestData.singlebandCogPath.toASCIIString)
252+
253+
val out = Paths.get("target", "example3-geotiff.tif")
254+
logger.info(s"Writing to $out")
255+
256+
val bitOfLayer = rf.filter(st_intersects(st_makePoint(754245, 3893385), rf_geometry($"proj_raster")))
257+
val expectedExtent = bitOfLayer.select(rf_extent($"proj_raster")).first()
258+
val crs = bitOfLayer.select(rf_crs($"proj_raster")).first()
259+
bitOfLayer.write.geotiff.withCRS(crs).save(out.toString)
260+
261+
checkTiff(out, 128, 128, expectedExtent)
220262
}
221263

222264
def s(band: Int): String =

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
### 0.8.2
66

7+
* Fixed handling of aggregate extent and image size on GeoTIFF writing. ([#362](https://github.com/locationtech/rasterframes/issues/362))
78
* Fixed issue with `RasterSourceDataSource` swallowing exceptions. ([#267](https://github.com/locationtech/rasterframes/issues/267))
89
* Fixed SparkML memory pressure issue caused by unnecessary reevaluation, overallocation, and primitive boxing. ([#343](https://github.com/locationtech/rasterframes/issues/343))
910
* Fixed Parquet serialization issue with `RasterRef`s ([#338](https://github.com/locationtech/rasterframes/issues/338))

experimental/src/it/scala/org/locationtech/rasterframes/experimental/datasource/awspds/L8CatalogRelationTest.scala

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,13 @@ class L8CatalogRelationTest extends TestEnvironment {
108108
stats.mean shouldBe > (10000.0)
109109
}
110110

111-
ignore("should construct an RGB composite") {
112-
val aoi = Extent(31.115, 29.963, 31.148, 29.99)
111+
it("should construct an RGB composite") {
112+
val aoiLL = Extent(31.115, 29.963, 31.148, 29.99)
113+
113114
val scene = catalog
114115
.where(
115116
to_date($"acquisition_date") === to_date(lit("2019-07-03")) &&
116-
st_intersects(st_geometry($"bounds_wgs84"), geomLit(aoi.jtsGeom))
117+
st_intersects(st_geometry($"bounds_wgs84"), geomLit(aoiLL.jtsGeom))
117118
)
118119
.orderBy("cloud_cover_pct")
119120
.limit(1)
@@ -122,19 +123,13 @@ class L8CatalogRelationTest extends TestEnvironment {
122123
.fromCatalog(scene, "B4", "B3", "B2")
123124
.withTileDimensions(256, 256)
124125
.load()
125-
.where(st_contains(rf_geometry($"B4"), st_reproject(geomLit(aoi.jtsGeom), lit("EPSG:4326"), rf_crs($"B4"))))
126-
126+
.limit(1)
127127

128128
noException should be thrownBy {
129-
val raster = TileRasterizerAggregate(df, LatLng, Some(aoi), None)
130-
println(raster)
129+
val raster = TileRasterizerAggregate.collect(df, LatLng, Some(aoiLL), None)
130+
raster.tile.bandCount should be (3)
131+
raster.extent.area > 0
131132
}
132-
133-
// import geotrellis.raster.io.geotiff.{GeoTiffOptions, MultibandGeoTiff, Tiled}
134-
// import geotrellis.raster.io.geotiff.compression.{DeflateCompression}
135-
// import geotrellis.raster.io.geotiff.tags.codes.ColorSpace
136-
// val tiffOptions = GeoTiffOptions(Tiled, DeflateCompression, ColorSpace.RGB)
137-
// MultibandGeoTiff(raster, raster.crs, tiffOptions).write("target/composite.tif")
138133
}
139134
}
140135
}

0 commit comments

Comments
 (0)