Skip to content

Commit 5e7ab79

Browse files
authored
Merge pull request #10 from numericalEFT/pchou
add unittests and improve several type checkings
2 parents 6bea2cc + c112593 commit 5e7ab79

File tree

10 files changed

+980
-38
lines changed

10 files changed

+980
-38
lines changed

src/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ class BaseDistribution(nn.Module):
1212
def __init__(self, bounds, device="cpu", dtype=torch.float64):
1313
super().__init__()
1414
self.dtype = dtype
15-
# self.bounds = bounds
1615
if isinstance(bounds, (list, np.ndarray)):
1716
self.bounds = torch.tensor(bounds, dtype=dtype, device=device)
1817
elif isinstance(bounds, torch.Tensor):
19-
self.bounds = bounds
18+
self.bounds = bounds.to(dtype=dtype, device=device)
2019
else:
21-
raise ValueError("Unsupported map specification")
20+
raise ValueError("'bounds' must be a list, numpy array, or torch tensor.")
21+
2222
self.dim = self.bounds.shape[0]
2323
self.device = device
2424

src/base_test.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import unittest
2+
import torch
3+
import numpy as np
4+
from base import BaseDistribution, Uniform
5+
6+
7+
class TestBaseDistribution(unittest.TestCase):
8+
def setUp(self):
9+
# Common setup for all tests
10+
self.bounds_list = [[0.0, 1.0], [2.0, 3.0]]
11+
self.bounds_tensor = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
12+
self.device = "cpu"
13+
self.dtype = torch.float64
14+
15+
def test_init_with_list(self):
16+
base_dist = BaseDistribution(self.bounds_list, self.device, self.dtype)
17+
self.assertEqual(base_dist.bounds.tolist(), self.bounds_list)
18+
self.assertEqual(base_dist.dim, 2)
19+
self.assertEqual(base_dist.device, self.device)
20+
self.assertEqual(base_dist.dtype, self.dtype)
21+
22+
def test_init_with_tensor(self):
23+
base_dist = BaseDistribution(self.bounds_tensor, self.device, self.dtype)
24+
self.assertTrue(torch.equal(base_dist.bounds, self.bounds_tensor))
25+
self.assertEqual(base_dist.dim, 2)
26+
self.assertEqual(base_dist.device, self.device)
27+
self.assertEqual(base_dist.dtype, self.dtype)
28+
29+
def test_init_with_invalid_bounds(self):
30+
with self.assertRaises(ValueError):
31+
BaseDistribution("invalid_bounds", self.device, self.dtype)
32+
33+
def test_sample_not_implemented(self):
34+
base_dist = BaseDistribution(self.bounds_list, self.device, self.dtype)
35+
with self.assertRaises(NotImplementedError):
36+
base_dist.sample()
37+
38+
def tearDown(self):
39+
# Common teardown for all tests
40+
pass
41+
42+
43+
class TestUniform(unittest.TestCase):
44+
def setUp(self):
45+
# Common setup for all tests
46+
self.bounds_list = [[0.0, 1.0], [2.0, 3.0]]
47+
self.bounds_tensor = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
48+
self.device = "cpu"
49+
self.dtype = torch.float64
50+
self.uniform_dist = Uniform(self.bounds_list, self.device, self.dtype)
51+
52+
def test_init_with_list(self):
53+
self.assertEqual(self.uniform_dist.bounds.tolist(), self.bounds_list)
54+
self.assertEqual(self.uniform_dist.dim, 2)
55+
self.assertEqual(self.uniform_dist.device, self.device)
56+
self.assertEqual(self.uniform_dist.dtype, self.dtype)
57+
58+
def test_init_with_tensor(self):
59+
uniform_dist = Uniform(self.bounds_tensor, self.device, self.dtype)
60+
self.assertTrue(torch.equal(uniform_dist.bounds, self.bounds_tensor))
61+
self.assertEqual(uniform_dist.dim, 2)
62+
self.assertEqual(uniform_dist.device, self.device)
63+
self.assertEqual(uniform_dist.dtype, self.dtype)
64+
65+
def test_sample_within_bounds(self):
66+
nsamples = 1000
67+
samples, log_detJ = self.uniform_dist.sample(nsamples)
68+
self.assertEqual(samples.shape, (nsamples, 2))
69+
self.assertTrue(torch.all(samples[:, 0] >= 0.0))
70+
self.assertTrue(torch.all(samples[:, 0] <= 1.0))
71+
self.assertTrue(torch.all(samples[:, 1] >= 2.0))
72+
self.assertTrue(torch.all(samples[:, 1] <= 3.0))
73+
self.assertEqual(log_detJ.shape, (nsamples,))
74+
self.assertTrue(
75+
torch.allclose(
76+
log_detJ, torch.tensor([np.log(1.0) + np.log(1.0)] * nsamples)
77+
)
78+
)
79+
80+
def test_sample_with_single_sample(self):
81+
samples, log_detJ = self.uniform_dist.sample(1)
82+
self.assertEqual(samples.shape, (1, 2))
83+
self.assertTrue(torch.all(samples[:, 0] >= 0.0))
84+
self.assertTrue(torch.all(samples[:, 0] <= 1.0))
85+
self.assertTrue(torch.all(samples[:, 1] >= 2.0))
86+
self.assertTrue(torch.all(samples[:, 1] <= 3.0))
87+
self.assertEqual(log_detJ.shape, (1,))
88+
self.assertTrue(
89+
torch.allclose(log_detJ, torch.tensor([np.log(1.0) + np.log(1.0)]))
90+
)
91+
92+
def test_sample_with_zero_samples(self):
93+
samples, log_detJ = self.uniform_dist.sample(0)
94+
self.assertEqual(samples.shape, (0, 2))
95+
self.assertEqual(log_detJ.shape, (0,))
96+
97+
def tearDown(self):
98+
# Common teardown for all tests
99+
pass
100+
101+
102+
if __name__ == "__main__":
103+
unittest.main()

src/integrators.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def __call__(self, f: Callable, **kwargs):
103103
results = np.array([RAvg() for _ in range(f_size)])
104104
for i in range(f_size):
105105
_mean = values[:, i].mean().item()
106-
_std = values[:, i].std().item() / self.nbatch**0.5
107-
results[i].add(gvar.gvar(_mean, _std))
106+
_var = values[:, i].var().item() / self.nbatch
107+
results[i].update(_mean, _var, self.neval)
108108
if f_size == 1:
109109
return results[0]
110110
else:
@@ -247,15 +247,17 @@ def one_step(current_y, current_x, current_weight, current_jac):
247247
results_ref = RAvg()
248248

249249
mean_ref = refvalues.mean().item()
250-
std_ref = refvalues.std().item() / self.nbatch**0.5
250+
var_ref = refvalues.var().item() / self.nbatch
251251

252-
results_ref.add(gvar.gvar(mean_ref, std_ref))
252+
results_ref.update(mean_ref, var_ref, self.neval)
253253
for i in range(f_size):
254254
_mean = values[:, i].mean().item()
255-
_std = values[:, i].std().item() / self.nbatch**0.5
256-
results[i].add(gvar.gvar(_mean, _std))
255+
_var = values[:, i].var().item() / self.nbatch
256+
results[i].update(_mean, _var, self.neval)
257257

258258
if f_size == 1:
259-
return results[0] / results_ref * self._rangebounds.prod()
259+
res = results[0] / results_ref * self._rangebounds.prod()
260+
result = RAvg(itn_results=[res], sum_neval=self.neval)
261+
return result
260262
else:
261263
return results / results_ref * self._rangebounds.prod().item()

0 commit comments

Comments
 (0)