Skip to content

Commit 60c4917

Browse files
authored
Merge pull request #337 from s22s/fix/333
Fix for #333 and additional tests in that vein.
2 parents 40d56d2 + b92012e commit 60c4917

File tree

5 files changed

+38
-11
lines changed

5 files changed

+38
-11
lines changed

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/LocalMeanAggregate.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
3030
import org.apache.spark.sql.types.DataType
3131
import org.apache.spark.sql.{Column, TypedColumn}
3232
import org.locationtech.rasterframes.TileType
33+
import org.locationtech.rasterframes.expressions.accessors.RealizeTile
3334

3435
@ExpressionDescription(
3536
usage = "_FUNC_(tile) - Computes a new tile contining the mean cell values across all tiles in column.",
@@ -58,11 +59,11 @@ case class LocalMeanAggregate(child: Expression) extends UnaryRasterAggregate {
5859
)
5960
override lazy val updateExpressions: Seq[Expression] = Seq(
6061
If(IsNull(count),
61-
SetCellType(Defined(child), Literal("int32")),
62-
If(IsNull(child), count, BiasedAdd(count, Defined(child)))
62+
SetCellType(RealizeTile(Defined(child)), Literal("int32")),
63+
If(IsNull(child), count, BiasedAdd(count, Defined(RealizeTile(child))))
6364
),
6465
If(IsNull(sum),
65-
SetCellType(child, Literal("float64")),
66+
SetCellType(RealizeTile(child), Literal("float64")),
6667
If(IsNull(child), sum, BiasedAdd(sum, child))
6768
)
6869
)

core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import org.locationtech.rasterframes._
3333
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3434
import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
3535
import org.locationtech.rasterframes.expressions.row
36-
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
3736

3837
/**
3938
* Expression to combine the given tile columns into an 32-bit RGB composite.

core/src/main/scala/org/locationtech/rasterframes/extensions/DataFrameMethods.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
package org.locationtech.rasterframes.extensions
2323

2424
import geotrellis.proj4.CRS
25-
import geotrellis.raster.{MultibandTile, ProjectedRaster}
2625
import geotrellis.spark.io._
2726
import geotrellis.spark.{SpaceTimeKey, SpatialComponent, SpatialKey, TemporalKey, TileLayerMetadata}
2827
import geotrellis.util.MethodExtensions
@@ -33,9 +32,7 @@ import org.apache.spark.sql.{Column, DataFrame, TypedColumn}
3332
import org.locationtech.rasterframes.StandardColumns._
3433
import org.locationtech.rasterframes.encoders.CatalystSerializer._
3534
import org.locationtech.rasterframes.encoders.StandardEncoders._
36-
import org.locationtech.rasterframes.expressions.{DynamicExtractors, aggregates}
37-
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate
38-
import org.locationtech.rasterframes.model.TileDimensions
35+
import org.locationtech.rasterframes.expressions.DynamicExtractors
3936
import org.locationtech.rasterframes.tiles.ProjectedRasterTile
4037
import org.locationtech.rasterframes.util._
4138
import org.locationtech.rasterframes.{MetadataKeys, RasterFrameLayer}

core/src/test/scala/org/locationtech/rasterframes/RasterFunctionsSpec.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,11 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
525525
checkDocs("rf_agg_local_max")
526526
}
527527

528-
ignore("should compute local mean") {
528+
it("should compute local mean") {
529529
checkDocs("rf_agg_local_mean")
530-
// https://github.com/locationtech/rasterframes/issues/333
531530
val df = Seq(two, three, one, six).toDF("tile")
532531
.withColumn("id", monotonically_increasing_id())
532+
533533
df.select(rf_agg_local_mean($"tile")).first() should be(three.toArrayTile())
534534

535535
df.selectExpr("rf_agg_local_mean(tile)").as[Tile].first() should be(three.toArrayTile())
@@ -539,7 +539,6 @@ class RasterFunctionsSpec extends TestEnvironment with RasterMatchers {
539539
.agg(rf_agg_local_mean($"tile"))
540540
.collect()
541541
}
542-
543542
}
544543

545544
it("should compute local data cell counts") {

core/src/test/scala/org/locationtech/rasterframes/TileStatsSpec.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ package org.locationtech.rasterframes
2424
import geotrellis.raster._
2525
import geotrellis.raster.mapalgebra.local.{Max, Min}
2626
import geotrellis.spark._
27+
import org.apache.spark.sql.Column
2728
import org.apache.spark.sql.functions._
2829
import org.locationtech.rasterframes.TestData.randomTile
2930
import org.locationtech.rasterframes.stats.CellHistogram
@@ -317,4 +318,34 @@ class TileStatsSpec extends TestEnvironment with TestData {
317318
ndCount2 should be(count + 1)
318319
}
319320
}
321+
322+
describe("proj_raster handling") {
323+
it("should handle proj_raster structures") {
324+
val df = Seq(lazyPRT, lazyPRT).toDF("tile")
325+
326+
val targets = Seq[Column => Column](
327+
rf_is_no_data_tile,
328+
rf_data_cells,
329+
rf_no_data_cells,
330+
rf_agg_local_max,
331+
rf_agg_local_min,
332+
rf_agg_local_mean,
333+
rf_agg_local_data_cells,
334+
rf_agg_local_no_data_cells,
335+
rf_agg_local_stats,
336+
rf_agg_approx_histogram,
337+
rf_tile_histogram,
338+
rf_tile_stats,
339+
rf_tile_mean,
340+
rf_tile_max,
341+
rf_tile_min
342+
)
343+
344+
forEvery(targets) { f =>
345+
noException shouldBe thrownBy {
346+
df.select(f($"tile")).collect()
347+
}
348+
}
349+
}
350+
}
320351
}

0 commit comments

Comments
 (0)