Skip to content

Commit 44cab76

Browse files
iluiseTillHae
authored andcommitted
Iluise/fix lead time (ecmwf#1571)
* implement reader merge * working version of merge reader * linter * lint * fix lead time * update to develop
1 parent 1a04b23 commit 44cab76

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

packages/evaluate/src/weathergen/evaluate/plotting/plot_utils.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections.abc import Iterable, Sequence
1212

1313
import numpy as np
14+
import xarray as xr
1415

1516
_logger = logging.getLogger(__name__)
1617

@@ -99,17 +100,19 @@ def plot_metric_region(
99100
if ch not in np.atleast_1d(data.channel.values) or data.isnull().all():
100101
continue
101102

102-
data, time_dim = _assign_time_coord(data)
103-
104103
selected_data.append(data.sel(channel=ch))
105104
labels.append(runs[run_id].get("label", run_id))
106105
run_ids.append(run_id)
107106

108107
if selected_data:
109108
_logger.info(f"Creating plot for {metric} - {region} - {stream} - {ch}.")
109+
110110
name = create_filename(
111111
prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream, ch]
112112
)
113+
114+
selected_data, time_dim = _assign_time_coord(selected_data)
115+
113116
plotter.plot(
114117
selected_data,
115118
labels,
@@ -120,12 +123,12 @@ def plot_metric_region(
120123
)
121124

122125

123-
def _assign_time_coord(data: object) -> object:
126+
def _assign_time_coord(selected_data: list[xr.DataArray]) -> tuple[xr.DataArray, str]:
124127
"""Ensure that lead_time coordinate exists in the data array.
125128
126129
Parameters
127130
----------
128-
data : xarray.DataArray
131+
selected_data : list[xarray.DataArray]
129132
The data array to check.
130133
131134
Returns
@@ -136,23 +139,30 @@ def _assign_time_coord(data: object) -> object:
136139
time_dim : str
137140
The name of the time dimension used for x-axis.
138141
"""
139-
if "forecast_step" not in data.dims and "forecast_step" not in data.coords:
140-
raise ValueError("forecast_step coordinate not found in data dimensions or coordinates.")
141142

142143
time_dim = "forecast_step"
143144

144-
if "lead_time" in data.coords and data["forecast_step"].size == data["lead_time"].size:
145-
data = data.swap_dims({"forecast_step": "lead_time"})
145+
for data in selected_data:
146+
if "forecast_step" not in data.dims and "forecast_step" not in data.coords:
147+
raise ValueError(
148+
"forecast_step coordinate not found in data dimensions or coordinates."
149+
)
146150

147-
# Prefer lead_time as x_dim if present in dimensions
148-
if "lead_time" in data.dims:
149-
time_dim = "lead_time"
150-
else:
151-
_logger.warning(
152-
"lead_time coordinate not found or mismatched size; using forecast_step as x-axis."
153-
)
151+
if "lead_time" not in data.coords and "lead_time" not in data.dims:
152+
_logger.warning(
153+
"lead_time coordinate not found for all plotted data; "
154+
"using forecast_step as x-axis."
155+
)
156+
return selected_data, time_dim
157+
158+
# Swap forecast_step with lead_time if all available run_ids have lead_time coord
159+
time_dim = "lead_time"
154160

155-
return data, time_dim
161+
for i, data in enumerate(selected_data):
162+
if data.coords["lead_time"].shape == data.coords["forecast_step"].shape:
163+
selected_data[i] = data.swap_dims({"forecast_step": "lead_time"})
164+
165+
return selected_data, time_dim
156166

157167

158168
def ratio_plot_metric_region(
@@ -251,8 +261,6 @@ def heat_maps_metric_region(
251261
if data.isnull().all():
252262
continue
253263

254-
data, time_dim = _assign_time_coord(data)
255-
256264
selected_data.append(data)
257265
label = runs[run_id].get("label", run_id)
258266
if label != run_id:
@@ -265,6 +273,8 @@ def heat_maps_metric_region(
265273
name = create_filename(
266274
prefix=[metric, region], middle=sorted(set(run_ids)), suffix=[stream]
267275
)
276+
selected_data, time_dim = _assign_time_coord(selected_data)
277+
268278
plotter.heat_map(
269279
selected_data,
270280
labels,

packages/evaluate/src/weathergen/evaluate/utils/utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -214,14 +214,15 @@ def calc_scores_per_stream(
214214
reader, map_dir, stream, region, score_data, metrics, fstep
215215
)
216216

217-
lead_time_values = np.array(
218-
[lead_time_map[f].astype(int) for f in metric_stream.forecast_step.values]
219-
).squeeze()
220-
221-
if lead_time_values.shape == metric_stream.forecast_step.shape:
222-
metric_stream = metric_stream.assign_coords(
223-
lead_time=("forecast_step", lead_time_values)
224-
)
217+
if all(lead_time_map[f] is not None for f in lead_time_map):
218+
lead_time_values = np.array(
219+
[lead_time_map[f].astype(int) for f in metric_stream.forecast_step.values]
220+
).squeeze()
221+
222+
if lead_time_values.shape == metric_stream.forecast_step.shape:
223+
metric_stream = metric_stream.assign_coords(
224+
lead_time=("forecast_step", lead_time_values)
225+
)
225226

226227
_logger.info(f"Scores for run {reader.run_id} - {stream} calculated successfully.")
227228

0 commit comments

Comments
 (0)