Skip to content

Commit 14843fc

Browse files
authored
Merge pull request #7 from numericalEFT/api_refact2
Api refactor
2 parents 63fa50d + 7123f22 commit 14843fc

File tree

4 files changed

+156
-123
lines changed

4 files changed

+156
-123
lines changed

src/base.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ class BaseDistribution(nn.Module):
99
Parameters do not depend of target variable (as is the case for a VAE encoder)
1010
"""
1111

12-
def __init__(self, bounds, device="cpu"):
12+
def __init__(self, bounds, device="cpu", dtype=torch.float64):
1313
super().__init__()
14+
self.dtype = dtype
1415
# self.bounds = bounds
1516
if isinstance(bounds, (list, np.ndarray)):
16-
self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device)
17+
self.bounds = torch.tensor(bounds, dtype=dtype, device=device)
1718
else:
1819
raise ValueError("Unsupported map specification")
1920
self.dim = self.bounds.shape[0]
@@ -36,13 +37,14 @@ class Uniform(BaseDistribution):
3637
Multivariate uniform distribution
3738
"""
3839

39-
def __init__(self, bounds, device="cpu"):
40-
super().__init__(bounds, device)
40+
def __init__(self, bounds, device="cpu", dtype=torch.float64):
41+
super().__init__(bounds, device, dtype)
4142
self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0]
4243

4344
def sample(self, nsamples=1, **kwargs):
4445
u = (
45-
torch.rand((nsamples, self.dim), device=self.device) * self._rangebounds
46+
torch.rand((nsamples, self.dim), device=self.device, dtype=self.dtype)
47+
* self._rangebounds
4648
+ self.bounds[:, 0]
4749
)
4850
log_detJ = torch.log(self._rangebounds).sum().repeat(nsamples)

src/integrators.py

Lines changed: 132 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,40 @@
11
from typing import Callable, Union, List, Tuple, Dict
22
import torch
33
from utils import RAvg
4-
from maps import Map, Affine, CompositeMap
4+
from maps import Map, Linear, CompositeMap
55
from base import Uniform
66
import gvar
77
import numpy as np
88

99

1010
class Integrator:
1111
"""
12-
Base class for all integrators.
12+
Base class for all integrators. This class is designed to handle integration tasks
13+
over a specified domain (bounds) using a sampling method (q0) and optional
14+
transformation maps.
1315
"""
1416

1517
def __init__(
1618
self,
17-
bounds: Union[List[Tuple[float, float]], np.ndarray],
19+
bounds,
1820
q0=None,
1921
maps=None,
2022
neval: int = 1000,
2123
nbatch: int = None,
2224
device="cpu",
23-
adapt=False,
25+
dtype=torch.float64,
2426
):
25-
self.adapt = adapt
27+
if not isinstance(bounds, (list, np.ndarray)):
28+
raise TypeError("bounds must be a list or a NumPy array.")
29+
self.dtype = dtype
2630
self.dim = len(bounds)
2731
if not q0:
28-
q0 = Uniform(bounds, device=device)
29-
self.bounds = torch.tensor(bounds, dtype=torch.float64, device=device)
32+
q0 = Uniform(bounds, device=device, dtype=dtype)
33+
self.bounds = torch.tensor(bounds, dtype=dtype, device=device)
3034
self.q0 = q0
35+
if maps:
36+
if not self.dtype == maps.dtype:
37+
raise ValueError("Float type of maps should be same as integrator.")
3138
self.maps = maps
3239
self.neval = neval
3340
if nbatch is None:
@@ -54,46 +61,41 @@ def sample(self, nsample, **kwargs):
5461
class MonteCarlo(Integrator):
5562
def __init__(
5663
self,
57-
bounds: Union[List[Tuple[float, float]], np.ndarray],
64+
bounds,
5865
q0=None,
5966
maps=None,
60-
nitn: int = 10,
6167
neval: int = 1000,
6268
nbatch: int = None,
6369
device="cpu",
64-
adapt=False,
70+
dtype=torch.float64,
6571
):
66-
super().__init__(bounds, q0, maps, neval, nbatch, device, adapt)
67-
self.nitn = nitn
72+
super().__init__(bounds, q0, maps, neval, nbatch, device, dtype)
6873

6974
def __call__(self, f: Callable, **kwargs):
7075
x, _ = self.sample(self.nbatch)
7176
f_values = f(x)
7277
f_size = len(f_values) if isinstance(f_values, (list, tuple)) else 1
73-
type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype)
74-
75-
mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
76-
var = torch.zeros(f_size, dtype=type_fval, device=self.device)
78+
# type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype)
79+
# mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
80+
# var = torch.zeros(f_size, dtype=type_fval, device=self.device)
7781
# var = torch.zeros((f_size, f_size), dtype=type_fval, device=self.device)
78-
79-
result = RAvg(weighted=self.adapt)
82+
mean = torch.zeros(f_size, dtype=self.dtype, device=self.device)
83+
var = torch.zeros(f_size, dtype=self.dtype, device=self.device)
84+
result = RAvg()
8085
epoch = self.neval // self.nbatch
8186

82-
for itn in range(self.nitn):
83-
mean[:] = 0
84-
var[:] = 0
85-
for _ in range(epoch):
86-
x, log_detJ = self.sample(self.nbatch)
87-
f_values = f(x)
88-
batch_results = self._multiply_by_jacobian(
89-
f_values, torch.exp(log_detJ)
90-
)
87+
mean[:] = 0
88+
var[:] = 0
89+
for _ in range(epoch):
90+
x, log_detJ = self.sample(self.nbatch)
91+
f_values = f(x)
92+
batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ))
9193

92-
mean += torch.mean(batch_results, dim=-1) / epoch
93-
var += torch.var(batch_results, dim=-1) / (self.neval * epoch)
94+
mean += torch.mean(batch_results, dim=-1) / epoch
95+
var += torch.var(batch_results, dim=-1) / (self.neval * epoch)
9496

95-
result.sum_neval += self.neval
96-
result.add(gvar.gvar(mean.item(), (var**0.5).item()))
97+
result.sum_neval += self.neval
98+
result.add(gvar.gvar(mean.item(), (var**0.5).item()))
9799
return result
98100

99101
def _multiply_by_jacobian(self, values, jac):
@@ -105,29 +107,48 @@ def _multiply_by_jacobian(self, values, jac):
105107
return values * jac
106108

107109

110+
def random_walk(dim, bounds, device, dtype, u, **kwargs):
111+
rangebounds = bounds[:, 1] - bounds[:, 0]
112+
step_size = kwargs.get("step_size", 0.2)
113+
step_sizes = rangebounds * step_size
114+
step = torch.empty(dim, device=device, dtype=dtype).uniform_(-1, 1) * step_sizes
115+
new_u = (u + step - bounds[:, 0]) % rangebounds + bounds[:, 0]
116+
return new_u
117+
118+
119+
def uniform(dim, bounds, device, dtype, u, **kwargs):
120+
rangebounds = bounds[:, 1] - bounds[:, 0]
121+
return torch.rand_like(u) * rangebounds + bounds[:, 0]
122+
123+
124+
def gaussian(dim, bounds, device, dtype, u, **kwargs):
125+
mean = kwargs.get("mean", torch.zeros_like(u))
126+
std = kwargs.get("std", torch.ones_like(u))
127+
return torch.normal(mean, std)
128+
129+
108130
class MCMC(MonteCarlo):
109131
def __init__(
110132
self,
111-
bounds: Union[List[Tuple[float, float]], np.ndarray],
133+
bounds,
112134
q0=None,
113135
maps=None,
114-
nitn: int = 10,
115136
neval=10000,
116137
nbatch=None,
117138
nburnin=500,
118139
device="cpu",
119-
adapt=False,
140+
dtype=torch.float64,
120141
):
121-
super().__init__(bounds, q0, maps, nitn, neval, nbatch, device, adapt)
142+
super().__init__(bounds, q0, maps, neval, nbatch, device, dtype)
122143
self.nburnin = nburnin
123144
if maps is None:
124-
self.maps = Affine([(0, 1)] * self.dim, device=device)
145+
self.maps = Linear([(0, 1)] * self.dim, device=device)
125146
self._rangebounds = self.bounds[:, 1] - self.bounds[:, 0]
126147

127148
def __call__(
128149
self,
129150
f: Callable,
130-
proposal_dist="uniform",
151+
proposal_dist: Callable = uniform,
131152
thinning=1,
132153
mix_rate=0.0,
133154
**kwargs,
@@ -146,84 +167,93 @@ def __call__(
146167
current_weight.masked_fill_(current_weight < epsilon, epsilon)
147168
# current_fval.masked_fill_(current_fval.abs() < epsilon, epsilon)
148169

149-
proposed_y = torch.empty_like(current_y)
150-
proposed_x = torch.empty_like(current_x)
151-
new_fval = torch.empty_like(current_fval)
152-
new_weight = torch.empty_like(current_weight)
170+
# proposed_y = torch.empty_like(current_y)
171+
# proposed_x = torch.empty_like(current_x)
172+
# new_fval = torch.empty_like(current_fval)
173+
# new_weight = torch.empty_like(current_weight)
153174

154175
f_size = len(current_fval) if isinstance(current_fval, (list, tuple)) else 1
155-
type_fval = current_fval.dtype if f_size == 1 else type(current_fval[0].dtype)
156-
mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
176+
# type_fval = current_fval.dtype if f_size == 1 else type(current_fval[0].dtype)
177+
# mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
178+
mean = torch.zeros(f_size, dtype=self.dtype, device=self.device)
157179
mean_ref = torch.zeros_like(mean)
158-
var = torch.zeros(f_size, dtype=type_fval, device=self.device)
180+
# var = torch.zeros(f_size, dtype=type_fval, device=self.device)
181+
var = torch.zeros(f_size, dtype=self.dtype, device=self.device)
159182
var_ref = torch.zeros_like(mean)
160183

161-
result = RAvg(weighted=self.adapt)
162-
result_ref = RAvg(weighted=self.adapt)
184+
result = RAvg()
185+
result_ref = RAvg()
163186

164187
epoch = self.neval // self.nbatch
165188
n_meas = 0
166-
for itn in range(self.nitn):
167-
for i in range(epoch):
168-
proposed_y[:] = self._propose(current_y, proposal_dist, **kwargs)
169-
proposed_x[:], new_jac = self.maps.forward(proposed_y)
170-
new_jac = torch.exp(new_jac)
171189

172-
new_fval[:] = f(proposed_x)
173-
new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs()
190+
def _propose(current_y, current_fval, current_weight, current_jac):
191+
proposed_y = proposal_dist(
192+
self.dim, self.bounds, self.device, self.dtype, current_y, **kwargs
193+
)
194+
proposed_x, new_jac = self.maps.forward(proposed_y)
195+
new_jac = torch.exp(new_jac)
196+
197+
new_fval = f(proposed_x)
198+
new_weight = mix_rate / new_jac + (1 - mix_rate) * new_fval.abs()
174199

175-
acceptance_probs = new_weight / current_weight * new_jac / current_jac
200+
acceptance_probs = new_weight / current_weight * new_jac / current_jac
176201

177-
accept = (
178-
torch.rand(self.nbatch, dtype=torch.float64, device=self.device)
179-
<= acceptance_probs
180-
)
202+
accept = (
203+
torch.rand(self.nbatch, dtype=self.dtype, device=self.device)
204+
<= acceptance_probs
205+
)
181206

182-
current_y = torch.where(accept.unsqueeze(1), proposed_y, current_y)
183-
current_fval = torch.where(accept, new_fval, current_fval)
184-
current_weight = torch.where(accept, new_weight, current_weight)
185-
current_jac = torch.where(accept, new_jac, current_jac)
186-
187-
if i < self.nburnin and itn == 0:
188-
continue
189-
elif i % thinning == 0:
190-
n_meas += 1
191-
batch_results = current_fval / current_weight
192-
193-
mean += torch.mean(batch_results, dim=-1) / epoch
194-
var += torch.var(batch_results, dim=-1) / epoch
195-
196-
batch_results_ref = 1 / (current_jac * current_weight)
197-
mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch
198-
var_ref += torch.var(batch_results_ref, dim=-1) / epoch
199-
200-
result.sum_neval += self.neval
201-
result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item()))
202-
result_ref.sum_neval += self.nbatch
203-
result_ref.add(
204-
gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item())
207+
current_y = torch.where(accept.unsqueeze(1), proposed_y, current_y)
208+
current_fval = torch.where(accept, new_fval, current_fval)
209+
current_weight = torch.where(accept, new_weight, current_weight)
210+
current_jac = torch.where(accept, new_jac, current_jac)
211+
return current_y, current_fval, current_weight, current_jac
212+
213+
for i in range(self.nburnin):
214+
current_y, current_fval, current_weight, current_jac = _propose(
215+
current_y, current_fval, current_weight, current_jac
205216
)
217+
for i in range(epoch // thinning):
218+
for j in range(thinning):
219+
current_y, current_fval, current_weight, current_jac = _propose(
220+
current_y, current_fval, current_weight, current_jac
221+
)
222+
n_meas += 1
223+
batch_results = current_fval / current_weight
224+
225+
mean += torch.mean(batch_results, dim=-1) / epoch
226+
var += torch.var(batch_results, dim=-1) / epoch
227+
228+
batch_results_ref = 1 / (current_jac * current_weight)
229+
mean_ref += torch.mean(batch_results_ref, dim=-1) / epoch
230+
var_ref += torch.var(batch_results_ref, dim=-1) / epoch
231+
232+
result.sum_neval += self.neval
233+
result.add(gvar.gvar(mean.item(), ((var / n_meas) ** 0.5).item()))
234+
result_ref.sum_neval += self.nbatch
235+
result_ref.add(gvar.gvar(mean_ref.item(), ((var_ref / n_meas) ** 0.5).item()))
206236

207237
return result / result_ref * self._rangebounds.prod()
208238

209-
def _propose(self, u, proposal_dist, **kwargs):
210-
if proposal_dist == "random_walk":
211-
step_size = kwargs.get("step_size", 0.2)
212-
step_sizes = self._rangebounds * step_size
213-
step = (
214-
torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes
215-
)
216-
new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[
217-
:, 0
218-
]
219-
return new_u
220-
# return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
221-
elif proposal_dist == "uniform":
222-
# return torch.rand_like(u)
223-
return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0]
224-
# elif proposal_dist == "gaussian":
225-
# mean = kwargs.get("mean", torch.zeros_like(u))
226-
# std = kwargs.get("std", torch.ones_like(u))
227-
# return torch.normal(mean, std)
228-
else:
229-
raise ValueError(f"Unknown proposal distribution: {proposal_dist}")
239+
# def _propose(self, u, proposal_dist, **kwargs):
240+
# if proposal_dist == "random_walk":
241+
# step_size = kwargs.get("step_size", 0.2)
242+
# step_sizes = self._rangebounds * step_size
243+
# step = (
244+
# torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes
245+
# )
246+
# new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[
247+
# :, 0
248+
# ]
249+
# return new_u
250+
# # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
251+
# elif proposal_dist == "uniform":
252+
# # return torch.rand_like(u)
253+
# return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0]
254+
# # elif proposal_dist == "gaussian":
255+
# # mean = kwargs.get("mean", torch.zeros_like(u))
256+
# # std = kwargs.get("std", torch.ones_like(u))
257+
# # return torch.normal(mean, std)
258+
# else:
259+
# raise ValueError(f"Unknown proposal distribution: {proposal_dist}")

0 commit comments

Comments
 (0)