3535
3636import torch
3737import torch .fft
38+ import numpy as np
3839
3940
4041def sech (x ):
@@ -53,6 +54,39 @@ def Martens_model(q, params):
5354 return Va * (sech (2.0 * q [0 ]))** 2 + 0.5 * Vb * (q [1 ] + Vc * (q [0 ]** 2 - 1.0 ) )** 2
5455
5556
57+ # Define Tully's simple avoided crossing diabatic potential matrix
58+ def tully_potential_matrix (Q , params ):
59+ """
60+ Q: Tensor with shape [ndof, Ngrid]
61+ Returns diabatic potential matrix [2, 2, Ngrid]
62+ """
63+ x = Q [0 ] # Assume 1D nuclear coordinate
64+
65+ A = params .get ("A" , 0.01 )
66+ B = params .get ("B" , 1.6 )
67+ C = params .get ("C" , 0.005 )
68+ D = params .get ("D" , 1.0 )
69+
70+
71+ V11 = torch .where (
72+ x >= 0 ,
73+ A * (1 - torch .exp (- B * x )),
74+ - A * (1 - torch .exp (B * x ))
75+ )
76+
77+ #V11 = A * (1 - torch.exp(-B * x)) # Diabatic state 1 potential
78+ V22 = - V11 # Diabatic state 2 potential (mirror)
79+ V12 = C * torch .exp (- D * x ** 2 ) # Coupling between diabatic states
80+
81+ shape = x .shape
82+ Vmat = torch .zeros ((* shape , 2 , 2 ), dtype = torch .cfloat )
83+ Vmat [..., 0 , 0 ] = V11
84+ Vmat [..., 1 , 1 ] = V22
85+ Vmat [..., 0 , 1 ] = V12
86+ Vmat [..., 1 , 0 ] = torch .conj (V12 )
87+
88+ return Vmat
89+
5690
5791def gaussian_wavepacket (q , params ):
5892 """
@@ -211,3 +245,311 @@ def solve(self):
211245 self .save ()
212246
213247
248+
249+
250+ class exact_tdse_solver_multistate :
251+ def __init__ (self , params ):
252+ """
253+ Initializes the TDSE solver.
254+
255+ Parameters:
256+ params (dict): Dictionary of simulation parameters:
257+ - prefix (str): Filename prefix for output
258+ - grid_size (list[int]): Number of points per spatial dimension
259+ - q_min (list[float]), q_max (list[float]): Spatial bounds
260+ - save_every_n_steps (int): Interval for recording data
261+ - dt (float): Time step
262+ - nsteps (int): Number of time steps
263+ - mass (list[float]): Masses per dimension
264+ - potential_fn_params (dict): Parameters for potential energy surface
265+ - psi0_fn_params (dict): Parameters for initial wavepacket
266+ - device (torch.device): Computation device
267+ - Nstates (int): Number of electronic states
268+ - representation (str): "diabatic" or "adiabatic"
269+ - initial_state_index (int): Which state to initialize
270+ - method (str): Propagation scheme ('miller-colton', 'split-operator', 'crank-nicolson')
271+ """
272+
273+ self .prefix = params .get ("prefix" , "exact-solution" )
274+ self .grid_size = torch .tensor (params .get ("grid_size" , [64 , 64 ]))
275+ self .ndim = len (self .grid_size )
276+ self .q_min = torch .tensor (params .get ("q_min" , [- 10.0 ]* self .ndim ))
277+ self .q_max = torch .tensor (params .get ("q_max" , [10.0 ]* self .ndim ))
278+ self .save_every_n_steps = params .get ("save_every_n_steps" , 1 )
279+ self .dt = params .get ("dt" , 0.01 )
280+ self .nsteps = params .get ("nsteps" , 500 )
281+ self .mass = torch .tensor (params .get ("mass" , [2000.0 ]* self .ndim ))
282+ self .potential_fn = params .get ("potential_fn" , None )
283+ self .potential_fn_params = params .get ("potential_fn_params" , {})
284+ self .psi0_fn = params .get ("psi0_fn" , None )
285+ self .psi0_fn_params = params .get ("psi0_fn_params" , {})
286+ self .device = params .get ("device" , torch .device ("cuda" if torch .cuda .is_available () else "cpu" ))
287+ self .hbar = 1.0
288+ self .Nstates = params .get ("Nstates" , 2 )
289+ self .representation = params .get ("representation" , "diabatic" )
290+ self .initial_state_index = params .get ("initial_state_index" , 0 )
291+ self .method = params .get ("method" , "miller-colton" ).lower ()
292+
293+ self .time = []
294+ self .kinetic_energy = []
295+ self .potential_energy = []
296+ self .total_energy = []
297+ self .population_right = []
298+ self .norm = []
299+
300+ def initialize_grids (self ):
301+ """
302+ Constructs real- and momentum-space grids, initializes kinetic energy operator
303+ and wavefunction on grid in chosen representation.
304+ """
305+ print ("Initializing grids" )
306+ print ("grid_size = " , self .grid_size )
307+ self .ngrid = int (self .grid_size .prod ().item ())
308+ print ("self.ngrid = " , self .ngrid )
309+
310+ # Real-space grid
311+ q_axes = [torch .linspace (self .q_min [i ], self .q_max [i ], self .grid_size [i ]) for i in range (self .ndim )]
312+ q_grids = torch .meshgrid (* q_axes , indexing = "ij" )
313+ self .dq = torch .tensor ([q_axes [i ][1 ] - q_axes [i ][0 ] for i in range (self .ndim )])
314+ self .Q = torch .stack (q_grids )
315+ print ("Q = " , self .Q .shape )
316+ print ("dq = " , self .dq )
317+
318+ # Momentum-space grid
319+ self .dk = 2 * torch .pi / (self .grid_size * self .dq )
320+ k_axes = [torch .fft .fftshift (torch .arange (- self .grid_size [i ] // 2 , self .grid_size [i ] // 2 )) * self .dk [i ] for i in range (self .ndim )]
321+ k_grids = torch .meshgrid (* k_axes , indexing = "ij" )
322+ self .K = torch .stack (k_grids )
323+ print ("K = " , self .K .shape )
324+ print ("dk = " , self .dk )
325+
326+ # Volume elements
327+ self .dV = self .dq .prod ()
328+ self .dVk = self .dk .prod () #(self.dq / self.grid_size).prod()
329+ print ("dV = " , self .dV )
330+ print ("dVk = " , self .dVk )
331+
332+
333+ # Allocate storage:
334+ self .nsnaps = self .nsteps // self .save_every_n_steps + 1 # how many snapshots to save
335+
336+ # Diabatic properties
337+ self .psi_r_dia = torch .zeros ((* self .grid_size , self .Nstates ), dtype = torch .cfloat )
338+ self .psi_k_dia = torch .zeros_like (self .psi_r_dia )
339+ self .psi_r_dia_all = torch .zeros (( self .nsnaps , * self .grid_size , self .Nstates ), dtype = torch .cfloat )
340+ self .psi_k_dia_all = torch .zeros ((self .nsnaps , * self .grid_size , self .Nstates ), dtype = torch .cfloat )
341+ self .rho_dia_all = torch .zeros ((self .nsnaps , self .Nstates , self .Nstates ), dtype = torch .cfloat )
342+
343+ # Adiabatic properties
344+ self .psi_r_adi = torch .zeros ((* self .grid_size , self .Nstates ), dtype = torch .cfloat )
345+ self .psi_k_adi = torch .zeros_like (self .psi_r_adi )
346+ self .psi_r_adi_all = torch .zeros (( self .nsnaps , * self .grid_size , self .Nstates ), dtype = torch .cfloat )
347+ self .psi_k_adi_all = torch .zeros ((self .nsnaps , * self .grid_size , self .Nstates ), dtype = torch .cfloat )
348+ self .rho_adi_all = torch .zeros ((self .nsnaps , self .Nstates , self .Nstates ), dtype = torch .cfloat )
349+
350+
351+ def update_adi_r (self ):
352+ """
353+ Convert diabatic to adiabatic r-space wavefunction: C_adi = U * C_dia
354+ """
355+ self .psi_r_adi = torch .einsum ("...ij, ...j->...i" , self .eigvecs , self .psi_r_dia )
356+
357+ def update_dia_r (self ):
358+ """
359+ Convert adiabatic to diabatic r-space wavefunction: C_dia = U.H * C_adi
360+ """
361+ self .psi_r_dia = torch .einsum ("...ij, ...j->...i" , self .eigvecs .conj ().transpose (- 2 ,- 1 ), self .psi_r_adi )
362+
363+
364+ def transform_r2k (self , rep ):
365+ """
366+ Compute the k-space diabatic and adiabatic wavefunctions from the r-space counterparts
367+ rep: 0 - diabatic, 1 - adiabatic
368+ """
369+ # Define which axes are spatial (everything except last one = Nstates)
370+ spatial_dims = tuple (range (self .ndim ))
371+
372+ # Apply FFT over spatial dimensions for all states at once
373+ if rep == 0 :
374+ self .psi_k_dia = torch .fft .fftn (self .psi_r_dia , dim = spatial_dims ) # norm='forward')
375+ elif rep == 1 :
376+ self .psi_k_adi = torch .fft .fftn (self .psi_r_adi , dim = spatial_dims ) # norm='forward')
377+
378+ def transform_k2r (self , rep ):
379+ """
380+ Compute the k-space diabatic and adiabatic wavefunctions from the r-space counterparts
381+ rep: 0 - diabatic, 1 - adiabatic
382+ """
383+ # Define which axes are spatial (everything except last one = Nstates)
384+ spatial_dims = tuple (range (self .ndim ))
385+
386+ # Apply inverse FFT over same spatial dimensions
387+ if rep == 0 :
388+ self .psi_r_dia = torch .fft .ifftn (self .psi_k_dia , dim = spatial_dims ) # norm='forward')
389+ elif rep == 1 :
390+ self .psi_r_adi = torch .fft .ifftn (self .psi_k_adi , dim = spatial_dims ) # norm='forward')
391+
392+
393+ def initialize_operators (self ):
394+ # Kinetic energy operator - same for all states
395+ self .T = 0.5 * torch .sum ((self .hbar * self .K ) ** 2 / self .mass .view (self .ndim , * [1 ]* self .ndim ), dim = 0 ) # T: [*grid_size]
396+ print ("T.shape = " , self .T .shape )
397+
398+ # Element-wise exponentials are fine here
399+ self .expT = torch .exp (- 1j * self .T * self .dt / self .hbar ) # expT: [*grid_size]
400+ print ("expT.shape = " , self .expT .shape )
401+
402+ # Potential energy operator - matrix for multiple states
403+ self .V = self .potential_fn (self .Q , self .potential_fn_params ) # V: [*grid_size, Nstates, Nstates]
404+ print ("V.shape = " , self .V .shape )
405+
406+ # dia <-> adi transformation for all points
407+ # V U = E U => H = U.H * E * U
408+ # E = <psi_adi | H | psi_adi >
409+ # V = <psi_dia | H | psi_dia >
410+ # | psi_adi > U = | psi_dia >
411+ # | Psi > = | psi_dia> C_dia = | psi_adi > C_adi = | psi_adi > U C_dia
412+ # so C_adi = U C_dia
413+
414+ self .eigvals , self .eigvecs = torch .linalg .eigh (self .V ) # V: [*grid_size, Nstates, Nstates]
415+ self .exp_diag = torch .exp (- 0.5j * self .dt * self .eigvals )
416+ self .expV_half = self .eigvecs .conj ().transpose (- 2 , - 1 ) @ torch .diag_embed (self .exp_diag ) @ self .eigvecs
417+ print ("expV_half.shape = " , self .expV_half .shape )
418+
419+ # Initialize the wavefunction in r-space
420+ wfn = self .psi0_fn (self .Q , self .psi0_fn_params )
421+
422+ if self .representation == "diabatic" :
423+ self .psi_r_dia [..., self .initial_state_index ] = wfn
424+ self .update_adi_r () # dia -> adi
425+ elif self .representation == "adiabatic" :
426+ self .psi_r_adi [..., self .initial_state_index ] = wfn
427+ self .update_dia_r () # adi -> dia
428+ #print("psi.shape = ", self.psi_r_dia.shape)
429+
430+ # Update the k-space wavefunctions:
431+ self .transform_r2k (0 ) # psi_r_dia -> psi_k_dia
432+ self .transform_r2k (1 ) # psi_r_adi -> psi_k_adi
433+
434+
435+ def propagate (self ):
436+ """
437+ Time-propagates the wavefunction using the selected propagation method.
438+ Supported methods: 'miller-colton', 'split-operator', 'crank-nicolson'
439+ """
440+ for step in range (self .nsteps ):
441+
442+ #====================== Saving and computing properties ==================
443+
444+ if step % self .save_every_n_steps == 0 :
445+ istep = int (step / self .save_every_n_steps )
446+
447+ # Diabatic r-space wavefunctions
448+ self .psi_r_dia_all [istep ] = self .psi_r_dia
449+
450+ # Compute other representations:
451+ self .update_adi_r (); self .psi_r_adi_all [istep ] = self .psi_r_adi ; # adiabatic in r-space
452+ self .transform_r2k (0 ); self .psi_k_dia_all [istep ] = self .psi_k_dia ; # diabatic in k-space
453+ self .transform_r2k (1 ); self .psi_k_adi_all [istep ] = self .psi_k_adi ; # adiabatic in k-space
454+
455+
456+ # Diabatic density matrix
457+ self .rho_dia_all [istep ] = torch .einsum ("...i, ...j->ij" , self .psi_r_dia , self .psi_r_dia .conj () ) * self .dV
458+
459+ # Adiabatic density matrix
460+ self .rho_adi_all [istep ] = torch .einsum ("...i, ...j->ij" , self .psi_r_adi , self .psi_r_adi .conj () ) * self .dV
461+
462+ # Kinetic energy
463+ KE = torch .einsum ("...i,...,...i->" , self .psi_k_dia .conj (), self .T , self .psi_k_dia ) * (self .dq / self .grid_size ).prod ()
464+
465+ # Full potential energy: PE = ∫ ψ* V ψ dx
466+ PE = torch .einsum ("...i,...ij,...j->" , self .psi_r_dia .conj (), self .V , self .psi_r_dia ) * self .dV
467+
468+
469+ nrm = torch .sum (torch .abs (self .psi_r_dia ) ** 2 ) * self .dV
470+ x_coords = self .Q [0 ]
471+ right_mask = self .Q [0 ] > 0
472+ #pop_right = torch.sum(self.prob_density[right_mask]) * self.dV
473+
474+ self .norm .append ( nrm )
475+ self .time .append (step * self .dt )
476+ self .kinetic_energy .append ( KE .real .item () )
477+ self .potential_energy .append ( PE .real .item () )
478+ self .total_energy .append ( KE + PE )
479+ #self.population_right.append(pop_right.item())
480+
481+ print (f"Step { step } : Norm = { nrm :.4f} " )
482+
483+
484+ #=================== Doing computations ==========================
485+ if self .method == "crank-nicolson" :
486+ # Use Crank-Nicolson scheme for time propagation
487+ # self.crank_nicolson_step()
488+ pass # not implemented
489+
490+ else :
491+ # Apply half step of potential evolution in adiabatic representation
492+ # self.potential_half_step()
493+
494+ if self .method == "miller-colton" :
495+ pass # not implemented
496+ # Miller-Colton propagation in momentum space
497+ #for i in range(self.Nstates):
498+ # self.psi_k[i] = torch.fft.fftn(self.psi[i])
499+ # self.psi_k[i] *= torch.exp(-1j * self.T * self.dt / (2 * self.hbar))
500+ # self.psi[i] = torch.fft.ifftn(self.psi_k[i])
501+
502+ elif self .method == "split-operator" :
503+ # Half-step propagation in real space
504+ self .psi_r_dia = torch .einsum ("...ij,...j->...i" , self .expV_half , self .psi_r_dia ) #self.expV_half @ self.psi
505+
506+ # Full-step in reciprocal space
507+ self .transform_r2k (0 )
508+
509+ # Multiply by expT — make sure it broadcasts correctly
510+ self .psi_k_dia *= self .expT .unsqueeze (- 1 ) # if expT.shape == [*,], expand to [*, 1]
511+
512+ # Apply inverse FFT over same spatial dimensions
513+ self .transform_k2r (0 )
514+
515+ # Another half of the propagation in real space
516+ self .psi_r_dia = torch .einsum ("...ij,...j->...i" , self .expV_half , self .psi_r_dia ) #self.expV_half @ self.psi
517+
518+
519+ def save (self ):
520+ """Saves the grid, wavefunction, and observables to disk."""
521+ torch .save ( {"grid_size" :self .grid_size ,
522+ "ndim" :self .ndim ,
523+ "q_min" :self .q_min , "q_max" :self .q_max ,
524+ "save_every_n_steps" : self .save_every_n_steps ,
525+ "dt" :self .dt , "nsteps" :self .nsteps ,
526+ "mass" :self .mass ,
527+ "psi_r_adi" :self .psi_r_adi ,
528+ "psi_r_dia" :self .psi_r_dia ,
529+ "psi_k_adi" :self .psi_k_adi ,
530+ "psi_k_dia" :self .psi_k_dia ,
531+ "rho_dia_all" :self .rho_dia_all ,
532+ "rho_adi_all" :self .rho_adi_all ,
533+ "psi_r_dia_all" :self .psi_r_dia_all ,
534+ "psi_r_adi_all" :self .psi_r_adi_all ,
535+ "psi_k_dia_all" :self .psi_k_dia_all ,
536+ "psi_k_adi_all" :self .psi_k_adi_all ,
537+ "time" :self .time ,
538+ "Q" :self .Q , "K" :self .K , "dq" :self .dq , "dk" :self .dk ,
539+ "dV" :self .dV , "dVk" :self .dVk ,
540+ "kinetic_energy" :self .kinetic_energy ,
541+ "potential_energy" :self .potential_energy ,
542+ "total_energy" :self .total_energy ,
543+ "population_right" :self .population_right ,
544+ "norm" :self .norm ,
545+ "V" :self .V , "T" :self .T ,
546+ }, F"{ self .prefix } .pt" )
547+
548+ def solve (self ):
549+ """Runs the full simulation and saves results."""
550+ self .initialize_grids ()
551+ self .initialize_operators ()
552+ self .propagate ()
553+ self .save ()
554+
555+
0 commit comments