diff --git a/MCintegration/maps.py b/MCintegration/maps.py index 333ce44..8cf3499 100644 --- a/MCintegration/maps.py +++ b/MCintegration/maps.py @@ -99,6 +99,22 @@ def __init__(self, dim, ninc=1000, device=None, dtype=torch.float32): f"'ninc' must be a scalar or a 1D array of length {self.dim}." ) + # Preallocate tensors to minimize memory allocations + self.max_ninc = self.ninc.max().item() + # Preallocate temporary tensors for adapt + self.sum_f = torch.zeros( + self.dim, self.max_ninc, dtype=self.dtype, device=self.device + ) + self.n_f = torch.zeros( + self.dim, self.max_ninc, dtype=self.dtype, device=self.device + ) + self.avg_f = torch.ones( + (self.dim, self.max_ninc), dtype=self.dtype, device=self.device + ) + self.tmp_f = torch.zeros( + (self.dim, self.max_ninc), dtype=self.dtype, device=self.device + ) + self.make_uniform() def adaptive_training( @@ -109,6 +125,16 @@ def adaptive_training( epoch=10, alpha=0.5, ): + """ + Perform adaptive training to adjust the grid based on the training function. + + Args: + batch_size (int): Number of samples per batch. + f (callable): Training function that takes x and fx as inputs. + f_dim (int, optional): Dimension of the function f. Defaults to 1. + epoch (int, optional): Number of training epochs. Defaults to 10. + alpha (float, optional): Adaptation rate. Defaults to 0.5. + """ q0 = Uniform(self.dim, device=self.device, dtype=self.dtype) sample = Configuration( batch_size, self.dim, f_dim, device=self.device, dtype=self.dtype @@ -122,6 +148,7 @@ def adaptive_training( self.add_training_data(sample) self.adapt(alpha) + @torch.no_grad() def add_training_data(self, sample): """Add training data ``f`` for ``u``-space points ``u``. @@ -139,121 +166,133 @@ def add_training_data(self, sample): point ``u[j, d]`` in ``u``-space. """ fval = (sample.detJ * sample.weight) ** 2 - if self.sum_f is None: - self.sum_f = torch.zeros_like(self.inc) - self.n_f = torch.zeros_like(self.inc) + TINY iu = torch.floor(sample.u * self.ninc).long() for d in range(self.dim): indices = iu[:, d] self.sum_f[d].scatter_add_(0, indices, fval.abs()) self.n_f[d].scatter_add_(0, indices, torch.ones_like(fval)) - def adapt(self, alpha=0.0): - """Adapt grid to accumulated training data. - - ``self.adapt(...)`` projects the training data onto - each axis independently and maps it into ``x`` space. - It shrinks ``x``-grid increments in regions where the - projected training data is large, and grows increments - where the projected data is small. The grid along - any direction is unchanged if the training data - is constant along that direction. - - The number of increments along a direction can be - changed by setting parameter ``ninc`` (array or number). + @torch.no_grad() + def adapt(self, alpha=0.5): + """ + Adapt the grid based on accumulated training data. - The grid does not change if no training data has - been accumulated, unless ``ninc`` is specified, in - which case the number of increments is adjusted - while preserving the relative density of increments - at different values of ``x``. + Shrinks grid increments in regions where the accumulated f is large, + and grows them where f is small. The adaptation speed is controlled by alpha. Args: - alpha (float): Determines the speed with which the grid - adapts to training data. Large (postive) values imply + alpha (float, optional): Determines the speed with which the grid + adapts to training data. Large (positive) values imply rapid evolution; small values (much less than one) imply slow evolution. Typical values are of order one. Choosing ``alpha<0`` causes adaptation to the unmodified training data (usually not a good idea). """ + # Aggregate training data across distributed processes if applicable if torch.distributed.is_initialized(): torch.distributed.all_reduce(self.sum_f, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce(self.n_f, op=torch.distributed.ReduceOp.SUM) + + # Initialize a new grid tensor new_grid = torch.empty( - (self.dim, torch.max(self.ninc) + 1), - dtype=self.dtype, - device=self.device, + (self.dim, self.max_ninc + 1), dtype=self.dtype, device=self.device ) - avg_f = torch.ones(self.inc.shape[1], dtype=self.dtype, device=self.device) + if alpha > 0: - tmp_f = torch.empty(self.inc.shape[1], dtype=self.dtype, device=self.device) + tmp_f = torch.empty(self.max_ninc, dtype=self.dtype, device=self.device) + + # avg_f = torch.ones(self.inc.shape[1], dtype=self.dtype, device=self.device) for d in range(self.dim): - ninc = self.ninc[d] + ninc = self.ninc[d].item() + if alpha != 0: - if self.sum_f is not None: - mask = self.n_f[d, :] > 0 - avg_f[mask] = self.sum_f[d, mask] / self.n_f[d, mask] - avg_f[~mask] = 0.0 - if alpha > 0: # smooth - tmp_f[0] = torch.abs(7.0 * avg_f[0] + avg_f[1]) / 8.0 + # Compute average f for current dimension where n_f > 0 + mask = self.n_f[d, :ninc] > 0 # Shape: (ninc,) + avg_f = torch.where( + mask, + self.sum_f[d, :ninc] / self.n_f[d, :ninc], + torch.zeros_like(self.sum_f[d, :ninc]), + ) # Shape: (ninc,) + + if alpha > 0: + # Smooth avg_f + tmp_f[0] = (7.0 * avg_f[0] + avg_f[1]).abs() / 8.0 # Shape: () tmp_f[ninc - 1] = ( - torch.abs(7.0 * avg_f[ninc - 1] + avg_f[ninc - 2]) / 8.0 - ) + 7.0 * avg_f[ninc - 1] + avg_f[ninc - 2] + ).abs() / 8.0 # Shape: () tmp_f[1 : ninc - 1] = ( - torch.abs( - 6.0 * avg_f[1 : ninc - 1] - + avg_f[: ninc - 2] - + avg_f[2:ninc] - ) - / 8.0 + 6.0 * avg_f[1 : ninc - 1] + avg_f[: ninc - 2] + avg_f[2:ninc] + ).abs() / 8.0 + + # Normalize tmp_f to ensure the sum is 1 + sum_f = torch.sum(tmp_f[:ninc]).clamp_min_(TINY) # Scalar + avg_f = tmp_f[:ninc] / sum_f + TINY # Shape: (ninc,) + + # Apply non-linear transformation controlled by alpha + avg_f = (-(1 - avg_f) / torch.log(avg_f)).pow_( + alpha + ) # Shape: (ninc,) + + # Compute the target accumulated f per increment + f_ninc = avg_f.sum() / ninc # Scalar + + new_grid[d, 0] = self.grid[d, 0] + new_grid[d, ninc] = self.grid[d, ninc] + + target_cumulative_weights = ( + torch.arange(1, ninc, device=self.device) * f_ninc + ) # Calculate the target cumulative weights for each new grid point + + cumulative_avg_f = torch.cat( + ( + torch.tensor([0.0], device=self.device), + torch.cumsum(avg_f, dim=0), ) - sum_f = torch.sum(tmp_f[:ninc]) - if sum_f > 0: - avg_f[:ninc] = tmp_f[:ninc] / sum_f + TINY - else: - avg_f[:ninc] = TINY - avg_f[:ninc] = ( - -(1 - avg_f[:ninc]) / torch.log(avg_f[:ninc]) - ) ** alpha - - new_grid[d, 0] = self.grid[d, 0] - new_grid[d, ninc] = self.grid[d, ninc] - f_ninc = torch.sum(avg_f[:ninc]) / ninc - - j = -1 - acc_f = 0 - for i in range(1, ninc): - while acc_f < f_ninc: - j += 1 - if j < ninc: - acc_f += avg_f[j] - else: - break - else: - acc_f -= f_ninc - new_grid[d, i] = ( - self.grid[d, j + 1] - (acc_f / avg_f[j]) * self.inc[d, j] + ) # Calculate the cumulative sum of avg_f + interval_indices = ( + torch.searchsorted( + cumulative_avg_f, target_cumulative_weights, right=True ) - continue - break + - 1 + ) # Find the intervals in the original grid where the target weights fall + # Extract the necessary values using the interval indices + grid_left = self.grid[d, interval_indices] + inc_relevant = self.inc[d, interval_indices] + avg_f_relevant = avg_f[interval_indices] + cumulative_avg_f_relevant = cumulative_avg_f[interval_indices] + + # Calculate the fractional position within each interval + fractional_positions = ( + target_cumulative_weights - cumulative_avg_f_relevant + ) / avg_f_relevant + + # Calculate the new grid points using vectorized operations + new_grid[d, 1:ninc] = grid_left + fractional_positions * inc_relevant + else: + # If alpha == 0 or no training data, retain the existing grid + new_grid[d, :] = self.grid[d, :] + + # Assign the newly computed grid self.grid = new_grid - self.inc = torch.empty( - (self.dim, self.grid.shape[1] - 1), - dtype=self.dtype, - device=self.device, - ) + + # Update increments based on the new grid + # Compute the difference between consecutive grid points + self.inc.zero_() # Reset increments to zero for d in range(self.dim): self.inc[d, : self.ninc[d]] = ( self.grid[d, 1 : self.ninc[d] + 1] - self.grid[d, : self.ninc[d]] ) + + # Clear accumulated training data for the next adaptation cycle self.clear() + @torch.no_grad() def make_uniform(self): self.inc = torch.empty( - self.dim, self.ninc.max(), dtype=self.dtype, device=self.device + self.dim, self.max_ninc, dtype=self.dtype, device=self.device ) self.grid = torch.empty( - self.dim, self.ninc.max() + 1, dtype=self.dtype, device=self.device + self.dim, self.max_ninc + 1, dtype=self.dtype, device=self.device ) for d in range(self.dim): @@ -271,81 +310,132 @@ def make_uniform(self): def extract_grid(self): "Return a list of lists specifying the map's grid." - grid = [] + grid_list = [] for d in range(self.dim): ng = self.ninc[d] + 1 - grid.append(self.grid[d, :ng].tolist()) - return grid + grid_list.append(self.grid[d, :ng].tolist()) + return grid_list + @torch.no_grad() def clear(self): "Clear information accumulated by :meth:`AdaptiveMap.add_training_data`." - self.sum_f = None - self.n_f = None + self.sum_f.zero_() + self.n_f.zero_() @torch.no_grad() def forward(self, u): - # u = u.to(self.device) u_ninc = u * self.ninc iu = torch.floor(u_ninc).long() - du_ninc = u_ninc - torch.floor(u_ninc).long() + du_ninc = u_ninc - iu + + batch_size = u.size(0) + # Clamp iu to [0, ninc-1] to handle out-of-bounds indices + min_tensor = torch.zeros((1, self.dim), dtype=iu.dtype, device=self.device) + max_tensor = (self.ninc - 1).unsqueeze(0).to(iu.dtype) # Shape: (1, dim) + iu_clamped = torch.clamp(iu, min=min_tensor, max=max_tensor) + + grid_expanded = self.grid.unsqueeze(0).expand(batch_size, -1, -1) + inc_expanded = self.inc.unsqueeze(0).expand(batch_size, -1, -1) + + grid_gather = torch.gather(grid_expanded, 2, iu_clamped.unsqueeze(2)).squeeze( + 2 + ) # Shape: (batch_size, dim) + inc_gather = torch.gather(inc_expanded, 2, iu_clamped.unsqueeze(2)).squeeze(2) + + x = grid_gather + inc_gather * du_ninc + log_detJ = (inc_gather * self.ninc).log_().sum(dim=1) + + # Handle out-of-bounds by setting x to grid boundary and adjusting detJ + out_of_bounds = iu >= self.ninc + if out_of_bounds.any(): + # Create indices for out-of-bounds + # For each sample and dimension, set x to grid[d, ninc[d]] + # and log_detJ += log(inc[d, ninc[d]-1] * ninc[d]) + boundary_grid = ( + self.grid[torch.arange(self.dim, device=self.device), self.ninc] + .unsqueeze(0) + .expand(batch_size, -1) + ) + # x = torch.where(out_of_bounds, boundary_grid, x) + x[out_of_bounds] = boundary_grid[out_of_bounds] - x = torch.empty_like(u) - detJ = torch.ones(u.shape[0], device=x.device) - # self.detJ.fill_(1.0) - for d in range(self.dim): - # Handle the case where iu < ninc - ninc = self.ninc[d] - mask = iu[:, d] < ninc - if mask.any(): - x[mask, d] = ( - self.grid[d, iu[mask, d]] - + self.inc[d, iu[mask, d]] * du_ninc[mask, d] - ) - detJ[mask] *= self.inc[d, iu[mask, d]] * ninc - - # Handle the case where iu >= ninc - mask_inv = ~mask - if mask_inv.any(): - x[mask_inv, d] = self.grid[d, ninc] - detJ[mask_inv] *= self.inc[d, ninc - 1] * ninc - - return x, detJ.log_() + boundary_inc = ( + self.inc[torch.arange(self.dim, device=self.device), self.ninc - 1] + .unsqueeze(0) + .expand(batch_size, -1) + ) + adj_log_detJ = ((boundary_inc * self.ninc).log_() * out_of_bounds).sum( + dim=1 + ) + log_detJ += adj_log_detJ + + return x, log_detJ @torch.no_grad() def inverse(self, x): - # self.detJ.fill_(1.0) - x = x.to(self.device) + """ + Inverse map from x-space to u-space. + + Args: + x (torch.Tensor): Tensor of shape (batch_size, dim) representing points in x-space. + + Returns: + u (torch.Tensor): Tensor of shape (batch_size, dim) representing points in u-space. + log_detJ (torch.Tensor): Tensor of shape (batch_size,) representing the log determinant of the Jacobian. + """ + x.to(self.device) + batch_size, dim = x.shape + + # Initialize output tensors u = torch.empty_like(x) - detJ = torch.ones(x.shape[0], device=x.device) - for d in range(self.dim): - ninc = self.ninc[d] - iu = torch.searchsorted(self.grid[d, :], x[:, d].contiguous(), right=True) - - mask_valid = (iu > 0) & (iu <= ninc) - mask_lower = iu <= 0 - mask_upper = iu > ninc - - # Handle valid range (0 < iu <= ninc) - if mask_valid.any(): - iui_valid = iu[mask_valid] - 1 - u[mask_valid, d] = ( - iui_valid - + (x[mask_valid, d] - self.grid[d, iui_valid]) - / self.inc[d, iui_valid] - ) / ninc - detJ[mask_valid] *= self.inc[d, iui_valid] * ninc - - # Handle lower bound (iu <= 0)\ - if mask_lower.any(): - u[mask_lower, d] = 0.0 - detJ[mask_lower] *= self.inc[d, 0] * ninc - - # Handle upper bound (iu > ninc) - if mask_upper.any(): - u[mask_upper, d] = 1.0 - detJ[mask_upper] *= self.inc[d, ninc - 1] * ninc - - return u, detJ.log_() + log_detJ = torch.zeros(batch_size, device=self.device, dtype=self.dtype) + + # Loop over each dimension to perform inverse mapping + for d in range(dim): + # Extract the grid and increment for dimension d + grid_d = self.grid[d] # Shape: (max_ninc + 1,) + inc_d = self.inc[d] # Shape: (max_ninc,) + + # ninc_d = self.ninc[d].float() # Scalar tensor + ninc_d = self.ninc[d] # Scalar tensor + + # Perform searchsorted to find indices where x should be inserted to maintain order + # torch.searchsorted returns indices in [0, max_ninc +1] + iu = ( + torch.searchsorted(grid_d, x[:, d].contiguous(), right=True) - 1 + ) # Shape: (batch_size,) + + # Clamp indices to [0, ninc_d - 1] to ensure they are within valid range + iu_clamped = torch.clamp(iu, min=0, max=ninc_d - 1) # Shape: (batch_size,) + + # Gather grid and increment values based on iu_clamped + # grid_gather and inc_gather have shape (batch_size,) + grid_gather = grid_d[iu_clamped] # Shape: (batch_size,) + inc_gather = inc_d[iu_clamped] # Shape: (batch_size,) + + # Compute du: fractional part within the increment + du = (x[:, d] - grid_gather) / (inc_gather + TINY) # Shape: (batch_size,) + + # Compute u for dimension d + u[:, d] = (du + iu_clamped) / ninc_d # Shape: (batch_size,) + + # Compute log determinant contribution for dimension d + log_detJ += (inc_gather * ninc_d + TINY).log_() # Shape: (batch_size,) + + # Handle out-of-bounds cases + # Lower bound: x <= grid[d, 0] + lower_mask = x[:, d] <= grid_d[0] # Shape: (batch_size,) + if lower_mask.any(): + u[:, d].masked_fill_(lower_mask, 0.0) + log_detJ += (inc_d[0] * ninc_d + TINY).log_() + + # Upper bound: x >= grid[d, ninc_d] + upper_mask = x[:, d] >= grid_d[ninc_d] # Shape: (batch_size,) + if upper_mask.any(): + u[:, d].masked_fill_(upper_mask, 1.0) + log_detJ += (inc_d[ninc_d - 1] * ninc_d + TINY).log_() + + return u, log_detJ # class NormalizingFlow(Map): diff --git a/MCintegration/maps_test.py b/MCintegration/maps_test.py index 909c48a..d54cd36 100644 --- a/MCintegration/maps_test.py +++ b/MCintegration/maps_test.py @@ -81,6 +81,11 @@ def setUp(self): self.sample = Configuration( batch_size=3, dim=2, f_dim=1, device=self.device, dtype=self.dtype ) + self.sample.u.uniform_(0, 1) + self.sample.x[:] = self.sample.u + self.sample.fx.uniform_(0, 1) + self.sample.weight.fill_(1.0) + self.sample.detJ.fill_(1.0) def tearDown(self): # Teardown after each test @@ -118,8 +123,12 @@ def test_clear(self): # Test clearing accumulated data self.vegas.add_training_data(self.sample) self.vegas.clear() - self.assertIsNone(self.vegas.sum_f) - self.assertIsNone(self.vegas.n_f) + # self.assertIsNone(self.vegas.sum_f) + # self.assertIsNone(self.vegas.n_f) + self.assertTrue(torch.all(self.vegas.sum_f == 0).item()) + self.assertTrue(torch.all(self.vegas.sum_f == 0).item()) + # self.assertEqual(self.vegas.sum_f, torch.zeros_like(self.vegas.sum_f)) + # self.assertEqual(self.vegas.n_f, torch.zeros_like(self.vegas.n_f)) def test_forward(self): # Test forward transformation @@ -153,8 +162,8 @@ def test_make_uniform(self): self.assertEqual(self.vegas.grid.shape, (2, self.ninc + 1)) self.assertEqual(self.vegas.inc.shape, (2, self.ninc)) self.assertTrue(torch.equal(self.vegas.grid, self.init_grid)) - self.assertIsNone(self.vegas.sum_f) - self.assertIsNone(self.vegas.n_f) + self.assertTrue(torch.all(self.vegas.sum_f == 0).item()) + self.assertTrue(torch.all(self.vegas.sum_f == 0).item()) def test_edge_cases(self): # Test edge cases diff --git a/examples/vegas_profile.py b/examples/vegas_profile.py new file mode 100644 index 0000000..e84f677 --- /dev/null +++ b/examples/vegas_profile.py @@ -0,0 +1,150 @@ +# Integration tests for VEGAS + MonteCarlo/MarkovChainMonteCarlo integral methods. +import torch +import logging +from MCintegration import MonteCarlo, MarkovChainMonteCarlo +from MCintegration import Vegas, set_seed, get_device +# from torch.autograd import profiler, ProfilerActivity +# import torch.utils.benchmark as benchmark + +set_seed(42) +device = get_device() +# device = torch.device("mps") +dtype = torch.float32 + +logging.basicConfig( + format="%(levelname)s:%(asctime)s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) +logger: logging.Logger = logging.getLogger(__name__) +logger.setLevel(level=logging.INFO) +output_dir = "./snapshot_output" + + +def sharp_peak(x, f): + f[:, 0] = torch.sum((x - 0.5) ** 2, dim=-1) + f[:, 0] *= -200 + f[:, 0].exp_() + return f[:, 0] + + +def sharp_integrands(x, f): + f[:, 0] = torch.sum((x - 0.5) ** 2, dim=-1) + f[:, 0] *= -200 + f[:, 0].exp_() + f[:, 1] = f[:, 0] * x[:, 0] + f[:, 2] = f[:, 0] * x[:, 0] ** 2 + return f.mean(dim=-1) + + +def func(x, f): + f[:, 0] = torch.log(x[:, 0]) / torch.sqrt(x[:, 0]) + return f[:, 0] + + +alpha = 2.0 +ninc = 1000 +# n_eval = 1000000 +n_eval = 500000 +batch_size = 10000 +n_therm = 10 + +print("\nCalculate the integral log(x)/x^0.5 in the bounds [0, 1]") +dim = 1 +bounds = [[0, 1]] * dim +print("Training VEGAS map...") +vegas_map = Vegas(dim, device=device, ninc=ninc, dtype=dtype) + +torch.cuda.memory._record_memory_history(max_entries=100000) +vegas_map.adaptive_training(100000, func, epoch=10, alpha=alpha) +try: + torch.cuda.memory._dump_snapshot(f"{output_dir}/vegas_training.pickle") +except Exception as e: + logger.error(f"Failed to capture memory snapshot {e}") +torch.cuda.memory._record_memory_history(enabled=None) + +vegas_integrator = MonteCarlo( + bounds, + func, + maps=vegas_map, + batch_size=batch_size, +) +res = vegas_integrator(n_eval) +print("VEGAS Integral results: ", res) + +vegasmcmc_integrator = MarkovChainMonteCarlo( + bounds, + func, + maps=vegas_map, + batch_size=batch_size, + nburnin=n_therm, +) +res = vegasmcmc_integrator(n_eval, mix_rate=0.5) +print("VEGAS-MarkovChainMonteCarlo Integral results: ", res) +print(type(res)) +print(res.sum_neval) +print(res.itn_results) +print(res.nitn) + +# Start Monte Carlo integration, including plain-MC, MarkovChainMonteCarlo, vegas, and vegas-MarkovChainMonteCarlo +print("\nCalculate the integral [h(X), x1 * h(X), x1^2 * h(X)] in the bounds [0, 1]^4") +print("h(X) = exp(-200 * (x1^2 + x2^2 + x3^2 + x4^2))") + +dim = 4 +bounds = [(0, 1)] * dim +vegas_map = Vegas(dim, device=device, ninc=ninc, dtype=dtype) +print("train VEGAS map for h(X)...") +vegas_map.adaptive_training(20000, sharp_peak, epoch=10, alpha=alpha) +# print(vegas_map.extract_grid()) + +print("VEGAS Integral results:") +vegas_integrator = MonteCarlo( + bounds, + sharp_integrands, + f_dim=3, + maps=vegas_map, + batch_size=batch_size, +) +res = vegas_integrator(neval=500000) +print( + " I[0] =", + res[0], + " I[1] =", + res[1], + " I[2] =", + res[2], + " I[1]/I[0] =", + res[1] / res[0], +) +print(type(res)) +print(type(res[0])) +print(res[0].sum_neval) +print(res[0].itn_results) +print(res[0].nitn) + + +print("VEGAS-MarkovChainMonteCarlo Integral results:") +vegasmcmc_integrator = MarkovChainMonteCarlo( + bounds, + sharp_integrands, + f_dim=3, + maps=vegas_map, + batch_size=batch_size, + nburnin=n_therm, +) +res = vegasmcmc_integrator(neval=500000, mix_rate=0.5) +print( + " I[0] =", + res[0], + " I[1] =", + res[1], + " I[2] =", + res[2], + " I[1]/I[0] =", + res[1] / res[0], +) +print(type(res)) +print(type(res[0])) +print(res[0].sum_neval) +print(res[0].itn_results) +print(res[0].nitn)