Skip to content

Commit 1ebd005

Browse files
pickle free centroids hdf (#1076)
* hazard.io: avoid pickling geometries and compress hdf5 files * 'changelog' * fix column drop * add unit tests with multiple wkb and xy columns * avoid crs resetting * explicitly ask for Points Co-authored-by: Lukas Riedel <[email protected]> * fix typo * fix point condition * add comment about filtering geometry columns * remove ducplicated methods and obsolete lines --------- Co-authored-by: Lukas Riedel <[email protected]>
1 parent 530e957 commit 1ebd005

File tree

3 files changed

+104
-60
lines changed

3 files changed

+104
-60
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ CLIMADA tutorials. [#872](https://github.com/CLIMADA-project/climada_python/pull
221221
- `Impact.write_hdf5` now throws an error if `event_name` is does not contain strings exclusively [#894](https://github.com/CLIMADA-project/climada_python/pull/894)
222222
- Split `climada.hazard.trop_cyclone` module into smaller submodules without affecting module usage [#911](https://github.com/CLIMADA-project/climada_python/pull/911)
223223
- `yearly_steps` parameter of `TropCyclone.apply_climate_scenario_knu` has been made explicit [#991](https://github.com/CLIMADA-project/climada_python/pull/991)
224+
- `Hazard.write_hdf5` writes centroids as x,y columns (or as wkb in case of polygons) at a compression level of 9, not as pickled `Shapely` objects anymore, which reduces the size of the files significantly.
224225

225226
### Fixed
226227

climada/hazard/centroids/centr.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,14 @@ def __eq__(self, other):
177177

178178
try:
179179
pd.testing.assert_frame_equal(self.gdf, other.gdf, check_like=True)
180-
return True
181180
except AssertionError:
182181
return False
183182

183+
if not (self.gdf.geometry == other.gdf.geometry).all():
184+
return False
185+
186+
return True
187+
184188
def to_default_crs(self, inplace=True):
185189
"""Project the current centroids to the default CRS (epsg4326)
186190
@@ -483,11 +487,11 @@ def plot(self, *, axis=None, figsize=(9, 13), **kwargs):
483487
-------
484488
ax : cartopy.mpl.geoaxes.GeoAxes instance
485489
"""
486-
if axis == None:
490+
if axis is None:
487491
fig, axis = plt.subplots(
488492
figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}
489493
)
490-
if type(axis) != cartopy.mpl.geoaxes.GeoAxes:
494+
if type(axis) is not cartopy.mpl.geoaxes.GeoAxes:
491495
raise AttributeError(
492496
f"The axis provided is of type: {type(axis)} "
493497
"The function requires a cartopy.mpl.geoaxes.GeoAxes."
@@ -906,23 +910,40 @@ def write_hdf5(self, file_name, mode="w"):
906910
(path and) file name to write to.
907911
"""
908912
LOGGER.info("Writing %s", file_name)
909-
store = pd.HDFStore(file_name, mode=mode)
910-
pandas_df = pd.DataFrame(self.gdf)
911-
for col in pandas_df.columns:
912-
if str(pandas_df[col].dtype) == "geometry":
913-
pandas_df[col] = np.asarray(self.gdf[col])
914-
915-
# Avoid pandas PerformanceWarning when writing HDF5 data
916-
with warnings.catch_warnings():
917-
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
918-
# Write dataframe
919-
store.put("centroids", pandas_df)
920-
921-
store.get_storer("centroids").attrs.metadata = {
922-
"crs": CRS.from_user_input(self.crs).to_wkt()
923-
}
924-
925-
store.close()
913+
xycols = []
914+
wkbcols = []
915+
store = pd.HDFStore(file_name, mode=mode, complevel=9)
916+
try:
917+
pandas_df = pd.DataFrame(self.gdf)
918+
# we replace all columns of type geometry
919+
# - with according x and y columns if they are strictly `Point`s
920+
# - with wkb values if they have other shapes
921+
# this saves a lot of time and disk space
922+
for col in pandas_df.columns:
923+
if str(pandas_df[col].dtype) == "geometry":
924+
if (self.gdf[col].geom_type == "Point").all():
925+
pandas_df[col + ".x"] = self.gdf[col].x
926+
pandas_df[col + ".y"] = self.gdf[col].y
927+
pandas_df.drop(columns=[col], inplace=True)
928+
xycols.append(col)
929+
else:
930+
pandas_df[col] = self.gdf[col].to_wkb()
931+
wkbcols.append(col)
932+
933+
# Avoid pandas PerformanceWarning when writing HDF5 data
934+
with warnings.catch_warnings():
935+
warnings.simplefilter("ignore", category=pd.errors.PerformanceWarning)
936+
# Write dataframe
937+
store.put("centroids", pandas_df)
938+
939+
centroids_metadata = {"crs": CRS.from_user_input(self.crs).to_wkt()}
940+
if xycols:
941+
centroids_metadata["xy_columns"] = xycols
942+
if wkbcols:
943+
centroids_metadata["wkb_columns"] = wkbcols
944+
store.get_storer("centroids").attrs.metadata = centroids_metadata
945+
finally:
946+
store.close()
926947

927948
@classmethod
928949
def from_hdf5(cls, file_name):
@@ -950,7 +971,21 @@ def from_hdf5(cls, file_name):
950971
# in previous versions of CLIMADA and/or geopandas,
951972
# the CRS was stored in '_crs'/'crs'
952973
crs = metadata.get("crs")
953-
gdf = gpd.GeoDataFrame(store["centroids"], crs=crs)
974+
gdf = gpd.GeoDataFrame(store["centroids"])
975+
with warnings.catch_warnings():
976+
# setting a column named 'geometry' triggers a future warning
977+
# with geopandas 0.14
978+
warnings.simplefilter(action="ignore", category=FutureWarning)
979+
980+
for xycol in metadata.get("xy_columns", []):
981+
gdf[xycol] = gpd.points_from_xy(
982+
x=gdf[xycol + ".x"], y=gdf[xycol + ".y"], crs=crs
983+
)
984+
gdf.drop(columns=[xycol + ".x", xycol + ".y"], inplace=True)
985+
for wkbcol in metadata.get("wkb_columns", []):
986+
gdf[wkbcol] = gpd.GeoSeries.from_wkb(gdf[wkbcol], crs=crs)
987+
gdf.set_geometry("geometry", inplace=True)
988+
954989
except TypeError:
955990
with h5py.File(file_name, "r") as data:
956991
gdf = cls._gdf_from_legacy_hdf5(data.get("centroids"))

climada/hazard/centroids/test/test_centr.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,6 @@ def test_write_read_excel(self):
556556
def test_from_raster_file(self):
557557
"""Test from_raster_file"""
558558
width, height = 50, 60
559-
o_lat, o_lon = (10.42822096697894, -69.33714959699981)
560-
res_lat, res_lon = (-0.009000000000000341, 0.009000000000000341)
561559

562560
centr_ras = Centroids.from_raster_file(
563561
HAZ_DEMO_FL, window=Window(0, 0, width, height)
@@ -725,6 +723,51 @@ def test_read_write_hdf5(self):
725723
self.assertTrue(centroids_w == centroids_r)
726724
tmpfile.unlink()
727725

726+
def test_read_write_hdf5_with_additional_columns(self):
727+
tmpfile = Path("test_write_hdf5.out.hdf5")
728+
crs = CRS.from_user_input(ALT_CRS)
729+
centroids_w = Centroids(
730+
lat=VEC_LAT,
731+
lon=VEC_LON,
732+
crs=crs,
733+
region_id=REGION_ID,
734+
on_land=ON_LAND,
735+
)
736+
centroids_w.gdf = (
737+
centroids_w.gdf.join(
738+
gpd.GeoDataFrame(
739+
{"more_points": [shapely.Point(i, i) for i in range(8)]}
740+
).set_geometry("more_points")
741+
)
742+
.join(
743+
gpd.GeoDataFrame(
744+
{
745+
"some_shapes": [
746+
shapely.Point((2, 2)),
747+
shapely.Point((3, 3)),
748+
shapely.Polygon([(0, 0), (1, 1), (1, 0), (0, 0)]),
749+
shapely.LineString([(0, 1), (1, 0)]),
750+
]
751+
* 2
752+
}
753+
).set_geometry("some_shapes")
754+
)
755+
.join(
756+
gpd.GeoDataFrame(
757+
{
758+
"more_shapes": [
759+
shapely.LineString([(0, 1), (1, 2)]),
760+
]
761+
* 8
762+
}
763+
).set_geometry("more_shapes", crs=DEF_CRS)
764+
)
765+
)
766+
centroids_w.write_hdf5(tmpfile)
767+
centroids_r = Centroids.from_hdf5(tmpfile)
768+
self.assertEqual(centroids_w, centroids_r)
769+
tmpfile.unlink()
770+
728771
def test_from_hdf5_nonexistent_file(self):
729772
"""Test raising FileNotFoundError when creating Centroids object from a nonexistent HDF5 file"""
730773
file_name = "/path/to/nonexistentfile.h5"
@@ -901,41 +944,6 @@ def test_union(self):
901944
cent.on_land, np.concatenate([on_land, on_land2, [None, None]])
902945
)
903946

904-
def test_select_pass(self):
905-
"""Test Centroids.select method"""
906-
region_id = np.zeros(VEC_LAT.size)
907-
region_id[[2, 4]] = 10
908-
centr = Centroids(lat=VEC_LAT, lon=VEC_LON, region_id=region_id)
909-
910-
fil_centr = centr.select(reg_id=10)
911-
self.assertEqual(fil_centr.size, 2)
912-
self.assertEqual(fil_centr.lat[0], VEC_LAT[2])
913-
self.assertEqual(fil_centr.lat[1], VEC_LAT[4])
914-
self.assertEqual(fil_centr.lon[0], VEC_LON[2])
915-
self.assertEqual(fil_centr.lon[1], VEC_LON[4])
916-
self.assertTrue(np.array_equal(fil_centr.region_id, np.ones(2) * 10))
917-
918-
def test_select_extent_pass(self):
919-
"""Test select extent"""
920-
centr = Centroids(
921-
lat=np.array([-5, -3, 0, 3, 5]),
922-
lon=np.array([-180, -175, -170, 170, 175]),
923-
region_id=np.zeros(5),
924-
)
925-
ext_centr = centr.select(extent=[-175, -170, -5, 5])
926-
np.testing.assert_array_equal(ext_centr.lon, np.array([-175, -170]))
927-
np.testing.assert_array_equal(ext_centr.lat, np.array([-3, 0]))
928-
929-
# Cross antimeridian, version 1
930-
ext_centr = centr.select(extent=[170, -175, -5, 5])
931-
np.testing.assert_array_equal(ext_centr.lon, np.array([-180, -175, 170, 175]))
932-
np.testing.assert_array_equal(ext_centr.lat, np.array([-5, -3, 3, 5]))
933-
934-
# Cross antimeridian, version 2
935-
ext_centr = centr.select(extent=[170, 185, -5, 5])
936-
np.testing.assert_array_equal(ext_centr.lon, np.array([-180, -175, 170, 175]))
937-
np.testing.assert_array_equal(ext_centr.lat, np.array([-5, -3, 3, 5]))
938-
939947
def test_get_meta(self):
940948
"""
941949
Test that the `get_meta` method correctly generates metadata
@@ -968,7 +976,7 @@ def test_get_meta(self):
968976
self.assertTrue(u_coord.equal_crs(meta["crs"], expected_meta["crs"]))
969977
self.assertTrue(meta["transform"].almost_equals(expected_meta["transform"]))
970978

971-
def test_get_closest_point(self):
979+
def test_get_closest_point_1(self):
972980
"""Test get_closest_point"""
973981
for n, (lat, lon) in enumerate(LATLON):
974982
x, y, idx = self.centr.get_closest_point(lon * 0.99, lat * 1.01)
@@ -978,7 +986,7 @@ def test_get_closest_point(self):
978986
self.assertEqual(self.centr.lon[n], x)
979987
self.assertEqual(self.centr.lat[n], y)
980988

981-
def test_get_closest_point(self):
989+
def test_get_closest_point_2(self):
982990
"""Test get_closest_point"""
983991
for y_sign in [1, -1]:
984992
meta = {

0 commit comments

Comments
 (0)