@@ -42,6 +42,8 @@ def set_step(self, step):
4242 self ._step = step_choices [step ]
4343
4444
45+ # TODO: these integrators are not general anymore, as they rely on the
46+ # structure of the chain method classes
4547class VerletIntegrator (BaseIntegrator ):
4648 """A quick Verlet integration in Cartesian coordinates
4749
@@ -148,19 +150,24 @@ def _step(self, t, y, h, f):
148150# NEBM integrator for F-S algorithm
149151
150152
151- class FSIntegrator (BaseIntegrator ):
153+ # TODO: these integrators are not general anymore, as they rely on the
154+ # structure of the chain method classes
155+ # TODO: type checks
156+ class FSIntegrator (object ):
152157 """A step integrator considering the action of the band
153158 """
154- def __init__ (self , band , forces , rhs_fun , action_fun , n_images , n_dofs_image ,
159+ def __init__ (self , ChainObj ,
160+ # band, forces, distances, rhs_fun, action_fun,
161+ # n_images, n_dofs_image,
155162 maxSteps = 1000 ,
156163 maxCreep = 5 , actionTol = 1e-10 , forcesTol = 1e-6 ,
157164 etaScale = 1.0 , dEta = 2 , minEta = 0.001 ,
158165 # perturbSeed=42, perturbFactor=0.1,
159166 nTrail = 10 , resetMax = 20
160167 ):
161- super (FSIntegrator , self ).__init__ (band , rhs_fun )
168+ # super(FSIntegrator, self).__init__(band, rhs_fun)
162169
163- self .action_fun = action_fun
170+ self .ChainObj = ChainObj
164171
165172 # Integration parameters
166173 # TODO: move to run function?
@@ -172,24 +179,22 @@ def __init__(self, band, forces, rhs_fun, action_fun, n_images, n_dofs_image,
172179 self .forcesTol = forcesTol
173180 self .actionTol = actionTol
174181 self .maxSteps = maxSteps
175-
182+ self .step = 0
183+ self .nTrail = nTrail
176184 self .i_step = 0
177- self .n_images = n_images
178- self .n_dofs_image = n_dofs_image
185+
186+ # Chain objects:
187+ self .n_images = self .ChainObj .n_images
188+ self .n_dofs_image = self .ChainObj .n_dofs_image
179189 # self.forces_prev = np.zeros_like(band).reshape(n_images, -1)
180190 # self.G :
181- self .forces = forces
182- self .forces_old = np .zeros_like (forces )
191+ self .forces = self .ChainObj .forces
192+ self .distances = self .ChainObj .distances
193+ self .forces_old = np .zeros_like (self .ChainObj .forces )
183194
184- # Rename y to band
185- self .band = self .y
186- self .band_old = np .zeros_like (self .band ) # y -> band
187-
188- self .step = 0
189- self .nTrail = nTrail
190-
191- def run_until (self , t ):
192- pass
195+ # self.band should be just a reference to the band in the ChainObj
196+ self .band = self .ChainObj .band
197+ self .band_old = np .zeros_like (self .band )
193198
194199 def run_for (self , n_steps ):
195200
@@ -202,6 +207,7 @@ def run_for(self, n_steps):
202207 self .trailAction = np .zeros (self .nTrail )
203208 trailPool = cycle (range (self .nTrail )) # cycle through 0,1,...,(nTrail-1),0,1,...
204209 eta = 1.0
210+ self .i_step = 0
205211
206212 # In __init__:
207213 # self.band_last[:] = self.band
@@ -216,7 +222,7 @@ def run_for(self, n_steps):
216222 self .band [:] = self .band_old
217223
218224 # Compute from self.band. Do not update the step at this stage:
219- # This step updates the forces in the G array of the nebm module,
225+ # This step updates the forces and distances in the G array of the nebm module,
220226 # using the current band state self.y
221227 # TODO: remove time from chain method rhs
222228 # make a specific function to update G??
@@ -243,12 +249,14 @@ def run_for(self, n_steps):
243249
244250 self .trailAction
245251
246- self .rhs ( t , self .y )
247- self .action = self .action_fun ()
252+ self .ChainObj . nebm_step ( self .band , ensure_zero_extrema = True )
253+ self .action = self .ChainObj . action_fun ()
248254
249255 self .trailAction [nStart ] = self .action
250256 nStart = next (trailPool )
251257
258+ self .i_step += 1
259+
252260 # Getting averages of forces from the INNER images in the band (no extrema)
253261 # (forces are given by vector G in the chain method code)
254262 # TODO: we might use all band images, not only inner ones
@@ -266,7 +274,7 @@ def run_for(self, n_steps):
266274 exitFlag = True
267275 break # creep loop
268276
269- if (n_steps >= self .maxSteps ):
277+ if (self . i_step >= self .maxSteps ):
270278 print ('Number of steps reached maximum' )
271279 exitFlag = True
272280 break # creep loop
@@ -282,7 +290,7 @@ def run_for(self, n_steps):
282290 print ('' )
283291 resetCount += 1
284292 bestAction = self .action_old
285- RefinePath ( self .band ) # Resets the path to equidistant structures (smoothing kinks?)
293+ self . refine_path ( self . ChainObj . distances , self .band ) # Resets the path to equidistant structures (smoothing kinks?)
286294 # PathChanged[:] = True
287295
288296 if resetCount > self .resetMax :
@@ -309,13 +317,13 @@ def run_for(self, n_steps):
309317 resetCount = 0
310318
311319 # Taken from the string method class
312- def refine_path (self , distances ):
320+ def refine_path (self , distances , band ):
313321 """
314322 """
315323 new_dist = np .linspace (distances [0 ], distances [- 1 ], distances .shape [0 ])
316324 # Restructure the string by interpolating every spin component
317325 # print(self.integrator.y[self.n_dofs_image:self.n_dofs_image + 10])
318- bandrs = self . band .reshape (self .n_images , self .n_dofs_image )
326+ bandrs = band .reshape (self .n_images , self .n_dofs_image )
319327 for i in range (self .n_dofs_image ):
320328
321329 cs = si .CubicSpline (distances , bandrs [:, i ])
0 commit comments