Skip to content

Commit 40d56d2

Browse files
authored
Merge pull request #332 from s22s/fix/287
Updated ExplodeTiles to work with proj_raster type.
2 parents 158ca8a + 3614338 commit 40d56d2

File tree

6 files changed

+42
-19
lines changed

6 files changed

+42
-19
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,15 @@
2121

2222
package org.locationtech.rasterframes.expressions.generators
2323

24-
import org.locationtech.rasterframes._
25-
import org.locationtech.rasterframes.encoders.CatalystSerializer._
26-
import org.locationtech.rasterframes.util._
2724
import geotrellis.raster._
2825
import org.apache.spark.sql._
2926
import org.apache.spark.sql.catalyst.InternalRow
3027
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
3128
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, GenericInternalRow}
32-
import org.apache.spark.sql.rf.TileUDT
3329
import org.apache.spark.sql.types._
30+
import org.locationtech.rasterframes._
31+
import org.locationtech.rasterframes.expressions.DynamicExtractors
32+
import org.locationtech.rasterframes.util._
3433
import spire.syntax.cfor.cfor
3534

3635
/**
@@ -67,8 +66,11 @@ case class ExplodeTiles(
6766
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
6867
val tiles = Array.ofDim[Tile](children.length)
6968
cfor(0)(_ < tiles.length, _ + 1) { index =>
70-
val row = children(index).eval(input).asInstanceOf[InternalRow]
71-
tiles(index) = if(row != null) row.to[Tile](TileUDT.tileSerializer) else null
69+
val c = children(index)
70+
val row = c.eval(input).asInstanceOf[InternalRow]
71+
tiles(index) = if(row != null)
72+
DynamicExtractors.tileExtractor(c.dataType)(row)._1
73+
else null
7274
}
7375
val dims = tiles.filter(_ != null).map(_.dimensions)
7476
if(dims.isEmpty) Seq.empty[InternalRow]

core/src/main/scala/org/locationtech/rasterframes/ml/TileColumnSupport.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
package org.locationtech.rasterframes.ml
2323

24-
import org.apache.spark.sql.rf.TileUDT
2524
import org.apache.spark.sql.types.{StructField, StructType}
25+
import org.locationtech.rasterframes.expressions.DynamicExtractors
2626

2727
/**
2828
* Utility mix-in for separating out tile columns from non-tile columns.
@@ -31,13 +31,11 @@ import org.apache.spark.sql.types.{StructField, StructType}
3131
*/
3232
trait TileColumnSupport {
3333
protected def isTile(field: StructField) =
34-
field.dataType.typeName.equalsIgnoreCase(TileUDT.typeName)
34+
DynamicExtractors.tileExtractor.isDefinedAt(field.dataType)
3535

3636
type TileFields = Array[StructField]
3737
type NonTileFields = Array[StructField]
3838
protected def selectTileAndNonTileFields(schema: StructType): (TileFields, NonTileFields) = {
39-
val tiles = schema.fields.filter(isTile)
40-
val nonTiles = schema.fields.filterNot(isTile)
41-
(tiles, nonTiles)
39+
schema.fields.partition(f => DynamicExtractors.tileExtractor.isDefinedAt(f.dataType))
4240
}
4341
}

core/src/test/scala/org/locationtech/rasterframes/ml/TileExploderSpec.scala

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,48 @@
2121

2222
package org.locationtech.rasterframes.ml
2323

24-
import org.locationtech.rasterframes.TestData
25-
import geotrellis.raster.Tile
26-
import org.apache.spark.sql.functions.lit
27-
import org.locationtech.rasterframes.TestEnvironment
24+
import geotrellis.proj4.LatLng
25+
import geotrellis.raster.{IntCellType, Tile}
26+
import org.apache.spark.sql.functions.{avg, lit}
27+
import org.locationtech.rasterframes.{TestData, TestEnvironment}
2828
/**
2929
*
3030
* @since 2/16/18
3131
*/
3232
class TileExploderSpec extends TestEnvironment with TestData {
3333
describe("Tile explode transformer") {
34-
it("should explode tiles") {
35-
import spark.implicits._
34+
import spark.implicits._
35+
it("should explode tile") {
3636
val df = Seq[(Tile, Tile)]((byteArrayTile, byteArrayTile)).toDF("tile1", "tile2").withColumn("other", lit("stuff"))
3737

3838
val exploder = new TileExploder()
3939
val newSchema = exploder.transformSchema(df.schema)
4040

4141
val exploded = exploder.transform(df)
42+
4243
assert(newSchema === exploded.schema)
4344
assert(exploded.columns.length === 5)
4445
assert(exploded.count() === 9)
4546
write(exploded)
47+
exploded.agg(avg($"tile1")).as[Double].first() should be (byteArrayTile.statisticsDouble.get.mean)
48+
}
49+
50+
it("should explode proj_raster") {
51+
val randPRT = TestData.projectedRasterTile(10, 10, scala.util.Random.nextInt(), extent, LatLng, IntCellType)
52+
53+
val df = Seq(randPRT).toDF("proj_raster").withColumn("other", lit("stuff"))
54+
55+
val exploder = new TileExploder()
56+
val newSchema = exploder.transformSchema(df.schema)
57+
58+
val exploded = exploder.transform(df)
59+
60+
assert(newSchema === exploded.schema)
61+
assert(exploded.columns.length === 4)
62+
assert(exploded.count() === randPRT.size)
63+
write(exploded)
64+
65+
exploded.agg(avg($"proj_raster")).as[Double].first() should be (randPRT.statisticsDouble.get.mean)
4666
}
4767
}
4868
}

docs/src/main/paradox/release-notes.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## 0.8.x
44

5+
### 0.8.2
6+
7+
* Fixed `TileExploder` to support `proj_raster` struct [(#287)](https://github.com/locationtech/rasterframes/issues/287).
8+
59
### 0.8.1
610

711
* Added `rf_local_no_data`, `rf_local_data` and `rf_interpret_cell_type_as` raster functions.

pyrasterframes/src/main/python/docs/vector-data.pymd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ l8_filtered = l8 \
9393
.filter(st_intersects(l8.geom, st_bufferPoint(l8.paducah, lit(50000.0)))) \
9494
.filter(l8.acquisition_date > '2018-02-01') \
9595
.filter(l8.acquisition_date < '2018-04-01')
96-
l8_filtered.select('product_id', 'entity_id', 'acquisition_date', 'cloud_cover_pct').toPandas()
96+
l8_filtered.select('product_id', 'entity_id', 'acquisition_date', 'cloud_cover_pct')
9797
```
9898

9999
[GeoPandas]: http://geopandas.org

pyrasterframes/src/main/python/tests/ExploderTests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
class ExploderTests(TestEnvironment):
3535

36-
@unittest.skip("See issue https://github.com/locationtech/rasterframes/issues/163")
3736
def test_tile_exploder_pipeline_for_prt(self):
3837
# NB the tile is a Projected Raster Tile
3938
df = self.spark.read.raster(self.img_uri)

0 commit comments

Comments
 (0)