Skip to content

Commit 7e89944

Browse files
committed
feat(L3)!: switch to using dask
The implemented routine acts on groups of months as independent entities. While on could believe using a more pandas-native approach like resample, this seems not to be the case; probably because the data I used was date-sorted, such that the data of each grid cell is scattered throughout the entire input dataset.
1 parent 62a27c4 commit 7e89944

File tree

2 files changed

+151
-69
lines changed

2 files changed

+151
-69
lines changed

cryoswath/l2.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from dask import dataframe as dd
12
import geopandas as gpd
23
from multiprocessing import Pool
34
import numpy as np
@@ -18,6 +19,7 @@
1819
def from_id(track_idx: pd.DatetimeIndex|str, *,
1920
reprocess: bool = True,
2021
save_or_return: str = "both",
22+
cache: str = None,
2123
cores: int = len(os.sched_getaffinity(0)),
2224
**kwargs) -> tuple[gpd.GeoDataFrame]:
2325
# this function collects processed data and processes the remaining.
@@ -28,13 +30,13 @@ def from_id(track_idx: pd.DatetimeIndex|str, *,
2830
if not isinstance(track_idx, pd.DatetimeIndex):
2931
track_idx = pd.DatetimeIndex(track_idx if isinstance(track_idx, list) else [track_idx])
3032
if track_idx.tz == None:
31-
track_idx.tz_localize("UTC")
33+
track_idx = track_idx.tz_localize("UTC")
3234
# somehow the download thread prevents the processing of tracks. it may
3335
# be due to GIL lock. for now, it is just disabled, so one has to
3436
# download in advance. on the fly is always possible, however, with
3537
# parallel processing this can lead to issues because ESA blocks ftp
3638
# connections if there are too many.
37-
print("Note that you can speed up processing substantially by previously downloading the L1b data.")
39+
print("[note] You can speed up processing substantially by previously downloading the L1b data.")
3840
# stop_event = Event()
3941
# download_thread = Thread(target=l1b.download_wrapper,
4042
# kwargs=dict(track_idx=track_idx, num_processes=8, stop_event=stop_event),
@@ -43,10 +45,36 @@ def from_id(track_idx: pd.DatetimeIndex|str, *,
4345
# download_thread.start()
4446
try:
4547
start_datetime, end_datetime = track_idx.sort_values()[[0,-1]]
48+
# ! below will not return data that is cached, even if save_or_return="both"
49+
# this is a flaw in the current logic. rework.
50+
if cache is not None and save_or_return != "return":
51+
try:
52+
with pd.HDFStore(cache, "r") as hdf:
53+
cached = hdf.select("poca", columns=[])
54+
# for better performance: reduce indices to two per month
55+
sample_rate_ns = int(15*(24*60*60)*1e9)
56+
tmp = cached.index.astype("int64")//sample_rate_ns
57+
tmp = pd.arrays.DatetimeArray(np.append(np.unique(tmp)*sample_rate_ns,
58+
# adding first and last element
59+
# included for debugging. on default, at least adding the last
60+
# index should not be added to prevent missing data
61+
cached.index[[0,-1]].astype("int64")))
62+
skip_months = np.unique(tmp.normalize()+pd.DateOffset(day=1))
63+
# print(skip_months)
64+
del cached
65+
except (OSError, KeyError) as err:
66+
if isinstance(err, KeyError):
67+
warnings.warn(f"Removed cache because of KeyError (\"{str(err)}\").")
68+
os.remove(cache)
69+
skip_months = np.empty(0)
4670
swath_list = []
4771
poca_list = []
4872
kwargs["cs_full_file_names"] = load_cs_full_file_names(update="no")
49-
for current_month in pd.date_range(start_datetime.normalize()-pd.offsets.MonthBegin(), end_datetime, freq="MS"):
73+
for current_month in pd.date_range(start_datetime.normalize()-pd.DateOffset(day=1),
74+
end_datetime, freq="MS"):
75+
if cache is not None and save_or_return != "return" and current_month.tz_localize(None) in skip_months:
76+
print("Skipping cached month", current_month.strftime("%Y-%m"))
77+
continue
5078
current_subdir = current_month.strftime(f"%Y{os.path.sep}%m")
5179
l2_paths = pd.DataFrame(columns=["swath", "poca"])
5280
for l2_type in ["swath", "poca"]:
@@ -58,22 +86,45 @@ def from_id(track_idx: pd.DatetimeIndex|str, *,
5886
else:
5987
os.makedirs(os.path.join(data_path, f"L2_{l2_type}", current_subdir))
6088
print("start processing", current_month)
61-
with Pool(processes=cores) as p:
62-
# function is defined at the bottom of this module
63-
collective_swath_poca_list = p.starmap(
64-
process_track,
65-
[(idx, reprocess, l2_paths, save_or_return, data_path, current_subdir, kwargs)
66-
for idx
67-
# indices per month with work-around :/ should be easier
68-
in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index],
69-
chunksize=1)
89+
if cores > 1:
90+
with Pool(processes=cores) as p:
91+
# function is defined at the bottom of this module
92+
collective_swath_poca_list = p.starmap(
93+
process_track,
94+
[(idx, reprocess, l2_paths, save_or_return, current_subdir, kwargs) for idx
95+
# indices per month with work-around :/ should be easier
96+
in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index],
97+
chunksize=1)
98+
else:
99+
collective_swath_poca_list = []
100+
for idx in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index:
101+
collective_swath_poca_list.append(process_track(idx, reprocess, l2_paths, save_or_return,
102+
current_subdir, kwargs))
103+
if cache is not None:
104+
for l2_type, i in zip(["swath", "poca"], [0, 1]):
105+
l2_data = pd.concat([item[i] for item in collective_swath_poca_list])
106+
if l2_type == "swath":
107+
l2_data.index = l2_data.index.get_level_values(0).astype(np.int64) \
108+
+ l2_data.index.get_level_values(1)
109+
l2_data.rename_axis("time", inplace=True)
110+
l2_data = pd.DataFrame(index=l2_data.index,
111+
data=pd.concat([l2_data.h_diff, l2_data.geometry.get_coordinates()],
112+
axis=1, copy=False))
113+
l2_data.astype(dict(h_diff=np.float32, x=np.int32, y=np.int32)).to_hdf(cache, key=l2_type, mode="a", append=True, format="table")
70114
if save_or_return != "save":
71-
for swath_poca_tuple in collective_swath_poca_list: # .get()
72-
swath_list.append(swath_poca_tuple[0])
73-
poca_list.append(swath_poca_tuple[1])
115+
swath_list.append(pd.concat([item[0] for item in collective_swath_poca_list]))
116+
poca_list.append(pd.concat([item[1] for item in collective_swath_poca_list]))
74117
print("done processing", current_month)
75118
if save_or_return != "save":
76-
return pd.concat(swath_list), pd.concat(poca_list)
119+
if swath_list == []:
120+
swath_list = pd.DataFrame()
121+
else:
122+
swath_list = pd.concat(swath_list)
123+
if poca_list == []:
124+
poca_list = pd.DataFrame()
125+
else:
126+
poca_list = pd.concat(poca_list)
127+
return swath_list, poca_list
77128
except:
78129
# print("Waiting for download threads to join.")
79130
# stop_event.set()
@@ -181,7 +232,7 @@ def process_and_save(region_of_interest: str|shapely.Polygon,
181232

182233

183234
# local helper function. can't be defined where it is needed because of namespace issues
184-
def process_track(idx, reprocess, l2_paths, save_or_return, data_path, current_subdir, kwargs):
235+
def process_track(idx, reprocess, l2_paths, save_or_return, current_subdir, kwargs):
185236
print("getting", idx)
186237
# print("kwargs", wargs)
187238
try:

cryoswath/l3.py

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import dask.dataframe
12
from dateutil.relativedelta import relativedelta
2-
import geopandas as gpd
3+
# import numba
34
import numpy as np
45
import os
56
import pandas as pd
@@ -9,56 +10,109 @@
910
from .misc import *
1011

1112
__all__ = list()
13+
14+
15+
# numba does not do help here easily. using the numpy functions is as fast as it gets.
16+
def med_iqr_cnt(data):
17+
quartiles = np.quantile(data, [.25, .5, .75])
18+
return pd.DataFrame([[quartiles[1], quartiles[2]-quartiles[0], len(data)]], columns=["_median", "_iqr", "_count"])
19+
__all__.append("med_iqr_cnt")
1220

1321

1422
def build_dataset(region_of_interest: str|shapely.Polygon,
1523
start_datetime: str|pd.Timestamp,
1624
end_datetime: str|pd.Timestamp, *,
17-
aggregation_period: relativedelta = relativedelta(months=3),
18-
timestep: relativedelta = relativedelta(months=1),
25+
l2_type: str = "swath",
26+
timestep_months: int = 1,
27+
window_ntimesteps: int = 3,
1928
spatial_res_meter: float = 500,
20-
**kwargs):
29+
agg_func_and_meta: tuple[callable, dict] = (med_iqr_cnt,
30+
{"_median": "f8", "_iqr": "f8", "_count": "i8"}),
31+
**l2_from_id_kwargs):
32+
if window_ntimesteps%2 - 1:
33+
old_window = window_ntimesteps
34+
window_ntimesteps = (window_ntimesteps//2+1)
35+
warnings.warn(f"The window should be a uneven number of time steps. You asked for {old_window}, but it has "+ f"been changed to {window_ntimesteps}.")
36+
# ! end time step should be included.
2137
start_datetime, end_datetime = pd.to_datetime([start_datetime, end_datetime])
22-
print("Building a gridded dataset of elevation estimates for the region",
23-
f"{region_of_interest} from {start_datetime} to {end_datetime} for",
24-
f"a rolling window of {aggregation_period} every {timestep}.")
25-
# if len(aggregation_period.kwds.keys()) != 1 \
26-
# or len(timestep.kwds.keys()) != 1 \
27-
# or list(aggregation_period.kwds.keys())[0] not in ["years", "months", "days"] \
28-
# or list(timestep.kwds.keys())[0] not in ["years", "months", "days"]:
29-
# raise Exception("Only use one of years, months, days for agg_time and timestep.")
38+
print("Building a gridded dataset of elevation estimates for",
39+
"the region "+region_of_interest if isinstance(region_of_interest, str) else "a custom area",
40+
f"from {start_datetime} to {end_datetime} every {timestep_months} for",
41+
f"a rolling window of {window_ntimesteps} time steps.")
3042
if "buffer_region_by" not in locals():
3143
# buffer_by defaults to 30 km to not miss any tracks. Usually,
3244
# 10 km should do.
3345
buffer_region_by = 30_000
34-
time_buffer = (aggregation_period-timestep)/2
46+
time_buffer_months = (window_ntimesteps*timestep_months)//2
47+
ext_t_axis = pd.date_range(start_datetime-pd.DateOffset(months=time_buffer_months),
48+
end_datetime+pd.DateOffset(months=time_buffer_months),
49+
freq=f"{timestep_months}MS",
50+
).astype("int64")
3551
cs_tracks = load_cs_ground_tracks(region_of_interest, start_datetime, end_datetime,
36-
buffer_period_by=time_buffer,buffer_region_by=buffer_region_by)
52+
buffer_period_by=relativedelta(months=time_buffer_months),
53+
buffer_region_by=buffer_region_by)
3754
print("First and last available ground tracks are on",
3855
f"{cs_tracks.index[0]} and {cs_tracks.index[-1]}, respectively.,",
39-
f"{cs_tracks.shape[0]} tracks in total.")
40-
print("Run update_cs_ground_tracks, optionally with `full=True` or",
56+
f"{cs_tracks.shape[0]} tracks in total."
57+
"\n[note] Run update_cs_ground_tracks, optionally with `full=True` or",
4158
"`incremental=True`, if you local ground tracks store is not up to",
4259
"date. Consider pulling the latest version from the repository.")
43-
# I believe passing loading l2 data to the function prevents copying
44-
# on .drop. an alternative would be to define l2_data nonlocal
45-
# within the gridding function
46-
l3_data = med_mad_cnt_grid(l2.from_id(cs_tracks.index, **filter_kwargs(l2.from_id, kwargs)),
47-
start_datetime=start_datetime, end_datetime=end_datetime,
48-
aggregation_period=aggregation_period, timestep=timestep,
49-
spatial_res_meter=spatial_res_meter)
50-
l3_data.to_netcdf(build_path(region_of_interest, timestep, spatial_res_meter, aggregation_period))
60+
61+
print("Storing the essential L2 data in hdf5, downloading and",
62+
"processing L1b files if not available...")
63+
if isinstance(region_of_interest, str):
64+
region_id = region_of_interest
65+
else:
66+
region_id = "_".join([region_of_interest.centroid.x, region_of_interest.centroid.y])
67+
cache_path = os.path.join(data_path, "tmp", region_id)
68+
l2.from_id(cs_tracks.index, save_or_return="save", cache=cache_path,
69+
**filter_kwargs(l2.from_id, l2_from_id_kwargs, blacklist=["save_or_return", "cache"]))
70+
71+
print("Gridding the data...")
72+
# one could drop some of the data before gridding. however, excluding
73+
# off-glacier data is expensive and filtering large differences to the
74+
# DEM can hide issues while statistics like the median and the IQR
75+
# should be fairly robust.
76+
l2_ddf = dask.dataframe.read_hdf(cache_path, l2_type, sorted_index=True)
77+
l2_ddf = l2_ddf.loc[ext_t_axis[0]:ext_t_axis[-1]]
78+
l2_ddf = l2_ddf.repartition(npartitions=3*len(os.sched_getaffinity(0)))
79+
80+
l2_ddf[["x", "y"]] = l2_ddf[["x", "y"]]//spatial_res_meter*spatial_res_meter
81+
l2_ddf["roll_0"] = l2_ddf.index.map_partitions(pd.cut, bins=ext_t_axis, right=False, labels=False, include_lowest=True)
82+
for i in range(1, window_ntimesteps):
83+
l2_ddf[f"roll_{i}"] = l2_ddf.map_partitions(lambda df: df.roll_0-i).persist()
84+
for i in range(window_ntimesteps):
85+
l2_ddf[f"roll_{i}"] = l2_ddf[f"roll_{i}"].map_partitions(lambda series: series.astype("i4")//window_ntimesteps)
86+
87+
roll_res = [None]*window_ntimesteps
88+
for i in range(window_ntimesteps):
89+
roll_res[i] = l2_ddf.rename(columns={f"roll_{i}": "time_idx"}).groupby(["time_idx", "x", "y"], sort=False).h_diff.apply(agg_func_and_meta[0], meta=agg_func_and_meta[1]).persist()
90+
for i in range(window_ntimesteps):
91+
roll_res[i] = roll_res[i].compute().droplevel(3, axis=0)
92+
roll_res[i].index = roll_res[i].index.set_levels(
93+
(roll_res[i].index.levels[0]*window_ntimesteps+i+1), level=0).rename("time", level=0)
94+
95+
l3_data = pd.concat(roll_res).sort_index()\
96+
.loc[(slice(0,len(ext_t_axis)-1),slice(None),slice(None)),:]
97+
l3_data.index = l3_data.index.remove_unused_levels()
98+
l3_data.index = l3_data.index.set_levels(
99+
ext_t_axis[l3_data.index.levels[0]].astype("datetime64[ns]"), level=0)
100+
l3_data = l3_data.query(f"time >= '{start_datetime}' and time <= '{end_datetime}'")
101+
l3_data.to_xarray().to_netcdf(build_path(region_id, timestep_months, spatial_res_meter))
51102
return l3_data
52103
__all__.append("build_dataset")
53104

54105

55-
def build_path(region_of_interest, timestep, spatial_res_meter, aggregation_period):
56-
region_id = find_region_id(region_of_interest)
57-
if list(timestep.kwds.values())[0]!=1:
58-
timestep_str = str(list(timestep.kwds.values())[0])+"-"
106+
def build_path(region_of_interest, timestep_months, spatial_res_meter, aggregation_period):
107+
if not isinstance(region_of_interest, str):
108+
region_id = find_region_id(region_of_interest)
109+
else:
110+
region_id = region_of_interest
111+
if timestep_months != 1:
112+
timestep_str = str(timestep_months)+"-"
59113
else:
60114
timestep_str = ""
61-
timestep_str += list(timestep.kwds.keys())[0][:-1]+"ly"
115+
timestep_str += "monthly"
62116
if spatial_res_meter == 1000:
63117
spatial_res_str = "1km"
64118
elif np.floor(spatial_res_meter/1000) < 2:
@@ -69,29 +123,6 @@ def build_path(region_of_interest, timestep, spatial_res_meter, aggregation_peri
69123
return os.path.join(data_path, "L3", "_".join(
70124
[region_id, timestep_str, spatial_res_str+".nc"]))
71125
__all__.append("build_path")
72-
73-
74-
def med_mad_cnt_grid(l2_data: gpd.GeoDataFrame, *,
75-
start_datetime: pd.Timestamp,
76-
end_datetime: pd.Timestamp,
77-
aggregation_period: relativedelta,
78-
timestep: relativedelta,
79-
spatial_res_meter: float):
80-
def stats(data: pd.Series) -> pd.Series:
81-
median = data.median()
82-
mad = np.abs(data-median).median()
83-
return pd.Series([median, mad, data.shape[0]])
84-
time_axis = pd.date_range(start_datetime+pd.offsets.MonthBegin(0), end_datetime, freq=timestep)
85-
if time_axis.tz == None: time_axis = time_axis.tz_localize("UTC")
86-
# if l2_data.index[0].tz == None: l2_data.index = l2_data.index.tz_localize("UTC")
87-
def rolling_stats(data):
88-
results_list = [None]*aggregation_period.months
89-
for i in range(aggregation_period.months):
90-
results_list[i] = data.groupby(subset.index.get_level_values("time")-pd.offsets.QuarterBegin(1, normalize=True)+pd.DateOffset(months=i)).apply(stats)
91-
result = pd.concat(results_list).unstack().sort_index().rename(columns={0: "med_elev_diff", 1: "mad_elev_diff", 2: "cnt_elev_diff"})#, inplace=True
92-
return result.loc[time_axis.join(result.index, how="inner")]
93-
return l2.grid(l2_data, spatial_res_meter, rolling_stats).to_xarray()
94-
__all__.append("med_mad_cnt_grid")
95126

96127

97128
__all__ = sorted(__all__)

0 commit comments

Comments
 (0)