Skip to content

Commit 108ee65

Browse files
authored
Merge pull request #448 from s22s/feature/extra-examples
Reconstituted two forms of the mini classification example in Scala.
2 parents 43bd3b3 + fe80a85 commit 108ee65

File tree

2 files changed

+314
-0
lines changed

2 files changed

+314
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2020 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package examples
23+
24+
import org.locationtech.rasterframes._
25+
import geotrellis.raster._
26+
import geotrellis.raster.io.geotiff.reader.GeoTiffReader
27+
import geotrellis.raster.render.{ColorRamps, IndexedColorMap}
28+
import org.apache.spark.ml.Pipeline
29+
import org.apache.spark.ml.classification.DecisionTreeClassifier
30+
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
31+
import org.apache.spark.ml.feature.VectorAssembler
32+
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
33+
import org.apache.spark.sql._
34+
import org.locationtech.rasterframes.ml.{NoDataFilter, TileExploder}
35+
36+
object Classification extends App {
37+
38+
// // Utility for reading imagery from our test data set
39+
def readTiff(name: String) = GeoTiffReader.readSingleband(getClass.getResource(s"/$name").getPath)
40+
41+
implicit val spark = SparkSession.builder()
42+
.master("local[*]")
43+
.appName(getClass.getName)
44+
.withKryoSerialization
45+
.getOrCreate()
46+
.withRasterFrames
47+
48+
import spark.implicits._
49+
50+
// The first step is to load multiple bands of imagery and construct
51+
// a single RasterFrame from them.
52+
val filenamePattern = "L8-%s-Elkton-VA.tiff"
53+
val bandNumbers = 2 to 7
54+
val bandColNames = bandNumbers.map(b s"band_$b").toArray
55+
val tileSize = 128
56+
57+
// For each identified band, load the associated image file
58+
val joinedRF = bandNumbers
59+
.map { b (b, filenamePattern.format("B" + b)) }
60+
.map { case (b, f) (b, readTiff(f)) }
61+
.map { case (b, t) t.projectedRaster.toLayer(tileSize, tileSize, s"band_$b") }
62+
.reduce(_ spatialJoin _)
63+
.withCRS()
64+
.withExtent()
65+
66+
// We should see a single spatial_key column along with 4 columns of tiles.
67+
joinedRF.printSchema()
68+
69+
// Similarly pull in the target label data.
70+
val targetCol = "target"
71+
72+
// Load the target label raster. We have to convert the cell type to
73+
// Double to meet expectations of SparkML
74+
val target = readTiff(filenamePattern.format("Labels"))
75+
.mapTile(_.convert(DoubleConstantNoDataCellType))
76+
.projectedRaster
77+
.toLayer(tileSize, tileSize, targetCol)
78+
79+
// Take a peek at what kind of label data we have to work with.
80+
target.select(rf_agg_stats(target(targetCol))).show
81+
82+
val abt = joinedRF.spatialJoin(target)
83+
84+
// SparkML requires that each observation be in its own row, and those
85+
// observations be packed into a single `Vector`. The first step is to
86+
// "explode" the tiles into a single row per cell/pixel
87+
val exploder = new TileExploder()
88+
89+
val noDataFilter = new NoDataFilter()
90+
.setInputCols(bandColNames :+ targetCol)
91+
92+
// To "vectorize" the the band columns we use the SparkML `VectorAssembler`
93+
val assembler = new VectorAssembler()
94+
.setInputCols(bandColNames)
95+
.setOutputCol("features")
96+
97+
// Using a decision tree for classification
98+
val classifier = new DecisionTreeClassifier()
99+
.setLabelCol(targetCol)
100+
.setFeaturesCol(assembler.getOutputCol)
101+
102+
// Assemble the model pipeline
103+
val pipeline = new Pipeline()
104+
.setStages(Array(exploder, noDataFilter, assembler, classifier))
105+
106+
// Configure how we're going to evaluate our model's performance.
107+
val evaluator = new MulticlassClassificationEvaluator()
108+
.setLabelCol(targetCol)
109+
.setPredictionCol("prediction")
110+
.setMetricName("f1")
111+
112+
// Use a parameter grid to determine what the optimal max tree depth is for this data
113+
val paramGrid = new ParamGridBuilder()
114+
//.addGrid(classifier.maxDepth, Array(1, 2, 3, 4))
115+
.build()
116+
117+
// Configure the cross validator
118+
val trainer = new CrossValidator()
119+
.setEstimator(pipeline)
120+
.setEvaluator(evaluator)
121+
.setEstimatorParamMaps(paramGrid)
122+
.setNumFolds(4)
123+
124+
// Push the "go" button
125+
val model = trainer.fit(abt)
126+
127+
// Format the `paramGrid` settings resultant model
128+
val metrics = model.getEstimatorParamMaps
129+
.map(_.toSeq.map(p s"${p.param.name} = ${p.value}"))
130+
.map(_.mkString(", "))
131+
.zip(model.avgMetrics)
132+
133+
// Render the parameter/performance association
134+
metrics.toSeq.toDF("params", "metric").show(false)
135+
136+
// Score the original data set, including cells
137+
// without target values.
138+
val scored = model.bestModel.transform(joinedRF)
139+
140+
// Add up class membership results
141+
scored.groupBy($"prediction" as "class").count().show
142+
143+
scored.show(10)
144+
145+
val tlm = joinedRF.tileLayerMetadata.left.get
146+
147+
val retiled: DataFrame = scored.groupBy($"crs", $"extent").agg(
148+
rf_assemble_tile(
149+
$"column_index", $"row_index", $"prediction",
150+
tlm.tileCols, tlm.tileRows, IntConstantNoDataCellType
151+
)
152+
)
153+
154+
val rf: RasterFrameLayer = retiled.toLayer(tlm)
155+
156+
val raster = rf.toRaster($"prediction", 186, 169)
157+
158+
val clusterColors = IndexedColorMap.fromColorMap(
159+
ColorRamps.Viridis.toColorMap((0 until 3).toArray)
160+
)
161+
162+
raster.tile.renderPng(clusterColors).write("classified.png")
163+
164+
spark.stop()
165+
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/*
2+
* This software is licensed under the Apache 2 license, quoted below.
3+
*
4+
* Copyright 2020 Astraea, Inc.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+
* use this file except in compliance with the License. You may obtain a copy of
8+
* the License at
9+
*
10+
* [http://www.apache.org/licenses/LICENSE-2.0]
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
* License for the specific language governing permissions and limitations under
16+
* the License.
17+
*
18+
* SPDX-License-Identifier: Apache-2.0
19+
*
20+
*/
21+
22+
package examples
23+
24+
import java.net.URL
25+
26+
import geotrellis.raster._
27+
import geotrellis.raster.io.geotiff.reader.GeoTiffReader
28+
import geotrellis.raster.render.{ColorRamps, IndexedColorMap, Png}
29+
import org.apache.spark.ml.Pipeline
30+
import org.apache.spark.ml.classification.DecisionTreeClassifier
31+
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
32+
import org.apache.spark.ml.feature.VectorAssembler
33+
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
34+
import org.apache.spark.sql._
35+
import org.locationtech.rasterframes._
36+
import org.locationtech.rasterframes.datasource.raster._
37+
import org.locationtech.rasterframes.ml.{NoDataFilter, TileExploder}
38+
39+
object Classification extends App {
40+
41+
// // Utility for reading imagery from our test data set
42+
def href(name: String): URL = getClass.getResource(s"/$name")
43+
44+
implicit val spark = SparkSession.builder()
45+
.master("local[*]")
46+
.appName(getClass.getName)
47+
.withKryoSerialization
48+
.getOrCreate()
49+
.withRasterFrames
50+
51+
import spark.implicits._
52+
53+
// The first step is to load multiple bands of imagery and construct
54+
// a single RasterFrame from them.
55+
val filenamePattern = "L8-%s-Elkton-VA.tiff"
56+
val bandNumbers = 2 to 7
57+
val bandColNames = bandNumbers.map(b s"band_$b").toArray
58+
val bandSrcs = bandNumbers.map(n => filenamePattern.format("B" + n)).map(href)
59+
val labelSrc = href(filenamePattern.format("Labels"))
60+
val tileSize = 128
61+
62+
val catalog = s"${bandColNames.mkString(",")},target\n${bandSrcs.mkString(",")}, $labelSrc"
63+
64+
// For each identified band, load the associated image file
65+
val abt = spark.read.raster.fromCSV(catalog, bandColNames :+ "target": _*).load()
66+
67+
// We should see a single spatial_key column along with 4 columns of tiles.
68+
abt.printSchema()
69+
70+
// Similarly pull in the target label data.
71+
val targetCol = "target"
72+
73+
// Take a peek at what kind of label data we have to work with.
74+
abt.select(rf_agg_stats(abt(targetCol))).show
75+
76+
// SparkML requires that each observation be in its own row, and those
77+
// observations be packed into a single `Vector`. The first step is to
78+
// "explode" the tiles into a single row per cell/pixel
79+
val exploder = new TileExploder()
80+
81+
val noDataFilter = new NoDataFilter()
82+
.setInputCols(bandColNames :+ targetCol)
83+
84+
// To "vectorize" the the band columns we use the SparkML `VectorAssembler`
85+
val assembler = new VectorAssembler()
86+
.setInputCols(bandColNames)
87+
.setOutputCol("features")
88+
89+
// Using a decision tree for classification
90+
val classifier = new DecisionTreeClassifier()
91+
.setLabelCol(targetCol)
92+
.setFeaturesCol(assembler.getOutputCol)
93+
94+
// Assemble the model pipeline
95+
val pipeline = new Pipeline()
96+
.setStages(Array(exploder, noDataFilter, assembler, classifier))
97+
98+
// Configure how we're going to evaluate our model's performance.
99+
val evaluator = new MulticlassClassificationEvaluator()
100+
.setLabelCol(targetCol)
101+
.setPredictionCol("prediction")
102+
.setMetricName("f1")
103+
104+
// Use a parameter grid to determine what the optimal max tree depth is for this data
105+
val paramGrid = new ParamGridBuilder()
106+
//.addGrid(classifier.maxDepth, Array(1, 2, 3, 4))
107+
.build()
108+
109+
// Configure the cross validator
110+
val trainer = new CrossValidator()
111+
.setEstimator(pipeline)
112+
.setEvaluator(evaluator)
113+
.setEstimatorParamMaps(paramGrid)
114+
.setNumFolds(4)
115+
116+
// Push the "go" button
117+
val model = trainer.fit(abt)
118+
119+
// Format the `paramGrid` settings resultant model
120+
val metrics = model.getEstimatorParamMaps
121+
.map(_.toSeq.map(p s"${p.param.name} = ${p.value}"))
122+
.map(_.mkString(", "))
123+
.zip(model.avgMetrics)
124+
125+
// Render the parameter/performance association
126+
metrics.toSeq.toDF("params", "metric").show(false)
127+
128+
// Score the original data set, including cells
129+
// without target values.
130+
val scored = model.bestModel.transform(abt)
131+
132+
// Add up class membership results
133+
scored.groupBy($"prediction" as "class").count().show
134+
135+
scored.show(10)
136+
137+
val retiled: DataFrame = scored.groupBy($"crs", $"extent").agg(
138+
rf_assemble_tile(
139+
$"column_index", $"row_index", $"prediction",
140+
186, 169, IntConstantNoDataCellType
141+
)
142+
)
143+
144+
val pngBytes = retiled.select(rf_render_png($"target", ColorRamps.Viridis)).first
145+
146+
Png(pngBytes).write("classified.png")
147+
148+
spark.stop()
149+
}

0 commit comments

Comments
 (0)