@@ -32,7 +32,7 @@ catalog_df = pd.DataFrame([
3232 {b: uri_base.format(b) for b in cols}
3333])
3434
35- df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128 )) \
35+ df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(256, 256 )) \
3636 .repartition(100)
3737
3838df = df.select(
@@ -91,23 +91,12 @@ To filter only for good quality pixels, we follow roughly the same procedure as
9191```python, make_mask
9292from pyspark.sql.functions import lit
9393
94- mask_part = df_labeled \
95- .withColumn('nodata', rf_local_equal('scl', lit(0))) \
96- .withColumn('defect', rf_local_equal('scl', lit(1))) \
97- .withColumn('cloud8', rf_local_equal('scl', lit(8))) \
98- .withColumn('cloud9', rf_local_equal('scl', lit(9))) \
99- .withColumn('cirrus', rf_local_equal('scl', lit(10)))
100-
101- df_mask_inv = mask_part \
102- .withColumn('mask', rf_local_add('nodata', 'defect')) \
103- .withColumn('mask', rf_local_add('mask', 'cloud8')) \
104- .withColumn('mask', rf_local_add('mask', 'cloud9')) \
105- .withColumn('mask', rf_local_add('mask', 'cirrus')) \
106- .drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
107-
94+ df_labeled = df_labeled \
95+ .withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10]))
96+
10897# at this point the mask contains 0 for good cells and 1 for defect, etc
10998# convert cell type and set value 1 to NoData
110- df_mask = df_mask_inv .withColumn('mask',
99+ df_mask = df_labeled .withColumn('mask',
111100 rf_with_no_data(rf_convert_cell_type('mask', 'uint8'), 1.0)
112101)
113102
@@ -213,20 +202,26 @@ retiled.printSchema()
213202```
214203
215204Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image.
205+ Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
216206
217207```python, display_rgb
218208sample = retiled \
219- .select('prediction', rf_rgb_composite( 'red', 'grn', 'blu').alias('rgb') ) \
209+ .select('prediction', 'red', 'grn', 'blu') \
220210 .sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \
221211 .first()
222212
223- sample_rgb = sample['rgb']
224- mins = np.nanmin(sample_rgb.cells, axis=(0,1))
225- plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins))
226- ```
213+ sample_rgb = np.concatenate([sample['red'].cells[:, :, None],
214+ sample['grn'].cells[ :, :, None],
215+ sample['blu'].cells[ :, :, None]], axis=2)
216+ # plot scaled RGB
217+ scaling_quantiles = np.nanpercentile(sample_rgb, [3.00, 97.00], axis=(0,1))
218+ scaled = np.clip(sample_rgb, scaling_quantiles[0, :], scaling_quantiles[1, :])
219+ scaled -= scaling_quantiles[0, :]
220+ scaled /= (scaling_quantiles[1, : ] - scaling_quantiles[0, :])
227221
228- Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
222+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
223+ ax1.imshow(scaled)
229224
230- ```python, display_prediction
231- display (sample['prediction'])
225+ # display prediction
226+ ax2.imshow (sample['prediction'].cells )
232227```
0 commit comments