Skip to content

Commit 73a52e6

Browse files
committed
Fixes to scala classification model example
1 parent 108ee65 commit 73a52e6

File tree

1 file changed

+14
-32
lines changed

1 file changed

+14
-32
lines changed

datasource/src/test/scala/examples/Classification.scala renamed to datasource/src/test/scala/examples/ClassificationRasterSource.scala

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,22 @@
2121

2222
package examples
2323

24-
import java.net.URL
25-
2624
import geotrellis.raster._
27-
import geotrellis.raster.io.geotiff.reader.GeoTiffReader
28-
import geotrellis.raster.render.{ColorRamps, IndexedColorMap, Png}
25+
import geotrellis.raster.render.{ColorRamp, ColorRamps, Png}
2926
import org.apache.spark.ml.Pipeline
3027
import org.apache.spark.ml.classification.DecisionTreeClassifier
3128
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
3229
import org.apache.spark.ml.feature.VectorAssembler
33-
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
3430
import org.apache.spark.sql._
3531
import org.locationtech.rasterframes._
3632
import org.locationtech.rasterframes.datasource.raster._
3733
import org.locationtech.rasterframes.ml.{NoDataFilter, TileExploder}
3834

39-
object Classification extends App {
35+
36+
object ClassificationRasterSource extends App {
4037

4138
// // Utility for reading imagery from our test data set
42-
def href(name: String): URL = getClass.getResource(s"/$name")
39+
def href(name: String) = "https://raw.githubusercontent.com/locationtech/rasterframes/develop/core/src/test/resources/" + name
4340

4441
implicit val spark = SparkSession.builder()
4542
.master("local[*]")
@@ -61,8 +58,11 @@ object Classification extends App {
6158

6259
val catalog = s"${bandColNames.mkString(",")},target\n${bandSrcs.mkString(",")}, $labelSrc"
6360

61+
6462
// For each identified band, load the associated image file
6563
val abt = spark.read.raster.fromCSV(catalog, bandColNames :+ "target": _*).load()
64+
.withColumn("crs", rf_crs($"band_4"))
65+
.withColumn("extent", rf_extent($"band_4"))
6666

6767
// We should see a single spatial_key column along with 4 columns of tiles.
6868
abt.printSchema()
@@ -101,33 +101,11 @@ object Classification extends App {
101101
.setPredictionCol("prediction")
102102
.setMetricName("f1")
103103

104-
// Use a parameter grid to determine what the optimal max tree depth is for this data
105-
val paramGrid = new ParamGridBuilder()
106-
//.addGrid(classifier.maxDepth, Array(1, 2, 3, 4))
107-
.build()
108-
109-
// Configure the cross validator
110-
val trainer = new CrossValidator()
111-
.setEstimator(pipeline)
112-
.setEvaluator(evaluator)
113-
.setEstimatorParamMaps(paramGrid)
114-
.setNumFolds(4)
115-
116-
// Push the "go" button
117-
val model = trainer.fit(abt)
118-
119-
// Format the `paramGrid` settings resultant model
120-
val metrics = model.getEstimatorParamMaps
121-
.map(_.toSeq.map(p s"${p.param.name} = ${p.value}"))
122-
.map(_.mkString(", "))
123-
.zip(model.avgMetrics)
124-
125-
// Render the parameter/performance association
126-
metrics.toSeq.toDF("params", "metric").show(false)
104+
val model = pipeline.fit(abt)
127105

128106
// Score the original data set, including cells
129107
// without target values.
130-
val scored = model.bestModel.transform(abt)
108+
val scored = model.transform(abt.drop("target"))
131109

132110
// Add up class membership results
133111
scored.groupBy($"prediction" as "class").count().show
@@ -141,7 +119,11 @@ object Classification extends App {
141119
)
142120
)
143121

144-
val pngBytes = retiled.select(rf_render_png($"target", ColorRamps.Viridis)).first
122+
val clusterColors = ColorRamp(
123+
ColorRamps.Viridis.toColorMap((0 until 3).toArray).colors
124+
)
125+
126+
val pngBytes = retiled.select(rf_render_png($"prediction", clusterColors)).first
145127

146128
Png(pngBytes).write("classified.png")
147129

0 commit comments

Comments
 (0)