Skip to content

Commit 8b9484a

Browse files
feat(loss): add AMSE loss with SHT caching to address spectral double penalty (#164)
* feat(loss): add AMSE loss with SHT caching to address spectral double penalty * test: add unit tests for AMSENormalizedLoss with gradient and CUDA checks * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Formatting: format code acc to guidelines * Fix: Minor bug in code fixed * Fix * refactor: convert loss tests to pytest-style functions * pre commit task * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert "pre commit task" This reverts commit 62a23f1. * Formatting code acc to standards * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed node modules folder * Removed unnecessary comment and formatted acc to stds --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c0fb0ea commit 8b9484a

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

graph_weather/models/losses.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import numpy as np
44
import torch
5+
import torch.nn as nn
6+
import torch_harmonics as th
57

68

79
class NormalizedMSELoss(torch.nn.Module):
@@ -90,3 +92,104 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor):
9092

9193
assert not torch.isnan(out).any()
9294
return out.mean()
95+
96+
97+
# Spectrally Adjusted Mean Squared Error (AMSE) loss
98+
class AMSENormalizedLoss(nn.Module):
99+
"""
100+
Spectrally Adjusted Mean Squared Error (AMSE) Loss.
101+
102+
This loss function is designed to address the "double penalty" issue in spatial forecasting
103+
by separately penalizing amplitude and phase differences in the spectral domain.
104+
105+
It applies the Spherical Harmonic Transform (SHT) to both predictions and targets,
106+
computes the power spectral density (PSD), and then evaluates two terms:
107+
1. Amplitude Error (difference in spectral amplitudes).
108+
2. Decorrelation Error (phase misalignment/coherence loss).
109+
110+
This implementation follows the formulation in:
111+
"Fixing the Double Penalty in Data-Driven Weather Forecasting Through a Modified Spherical Harmonic Loss Function"
112+
(ICML 2025 Poster).
113+
114+
Args:
115+
feature_variance (list or torch.Tensor): Variance of each physical feature for normalization (length C).
116+
epsilon (float): Small constant for numerical stability.
117+
"""
118+
119+
def __init__(self, feature_variance: list | torch.Tensor, epsilon: float = 1e-9):
120+
super().__init__()
121+
if not isinstance(feature_variance, torch.Tensor):
122+
feature_variance = torch.tensor(feature_variance, dtype=torch.float32)
123+
else:
124+
feature_variance = feature_variance.clone().detach().float()
125+
126+
self.register_buffer("feature_variance", feature_variance)
127+
128+
# SHT cache to avoid re-initializing on every forward pass since object performs some expensive pre-computation when it's initialized. Doing this repeatedly inside the training loop can add unnecessary overhead.
129+
self.epsilon = epsilon
130+
self.sht_cache = {}
131+
132+
def _get_sht(self, nlat: int, nlon: int, device: torch.device) -> th.RealSHT:
133+
"""
134+
Helper to get a cached SHT object, creating it if it doesn't exist.
135+
This prevents re-initializing the SHT object on every forward pass.
136+
"""
137+
key = (nlat, nlon, device)
138+
if key not in self.sht_cache:
139+
self.sht_cache[key] = th.RealSHT(nlat, nlon, grid="equiangular").to(device)
140+
return self.sht_cache[key]
141+
142+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
143+
"""
144+
Forward pass to compute the AMSE loss.
145+
146+
Args:
147+
pred (torch.Tensor): Predicted tensor of shape (B, C, H, W).
148+
target (torch.Tensor): Ground truth tensor of shape (B, C, H, W).
149+
150+
Returns:
151+
torch.Tensor: Scalar loss value (averaged over batch and features).
152+
"""
153+
if pred.shape != target.shape:
154+
raise ValueError("Prediction and target tensors must have the same shape.")
155+
if pred.ndim != 4:
156+
raise ValueError("Input tensors must be 4D: (batch, channels, lat, lon)")
157+
158+
batch_size, num_channels, nlat, nlon = pred.shape
159+
160+
# Reshape to (B*C, H, W) to process all variables at once
161+
pred_reshaped = pred.view(batch_size * num_channels, nlat, nlon)
162+
target_reshaped = target.view(batch_size * num_channels, nlat, nlon)
163+
164+
# Get the (potentially cached) SHT object
165+
sht = self._get_sht(nlat, nlon, pred.device)
166+
pred_coeffs = sht(pred_reshaped) # (B*C, L, M) complex
167+
target_coeffs = sht(target_reshaped) # (B*C, L, M) complex
168+
169+
# Compute Power Spectral Densities (PSD): sum |coeff|^2 over M
170+
pred_psd = torch.sum(torch.abs(pred_coeffs) ** 2, dim=-1) # (B*C, L)
171+
target_psd = torch.sum(torch.abs(target_coeffs) ** 2, dim=-1) # (B*C, L)
172+
173+
# Compute spectral coherence between prediction and target
174+
cross_power = pred_coeffs * torch.conj(target_coeffs) # (B*C, L, M)
175+
coherence_num = torch.sum(cross_power.real, dim=-1) # (B*C, L)
176+
coherence_denom = torch.sqrt(pred_psd * target_psd)
177+
coherence = coherence_num / (coherence_denom + self.epsilon) # (B*C, L)
178+
179+
# Compute amplitude error: difference in sqrt(PSD)
180+
amp_error = (
181+
torch.sqrt(pred_psd + self.epsilon) - torch.sqrt(target_psd + self.epsilon)
182+
) ** 2
183+
184+
# Compute decorrelation error
185+
decor_error = 2.0 * coherence_denom * (1.0 - coherence)
186+
187+
# Total spectral loss per sample
188+
spectral_loss = torch.sum(amp_error + decor_error, dim=-1) # (B*C,)
189+
190+
# Reshape back to (B, C)
191+
spectral_loss = spectral_loss.view(batch_size, num_channels)
192+
193+
# Normalize by feature-wise variance and compute mean loss
194+
normalized_loss = spectral_loss / (self.feature_variance + self.epsilon)
195+
return normalized_loss.mean()

tests/test_asme_loss.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import pytest
2+
import torch
3+
import torch_harmonics as th
4+
5+
from graph_weather.models.losses import AMSENormalizedLoss
6+
7+
8+
@pytest.fixture
9+
def default_shape() -> tuple[int, int, int, int]:
10+
"""Return a default tensor shape (B, C, H, W) for test inputs."""
11+
return 2, 3, 32, 64
12+
13+
14+
@pytest.fixture
15+
def feature_variance(default_shape: tuple) -> torch.Tensor:
16+
"""Return a synthetic feature variance tensor, one value per channel."""
17+
_, num_channels, _, _ = default_shape
18+
return (torch.rand(num_channels) + 0.5).clone().detach()
19+
20+
21+
@pytest.fixture
22+
def loss_fn(feature_variance: torch.Tensor) -> AMSENormalizedLoss:
23+
"""Instantiate the AMSENormalizedLoss with mock feature variance."""
24+
return AMSENormalizedLoss(feature_variance=feature_variance)
25+
26+
27+
def test_zero_loss_for_identical_inputs(loss_fn: AMSENormalizedLoss, default_shape: tuple):
28+
"""Loss should be zero when prediction and target tensors are identical."""
29+
pred = torch.randn(default_shape)
30+
target = pred.clone()
31+
loss = loss_fn(pred, target)
32+
assert torch.allclose(loss, torch.tensor(0.0), atol=1e-6)
33+
34+
35+
def test_positive_loss_for_different_inputs(loss_fn: AMSENormalizedLoss, default_shape: tuple):
36+
"""Loss should be strictly positive when inputs differ."""
37+
pred = torch.randn(default_shape)
38+
target = torch.randn(default_shape)
39+
loss = loss_fn(pred, target)
40+
assert loss.item() > 0.0
41+
42+
43+
def test_gradient_flow(loss_fn: AMSENormalizedLoss, default_shape: tuple):
44+
"""Check that gradients can flow through the loss for backpropagation."""
45+
pred = torch.randn(default_shape, requires_grad=True)
46+
target = torch.randn(default_shape)
47+
loss = loss_fn(pred, target)
48+
loss.backward()
49+
assert pred.grad is not None
50+
assert torch.sum(torch.abs(pred.grad)) > 0
51+
52+
53+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
54+
def test_cuda_execution(feature_variance: torch.Tensor, default_shape: tuple):
55+
"""Verify that the loss runs on GPU and returns a finite CUDA tensor."""
56+
device = torch.device("cuda")
57+
loss_fn_cuda = AMSENormalizedLoss(feature_variance=feature_variance).to(device)
58+
pred = torch.randn(default_shape, device=device)
59+
target = torch.randn(default_shape, device=device)
60+
loss = loss_fn_cuda(pred, target)
61+
assert loss.is_cuda
62+
assert torch.isfinite(loss)
63+
64+
65+
def test_known_value_simple_case(feature_variance: torch.Tensor):
66+
"""
67+
Validate loss against a known spectral case.
68+
69+
This test generates synthetic spectral coefficients and applies the inverse
70+
spherical harmonic transform to ensure the AMSE loss produces expected values.
71+
"""
72+
nlat, nlon = 16, 32
73+
batch_size, num_channels = 1, feature_variance.shape[0]
74+
75+
sht_forward_temp = th.RealSHT(nlat, nlon, grid="equiangular")
76+
lmax, mmax = sht_forward_temp.lmax, sht_forward_temp.mmax
77+
coeffs_shape = (batch_size * num_channels, lmax, mmax)
78+
79+
# Place known energy in (l=1, m=0) band
80+
target_coeffs = torch.zeros(coeffs_shape, dtype=torch.complex64)
81+
target_coeffs[:, 1, 0] = 1.0 + 0.0j
82+
pred_coeffs = target_coeffs * 0.5
83+
84+
# Inverse SHT to get spatial-domain data
85+
isht = th.InverseRealSHT(nlat, nlon, grid="equiangular")
86+
target = isht(target_coeffs).view(batch_size, num_channels, nlat, nlon)
87+
pred = isht(pred_coeffs).view(batch_size, num_channels, nlat, nlon)
88+
89+
# Manually compute expected normalized spectral loss
90+
psd_target_l1 = 1.0**2
91+
psd_pred_l1 = 0.5**2
92+
amp_error_l1 = (
93+
torch.sqrt(torch.tensor(psd_pred_l1)) - torch.sqrt(torch.tensor(psd_target_l1))
94+
) ** 2
95+
expected_spectral_loss_per_channel = amp_error_l1
96+
expected_normalized_loss = (expected_spectral_loss_per_channel / feature_variance).mean()
97+
98+
# Compare to actual loss
99+
loss_fn = AMSENormalizedLoss(feature_variance=feature_variance)
100+
actual_loss = loss_fn(pred, target)
101+
102+
assert torch.allclose(actual_loss, expected_normalized_loss, atol=1e-5)

0 commit comments

Comments
 (0)