Skip to content

Commit 52983e3

Browse files
committed
Update doc to use rf_local_is_in when masking; fix #351
Signed-off-by: Jason T. Brown <[email protected]>
1 parent e7b3b90 commit 52983e3

File tree

2 files changed

+26
-41
lines changed

2 files changed

+26
-41
lines changed

pyrasterframes/src/main/python/docs/nodata-handling.pymd

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -105,32 +105,23 @@ Drawing on @ref:[local map algebra](local-algebra.md) techniques, we will create
105105
```python, def_mask
106106
from pyspark.sql.functions import lit
107107

108-
mask_part = unmasked.withColumn('nodata', rf_local_equal('scl', lit(0))) \
109-
.withColumn('defect', rf_local_equal('scl', lit(1))) \
110-
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
111-
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
112-
.withColumn('cirrus', rf_local_equal('scl', lit(10)))
113-
114-
one_mask = mask_part.withColumn('mask', rf_local_add('nodata', 'defect')) \
115-
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
116-
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
117-
.withColumn('mask', rf_local_add('mask', 'cirrus'))
118-
119-
cell_types = one_mask.select(rf_cell_type('mask')).distinct()
108+
mask = unmasked.withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10])
109+
110+
cell_types = mask.select(rf_cell_type('mask')).distinct()
120111
cell_types
121112
```
122113

123114
Because there is not a NoData already defined, we will choose one. In this particular example, the minimum value is greater than zero, so we can use 0 as the NoData value.
124115

125116
```python, pick_nd
126-
blue_min = one_mask.agg(rf_agg_stats('blue').min.alias('blue_min'))
117+
blue_min = mask.agg(rf_agg_stats('blue').min.alias('blue_min'))
127118
blue_min
128119
```
129120

130121
We can now construct the cell type string for our blue band's cell type, designating 0 as NoData.
131122

132123
```python, get_ct_string
133-
blue_ct = one_mask.select(rf_cell_type('blue')).distinct().first()[0][0]
124+
blue_ct = mask.select(rf_cell_type('blue')).distinct().first()[0][0]
134125
masked_blue_ct = CellType(blue_ct).with_no_data_value(0)
135126
masked_blue_ct.cell_type_name
136127
```
@@ -139,9 +130,8 @@ Now we will use the @ref:[`rf_mask_by_value`](reference.md#rf-mask-by-value) to
139130

140131
```python, mask_blu
141132
with_nd = rf_convert_cell_type('blue', masked_blue_ct)
142-
masked = one_mask.withColumn('blue_masked',
143-
rf_mask_by_value(with_nd, 'mask', lit(1))) \
144-
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus', 'blue')
133+
masked = mask.withColumn('blue_masked',
134+
rf_mask_by_value(with_nd, 'mask', lit(1)))
145135
```
146136

147137
We can verify that the number of NoData cells in the resulting `blue_masked` column matches the total of the boolean `mask` _tile_ to ensure our logic is correct.

pyrasterframes/src/main/python/docs/supervised-learning.pymd

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ catalog_df = pd.DataFrame([
3232
{b: uri_base.format(b) for b in cols}
3333
])
3434

35-
df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128)) \
35+
df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(256, 256)) \
3636
.repartition(100)
3737

3838
df = df.select(
@@ -91,23 +91,12 @@ To filter only for good quality pixels, we follow roughly the same procedure as
9191
```python, make_mask
9292
from pyspark.sql.functions import lit
9393

94-
mask_part = df_labeled \
95-
.withColumn('nodata', rf_local_equal('scl', lit(0))) \
96-
.withColumn('defect', rf_local_equal('scl', lit(1))) \
97-
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
98-
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
99-
.withColumn('cirrus', rf_local_equal('scl', lit(10)))
100-
101-
df_mask_inv = mask_part \
102-
.withColumn('mask', rf_local_add('nodata', 'defect')) \
103-
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
104-
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
105-
.withColumn('mask', rf_local_add('mask', 'cirrus')) \
106-
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
107-
94+
df_labeled = df_labeled \
95+
.withColumn('mask', rf_local_is_in('scl', [0, 1, 8, 9, 10]))
96+
10897
# at this point the mask contains 0 for good cells and 1 for defect, etc
10998
# convert cell type and set value 1 to NoData
110-
df_mask = df_mask_inv.withColumn('mask',
99+
df_mask = df_labeled.withColumn('mask',
111100
rf_with_no_data(rf_convert_cell_type('mask', 'uint8'), 1.0)
112101
)
113102

@@ -213,20 +202,26 @@ retiled.printSchema()
213202
```
214203

215204
Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image.
205+
Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
216206

217207
```python, display_rgb
218208
sample = retiled \
219-
.select('prediction', rf_rgb_composite('red', 'grn', 'blu').alias('rgb')) \
209+
.select('prediction', 'red', 'grn', 'blu') \
220210
.sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \
221211
.first()
222212

223-
sample_rgb = sample['rgb']
224-
mins = np.nanmin(sample_rgb.cells, axis=(0,1))
225-
plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins))
226-
```
213+
sample_rgb = np.concatenate([sample['red'].cells[:, :, None],
214+
sample['grn'].cells[ :, :, None],
215+
sample['blu'].cells[ :, :, None]], axis=2)
216+
# plot scaled RGB
217+
scaling_quantiles = np.nanpercentile(sample_rgb, [3.00, 97.00], axis=(0,1))
218+
scaled = np.clip(sample_rgb, scaling_quantiles[0, :], scaling_quantiles[1, :])
219+
scaled -= scaling_quantiles[0, :]
220+
scaled /= (scaling_quantiles[1, : ] - scaling_quantiles[0, :])
227221

228-
Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
222+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
223+
ax1.imshow(scaled)
229224

230-
```python, display_prediction
231-
display(sample['prediction'])
225+
# display prediction
226+
ax2.imshow(sample['prediction'].cells)
232227
```

0 commit comments

Comments
 (0)