Skip to content

Commit 5953f0f

Browse files
committed
Fixed handling of aggregate extent and image size on geotiff write.
1 parent 0b04372 commit 5953f0f

File tree

5 files changed

+70
-42
lines changed

5 files changed

+70
-42
lines changed

core/src/it/scala/org/locationtech/rasterframes/ref/RasterRefIT.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class RasterRefIT extends TestEnvironment {
4848
rf_crs($"red"), rf_extent($"red"), rf_tile($"red"), rf_tile($"green"), rf_tile($"blue"))
4949
.toDF
5050

51-
val raster = TileRasterizerAggregate(df, redScene.crs, None, None)
51+
val raster = TileRasterizerAggregate.collect(df, redScene.crs, None, None)
5252

5353
forEvery(raster.tile.statisticsDouble) { stats =>
5454
stats should be ('defined)

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ object TileRasterizerAggregate {
119119
new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol).as(nodeName).as[Raster[Tile]]
120120
}
121121

122-
def apply(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
122+
def collect(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
123123
val tileCols = WithDataFrameMethods(df).tileColumns
124124
require(tileCols.nonEmpty, "need at least one tile column")
125125
// Select the anchoring Tile, Extent and CRS columns

core/src/main/scala/org/locationtech/rasterframes/extensions/RFSpatialColumnMethods.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import org.apache.spark.sql.functions.{asc, udf => sparkUdf}
3434
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
3535
import org.locationtech.geomesa.curve.Z2SFC
3636
import org.locationtech.rasterframes.StandardColumns
37+
import org.locationtech.rasterframes.encoders.serialized_literal
3738

3839
/**
3940
* RasterFrameLayer extension methods associated with adding spatially descriptive columns.
@@ -71,6 +72,15 @@ trait RFSpatialColumnMethods extends MethodExtensions[RasterFrameLayer] with Sta
7172
val key2Extent = sparkUdf(keyCol2Extent)
7273
self.withColumn(colName, key2Extent(self.spatialKeyColumn)).certify
7374
}
75+
/**
76+
* Append a column containing the CRS of the layer.
77+
*
78+
* @param colName name of column to append. Defaults to "crs"
79+
* @return updated RasterFrameLayer
80+
*/
81+
def withCRS(colName: String = CRS_COLUMN.columnName): RasterFrameLayer = {
82+
self.withColumn(colName, serialized_literal(self.crs)).certify
83+
}
7484

7585
/**
7686
* Append a column containing the bounds of the row's spatial key.

datasource/src/main/scala/org/locationtech/rasterframes/datasource/geotiff/GeoTiffDataSource.scala

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -67,27 +67,21 @@ class GeoTiffDataSource
6767

6868
require(tileCols.nonEmpty, "Could not find any tile columns.")
6969

70-
val raster = if (df.isAlreadyLayer) {
71-
val layer = df.certify
72-
val tlm = layer.tileLayerMetadata.merge
73-
74-
// If no desired image size is given, write at full size.
75-
val TileDimensions(cols, rows) = parameters.rasterDimensions
76-
.getOrElse {
77-
val actualSize = tlm.layout.toRasterExtent().gridBoundsFor(tlm.extent)
78-
TileDimensions(actualSize.width, actualSize.height)
79-
}
8070

81-
// Should we really play traffic cop here?
82-
if (cols.toDouble * rows * 64.0 > Runtime.getRuntime.totalMemory() * 0.5)
83-
logger.warn(
84-
s"You've asked for the construction of a very large image ($cols x $rows), destined for ${path}. Out of memory error likely.")
8571

86-
layer.toMultibandRaster(tileCols, cols.toInt, rows.toInt)
87-
} else {
88-
require(parameters.crs.nonEmpty, "A destination CRS must be provided")
89-
TileRasterizerAggregate(df, parameters.crs.get, None, parameters.rasterDimensions)
90-
}
72+
val destCRS = parameters.crs.orElse(df.asLayerSafely.map(_.crs)).getOrElse(
73+
throw new IllegalArgumentException("A destination CRS must be provided")
74+
)
75+
76+
val input = df.asLayerSafely.map(layer =>
77+
(layer.crsColumns.isEmpty, layer.extentColumns.isEmpty) match {
78+
case (true, true) => layer.withExtent().withCRS()
79+
case (true, false) => layer.withCRS()
80+
case (false, true) => layer.withExtent()
81+
case _ => layer
82+
}).getOrElse(df)
83+
84+
val raster = TileRasterizerAggregate.collect(input, destCRS, None, parameters.rasterDimensions)
9185

9286
val tags = Tags(
9387
RFBuildInfo.toMap.filter(_._1.toLowerCase().contains("version")).mapValues(_.toString),

datasource/src/test/scala/org/locationtech/rasterframes/datasource/geotiff/GeoTiffDataSourceSpec.scala

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
*/
2121
package org.locationtech.rasterframes.datasource.geotiff
2222

23-
import java.nio.file.Paths
23+
import java.nio.file.{Path, Paths}
2424

2525
import geotrellis.proj4._
26+
import geotrellis.raster.CellType
2627
import geotrellis.raster.io.geotiff.{MultibandGeoTiff, SinglebandGeoTiff}
2728
import geotrellis.vector.Extent
2829
import org.locationtech.rasterframes._
@@ -90,6 +91,15 @@ class GeoTiffDataSourceSpec
9091

9192
describe("GeoTiff writing") {
9293

94+
def checkTiff(file: Path, cols: Int, rows: Int, extent: Extent, cellType: Option[CellType] = None) = {
95+
val outputTif = SinglebandGeoTiff(file.toString)
96+
outputTif.tile.dimensions should be ((cols, rows))
97+
outputTif.extent should be (extent)
98+
cellType.foreach(ct =>
99+
outputTif.cellType should be (ct)
100+
)
101+
}
102+
93103
it("should write GeoTIFF RF to parquet") {
94104
val rf = spark.read.format("geotiff").load(cogPath.toASCIIString).asLayer
95105
assert(write(rf))
@@ -105,6 +115,9 @@ class GeoTiffDataSourceSpec
105115
noException shouldBe thrownBy {
106116
rf.write.format("geotiff").save(out.toString)
107117
}
118+
val extent = rf.tileLayerMetadata.merge.extent
119+
120+
checkTiff(out, 1028, 989, extent)
108121
}
109122

110123
it("should write unstructured raster") {
@@ -117,10 +130,10 @@ class GeoTiffDataSourceSpec
117130

118131
val crs = df.select(rf_crs($"proj_raster")).first()
119132

120-
val out = Paths.get("target", "unstructured.tif").toString
133+
val out = Paths.get("target", "unstructured.tif")
121134

122135
noException shouldBe thrownBy {
123-
df.write.geotiff.withCRS(crs).save(out)
136+
df.write.geotiff.withCRS(crs).save(out.toString)
124137
}
125138

126139
val (inCols, inRows) = {
@@ -130,11 +143,7 @@ class GeoTiffDataSourceSpec
130143
inCols should be (774)
131144
inRows should be (500) //from gdalinfo
132145

133-
val outputTif = SinglebandGeoTiff(out)
134-
outputTif.imageData.cols should be (inCols)
135-
outputTif.imageData.rows should be (inRows)
136-
137-
// TODO check datatype, extent.
146+
checkTiff(out, inCols, inRows, Extent(431902.5, 4313647.5, 443512.5, 4321147.5))
138147
}
139148

140149
it("should round trip unstructured raster from COG"){
@@ -164,26 +173,22 @@ class GeoTiffDataSourceSpec
164173

165174
dfExtent shouldBe resourceExtent
166175

167-
val out = Paths.get("target", "unstructured_cog.tif").toString
176+
val out = Paths.get("target", "unstructured_cog.tif")
168177

169178
noException shouldBe thrownBy {
170-
df.write.geotiff.withCRS(crs).save(out)
179+
df.write.geotiff.withCRS(crs).save(out.toString)
171180
}
172181

173182
val (inCols, inRows, inExtent, inCellType) = {
174183
val tif = readSingleband("LC08_B7_Memphis_COG.tiff")
175184
val id = tif.imageData
176185
(id.cols, id.rows, tif.extent, tif.cellType)
177186
}
178-
inCols should be (963)
179-
inRows should be (754) //from gdalinfo
187+
inCols should be (resourceCols)
188+
inRows should be (resourceRows) //from gdalinfo
180189
inExtent should be (resourceExtent)
181190

182-
val outputTif = SinglebandGeoTiff(out)
183-
outputTif.imageData.cols should be (inCols)
184-
outputTif.imageData.rows should be (inRows)
185-
outputTif.extent should be (resourceExtent)
186-
outputTif.cellType should be (inCellType)
191+
checkTiff(out, inCols, inRows, resourceExtent, Some(inCellType))
187192
}
188193

189194
it("should write GeoTIFF without layer") {
@@ -218,9 +223,12 @@ class GeoTiffDataSourceSpec
218223
.save(out.toString)
219224
}
220225
}
226+
227+
checkTiff(out, 128, 128,
228+
Extent(-76.52586750038186, 36.85907177863949, -76.17461216980891, 37.1303690755922))
221229
}
222230

223-
it("should produce the correct subregion") {
231+
it("should produce the correct subregion from layer") {
224232
import spark.implicits._
225233
val rf = SinglebandGeoTiff(TestData.singlebandCogPath.getPath)
226234
.projectedRaster.toLayer(128, 128).withExtent()
@@ -232,9 +240,25 @@ class GeoTiffDataSourceSpec
232240
val expectedExtent = bitOfLayer.select($"extent".as[Extent]).first()
233241
bitOfLayer.write.geotiff.save(out.toString)
234242

235-
val result = SinglebandGeoTiff(out.toString)
236-
result.tile.dimensions should be (128, 128)
237-
result.extent should be (expectedExtent)
243+
checkTiff(out, 128, 128, expectedExtent)
244+
}
245+
246+
it("should produce the correct subregion without layer") {
247+
import spark.implicits._
248+
249+
val rf = spark.read.raster
250+
.withTileDimensions(128, 128)
251+
.load(TestData.singlebandCogPath.toASCIIString)
252+
253+
val out = Paths.get("target", "example3-geotiff.tif")
254+
logger.info(s"Writing to $out")
255+
256+
val bitOfLayer = rf.filter(st_intersects(st_makePoint(754245, 3893385), rf_geometry($"proj_raster")))
257+
val expectedExtent = bitOfLayer.select(rf_extent($"proj_raster")).first()
258+
val crs = bitOfLayer.select(rf_crs($"proj_raster")).first()
259+
bitOfLayer.write.geotiff.withCRS(crs).save(out.toString)
260+
261+
checkTiff(out, 128, 128, expectedExtent)
238262
}
239263

240264
def s(band: Int): String =

0 commit comments

Comments
 (0)