Skip to content

Commit e342580

Browse files
committed
PR feedback
Signed-off-by: Jason T. Brown <[email protected]>
1 parent 31fced9 commit e342580

File tree

3 files changed

+45
-46
lines changed

3 files changed

+45
-46
lines changed

pyrasterframes/src/main/python/pyrasterframes/__init__.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,20 @@ def _raster_reader(
117117
lazy_tiles=True,
118118
**options):
119119
"""
120-
Returns a Spark DataFrame from raster data files specified by URI pointers
121-
The returned DataFrame will have a column of (CRS, Extent, Tile) for each URI read
122-
Multiple bands from the same raster file are spread across rows of the DataFrame. See band_indexes param.
123-
If bands from a scene are stored in separate files, provide a DataFrame to the `source` parameter. Each row in the returned DataFrame will contain one (CRS, Extent, Tile) for each item in `catalog_col_names`
120+
Returns a Spark DataFrame from raster data files specified by URIs.
121+
Each row in the returned DataFrame will contain a column with struct of (CRS, Extent, Tile) for each item in
122+
`catalog_col_names`.
123+
Multiple bands from the same raster file are spread across rows of the DataFrame. See `band_indexes` param.
124+
If bands from a scene are stored in separate files, provide a DataFrame to the `source` parameter.
124125
125126
For more details and example usage, consult https://rasterframes.io/raster-read.html
126127
127-
:param source: a string, list of strings, list of lists of strings, a pandas DataFrame or a Spark DataFrame giving URIs to the raster data to read
128-
:param catalog_col_names: required if source is a DataFrame or CSV string. It is a list of strings giving the names of columns containing URIs to read
129-
:param band_indexes: list of integers indicating which bands, zero-based, to read from the raster files specified; default is to read only the first band
130-
:param tile_dimensions: tuple or list of two indicating the default tile dimension as (columns, rows)
131-
:param lazy_tiles: If true (default) only generate minimal references to tile contents; if false, fetch tile cell values
132-
:param options: Additional keyword arguments to pass to the spark DataSource
128+
:param source: a string, list of strings, list of lists of strings, a Pandas DataFrame or a Spark DataFrame giving URIs to the raster data to read.
129+
:param catalog_col_names: required if `source` is a DataFrame or CSV string. It is a list of strings giving the names of columns containing URIs to read.
130+
:param band_indexes: list of integers indicating which bands, zero-based, to read from the raster files specified; default is to read only the first band.
131+
:param tile_dimensions: tuple or list of two indicating the default tile dimension as (columns, rows).
132+
:param lazy_tiles: If true (default) only generate minimal references to tile contents; if false, fetch tile cell values.
133+
:param options: Additional keyword arguments to pass to the Spark DataSource.
133134
"""
134135

135136
from pandas import DataFrame as PdDataFrame

pyrasterframes/src/main/python/tests/RasterFunctionsTests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def test_render_composite(self):
286286
cat = self.spark.createDataFrame([
287287
Row(red=self.l8band_uri(4), green=self.l8band_uri(3), blue=self.l8band_uri(2))
288288
])
289-
rf = self.spark.read.raster(catalog=cat, catalog_col_names=cat.columns)
289+
rf = self.spark.read.raster(cat, catalog_col_names=cat.columns)
290290

291291
# Test composite construction
292292
rgb = rf.select(rf_tile(rf_rgb_composite('red', 'green', 'blue')).alias('rgb')).first()['rgb']

pyrasterframes/src/main/python/tests/RasterSourceTest.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from pyrasterframes.rasterfunctions import *
2222
from pyrasterframes.rf_types import *
2323
from pyspark.sql.functions import *
24+
import pandas as pd
25+
from shapely.geometry import Point
2426
import os.path
2527
from unittest import skip
2628
from . import TestEnvironment
@@ -41,6 +43,14 @@ def path(scene, band):
4143
p = scene_dict[scene]
4244
return p.format(band)
4345

46+
def path_pandas_df(self):
47+
return pd.DataFrame([
48+
{'b1': self.path(1, 1), 'b2': self.path(1, 2), 'b3': self.path(1, 3), 'geo': Point(1, 1)},
49+
{'b1': self.path(2, 1), 'b2': self.path(2, 2), 'b3': self.path(2, 3), 'geo': Point(2, 2)},
50+
{'b1': self.path(3, 1), 'b2': self.path(3, 2), 'b3': self.path(3, 3), 'geo': Point(3, 3)},
51+
])
52+
53+
4454
def test_handle_lazy_eval(self):
4555
df = self.spark.read.raster(self.path(1, 1))
4656
ltdf = df.select('proj_raster')
@@ -129,59 +139,41 @@ def test_schemeless_string(self):
129139
self.assertTrue(df.count() > 0)
130140

131141
def test_spark_df_source(self):
132-
import pandas as pd
142+
catalog_columns = ['b1', 'b2', 'b3']
143+
catalog = self.spark.createDataFrame(self.path_pandas_df())
133144

134-
# Create a pandas dataframe (makes it easy to create spark df)
135-
path_pandas = pd.DataFrame([
136-
{'b1': self.path(1, 1), 'b2': self.path(1, 2), 'b3': self.path(1, 3)},
137-
{'b1': self.path(2, 1), 'b2': self.path(2, 2), 'b3': self.path(2, 3)},
138-
{'b1': self.path(3, 1), 'b2': self.path(3, 2), 'b3': self.path(3, 3)},
139-
])
140-
# comma separated list of column names containing URI's to read.
141-
catalog_columns = path_pandas.columns.tolist()
142-
path_table = self.spark.createDataFrame(path_pandas)
143-
144-
path_df = self.spark.read.raster(
145-
path_table,
145+
df = self.spark.read.raster(
146+
catalog,
146147
tile_dimensions=(512, 512),
147148
catalog_col_names=catalog_columns,
148149
lazy_tiles=True # We'll get an OOM error if we try to read 9 scenes all at once!
149150
)
150151

151-
self.assertTrue(len(path_df.columns) == 6) # three bands times {path, tile}
152-
self.assertTrue(path_df.select('b1_path').distinct().count() == 3) # as per scene_dict
153-
b1_paths_maybe = path_df.select('b1_path').distinct().collect()
152+
self.assertTrue(len(df.columns) == 7) # three bands times {path, tile} plus geo
153+
self.assertTrue(df.select('b1_path').distinct().count() == 3) # as per scene_dict
154+
b1_paths_maybe = df.select('b1_path').distinct().collect()
154155
b1_paths = [self.path(s, 1) for s in [1, 2, 3]]
155156
self.assertTrue(all([row.b1_path in b1_paths for row in b1_paths_maybe]))
156157

157158
def test_pandas_source(self):
158-
import pandas as pd
159-
import geopandas
160-
from shapely.geometry import Point
161159

162-
# Create a pandas dataframe (makes it easy to create spark df)
163-
path_pandas = pd.DataFrame([
164-
{'b1': self.path(1, 1), 'b2': self.path(1, 2), 'b3': self.path(1, 3), 'geo': Point(1, 1)},
165-
{'b1': self.path(2, 1), 'b2': self.path(2, 2), 'b3': self.path(2, 3), 'geo': Point(2, 2)},
166-
{'b1': self.path(3, 1), 'b2': self.path(3, 2), 'b3': self.path(3, 3), 'geo': Point(3, 3)},
167-
])
168-
169-
# here a subtle difference with the test_raster_source_catalog_reader test, feed the DataFrame
170-
# not a CSV and not an already created spark DF.
171160
df = self.spark.read.raster(
172-
path_pandas,
161+
self.path_pandas_df(),
173162
catalog_col_names=['b1', 'b2', 'b3']
174163
)
175164
self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo
176165
self.assertTrue('geo' in df.columns)
177166
self.assertTrue(df.select('b1_path').distinct().count() == 3)
178167

179-
# Same test with geopandas
180-
geo_df = geopandas.GeoDataFrame(path_pandas, crs={'init': 'EPSG:4326'}, geometry='geo')
181-
df2 = self.spark.read.raster(geo_df, ['b1', 'b2', 'b3'])
182-
self.assertEqual(len(df2.columns), 7) # three path cols, three tile cols, and geo
183-
self.assertTrue('geo' in df2.columns)
184-
self.assertTrue(df2.select('b1_path').distinct().count() == 3)
168+
def test_geopandas_source(self):
169+
from geopandas import GeoDataFrame
170+
# Same test as test_pandas_source with geopandas
171+
geo_df = GeoDataFrame(self.path_pandas_df(), crs={'init': 'EPSG:4326'}, geometry='geo')
172+
df = self.spark.read.raster(geo_df, ['b1', 'b2', 'b3'])
173+
174+
self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo
175+
self.assertTrue('geo' in df.columns)
176+
self.assertTrue(df.select('b1_path').distinct().count() == 3)
185177

186178
def test_csv_string(self):
187179

@@ -198,3 +190,9 @@ def test_csv_string(self):
198190
df = self.spark.read.raster(s, ['b1', 'b2'])
199191
self.assertEqual(len(df.columns), 3 + 2) # number of columns in original DF plus cardinality of catalog_col_names
200192
self.assertTrue(len(df.take(1))) # non-empty check
193+
194+
def test_catalog_named_arg(self):
195+
# through version 0.8.1 reading a catalog was via named argument only.
196+
df = self.spark.read.raster(catalog=self.path_pandas_df(), catalog_col_names=['b1', 'b2', 'b3'])
197+
self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo
198+
self.assertTrue(df.select('b1_path').distinct().count() == 3)

0 commit comments

Comments
 (0)