Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions neuralforecast/common/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,8 @@ def _get_loc_scale(self, y_idx, add_channel_dim=False):
def _compute_valid_loss(
self, insample_y, outsample_y, output, outsample_mask, y_idx
):
output_from_scaled_distribution = False

if self.loss.is_distribution_output:
y_loc, y_scale = self._get_loc_scale(y_idx)
distr_args = self.loss.scale_decouple(
Expand All @@ -1001,17 +1003,20 @@ def _compute_valid_loss(
):
_, _, quants = self.loss.sample(distr_args=distr_args)
output = quants
output_from_scaled_distribution = True
elif isinstance(self.valid_loss, losses.BasePointLoss):
distr = self.loss.get_distribution(distr_args=distr_args)
output = distr.mean
output_from_scaled_distribution = True

# Validation Loss evaluation
# Validation loss evaluation
if self.valid_loss.is_distribution_output:
valid_loss = self.valid_loss(
y=outsample_y, distr_args=distr_args, mask=outsample_mask
)
else:
output = self._inv_normalization(y_hat=output, y_idx=y_idx)
if not output_from_scaled_distribution:
output = self._inv_normalization(y_hat=output, y_idx=y_idx)
valid_loss = self.valid_loss(
y=outsample_y, y_hat=output, y_insample=insample_y, mask=outsample_mask
)
Expand Down