Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 122 additions & 72 deletions src/libra_py/dynamics/ldr_torch/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,19 @@ def __init__(self, params):
self.prefix = params.get("prefix", "ldr-solution")
self.device = params.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
self.hbar = 1.0
self.Hamiltonian_scheme = "symmetrized"
self.hamiltonian_scheme = "symmetrized"
self.q0 = torch.tensor(params.get("q0", [0.0]), dtype=torch.float64, device=self.device)
self.p0 = torch.tensor(params.get("p0", [0.0]), dtype=torch.float64, device=self.device)
self.k = torch.tensor(params.get("k", [0.001]), dtype=torch.float64, device=self.device)
self.mass = torch.tensor(params.get("mass", [2000.0]), dtype=torch.float64, device=self.device)
self.alpha = torch.tensor(params.get("alpha", [18.0]), dtype=torch.float64, device=self.device)
self.qgrid = torch.tensor(params.get("qgrid", [[-10 + i * 0.1] for i in range(int((10 - (-10)) / 0.1) + 1)] ), dtype=torch.float64, device=self.device) #(N, D)
self.ngrids = len(self.qgrid) # N
self.ndof = self.qgrid.shape[1]
self.nstates = params.get("nstates", 2)
self.istate = params.get("istate", 0)

self.elec_ampl = params.get("elec_ampl", torch.tensor([1.0+0.j]*self.ngrids, dtype=torch.cdouble))

self.save_every_n_steps = params.get("save_every_n_steps", 1)
self.properties_to_save = params.get("properties_to_save", ["time", "population_right"])
self.dt = params.get("dt", 0.01)
Expand All @@ -60,99 +62,103 @@ def __init__(self, params):

self.E = params.get("E", torch.zeros(self.nstates, self.ngrids, device=self.device) )

Selec_default = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
s_elec_default = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
for i in range(self.nstates):
start, end = i * self.ngrids, (i + 1) * self.ngrids
Selec_default[start:end, start:end] = torch.eye(self.ngrids, device=self.device)
self.Selec = params.get("Selec", Selec_default )
s_elec_default[start:end, start:end] = torch.eye(self.ngrids, device=self.device)
self.s_elec = params.get("s_elec", s_elec_default )

# Computed with LDR methods
self.C0 = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)
self.Ccurr = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)
self.C_curr = torch.zeros(self.ndim, dtype=torch.cdouble, device=self.device)

self.Snucl = torch.eye(self.ngrids, dtype=torch.cdouble, device=self.device)
self.Tnucl = torch.zeros(self.ngrids, self.ngrids, dtype=torch.cdouble, device=self.device)
self.s_nucl = torch.eye(self.ngrids, dtype=torch.cdouble, device=self.device)
self.t_nucl = torch.zeros(self.ngrids, self.ngrids, dtype=torch.cdouble, device=self.device)

self.S, self.H = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device), torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
self.U = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)
self.S_half = torch.zeros(self.ndim, self.ndim, dtype=torch.cdouble, device=self.device)

self.time = []
self.kinetic_energy = []
self.potential_energy = []
self.total_energy = []
self.average_pos = []
self.population_right = []
self.denmat = []
self.norm = []
self.C_save = []

def chi_overlap(self):
"""
Compute nuclear overlap matrix Snucl[i, j] for the mesh qmesh.
Compute nuclear overlap matrix s_nucl[i, j] for the mesh qmesh
from the Gaussian basis, g(x; q) = \exp(-\alpha * (x-q)**2).
"""
delta = self.qgrid[:, None, :] - self.qgrid[None, :, :] # (N, N, D)
exponent = -0.5 * torch.sum(self.alpha * delta**2, dim=2) # (N, N)
self.Snucl = torch.exp(exponent)
self.s_nucl = torch.exp(exponent)

def chi_kinetic(self):
r"""
Compute nuclear kinetic energy matrix Tnucl[i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
with T = Σ_ν -½ m_ν^{-1} ∂²/∂x_ν².
"""
Compute nuclear kinetic energy matrix t_nucl[i,j] = <g(x; qgrid[i]) | T | g(x; qgrid[j])>,
with T = \sum_{\nu} -0.5* m_ν^{-1} \partial^{2}/\partial x_{\nu}^2.
"""
delta = self.qgrid[:, None, :] - self.qgrid[None, :, :] # (N, N, D)
tau = self.alpha / (2.0 * self.mass) * (1.0 - self.alpha * delta**2) # (N, N, D)
tau_sum = torch.sum(tau, dim=2) # (N, N)

self.Tnucl = self.Snucl * tau_sum # (N, N)
self.t_nucl = self.s_nucl * tau_sum # (N, N)

def build_compound_overlap(self):
"""
Build the compound nuclear-electronic overlap matrix self.S (ndim, ndim)
"""
N, s, ndim = self.ngrids, self.nstates, self.ndim

# Reshape Selec[a, b] -> (i, n, j, m) with:
# Reshape s_elec[a, b] -> (i, n, j, m) with:
# a = i * N + n
# b = j * N + m
Selec4D = self.Selec.view(s, N, s, N) # (i, n, j, m)
s_elec_4d = self.s_elec.view(s, N, s, N) # (i, n, j, m)

Snucl4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
s_nucl_4d = self.s_nucl[None, :, None, :] # (1, n, 1, m)

S4D = Selec4D * Snucl4D
S_4d = s_elec_4d * s_nucl_4d

# Reshape back to (ndim, ndim) with compound indices
self.S = S4D.permute(0, 1, 2, 3).reshape(ndim, ndim)
self.S = S_4d.reshape(ndim, ndim)

def build_compound_hamiltonian(self):
"""
Build the compound nuclear-electronic Hamiltonian self.H (ndim, ndim) using different schemes.
"""
N, s, ndim = self.ngrids, self.nstates, self.ndim
scheme = self.Hamiltonian_scheme
Selec4D = self.Selec.view(s, N, s, N) # (s, N, s, N)
T4D = self.Tnucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
S4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
scheme = self.hamiltonian_scheme
s_elec_4d = self.s_elec.view(s, N, s, N) # (s, N, s, N)
T_4d = self.t_nucl[None, :, None, :] # (1, N, 1, N)
S_4d = self.s_nucl[None, :, None, :] # (1, N, 1, N)

if scheme == 'as_is': # For showing the original non-Hermitian form, not intended to use
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
bracket4D = T4D + Ej4D * S4D
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
bracket_4d = T_4d + E_j_4d * S_4d
elif scheme == 'symmetrized':
Ei4D = self.E[:, :, None, None] # (s, N, 1, 1)
Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
Eavg4D = 0.5 * (Ei4D + Ej4D) # (s, N, s, N)
bracket4D = T4D + Eavg4D * S4D
E_i_4d = self.E[:, :, None, None] # (s, N, 1, 1)
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
E_avg_4d = 0.5 * (E_i_4d + E_j_4d) # (s, N, s, N)
bracket_4d = T_4d + E_avg_4d * S_4d
elif scheme == 'diagonal':
# Build Kronecker deltas for electronic and nuclear indices
delta_ij = torch.eye(s, device=self.device).unsqueeze(1).unsqueeze(3) # (s, 1, s, 1)
delta_nm = torch.eye(N, device=self.device).unsqueeze(0).unsqueeze(2) # (1, N, 1, N)
delta4D = delta_ij * delta_nm
delta_ij = torch.eye(s, device=self.device)[:, None, :, None] # (s, 1, s, 1)
delta_nm = torch.eye(N, device=self.device)[None, :, None, :] # (1, N, 1, N)
delta_4d = delta_ij * delta_nm

Ej4D = self.E[None, None, :, :] # (1, 1, s, N)
bracket4D = T4D + Ej4D * S4D * delta4D
E_j_4d = self.E[None, None, :, :] # (1, 1, s, N)
bracket_4d = T_4d + E_j_4d * S_4d * delta_4d

else:
raise ValueError(f"Unknown Hamiltonian scheme: {scheme}")

H4D = Selec4D * bracket4D
self.H = H4D.reshape(ndim, ndim)
H_4d = s_elec_4d * bracket_4d
self.H = H_4d.reshape(ndim, ndim)

def compute_propagator(self):
"""
Expand All @@ -166,7 +172,7 @@ def compute_propagator(self):

evals_S, evecs_S = torch.linalg.eigh(S)

S_half = (evecs_S @ torch.diag(evals_S.sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)
self.S_half = (evecs_S @ torch.diag(evals_S.sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)
S_invhalf = (evecs_S @ torch.diag((1.0 / evals_S).sqrt().to(dtype=torch.cdouble)) @ evecs_S.T).to(dtype=torch.cdouble)

H_ortho = S_invhalf @ H @ S_invhalf
Expand All @@ -176,7 +182,7 @@ def compute_propagator(self):
exp_diag = torch.diag(torch.exp(-1j * evals_H * dt))
U_ortho = evecs_H @ exp_diag @ evecs_H.conj().T

self.U = S_invhalf @ U_ortho @ S_half
self.U = S_invhalf @ U_ortho @ self.S_half


def initialize_C(self):
Expand Down Expand Up @@ -217,7 +223,7 @@ def initialize_C(self):
delta_eta = -0.5 * torch.dot(xi0 + p0, q0) + 0.5 * torch.dot(xig, qgrid[n]).conj()
exponent = -1.j * 0.5 * torch.dot(delta_xi, torch.matmul(delta_A_inv, delta_xi)) + 1.j * delta_eta

self.C0[index] = torch.exp(exponent)
self.C0[index] = self.elec_ampl[n] * torch.exp(exponent)

# Normalize
overlap = torch.matmul(self.S, self.C0)
Expand All @@ -230,14 +236,14 @@ def propagate(self):
Propagate coefficient.
"""
# Initialize first step with normalized initial wavefunction
self.Ccurr = self.C0.clone()
self.C_curr = self.C0.clone()

print(F"step = 0")
self.save_results(0)

for step in range(1, self.nsteps):
Cvec = self.Ccurr.clone()
self.Ccurr = self.U @ Cvec
C_vec = self.C_curr.clone()
self.C_curr = self.U @ C_vec

if step % self.save_every_n_steps == 0:
print(F"step = {step}")
Expand All @@ -247,53 +253,72 @@ def save_results(self, step):
if "time" in self.properties_to_save:
self.time.append(step*self.dt)
if "norm" in self.properties_to_save:
overlap = torch.matmul(self.S, self.Ccurr)
self.norm.append(torch.sqrt(torch.vdot(self.Ccurr, overlap)))
overlap = torch.matmul(self.S, self.C_curr)
self.norm.append(torch.sqrt(torch.vdot(self.C_curr, overlap)))
if "population_right" in self.properties_to_save:
self.population_right.append(self.compute_populations())
if "denmat" in self.properties_to_save:
self.denmat.append(self.compute_denmat())
if "kinetic_energy" in self.properties_to_save:
self.kinetic_energy.append(self.compute_kinetic_energy())
if "potential_energy" in self.properties_to_save:
self.potential_energy.append(self.compute_potential_energy())
if "total_energy" in self.properties_to_save:
self.total_energy.append(self.compute_total_energy())
if "average_pos" in self.properties_to_save:
self.average_pos.append(self.compute_average_pos())
if "C_save" in self.properties_to_save:
self.C_save.append(self.Ccurr)
self.C_save.append(self.C_curr)

def compute_populations(self):
"""
Compute electronic state population for a single step.
"""
N, s = self.ngrids, self.nstates
Cvec = self.Ccurr
C_vec = self.C_curr

# Compute SC once: shape (ndim,)
SC = self.S @ Cvec
SC = self.S @ C_vec

C_blocks = Cvec.view(s, N)
C_blocks = C_vec.view(s, N)
SC_blocks = SC.view(s, N)

# Compute P[i] = sum_j <C_j|S_{ji}|C_i> = Re[ sum_N (C_j*) * SC_j ]
P = torch.sum(C_blocks.conj() * SC_blocks, dim=1).real

return P

def compute_denmat(self):
"""
Compute electronic density matrix for a single step using the orthogonalization.
"""
N, s = self.ngrids, self.nstates
C_vec = self.C_curr

# Orthogonalize coefficients: C_ortho = S^{1/2} C
C_ortho = self.S_half @ C_vec

C_blocks = C_ortho.view(s, N)

rho = C_blocks @ C_blocks.conj().T # (s, s)

return rho

def compute_kinetic_energy(self):
"""
Compute nuclear kinetic energy as C^+ T C / C^+ S C for a single step.
"""
N, s, ndim = self.ngrids, self.nstates, self.ndim

# Rebuild compound kinetic matrix: T4D * Selec4D
Selec4D = self.Selec.view(s, N, s, N)
T4D = self.Tnucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
T4D_compound = Selec4D * T4D
T_compound = T4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
# Rebuild compound kinetic matrix: T_4d * s_elec_4d
s_elec_4d = self.s_elec.view(s, N, s, N)
T_4d = self.t_nucl[None, :, None, :]
T_compound = (s_elec_4d * T_4d).reshape(ndim, ndim)

Cvec = self.Ccurr
C_vec = self.C_curr

numer = torch.vdot(Cvec, T_compound @ Cvec).real
denom = torch.vdot(Cvec, self.S @ Cvec).real
numer = torch.vdot(C_vec, T_compound @ C_vec).real
denom = torch.vdot(C_vec, self.S @ C_vec).real

return numer / denom

Expand All @@ -304,17 +329,16 @@ def compute_potential_energy(self):
"""
N, s, ndim = self.ngrids, self.nstates, self.ndim

Selec4D = self.Selec.view(s, N, s, N)
S4D = self.Snucl.unsqueeze(0).unsqueeze(2) # (1, n, 1, m)
Ej4D = self.E[None, None, :, :] # (1,1,j,m)
s_elec_4d = self.s_elec.view(s, N, s, N)
S_4d = self.s_nucl[None, :, None, :]
E_j_4d = self.E[None, None, :, :] # (1,1,j,m)

V4D_compound = Selec4D * (Ej4D * S4D)
V_compound = V4D_compound.permute(0, 1, 2, 3).reshape(ndim, ndim)
V_compound = (s_elec_4d * (E_j_4d * S_4d)).reshape(ndim, ndim)

Cvec = self.Ccurr
C_vec = self.C_curr

numer = torch.vdot(Cvec, V_compound @ Cvec).real
denom = torch.vdot(Cvec, self.S @ Cvec).real
numer = torch.vdot(C_vec, V_compound @ C_vec).real
denom = torch.vdot(C_vec, self.S @ C_vec).real

return numer / denom

Expand All @@ -323,13 +347,37 @@ def compute_total_energy(self):
"""
Compute total energy as C^+ H C / C^+ S C for a single step.
"""
Cvec = self.Ccurr
C_vec = self.C_curr

numer = torch.vdot(Cvec, self.H @ Cvec).real
denom = torch.vdot(Cvec, self.S @ Cvec).real
numer = torch.vdot(C_vec, self.H @ C_vec).real
denom = torch.vdot(C_vec, self.S @ C_vec).real

return numer / denom

def compute_average_pos(self):
"""
Compute average position as <q_i> = \sum_i C^+ Q C / C^+ S C for a single step.
"""
N, s, ndim = self.ngrids, self.nstates, self.ndim

C_vec = self.C_curr

denom = torch.vdot(C_vec, self.S @ C_vec).real
s_elec_4d = self.s_elec.view(s, N, s, N)

avg_q = []
for idof in range(self.ndof):
q_med = 0.5 * (self.qgrid[:, None, idof] + self.qgrid[None,:,idof])
q_nucl = self.s_nucl * q_med
Q_4d = q_nucl[None, :, None, :]
Q_4d_compound = s_elec_4d * Q_4d
Q_compound = Q_4d_compound.reshape(ndim, ndim)

numer = torch.vdot(C_vec, Q_compound @ C_vec).real
avg_q.append(numer / denom)

return avg_q

def save(self):
torch.save( {"q0":self.q0,
"p0":self.p0,
Expand All @@ -339,22 +387,24 @@ def save(self):
"qgrid":self.qgrid,
"nstates":self.nstates,
"istate":self.istate,
"Snucl":self.Snucl,
"Tnucl":self.Tnucl,
"s_nucl":self.s_nucl,
"t_nucl":self.t_nucl,
"E":self.E,
"Selec":self.Selec,
"s_elec":self.s_elec,
"S":self.S,
"H":self.H,
"U":self.U,
"C_save":self.C_save,
"save_every_n_steps":self.save_every_n_steps,
"Hamiltonian_scheme": self.Hamiltonian_scheme,
"hamiltonian_scheme": self.hamiltonian_scheme,
"dt":self.dt, "nsteps":self.nsteps,
"time":self.time,
"kinetic_energy":self.kinetic_energy,
"potential_energy":self.potential_energy,
"total_energy":self.total_energy,
"average_pos":self.average_pos,
"population_right":self.population_right,
"denmat":self.denmat,
"norm":self.norm
}, F"{self.prefix}.pt" )

Expand Down