Skip to content

Commit f50a9f8

Browse files
authored
Merge branch 'master' into dGPFantasize
2 parents 6c2fd48 + ee35601 commit f50a9f8

File tree

7 files changed

+49
-12
lines changed

7 files changed

+49
-12
lines changed

.conda/meta.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ requirements:
1818
run:
1919
- pytorch>=1.11
2020
- scikit-learn
21-
- linear_operator>=0.2.0
21+
- linear_operator>=0.4.0
2222

2323
test:
2424
imports:

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@ sphinx_autodoc_typehints
88
nbsphinx
99
m2r2
1010
pyro-ppl
11-
linear_operator>=0.2.0
11+
linear_operator>=0.4.0
1212
torch>=1.11

gpytorch/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
mean_standardized_log_loss,
55
negative_log_predictive_density,
66
quantile_coverage_error,
7+
standardized_mean_squared_error,
78
)
89

910
__all__ = [
1011
"mean_absolute_error",
1112
"mean_squared_error",
13+
"standardized_mean_squared_error",
1214
"mean_standardized_log_loss",
1315
"negative_log_predictive_density",
1416
"quantile_coverage_error",

gpytorch/metrics/metrics.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from math import pi
2+
from typing import Optional
23

34
import torch
45

@@ -12,7 +13,7 @@ def mean_absolute_error(
1213
test_y: torch.Tensor,
1314
):
1415
"""
15-
Mean Absolute Error.
16+
Mean absolute error.
1617
"""
1718
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
1819
return torch.abs(pred_dist.mean - test_y).mean(dim=combine_dim)
@@ -24,7 +25,7 @@ def mean_squared_error(
2425
squared: bool = True,
2526
):
2627
"""
27-
Mean Squared Error.
28+
Mean squared error.
2829
"""
2930
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
3031
res = torch.square(pred_dist.mean - test_y).mean(dim=combine_dim)
@@ -33,29 +34,59 @@ def mean_squared_error(
3334
return res
3435

3536

37+
def standardized_mean_squared_error(
38+
pred_dist: MultivariateNormal,
39+
test_y: torch.Tensor,
40+
):
41+
"""Standardized mean squared error.
42+
43+
Standardizes the mean squared error by the variance of the test data.
44+
"""
45+
return mean_squared_error(pred_dist, test_y, squared=True) / test_y.var()
46+
47+
3648
def negative_log_predictive_density(
3749
pred_dist: MultivariateNormal,
3850
test_y: torch.Tensor,
3951
):
52+
"""Negative log predictive density.
53+
54+
Computes the negative predictive log density normalized by the size of the test data.
55+
"""
4056
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
4157
return -pred_dist.log_prob(test_y) / test_y.shape[combine_dim]
4258

4359

4460
def mean_standardized_log_loss(
4561
pred_dist: MultivariateNormal,
4662
test_y: torch.Tensor,
63+
train_y: Optional[torch.Tensor] = None,
4764
):
4865
"""
49-
Mean Standardized Log Loss.
50-
Reference: Page No. 23,
51-
Gaussian Processes for Machine Learning,
52-
Carl Edward Rasmussen and Christopher K. I. Williams,
53-
The MIT Press, 2006. ISBN 0-262-18253-X
66+
Mean standardized log loss.
67+
68+
Computes the average *standardized* log loss, which subtracts the loss obtained
69+
under the trivial model which predicts with the mean and variance of the training
70+
data from the mean log loss. See p.23 of Rasmussen and Williams (2006).
71+
72+
If no training data is supplied, the mean log loss is computed.
5473
"""
5574
combine_dim = -2 if isinstance(pred_dist, MultitaskMultivariateNormal) else -1
75+
5676
f_mean = pred_dist.mean
5777
f_var = pred_dist.variance
58-
return (0.5 * torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim)
78+
loss_model = (0.5 * torch.log(2 * pi * f_var) + torch.square(test_y - f_mean) / (2 * f_var)).mean(dim=combine_dim)
79+
res = loss_model
80+
81+
if train_y is not None:
82+
data_mean = train_y.mean(dim=combine_dim)
83+
data_var = train_y.var()
84+
loss_trivial_model = (
85+
0.5 * torch.log(2 * pi * data_var) + torch.square(test_y - data_mean) / (2 * data_var)
86+
).mean(dim=combine_dim)
87+
res = res - loss_trivial_model
88+
89+
return res
5990

6091

6192
def quantile_coverage_error(

gpytorch/variational/nearest_neighbor_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def _stochastic_kl_helper(self, kl_indices: Tensor) -> Tensor:
293293

294294
# compute interp_term
295295
cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors)
296-
cross_cov = self.model.covar_module.forward(nearest_neighbors, inducing_points.unsqueeze(-2))
296+
cross_cov = to_dense(self.model.covar_module.forward(nearest_neighbors, inducing_points.unsqueeze(-2)))
297297
interp_term = torch.linalg.solve(
298298
cov + self.jitter_val * torch.eye(self.k, device=self.inducing_points.device), cross_cov
299299
).squeeze(-1)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def find_version(*file_paths):
4040
torch_min = "1.11"
4141
install_requires = [
4242
"scikit-learn",
43-
"linear_operator>=0.2.0",
43+
"linear_operator>=0.4.0",
4444
]
4545
# if recent dev version of PyTorch is installed, no need to install stable
4646
try:

test/metrics/test_metrics.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
mean_standardized_log_loss,
1515
negative_log_predictive_density,
1616
quantile_coverage_error,
17+
standardized_mean_squared_error,
1718
)
1819
from gpytorch.models import ExactGP
1920

@@ -126,6 +127,9 @@ def test_negative_log_predictive_density(self):
126127
def test_mean_standardized_log_loss(self):
127128
self._test_metric(mean_standardized_log_loss)
128129

130+
def test_standardized_mean_squared_error(self):
131+
self._test_metric(standardized_mean_squared_error)
132+
129133
def test_quantile_coverage_error(self):
130134
self._test_metric(
131135
quantile_coverage_error,

0 commit comments

Comments
 (0)