Skip to content

Commit 9a4f156

Browse files
committed
Replaced TileDimensions with Dimension[Int].
1 parent 38d503e commit 9a4f156

File tree

32 files changed

+146
-175
lines changed

32 files changed

+146
-175
lines changed

bench/src/main/scala/org/locationtech/rasterframes/bench/RasterRefBench.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ import org.apache.spark.sql._
2828
import org.locationtech.rasterframes._
2929
import org.locationtech.rasterframes.expressions.generators.RasterSourceToRasterRefs
3030
import org.locationtech.rasterframes.expressions.transformers.RasterRefToTile
31-
import org.locationtech.rasterframes.model.TileDimensions
3231
import org.locationtech.rasterframes.ref.RFRasterSource
3332
import org.openjdk.jmh.annotations._
3433

@@ -47,7 +46,7 @@ class RasterRefBench extends SparkEnv with LazyLogging {
4746
val r2 = RFRasterSource(remoteCOGSingleband2)
4847

4948
singleDF = Seq((r1, r2)).toDF("B1", "B2")
50-
.select(RasterRefToTile(RasterSourceToRasterRefs(Some(TileDimensions(r1.dimensions)), Seq(0), $"B1", $"B2")))
49+
.select(RasterRefToTile(RasterSourceToRasterRefs(Some(r1.dimensions), Seq(0), $"B1", $"B2")))
5150

5251
expandedDF = Seq((r1, r2)).toDF("B1", "B2")
5352
.select(RasterRefToTile(RasterSourceToRasterRefs($"B1", $"B2")))

build.sbt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,11 @@ lazy val pyrasterframes = project
9090
spark("core").value % Provided,
9191
spark("mllib").value % Provided,
9292
spark("sql").value % Provided
93-
)
93+
),
94+
Test / test := (Test / test).dependsOn(experimental / Test / test).value
9495
)
9596

97+
9698
lazy val datasource = project
9799
.configs(IntegrationTest)
98100
.settings(Defaults.itSettings)
@@ -105,6 +107,7 @@ lazy val datasource = project
105107
spark("mllib").value % Provided,
106108
spark("sql").value % Provided
107109
),
110+
Test / test := (Test / test).dependsOn(core / Test / test).value,
108111
initialCommands in console := (initialCommands in console).value +
109112
"""
110113
|import org.locationtech.rasterframes.datasource.geotrellis._
@@ -127,7 +130,7 @@ lazy val experimental = project
127130
),
128131
fork in IntegrationTest := true,
129132
javaOptions in IntegrationTest := Seq("-Xmx2G"),
130-
parallelExecution in IntegrationTest := false
133+
Test / test := (Test / test).dependsOn(datasource / Test / test).value
131134
)
132135

133136
lazy val docs = project

core/src/main/scala/org/apache/spark/sql/rf/VersionShims.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.apache.spark.sql.rf
22

33
import java.lang.reflect.Constructor
44

5+
import org.apache.spark.sql.AnalysisException
56
import org.apache.spark.sql.catalyst.FunctionIdentifier
67
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
78
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
@@ -12,7 +13,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
1213
import org.apache.spark.sql.execution.datasources.LogicalRelation
1314
import org.apache.spark.sql.sources.BaseRelation
1415
import org.apache.spark.sql.types.DataType
15-
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SQLContext}
1616

1717
import scala.reflect._
1818
import scala.util.{Failure, Success, Try}
@@ -23,11 +23,6 @@ import scala.util.{Failure, Success, Try}
2323
* @since 2/13/18
2424
*/
2525
object VersionShims {
26-
def readJson(sqlContext: SQLContext, rows: Dataset[String]): DataFrame = {
27-
// NB: Will get a deprecation warning for Spark 2.2.x
28-
sqlContext.read.json(rows.rdd) // <-- deprecation warning expected
29-
}
30-
3126
def updateRelation(lr: LogicalRelation, base: BaseRelation): LogicalPlan = {
3227
val lrClazz = classOf[LogicalRelation]
3328
val ctor = lrClazz.getConstructors.head.asInstanceOf[Constructor[LogicalRelation]]

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ package org.locationtech.rasterframes
2323
import geotrellis.proj4.CRS
2424
import geotrellis.raster.mapalgebra.local.LocalTileBinaryOp
2525
import geotrellis.raster.render.ColorRamp
26-
import geotrellis.raster.{CellType, Tile}
26+
import geotrellis.raster.{CellType, Dimensions, Tile}
2727
import geotrellis.vector.Extent
2828
import org.apache.spark.annotation.Experimental
2929
import org.apache.spark.sql.functions.{lit, udf}
@@ -35,9 +35,8 @@ import org.locationtech.rasterframes.expressions.aggregates._
3535
import org.locationtech.rasterframes.expressions.generators._
3636
import org.locationtech.rasterframes.expressions.localops._
3737
import org.locationtech.rasterframes.expressions.tilestats._
38-
import org.locationtech.rasterframes.expressions.transformers.RenderPNG.{RenderCompositePNG, RenderColorRampPNG}
38+
import org.locationtech.rasterframes.expressions.transformers.RenderPNG.{RenderColorRampPNG, RenderCompositePNG}
3939
import org.locationtech.rasterframes.expressions.transformers._
40-
import org.locationtech.rasterframes.model.TileDimensions
4140
import org.locationtech.rasterframes.stats._
4241
import org.locationtech.rasterframes.{functions => F}
4342

@@ -51,7 +50,7 @@ trait RasterFunctions {
5150

5251
// format: off
5352
/** Query the number of (cols, rows) in a Tile. */
54-
def rf_dimensions(col: Column): TypedColumn[Any, TileDimensions] = GetDimensions(col)
53+
def rf_dimensions(col: Column): TypedColumn[Any, Dimensions[Int]] = GetDimensions(col)
5554

5655
/** Extracts the bounding box of a geometry as an Extent */
5756
def st_extent(col: Column): TypedColumn[Any, Extent] = GeometryToExtent(col)

core/src/main/scala/org/locationtech/rasterframes/encoders/StandardEncoders.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import java.sql.Timestamp
2727
import org.locationtech.rasterframes.stats.{CellHistogram, CellStatistics, LocalCellStatistics}
2828
import org.locationtech.jts.geom.Envelope
2929
import geotrellis.proj4.CRS
30-
import geotrellis.raster.{CellSize, CellType, Raster, Tile, TileLayout}
30+
import geotrellis.raster.{CellSize, CellType, Dimensions, Raster, Tile, TileLayout}
3131
import geotrellis.layer._
3232
import geotrellis.vector.{Extent, ProjectedExtent}
3333
import org.apache.spark.sql.{Encoder, Encoders}
@@ -70,8 +70,7 @@ trait StandardEncoders extends SpatialEncoders {
7070
implicit def tileContextEncoder: ExpressionEncoder[TileContext] = TileContext.encoder
7171
implicit def tileDataContextEncoder: ExpressionEncoder[TileDataContext] = TileDataContext.encoder
7272
implicit def extentTilePairEncoder: Encoder[(ProjectedExtent, Tile)] = Encoders.tuple(projectedExtentEncoder, singlebandTileEncoder)
73-
74-
73+
implicit def tileDimensionsEncoder: Encoder[Dimensions[Int]] = CatalystSerializerEncoder[Dimensions[Int]](true)
7574
}
7675

7776
object StandardEncoders extends StandardEncoders

core/src/main/scala/org/locationtech/rasterframes/encoders/StandardSerializers.scala

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import com.github.blemale.scaffeine.Scaffeine
2525
import geotrellis.proj4.CRS
2626
import geotrellis.raster._
2727
import geotrellis.layer._
28-
2928
import geotrellis.vector._
3029
import org.apache.spark.sql.types._
3130
import org.locationtech.jts.geom.Envelope
@@ -60,9 +59,11 @@ trait StandardSerializers {
6059
StructField("xmax", DoubleType, false),
6160
StructField("ymax", DoubleType, false)
6261
))
62+
6363
override def to[R](t: Extent, io: CatalystIO[R]): R = io.create(
6464
t.xmin, t.ymin, t.xmax, t.ymax
6565
)
66+
6667
override def from[R](row: R, io: CatalystIO[R]): Extent = Extent(
6768
io.getDouble(row, 0),
6869
io.getDouble(row, 1),
@@ -95,25 +96,31 @@ trait StandardSerializers {
9596
override val schema: StructType = StructType(Seq(
9697
StructField("crsProj4", StringType, false)
9798
))
99+
98100
override def to[R](t: CRS, io: CatalystIO[R]): R = io.create(
99101
io.encode(
100102
// Don't do this... it's 1000x slower to decode.
101103
//t.epsgCode.map(c => "EPSG:" + c).getOrElse(t.toProj4String)
102104
t.toProj4String
103105
)
104106
)
107+
105108
override def from[R](row: R, io: CatalystIO[R]): CRS =
106109
LazyCRS(io.getString(row, 0))
107110
}
108111

109112
implicit val cellTypeSerializer: CatalystSerializer[CellType] = new CatalystSerializer[CellType] {
113+
110114
import StandardSerializers._
115+
111116
override val schema: StructType = StructType(Seq(
112117
StructField("cellTypeName", StringType, false)
113118
))
119+
114120
override def to[R](t: CellType, io: CatalystIO[R]): R = io.create(
115121
io.encode(ct2sCache.get(t))
116122
)
123+
117124
override def from[R](row: R, io: CatalystIO[R]): CellType =
118125
s2ctCache.get(io.getString(row, 0))
119126
}
@@ -229,7 +236,7 @@ trait StandardSerializers {
229236
)
230237
}
231238

232-
implicit def boundsSerializer[T >: Null: CatalystSerializer]: CatalystSerializer[KeyBounds[T]] = new CatalystSerializer[KeyBounds[T]] {
239+
implicit def boundsSerializer[T >: Null : CatalystSerializer]: CatalystSerializer[KeyBounds[T]] = new CatalystSerializer[KeyBounds[T]] {
233240
override val schema: StructType = StructType(Seq(
234241
StructField("minKey", schemaOf[T], true),
235242
StructField("maxKey", schemaOf[T], true)
@@ -246,7 +253,7 @@ trait StandardSerializers {
246253
)
247254
}
248255

249-
def tileLayerMetadataSerializer[T >: Null: CatalystSerializer]: CatalystSerializer[TileLayerMetadata[T]] = new CatalystSerializer[TileLayerMetadata[T]] {
256+
def tileLayerMetadataSerializer[T >: Null : CatalystSerializer]: CatalystSerializer[TileLayerMetadata[T]] = new CatalystSerializer[TileLayerMetadata[T]] {
250257
override val schema: StructType = StructType(Seq(
251258
StructField("cellType", schemaOf[CellType], false),
252259
StructField("layout", schemaOf[LayoutDefinition], false),
@@ -273,6 +280,7 @@ trait StandardSerializers {
273280
}
274281

275282
implicit def rasterSerializer: CatalystSerializer[Raster[Tile]] = new CatalystSerializer[Raster[Tile]] {
283+
276284
import org.apache.spark.sql.rf.TileUDT.tileSerializer
277285

278286
override val schema: StructType = StructType(Seq(
@@ -294,6 +302,22 @@ trait StandardSerializers {
294302
implicit val spatialKeyTLMSerializer = tileLayerMetadataSerializer[SpatialKey]
295303
implicit val spaceTimeKeyTLMSerializer = tileLayerMetadataSerializer[SpaceTimeKey]
296304

305+
implicit val tileDimensionsSerializer: CatalystSerializer[Dimensions[Int]] = new CatalystSerializer[Dimensions[Int]] {
306+
override val schema: StructType = StructType(Seq(
307+
StructField("cols", IntegerType, false),
308+
StructField("rows", IntegerType, false)
309+
))
310+
311+
override protected def to[R](t: Dimensions[Int], io: CatalystIO[R]): R = io.create(
312+
t.cols,
313+
t.rows
314+
)
315+
316+
override protected def from[R](t: R, io: CatalystIO[R]): Dimensions[Int] = Dimensions[Int](
317+
io.getInt(t, 0),
318+
io.getInt(t, 1)
319+
)
320+
}
297321
}
298322

299323
object StandardSerializers {

core/src/main/scala/org/locationtech/rasterframes/expressions/accessors/GetDimensions.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ package org.locationtech.rasterframes.expressions.accessors
2323

2424
import org.locationtech.rasterframes.encoders.CatalystSerializer._
2525
import org.locationtech.rasterframes.expressions.OnCellGridExpression
26-
import geotrellis.raster.CellGrid
26+
import geotrellis.raster.{CellGrid, Dimensions}
2727
import org.apache.spark.sql._
2828
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
2929
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
30-
import org.locationtech.rasterframes.model.TileDimensions
3130

3231
/**
3332
* Extract a raster's dimensions
@@ -43,12 +42,13 @@ import org.locationtech.rasterframes.model.TileDimensions
4342
case class GetDimensions(child: Expression) extends OnCellGridExpression with CodegenFallback {
4443
override def nodeName: String = "rf_dimensions"
4544

46-
def dataType = schemaOf[TileDimensions]
45+
def dataType = schemaOf[Dimensions[Int]]
4746

48-
override def eval(grid: CellGrid[Int]): Any = TileDimensions(grid.cols, grid.rows).toInternalRow
47+
override def eval(grid: CellGrid[Int]): Any = Dimensions[Int](grid.cols, grid.rows).toInternalRow
4948
}
5049

5150
object GetDimensions {
52-
def apply(col: Column): TypedColumn[Any, TileDimensions] =
53-
new Column(new GetDimensions(col.expr)).as[TileDimensions]
51+
import org.locationtech.rasterframes.encoders.StandardEncoders.tileDimensionsEncoder
52+
def apply(col: Column): TypedColumn[Any, Dimensions[Int]] =
53+
new Column(new GetDimensions(col.expr)).as[Dimensions[Int]]
5454
}

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/ProjectedLayerMetadataAggregate.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ package org.locationtech.rasterframes.expressions.aggregates
2424
import org.locationtech.rasterframes._
2525
import org.locationtech.rasterframes.encoders.CatalystSerializer
2626
import org.locationtech.rasterframes.encoders.CatalystSerializer._
27-
import org.locationtech.rasterframes.model.TileDimensions
2827
import geotrellis.proj4.{CRS, Transform}
2928
import geotrellis.raster._
3029
import geotrellis.raster.reproject.{Reproject, ReprojectRasterExtent}
@@ -34,7 +33,7 @@ import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAg
3433
import org.apache.spark.sql.types.{DataType, StructField, StructType}
3534
import org.apache.spark.sql.{Column, Row, TypedColumn}
3635

37-
class ProjectedLayerMetadataAggregate(destCRS: CRS, destDims: TileDimensions) extends UserDefinedAggregateFunction {
36+
class ProjectedLayerMetadataAggregate(destCRS: CRS, destDims: Dimensions[Int]) extends UserDefinedAggregateFunction {
3837
import ProjectedLayerMetadataAggregate._
3938

4039
override def inputSchema: StructType = CatalystSerializer[InputRecord].schema
@@ -94,14 +93,14 @@ object ProjectedLayerMetadataAggregate {
9493
/** Primary user facing constructor */
9594
def apply(destCRS: CRS, extent: Column, crs: Column, cellType: Column, tileSize: Column): TypedColumn[Any, TileLayerMetadata[SpatialKey]] =
9695
// Ordering must match InputRecord schema
97-
new ProjectedLayerMetadataAggregate(destCRS, TileDimensions(NOMINAL_TILE_SIZE, NOMINAL_TILE_SIZE))(extent, crs, cellType, tileSize).as[TileLayerMetadata[SpatialKey]]
96+
new ProjectedLayerMetadataAggregate(destCRS, Dimensions(NOMINAL_TILE_SIZE, NOMINAL_TILE_SIZE))(extent, crs, cellType, tileSize).as[TileLayerMetadata[SpatialKey]]
9897

99-
def apply(destCRS: CRS, destDims: TileDimensions, extent: Column, crs: Column, cellType: Column, tileSize: Column): TypedColumn[Any, TileLayerMetadata[SpatialKey]] =
98+
def apply(destCRS: CRS, destDims: Dimensions[Int], extent: Column, crs: Column, cellType: Column, tileSize: Column): TypedColumn[Any, TileLayerMetadata[SpatialKey]] =
10099
// Ordering must match InputRecord schema
101100
new ProjectedLayerMetadataAggregate(destCRS, destDims)(extent, crs, cellType, tileSize).as[TileLayerMetadata[SpatialKey]]
102101

103102
private[expressions]
104-
case class InputRecord(extent: Extent, crs: CRS, cellType: CellType, tileSize: TileDimensions) {
103+
case class InputRecord(extent: Extent, crs: CRS, cellType: CellType, tileSize: Dimensions[Int]) {
105104
def toBufferRecord(destCRS: CRS): BufferRecord = {
106105
val transform = Transform(crs, destCRS)
107106

@@ -125,7 +124,7 @@ object ProjectedLayerMetadataAggregate {
125124
StructField("extent", CatalystSerializer[Extent].schema, false),
126125
StructField("crs", CatalystSerializer[CRS].schema, false),
127126
StructField("cellType", CatalystSerializer[CellType].schema, false),
128-
StructField("tileSize", CatalystSerializer[TileDimensions].schema, false)
127+
StructField("tileSize", CatalystSerializer[Dimensions[Int]].schema, false)
129128
))
130129

131130
override protected def to[R](t: InputRecord, io: CatalystIO[R]): R =
@@ -135,7 +134,7 @@ object ProjectedLayerMetadataAggregate {
135134
io.get[Extent](t, 0),
136135
io.get[CRS](t, 1),
137136
io.get[CellType](t, 2),
138-
io.get[TileDimensions](t, 3)
137+
io.get[Dimensions[Int]](t, 3)
139138
)
140139
}
141140
}

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/TileRasterizerAggregate.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ package org.locationtech.rasterframes.expressions.aggregates
2424
import geotrellis.proj4.CRS
2525
import geotrellis.raster.reproject.Reproject
2626
import geotrellis.raster.resample.ResampleMethod
27-
import geotrellis.raster.{ArrayTile, CellType, MultibandTile, ProjectedRaster, Raster, Tile}
27+
import geotrellis.raster.{ArrayTile, CellType, Dimensions, MultibandTile, ProjectedRaster, Raster, Tile}
2828
import geotrellis.layer._
2929
import geotrellis.vector.Extent
3030
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
@@ -34,7 +34,6 @@ import org.locationtech.rasterframes._
3434
import org.locationtech.rasterframes.util._
3535
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3636
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate.ProjectedRasterDefinition
37-
import org.locationtech.rasterframes.model.TileDimensions
3837
import org.slf4j.LoggerFactory
3938

4039
/**
@@ -119,7 +118,7 @@ object TileRasterizerAggregate {
119118
new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol).as(nodeName).as[Raster[Tile]]
120119
}
121120

122-
def collect(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
121+
def collect(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[Dimensions[Int]]): ProjectedRaster[MultibandTile] = {
123122
val tileCols = WithDataFrameMethods(df).tileColumns
124123
require(tileCols.nonEmpty, "need at least one tile column")
125124
// Select the anchoring Tile, Extent and CRS columns

core/src/main/scala/org/locationtech/rasterframes/expressions/generators/RasterSourceToRasterRefs.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
package org.locationtech.rasterframes.expressions.generators
2323

24-
import geotrellis.raster.GridBounds
24+
import geotrellis.raster.{Dimensions, GridBounds}
2525
import geotrellis.vector.Extent
2626
import org.apache.spark.sql.catalyst.InternalRow
2727
import org.apache.spark.sql.catalyst.expressions._
@@ -30,8 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType}
3030
import org.apache.spark.sql.{Column, TypedColumn}
3131
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3232
import org.locationtech.rasterframes.expressions.generators.RasterSourceToRasterRefs.bandNames
33-
import org.locationtech.rasterframes.model.TileDimensions
34-
import org.locationtech.rasterframes.ref.{RasterRef, RFRasterSource}
33+
import org.locationtech.rasterframes.ref.{RFRasterSource, RasterRef}
3534
import org.locationtech.rasterframes.util._
3635
import org.locationtech.rasterframes.RasterSourceType
3736

@@ -43,7 +42,7 @@ import scala.util.control.NonFatal
4342
*
4443
* @since 9/6/18
4544
*/
46-
case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[Int], subtileDims: Option[TileDimensions] = None) extends Expression
45+
case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[Int], subtileDims: Option[Dimensions[Int]] = None) extends Expression
4746
with Generator with CodegenFallback with ExpectsInputTypes {
4847

4948
override def inputTypes: Seq[DataType] = Seq.fill(children.size)(RasterSourceType)
@@ -86,7 +85,7 @@ case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[
8685

8786
object RasterSourceToRasterRefs {
8887
def apply(rrs: Column*): TypedColumn[Any, RasterRef] = apply(None, Seq(0), rrs: _*)
89-
def apply(subtileDims: Option[TileDimensions], bandIndexes: Seq[Int], rrs: Column*): TypedColumn[Any, RasterRef] =
88+
def apply(subtileDims: Option[Dimensions[Int]], bandIndexes: Seq[Int], rrs: Column*): TypedColumn[Any, RasterRef] =
9089
new Column(new RasterSourceToRasterRefs(rrs.map(_.expr), bandIndexes, subtileDims)).as[RasterRef]
9190

9291
private[rasterframes] def bandNames(basename: String, bandIndexes: Seq[Int]): Seq[String] = bandIndexes match {

0 commit comments

Comments
 (0)