Skip to content

Commit 090d6e1

Browse files
authored
Merge pull request #2382 from fjzzq2002/master
Fix training status of noise model of `HeteroskedasticNoise` after exceptions
2 parents 938d4f9 + 45b28a7 commit 090d6e1

File tree

2 files changed

+47
-7
lines changed

2 files changed

+47
-7
lines changed

gpytorch/likelihoods/noise_models.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,15 @@ def forward(
125125
if noise is not None:
126126
return DiagLinearOperator(noise)
127127
training = self.noise_model.training # keep track of mode
128-
self.noise_model.eval() # we want the posterior prediction of the noise model
129-
with settings.detach_test_caches(False), settings.debug(False):
130-
if len(params) == 1 and not torch.is_tensor(params[0]):
131-
output = self.noise_model(*params[0])
132-
else:
133-
output = self.noise_model(*params)
134-
self.noise_model.train(training)
128+
try:
129+
self.noise_model.eval() # we want the posterior prediction of the noise model
130+
with settings.detach_test_caches(False), settings.debug(False):
131+
if len(params) == 1 and not torch.is_tensor(params[0]):
132+
output = self.noise_model(*params[0])
133+
else:
134+
output = self.noise_model(*params)
135+
finally:
136+
self.noise_model.train(training)
135137
if not isinstance(output, MultivariateNormal):
136138
raise NotImplementedError("Currently only noise models that return a MultivariateNormal are supported")
137139
# note: this also works with MultitaskMultivariateNormal, where this
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#!/usr/bin/env python3
2+
3+
import unittest
4+
5+
import torch
6+
7+
import gpytorch
8+
from gpytorch.likelihoods import HeteroskedasticNoise
9+
10+
11+
class NumericallyUnstableModelExample(gpytorch.models.GP):
12+
def __init__(self):
13+
super(NumericallyUnstableModelExample, self).__init__()
14+
self.fail_arithmetic = False
15+
16+
def train(self, mode=True):
17+
if mode:
18+
self.fail_arithmetic = False # reset on .train()
19+
super().train(mode=mode)
20+
21+
def forward(self, x):
22+
if self.fail_arithmetic:
23+
raise ArithmeticError()
24+
return gpytorch.distributions.MultivariateNormal(torch.tensor([-3.0]), torch.tensor([[2.0]]))
25+
26+
27+
class TestNoiseModels(unittest.TestCase):
28+
def test_heteroskedasticnoise_error(self):
29+
noise_model = NumericallyUnstableModelExample().to(torch.double)
30+
likelihood = HeteroskedasticNoise(noise_model)
31+
self.assertEqual(noise_model.training, True)
32+
self.assertEqual(likelihood.training, True)
33+
noise_model.fail_arithmetic = True
34+
test_x = torch.tensor([[3.0]])
35+
with self.assertRaises(ArithmeticError):
36+
likelihood(test_x)
37+
self.assertEqual(likelihood.training, True)
38+
likelihood(test_x)

0 commit comments

Comments
 (0)