@@ -72,17 +72,18 @@ crses = df.select('crs.crsProj4').distinct().collect()
7272print('Found ', len(crses), 'distinct CRS.')
7373crs = crses[0][0]
7474
75- label_df = spark.read.geojson(os.path.join(resource_dir_uri(), 'luray-labels.geojson')) \
76- .select('id', st_reproject('geometry', lit('EPSG:4326'), lit(crs)).alias('geometry')) \
77- .hint('broadcast')
75+ label_df = spark.read.geojson(
76+ os.path.join(resource_dir_uri(), 'luray-labels.geojson')) \
77+ .select('id', st_reproject('geometry', lit('EPSG:4326'), lit(crs)).alias('geometry')) \
78+ .hint('broadcast')
7879
79- df_joined = df.join(label_df, st_intersects(st_geometry('extent'), 'geometry'))
80+ df_joined = df.join(label_df, st_intersects(st_geometry('extent'), 'geometry')) \
81+ .withColumn('dims', rf_dimensions('B01'))
82+
83+ df_labeled = df_joined.withColumn('label',
84+ rf_rasterize('geometry', st_geometry('extent'), 'id', 'dims.cols', 'dims.rows')
85+ )
8086
81- df_joined.createOrReplaceTempView('df_joined')
82- df_labeled = spark.sql("""
83- SELECT *, rf_rasterize(geometry, st_geometry(extent), id, rf_dimensions(B01).cols, rf_dimensions(B01).rows) AS label
84- FROM df_joined
85- """)
8687```
8788
8889## Masking Poor Quality Cells
@@ -92,17 +93,20 @@ To filter only for good quality pixels, we follow roughly the same procedure as
9293```python, make_mask
9394from pyspark.sql.functions import lit
9495
95- mask_part = df_labeled.withColumn('nodata', rf_local_equal('scl', lit(0))) \
96- .withColumn('defect', rf_local_equal('scl', lit(1))) \
97- .withColumn('cloud8', rf_local_equal('scl', lit(8))) \
98- .withColumn('cloud9', rf_local_equal('scl', lit(9))) \
99- .withColumn('cirrus', rf_local_equal('scl', lit(10)))
100-
101- df_mask_inv = mask_part.withColumn('mask', rf_local_add('nodata', 'defect')) \
102- .withColumn('mask', rf_local_add('mask', 'cloud8')) \
103- .withColumn('mask', rf_local_add('mask', 'cloud9')) \
104- .withColumn('mask', rf_local_add('mask', 'cirrus')) \
105- .drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
96+ mask_part = df_labeled \
97+ .withColumn('nodata', rf_local_equal('scl', lit(0))) \
98+ .withColumn('defect', rf_local_equal('scl', lit(1))) \
99+ .withColumn('cloud8', rf_local_equal('scl', lit(8))) \
100+ .withColumn('cloud9', rf_local_equal('scl', lit(9))) \
101+ .withColumn('cirrus', rf_local_equal('scl', lit(10)))
102+
103+ df_mask_inv = mask_part \
104+ .withColumn('mask', rf_local_add('nodata', 'defect')) \
105+ .withColumn('mask', rf_local_add('mask', 'cloud8')) \
106+ .withColumn('mask', rf_local_add('mask', 'cloud9')) \
107+ .withColumn('mask', rf_local_add('mask', 'cirrus')) \
108+ .drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
109+
106110# at this point the mask contains 0 for good cells and 1 for defect, etc
107111# convert cell type and set value 1 to NoData
108112df_mask = df_mask_inv.withColumn('mask',
@@ -159,7 +163,9 @@ pipeline.getStages()
159163The next step is to actually run each step of the Pipeline we created, including fitting the decision tree model. We filter the DataFrame for only _tiles_ intersecting the label raster because the label shapes are relatively sparse over the imagery. It would be logically equivalent to either include or exclude thi step, but it is more efficient to filter because it will mean less data going into the pipeline.
160164
161165```python, train
162- model = pipeline.fit(df_mask.filter(rf_tile_sum('label') > 0).cache())
166+ model_input = df_mask.filter(rf_tile_sum('label') > 0).cache()
167+ display(model_input)
168+ model = pipeline.fit(model_input)
163169```
164170
165171## Model Evaluation
@@ -171,9 +177,11 @@ prediction_df = model.transform(df_mask) \
171177 .drop(assembler.getOutputCol()).cache()
172178prediction_df.printSchema()
173179
174- eval = MulticlassClassificationEvaluator(predictionCol=classifier.getPredictionCol(),
175- labelCol=classifier.getLabelCol(),
176- metricName='accuracy')
180+ eval = MulticlassClassificationEvaluator(
181+ predictionCol=classifier.getPredictionCol(),
182+ labelCol=classifier.getLabelCol(),
183+ metricName='accuracy'
184+ )
177185
178186accuracy = eval.evaluate(prediction_df)
179187print("\nAccuracy:", accuracy)
@@ -185,7 +193,7 @@ As an example of using the flexibility provided by DataFrames, the code below co
185193cnf_mtrx = prediction_df.groupBy(classifier.getPredictionCol()) \
186194 .pivot(classifier.getLabelCol()) \
187195 .count() \
188- .sort(classifier.getPredictionCol())
196+ .sort(classifier.getPredictionCol())
189197cnf_mtrx
190198```
191199
@@ -195,40 +203,33 @@ Because the pipeline included a `TileExploder`, we will recreate the tiled data
195203
196204```python, assemble_prediction
197205scored = model.transform(df_mask.drop('label'))
198- scored.createOrReplaceTempView('scored')
199-
200- retiled = spark.sql("""
201- SELECT extent, crs,
202- rf_assemble_tile(column_index, row_index, prediction, 128, 128) as prediction,
203- rf_assemble_tile(column_index, row_index, B04, 128, 128) as red,
204- rf_assemble_tile(column_index, row_index, B03, 128, 128) as grn,
205- rf_assemble_tile(column_index, row_index, B02, 128, 128) as blu
206- FROM scored
207- GROUP BY extent, crs
208- """)
209206
207+ retiled = scored \
208+ .groupBy('extent', 'crs') \
209+ .agg(
210+ rf_assemble_tile('column_index', 'row_index', 'prediction', 128, 128).alias('prediction'),
211+ rf_assemble_tile('column_index', 'row_index', 'B04', 128, 128).alias('red'),
212+ rf_assemble_tile('column_index', 'row_index', 'B03', 128, 128).alias('grn'),
213+ rf_assemble_tile('column_index', 'row_index', 'B02', 128, 128).alias('blu')
214+ )
210215retiled.printSchema()
211216```
212217
213218Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image.
214219
215220```python, display_rgb
216- sample = retiled.select('prediction', 'red', 'grn', 'blu') \
221+ sample = retiled \
222+ .select('prediction', rf_rgb_composite('red', 'grn', 'blu').alias('rgb')) \
217223 .sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \
218224 .first()
219225
220- sample_prediction = sample['prediction']
221-
222- red = sample['red'].cells
223- grn = sample['grn'].cells
224- blu = sample['blu'].cells
225- sample_rgb = np.concatenate([red[ :, :, None], grn[:, :, None] , blu[ :, :, None]], axis=2)
226- mins = np.nanmin(sample_rgb, axis=(0,1))
227- plt.imshow((sample_rgb - mins)/ (np.nanmax(sample_rgb, axis=(0,1)) - mins))
226+ sample_rgb = sample['rgb']
227+ mins = np.nanmin(sample_rgb.cells, axis=(0,1))
228+ plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins))
228229```
229230
230231Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
231232
232233```python, display_prediction
233- display(sample_prediction )
234+ display(sample['prediction'] )
234235```
0 commit comments