|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | import torch |
8 | | -from botorch.posteriors.deterministic import DeterministicPosterior |
9 | 8 | from botorch.posteriors.gpytorch import GPyTorchPosterior |
10 | 9 | from botorch.posteriors.posterior_list import PosteriorList |
11 | 10 | from botorch.posteriors.torch import TorchPosterior |
12 | 11 | from botorch.posteriors.transformed import TransformedPosterior |
13 | 12 | from botorch.sampling.get_sampler import get_sampler |
14 | 13 | from botorch.sampling.list_sampler import ListSampler |
15 | 14 | from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler |
16 | | -from botorch.sampling.stochastic_samplers import StochasticSampler |
17 | 15 | from botorch.utils.testing import BotorchTestCase |
18 | 16 | from gpytorch.distributions import MultivariateNormal |
19 | 17 | from torch.distributions.gamma import Gamma |
|
22 | 20 | class TestGetSampler(BotorchTestCase): |
23 | 21 | def test_get_sampler(self): |
24 | 22 | # Basic usage w/ gpytorch posterior. |
25 | | - posterior = GPyTorchPosterior( |
| 23 | + mvn_posterior = GPyTorchPosterior( |
26 | 24 | distribution=MultivariateNormal(torch.rand(2), torch.eye(2)) |
27 | 25 | ) |
| 26 | + seed = 2 |
| 27 | + n_samples = 10 |
28 | 28 | sampler = get_sampler( |
29 | | - posterior=posterior, sample_shape=torch.Size([10]), seed=2 |
| 29 | + posterior=mvn_posterior, sample_shape=torch.Size([n_samples]), seed=seed |
30 | 30 | ) |
31 | 31 | self.assertIsInstance(sampler, SobolQMCNormalSampler) |
32 | | - self.assertEqual(sampler.seed, 2) |
33 | | - self.assertEqual(sampler.sample_shape, torch.Size([10])) |
| 32 | + self.assertEqual(sampler.seed, seed) |
| 33 | + self.assertEqual(sampler.sample_shape, torch.Size([n_samples])) |
34 | 34 |
|
35 | 35 | # Fallback to IID sampler. |
36 | | - posterior = GPyTorchPosterior( |
| 36 | + big_mvn_posterior = GPyTorchPosterior( |
37 | 37 | distribution=MultivariateNormal(torch.rand(22000), torch.eye(22000)) |
38 | 38 | ) |
39 | | - sampler = get_sampler(posterior=posterior, sample_shape=torch.Size([10])) |
| 39 | + sampler = get_sampler( |
| 40 | + posterior=big_mvn_posterior, sample_shape=torch.Size([n_samples]) |
| 41 | + ) |
40 | 42 | self.assertIsInstance(sampler, IIDNormalSampler) |
41 | | - self.assertEqual(sampler.sample_shape, torch.Size([10])) |
| 43 | + self.assertEqual(sampler.sample_shape, torch.Size([n_samples])) |
42 | 44 |
|
43 | 45 | # Transformed posterior. |
44 | 46 | tf_post = TransformedPosterior( |
45 | | - posterior=posterior, sample_transform=lambda X: X |
| 47 | + posterior=big_mvn_posterior, sample_transform=lambda X: X |
46 | 48 | ) |
47 | | - sampler = get_sampler(posterior=tf_post, sample_shape=torch.Size([10])) |
| 49 | + sampler = get_sampler(posterior=tf_post, sample_shape=torch.Size([n_samples])) |
48 | 50 | self.assertIsInstance(sampler, IIDNormalSampler) |
49 | | - self.assertEqual(sampler.sample_shape, torch.Size([10])) |
| 51 | + self.assertEqual(sampler.sample_shape, torch.Size([n_samples])) |
50 | 52 |
|
51 | | - # PosteriorList with transformed & deterministic. |
52 | | - post_list = PosteriorList( |
53 | | - tf_post, DeterministicPosterior(values=torch.rand(1, 2)) |
54 | | - ) |
| 53 | + # PosteriorList with transformed & original |
| 54 | + post_list = PosteriorList(tf_post, mvn_posterior) |
55 | 55 | sampler = get_sampler(posterior=post_list, sample_shape=torch.Size([5])) |
56 | 56 | self.assertIsInstance(sampler, ListSampler) |
57 | 57 | self.assertIsInstance(sampler.samplers[0], IIDNormalSampler) |
58 | | - self.assertIsInstance(sampler.samplers[1], StochasticSampler) |
| 58 | + self.assertIsInstance(sampler.samplers[1], SobolQMCNormalSampler) |
59 | 59 | for s in sampler.samplers: |
60 | 60 | self.assertEqual(s.sample_shape, torch.Size([5])) |
61 | 61 |
|
|
0 commit comments