Skip to content

Commit 9dc6a63

Browse files
committed
Updated rf_assemble_tile to accept literal or columnar tile size
specifications.
1 parent 9b5da3c commit 9dc6a63

File tree

4 files changed

+70
-51
lines changed

4 files changed

+70
-51
lines changed

core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ trait RasterFunctions {
8282
def rf_assemble_tile(columnIndex: Column, rowIndex: Column, cellData: Column, tileCols: Int, tileRows: Int, ct: CellType): TypedColumn[Any, Tile] =
8383
rf_convert_cell_type(TileAssembler(columnIndex, rowIndex, cellData, lit(tileCols), lit(tileRows)), ct).as(cellData.columnName).as[Tile](singlebandTileEncoder)
8484

85+
/** Create a Tile from a column of cell data with location indexes and preform cell conversion. */
86+
def rf_assemble_tile(columnIndex: Column, rowIndex: Column, cellData: Column, tileCols: Int, tileRows: Int): TypedColumn[Any, Tile] =
87+
TileAssembler(columnIndex, rowIndex, cellData, lit(tileCols), lit(tileRows))
88+
8589
/** Create a Tile from a column of cell data with location indexes. */
8690
def rf_assemble_tile(columnIndex: Column, rowIndex: Column, cellData: Column, tileCols: Column, tileRows: Column): TypedColumn[Any, Tile] =
8791
TileAssembler(columnIndex, rowIndex, cellData, tileCols, tileRows)

pyrasterframes/src/main/python/docs/reference.pymd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ from IPython.display import display
1717
import os.path
1818

1919
spark = pyrasterframes.get_spark_session()
20-
``
20+
```
2121

2222
## List of Available SQL and Python Functions
2323

pyrasterframes/src/main/python/docs/supervised-learning.pymd

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,18 @@ crses = df.select('crs.crsProj4').distinct().collect()
7272
print('Found ', len(crses), 'distinct CRS.')
7373
crs = crses[0][0]
7474

75-
label_df = spark.read.geojson(os.path.join(resource_dir_uri(), 'luray-labels.geojson')) \
76-
.select('id', st_reproject('geometry', lit('EPSG:4326'), lit(crs)).alias('geometry')) \
77-
.hint('broadcast')
75+
label_df = spark.read.geojson(
76+
os.path.join(resource_dir_uri(), 'luray-labels.geojson')) \
77+
.select('id', st_reproject('geometry', lit('EPSG:4326'), lit(crs)).alias('geometry')) \
78+
.hint('broadcast')
7879

79-
df_joined = df.join(label_df, st_intersects(st_geometry('extent'), 'geometry'))
80+
df_joined = df.join(label_df, st_intersects(st_geometry('extent'), 'geometry')) \
81+
.withColumn('dims', rf_dimensions('B01'))
82+
83+
df_labeled = df_joined.withColumn('label',
84+
rf_rasterize('geometry', st_geometry('extent'), 'id', 'dims.cols', 'dims.rows')
85+
)
8086

81-
df_joined.createOrReplaceTempView('df_joined')
82-
df_labeled = spark.sql("""
83-
SELECT *, rf_rasterize(geometry, st_geometry(extent), id, rf_dimensions(B01).cols, rf_dimensions(B01).rows) AS label
84-
FROM df_joined
85-
""")
8687
```
8788

8889
## Masking Poor Quality Cells
@@ -92,17 +93,20 @@ To filter only for good quality pixels, we follow roughly the same procedure as
9293
```python, make_mask
9394
from pyspark.sql.functions import lit
9495

95-
mask_part = df_labeled.withColumn('nodata', rf_local_equal('scl', lit(0))) \
96-
.withColumn('defect', rf_local_equal('scl', lit(1))) \
97-
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
98-
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
99-
.withColumn('cirrus', rf_local_equal('scl', lit(10)))
100-
101-
df_mask_inv = mask_part.withColumn('mask', rf_local_add('nodata', 'defect')) \
102-
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
103-
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
104-
.withColumn('mask', rf_local_add('mask', 'cirrus')) \
105-
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
96+
mask_part = df_labeled \
97+
.withColumn('nodata', rf_local_equal('scl', lit(0))) \
98+
.withColumn('defect', rf_local_equal('scl', lit(1))) \
99+
.withColumn('cloud8', rf_local_equal('scl', lit(8))) \
100+
.withColumn('cloud9', rf_local_equal('scl', lit(9))) \
101+
.withColumn('cirrus', rf_local_equal('scl', lit(10)))
102+
103+
df_mask_inv = mask_part \
104+
.withColumn('mask', rf_local_add('nodata', 'defect')) \
105+
.withColumn('mask', rf_local_add('mask', 'cloud8')) \
106+
.withColumn('mask', rf_local_add('mask', 'cloud9')) \
107+
.withColumn('mask', rf_local_add('mask', 'cirrus')) \
108+
.drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
109+
106110
# at this point the mask contains 0 for good cells and 1 for defect, etc
107111
# convert cell type and set value 1 to NoData
108112
df_mask = df_mask_inv.withColumn('mask',
@@ -159,7 +163,9 @@ pipeline.getStages()
159163
The next step is to actually run each step of the Pipeline we created, including fitting the decision tree model. We filter the DataFrame for only _tiles_ intersecting the label raster because the label shapes are relatively sparse over the imagery. It would be logically equivalent to either include or exclude thi step, but it is more efficient to filter because it will mean less data going into the pipeline.
160164

161165
```python, train
162-
model = pipeline.fit(df_mask.filter(rf_tile_sum('label') > 0).cache())
166+
model_input = df_mask.filter(rf_tile_sum('label') > 0).cache()
167+
display(model_input)
168+
model = pipeline.fit(model_input)
163169
```
164170

165171
## Model Evaluation
@@ -171,9 +177,11 @@ prediction_df = model.transform(df_mask) \
171177
.drop(assembler.getOutputCol()).cache()
172178
prediction_df.printSchema()
173179

174-
eval = MulticlassClassificationEvaluator(predictionCol=classifier.getPredictionCol(),
175-
labelCol=classifier.getLabelCol(),
176-
metricName='accuracy')
180+
eval = MulticlassClassificationEvaluator(
181+
predictionCol=classifier.getPredictionCol(),
182+
labelCol=classifier.getLabelCol(),
183+
metricName='accuracy'
184+
)
177185

178186
accuracy = eval.evaluate(prediction_df)
179187
print("\nAccuracy:", accuracy)
@@ -185,7 +193,7 @@ As an example of using the flexibility provided by DataFrames, the code below co
185193
cnf_mtrx = prediction_df.groupBy(classifier.getPredictionCol()) \
186194
.pivot(classifier.getLabelCol()) \
187195
.count() \
188-
.sort(classifier.getPredictionCol())
196+
.sort(classifier.getPredictionCol())
189197
cnf_mtrx
190198
```
191199

@@ -195,40 +203,33 @@ Because the pipeline included a `TileExploder`, we will recreate the tiled data
195203

196204
```python, assemble_prediction
197205
scored = model.transform(df_mask.drop('label'))
198-
scored.createOrReplaceTempView('scored')
199-
200-
retiled = spark.sql("""
201-
SELECT extent, crs,
202-
rf_assemble_tile(column_index, row_index, prediction, 128, 128) as prediction,
203-
rf_assemble_tile(column_index, row_index, B04, 128, 128) as red,
204-
rf_assemble_tile(column_index, row_index, B03, 128, 128) as grn,
205-
rf_assemble_tile(column_index, row_index, B02, 128, 128) as blu
206-
FROM scored
207-
GROUP BY extent, crs
208-
""")
209206

207+
retiled = scored \
208+
.groupBy('extent', 'crs') \
209+
.agg(
210+
rf_assemble_tile('column_index', 'row_index', 'prediction', 128, 128).alias('prediction'),
211+
rf_assemble_tile('column_index', 'row_index', 'B04', 128, 128).alias('red'),
212+
rf_assemble_tile('column_index', 'row_index', 'B03', 128, 128).alias('grn'),
213+
rf_assemble_tile('column_index', 'row_index', 'B02', 128, 128).alias('blu')
214+
)
210215
retiled.printSchema()
211216
```
212217

213218
Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image.
214219

215220
```python, display_rgb
216-
sample = retiled.select('prediction', 'red', 'grn', 'blu') \
221+
sample = retiled \
222+
.select('prediction', rf_rgb_composite('red', 'grn', 'blu').alias('rgb')) \
217223
.sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \
218224
.first()
219225

220-
sample_prediction = sample['prediction']
221-
222-
red = sample['red'].cells
223-
grn = sample['grn'].cells
224-
blu = sample['blu'].cells
225-
sample_rgb = np.concatenate([red[ :, :, None], grn[:, :, None] , blu[ :, :, None]], axis=2)
226-
mins = np.nanmin(sample_rgb, axis=(0,1))
227-
plt.imshow((sample_rgb - mins)/ (np.nanmax(sample_rgb, axis=(0,1)) - mins))
226+
sample_rgb = sample['rgb']
227+
mins = np.nanmin(sample_rgb.cells, axis=(0,1))
228+
plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins))
228229
```
229230

230231
Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
231232

232233
```python, display_prediction
233-
display(sample_prediction)
234+
display(sample['prediction'])
234235
```

pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,27 @@ def rf_cell_types():
5353
return [CellType(str(ct)) for ct in _context_call('rf_cell_types')]
5454

5555

56-
def rf_assemble_tile(col_index, row_index, cell_data_col, num_cols, num_rows, cell_type):
56+
def rf_assemble_tile(col_index, row_index, cell_data_col, num_cols, num_rows, cell_type=None):
5757
"""Create a Tile from a column of cell data with location indices"""
5858
jfcn = RFContext.active().lookup('rf_assemble_tile')
59-
return Column(
60-
jfcn(_to_java_column(col_index), _to_java_column(row_index), _to_java_column(cell_data_col), num_cols, num_rows,
61-
_parse_cell_type(cell_type)))
6259

60+
if isinstance(num_cols, Column):
61+
num_cols = _to_java_column(num_cols)
62+
63+
if isinstance(num_rows, Column):
64+
num_rows = _to_java_column(num_rows)
65+
66+
if cell_type is None:
67+
return Column(jfcn(
68+
_to_java_column(col_index), _to_java_column(row_index), _to_java_column(cell_data_col),
69+
num_cols, num_rows
70+
))
71+
72+
else:
73+
return Column(jfcn(
74+
_to_java_column(col_index), _to_java_column(row_index), _to_java_column(cell_data_col),
75+
num_cols, num_rows, _parse_cell_type(cell_type)
76+
))
6377

6478
def rf_array_to_tile(array_col, num_cols, num_rows):
6579
"""Convert array in `array_col` into a Tile of dimensions `num_cols` and `num_rows'"""

0 commit comments

Comments
 (0)