1+ # base.py
2+ # This file contains the base distribution classes for Monte Carlo integration methods
3+ # It defines foundational classes for sampling distributions and transformations
4+
15import torch
26from torch import nn
37import numpy as np
48import sys
59from MCintegration .utils import get_device
610
11+ # Constants for numerical stability
12+ # Small but safe non-zero value
713MINVAL = 10 ** (sys .float_info .min_10_exp + 50 )
8- MAXVAL = 10 ** (sys .float_info .max_10_exp - 50 )
14+ MAXVAL = 10 ** (sys .float_info .max_10_exp - 50 ) # Large but safe value
915EPSILON = 1e-16 # Small value to ensure numerical stability
1016
1117
1218class BaseDistribution (nn .Module ):
1319 """
14- Base distribution of a flow-based model
15- Parameters do not depend of target variable (as is the case for a VAE encoder)
20+ Base distribution class for flow-based models.
21+ This is an abstract base class that provides structure for probability distributions
22+ used in Monte Carlo integration. Parameters do not depend on target variables
23+ (unlike a VAE encoder).
1624 """
1725
1826 def __init__ (self , dim , device = "cpu" , dtype = torch .float32 ):
27+ """
28+ Initialize BaseDistribution.
29+
30+ Args:
31+ dim (int): Dimensionality of the distribution
32+ device (str or torch.device): Device to use for computation
33+ dtype (torch.dtype): Data type for computations
34+ """
1935 super ().__init__ ()
2036 self .dtype = dtype
2137 self .dim = dim
2238 self .device = device
2339
2440 def sample (self , batch_size = 1 , ** kwargs ):
25- """Samples from base distribution
41+ """
42+ Sample from the base distribution.
2643
2744 Args:
28- num_samples: Number of samples to draw from the distriubtion
45+ batch_size (int): Number of samples to draw
46+ **kwargs: Additional arguments
2947
3048 Returns:
31- Samples drawn from the distribution
49+ tuple: (samples, log_det_jacobian)
50+
51+ Raises:
52+ NotImplementedError: This is an abstract method
3253 """
3354 raise NotImplementedError
3455
3556 def sample_with_detJ (self , batch_size = 1 , ** kwargs ):
57+ """
58+ Sample from base distribution with Jacobian determinant (not log).
59+
60+ Args:
61+ batch_size (int): Number of samples to draw
62+ **kwargs: Additional arguments
63+
64+ Returns:
65+ tuple: (samples, det_jacobian)
66+ """
3667 u , detJ = self .sample (batch_size , ** kwargs )
37- detJ .exp_ ()
68+ detJ .exp_ () # Convert log_det to det
3869 return u , detJ
3970
4071
4172class Uniform (BaseDistribution ):
4273 """
43- Multivariate uniform distribution
74+ Multivariate uniform distribution over [0,1]^dim.
75+ Samples from a uniform distribution in the hypercube [0,1]^dim.
4476 """
4577
4678 def __init__ (self , dim , device = "cpu" , dtype = torch .float32 ):
79+ """
80+ Initialize Uniform distribution.
81+
82+ Args:
83+ dim (int): Dimensionality of the distribution
84+ device (str or torch.device): Device to use for computation
85+ dtype (torch.dtype): Data type for computations
86+ """
4787 super ().__init__ (dim , device , dtype )
4888
4989 def sample (self , batch_size = 1 , ** kwargs ):
90+ """
91+ Sample from uniform distribution over [0,1]^dim.
92+
93+ Args:
94+ batch_size (int): Number of samples to draw
95+ **kwargs: Additional arguments
96+
97+ Returns:
98+ tuple: (uniform samples, log_det_jacobian=0)
99+ """
50100 # torch.manual_seed(0) # test seed
51- u = torch .rand ((batch_size , self .dim ), device = self .device , dtype = self .dtype )
52- log_detJ = torch .zeros (batch_size , device = self .device , dtype = self .dtype )
101+ u = torch .rand ((batch_size , self .dim ),
102+ device = self .device , dtype = self .dtype )
103+ log_detJ = torch .zeros (
104+ batch_size , device = self .device , dtype = self .dtype )
53105 return u , log_detJ
54106
55107
56108class LinearMap (nn .Module ):
109+ """
110+ Linear transformation map of the form x = u * A + b.
111+ Maps points from one space to another using a linear transformation.
112+ """
113+
57114 def __init__ (self , A , b , device = None , dtype = torch .float32 ):
115+ """
116+ Initialize LinearMap with scaling A and offset b.
117+
118+ Args:
119+ A (list, numpy.ndarray, torch.Tensor): Scaling factors
120+ b (list, numpy.ndarray, torch.Tensor): Offset values
121+ device (str or torch.device): Device to use for computation
122+ dtype (torch.dtype): Data type for computations
123+ """
58124 if device is None :
59125 self .device = get_device ()
60126 else :
@@ -67,24 +133,54 @@ def __init__(self, A, b, device=None, dtype=torch.float32):
67133 elif isinstance (A , torch .Tensor ):
68134 self .A = A .to (dtype = self .dtype , device = self .device )
69135 else :
70- raise ValueError ("'A' must be a list, numpy array, or torch tensor." )
136+ raise ValueError (
137+ "'A' must be a list, numpy array, or torch tensor." )
71138
72139 if isinstance (b , (list , np .ndarray )):
73140 self .b = torch .tensor (b , dtype = self .dtype , device = self .device )
74141 elif isinstance (b , torch .Tensor ):
75142 self .b = b .to (dtype = self .dtype , device = self .device )
76143 else :
77- raise ValueError ("'b' must be a list, numpy array, or torch tensor." )
144+ raise ValueError (
145+ "'b' must be a list, numpy array, or torch tensor." )
78146
147+ # Pre-compute determinant of Jacobian for efficiency
79148 self ._detJ = torch .prod (self .A )
80149
81150 def forward (self , u ):
151+ """
152+ Apply forward transformation: x = u * A + b.
153+
154+ Args:
155+ u (torch.Tensor): Input points
156+
157+ Returns:
158+ tuple: (transformed points, log_det_jacobian)
159+ """
82160 return u * self .A + self .b , torch .log (self ._detJ .repeat (u .shape [0 ]))
83161
84162 def forward_with_detJ (self , u ):
163+ """
164+ Apply forward transformation with Jacobian determinant (not log).
165+
166+ Args:
167+ u (torch.Tensor): Input points
168+
169+ Returns:
170+ tuple: (transformed points, det_jacobian)
171+ """
85172 u , detJ = self .forward (u )
86- detJ .exp_ ()
173+ detJ .exp_ () # Convert log_det to det
87174 return u , detJ
88175
89176 def inverse (self , x ):
177+ """
178+ Apply inverse transformation: u = (x - b) / A.
179+
180+ Args:
181+ x (torch.Tensor): Input points
182+
183+ Returns:
184+ tuple: (transformed points, log_det_jacobian)
185+ """
90186 return (x - self .b ) / self .A , torch .log (self ._detJ .repeat (x .shape [0 ]))
0 commit comments