Skip to content

Commit c90965e

Browse files
authored
[FIX] Prevent inverse normalization of quantile validation loss (#1432)
1 parent a788a8b commit c90965e

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

neuralforecast/common/_base_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,8 @@ def _get_loc_scale(self, y_idx, add_channel_dim=False):
991991
def _compute_valid_loss(
992992
self, insample_y, outsample_y, output, outsample_mask, y_idx
993993
):
994+
output_from_scaled_distribution = False
995+
994996
if self.loss.is_distribution_output:
995997
y_loc, y_scale = self._get_loc_scale(y_idx)
996998
distr_args = self.loss.scale_decouple(
@@ -1001,17 +1003,20 @@ def _compute_valid_loss(
10011003
):
10021004
_, _, quants = self.loss.sample(distr_args=distr_args)
10031005
output = quants
1006+
output_from_scaled_distribution = True
10041007
elif isinstance(self.valid_loss, losses.BasePointLoss):
10051008
distr = self.loss.get_distribution(distr_args=distr_args)
10061009
output = distr.mean
1010+
output_from_scaled_distribution = True
10071011

1008-
# Validation Loss evaluation
1012+
# Validation loss evaluation
10091013
if self.valid_loss.is_distribution_output:
10101014
valid_loss = self.valid_loss(
10111015
y=outsample_y, distr_args=distr_args, mask=outsample_mask
10121016
)
10131017
else:
1014-
output = self._inv_normalization(y_hat=output, y_idx=y_idx)
1018+
if not output_from_scaled_distribution:
1019+
output = self._inv_normalization(y_hat=output, y_idx=y_idx)
10151020
valid_loss = self.valid_loss(
10161021
y=outsample_y, y_hat=output, y_insample=insample_y, mask=outsample_mask
10171022
)

tests/test_core.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,3 +1969,40 @@ def _test_model_additivity(preds_df, expl, model_name, use_polars, n_series, h,
19691969
rtol=1e-3,
19701970
err_msg="Attribution predictions do not match model predictions"
19711971
)
1972+
1973+
def test_compute_valid_loss_distribution_to_quantile_scale():
1974+
"""
1975+
Test that when training with DistributionLoss and validating with
1976+
quantile-based losses, the validation loss is computed on the original scale.
1977+
"""
1978+
loss = DistributionLoss(distribution='StudentT', level=[80, 90])
1979+
1980+
# Simulate normalized model output (mean ~0, scale ~1)
1981+
batch_size, horizon, n_series = 2, 12, 1
1982+
raw_output = (
1983+
torch.ones(batch_size, horizon, n_series) * 5,
1984+
torch.zeros(batch_size, horizon, n_series), # mean (normalized)
1985+
torch.zeros(batch_size, horizon, n_series), # scale (normalized)
1986+
)
1987+
1988+
# Simulate real data statistics
1989+
loc = torch.ones(batch_size, horizon, n_series) * 400
1990+
scale = torch.ones(batch_size, horizon, n_series) * 100
1991+
1992+
# Apply scale_decouple (transforms distribution params to original scale)
1993+
distr_args = loss.scale_decouple(raw_output, loc=loc, scale=scale)
1994+
1995+
# Sample quantiles
1996+
_, _, quants = loss.sample(distr_args)
1997+
1998+
# Target would be in original scale (around loc)
1999+
target_mean = loc.mean().item()
2000+
quants_mean = quants.mean().item()
2001+
2002+
ratio = quants_mean / target_mean
2003+
2004+
# Ratio should be close to 1 - quantiles and target on same scale
2005+
assert 0.8 < ratio < 1.2, (
2006+
f"Quantiles mean ({quants_mean:.2f}) and target mean ({target_mean:.2f}) "
2007+
f"are not on the same scale. Ratio: {ratio:.2f}"
2008+
)

0 commit comments

Comments
 (0)