Skip to content

Commit 8bac84a

Browse files
authored
Merge pull request #358 from s22s/feature/rs-reuse-refactor
Tweaks to RasterSource-related code for easier extendability.
2 parents f4a9a7c + 50c69a9 commit 8bac84a

File tree

6 files changed

+75
-62
lines changed

6 files changed

+75
-62
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import geotrellis.raster._
2525
import org.apache.spark.sql._
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, CodegenFallback, UnsafeRowWriter}
28-
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, GenericInternalRow, UnsafeRow}
28+
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, UnsafeRow}
2929
import org.apache.spark.sql.types._
3030
import org.locationtech.rasterframes._
3131
import org.locationtech.rasterframes.expressions.DynamicExtractors

core/src/main/scala/org/locationtech/rasterframes/extensions/ReprojectToLayer.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,23 @@ import org.apache.spark.sql.functions.broadcast
2727
import org.locationtech.rasterframes._
2828
import org.locationtech.rasterframes.util._
2929
object ReprojectToLayer {
30-
3130
def apply(df: DataFrame, tlm: TileLayerMetadata[SpatialKey]): RasterFrameLayer = {
3231
// create a destination dataframe with crs and extend columns
3332
// use RasterJoin to do the rest.
3433
val gb = tlm.gridBounds
3534
val crs = tlm.crs
3635

36+
import df.sparkSession.implicits._
37+
implicit val enc = Encoders.tuple(spatialKeyEncoder, extentEncoder, crsEncoder)
38+
3739
val gridItems = for {
3840
(col, row) <- gb.coordsIter
3941
sk = SpatialKey(col, row)
4042
e = tlm.mapTransform(sk)
4143
} yield (sk, e, crs)
4244

43-
val dest = df.sparkSession.createDataFrame(gridItems.toSeq)
44-
.toDF(SPATIAL_KEY_COLUMN.columnName, EXTENT_COLUMN.columnName, CRS_COLUMN.columnName)
45+
val dest = gridItems.toSeq.toDF(SPATIAL_KEY_COLUMN.columnName, EXTENT_COLUMN.columnName, CRS_COLUMN.columnName)
46+
dest.show(false)
4547
val joined = RasterJoin(broadcast(dest), df)
4648

4749
joined.asLayer(SPATIAL_KEY_COLUMN, tlm)

datasource/src/main/scala/org/locationtech/rasterframes/datasource/geotiff/GeoTiffDataSource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class GeoTiffDataSource
9090
}
9191

9292
val tags = Tags(
93-
RFBuildInfo.toMap.filter(_._1.toLowerCase().contains("version")).mapValues(_.toString),
93+
RFBuildInfo.toMap.filter(_._1.toLowerCase() == "version").mapValues(_.toString),
9494
tileCols.map(c => Map("RF_COL" -> c.columnName)).toList
9595
)
9696

datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121

2222
package org.locationtech.rasterframes.datasource.raster
2323

24+
import java.net.URI
25+
import java.util.UUID
26+
2427
import org.locationtech.rasterframes._
2528
import org.locationtech.rasterframes.util._
26-
import org.apache.spark.sql.SQLContext
29+
import org.apache.spark.sql.{DataFrame, DataFrameReader, SQLContext}
2730
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}
2831
import org.locationtech.rasterframes.model.TileDimensions
32+
import shapeless.tag
33+
import shapeless.tag.@@
2934

3035
class RasterSourceDataSource extends DataSourceRegister with RelationProvider {
3136
import RasterSourceDataSource._
@@ -58,6 +63,8 @@ object RasterSourceDataSource {
5863
}
5964
/** Container for specifying raster paths. */
6065
case class RasterSourceCatalog(csv: String, bandColumnNames: String*) extends WithBandColumns {
66+
protected def tmpTableName() = UUID.randomUUID().toString.replace("-", "")
67+
6168
def registerAsTable(sqlContext: SQLContext): RasterSourceCatalogRef = {
6269
import sqlContext.implicits._
6370
val lines = csv
@@ -95,7 +102,6 @@ object RasterSourceDataSource {
95102
/** Container for specifying where to select raster paths from. */
96103
case class RasterSourceCatalogRef(tableName: String, bandColumnNames: String*) extends WithBandColumns
97104

98-
private[raster]
99105
implicit class ParamsDictAccessors(val parameters: Map[String, String]) extends AnyVal {
100106
def tokenize(csv: String): Seq[String] = csv.split(',').map(_.trim)
101107

@@ -151,4 +157,60 @@ object RasterSourceDataSource {
151157
}
152158
}
153159
}
160+
161+
/** Mixin for adding extension methods on DataFrameReader for RasterSourceDataSource-like readers. */
162+
trait CatalogReaderOptionsSupport[ReaderTag] {
163+
type TaggedReader = DataFrameReader @@ ReaderTag
164+
val reader: TaggedReader
165+
166+
protected def tmpTableName() = UUID.randomUUID().toString.replace("-", "")
167+
168+
/** Set the zero-based band indexes to read. Defaults to Seq(0). */
169+
def withBandIndexes(bandIndexes: Int*): TaggedReader =
170+
tag[ReaderTag][DataFrameReader](
171+
reader.option(RasterSourceDataSource.BAND_INDEXES_PARAM, bandIndexes.mkString(","))
172+
)
173+
174+
def withTileDimensions(cols: Int, rows: Int): TaggedReader =
175+
tag[ReaderTag][DataFrameReader](
176+
reader.option(RasterSourceDataSource.TILE_DIMS_PARAM, s"$cols,$rows")
177+
)
178+
179+
/** Indicate if tile reading should be delayed until cells are fetched. Defaults to `true`. */
180+
def withLazyTiles(state: Boolean): TaggedReader =
181+
tag[ReaderTag][DataFrameReader](
182+
reader.option(RasterSourceDataSource.LAZY_TILES_PARAM, state))
183+
184+
def fromCatalog(catalog: DataFrame, bandColumnNames: String*): TaggedReader =
185+
tag[ReaderTag][DataFrameReader] {
186+
val tmpName = tmpTableName()
187+
catalog.createOrReplaceTempView(tmpName)
188+
reader
189+
.option(RasterSourceDataSource.CATALOG_TABLE_PARAM, tmpName)
190+
.option(RasterSourceDataSource.CATALOG_TABLE_COLS_PARAM, bandColumnNames.mkString(",")): DataFrameReader
191+
}
192+
193+
def fromCatalog(tableName: String, bandColumnNames: String*): TaggedReader =
194+
tag[ReaderTag][DataFrameReader](
195+
reader.option(RasterSourceDataSource.CATALOG_TABLE_PARAM, tableName)
196+
.option(RasterSourceDataSource.CATALOG_TABLE_COLS_PARAM, bandColumnNames.mkString(","))
197+
)
198+
199+
def fromCSV(catalogCSV: String, bandColumnNames: String*): TaggedReader =
200+
tag[ReaderTag][DataFrameReader](
201+
reader.option(RasterSourceDataSource.CATALOG_CSV_PARAM, catalogCSV)
202+
.option(RasterSourceDataSource.CATALOG_TABLE_COLS_PARAM, bandColumnNames.mkString(","))
203+
)
204+
205+
def from(newlineDelimPaths: String): TaggedReader =
206+
tag[ReaderTag][DataFrameReader](
207+
reader.option(RasterSourceDataSource.PATHS_PARAM, newlineDelimPaths)
208+
)
209+
210+
def from(paths: Seq[String]): TaggedReader =
211+
from(paths.mkString("\n"))
212+
213+
def from(uris: Seq[URI])(implicit d: DummyImplicit): TaggedReader =
214+
from(uris.map(_.toASCIIString))
215+
}
154216
}

datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/package.scala

Lines changed: 3 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,11 @@
2121

2222
package org.locationtech.rasterframes.datasource
2323

24-
import java.net.URI
25-
import java.util.UUID
26-
27-
import org.apache.spark.sql.{DataFrame, DataFrameReader}
24+
import org.apache.spark.sql.DataFrameReader
2825
import shapeless.tag
2926
import shapeless.tag.@@
3027
package object raster {
3128

32-
private[raster] def tmpTableName() = UUID.randomUUID().toString.replace("-", "")
33-
3429
trait RasterSourceDataFrameReaderTag
3530
type RasterSourceDataFrameReader = DataFrameReader @@ RasterSourceDataFrameReaderTag
3631

@@ -42,52 +37,6 @@ package object raster {
4237
}
4338

4439
/** Adds option methods relevant to RasterSourceDataSource. */
45-
implicit class RasterSourceDataFrameReaderHasOptions(val reader: RasterSourceDataFrameReader) {
46-
/** Set the zero-based band indexes to read. Defaults to Seq(0). */
47-
def withBandIndexes(bandIndexes: Int*): RasterSourceDataFrameReader =
48-
tag[RasterSourceDataFrameReaderTag][DataFrameReader](
49-
reader.option(RasterSourceDataSource.BAND_INDEXES_PARAM, bandIndexes.mkString(",")))
50-
51-
def withTileDimensions(cols: Int, rows: Int): RasterSourceDataFrameReader =
52-
tag[RasterSourceDataFrameReaderTag][DataFrameReader](
53-
reader.option(RasterSourceDataSource.TILE_DIMS_PARAM, s"$cols,$rows")
54-
)
55-
56-
/** Indicate if tile reading should be delayed until cells are fetched. Defaults to `true`. */
57-
def withLazyTiles(state: Boolean): RasterSourceDataFrameReader =
58-
tag[RasterSourceDataFrameReaderTag][DataFrameReader](
59-
reader.option(RasterSourceDataSource.LAZY_TILES_PARAM, state))
60-
61-
def fromCatalog(catalog: DataFrame, bandColumnNames: String*): RasterSourceDataFrameReader =
62-
tag[RasterSourceDataFrameReaderTag][DataFrameReader] {
63-
val tmpName = tmpTableName()
64-
catalog.createOrReplaceTempView(tmpName)
65-
reader
66-
.option(RasterSourceDataSource.CATALOG_TABLE_PARAM, tmpName)
67-
.option(RasterSourceDataSource.CATALOG_TABLE_COLS_PARAM, bandColumnNames.mkString(",")): DataFrameReader
68-
}
69-
70-
def fromCatalog(tableName: String, bandColumnNames: String*): RasterSourceDataFrameReader =
71-
tag[RasterSourceDataFrameReaderTag][DataFrameReader](
72-
reader.option(RasterSourceDataSource.CATALOG_TABLE_PARAM, tableName)
73-
.option(RasterSourceDataSource.CATALOG_TABLE_COLS_PARAM, bandColumnNames.mkString(","))
74-
)
75-
76-
def fromCSV(catalogCSV: String, bandColumnNames: String*): RasterSourceDataFrameReader =
77-
tag[RasterSourceDataFrameReaderTag][DataFrameReader](
78-
reader.option(RasterSourceDataSource.CATALOG_CSV_PARAM, catalogCSV)
79-
.option(RasterSourceDataSource.CATALOG_TABLE_COLS_PARAM, bandColumnNames.mkString(","))
80-
)
81-
82-
def from(newlineDelimPaths: String): RasterSourceDataFrameReader =
83-
tag[RasterSourceDataFrameReaderTag][DataFrameReader](
84-
reader.option(RasterSourceDataSource.PATHS_PARAM, newlineDelimPaths)
85-
)
86-
87-
def from(paths: Seq[String]): RasterSourceDataFrameReader =
88-
from(paths.mkString("\n"))
89-
90-
def from(uris: Seq[URI])(implicit d: DummyImplicit): RasterSourceDataFrameReader =
91-
from(uris.map(_.toASCIIString))
92-
}
40+
implicit class RasterSourceDataFrameReaderHasOptions(val reader: RasterSourceDataFrameReader)
41+
extends RasterSourceDataSource.CatalogReaderOptionsSupport[RasterSourceDataFrameReaderTag]
9342
}

pyrasterframes/src/main/python/docs/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ RasterFrames is released under the [Apache 2.0 License](https://github.com/locat
3131
* [Vector Data](vector-data.md)
3232
* [Raster Processing](raster-processing.md)
3333
* [Numpy and Pandas](numpy-pandas.md)
34-
* [API Languages](languages.md)
34+
* [Scala and SQL](languages.md)
3535
* [Function Reference](reference.md)
3636
* [Release Notes](release-notes.md)
3737
@@@

0 commit comments

Comments
 (0)