Skip to content

Commit b671207

Browse files
committed
Switched Explode tiles to use UnsafeRow for slight improvement on memory pressure.
Reworked TileExplodeBench
1 parent c19ad68 commit b671207

File tree

4 files changed

+27
-679
lines changed

4 files changed

+27
-679
lines changed

bench/src/main/scala/org/locationtech/rasterframes/bench/TileExplodeBench.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ package org.locationtech.rasterframes.bench
2222

2323
import java.util.concurrent.TimeUnit
2424

25+
import org.apache.spark.sql.catalyst.InternalRow
26+
import org.apache.spark.sql.catalyst.expressions.BoundReference
27+
import org.apache.spark.sql.rf.TileUDT
2528
import org.locationtech.rasterframes._
26-
import org.apache.spark.sql._
27-
import org.apache.spark.sql.functions._
29+
import org.locationtech.rasterframes.expressions.generators.ExplodeTiles
2830
import org.openjdk.jmh.annotations._
29-
3031
/**
3132
*
3233
* @author sfitch
@@ -36,33 +37,32 @@ import org.openjdk.jmh.annotations._
3637
@State(Scope.Benchmark)
3738
@OutputTimeUnit(TimeUnit.MILLISECONDS)
3839
class TileExplodeBench extends SparkEnv {
39-
import spark.implicits._
4040

41-
@Param(Array("uint8", "uint16ud255", "float32", "float64"))
41+
//@Param(Array("uint8", "uint16ud255", "float32", "float64"))
42+
@Param(Array("uint16ud255"))
4243
var cellTypeName: String = _
4344

4445
@Param(Array("256"))
4546
var tileSize: Int = _
4647

47-
@Param(Array("100"))
48+
@Param(Array("2000"))
4849
var numTiles: Int = _
4950

5051
@transient
51-
var tiles: DataFrame = _
52+
var tiles: Array[InternalRow] = _
53+
54+
var exploder: ExplodeTiles = _
5255

5356
@Setup(Level.Trial)
5457
def setupData(): Unit = {
55-
tiles = Seq.fill(numTiles)(randomTile(tileSize, tileSize, cellTypeName))
56-
.toDF("tile").repartition(10)
57-
}
58-
59-
@Benchmark
60-
def arrayExplode() = {
61-
tiles.select(posexplode(rf_tile_to_array_double($"tile"))).count()
58+
tiles = Array.fill(numTiles)(randomTile(tileSize, tileSize, cellTypeName))
59+
.map(t => InternalRow(TileUDT.tileSerializer.toInternalRow(t)))
60+
val expr = BoundReference(0, TileType, true)
61+
exploder = new ExplodeTiles(1.0, None, Seq(expr))
6262
}
63-
6463
@Benchmark
6564
def tileExplode() = {
66-
tiles.select(rf_explode_tiles($"tile")).count()
65+
for(t <- tiles)
66+
exploder.eval(t)
6767
}
6868
}

core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ package org.locationtech.rasterframes.expressions.generators
2424
import geotrellis.raster._
2525
import org.apache.spark.sql._
2626
import org.apache.spark.sql.catalyst.InternalRow
27-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
28-
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, GenericInternalRow}
27+
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, CodegenFallback, UnsafeRowWriter}
28+
import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, GenericInternalRow, UnsafeRow}
2929
import org.apache.spark.sql.types._
3030
import org.locationtech.rasterframes._
3131
import org.locationtech.rasterframes.expressions.DynamicExtractors
@@ -87,14 +87,17 @@ case class ExplodeTiles(
8787
cfor(0)(_ < rows, _ + 1) { row =>
8888
cfor(0)(_ < cols, _ + 1) { col =>
8989
val rowIndex = row * cols + col
90-
val outCols = Array.ofDim[Any](numOutCols)
91-
outCols(0) = col
92-
outCols(1) = row
90+
val outRow = new UnsafeRow(numOutCols)
91+
val buffer = new BufferHolder(outRow)
92+
val writer = new UnsafeRowWriter(buffer, numOutCols)
93+
writer.write(0, col)
94+
writer.write(1, row)
9395
cfor(0)(_ < tiles.length, _ + 1) { index =>
9496
val tile = tiles(index)
95-
outCols(index + 2) = if(tile == null) doubleNODATA else tile.getDouble(col, row)
97+
val cell: Double = if (tile == null) doubleNODATA else tile.getDouble(col, row)
98+
writer.write(index + 2, cell)
9699
}
97-
retval(rowIndex) = new GenericInternalRow(outCols)
100+
retval(rowIndex) = outRow
98101
}
99102
}
100103
if(sampleFraction > 0.0 && sampleFraction < 1.0) sample(retval)

project/plugins.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.6.2")
88
addSbtPlugin("com.typesafe.sbt" % "sbt-site" % "1.3.2")
99
addSbtPlugin("com.lightbend.paradox" % "sbt-paradox" % "0.5.5")
1010
addSbtPlugin("io.github.jonas" % "sbt-paradox-material-theme" % "0.6.0")
11-
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.6")
11+
addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.3.3")
1212
addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "2.1")
1313
addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.1")
1414
addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.4.1")

0 commit comments

Comments
 (0)