Skip to content

Commit a6ad422

Browse files
committed
Added a multistate generalization of the TD-SE solver written with PyTorch
1 parent 36de489 commit a6ad422

File tree

1 file changed

+342
-0
lines changed

1 file changed

+342
-0
lines changed

src/libra_py/dynamics/exact_torch/compute.py

Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import torch
3737
import torch.fft
38+
import numpy as np
3839

3940

4041
def sech(x):
@@ -53,6 +54,39 @@ def Martens_model(q, params):
5354
return Va * (sech(2.0*q[0]))**2 + 0.5 * Vb * (q[1] + Vc * (q[0]**2 - 1.0 ) )**2
5455

5556

57+
# Define Tully's simple avoided crossing diabatic potential matrix
58+
def tully_potential_matrix(Q, params):
59+
"""
60+
Q: Tensor with shape [ndof, Ngrid]
61+
Returns diabatic potential matrix [2, 2, Ngrid]
62+
"""
63+
x = Q[0] # Assume 1D nuclear coordinate
64+
65+
A = params.get("A", 0.01)
66+
B = params.get("B", 1.6)
67+
C = params.get("C", 0.005)
68+
D = params.get("D", 1.0)
69+
70+
71+
V11 = torch.where(
72+
x >= 0,
73+
A * (1 - torch.exp(-B * x)),
74+
-A * (1 - torch.exp(B * x))
75+
)
76+
77+
#V11 = A * (1 - torch.exp(-B * x)) # Diabatic state 1 potential
78+
V22 = -V11 # Diabatic state 2 potential (mirror)
79+
V12 = C * torch.exp(-D * x**2) # Coupling between diabatic states
80+
81+
shape = x.shape
82+
Vmat = torch.zeros((*shape, 2, 2), dtype=torch.cfloat)
83+
Vmat[..., 0, 0] = V11
84+
Vmat[..., 1, 1] = V22
85+
Vmat[..., 0, 1] = V12
86+
Vmat[..., 1, 0] = torch.conj(V12)
87+
88+
return Vmat
89+
5690

5791
def gaussian_wavepacket(q, params):
5892
"""
@@ -211,3 +245,311 @@ def solve(self):
211245
self.save()
212246

213247

248+
249+
250+
class exact_tdse_solver_multistate:
251+
def __init__(self, params):
252+
"""
253+
Initializes the TDSE solver.
254+
255+
Parameters:
256+
params (dict): Dictionary of simulation parameters:
257+
- prefix (str): Filename prefix for output
258+
- grid_size (list[int]): Number of points per spatial dimension
259+
- q_min (list[float]), q_max (list[float]): Spatial bounds
260+
- save_every_n_steps (int): Interval for recording data
261+
- dt (float): Time step
262+
- nsteps (int): Number of time steps
263+
- mass (list[float]): Masses per dimension
264+
- potential_fn_params (dict): Parameters for potential energy surface
265+
- psi0_fn_params (dict): Parameters for initial wavepacket
266+
- device (torch.device): Computation device
267+
- Nstates (int): Number of electronic states
268+
- representation (str): "diabatic" or "adiabatic"
269+
- initial_state_index (int): Which state to initialize
270+
- method (str): Propagation scheme ('miller-colton', 'split-operator', 'crank-nicolson')
271+
"""
272+
273+
self.prefix = params.get("prefix", "exact-solution")
274+
self.grid_size = torch.tensor(params.get("grid_size", [64, 64]))
275+
self.ndim = len(self.grid_size)
276+
self.q_min = torch.tensor(params.get("q_min", [-10.0]*self.ndim))
277+
self.q_max = torch.tensor(params.get("q_max", [10.0]*self.ndim))
278+
self.save_every_n_steps = params.get("save_every_n_steps", 1)
279+
self.dt = params.get("dt", 0.01)
280+
self.nsteps = params.get("nsteps", 500)
281+
self.mass = torch.tensor(params.get("mass", [2000.0]*self.ndim))
282+
self.potential_fn = params.get("potential_fn", None)
283+
self.potential_fn_params = params.get("potential_fn_params", {})
284+
self.psi0_fn = params.get("psi0_fn", None)
285+
self.psi0_fn_params = params.get("psi0_fn_params", {})
286+
self.device = params.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
287+
self.hbar = 1.0
288+
self.Nstates = params.get("Nstates", 2)
289+
self.representation = params.get("representation", "diabatic")
290+
self.initial_state_index = params.get("initial_state_index", 0)
291+
self.method = params.get("method", "miller-colton").lower()
292+
293+
self.time = []
294+
self.kinetic_energy = []
295+
self.potential_energy = []
296+
self.total_energy = []
297+
self.population_right = []
298+
self.norm = []
299+
300+
def initialize_grids(self):
301+
"""
302+
Constructs real- and momentum-space grids, initializes kinetic energy operator
303+
and wavefunction on grid in chosen representation.
304+
"""
305+
print("Initializing grids")
306+
print("grid_size = ", self.grid_size)
307+
self.ngrid = int(self.grid_size.prod().item())
308+
print("self.ngrid = ", self.ngrid)
309+
310+
# Real-space grid
311+
q_axes = [torch.linspace(self.q_min[i], self.q_max[i], self.grid_size[i]) for i in range(self.ndim)]
312+
q_grids = torch.meshgrid(*q_axes, indexing="ij")
313+
self.dq = torch.tensor([q_axes[i][1] - q_axes[i][0] for i in range(self.ndim)])
314+
self.Q = torch.stack(q_grids)
315+
print("Q = ", self.Q.shape)
316+
print("dq = ", self.dq)
317+
318+
# Momentum-space grid
319+
self.dk = 2 * torch.pi / (self.grid_size * self.dq)
320+
k_axes = [torch.fft.fftshift(torch.arange(-self.grid_size[i] // 2, self.grid_size[i] // 2)) * self.dk[i] for i in range(self.ndim)]
321+
k_grids = torch.meshgrid(*k_axes, indexing="ij")
322+
self.K = torch.stack(k_grids)
323+
print("K = ", self.K.shape)
324+
print("dk = ", self.dk)
325+
326+
# Volume elements
327+
self.dV = self.dq.prod()
328+
self.dVk = self.dk.prod() #(self.dq / self.grid_size).prod()
329+
print("dV = ", self.dV)
330+
print("dVk = ", self.dVk)
331+
332+
333+
# Allocate storage:
334+
self.nsnaps = self.nsteps // self.save_every_n_steps + 1 # how many snapshots to save
335+
336+
# Diabatic properties
337+
self.psi_r_dia = torch.zeros((*self.grid_size, self.Nstates), dtype=torch.cfloat)
338+
self.psi_k_dia = torch.zeros_like(self.psi_r_dia)
339+
self.psi_r_dia_all = torch.zeros(( self.nsnaps , *self.grid_size, self.Nstates), dtype=torch.cfloat)
340+
self.psi_k_dia_all = torch.zeros((self.nsnaps, *self.grid_size, self.Nstates), dtype=torch.cfloat)
341+
self.rho_dia_all = torch.zeros((self.nsnaps, self.Nstates, self.Nstates), dtype=torch.cfloat)
342+
343+
# Adiabatic properties
344+
self.psi_r_adi = torch.zeros((*self.grid_size, self.Nstates), dtype=torch.cfloat)
345+
self.psi_k_adi = torch.zeros_like(self.psi_r_adi)
346+
self.psi_r_adi_all = torch.zeros(( self.nsnaps , *self.grid_size, self.Nstates), dtype=torch.cfloat)
347+
self.psi_k_adi_all = torch.zeros((self.nsnaps, *self.grid_size, self.Nstates), dtype=torch.cfloat)
348+
self.rho_adi_all = torch.zeros((self.nsnaps, self.Nstates, self.Nstates), dtype=torch.cfloat)
349+
350+
351+
def update_adi_r(self):
352+
"""
353+
Convert diabatic to adiabatic r-space wavefunction: C_adi = U * C_dia
354+
"""
355+
self.psi_r_adi = torch.einsum("...ij, ...j->...i", self.eigvecs, self.psi_r_dia)
356+
357+
def update_dia_r(self):
358+
"""
359+
Convert adiabatic to diabatic r-space wavefunction: C_dia = U.H * C_adi
360+
"""
361+
self.psi_r_dia = torch.einsum("...ij, ...j->...i", self.eigvecs.conj().transpose(-2,-1), self.psi_r_adi)
362+
363+
364+
def transform_r2k(self, rep):
365+
"""
366+
Compute the k-space diabatic and adiabatic wavefunctions from the r-space counterparts
367+
rep: 0 - diabatic, 1 - adiabatic
368+
"""
369+
# Define which axes are spatial (everything except last one = Nstates)
370+
spatial_dims = tuple(range(self.ndim))
371+
372+
# Apply FFT over spatial dimensions for all states at once
373+
if rep==0:
374+
self.psi_k_dia = torch.fft.fftn(self.psi_r_dia, dim=spatial_dims ) # norm='forward')
375+
elif rep==1:
376+
self.psi_k_adi = torch.fft.fftn(self.psi_r_adi, dim=spatial_dims ) # norm='forward')
377+
378+
def transform_k2r(self, rep):
379+
"""
380+
Compute the k-space diabatic and adiabatic wavefunctions from the r-space counterparts
381+
rep: 0 - diabatic, 1 - adiabatic
382+
"""
383+
# Define which axes are spatial (everything except last one = Nstates)
384+
spatial_dims = tuple(range(self.ndim))
385+
386+
# Apply inverse FFT over same spatial dimensions
387+
if rep==0:
388+
self.psi_r_dia = torch.fft.ifftn(self.psi_k_dia, dim=spatial_dims ) # norm='forward')
389+
elif rep==1:
390+
self.psi_r_adi = torch.fft.ifftn(self.psi_k_adi, dim=spatial_dims ) # norm='forward')
391+
392+
393+
def initialize_operators(self):
394+
# Kinetic energy operator - same for all states
395+
self.T = 0.5 * torch.sum((self.hbar * self.K) ** 2 / self.mass.view(self.ndim, *[1]*self.ndim), dim=0) # T: [*grid_size]
396+
print("T.shape = ", self.T.shape)
397+
398+
# Element-wise exponentials are fine here
399+
self.expT = torch.exp(-1j * self.T * self.dt / self.hbar) # expT: [*grid_size]
400+
print("expT.shape = ", self.expT.shape)
401+
402+
# Potential energy operator - matrix for multiple states
403+
self.V = self.potential_fn(self.Q, self.potential_fn_params) # V: [*grid_size, Nstates, Nstates]
404+
print("V.shape = ", self.V.shape)
405+
406+
# dia <-> adi transformation for all points
407+
# V U = E U => H = U.H * E * U
408+
# E = <psi_adi | H | psi_adi >
409+
# V = <psi_dia | H | psi_dia >
410+
# | psi_adi > U = | psi_dia >
411+
# | Psi > = | psi_dia> C_dia = | psi_adi > C_adi = | psi_adi > U C_dia
412+
# so C_adi = U C_dia
413+
414+
self.eigvals, self.eigvecs = torch.linalg.eigh(self.V) # V: [*grid_size, Nstates, Nstates]
415+
self.exp_diag = torch.exp(-0.5j * self.dt * self.eigvals)
416+
self.expV_half = self.eigvecs.conj().transpose(-2, -1) @ torch.diag_embed(self.exp_diag) @ self.eigvecs
417+
print("expV_half.shape = ", self.expV_half.shape)
418+
419+
# Initialize the wavefunction in r-space
420+
wfn = self.psi0_fn(self.Q, self.psi0_fn_params)
421+
422+
if self.representation == "diabatic":
423+
self.psi_r_dia[..., self.initial_state_index] = wfn
424+
self.update_adi_r() # dia -> adi
425+
elif self.representation == "adiabatic":
426+
self.psi_r_adi[..., self.initial_state_index] = wfn
427+
self.update_dia_r() # adi -> dia
428+
#print("psi.shape = ", self.psi_r_dia.shape)
429+
430+
# Update the k-space wavefunctions:
431+
self.transform_r2k(0) # psi_r_dia -> psi_k_dia
432+
self.transform_r2k(1) # psi_r_adi -> psi_k_adi
433+
434+
435+
def propagate(self):
436+
"""
437+
Time-propagates the wavefunction using the selected propagation method.
438+
Supported methods: 'miller-colton', 'split-operator', 'crank-nicolson'
439+
"""
440+
for step in range(self.nsteps):
441+
442+
#====================== Saving and computing properties ==================
443+
444+
if step % self.save_every_n_steps == 0:
445+
istep = int(step / self.save_every_n_steps)
446+
447+
# Diabatic r-space wavefunctions
448+
self.psi_r_dia_all[istep] = self.psi_r_dia
449+
450+
# Compute other representations:
451+
self.update_adi_r(); self.psi_r_adi_all[istep] = self.psi_r_adi; # adiabatic in r-space
452+
self.transform_r2k(0); self.psi_k_dia_all[istep] = self.psi_k_dia; # diabatic in k-space
453+
self.transform_r2k(1); self.psi_k_adi_all[istep] = self.psi_k_adi; # adiabatic in k-space
454+
455+
456+
# Diabatic density matrix
457+
self.rho_dia_all[istep] = torch.einsum("...i, ...j->ij", self.psi_r_dia, self.psi_r_dia.conj() ) * self.dV
458+
459+
# Adiabatic density matrix
460+
self.rho_adi_all[istep] = torch.einsum("...i, ...j->ij", self.psi_r_adi, self.psi_r_adi.conj() ) * self.dV
461+
462+
# Kinetic energy
463+
KE = torch.einsum("...i,...,...i->", self.psi_k_dia.conj(), self.T, self.psi_k_dia ) * (self.dq/self.grid_size).prod()
464+
465+
# Full potential energy: PE = ∫ ψ* V ψ dx
466+
PE = torch.einsum("...i,...ij,...j->", self.psi_r_dia.conj(), self.V, self.psi_r_dia) * self.dV
467+
468+
469+
nrm = torch.sum(torch.abs(self.psi_r_dia) ** 2 ) * self.dV
470+
x_coords = self.Q[0]
471+
right_mask = self.Q[0] > 0
472+
#pop_right = torch.sum(self.prob_density[right_mask]) * self.dV
473+
474+
self.norm.append( nrm )
475+
self.time.append(step * self.dt)
476+
self.kinetic_energy.append( KE.real.item() )
477+
self.potential_energy.append( PE.real.item() )
478+
self.total_energy.append( KE + PE )
479+
#self.population_right.append(pop_right.item())
480+
481+
print(f"Step {step}: Norm = {nrm:.4f}")
482+
483+
484+
#=================== Doing computations ==========================
485+
if self.method == "crank-nicolson":
486+
# Use Crank-Nicolson scheme for time propagation
487+
# self.crank_nicolson_step()
488+
pass # not implemented
489+
490+
else:
491+
# Apply half step of potential evolution in adiabatic representation
492+
# self.potential_half_step()
493+
494+
if self.method == "miller-colton":
495+
pass # not implemented
496+
# Miller-Colton propagation in momentum space
497+
#for i in range(self.Nstates):
498+
# self.psi_k[i] = torch.fft.fftn(self.psi[i])
499+
# self.psi_k[i] *= torch.exp(-1j * self.T * self.dt / (2 * self.hbar))
500+
# self.psi[i] = torch.fft.ifftn(self.psi_k[i])
501+
502+
elif self.method == "split-operator":
503+
# Half-step propagation in real space
504+
self.psi_r_dia = torch.einsum("...ij,...j->...i", self.expV_half, self.psi_r_dia) #self.expV_half @ self.psi
505+
506+
# Full-step in reciprocal space
507+
self.transform_r2k(0)
508+
509+
# Multiply by expT — make sure it broadcasts correctly
510+
self.psi_k_dia *= self.expT.unsqueeze(-1) # if expT.shape == [*,], expand to [*, 1]
511+
512+
# Apply inverse FFT over same spatial dimensions
513+
self.transform_k2r(0)
514+
515+
# Another half of the propagation in real space
516+
self.psi_r_dia = torch.einsum("...ij,...j->...i", self.expV_half, self.psi_r_dia) #self.expV_half @ self.psi
517+
518+
519+
def save(self):
520+
"""Saves the grid, wavefunction, and observables to disk."""
521+
torch.save( {"grid_size":self.grid_size,
522+
"ndim":self.ndim,
523+
"q_min":self.q_min, "q_max":self.q_max,
524+
"save_every_n_steps": self.save_every_n_steps,
525+
"dt":self.dt, "nsteps":self.nsteps,
526+
"mass":self.mass,
527+
"psi_r_adi":self.psi_r_adi,
528+
"psi_r_dia":self.psi_r_dia,
529+
"psi_k_adi":self.psi_k_adi,
530+
"psi_k_dia":self.psi_k_dia,
531+
"rho_dia_all":self.rho_dia_all,
532+
"rho_adi_all":self.rho_adi_all,
533+
"psi_r_dia_all":self.psi_r_dia_all,
534+
"psi_r_adi_all":self.psi_r_adi_all,
535+
"psi_k_dia_all":self.psi_k_dia_all,
536+
"psi_k_adi_all":self.psi_k_adi_all,
537+
"time":self.time,
538+
"Q":self.Q, "K":self.K, "dq":self.dq, "dk":self.dk,
539+
"dV":self.dV, "dVk":self.dVk,
540+
"kinetic_energy":self.kinetic_energy,
541+
"potential_energy":self.potential_energy,
542+
"total_energy":self.total_energy,
543+
"population_right":self.population_right,
544+
"norm":self.norm,
545+
"V":self.V, "T":self.T,
546+
}, F"{self.prefix}.pt" )
547+
548+
def solve(self):
549+
"""Runs the full simulation and saves results."""
550+
self.initialize_grids()
551+
self.initialize_operators()
552+
self.propagate()
553+
self.save()
554+
555+

0 commit comments

Comments
 (0)