Skip to content

Commit 12b0ff8

Browse files
JubekuclessigJulian Kuehnert
authored
Implement first function for latitude weighting (ecmwf#705)
* Changed logging level for some messages. * Refactored loss computation to improve performance. * Working around ruff issue * - Refactored code to improve structure and readability - Fixed problem with incomplete normalization over loss functions - Solved problem with mse_weighted as loss function when mse is specified * Fixed problems with multi-worker training * add location weights, first commit * assertion on mask and len(location_weights) * restructuring of location weights and fixes in mse_channel_location_weighted function * fix coords_raw dependency on offset and fstep * ruff * addressing review commits and fixing bug * rm location_weight from default stream config --------- Co-authored-by: Christian Lessig <christian.lessig@ecmwf.int> Co-authored-by: Julian Kuehnert <julian.kuehnert@ecwmf.int>
1 parent f517759 commit 12b0ff8

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

packages/evaluate/src/weathergen/evaluate/plotter.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import xarray as xr
99

1010
from weathergen.utils.config import _load_private_conf
11-
work_dir = Path( _load_private_conf(None)['path_shared_working_dir']) / 'assets/cartopy'
11+
12+
work_dir = Path(_load_private_conf(None)["path_shared_working_dir"]) / "assets/cartopy"
1213
import cartopy
13-
cartopy.config['data_dir'] = str(work_dir)
14-
cartopy.config['pre_existing_data_dir'] = str(work_dir)
14+
15+
cartopy.config["data_dir"] = str(work_dir)
16+
cartopy.config["pre_existing_data_dir"] = str(work_dir)
1517
os.environ["CARTOPY_DATA_DIR"] = str(work_dir)
1618

1719
np.seterr(divide="ignore", invalid="ignore")

src/weathergen/train/loss.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,15 @@ def mse_channel_location_weighted(
135135
mask_nan = ~torch.isnan(target)
136136
pred = pred[0] if pred.shape[0] == 0 else pred.mean(0)
137137

138-
diff2 = torch.square(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0)).mean(0)
139-
wl = weights_points
140-
loss_chs = ((diff2.transpose(1, 0) * wl).transpose(1, 0) if wl else diff2).mean(0)
138+
diff2 = torch.square(torch.where(mask_nan, target, 0) - torch.where(mask_nan, pred, 0))
139+
if weights_points is not None:
140+
diff2 = (diff2.transpose(1, 0) * weights_points).transpose(1, 0)
141+
loss_chs = diff2.mean(0)
141142
loss = torch.mean(loss_chs * weights_channels if weights_channels else loss_chs)
142143

143144
return loss, loss_chs
145+
146+
147+
def cosine_latitude(stream_data, forecast_offset, fstep, min_value=1e-3, max_value=1.0):
148+
latitudes_radian = stream_data.target_coords_raw[forecast_offset + fstep][:, 0] * np.pi / 180
149+
return (max_value - min_value) * np.cos(latitudes_radian) + min_value

src/weathergen/train/loss_calculator.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ def _get_weights(self, stream_info):
105105

106106
return stream_info_loss_weight, weights_channels
107107

108+
def _get_location_weights(self, stream_info, stream_data, forecast_offset, fstep):
109+
location_weight_type = stream_info.get("location_weight", None)
110+
if location_weight_type is None:
111+
return None
112+
weights_locations_fct = getattr(losses, location_weight_type)
113+
weights_locations = weights_locations_fct(stream_data, forecast_offset, fstep)
114+
weights_locations = weights_locations.to(device=self.device, non_blocking=True)
115+
116+
return weights_locations
117+
108118
def _get_substep_masks(self, stream_info, fstep, stream_data):
109119
"""
110120
Find substeps and create corresponding masks (reused across loss functions)
@@ -140,7 +150,7 @@ def _loss_per_loss_function(
140150

141151
ctr_substeps = 0
142152
for mask_t in substep_masks:
143-
assert mask_t.sum() == len(weights_locations) if weights_locations else True
153+
assert mask_t.sum() == len(weights_locations) if weights_locations is not None else True
144154

145155
loss, loss_chs = loss_fct(
146156
target[mask_t], pred[:, mask_t], weights_channels, weights_locations
@@ -220,9 +230,6 @@ def compute_loss(
220230

221231
stream_data = streams_data[i_batch][i_stream_info]
222232

223-
# TODO: set from stream info
224-
weights_locations = None
225-
226233
loss_fsteps = torch.tensor(0.0, device=self.device, requires_grad=True)
227234
ctr_fsteps = 0
228235
for fstep, target in enumerate(targets):
@@ -240,6 +247,11 @@ def compute_loss(
240247
# get weigths for current streams
241248
stream_loss_weight, weights_channels = self._get_weights(stream_info)
242249

250+
# get weights for locations
251+
weights_locations = self._get_location_weights(
252+
stream_info, stream_data, self.cf.forecast_offset, fstep
253+
)
254+
243255
# get masks for sub-time steps
244256
substep_masks = self._get_substep_masks(stream_info, fstep, stream_data)
245257

0 commit comments

Comments
 (0)