Skip to content

Commit aeb5e23

Browse files
authored
Fix error messages for ApproximateGP.get_fantasy_model (#2374)
[Fixes #2370]
1 parent d19f52e commit aeb5e23

10 files changed

+42
-24
lines changed

gpytorch/variational/_variational_strategy.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -264,32 +264,38 @@ def get_fantasy_model(
264264
# whitened / unwhitened variational strategies
265265
if not self.has_fantasy_strategy:
266266
raise NotImplementedError(
267-
"No fantasy model support for ",
268-
self.__name__,
269-
". Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.",
267+
f"No fantasy model support for {self.__class__.__name__}. "
268+
"Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported."
270269
)
270+
else:
271+
from . import CholeskyVariationalDistribution # Circular import otherwise
272+
273+
if not isinstance(self._variational_distribution, CholeskyVariationalDistribution):
274+
raise NotImplementedError(
275+
"Fantasy models are only support for variational models with CholeskyVariationalDistribution."
276+
)
277+
271278
if not isinstance(self.model.likelihood, GaussianLikelihood):
272279
raise NotImplementedError(
273-
"No fantasy model support for ",
274-
self.model.likelihood,
275-
". Only GaussianLikelihoods are currently supported.",
280+
f"No fantasy model support for {self.model.likelihood.__class__.__name__}. "
281+
"Only GaussianLikelihoods are currently supported."
276282
)
277283
# we assume that either the user has given the model a mean_module and a covar_module
278284
# or that it will be passed into the get_fantasy_model function. we check for these.
279285
if mean_module is None:
280286
mean_module = getattr(self.model, "mean_module", None)
281287
if mean_module is None:
282288
raise ModuleNotFoundError(
283-
"Either you must provide a mean_module as input to get_fantasy_model",
284-
"or it must be an attribute of the model called mean_module.",
289+
"Either you must provide a mean_module as input to get_fantasy_model "
290+
"or it must be an attribute of the model called mean_module."
285291
)
286292
if covar_module is None:
287293
covar_module = getattr(self.model, "covar_module", None)
288294
if covar_module is None:
289295
# raise an error
290296
raise ModuleNotFoundError(
291-
"Either you must provide a covar_module as input to get_fantasy_model",
292-
"or it must be an attribute of the model called covar_module.",
297+
"Either you must provide a covar_module as input to get_fantasy_model "
298+
"or it must be an attribute of the model called covar_module."
293299
)
294300

295301
# first we construct an exact model over the inducing points with the inducing covariance

gpytorch/variational/nearest_neighbor_variational_strategy.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from torch import LongTensor, Tensor
1010

1111
from ..distributions import MultivariateNormal
12-
from ..models import ApproximateGP
12+
from ..models import ApproximateGP, ExactGP
13+
from ..module import Module
1314
from ..utils.errors import CachingError
1415
from ..utils.memoize import add_to_cache, cached, pop_from_cache
1516
from ..utils.nearest_neighbors import NNUtil
@@ -238,6 +239,19 @@ def forward(
238239
# Return the distribution
239240
return MultivariateNormal(predictive_mean, DiagLinearOperator(predictive_var))
240241

242+
def get_fantasy_model(
243+
self,
244+
inputs: Tensor,
245+
targets: Tensor,
246+
mean_module: Optional[Module] = None,
247+
covar_module: Optional[Module] = None,
248+
**kwargs,
249+
) -> ExactGP:
250+
raise NotImplementedError(
251+
f"No fantasy model support for {self.__class__.__name__}. "
252+
"Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported."
253+
)
254+
241255
def _set_training_iterator(self) -> None:
242256
self._training_indices_iter = 0
243257
training_indices = torch.randperm(self.M - self.k, device=self.inducing_points.device) + self.k

test/variational/test_ciq_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_eval_iteration(self, *args, **kwargs):
3838
self.assertEqual(ciq_mock.call_count, 2) # One for each evaluation call
3939

4040
def test_fantasy_call(self, *args, **kwargs):
41-
with self.assertRaises(AttributeError):
41+
with self.assertRaises(NotImplementedError):
4242
super().test_fantasy_call(*args, **kwargs)
4343

4444

test/variational/test_grid_interpolation_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_eval_iteration(self, *args, **kwargs):
7676
self.assertFalse(ciq_mock.called)
7777

7878
def test_fantasy_call(self, *args, **kwargs):
79-
with self.assertRaises(AttributeError):
79+
with self.assertRaises(NotImplementedError):
8080
super().test_fantasy_call(*args, **kwargs)
8181

8282

test/variational/test_independent_multitask_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def test_eval_iteration(self, *args, expected_batch_shape=None, **kwargs):
6161
super().test_eval_iteration(*args, expected_batch_shape=expected_batch_shape, **kwargs)
6262

6363
def test_fantasy_call(self, *args, **kwargs):
64-
with self.assertRaises(AttributeError):
64+
with self.assertRaises(NotImplementedError):
6565
super().test_fantasy_call(*args, **kwargs)
6666

6767

test/variational/test_lmc_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_eval_iteration(self, *args, expected_batch_shape=None, **kwargs):
6969
self.assertFalse(ciq_mock.called)
7070

7171
def test_fantasy_call(self, *args, **kwargs):
72-
with self.assertRaises(AttributeError):
72+
with self.assertRaises(NotImplementedError):
7373
super().test_fantasy_call(*args, **kwargs)
7474

7575

test/variational/test_nearest_neighbor_variational_strategy.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,13 @@ def _make_model_and_likelihood(
4242
distribution_cls=gpytorch.variational.MeanFieldVariationalDistribution,
4343
constant_mean=True,
4444
):
45-
class _VNNGPRegressionModel(gpytorch.models.GP):
45+
class _VNNGPRegressionModel(gpytorch.models.ApproximateGP):
4646
def __init__(self, inducing_points, k, training_batch_size):
47-
super(_VNNGPRegressionModel, self).__init__()
48-
4947
variational_distribution = distribution_cls(num_inducing, batch_shape=batch_shape)
50-
51-
self.variational_strategy = strategy_cls(
48+
variational_strategy = strategy_cls(
5249
self, inducing_points, variational_distribution, k=k, training_batch_size=training_batch_size
5350
)
51+
super().__init__(variational_strategy)
5452

5553
if constant_mean:
5654
self.mean_module = gpytorch.means.ConstantMean()
@@ -252,7 +250,7 @@ def test_eval_larger_pred_batch(self):
252250
)
253251

254252
def test_fantasy_call(self, *args, **kwargs):
255-
with self.assertRaises(AttributeError):
253+
with self.assertRaises(NotImplementedError):
256254
super().test_fantasy_call(*args, **kwargs)
257255

258256

test/variational/test_orthogonally_decoupled_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_eval_iteration(self, *args, **kwargs):
5858
self.assertEqual(cholesky_mock.call_count, 1) # One to compute cache, that's it!
5959

6060
def test_fantasy_call(self, *args, **kwargs):
61-
with self.assertRaises(AttributeError):
61+
with self.assertRaises(NotImplementedError):
6262
super().test_fantasy_call(*args, **kwargs)
6363

6464

test/variational/test_unwhitened_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_fantasy_call(self, *args, **kwargs):
4545
if self.distribution_cls is gpytorch.variational.CholeskyVariationalDistribution:
4646
return super().test_fantasy_call(*args, **kwargs)
4747

48-
with self.assertRaises(AttributeError):
48+
with self.assertRaises(NotImplementedError):
4949
super().test_fantasy_call(*args, **kwargs)
5050

5151

test/variational/test_variational_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def test_fantasy_call(self, *args, **kwargs):
4242
if self.distribution_cls is gpytorch.variational.CholeskyVariationalDistribution:
4343
return super().test_fantasy_call(*args, **kwargs)
4444

45-
with self.assertRaises(AttributeError):
45+
with self.assertRaises(NotImplementedError):
4646
super().test_fantasy_call(*args, **kwargs)
4747

4848

0 commit comments

Comments
 (0)