diff --git a/src/integrators.py b/src/integrators.py index ecec374..f747cb2 100644 --- a/src/integrators.py +++ b/src/integrators.py @@ -27,7 +27,8 @@ def __init__( self.dtype = dtype if maps: if not self.dtype == maps.dtype: - raise ValueError("Float type of maps should be same as integrator.") + raise ValueError( + "Float type of maps should be same as integrator.") self.bounds = maps.bounds else: if not isinstance(bounds, (list, np.ndarray)): @@ -72,6 +73,7 @@ def __init__( device="cpu", dtype=torch.float64, ): + super().__init__(maps, bounds, q0, neval, nbatch, device, dtype) def __call__(self, f: Callable, **kwargs): @@ -91,12 +93,14 @@ def __call__(self, f: Callable, **kwargs): ) epoch = self.neval // self.nbatch - values = torch.zeros((self.nbatch, f_size), dtype=type_fval, device=self.device) + values = torch.zeros((self.nbatch, f_size), + dtype=type_fval, device=self.device) for iepoch in range(epoch): x, log_detJ = self.sample(self.nbatch) f_values = f(x) - batch_results = self._multiply_by_jacobian(f_values, torch.exp(log_detJ)) + batch_results = self._multiply_by_jacobian( + f_values, torch.exp(log_detJ)) values += batch_results / epoch @@ -123,7 +127,8 @@ def random_walk(dim, bounds, device, dtype, u, **kwargs): rangebounds = bounds[:, 1] - bounds[:, 0] step_size = kwargs.get("step_size", 0.2) step_sizes = rangebounds * step_size - step = torch.empty(dim, device=device, dtype=dtype).uniform_(-1, 1) * step_sizes + step = torch.empty(dim, device=device, + dtype=dtype).uniform_(-1, 1) * step_sizes new_u = (u + step - bounds[:, 0]) % rangebounds + bounds[:, 0] return new_u @@ -191,7 +196,8 @@ def _integrand(x): ) type_fval = current_fval.dtype - current_weight = mix_rate / current_jac + (1 - mix_rate) * current_fval.abs() + current_weight = mix_rate / current_jac + \ + (1 - mix_rate) * current_fval.abs() current_weight.masked_fill_(current_weight < epsilon, epsilon) n_meas = epoch // thinning @@ -226,8 +232,10 @@ def one_step(current_y, current_x, current_weight, current_jac): current_y, current_x, current_weight, current_jac ) - values = torch.zeros((self.nbatch, f_size), dtype=type_fval, device=self.device) - refvalues = torch.zeros(self.nbatch, dtype=type_fval, device=self.device) + values = torch.zeros((self.nbatch, f_size), + dtype=type_fval, device=self.device) + refvalues = torch.zeros( + self.nbatch, dtype=type_fval, device=self.device) for imeas in range(n_meas): for j in range(thinning): diff --git a/src/maps.py b/src/maps.py index 6f6fc1d..5ca410e 100644 --- a/src/maps.py +++ b/src/maps.py @@ -15,7 +15,8 @@ def __init__(self, bounds, device="cpu", dtype=torch.float64): elif isinstance(bounds, torch.Tensor): self.bounds = bounds.to(dtype=dtype, device=device) else: - raise ValueError("'bounds' must be a list, numpy array, or torch tensor.") + raise ValueError( + "'bounds' must be a list, numpy array, or torch tensor.") self.dim = self.bounds.shape[0] self.device = device @@ -29,9 +30,11 @@ def inverse(self, x): class CompositeMap(Map): - def __init__(self, maps, device="cpu", dtype=torch.float64): + def __init__(self, maps, device="cpu", dtype=None): if not maps: raise ValueError("Maps can not be empty.") + if dtype is None: + dtype = maps[-1].dtype super().__init__(maps[-1].bounds, device, dtype) self.maps = maps @@ -71,17 +74,20 @@ def __init__(self, bounds, ninc=1000, alpha=0.5, device="cpu", dtype=torch.float # Ensure ninc is a tensor of appropriate shape and type if isinstance(ninc, int): - self.ninc = torch.full((self.dim,), ninc, dtype=torch.int32, device=device) + self.ninc = torch.full( + (self.dim,), ninc, dtype=torch.int32, device=device) elif isinstance(ninc, (list, np.ndarray)): self.ninc = torch.tensor(ninc, dtype=torch.int32, device=device) elif isinstance(ninc, torch.Tensor): self.ninc = ninc.to(dtype=torch.int32, device=device) else: - raise ValueError("'ninc' must be an int, list, numpy array, or torch tensor.") - + raise ValueError( + "'ninc' must be an int, list, numpy array, or torch tensor.") + # Ensure ninc has the correct shape if self.ninc.shape != (self.dim,): - raise ValueError(f"'ninc' must be a scalar or a 1D array of length {self.dim}.") + raise ValueError( + f"'ninc' must be a scalar or a 1D array of length {self.dim}.") self.make_uniform() self.alpha = alpha @@ -165,13 +171,14 @@ def adapt(self, alpha=0.0): """ new_grid = torch.empty( (self.dim, torch.max(self.ninc) + 1), - dtype=torch.float64, + dtype=self.dtype, device=self.device, ) - avg_f = torch.ones(self.inc.shape[1], dtype=torch.float64, device=self.device) + avg_f = torch.ones( + self.inc.shape[1], dtype=self.dtype, device=self.device) if alpha > 0: tmp_f = torch.empty( - self.inc.shape[1], dtype=torch.float64, device=self.device + self.inc.shape[1], dtype=self.dtype, device=self.device ) for d in range(self.dim): ninc = self.ninc[d] @@ -183,11 +190,12 @@ def adapt(self, alpha=0.0): if alpha > 0: # smooth tmp_f[0] = torch.abs(7.0 * avg_f[0] + avg_f[1]) / 8.0 tmp_f[ninc - 1] = ( - torch.abs(7.0 * avg_f[ninc - 1] + avg_f[ninc - 2]) / 8.0 + torch.abs(7.0 * avg_f[ninc - 1] + + avg_f[ninc - 2]) / 8.0 ) - tmp_f[1 : ninc - 1] = ( + tmp_f[1: ninc - 1] = ( torch.abs( - 6.0 * avg_f[1 : ninc - 1] + 6.0 * avg_f[1: ninc - 1] + avg_f[: ninc - 2] + avg_f[2:ninc] ) @@ -218,17 +226,19 @@ def adapt(self, alpha=0.0): else: acc_f -= f_ninc new_grid[d, i] = ( - self.grid[d, j + 1] - (acc_f / avg_f[j]) * self.inc[d, j] + self.grid[d, j + 1] - + (acc_f / avg_f[j]) * self.inc[d, j] ) continue break self.grid = new_grid self.inc = torch.empty( - (self.dim, self.grid.shape[1] - 1), dtype=torch.float64, device=self.device + (self.dim, self.grid.shape[1] - 1), dtype=self.dtype, device=self.device ) for d in range(self.dim): self.inc[d, : self.ninc[d]] = ( - self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]] + self.grid[d, 1: self.ninc[d] + 1] - + self.grid[d, : self.ninc[d]] ) self.clear() @@ -249,7 +259,8 @@ def make_uniform(self): device=self.device, ) self.inc[d, : self.ninc[d]] = ( - self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]] + self.grid[d, 1: self.ninc[d] + 1] - + self.grid[d, : self.ninc[d]] ) self.clear() @@ -305,7 +316,8 @@ def inverse(self, x): jac = torch.ones(x.shape[0], device=x.device) for d in range(self.dim): ninc = self.ninc[d] - iu = torch.searchsorted(self.grid[d, :], x[:, d].contiguous(), right=True) + iu = torch.searchsorted( + self.grid[d, :], x[:, d].contiguous(), right=True) mask_valid = (iu > 0) & (iu <= ninc) mask_lower = iu <= 0 diff --git a/src/vegas_test.py b/src/vegas_test.py index addc8fc..fa0ce1e 100644 --- a/src/vegas_test.py +++ b/src/vegas_test.py @@ -6,7 +6,21 @@ # set_seed(42) # device = get_device() -device = torch.device("cpu") +device = torch.device("mps") +# device = torch.device("cpu") +dtype = torch.float32 + + +def reset_nan_to_zero(tensor): + """ + Resets NaN values in a tensor to zero in-place. + + Args: + tensor: The PyTorch tensor to modify. + """ + mask = torch.isnan( + tensor) # Create a boolean mask where True indicates NaN values + tensor[mask] = 0 # U def integrand_list1(x): @@ -14,6 +28,9 @@ def integrand_list1(x): for d in range(4): dx2 += (x[:, d] - 0.5) ** 2 f = torch.exp(-200 * dx2) + reset_nan_to_zero(f) + if torch.isnan(f).any(): + print("NaN detected in func") return [f, f * x[:, 0], f * x[:, 0] ** 2] @@ -21,22 +38,30 @@ def sharp_peak(x): dx2 = torch.zeros(x.shape[0], dtype=x.dtype, device=x.device) for d in range(4): dx2 += (x[:, d] - 0.5) ** 2 - return torch.exp(-200 * dx2) + res = torch.exp(-200 * dx2) + reset_nan_to_zero(res) + if torch.isnan(res).any(): + print("NaN detected in func") + return res def func(x): - return torch.log(x[:, 0]) / torch.sqrt(x[:, 0]) + res = torch.log(x[:, 0]) / torch.sqrt(x[:, 0]) + reset_nan_to_zero(res) + if torch.isnan(res).any(): + print("NaN detected in func") + return res ninc = 1000 -n_eval = 50000 +n_eval = 500000 n_batch = 10000 n_therm = 10 print("\nCalculate the integral log(x)/x^0.5 in the bounds [0, 1]") print("train VEGAS map") -vegas_map = Vegas([(0, 1)], device=device, ninc=ninc) +vegas_map = Vegas([(0, 1)], device=device, ninc=ninc, dtype=dtype) vegas_map.train(20000, func, epoch=10, alpha=0.5) vegas_integrator = MonteCarlo( @@ -44,6 +69,7 @@ def func(x): neval=1000000, nbatch=n_batch, device=device, + dtype=dtype ) res = vegas_integrator(func) print("VEGAS Integral results: ", res) @@ -54,17 +80,19 @@ def func(x): nbatch=n_batch, nburnin=n_therm, device=device, + dtype=dtype ) res = vegasmcmc_integrator(func, mix_rate=0.5) print("VEGAS-MCMC Integral results: ", res) print(type(res)) # Start Monte Carlo integration, including plain-MC, MCMC, vegas, and vegas-MCMC -print("\nCalculate the integral [h(X), x1 * h(X), x1^2 * h(X)] in the bounds [0, 1]^4") +print( + "\nCalculate the integral [h(X), x1 * h(X), x1^2 * h(X)] in the bounds [0, 1]^4") print("h(X) = exp(-200 * (x1^2 + x2^2 + x3^2 + x4^2))") bounds = [(0, 1)] * 4 -vegas_map = Vegas(bounds, device=device, ninc=ninc) +vegas_map = Vegas(bounds, device=device, ninc=ninc, dtype=dtype) print("train VEGAS map for h(X)...") vegas_map.train(20000, sharp_peak, epoch=10, alpha=0.5) # print(vegas_map.extract_grid()) @@ -75,6 +103,7 @@ def func(x): neval=n_eval, nbatch=n_batch, device=device, + dtype=dtype ) res = vegas_integrator(integrand_list1) print( @@ -101,6 +130,7 @@ def func(x): nbatch=n_batch, nburnin=n_therm, device=device, + dtype=dtype ) res = vegasmcmc_integrator(integrand_list1, mix_rate=0.5) print(