File tree Expand file tree Collapse file tree 2 files changed +47
-7
lines changed
Expand file tree Collapse file tree 2 files changed +47
-7
lines changed Original file line number Diff line number Diff line change @@ -125,13 +125,15 @@ def forward(
125125 if noise is not None :
126126 return DiagLinearOperator (noise )
127127 training = self .noise_model .training # keep track of mode
128- self .noise_model .eval () # we want the posterior prediction of the noise model
129- with settings .detach_test_caches (False ), settings .debug (False ):
130- if len (params ) == 1 and not torch .is_tensor (params [0 ]):
131- output = self .noise_model (* params [0 ])
132- else :
133- output = self .noise_model (* params )
134- self .noise_model .train (training )
128+ try :
129+ self .noise_model .eval () # we want the posterior prediction of the noise model
130+ with settings .detach_test_caches (False ), settings .debug (False ):
131+ if len (params ) == 1 and not torch .is_tensor (params [0 ]):
132+ output = self .noise_model (* params [0 ])
133+ else :
134+ output = self .noise_model (* params )
135+ finally :
136+ self .noise_model .train (training )
135137 if not isinstance (output , MultivariateNormal ):
136138 raise NotImplementedError ("Currently only noise models that return a MultivariateNormal are supported" )
137139 # note: this also works with MultitaskMultivariateNormal, where this
Original file line number Diff line number Diff line change 1+ #!/usr/bin/env python3
2+
3+ import unittest
4+
5+ import torch
6+
7+ import gpytorch
8+ from gpytorch .likelihoods import HeteroskedasticNoise
9+
10+
11+ class NumericallyUnstableModelExample (gpytorch .models .GP ):
12+ def __init__ (self ):
13+ super (NumericallyUnstableModelExample , self ).__init__ ()
14+ self .fail_arithmetic = False
15+
16+ def train (self , mode = True ):
17+ if mode :
18+ self .fail_arithmetic = False # reset on .train()
19+ super ().train (mode = mode )
20+
21+ def forward (self , x ):
22+ if self .fail_arithmetic :
23+ raise ArithmeticError ()
24+ return gpytorch .distributions .MultivariateNormal (torch .tensor ([- 3.0 ]), torch .tensor ([[2.0 ]]))
25+
26+
27+ class TestNoiseModels (unittest .TestCase ):
28+ def test_heteroskedasticnoise_error (self ):
29+ noise_model = NumericallyUnstableModelExample ().to (torch .double )
30+ likelihood = HeteroskedasticNoise (noise_model )
31+ self .assertEqual (noise_model .training , True )
32+ self .assertEqual (likelihood .training , True )
33+ noise_model .fail_arithmetic = True
34+ test_x = torch .tensor ([[3.0 ]])
35+ with self .assertRaises (ArithmeticError ):
36+ likelihood (test_x )
37+ self .assertEqual (likelihood .training , True )
38+ likelihood (test_x )
You can’t perform that action at this time.
0 commit comments