Skip to content

Commit f00936b

Browse files
Balandatfacebook-github-bot
authored andcommitted
Move mock utils from test folder to botorch.utils.mock (#171)
Summary: Pull Request resolved: #171 This makes it easier to import MockModel and MockPosterior for testing purposes. Reviewed By: danielrjiang Differential Revision: D15795194 fbshipit-source-id: 5ba3acbf908850dd5e7c2125e4b121e5dfd7165d
1 parent ec4aca7 commit f00936b

File tree

6 files changed

+26
-13
lines changed

6 files changed

+26
-13
lines changed

test/mock.py renamed to botorch/utils/mock.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from typing import List, Optional
77

88
import torch
9-
from botorch.models.model import Model
10-
from botorch.posteriors import Posterior
119
from torch import Tensor
1210

11+
from ..models.model import Model
12+
from ..posteriors import Posterior
13+
1314

1415
EMPTY_SIZE = torch.Size()
1516

test/acquisition/test_analytic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
)
1818
from botorch.exceptions import UnsupportedError
1919
from botorch.models import FixedNoiseGP, SingleTaskGP
20-
21-
from ..mock import MockModel, MockPosterior
20+
from botorch.utils.mock import MockModel, MockPosterior
2221

2322

2423
NEI_NOISE = [-0.099, -0.004, 0.227, -0.182, 0.018, 0.334, -0.270, 0.156, -0.237, 0.052]

test/acquisition/test_monte_carlo.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,7 @@
1414
qUpperConfidenceBound,
1515
)
1616
from botorch.acquisition.sampler import IIDNormalSampler, SobolQMCNormalSampler
17-
18-
from ..mock import MockModel, MockPosterior
19-
20-
21-
# TODO: T41739913 Implement tests for all MCAcquisitionFunctions
17+
from botorch.utils.mock import MockModel, MockPosterior
2218

2319

2420
class TestMCAcquisitionFunction(unittest.TestCase):

test/acquisition/test_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@
99
from botorch.acquisition import monte_carlo, utils
1010
from botorch.acquisition.objective import MCAcquisitionObjective
1111
from botorch.acquisition.sampler import IIDNormalSampler, SobolQMCNormalSampler
12+
from botorch.utils.mock import MockModel, MockPosterior
1213
from torch import Tensor
1314

14-
from ..mock import MockModel, MockPosterior
15-
1615

1716
class DummyMCObjective(MCAcquisitionObjective):
1817
def forward(self, samples: Tensor) -> Tensor:

test/exceptions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#! /usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,40 @@
55
import unittest
66

77
import torch
8-
9-
from .mock import MockModel, MockPosterior
8+
from botorch.utils.mock import MockModel, MockPosterior
109

1110

1211
class TestMock(unittest.TestCase):
1312
def test_MockPosterior(self):
13+
# test basic logic
14+
mp = MockPosterior()
15+
self.assertEqual(mp.device.type, "cpu")
16+
self.assertEqual(mp.dtype, torch.float32)
17+
self.assertEqual(mp.event_shape, torch.Size())
18+
self.assertEqual(
19+
MockPosterior(variance=torch.rand(2)).event_shape, torch.Size([2])
20+
)
21+
# test passing in tensors
1422
mean = torch.rand(2)
1523
variance = torch.eye(2)
1624
samples = torch.rand(1, 2)
1725
mp = MockPosterior(mean=mean, variance=variance, samples=samples)
26+
self.assertEqual(mp.device.type, "cpu")
27+
self.assertEqual(mp.dtype, torch.float32)
1828
self.assertTrue(torch.equal(mp.mean, mean))
1929
self.assertTrue(torch.equal(mp.variance, variance))
2030
self.assertTrue(torch.all(mp.sample() == samples.unsqueeze(0)))
2131
self.assertTrue(
2232
torch.all(mp.sample(torch.Size([2])) == samples.repeat(2, 1, 1))
2333
)
34+
with self.assertRaises(RuntimeError):
35+
mp.sample(sample_shape=torch.Size([2]), base_samples=torch.rand(3))
2436

2537
def test_MockModel(self):
2638
mp = MockPosterior()
2739
mm = MockModel(mp)
2840
X = torch.empty(0)
2941
self.assertEqual(mm.posterior(X), mp)
42+
self.assertEqual(mm.num_outputs, 0)
43+
mm.state_dict()
44+
mm.load_state_dict()

0 commit comments

Comments
 (0)