Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(
self.dtype = dtype
if maps:
if not self.dtype == maps.dtype:
raise ValueError("Float type of maps should be same as integrator.")
raise ValueError(
"Float type of maps should be same as integrator.")
self.bounds = maps.bounds
else:
if not isinstance(bounds, (list, np.ndarray)):
Expand Down Expand Up @@ -72,6 +73,7 @@ def __init__(
device="cpu",
dtype=torch.float64,
):

super().__init__(maps, bounds, q0, neval, nbatch, device, dtype)

def __call__(self, f: Callable, **kwargs):
Expand All @@ -91,12 +93,14 @@ def __call__(self, f: Callable, **kwargs):
)

epoch = self.neval // self.nbatch
values = torch.zeros((self.nbatch, f_size), dtype=type_fval, device=self.device)
values = torch.zeros((self.nbatch, f_size),
dtype=type_fval, device=self.device)

for iepoch in range(epoch):
x, log_detJ = self.sample(self.nbatch)
f_values = f(x)
batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ))
batch_results = self._multiply_by_jacobian(
f_values, torch.exp(log_detJ))

values += batch_results / epoch

Expand All @@ -123,7 +127,8 @@ def random_walk(dim, bounds, device, dtype, u, **kwargs):
rangebounds = bounds[:, 1] - bounds[:, 0]
step_size = kwargs.get("step_size", 0.2)
step_sizes = rangebounds * step_size
step = torch.empty(dim, device=device, dtype=dtype).uniform_(-1, 1) * step_sizes
step = torch.empty(dim, device=device,
dtype=dtype).uniform_(-1, 1) * step_sizes
new_u = (u + step - bounds[:, 0]) % rangebounds + bounds[:, 0]
return new_u

Expand Down Expand Up @@ -191,7 +196,8 @@ def _integrand(x):
)
type_fval = current_fval.dtype

current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs()
current_weight = mix_rate / current_jac + \
(1 - mix_rate) * current_fval.abs()
current_weight.masked_fill_(current_weight < epsilon, epsilon)

n_meas = epoch // thinning
Expand Down Expand Up @@ -226,8 +232,10 @@ def one_step(current_y, current_x, current_weight, current_jac):
current_y, current_x, current_weight, current_jac
)

values = torch.zeros((self.nbatch, f_size), dtype=type_fval, device=self.device)
refvalues = torch.zeros(self.nbatch, dtype=type_fval, device=self.device)
values = torch.zeros((self.nbatch, f_size),
dtype=type_fval, device=self.device)
refvalues = torch.zeros(
self.nbatch, dtype=type_fval, device=self.device)

for imeas in range(n_meas):
for j in range(thinning):
Expand Down
46 changes: 29 additions & 17 deletions src/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(self, bounds, device="cpu", dtype=torch.float64):
elif isinstance(bounds, torch.Tensor):
self.bounds = bounds.to(dtype=dtype, device=device)
else:
raise ValueError("'bounds' must be a list, numpy array, or torch tensor.")
raise ValueError(
"'bounds' must be a list, numpy array, or torch tensor.")

self.dim = self.bounds.shape[0]
self.device = device
Expand All @@ -29,9 +30,11 @@ def inverse(self, x):


class CompositeMap(Map):
def __init__(self, maps, device="cpu", dtype=torch.float64):
def __init__(self, maps, device="cpu", dtype=None):
if not maps:
raise ValueError("Maps can not be empty.")
if dtype is None:
dtype = maps[-1].dtype
super().__init__(maps[-1].bounds, device, dtype)
self.maps = maps

Expand Down Expand Up @@ -71,17 +74,20 @@ def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu", dtype=torch.float

# Ensure ninc is a tensor of appropriate shape and type
if isinstance(ninc, int):
self.ninc = torch.full((self.dim,), ninc, dtype=torch.int32, device=device)
self.ninc = torch.full(
(self.dim,), ninc, dtype=torch.int32, device=device)
elif isinstance(ninc, (list, np.ndarray)):
self.ninc = torch.tensor(ninc, dtype=torch.int32, device=device)
elif isinstance(ninc, torch.Tensor):
self.ninc = ninc.to(dtype=torch.int32, device=device)
else:
raise ValueError("'ninc' must be an int, list, numpy array, or torch tensor.")

raise ValueError(
"'ninc' must be an int, list, numpy array, or torch tensor.")

# Ensure ninc has the correct shape
if self.ninc.shape != (self.dim,):
raise ValueError(f"'ninc' must be a scalar or a 1D array of length {self.dim}.")
raise ValueError(
f"'ninc' must be a scalar or a 1D array of length {self.dim}.")

self.make_uniform()
self.alpha = alpha
Expand Down Expand Up @@ -165,13 +171,14 @@ def adapt(self, alpha=0.0):
"""
new_grid = torch.empty(
(self.dim, torch.max(self.ninc) + 1),
dtype=torch.float64,
dtype=self.dtype,
device=self.device,
)
avg_f = torch.ones(self.inc.shape[1], dtype=torch.float64, device=self.device)
avg_f = torch.ones(
self.inc.shape[1], dtype=self.dtype, device=self.device)
if alpha > 0:
tmp_f = torch.empty(
self.inc.shape[1], dtype=torch.float64, device=self.device
self.inc.shape[1], dtype=self.dtype, device=self.device
)
for d in range(self.dim):
ninc = self.ninc[d]
Expand All @@ -183,11 +190,12 @@ def adapt(self, alpha=0.0):
if alpha > 0: # smooth
tmp_f[0] = torch.abs(7.0 * avg_f[0] + avg_f[1]) / 8.0
tmp_f[ninc - 1] = (
torch.abs(7.0 * avg_f[ninc - 1] + avg_f[ninc - 2]) / 8.0
torch.abs(7.0 * avg_f[ninc - 1] +
avg_f[ninc - 2]) / 8.0
)
tmp_f[1 : ninc - 1] = (
tmp_f[1: ninc - 1] = (
torch.abs(
6.0 * avg_f[1 : ninc - 1]
6.0 * avg_f[1: ninc - 1]
+ avg_f[: ninc - 2]
+ avg_f[2:ninc]
)
Expand Down Expand Up @@ -218,17 +226,19 @@ def adapt(self, alpha=0.0):
else:
acc_f -= f_ninc
new_grid[d, i] = (
self.grid[d, j + 1] - (acc_f / avg_f[j]) * self.inc[d, j]
self.grid[d, j + 1] -
(acc_f / avg_f[j]) * self.inc[d, j]
)
continue
break
self.grid = new_grid
self.inc = torch.empty(
(self.dim, self.grid.shape[1] - 1), dtype=torch.float64, device=self.device
(self.dim, self.grid.shape[1] - 1), dtype=self.dtype, device=self.device
)
for d in range(self.dim):
self.inc[d, : self.ninc[d]] = (
self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]]
self.grid[d, 1: self.ninc[d] + 1] -
self.grid[d, : self.ninc[d]]
)
self.clear()

Expand All @@ -249,7 +259,8 @@ def make_uniform(self):
device=self.device,
)
self.inc[d, : self.ninc[d]] = (
self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]]
self.grid[d, 1: self.ninc[d] + 1] -
self.grid[d, : self.ninc[d]]
)
self.clear()

Expand Down Expand Up @@ -305,7 +316,8 @@ def inverse(self, x):
jac = torch.ones(x.shape[0], device=x.device)
for d in range(self.dim):
ninc = self.ninc[d]
iu = torch.searchsorted(self.grid[d, :], x[:, d].contiguous(), right=True)
iu = torch.searchsorted(
self.grid[d, :], x[:, d].contiguous(), right=True)

mask_valid = (iu > 0) & (iu <= ninc)
mask_lower = iu <= 0
Expand Down
44 changes: 37 additions & 7 deletions src/vegas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,70 @@

# set_seed(42)
# device = get_device()
device = torch.device("cpu")
device = torch.device("mps")
# device = torch.device("cpu")
dtype = torch.float32


def reset_nan_to_zero(tensor):
"""
Resets NaN values in a tensor to zero in-place.

Args:
tensor: The PyTorch tensor to modify.
"""
mask = torch.isnan(
tensor) # Create a boolean mask where True indicates NaN values
tensor[mask] = 0 # U


def integrand_list1(x):
dx2 = torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
for d in range(4):
dx2 += (x[:, d] - 0.5) ** 2
f = torch.exp(-200 * dx2)
reset_nan_to_zero(f)
if torch.isnan(f).any():
print("NaN detected in func")
return [f, f * x[:, 0], f * x[:, 0] ** 2]


def sharp_peak(x):
dx2 = torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)
for d in range(4):
dx2 += (x[:, d] - 0.5) ** 2
return torch.exp(-200 * dx2)
res = torch.exp(-200 * dx2)
reset_nan_to_zero(res)
if torch.isnan(res).any():
print("NaN detected in func")
return res


def func(x):
return torch.log(x[:, 0]) / torch.sqrt(x[:, 0])
res = torch.log(x[:, 0]) / torch.sqrt(x[:, 0])
reset_nan_to_zero(res)
if torch.isnan(res).any():
print("NaN detected in func")
return res


ninc = 1000
n_eval = 50000
n_eval = 500000
n_batch = 10000
n_therm = 10

print("\nCalculate the integral log(x)/x^0.5 in the bounds [0, 1]")

print("train VEGAS map")
vegas_map = Vegas([(0, 1)], device=device, ninc=ninc)
vegas_map = Vegas([(0, 1)], device=device, ninc=ninc, dtype=dtype)
vegas_map.train(20000, func, epoch=10, alpha=0.5)

vegas_integrator = MonteCarlo(
maps=vegas_map,
neval=1000000,
nbatch=n_batch,
device=device,
dtype=dtype
)
res = vegas_integrator(func)
print("VEGAS Integral results: ", res)
Expand All @@ -54,17 +80,19 @@ def func(x):
nbatch=n_batch,
nburnin=n_therm,
device=device,
dtype=dtype
)
res = vegasmcmc_integrator(func, mix_rate=0.5)
print("VEGAS-MCMC Integral results: ", res)
print(type(res))

# Start Monte Carlo integration, including plain-MC, MCMC, vegas, and vegas-MCMC
print("\nCalculate the integral [h(X), x1 * h(X), x1^2 * h(X)] in the bounds [0, 1]^4")
print(
"\nCalculate the integral [h(X), x1 * h(X), x1^2 * h(X)] in the bounds [0, 1]^4")
print("h(X) = exp(-200 * (x1^2 + x2^2 + x3^2 + x4^2))")

bounds = [(0, 1)] * 4
vegas_map = Vegas(bounds, device=device, ninc=ninc)
vegas_map = Vegas(bounds, device=device, ninc=ninc, dtype=dtype)
print("train VEGAS map for h(X)...")
vegas_map.train(20000, sharp_peak, epoch=10, alpha=0.5)
# print(vegas_map.extract_grid())
Expand All @@ -75,6 +103,7 @@ def func(x):
neval=n_eval,
nbatch=n_batch,
device=device,
dtype=dtype
)
res = vegas_integrator(integrand_list1)
print(
Expand All @@ -101,6 +130,7 @@ def func(x):
nbatch=n_batch,
nburnin=n_therm,
device=device,
dtype=dtype
)
res = vegasmcmc_integrator(integrand_list1, mix_rate=0.5)
print(
Expand Down
Loading