Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ class BaseDistribution(nn.Module):
def __init__(self, bounds, device="cpu", dtype=torch.float64):
super().__init__()
self.dtype = dtype
# self.bounds = bounds
if isinstance(bounds, (list, np.ndarray)):
self.bounds = torch.tensor(bounds, dtype=dtype, device=device)
elif isinstance(bounds, torch.Tensor):
self.bounds = bounds
self.bounds = bounds.to(dtype=dtype, device=device)
else:
raise ValueError("Unsupported map specification")
raise ValueError("'bounds' must be a list, numpy array, or torch tensor.")

self.dim = self.bounds.shape[0]
self.device = device

Expand Down
103 changes: 103 additions & 0 deletions src/base_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import unittest
import torch
import numpy as np
from base import BaseDistribution, Uniform


class TestBaseDistribution(unittest.TestCase):
def setUp(self):
# Common setup for all tests
self.bounds_list = [[0.0, 1.0], [2.0, 3.0]]
self.bounds_tensor = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
self.device = "cpu"
self.dtype = torch.float64

def test_init_with_list(self):
base_dist = BaseDistribution(self.bounds_list, self.device, self.dtype)
self.assertEqual(base_dist.bounds.tolist(), self.bounds_list)
self.assertEqual(base_dist.dim, 2)
self.assertEqual(base_dist.device, self.device)
self.assertEqual(base_dist.dtype, self.dtype)

def test_init_with_tensor(self):
base_dist = BaseDistribution(self.bounds_tensor, self.device, self.dtype)
self.assertTrue(torch.equal(base_dist.bounds, self.bounds_tensor))
self.assertEqual(base_dist.dim, 2)
self.assertEqual(base_dist.device, self.device)
self.assertEqual(base_dist.dtype, self.dtype)

def test_init_with_invalid_bounds(self):
with self.assertRaises(ValueError):
BaseDistribution("invalid_bounds", self.device, self.dtype)

def test_sample_not_implemented(self):
base_dist = BaseDistribution(self.bounds_list, self.device, self.dtype)
with self.assertRaises(NotImplementedError):
base_dist.sample()

def tearDown(self):
# Common teardown for all tests
pass


class TestUniform(unittest.TestCase):
def setUp(self):
# Common setup for all tests
self.bounds_list = [[0.0, 1.0], [2.0, 3.0]]
self.bounds_tensor = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
self.device = "cpu"
self.dtype = torch.float64
self.uniform_dist = Uniform(self.bounds_list, self.device, self.dtype)

def test_init_with_list(self):
self.assertEqual(self.uniform_dist.bounds.tolist(), self.bounds_list)
self.assertEqual(self.uniform_dist.dim, 2)
self.assertEqual(self.uniform_dist.device, self.device)
self.assertEqual(self.uniform_dist.dtype, self.dtype)

def test_init_with_tensor(self):
uniform_dist = Uniform(self.bounds_tensor, self.device, self.dtype)
self.assertTrue(torch.equal(uniform_dist.bounds, self.bounds_tensor))
self.assertEqual(uniform_dist.dim, 2)
self.assertEqual(uniform_dist.device, self.device)
self.assertEqual(uniform_dist.dtype, self.dtype)

def test_sample_within_bounds(self):
nsamples = 1000
samples, log_detJ = self.uniform_dist.sample(nsamples)
self.assertEqual(samples.shape, (nsamples, 2))
self.assertTrue(torch.all(samples[:, 0] >= 0.0))
self.assertTrue(torch.all(samples[:, 0] <= 1.0))
self.assertTrue(torch.all(samples[:, 1] >= 2.0))
self.assertTrue(torch.all(samples[:, 1] <= 3.0))
self.assertEqual(log_detJ.shape, (nsamples,))
self.assertTrue(
torch.allclose(
log_detJ, torch.tensor([np.log(1.0) + np.log(1.0)] * nsamples)
)
)

def test_sample_with_single_sample(self):
samples, log_detJ = self.uniform_dist.sample(1)
self.assertEqual(samples.shape, (1, 2))
self.assertTrue(torch.all(samples[:, 0] >= 0.0))
self.assertTrue(torch.all(samples[:, 0] <= 1.0))
self.assertTrue(torch.all(samples[:, 1] >= 2.0))
self.assertTrue(torch.all(samples[:, 1] <= 3.0))
self.assertEqual(log_detJ.shape, (1,))
self.assertTrue(
torch.allclose(log_detJ, torch.tensor([np.log(1.0) + np.log(1.0)]))
)

def test_sample_with_zero_samples(self):
samples, log_detJ = self.uniform_dist.sample(0)
self.assertEqual(samples.shape, (0, 2))
self.assertEqual(log_detJ.shape, (0,))

def tearDown(self):
# Common teardown for all tests
pass


if __name__ == "__main__":
unittest.main()

Check warning on line 103 in src/base_test.py

View check run for this annotation

Codecov / codecov/patch

src/base_test.py#L103

Added line #L103 was not covered by tests
16 changes: 9 additions & 7 deletions src/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def __call__(self, f: Callable, **kwargs):
results = np.array([RAvg() for _ in range(f_size)])
for i in range(f_size):
_mean = values[:, i].mean().item()
_std = values[:, i].std().item() / self.nbatch**0.5
results[i].add(gvar.gvar(_mean, _std))
_var = values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)
if f_size == 1:
return results[0]
else:
Expand Down Expand Up @@ -247,15 +247,17 @@ def one_step(current_y, current_x, current_weight, current_jac):
results_ref = RAvg()

mean_ref = refvalues.mean().item()
std_ref = refvalues.std().item() / self.nbatch**0.5
var_ref = refvalues.var().item() / self.nbatch

results_ref.add(gvar.gvar(mean_ref, std_ref))
results_ref.update(mean_ref, var_ref, self.neval)
for i in range(f_size):
_mean = values[:, i].mean().item()
_std = values[:, i].std().item() / self.nbatch**0.5
results[i].add(gvar.gvar(_mean, _std))
_var = values[:, i].var().item() / self.nbatch
results[i].update(_mean, _var, self.neval)

if f_size == 1:
return results[0] / results_ref * self._rangebounds.prod()
res = results[0] / results_ref * self._rangebounds.prod()
result = RAvg(itn_results=[res], sum_neval=self.neval)
return result
else:
return results / results_ref * self._rangebounds.prod().item()
Loading
Loading