1+ /*
2+ * This software is licensed under the Apache 2 license, quoted below.
3+ *
4+ * Copyright 2020 Astraea, Inc.
5+ *
6+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
7+ * use this file except in compliance with the License. You may obtain a copy of
8+ * the License at
9+ *
10+ * [http://www.apache.org/licenses/LICENSE-2.0]
11+ *
12+ * Unless required by applicable law or agreed to in writing, software
13+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+ * License for the specific language governing permissions and limitations under
16+ * the License.
17+ *
18+ * SPDX-License-Identifier: Apache-2.0
19+ *
20+ */
21+
22+ package examples
23+
24+ import org .locationtech .rasterframes ._
25+ import geotrellis .raster ._
26+ import geotrellis .raster .io .geotiff .reader .GeoTiffReader
27+ import geotrellis .raster .render .{ColorRamps , IndexedColorMap }
28+ import org .apache .spark .ml .Pipeline
29+ import org .apache .spark .ml .classification .DecisionTreeClassifier
30+ import org .apache .spark .ml .evaluation .MulticlassClassificationEvaluator
31+ import org .apache .spark .ml .feature .VectorAssembler
32+ import org .apache .spark .ml .tuning .{CrossValidator , ParamGridBuilder }
33+ import org .apache .spark .sql ._
34+ import org .locationtech .rasterframes .ml .{NoDataFilter , TileExploder }
35+
36+ object Classification extends App {
37+
38+ // // Utility for reading imagery from our test data set
39+ def readTiff (name : String ) = GeoTiffReader .readSingleband(getClass.getResource(s " / $name" ).getPath)
40+
41+ implicit val spark = SparkSession .builder()
42+ .master(" local[*]" )
43+ .appName(getClass.getName)
44+ .withKryoSerialization
45+ .getOrCreate()
46+ .withRasterFrames
47+
48+ import spark .implicits ._
49+
50+ // The first step is to load multiple bands of imagery and construct
51+ // a single RasterFrame from them.
52+ val filenamePattern = " L8-%s-Elkton-VA.tiff"
53+ val bandNumbers = 2 to 7
54+ val bandColNames = bandNumbers.map(b ⇒ s " band_ $b" ).toArray
55+ val tileSize = 128
56+
57+ // For each identified band, load the associated image file
58+ val joinedRF = bandNumbers
59+ .map { b ⇒ (b, filenamePattern.format(" B" + b)) }
60+ .map { case (b, f) ⇒ (b, readTiff(f)) }
61+ .map { case (b, t) ⇒ t.projectedRaster.toLayer(tileSize, tileSize, s " band_ $b" ) }
62+ .reduce(_ spatialJoin _)
63+ .withCRS()
64+ .withExtent()
65+
66+ // We should see a single spatial_key column along with 4 columns of tiles.
67+ joinedRF.printSchema()
68+
69+ // Similarly pull in the target label data.
70+ val targetCol = " target"
71+
72+ // Load the target label raster. We have to convert the cell type to
73+ // Double to meet expectations of SparkML
74+ val target = readTiff(filenamePattern.format(" Labels" ))
75+ .mapTile(_.convert(DoubleConstantNoDataCellType ))
76+ .projectedRaster
77+ .toLayer(tileSize, tileSize, targetCol)
78+
79+ // Take a peek at what kind of label data we have to work with.
80+ target.select(rf_agg_stats(target(targetCol))).show
81+
82+ val abt = joinedRF.spatialJoin(target)
83+
84+ // SparkML requires that each observation be in its own row, and those
85+ // observations be packed into a single `Vector`. The first step is to
86+ // "explode" the tiles into a single row per cell/pixel
87+ val exploder = new TileExploder ()
88+
89+ val noDataFilter = new NoDataFilter ()
90+ .setInputCols(bandColNames :+ targetCol)
91+
92+ // To "vectorize" the the band columns we use the SparkML `VectorAssembler`
93+ val assembler = new VectorAssembler ()
94+ .setInputCols(bandColNames)
95+ .setOutputCol(" features" )
96+
97+ // Using a decision tree for classification
98+ val classifier = new DecisionTreeClassifier ()
99+ .setLabelCol(targetCol)
100+ .setFeaturesCol(assembler.getOutputCol)
101+
102+ // Assemble the model pipeline
103+ val pipeline = new Pipeline ()
104+ .setStages(Array (exploder, noDataFilter, assembler, classifier))
105+
106+ // Configure how we're going to evaluate our model's performance.
107+ val evaluator = new MulticlassClassificationEvaluator ()
108+ .setLabelCol(targetCol)
109+ .setPredictionCol(" prediction" )
110+ .setMetricName(" f1" )
111+
112+ // Use a parameter grid to determine what the optimal max tree depth is for this data
113+ val paramGrid = new ParamGridBuilder ()
114+ // .addGrid(classifier.maxDepth, Array(1, 2, 3, 4))
115+ .build()
116+
117+ // Configure the cross validator
118+ val trainer = new CrossValidator ()
119+ .setEstimator(pipeline)
120+ .setEvaluator(evaluator)
121+ .setEstimatorParamMaps(paramGrid)
122+ .setNumFolds(4 )
123+
124+ // Push the "go" button
125+ val model = trainer.fit(abt)
126+
127+ // Format the `paramGrid` settings resultant model
128+ val metrics = model.getEstimatorParamMaps
129+ .map(_.toSeq.map(p ⇒ s " ${p.param.name} = ${p.value}" ))
130+ .map(_.mkString(" , " ))
131+ .zip(model.avgMetrics)
132+
133+ // Render the parameter/performance association
134+ metrics.toSeq.toDF(" params" , " metric" ).show(false )
135+
136+ // Score the original data set, including cells
137+ // without target values.
138+ val scored = model.bestModel.transform(joinedRF)
139+
140+ // Add up class membership results
141+ scored.groupBy($" prediction" as " class" ).count().show
142+
143+ scored.show(10 )
144+
145+ val tlm = joinedRF.tileLayerMetadata.left.get
146+
147+ val retiled : DataFrame = scored.groupBy($" crs" , $" extent" ).agg(
148+ rf_assemble_tile(
149+ $" column_index" , $" row_index" , $" prediction" ,
150+ tlm.tileCols, tlm.tileRows, IntConstantNoDataCellType
151+ )
152+ )
153+
154+ val rf : RasterFrameLayer = retiled.toLayer(tlm)
155+
156+ val raster = rf.toRaster($" prediction" , 186 , 169 )
157+
158+ val clusterColors = IndexedColorMap .fromColorMap(
159+ ColorRamps .Viridis .toColorMap((0 until 3 ).toArray)
160+ )
161+
162+ raster.tile.renderPng(clusterColors).write(" classified.png" )
163+
164+ spark.stop()
165+ }
0 commit comments