1111from collections .abc import Iterable , Sequence
1212
1313import numpy as np
14+ import xarray as xr
1415
1516_logger = logging .getLogger (__name__ )
1617
@@ -99,17 +100,19 @@ def plot_metric_region(
99100 if ch not in np .atleast_1d (data .channel .values ) or data .isnull ().all ():
100101 continue
101102
102- data , time_dim = _assign_time_coord (data )
103-
104103 selected_data .append (data .sel (channel = ch ))
105104 labels .append (runs [run_id ].get ("label" , run_id ))
106105 run_ids .append (run_id )
107106
108107 if selected_data :
109108 _logger .info (f"Creating plot for { metric } - { region } - { stream } - { ch } ." )
109+
110110 name = create_filename (
111111 prefix = [metric , region ], middle = sorted (set (run_ids )), suffix = [stream , ch ]
112112 )
113+
114+ selected_data , time_dim = _assign_time_coord (selected_data )
115+
113116 plotter .plot (
114117 selected_data ,
115118 labels ,
@@ -120,12 +123,12 @@ def plot_metric_region(
120123 )
121124
122125
123- def _assign_time_coord (data : object ) -> object :
126+ def _assign_time_coord (selected_data : list [ xr . DataArray ] ) -> tuple [ xr . DataArray , str ] :
124127 """Ensure that lead_time coordinate exists in the data array.
125128
126129 Parameters
127130 ----------
128- data : xarray.DataArray
131+ selected_data : list[ xarray.DataArray]
129132 The data array to check.
130133
131134 Returns
@@ -136,23 +139,30 @@ def _assign_time_coord(data: object) -> object:
136139 time_dim : str
137140 The name of the time dimension used for x-axis.
138141 """
139- if "forecast_step" not in data .dims and "forecast_step" not in data .coords :
140- raise ValueError ("forecast_step coordinate not found in data dimensions or coordinates." )
141142
142143 time_dim = "forecast_step"
143144
144- if "lead_time" in data .coords and data ["forecast_step" ].size == data ["lead_time" ].size :
145- data = data .swap_dims ({"forecast_step" : "lead_time" })
145+ for data in selected_data :
146+ if "forecast_step" not in data .dims and "forecast_step" not in data .coords :
147+ raise ValueError (
148+ "forecast_step coordinate not found in data dimensions or coordinates."
149+ )
146150
147- # Prefer lead_time as x_dim if present in dimensions
148- if "lead_time" in data .dims :
149- time_dim = "lead_time"
150- else :
151- _logger .warning (
152- "lead_time coordinate not found or mismatched size; using forecast_step as x-axis."
153- )
151+ if "lead_time" not in data .coords and "lead_time" not in data .dims :
152+ _logger .warning (
153+ "lead_time coordinate not found for all plotted data; "
154+ "using forecast_step as x-axis."
155+ )
156+ return selected_data , time_dim
157+
158+ # Swap forecast_step with lead_time if all available run_ids have lead_time coord
159+ time_dim = "lead_time"
154160
155- return data , time_dim
161+ for i , data in enumerate (selected_data ):
162+ if data .coords ["lead_time" ].shape == data .coords ["forecast_step" ].shape :
163+ selected_data [i ] = data .swap_dims ({"forecast_step" : "lead_time" })
164+
165+ return selected_data , time_dim
156166
157167
158168def ratio_plot_metric_region (
@@ -251,8 +261,6 @@ def heat_maps_metric_region(
251261 if data .isnull ().all ():
252262 continue
253263
254- data , time_dim = _assign_time_coord (data )
255-
256264 selected_data .append (data )
257265 label = runs [run_id ].get ("label" , run_id )
258266 if label != run_id :
@@ -265,6 +273,8 @@ def heat_maps_metric_region(
265273 name = create_filename (
266274 prefix = [metric , region ], middle = sorted (set (run_ids )), suffix = [stream ]
267275 )
276+ selected_data , time_dim = _assign_time_coord (selected_data )
277+
268278 plotter .heat_map (
269279 selected_data ,
270280 labels ,
0 commit comments