Skip to content

Commit 73e3651

Browse files
committed
Initial refactor of raster reader args in Python API
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 1ea29f2 commit 73e3651

File tree

4 files changed

+225
-156
lines changed

4 files changed

+225
-156
lines changed

pyrasterframes/src/main/python/pyrasterframes/__init__.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,28 @@ def _aliased_writer(df_writer, format_key, path, **options):
110110

111111
def _raster_reader(
112112
df_reader,
113-
path=None,
114-
catalog=None,
113+
source=None,
115114
catalog_col_names=None,
116115
band_indexes=None,
117116
tile_dimensions=(256, 256),
118117
lazy_tiles=True,
119118
**options):
119+
"""
120+
Returns a Spark DataFrame from a raster data files specified by URI pointers
121+
The returned DataFrame will have a column of (CRS, Extent, Tile) for each URI read
122+
Multiple bands from the same raster file are spread across rows of the DataFrame. See band_indexes param.
123+
If bands from a scene are stored in separate files, provide a DataFrame to the `source` parameter. Each row in the returned DataFrame will contain one (CRS, Extent, Tile) for each item in `catalog_col_names`
124+
125+
For more details and example usage, consult https://rasterframes.io/raster-read.html
126+
127+
:param source: a string, list of strings, a pandas DataFrame or a Spark DataFrame giving URIs to the raster data to read
128+
:param catalog_col_names: required if source is a DataFrame or CSV string. It is a list of strings giving the names of columns containing URIs to read
129+
:param band_indexes: list of integers indicating which bands, zero-based, to read from the raster files specified; default is to read only the first band
130+
:param tile_dimensions: tuple or list of two indicating the default tile dimension as (columns, rows)
131+
:param lazy_tiles: If true (default) only generate minimal references to tile contents; if false, fetch tile cell values
132+
:param options: Additional keyword arguments to pass to the spark DataSource
133+
:return:
134+
"""
120135

121136
from pandas import DataFrame as PdDataFrame
122137

@@ -140,6 +155,25 @@ def temp_name():
140155
"lazyTiles": lazy_tiles
141156
})
142157

158+
# Parse the `source` argument
159+
path = None # to pass into `path` param
160+
if isinstance(source, list):
161+
path = None
162+
catalog = None
163+
options.update(dict(paths='\n'.join(str(source))))
164+
elif isinstance(source, str):
165+
if '\n' in source or '\r' in source:
166+
# then the `source` string is a catalog as a CSV (header is required)
167+
path = None
168+
catalog = source
169+
else:
170+
# interpret source as a single URI string
171+
path = source
172+
catalog = None
173+
else:
174+
# user has passed in some other type, we will interpret as a catalog
175+
catalog = source
176+
143177
if catalog is not None:
144178
if catalog_col_names is None:
145179
raise Exception("'catalog_col_names' required when DataFrame 'catalog' specified")
@@ -149,6 +183,9 @@ def temp_name():
149183
"catalogColumns": to_csv(catalog_col_names)
150184
})
151185
elif isinstance(catalog, DataFrame):
186+
# check catalog_col_names
187+
assert all([c in catalog.columns for c in catalog_col_names]), \
188+
"All items in catalog_col_names must be the name of a column in the catalog DataFrame."
152189
# Create a random view name
153190
tmp_name = temp_name()
154191
catalog.createOrReplaceTempView(tmp_name)
@@ -157,6 +194,10 @@ def temp_name():
157194
"catalogColumns": to_csv(catalog_col_names)
158195
})
159196
elif isinstance(catalog, PdDataFrame):
197+
# check catalog_col_names
198+
assert all([c in catalog.columns for c in catalog_col_names]), \
199+
"All items in catalog_col_names must be the name of a column in the catalog DataFrame."
200+
160201
# Handle to active spark session
161202
session = SparkContext._active_spark_context._rf_context._spark_session
162203
# Create a random view name

pyrasterframes/src/main/python/tests/PyRasterFramesTests.py

Lines changed: 0 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -410,158 +410,6 @@ def test_raster_join(self):
410410
self.rf.raster_join(rf_prime, join_exprs=self.rf.extent)
411411

412412

413-
class RasterSource(TestEnvironment):
414-
415-
def test_handle_lazy_eval(self):
416-
df = self.spark.read.raster(self.img_uri)
417-
ltdf = df.select('proj_raster')
418-
self.assertGreater(ltdf.count(), 0)
419-
self.assertIsNotNone(ltdf.first())
420-
421-
tdf = df.select(rf_tile('proj_raster'))
422-
self.assertGreater(tdf.count(), 0)
423-
self.assertIsNotNone(tdf.first())
424-
425-
def test_strict_eval(self):
426-
df_lazy = self.spark.read.raster(self.img_uri, lazy_tiles=True)
427-
# when doing Show on a lazy tile we will see something like RasterRefTile(RasterRef(JVMGeoTiffRasterSource(...
428-
# use this trick to get the `show` string
429-
show_str_lazy = df_lazy.select('proj_raster')._jdf.showString(1, -1, False)
430-
self.assertTrue('RasterRef' in show_str_lazy)
431-
432-
# again for strict
433-
df_strict = self.spark.read.raster(self.img_uri, lazy_tiles=False)
434-
show_str_strict = df_strict.select('proj_raster')._jdf.showString(1, -1, False)
435-
self.assertTrue('RasterRef' not in show_str_strict)
436-
437-
438-
def test_prt_functions(self):
439-
df = self.spark.read.raster(self.img_uri) \
440-
.withColumn('crs', rf_crs('proj_raster')) \
441-
.withColumn('ext', rf_extent('proj_raster')) \
442-
.withColumn('geom', rf_geometry('proj_raster'))
443-
df.select('crs', 'ext', 'geom').first()
444-
445-
def test_raster_source_reader(self):
446-
# much the same as RasterSourceDataSourceSpec here; but using https PDS. Takes about 30s to run
447-
448-
def l8path(b):
449-
assert b in range(1, 12)
450-
base = "https://s3-us-west-2.amazonaws.com/landsat-pds/c1/L8/199/026/LC08_L1TP_199026_20180919_20180928_01_T1/LC08_L1TP_199026_20180919_20180928_01_T1_B{}.TIF"
451-
return base.format(b)
452-
453-
path_param = '\n'.join([l8path(b) for b in [1, 2, 3]]) # "http://foo.com/file1.tif,http://foo.com/file2.tif"
454-
tile_size = 512
455-
456-
df = self.spark.read.raster(
457-
tile_dimensions=(tile_size, tile_size),
458-
paths=path_param,
459-
lazy_tiles=True,
460-
).cache()
461-
462-
# schema is tile_path and tile
463-
# df.printSchema()
464-
self.assertTrue(len(df.columns) == 2 and 'proj_raster_path' in df.columns and 'proj_raster' in df.columns)
465-
466-
# the most common tile dimensions should be as passed to `options`, showing that options are correctly applied
467-
tile_size_df = df.select(rf_dimensions(df.proj_raster).rows.alias('r'), rf_dimensions(df.proj_raster).cols.alias('c')) \
468-
.groupby(['r', 'c']).count().toPandas()
469-
most_common_size = tile_size_df.loc[tile_size_df['count'].idxmax()]
470-
self.assertTrue(most_common_size.r == tile_size and most_common_size.c == tile_size)
471-
472-
# all rows are from a single source URI
473-
path_count = df.groupby(df.proj_raster_path).count()
474-
print(path_count.toPandas())
475-
self.assertTrue(path_count.count() == 3)
476-
477-
def test_raster_source_reader_schemeless(self):
478-
import os.path
479-
path = os.path.join(self.resource_dir, "L8-B8-Robinson-IL.tiff")
480-
self.assertTrue(not path.startswith('file://'))
481-
df = self.spark.read.raster(path)
482-
self.assertTrue(df.count() > 0)
483-
484-
def test_raster_source_catalog_reader(self):
485-
import pandas as pd
486-
487-
scene_dict = {
488-
1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF',
489-
2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF',
490-
3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF',
491-
}
492-
493-
def path(scene, band):
494-
assert band in range(1, 12)
495-
p = scene_dict[scene]
496-
return p.format(band)
497-
498-
# Create a pandas dataframe (makes it easy to create spark df)
499-
path_pandas = pd.DataFrame([
500-
{'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3)},
501-
{'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3)},
502-
{'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3)},
503-
])
504-
# comma separated list of column names containing URI's to read.
505-
catalog_columns = ','.join(path_pandas.columns.tolist()) # 'b1,b2,b3'
506-
path_table = self.spark.createDataFrame(path_pandas)
507-
508-
path_df = self.spark.read.raster(
509-
tile_dimensions=(512, 512),
510-
catalog=path_table,
511-
catalog_col_names=catalog_columns,
512-
lazy_tiles=True # We'll get an OOM error if we try to read 9 scenes all at once!
513-
)
514-
515-
self.assertTrue(len(path_df.columns) == 6) # three bands times {path, tile}
516-
self.assertTrue(path_df.select('b1_path').distinct().count() == 3) # as per scene_dict
517-
b1_paths_maybe = path_df.select('b1_path').distinct().collect()
518-
b1_paths = [s.format('1') for s in scene_dict.values()]
519-
self.assertTrue(all([row.b1_path in b1_paths for row in b1_paths_maybe]))
520-
521-
def test_raster_source_catalog_reader_with_pandas(self):
522-
import pandas as pd
523-
import geopandas
524-
from shapely.geometry import Point
525-
526-
scene_dict = {
527-
1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF',
528-
2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF',
529-
3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF',
530-
}
531-
532-
def path(scene, band):
533-
assert band in range(1, 12)
534-
p = scene_dict[scene]
535-
return p.format(band)
536-
537-
# Create a pandas dataframe (makes it easy to create spark df)
538-
path_pandas = pd.DataFrame([
539-
{'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3), 'geo': Point(1, 1)},
540-
{'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3), 'geo': Point(2, 2)},
541-
{'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3), 'geo': Point(3, 3)},
542-
])
543-
544-
# here a subtle difference with the test_raster_source_catalog_reader test, feed the DataFrame not a CSV and not an already created spark DF.
545-
df = self.spark.read.raster(
546-
catalog=path_pandas,
547-
catalog_col_names=['b1', 'b2', 'b3']
548-
)
549-
self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo
550-
self.assertTrue('geo' in df.columns)
551-
self.assertTrue(df.select('b1_path').distinct().count() == 3)
552-
553-
554-
# Same test with geopandas
555-
geo_df = geopandas.GeoDataFrame(path_pandas, crs={'init': 'EPSG:4326'}, geometry='geo')
556-
df2 = self.spark.read.raster(
557-
catalog=geo_df,
558-
catalog_col_names=['b1', 'b2', 'b3']
559-
)
560-
self.assertEqual(len(df2.columns), 7) # three path cols, three tile cols, and geo
561-
self.assertTrue('geo' in df2.columns)
562-
self.assertTrue(df2.select('b1_path').distinct().count() == 3)
563-
564-
565413
def suite():
566414
function_tests = unittest.TestSuite()
567415
return function_tests

0 commit comments

Comments
 (0)