11from math import pi
2+ from typing import Optional
23
34import torch
45
@@ -12,7 +13,7 @@ def mean_absolute_error(
1213 test_y : torch .Tensor ,
1314):
1415 """
15- Mean Absolute Error .
16+ Mean absolute error .
1617 """
1718 combine_dim = - 2 if isinstance (pred_dist , MultitaskMultivariateNormal ) else - 1
1819 return torch .abs (pred_dist .mean - test_y ).mean (dim = combine_dim )
@@ -24,7 +25,7 @@ def mean_squared_error(
2425 squared : bool = True ,
2526):
2627 """
27- Mean Squared Error .
28+ Mean squared error .
2829 """
2930 combine_dim = - 2 if isinstance (pred_dist , MultitaskMultivariateNormal ) else - 1
3031 res = torch .square (pred_dist .mean - test_y ).mean (dim = combine_dim )
@@ -33,29 +34,59 @@ def mean_squared_error(
3334 return res
3435
3536
37+ def standardized_mean_squared_error (
38+ pred_dist : MultivariateNormal ,
39+ test_y : torch .Tensor ,
40+ ):
41+ """Standardized mean squared error.
42+
43+ Standardizes the mean squared error by the variance of the test data.
44+ """
45+ return mean_squared_error (pred_dist , test_y , squared = True ) / test_y .var ()
46+
47+
3648def negative_log_predictive_density (
3749 pred_dist : MultivariateNormal ,
3850 test_y : torch .Tensor ,
3951):
52+ """Negative log predictive density.
53+
54+ Computes the negative predictive log density normalized by the size of the test data.
55+ """
4056 combine_dim = - 2 if isinstance (pred_dist , MultitaskMultivariateNormal ) else - 1
4157 return - pred_dist .log_prob (test_y ) / test_y .shape [combine_dim ]
4258
4359
4460def mean_standardized_log_loss (
4561 pred_dist : MultivariateNormal ,
4662 test_y : torch .Tensor ,
63+ train_y : Optional [torch .Tensor ] = None ,
4764):
4865 """
49- Mean Standardized Log Loss.
50- Reference: Page No. 23,
51- Gaussian Processes for Machine Learning,
52- Carl Edward Rasmussen and Christopher K. I. Williams,
53- The MIT Press, 2006. ISBN 0-262-18253-X
66+ Mean standardized log loss.
67+
68+ Computes the average *standardized* log loss, which subtracts the loss obtained
69+ under the trivial model which predicts with the mean and variance of the training
70+ data from the mean log loss. See p.23 of Rasmussen and Williams (2006).
71+
72+ If no training data is supplied, the mean log loss is computed.
5473 """
5574 combine_dim = - 2 if isinstance (pred_dist , MultitaskMultivariateNormal ) else - 1
75+
5676 f_mean = pred_dist .mean
5777 f_var = pred_dist .variance
58- return (0.5 * torch .log (2 * pi * f_var ) + torch .square (test_y - f_mean ) / (2 * f_var )).mean (dim = combine_dim )
78+ loss_model = (0.5 * torch .log (2 * pi * f_var ) + torch .square (test_y - f_mean ) / (2 * f_var )).mean (dim = combine_dim )
79+ res = loss_model
80+
81+ if train_y is not None :
82+ data_mean = train_y .mean (dim = combine_dim )
83+ data_var = train_y .var ()
84+ loss_trivial_model = (
85+ 0.5 * torch .log (2 * pi * data_var ) + torch .square (test_y - data_mean ) / (2 * data_var )
86+ ).mean (dim = combine_dim )
87+ res = res - loss_trivial_model
88+
89+ return res
5990
6091
6192def quantile_coverage_error (
0 commit comments