2121
2222package examples
2323
24- import java .net .URL
25-
2624import geotrellis .raster ._
27- import geotrellis .raster .io .geotiff .reader .GeoTiffReader
28- import geotrellis .raster .render .{ColorRamps , IndexedColorMap , Png }
25+ import geotrellis .raster .render .{ColorRamp , ColorRamps , Png }
2926import org .apache .spark .ml .Pipeline
3027import org .apache .spark .ml .classification .DecisionTreeClassifier
3128import org .apache .spark .ml .evaluation .MulticlassClassificationEvaluator
3229import org .apache .spark .ml .feature .VectorAssembler
33- import org .apache .spark .ml .tuning .{CrossValidator , ParamGridBuilder }
3430import org .apache .spark .sql ._
3531import org .locationtech .rasterframes ._
3632import org .locationtech .rasterframes .datasource .raster ._
3733import org .locationtech .rasterframes .ml .{NoDataFilter , TileExploder }
3834
39- object Classification extends App {
35+
36+ object ClassificationRasterSource extends App {
4037
4138 // // Utility for reading imagery from our test data set
42- def href (name : String ): URL = getClass.getResource( s " / $name " )
39+ def href (name : String ) = " https://raw.githubusercontent.com/locationtech/rasterframes/develop/core/src/test/resources/ " + name
4340
4441 implicit val spark = SparkSession .builder()
4542 .master(" local[*]" )
@@ -61,8 +58,11 @@ object Classification extends App {
6158
6259 val catalog = s " ${bandColNames.mkString(" ," )},target \n ${bandSrcs.mkString(" ," )}, $labelSrc"
6360
61+
6462 // For each identified band, load the associated image file
6563 val abt = spark.read.raster.fromCSV(catalog, bandColNames :+ " target" : _* ).load()
64+ .withColumn(" crs" , rf_crs($" band_4" ))
65+ .withColumn(" extent" , rf_extent($" band_4" ))
6666
6767 // We should see a single spatial_key column along with 4 columns of tiles.
6868 abt.printSchema()
@@ -101,33 +101,11 @@ object Classification extends App {
101101 .setPredictionCol(" prediction" )
102102 .setMetricName(" f1" )
103103
104- // Use a parameter grid to determine what the optimal max tree depth is for this data
105- val paramGrid = new ParamGridBuilder ()
106- // .addGrid(classifier.maxDepth, Array(1, 2, 3, 4))
107- .build()
108-
109- // Configure the cross validator
110- val trainer = new CrossValidator ()
111- .setEstimator(pipeline)
112- .setEvaluator(evaluator)
113- .setEstimatorParamMaps(paramGrid)
114- .setNumFolds(4 )
115-
116- // Push the "go" button
117- val model = trainer.fit(abt)
118-
119- // Format the `paramGrid` settings resultant model
120- val metrics = model.getEstimatorParamMaps
121- .map(_.toSeq.map(p ⇒ s " ${p.param.name} = ${p.value}" ))
122- .map(_.mkString(" , " ))
123- .zip(model.avgMetrics)
124-
125- // Render the parameter/performance association
126- metrics.toSeq.toDF(" params" , " metric" ).show(false )
104+ val model = pipeline.fit(abt)
127105
128106 // Score the original data set, including cells
129107 // without target values.
130- val scored = model.bestModel. transform(abt)
108+ val scored = model.transform(abt.drop( " target " ) )
131109
132110 // Add up class membership results
133111 scored.groupBy($" prediction" as " class" ).count().show
@@ -141,7 +119,11 @@ object Classification extends App {
141119 )
142120 )
143121
144- val pngBytes = retiled.select(rf_render_png($" target" , ColorRamps .Viridis )).first
122+ val clusterColors = ColorRamp (
123+ ColorRamps .Viridis .toColorMap((0 until 3 ).toArray).colors
124+ )
125+
126+ val pngBytes = retiled.select(rf_render_png($" prediction" , clusterColors)).first
145127
146128 Png (pngBytes).write(" classified.png" )
147129
0 commit comments