Skip to content

Commit 1730af9

Browse files
authored
Merge pull request #429 from s22s/feature/tile-quantile
Add rf_agg_approx_quantiles function
2 parents 73a52e6 + d73e255 commit 1730af9

File tree

10 files changed

+259
-8
lines changed

10 files changed

+259
-8
lines changed

core/src/main/scala/org/locationtech/rasterframes/encoders/StandardSerializers.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,21 @@
2121

2222
package org.locationtech.rasterframes.encoders
2323

24+
import java.nio.ByteBuffer
25+
2426
import com.github.blemale.scaffeine.Scaffeine
2527
import geotrellis.proj4.CRS
2628
import geotrellis.raster._
2729
import geotrellis.spark._
2830
import geotrellis.spark.tiling.LayoutDefinition
2931
import geotrellis.vector._
32+
import org.apache.spark.sql.catalyst.util.QuantileSummaries
3033
import org.apache.spark.sql.types._
3134
import org.locationtech.jts.geom.Envelope
3235
import org.locationtech.rasterframes.TileType
3336
import org.locationtech.rasterframes.encoders.CatalystSerializer.{CatalystIO, _}
3437
import org.locationtech.rasterframes.model.LazyCRS
38+
import org.locationtech.rasterframes.util.KryoSupport
3539

3640
/** Collection of CatalystSerializers for third-party types. */
3741
trait StandardSerializers {
@@ -294,9 +298,23 @@ trait StandardSerializers {
294298
implicit val spatialKeyTLMSerializer = tileLayerMetadataSerializer[SpatialKey]
295299
implicit val spaceTimeKeyTLMSerializer = tileLayerMetadataSerializer[SpaceTimeKey]
296300

301+
implicit val quantileSerializer: CatalystSerializer[QuantileSummaries] = new CatalystSerializer[QuantileSummaries] {
302+
override val schema: StructType = StructType(Seq(
303+
StructField("quantile_serializer_kryo", BinaryType, false)
304+
))
305+
306+
override protected def to[R](t: QuantileSummaries, io: CatalystSerializer.CatalystIO[R]): R = {
307+
val buf = KryoSupport.serialize(t)
308+
io.create(buf.array())
309+
}
310+
311+
override protected def from[R](t: R, io: CatalystSerializer.CatalystIO[R]): QuantileSummaries = {
312+
KryoSupport.deserialize[QuantileSummaries](ByteBuffer.wrap(io.getByteArray(t, 0)))
313+
}
314+
}
297315
}
298316

299-
object StandardSerializers {
317+
object StandardSerializers extends StandardSerializers {
300318
private val s2ctCache = Scaffeine().build[String, CellType](
301319
(s: String) => CellType.fromName(s)
302320
)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2019 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes.expressions.aggregates
23+
24+
import geotrellis.raster.{Tile, isNoData}
25+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
26+
import org.apache.spark.sql.catalyst.util.QuantileSummaries
27+
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
28+
import org.apache.spark.sql.{Column, Encoder, Row, TypedColumn, types}
29+
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
30+
import org.locationtech.rasterframes.TileType
31+
import org.locationtech.rasterframes.encoders.CatalystSerializer._
32+
import org.locationtech.rasterframes.expressions.accessors.ExtractTile
33+
34+
35+
case class ApproxCellQuantilesAggregate(probabilities: Seq[Double], relativeError: Double) extends UserDefinedAggregateFunction {
36+
import org.locationtech.rasterframes.encoders.StandardSerializers.quantileSerializer
37+
38+
override def inputSchema: StructType = StructType(Seq(
39+
StructField("value", TileType, true)
40+
))
41+
42+
override def bufferSchema: StructType = StructType(Seq(
43+
StructField("buffer", schemaOf[QuantileSummaries], false)
44+
))
45+
46+
override def dataType: types.DataType = DataTypes.createArrayType(DataTypes.DoubleType)
47+
48+
override def deterministic: Boolean = true
49+
50+
override def initialize(buffer: MutableAggregationBuffer): Unit =
51+
buffer.update(0, new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError).toRow)
52+
53+
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
54+
val qs = buffer.getStruct(0).to[QuantileSummaries]
55+
if (!input.isNullAt(0)) {
56+
val tile = input.getAs[Tile](0)
57+
var result = qs
58+
tile.foreachDouble(d => if (!isNoData(d)) result = result.insert(d))
59+
buffer.update(0, result.toRow)
60+
}
61+
else buffer
62+
}
63+
64+
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
65+
val left = buffer1.getStruct(0).to[QuantileSummaries]
66+
val right = buffer2.getStruct(0).to[QuantileSummaries]
67+
val merged = left.compress().merge(right.compress())
68+
buffer1.update(0, merged.toRow)
69+
}
70+
71+
override def evaluate(buffer: Row): Seq[Double] = {
72+
val summaries = buffer.getStruct(0).to[QuantileSummaries]
73+
probabilities.flatMap(summaries.query)
74+
}
75+
}
76+
77+
object ApproxCellQuantilesAggregate {
78+
private implicit def doubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
79+
80+
def apply(
81+
tile: Column,
82+
probabilities: Seq[Double],
83+
relativeError: Double = 0.00001): TypedColumn[Any, Seq[Double]] = {
84+
new ApproxCellQuantilesAggregate(probabilities, relativeError)(ExtractTile(tile))
85+
.as(s"rf_agg_approx_quantiles")
86+
.as[Seq[Double]]
87+
}
88+
}

core/src/main/scala/org/locationtech/rasterframes/expressions/aggregates/HistogramAggregate.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ object HistogramAggregate {
9898
import org.locationtech.rasterframes.encoders.StandardEncoders.cellHistEncoder
9999

100100
def apply(col: Column): TypedColumn[Any, CellHistogram] =
101-
new HistogramAggregate()(ExtractTile(col))
101+
apply(col, StreamingHistogram.DEFAULT_NUM_BUCKETS)
102+
103+
def apply(col: Column, numBuckets: Int): TypedColumn[Any, CellHistogram] =
104+
new HistogramAggregate(numBuckets)(ExtractTile(col))
102105
.as(s"rf_agg_approx_histogram($col)")
103106
.as[CellHistogram]
104107

core/src/main/scala/org/locationtech/rasterframes/functions/AggregateFunctions.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,32 @@ trait AggregateFunctions {
5151
/** Compute the cellwise/local count of NoData cells for all Tiles in a column. */
5252
def rf_agg_local_no_data_cells(tile: Column): TypedColumn[Any, Tile] = LocalCountAggregate.LocalNoDataCellsUDAF(tile)
5353

54-
/** Compute the full column aggregate floating point histogram. */
54+
/** Compute the approximate aggregate floating point histogram using a streaming algorithm, with the default of 80 buckets. */
5555
def rf_agg_approx_histogram(tile: Column): TypedColumn[Any, CellHistogram] = HistogramAggregate(tile)
5656

57+
/** Compute the approximate aggregate floating point histogram using a streaming algorithm, with the given number of buckets. */
58+
def rf_agg_approx_histogram(col: Column, numBuckets: Int): TypedColumn[Any, CellHistogram] = {
59+
require(numBuckets > 0, "Must provide a positive number of buckets")
60+
HistogramAggregate(col, numBuckets)
61+
}
62+
63+
/**
64+
* Calculates the approximate quantiles of a tile column of a DataFrame.
65+
* @param tile tile column to extract cells from.
66+
* @param probabilities a list of quantile probabilities
67+
* Each number must belong to [0, 1].
68+
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
69+
* @param relativeError The relative target precision to achieve (greater than or equal to 0).
70+
* @return the approximate quantiles at the given probabilities of each column
71+
*/
72+
def rf_agg_approx_quantiles(
73+
tile: Column,
74+
probabilities: Seq[Double],
75+
relativeError: Double = 0.00001): TypedColumn[Any, Seq[Double]] = {
76+
require(probabilities.nonEmpty, "at least one quantile probability is required")
77+
ApproxCellQuantilesAggregate(tile, probabilities, relativeError)
78+
}
79+
5780
/** Compute the full column aggregate floating point statistics. */
5881
def rf_agg_stats(tile: Column): TypedColumn[Any, CellStatistics] = CellStatsAggregate(tile)
5982

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2018 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package org.locationtech.rasterframes
23+
24+
import org.locationtech.rasterframes.RasterFunctions
25+
import org.apache.spark.sql.functions.{col, explode}
26+
27+
class RasterFramesStatsSpec extends TestEnvironment with TestData {
28+
29+
import spark.implicits._
30+
31+
val df = TestData.sampleGeoTiff
32+
.toDF()
33+
.withColumn("tilePlus2", rf_local_add(col("tile"), 2))
34+
35+
36+
describe("Tile quantiles through built-in functions") {
37+
38+
it("should compute approx percentiles for a single tile col") {
39+
// Use "explode"
40+
val result = df
41+
.select(rf_explode_tiles($"tile"))
42+
.stat
43+
.approxQuantile("tile", Array(0.10, 0.50, 0.90), 0.00001)
44+
45+
result.length should be(3)
46+
47+
// computing externally with numpy we arrive at 7963, 10068, 12160 for these quantiles
48+
result should contain inOrderOnly(7963.0, 10068.0, 12160.0)
49+
50+
// Use "to_array" and built-in explode
51+
val result2 = df
52+
.select(explode(rf_tile_to_array_double($"tile")) as "tile")
53+
.stat
54+
.approxQuantile("tile", Array(0.10, 0.50, 0.90), 0.00001)
55+
56+
result2.length should be(3)
57+
58+
// computing externally with numpy we arrive at 7963, 10068, 12160 for these quantiles
59+
result2 should contain inOrderOnly(7963.0, 10068.0, 12160.0)
60+
61+
}
62+
}
63+
64+
describe("Tile quantiles through custom aggregate") {
65+
it("should compute approx percentiles for a single tile col") {
66+
val result = df
67+
.select(rf_agg_approx_quantiles($"tile", Seq(0.1, 0.5, 0.9)))
68+
.first()
69+
70+
result.length should be(3)
71+
72+
// computing externally with numpy we arrive at 7963, 10068, 12160 for these quantiles
73+
result should contain inOrderOnly(7963.0, 10068.0, 12160.0)
74+
}
75+
76+
}
77+
}
78+

docs/src/main/paradox/reference.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,14 @@ Aggregates over the `tile` and returns statistical summaries of cell values: num
634634

635635
Aggregates over all of the rows in DataFrame of `tile` and returns a count of each cell value to create a histogram with values are plotted on the x-axis and counts on the y-axis. Related is the @ref:[`rf_tile_histogram`](reference.md#rf-tile-histogram) function which operates on a single row at a time.
636636

637+
### rf_agg_approx_quantiles
638+
639+
Array[Double] rf_agg_approx_quantiles(Tile tile, List[float] probabilities, float relative_error)
640+
641+
__Not supported in SQL.__
642+
643+
Calculates the approximate quantiles of a tile column of a DataFrame. `probabilities` is a list of float values at which to compute the quantiles. These must belong to [0, 1]. For example 0 is the minimum, 0.5 is the median, 1 is the maximum. Returns an array of values approximately at the specified `probabilities`.
644+
637645
### rf_agg_extent
638646

639647
Extent rf_agg_extent(Extent extent)

docs/src/main/paradox/release-notes.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* Added `rf_render_color_ramp_png` to compute PNG byte array for a single tile column, with specified color ramp.
1515
* In `rf_ipython`, improved rendering of dataframe binary contents with PNG preamble.
1616
* Throw an `IllegalArgumentException` when attempting to apply a mask to a `Tile` whose `CellType` has no NoData defined. ([#409](https://github.com/locationtech/rasterframes/issues/384))
17+
* Add `rf_agg_approx-quantiles` function to compute cell quantiles across an entire column.
1718

1819
### 0.8.4
1920

pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,22 @@ def rf_agg_approx_histogram(tile_col):
313313
return _apply_column_function('rf_agg_approx_histogram', tile_col)
314314

315315

316+
def rf_agg_approx_quantiles(tile_col, probabilities, relative_error=0.00001):
317+
"""
318+
Calculates the approximate quantiles of a tile column of a DataFrame.
319+
320+
:param tile_col: column to extract cells from.
321+
:param probabilities: a list of quantile probabilities. Each number must belong to [0, 1].
322+
For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
323+
:param relative_error: The relative target precision to achieve (greater than or equal to 0). Default is 0.00001
324+
:return: An array of values approximately at the specified `probabilities`
325+
"""
326+
327+
_jfn = RFContext.active().lookup('rf_agg_approx_quantiles')
328+
_tile_col = _to_java_column(tile_col)
329+
return Column(_jfn(_tile_col, probabilities, relative_error))
330+
331+
316332
def rf_agg_stats(tile_col):
317333
"""Compute the full column aggregate floating point statistics"""
318334
return _apply_column_function('rf_agg_stats', tile_col)

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,25 @@
2020

2121
from unittest import skip
2222

23-
import numpy as np
24-
import sys
25-
from numpy.testing import assert_equal
26-
from pyspark import Row
27-
from pyspark.sql.functions import *
2823

2924
import pyrasterframes
3025
from pyrasterframes.rasterfunctions import *
3126
from pyrasterframes.rf_types import *
3227
from pyrasterframes.utils import gdal_version
28+
from pyspark import Row
29+
from pyspark.sql.functions import *
30+
31+
import numpy as np
32+
from numpy.testing import assert_equal, assert_allclose
33+
34+
from unittest import skip
3335
from . import TestEnvironment
3436

3537

3638
class RasterFunctions(TestEnvironment):
3739

3840
def setUp(self):
41+
import sys
3942
if not sys.warnoptions:
4043
import warnings
4144
warnings.simplefilter("ignore")
@@ -138,6 +141,12 @@ def test_aggregations(self):
138141
self.assertEqual(row['rf_agg_no_data_cells(tile)'], 1000)
139142
self.assertEqual(row['rf_agg_stats(tile)'].data_cells, row['rf_agg_data_cells(tile)'])
140143

144+
def test_agg_approx_quantiles(self):
145+
agg = self.rf.agg(rf_agg_approx_quantiles('tile', [0.1, 0.5, 0.9, 0.98]))
146+
result = agg.first()[0]
147+
# expected result from computing in external python process; c.f. scala tests
148+
assert_allclose(result, np.array([7963., 10068., 12160., 14366.]))
149+
141150
def test_sql(self):
142151

143152
self.rf.createOrReplaceTempView("rf_test_sql")

pyrasterframes/src/main/scala/org/locationtech/rasterframes/py/PyRFContext.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,13 @@ class PyRFContext(implicit sparkSession: SparkSession) extends RasterFunctions
191191

192192
def rf_local_unequal_int(col: Column, scalar: Int): Column = rf_local_unequal[Int](col, scalar)
193193

194+
// other function support
195+
/** py4j friendly version of this function */
196+
def rf_agg_approx_quantiles(tile: Column, probabilities: java.util.List[Double], relativeError: Double): TypedColumn[Any, Seq[Double]] = {
197+
import scala.collection.JavaConverters._
198+
rf_agg_approx_quantiles(tile, probabilities.asScala, relativeError)
199+
}
200+
194201
def _make_crs_literal(crsText: String): Column = {
195202
rasterframes.encoders.serialized_literal[CRS](LazyCRS(crsText))
196203
}

0 commit comments

Comments
 (0)