Skip to content

Commit 61d6b6e

Browse files
authored
Merge pull request #329 from s22s/feature/python-raster-reader-arg-refactor-and-docs
Python raster reader argument refactor
2 parents 60c4917 + 74f09df commit 61d6b6e

File tree

15 files changed

+289
-217
lines changed

15 files changed

+289
-217
lines changed

datasource/src/main/scala/org/locationtech/rasterframes/datasource/raster/RasterSourceDataSource.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ object RasterSourceDataSource {
4444
final val SHORT_NAME = "raster"
4545
final val PATH_PARAM = "path"
4646
final val PATHS_PARAM = "paths"
47-
final val BAND_INDEXES_PARAM = "bandIndexes"
48-
final val TILE_DIMS_PARAM = "tileDimensions"
49-
final val CATALOG_TABLE_PARAM = "catalogTable"
50-
final val CATALOG_TABLE_COLS_PARAM = "catalogColumns"
51-
final val CATALOG_CSV_PARAM = "catalogCSV"
52-
final val LAZY_TILES_PARAM = "lazyTiles"
47+
final val BAND_INDEXES_PARAM = "band_indexes"
48+
final val TILE_DIMS_PARAM = "tile_dimensions"
49+
final val CATALOG_TABLE_PARAM = "catalog_table"
50+
final val CATALOG_TABLE_COLS_PARAM = "catalog_col_names"
51+
final val CATALOG_CSV_PARAM = "catalog_csv"
52+
final val LAZY_TILES_PARAM = "lazy_tiles"
5353

5454
final val DEFAULT_COLUMN_NAME = PROJECTED_RASTER_COLUMN.columnName
5555

pyrasterframes/src/main/python/docs/languages.pymd

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ red_nir_monthly_2017.printSchema()
4242

4343
```python, step_3_python
4444
red_nir_tiles_monthly_2017 = spark.read.raster(
45-
catalog=red_nir_monthly_2017,
45+
red_nir_monthly_2017,
4646
catalog_col_names=['red', 'nir'],
4747
tile_dimensions=(256, 256)
4848
)
@@ -97,9 +97,9 @@ sql("""
9797
CREATE OR REPLACE TEMPORARY VIEW red_nir_tiles_monthly_2017
9898
USING raster
9999
OPTIONS (
100-
catalogTable='red_nir_monthly_2017',
101-
catalogColumns='red,nir',
102-
tileDimensions='256,256'
100+
catalog_table='red_nir_monthly_2017',
101+
catalog_col_names='red,nir',
102+
tile_dimensions='256,256'
103103
)
104104
""")
105105
```

pyrasterframes/src/main/python/docs/local-algebra.pymd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ catalog_df = spark.createDataFrame([
4040
Row(red=uri_pattern.format(4), nir=uri_pattern.format(8))
4141
])
4242
df = spark.read.raster(
43-
catalog=catalog_df,
43+
catalog_df,
4444
catalog_col_names=['red', 'nir']
4545
)
4646
df.printSchema()

pyrasterframes/src/main/python/docs/nodata-handling.pymd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ from pyspark.sql import Row
9090
blue_uri = 'https://s22s-test-geotiffs.s3.amazonaws.com/luray_snp/B02.tif'
9191
scl_uri = 'https://s22s-test-geotiffs.s3.amazonaws.com/luray_snp/SCL.tif'
9292
cat = spark.createDataFrame([Row(blue=blue_uri, scl=scl_uri),])
93-
unmasked = spark.read.raster(catalog=cat, catalog_col_names=['blue', 'scl'])
93+
unmasked = spark.read.raster(cat, catalog_col_names=['blue', 'scl'])
9494
unmasked.printSchema()
9595
```
9696

pyrasterframes/src/main/python/docs/numpy-pandas.pymd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ cat = spark.read.format('aws-pds-modis-catalog').load() \
5151
(col('acquisition_date') < lit('2018-02-22'))
5252
)
5353

54-
spark_df = spark.read.raster(catalog=cat, catalog_col_names=['B01']) \
54+
spark_df = spark.read.raster(cat, catalog_col_names=['B01']) \
5555
.select(
5656
'acquisition_date',
5757
'granule_id',

pyrasterframes/src/main/python/docs/raster-read.pymd

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ modis_catalog = spark.read \
101101
.withColumn('red' , F.concat('base_url', F.lit("_B01.TIF"))) \
102102
.withColumn('nir' , F.concat('base_url', F.lit("_B02.TIF")))
103103

104-
modis_catalog.printSchema()
105-
106104
print("Available scenes: ", modis_catalog.count())
107105
```
108106

@@ -124,10 +122,7 @@ equator.select('date', 'gid')
124122
Now that we have prepared our catalog, we simply pass the DataFrame or CSV string to the `raster` DataSource to load the imagery. The `catalog_col_names` parameter gives the columns that contain the URI's to be read.
125123

126124
```python, read_catalog
127-
rf = spark.read.raster(
128-
catalog=equator,
129-
catalog_col_names=['red', 'nir']
130-
)
125+
rf = spark.read.raster(equator, catalog_col_names=['red', 'nir'])
131126
rf.printSchema()
132127
```
133128

@@ -179,7 +174,7 @@ mb.printSchema()
179174

180175
If a band is passed into `band_indexes` that exceeds the number of bands in the raster, a projected raster column will still be generated in the schema but the column will be full of `null` values.
181176

182-
You can also pass a `catalog` and `band_indexes` together into the `raster` reader. This will create a projected raster column for the combination of all items passed into `catalog_col_names` and `band_indexes`. Again if a band in `band_indexes` exceeds the number of bands in a raster, it will have a `null` value for the corresponding column.
177+
You can also pass a _catalog_ and `band_indexes` together into the `raster` reader. This will create a projected raster column for the combination of all items in `catalog_col_names` and `band_indexes`. Again if a band in `band_indexes` exceeds the number of bands in a raster, it will have a `null` value for the corresponding column.
183178

184179
Here is a trivial example with a _catalog_ over multiband rasters. We specify two columns containing URIs and two bands, resulting in four projected raster columns.
185180

@@ -191,7 +186,7 @@ mb_cat = pd.DataFrame([
191186
},
192187
])
193188
mb2 = spark.read.raster(
194-
catalog=spark.createDataFrame(mb_cat),
189+
spark.createDataFrame(mb_cat),
195190
catalog_col_names=['foo', 'bar'],
196191
band_indexes=[0, 1],
197192
tile_dimensions=(64,64)

pyrasterframes/src/main/python/docs/supervised-learning.pymd

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@ catalog_df = pd.DataFrame([
3333
{b: uri_base.format(b) for b in cols}
3434
])
3535

36-
df = spark.read.raster(catalog=catalog_df,
37-
catalog_col_names=cols,
38-
tile_dimensions=(128, 128)
39-
).repartition(100)
36+
df = spark.read.raster(catalog_df, catalog_col_names=cols, tile_dimensions=(128, 128)) \
37+
.repartition(100)
4038

4139
df = df.select(
4240
rf_crs(df.B01).alias('crs'),

pyrasterframes/src/main/python/docs/time-series.pymd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ We then [reproject](https://gis.stackexchange.com/questions/247770/understanding
9797
```python read_catalog
9898
raster_cols = ['B01', 'B02',] # red and near-infrared respectively
9999
park_rf = spark.read.raster(
100-
catalog=park_cat.select(['acquisition_date', 'granule_id', 'geo_simp'] + raster_cols),
100+
park_cat.select(['acquisition_date', 'granule_id', 'geo_simp'] + raster_cols),
101101
catalog_col_names=raster_cols) \
102102
.withColumn('park_native', st_reproject('geo_simp', lit('EPSG:4326'), rf_crs('B01'))) \
103103
.filter(st_intersects('park_native', rf_geometry('B01')))

pyrasterframes/src/main/python/docs/unsupervised-learning.pymd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ filenamePattern = "L8-B{}-Elkton-VA.tiff"
3737
catalog_df = pd.DataFrame([
3838
{'b' + str(b): os.path.join(resource_dir_uri(), filenamePattern.format(b)) for b in range(1, 8)}
3939
])
40-
df = spark.read.raster(catalog=catalog_df, catalog_col_names=catalog_df.columns)
40+
df = spark.read.raster(catalog_df, catalog_col_names=catalog_df.columns)
4141
df = df.select(
4242
rf_crs(df.b1).alias('crs'),
4343
rf_extent(df.b1).alias('extent'),

pyrasterframes/src/main/python/pyrasterframes/__init__.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,34 @@ def _aliased_writer(df_writer, format_key, path, **options):
110110

111111
def _raster_reader(
112112
df_reader,
113-
path=None,
114-
catalog=None,
113+
source=None,
115114
catalog_col_names=None,
116115
band_indexes=None,
117116
tile_dimensions=(256, 256),
118117
lazy_tiles=True,
119118
**options):
119+
"""
120+
Returns a Spark DataFrame from raster data files specified by URIs.
121+
Each row in the returned DataFrame will contain a column with struct of (CRS, Extent, Tile) for each item in
122+
`catalog_col_names`.
123+
Multiple bands from the same raster file are spread across rows of the DataFrame. See `band_indexes` param.
124+
If bands from a scene are stored in separate files, provide a DataFrame to the `source` parameter.
125+
126+
For more details and example usage, consult https://rasterframes.io/raster-read.html
127+
128+
:param source: a string, list of strings, list of lists of strings, a Pandas DataFrame or a Spark DataFrame giving URIs to the raster data to read.
129+
:param catalog_col_names: required if `source` is a DataFrame or CSV string. It is a list of strings giving the names of columns containing URIs to read.
130+
:param band_indexes: list of integers indicating which bands, zero-based, to read from the raster files specified; default is to read only the first band.
131+
:param tile_dimensions: tuple or list of two indicating the default tile dimension as (columns, rows).
132+
:param lazy_tiles: If true (default) only generate minimal references to tile contents; if false, fetch tile cell values.
133+
:param options: Additional keyword arguments to pass to the Spark DataSource.
134+
"""
120135

121136
from pandas import DataFrame as PdDataFrame
122137

138+
if 'catalog' in options:
139+
source = options['catalog'] # maintain back compatibility with 0.8.0
140+
123141
def to_csv(comp):
124142
if isinstance(comp, str):
125143
return comp
@@ -135,37 +153,75 @@ def temp_name():
135153
band_indexes = [0]
136154

137155
options.update({
138-
"bandIndexes": to_csv(band_indexes),
139-
"tileDimensions": to_csv(tile_dimensions),
140-
"lazyTiles": lazy_tiles
156+
"band_indexes": to_csv(band_indexes),
157+
"tile_dimensions": to_csv(tile_dimensions),
158+
"lazy_tiles": lazy_tiles
141159
})
142160

161+
# Parse the `source` argument
162+
path = None # to pass into `path` param
163+
if isinstance(source, list):
164+
if all([isinstance(i, str) for i in source]):
165+
path = None
166+
catalog = None
167+
options.update(dict(paths='\n'.join([str(i) for i in source]))) # pass in "uri1\nuri2\nuri3\n..."
168+
if all([isinstance(i, list) for i in source]):
169+
# list of lists; we will rely on pandas to:
170+
# - coerce all data to str (possibly using objects' __str__ or __repr__)
171+
# - ensure data is not "ragged": all sublists are same len
172+
path = None
173+
catalog_col_names = ['proj_raster_{}'.format(i) for i in range(len(source[0]))] # assign these names
174+
catalog = PdDataFrame(source,
175+
columns=catalog_col_names,
176+
dtype=str,
177+
)
178+
elif isinstance(source, str):
179+
if '\n' in source or '\r' in source:
180+
# then the `source` string is a catalog as a CSV (header is required)
181+
path = None
182+
catalog = source
183+
else:
184+
# interpret source as a single URI string
185+
path = source
186+
catalog = None
187+
else:
188+
# user has passed in some other type, we will try to interpret as a catalog
189+
catalog = source
190+
143191
if catalog is not None:
144192
if catalog_col_names is None:
145193
raise Exception("'catalog_col_names' required when DataFrame 'catalog' specified")
194+
146195
if isinstance(catalog, str):
147196
options.update({
148-
"catalogCSV": catalog,
149-
"catalogColumns": to_csv(catalog_col_names)
197+
"catalog_csv": catalog,
198+
"catalog_col_names": to_csv(catalog_col_names)
150199
})
151200
elif isinstance(catalog, DataFrame):
201+
# check catalog_col_names
202+
assert all([c in catalog.columns for c in catalog_col_names]), \
203+
"All items in catalog_col_names must be the name of a column in the catalog DataFrame."
152204
# Create a random view name
153205
tmp_name = temp_name()
154206
catalog.createOrReplaceTempView(tmp_name)
155207
options.update({
156-
"catalogTable": tmp_name,
157-
"catalogColumns": to_csv(catalog_col_names)
208+
"catalog_table": tmp_name,
209+
"catalog_col_names": to_csv(catalog_col_names)
158210
})
159211
elif isinstance(catalog, PdDataFrame):
212+
# check catalog_col_names
213+
assert all([c in catalog.columns for c in catalog_col_names]), \
214+
"All items in catalog_col_names must be the name of a column in the catalog DataFrame."
215+
160216
# Handle to active spark session
161217
session = SparkContext._active_spark_context._rf_context._spark_session
162218
# Create a random view name
163219
tmp_name = temp_name()
164220
spark_catalog = session.createDataFrame(catalog)
165221
spark_catalog.createOrReplaceTempView(tmp_name)
166222
options.update({
167-
"catalogTable": tmp_name,
168-
"catalogColumns": to_csv(catalog_col_names)
223+
"catalog_table": tmp_name,
224+
"catalog_col_names": to_csv(catalog_col_names)
169225
})
170226

171227
return df_reader \

0 commit comments

Comments
 (0)