11from typing import Callable , Union , List , Tuple , Dict
22import torch
33from utils import RAvg
4- from maps import Map , Affine , CompositeMap
4+ from maps import Map , Linear , CompositeMap
55from base import Uniform
66import gvar
77import numpy as np
88
99
1010class Integrator :
1111 """
12- Base class for all integrators.
12+ Base class for all integrators. This class is designed to handle integration tasks
13+ over a specified domain (bounds) using a sampling method (q0) and optional
14+ transformation maps.
1315 """
1416
1517 def __init__ (
1618 self ,
17- bounds : Union [ List [ Tuple [ float , float ]], np . ndarray ] ,
19+ bounds ,
1820 q0 = None ,
1921 maps = None ,
2022 neval : int = 1000 ,
2123 nbatch : int = None ,
2224 device = "cpu" ,
23- adapt = False ,
25+ dtype = torch . float64 ,
2426 ):
25- self .adapt = adapt
27+ if not isinstance (bounds , (list , np .ndarray )):
28+ raise TypeError ("bounds must be a list or a NumPy array." )
29+ self .dtype = dtype
2630 self .dim = len (bounds )
2731 if not q0 :
28- q0 = Uniform (bounds , device = device )
29- self .bounds = torch .tensor (bounds , dtype = torch . float64 , device = device )
32+ q0 = Uniform (bounds , device = device , dtype = dtype )
33+ self .bounds = torch .tensor (bounds , dtype = dtype , device = device )
3034 self .q0 = q0
35+ if maps :
36+ if not self .dtype == maps .dtype :
37+ raise ValueError ("Float type of maps should be same as integrator." )
3138 self .maps = maps
3239 self .neval = neval
3340 if nbatch is None :
@@ -54,46 +61,41 @@ def sample(self, nsample, **kwargs):
5461class MonteCarlo (Integrator ):
5562 def __init__ (
5663 self ,
57- bounds : Union [ List [ Tuple [ float , float ]], np . ndarray ] ,
64+ bounds ,
5865 q0 = None ,
5966 maps = None ,
60- nitn : int = 10 ,
6167 neval : int = 1000 ,
6268 nbatch : int = None ,
6369 device = "cpu" ,
64- adapt = False ,
70+ dtype = torch . float64 ,
6571 ):
66- super ().__init__ (bounds , q0 , maps , neval , nbatch , device , adapt )
67- self .nitn = nitn
72+ super ().__init__ (bounds , q0 , maps , neval , nbatch , device , dtype )
6873
6974 def __call__ (self , f : Callable , ** kwargs ):
7075 x , _ = self .sample (self .nbatch )
7176 f_values = f (x )
7277 f_size = len (f_values ) if isinstance (f_values , (list , tuple )) else 1
73- type_fval = f_values .dtype if f_size == 1 else type (f_values [0 ].dtype )
74-
75- mean = torch .zeros (f_size , dtype = type_fval , device = self .device )
76- var = torch .zeros (f_size , dtype = type_fval , device = self .device )
78+ # type_fval = f_values.dtype if f_size == 1 else type(f_values[0].dtype)
79+ # mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
80+ # var = torch.zeros(f_size, dtype=type_fval, device=self.device)
7781 # var = torch.zeros((f_size, f_size), dtype=type_fval, device=self.device)
78-
79- result = RAvg (weighted = self .adapt )
82+ mean = torch .zeros (f_size , dtype = self .dtype , device = self .device )
83+ var = torch .zeros (f_size , dtype = self .dtype , device = self .device )
84+ result = RAvg ()
8085 epoch = self .neval // self .nbatch
8186
82- for itn in range (self .nitn ):
83- mean [:] = 0
84- var [:] = 0
85- for _ in range (epoch ):
86- x , log_detJ = self .sample (self .nbatch )
87- f_values = f (x )
88- batch_results = self ._multiply_by_jacobian (
89- f_values , torch .exp (log_detJ )
90- )
87+ mean [:] = 0
88+ var [:] = 0
89+ for _ in range (epoch ):
90+ x , log_detJ = self .sample (self .nbatch )
91+ f_values = f (x )
92+ batch_results = self ._multiply_by_jacobian (f_values , torch .exp (log_detJ ))
9193
92- mean += torch .mean (batch_results , dim = - 1 ) / epoch
93- var += torch .var (batch_results , dim = - 1 ) / (self .neval * epoch )
94+ mean += torch .mean (batch_results , dim = - 1 ) / epoch
95+ var += torch .var (batch_results , dim = - 1 ) / (self .neval * epoch )
9496
95- result .sum_neval += self .neval
96- result .add (gvar .gvar (mean .item (), (var ** 0.5 ).item ()))
97+ result .sum_neval += self .neval
98+ result .add (gvar .gvar (mean .item (), (var ** 0.5 ).item ()))
9799 return result
98100
99101 def _multiply_by_jacobian (self , values , jac ):
@@ -105,29 +107,48 @@ def _multiply_by_jacobian(self, values, jac):
105107 return values * jac
106108
107109
110+ def random_walk (dim , bounds , device , dtype , u , ** kwargs ):
111+ rangebounds = bounds [:, 1 ] - bounds [:, 0 ]
112+ step_size = kwargs .get ("step_size" , 0.2 )
113+ step_sizes = rangebounds * step_size
114+ step = torch .empty (dim , device = device , dtype = dtype ).uniform_ (- 1 , 1 ) * step_sizes
115+ new_u = (u + step - bounds [:, 0 ]) % rangebounds + bounds [:, 0 ]
116+ return new_u
117+
118+
119+ def uniform (dim , bounds , device , dtype , u , ** kwargs ):
120+ rangebounds = bounds [:, 1 ] - bounds [:, 0 ]
121+ return torch .rand_like (u ) * rangebounds + bounds [:, 0 ]
122+
123+
124+ def gaussian (dim , bounds , device , dtype , u , ** kwargs ):
125+ mean = kwargs .get ("mean" , torch .zeros_like (u ))
126+ std = kwargs .get ("std" , torch .ones_like (u ))
127+ return torch .normal (mean , std )
128+
129+
108130class MCMC (MonteCarlo ):
109131 def __init__ (
110132 self ,
111- bounds : Union [ List [ Tuple [ float , float ]], np . ndarray ] ,
133+ bounds ,
112134 q0 = None ,
113135 maps = None ,
114- nitn : int = 10 ,
115136 neval = 10000 ,
116137 nbatch = None ,
117138 nburnin = 500 ,
118139 device = "cpu" ,
119- adapt = False ,
140+ dtype = torch . float64 ,
120141 ):
121- super ().__init__ (bounds , q0 , maps , nitn , neval , nbatch , device , adapt )
142+ super ().__init__ (bounds , q0 , maps , neval , nbatch , device , dtype )
122143 self .nburnin = nburnin
123144 if maps is None :
124- self .maps = Affine ([(0 , 1 )] * self .dim , device = device )
145+ self .maps = Linear ([(0 , 1 )] * self .dim , device = device )
125146 self ._rangebounds = self .bounds [:, 1 ] - self .bounds [:, 0 ]
126147
127148 def __call__ (
128149 self ,
129150 f : Callable ,
130- proposal_dist = " uniform" ,
151+ proposal_dist : Callable = uniform ,
131152 thinning = 1 ,
132153 mix_rate = 0.0 ,
133154 ** kwargs ,
@@ -146,84 +167,93 @@ def __call__(
146167 current_weight .masked_fill_ (current_weight < epsilon , epsilon )
147168 # current_fval.masked_fill_(current_fval.abs() < epsilon, epsilon)
148169
149- proposed_y = torch .empty_like (current_y )
150- proposed_x = torch .empty_like (current_x )
151- new_fval = torch .empty_like (current_fval )
152- new_weight = torch .empty_like (current_weight )
170+ # proposed_y = torch.empty_like(current_y)
171+ # proposed_x = torch.empty_like(current_x)
172+ # new_fval = torch.empty_like(current_fval)
173+ # new_weight = torch.empty_like(current_weight)
153174
154175 f_size = len (current_fval ) if isinstance (current_fval , (list , tuple )) else 1
155- type_fval = current_fval .dtype if f_size == 1 else type (current_fval [0 ].dtype )
156- mean = torch .zeros (f_size , dtype = type_fval , device = self .device )
176+ # type_fval = current_fval.dtype if f_size == 1 else type(current_fval[0].dtype)
177+ # mean = torch.zeros(f_size, dtype=type_fval, device=self.device)
178+ mean = torch .zeros (f_size , dtype = self .dtype , device = self .device )
157179 mean_ref = torch .zeros_like (mean )
158- var = torch .zeros (f_size , dtype = type_fval , device = self .device )
180+ # var = torch.zeros(f_size, dtype=type_fval, device=self.device)
181+ var = torch .zeros (f_size , dtype = self .dtype , device = self .device )
159182 var_ref = torch .zeros_like (mean )
160183
161- result = RAvg (weighted = self . adapt )
162- result_ref = RAvg (weighted = self . adapt )
184+ result = RAvg ()
185+ result_ref = RAvg ()
163186
164187 epoch = self .neval // self .nbatch
165188 n_meas = 0
166- for itn in range (self .nitn ):
167- for i in range (epoch ):
168- proposed_y [:] = self ._propose (current_y , proposal_dist , ** kwargs )
169- proposed_x [:], new_jac = self .maps .forward (proposed_y )
170- new_jac = torch .exp (new_jac )
171189
172- new_fval [:] = f (proposed_x )
173- new_weight = mix_rate / new_jac + (1 - mix_rate ) * new_fval .abs ()
190+ def _propose (current_y , current_fval , current_weight , current_jac ):
191+ proposed_y = proposal_dist (
192+ self .dim , self .bounds , self .device , self .dtype , current_y , ** kwargs
193+ )
194+ proposed_x , new_jac = self .maps .forward (proposed_y )
195+ new_jac = torch .exp (new_jac )
196+
197+ new_fval = f (proposed_x )
198+ new_weight = mix_rate / new_jac + (1 - mix_rate ) * new_fval .abs ()
174199
175- acceptance_probs = new_weight / current_weight * new_jac / current_jac
200+ acceptance_probs = new_weight / current_weight * new_jac / current_jac
176201
177- accept = (
178- torch .rand (self .nbatch , dtype = torch . float64 , device = self .device )
179- <= acceptance_probs
180- )
202+ accept = (
203+ torch .rand (self .nbatch , dtype = self . dtype , device = self .device )
204+ <= acceptance_probs
205+ )
181206
182- current_y = torch .where (accept .unsqueeze (1 ), proposed_y , current_y )
183- current_fval = torch .where (accept , new_fval , current_fval )
184- current_weight = torch .where (accept , new_weight , current_weight )
185- current_jac = torch .where (accept , new_jac , current_jac )
186-
187- if i < self .nburnin and itn == 0 :
188- continue
189- elif i % thinning == 0 :
190- n_meas += 1
191- batch_results = current_fval / current_weight
192-
193- mean += torch .mean (batch_results , dim = - 1 ) / epoch
194- var += torch .var (batch_results , dim = - 1 ) / epoch
195-
196- batch_results_ref = 1 / (current_jac * current_weight )
197- mean_ref += torch .mean (batch_results_ref , dim = - 1 ) / epoch
198- var_ref += torch .var (batch_results_ref , dim = - 1 ) / epoch
199-
200- result .sum_neval += self .neval
201- result .add (gvar .gvar (mean .item (), ((var / n_meas ) ** 0.5 ).item ()))
202- result_ref .sum_neval += self .nbatch
203- result_ref .add (
204- gvar .gvar (mean_ref .item (), ((var_ref / n_meas ) ** 0.5 ).item ())
207+ current_y = torch .where (accept .unsqueeze (1 ), proposed_y , current_y )
208+ current_fval = torch .where (accept , new_fval , current_fval )
209+ current_weight = torch .where (accept , new_weight , current_weight )
210+ current_jac = torch .where (accept , new_jac , current_jac )
211+ return current_y , current_fval , current_weight , current_jac
212+
213+ for i in range (self .nburnin ):
214+ current_y , current_fval , current_weight , current_jac = _propose (
215+ current_y , current_fval , current_weight , current_jac
205216 )
217+ for i in range (epoch // thinning ):
218+ for j in range (thinning ):
219+ current_y , current_fval , current_weight , current_jac = _propose (
220+ current_y , current_fval , current_weight , current_jac
221+ )
222+ n_meas += 1
223+ batch_results = current_fval / current_weight
224+
225+ mean += torch .mean (batch_results , dim = - 1 ) / epoch
226+ var += torch .var (batch_results , dim = - 1 ) / epoch
227+
228+ batch_results_ref = 1 / (current_jac * current_weight )
229+ mean_ref += torch .mean (batch_results_ref , dim = - 1 ) / epoch
230+ var_ref += torch .var (batch_results_ref , dim = - 1 ) / epoch
231+
232+ result .sum_neval += self .neval
233+ result .add (gvar .gvar (mean .item (), ((var / n_meas ) ** 0.5 ).item ()))
234+ result_ref .sum_neval += self .nbatch
235+ result_ref .add (gvar .gvar (mean_ref .item (), ((var_ref / n_meas ) ** 0.5 ).item ()))
206236
207237 return result / result_ref * self ._rangebounds .prod ()
208238
209- def _propose (self , u , proposal_dist , ** kwargs ):
210- if proposal_dist == "random_walk" :
211- step_size = kwargs .get ("step_size" , 0.2 )
212- step_sizes = self ._rangebounds * step_size
213- step = (
214- torch .empty (self .dim , device = self .device ).uniform_ (- 1 , 1 ) * step_sizes
215- )
216- new_u = (u + step - self .bounds [:, 0 ]) % self ._rangebounds + self .bounds [
217- :, 0
218- ]
219- return new_u
220- # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
221- elif proposal_dist == "uniform" :
222- # return torch.rand_like(u)
223- return torch .rand_like (u ) * self ._rangebounds + self .bounds [:, 0 ]
224- # elif proposal_dist == "gaussian":
225- # mean = kwargs.get("mean", torch.zeros_like(u))
226- # std = kwargs.get("std", torch.ones_like(u))
227- # return torch.normal(mean, std)
228- else :
229- raise ValueError (f"Unknown proposal distribution: { proposal_dist } " )
239+ # def _propose(self, u, proposal_dist, **kwargs):
240+ # if proposal_dist == "random_walk":
241+ # step_size = kwargs.get("step_size", 0.2)
242+ # step_sizes = self._rangebounds * step_size
243+ # step = (
244+ # torch.empty(self.dim, device=self.device).uniform_(-1, 1) * step_sizes
245+ # )
246+ # new_u = (u + step - self.bounds[:, 0]) % self._rangebounds + self.bounds[
247+ # :, 0
248+ # ]
249+ # return new_u
250+ # # return (u + (torch.rand_like(u) - 0.5) * step_size) % 1.0
251+ # elif proposal_dist == "uniform":
252+ # # return torch.rand_like(u)
253+ # return torch.rand_like(u) * self._rangebounds + self.bounds[:, 0]
254+ # # elif proposal_dist == "gaussian":
255+ # # mean = kwargs.get("mean", torch.zeros_like(u))
256+ # # std = kwargs.get("std", torch.ones_like(u))
257+ # # return torch.normal(mean, std)
258+ # else:
259+ # raise ValueError(f"Unknown proposal distribution: {proposal_dist}")
0 commit comments