From 5ca73510ee601a2d354ee9cf7fa638d439913188 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Mon, 14 Apr 2025 17:13:56 +0100 Subject: [PATCH 1/8] Add initial LJ implementation --- emle/models/_emle.py | 70 ++++++++++++- emle/models/_emle_base.py | 206 +++++++++++++++++++++++++++++++++++++- 2 files changed, 267 insertions(+), 9 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 2d6b29c..692d53e 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -27,17 +27,17 @@ __all__ = ["EMLE"] -import numpy as _np import os as _os +from typing import Union + +import numpy as _np import scipy.io as _scipy_io import torch as _torch import torchani as _torchani - from torch import Tensor -from typing import Union -from . import _patches from . import EMLEBase as _EMLEBase +from . import _patches # Monkey-patch the TorchANI BuiltInModel and BuiltinEnsemble classes so that # they call self.aev_computer using args only to allow forward hooks to work @@ -83,6 +83,7 @@ def __init__( model=None, method="electrostatic", alpha_mode="species", + lj_method=None, atomic_numbers=None, qm_charge=0, mm_charges=None, @@ -123,6 +124,15 @@ def __init__( scaling factors are obtained with GPR using the values learned for each reference environment + lj_method: str + How the LJ parameters are calculated. + "dynamic": + Lennard-Jones parameters are calculated dynamically from EMLE parameters. + "static": + Lennard-Jones parameters are fixed. + None + Lennard-Jones parameters and interactions are not included. + atomic_numbers: List[int], Tuple[int], numpy.ndarray, torch.Tensor Atomic numbers for the QM region. This allows use of optimised AEV symmetry functions from the NNPOps package. Only use this option @@ -179,6 +189,16 @@ def __init__( raise ValueError("'alpha_mode' must be 'species' or 'reference'") self._alpha_mode = alpha_mode + if lj_method is not None: + if not isinstance(lj_method, str): + raise TypeError("'lj_method' must be of type 'str'") + lj_method = lj_method.lower().replace(" ", "") + if lj_method not in {"dynamic", "static"}: + raise ValueError("'lj_method' must be 'dynamic' or 'static'") + if lj_method == "static": + raise NotImplementedError("Static LJ model not implemented") + self._lj_method = lj_method + if atomic_numbers is not None: if isinstance(atomic_numbers, (_np.ndarray, _torch.Tensor)): atomic_numbers = atomic_numbers.tolist() @@ -523,4 +543,46 @@ def forward( E_static, dtype=self._charges_mm.dtype, device=self._device ) + if self._lj_method is not None: + if self._lj_method == "dynamic": + # Calculate the isotropic polarizabilities and cube of the vdW radii. + alpha_qm = self._emle_base.calculate_isotropic_polarizabilities(A_thole) + rcubed_qm = -60 * q_val * s**3 / ANGSTROM_TO_BOHR**3 + + # Calculate the LJ parameters. + sigma_qm, epsilon_qm = self._emle_base.calculate_atomic_lj_parameters( + self._atomic_numbers, + rcubed_qm, + alpha_qm, + ) + + # TODO: How to handle this properly? + # Calculate the LJ parameters for the MM atoms. + sigma_tip3p_O = 3.0050806999206543 * 1.8897259886 # ≈ 5.676 Bohr + epsilon_tip3p_O = ( + 0.4382217228412628 * 0.00038087980 + ) # ≈ 0.000167 Hartree + + # Hydrogen + sigma_tip3p_H = 2.608452081680298 * 1.8897259886 # ≈ 4.927 Bohr + epsilon_tip3p_H = ( + 0.01849512942135334 * 0.00038087980 + ) # ≈ 7.043e-6 Hartree + + sigma_mm = charges_mm + sigma_mm[charges_mm < 0] = sigma_tip3p_O + sigma_mm[charges_mm > 0] = sigma_tip3p_H + epsilon_mm = charges_mm + epsilon_mm[charges_mm < 0] = epsilon_tip3p_O + epsilon_mm[charges_mm > 0] = epsilon_tip3p_H + + # Compute the LJ energy. + E_lj = self._emle_base.get_lj_energy( + sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data + ) + else: + raise NotImplementedError("Static LJ model not implemented") + + return _torch.stack((E_static, E_ind, E_lj), dim=0) + return _torch.stack((E_static, E_ind), dim=0) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index d374775..b21bba3 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -27,14 +27,12 @@ __all__ = ["EMLEBase"] -import numpy as _np - -import torch as _torch - -from torch import Tensor from typing import Tuple +import numpy as _np +import torch as _torch import torchani as _torchani +from torch import Tensor try: import NNPOps as _NNPOps @@ -57,6 +55,30 @@ class EMLEBase(_torch.nn.Module): # Store the list of supported species. _species = [1, 6, 7, 8, 16] + # Values in atomic units, taken from: + # Chu, X., & Dalgarno, A. (2004). + # Linear response time-dependent density functional theory for van der Waals coefficients. + # The Journal of Chemical Physics, 121(9), 4083--4088. http://doi.org/10.1063/1.1779576 + # Bfree: Ha.Bohr**6 (https://pubs.acs.org/doi/10.1021/acs.jctc.6b00027) #TODO: convert to Angstrom^6? + C6_COEFFICIENTS = { + 0: 0.0, # Dummy atom + 1: 6.5, # H + 6: 46.6, # C + 7: 24.2, # N + 8: 15.6, # O + 16: 134.0, # S + } + + # Calculated free atom volumes in Angstrom^3 at B3LYP/cc-pVTZ #TODO: convert to Bohr^3? + RCUBED_FREE = { + 0: 0.0, # Dummy atom + 1: 1.172520801342987, # H + 6: 5.044203900156338, # C + 7: 3.7853134153705934, # N + 8: 3.135943377417264, # O + 16: 11.015796687279208, # S + } + def __init__( self, params, @@ -227,6 +249,19 @@ def __init__( species_map[s] = i species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) + # Create lookup tensors for free atom volumes and C6 coefficients + self._vol_isolated = _torch.zeros( + max(self.RCUBED_FREE.keys()) + 1, dtype=_torch.float + ) + for Z, vol in self.RCUBED_FREE.items(): + self._vol_isolated[Z] = vol + + self._c6 = _torch.zeros( + max(self.C6_COEFFICIENTS.keys()) + 1, dtype=_torch.float + ) + for Z, c6_coeff in self.C6_COEFFICIENTS.items(): + self._c6[Z] = c6_coeff + # Compute the inverse of the K matrix. Kinv = self._get_Kinv(ref_features, 1e-3) @@ -1038,3 +1073,164 @@ def _get_T0_slater(r: Tensor, s: Tensor) -> Tensor: results: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS) """ return (1 - (1 + r / (s * 2)) * _torch.exp(-r / s)) / r + + @staticmethod + def get_lj_energy( + sigma_qm: Tensor, + epsilon_qm: Tensor, + sigma_mm: Tensor, + epsilon_mm: Tensor, + mesh_data: Tuple[Tensor, Tensor, Tensor], + ) -> Tensor: + """ + Calculate the Lennard-Jones energy. + + Parameters + ---------- + + sigma_qm: Tensor (N_BATCH, N_QM_ATOMS) + Lennard-Jones sigma values in Bohr. + + epsilon_qm: Tensor (N_BATCH, N_QM_ATOMS) + Lennard-Jones epsilon values in atomic units. + + sigma_mm: Tensor (N_BATCH, N_MM_ATOMS) + Lennard-Jones sigma values in Bohr. + + epsilon_mm: Tensor (N_BATCH, N_MM_ATOMS) + Lennard-Jones epsilon values in atomic units. + + mesh_data: Tuple[Tensor, Tensor, Tensor] + Mesh data tuple containing (r_inv, r_vec, s_outer_product). + r_inv: Tensor (N_BATCH, N_QM_ATOMS, N_MM_ATOMS) of inverse QM-MM distances. + + Returns + ------- + + Tensor (N_BATCH,) + Total Lennard-Jones energy for each batch element in atomic units. + """ + # Lorentz-Berthelot combining rules + # sigma (N_BATCH, N_QM, N_MM) + # epsilon (N_BATCH, N_QM, N_MM) + sigma = 0.5 * (sigma_qm[:, :, None] + sigma_mm[:, None, :]) + epsilon_product = epsilon_qm[:, :, None] * epsilon_mm[:, None, :] + epsilon = _torch.where(epsilon_product > 0, _torch.sqrt(epsilon_product), 0.0) + + # Get distances + # r_inv (N_BATCH, N_QM, N_MM) + r_inv, _, _ = mesh_data + sigma_r_inv_6 = (sigma * r_inv) ** 6 + sigma_r_inv_12 = sigma_r_inv_6 * sigma_r_inv_6 + + # Calculate pairwise energy matrix (N_BATCH, N_QM, N_MM) + pairwise_energy = 4 * epsilon * (sigma_r_inv_12 - sigma_r_inv_6) + + # Sum over QM and MM atoms for each batch element + total_energy = pairwise_energy.sum(dim=(1, 2)) + + return total_energy + + @staticmethod + def calculate_isotropic_polarizabilities(A_thole: _torch.Tensor) -> _torch.Tensor: + """ + Calculate isotropic polarizabilities from the A_thole tensor. + + Parameters + ---------- + + A_thole : torch.Tensor(N_BATCH, 3N_ATOMS, 3N_ATOMS) + Full polarizability tensor in block form. + + Returns + ------- + + torch.Tensor(N_BATCH, N_ATOMS) + Isotropic polarizabilities per atom. + """ + + def _get_traces(A_thole: _torch.Tensor) -> _torch.Tensor: + """ + Compute the trace of the inverse of each 3x3 block in each polarizability tensor. + """ + n_mol, dim, _ = A_thole.shape + if dim % 3 != 0: + raise ValueError("Dimension of A_thole must be divisible by 3.") + + n_atoms = dim // 3 + traces = _torch.empty( + (n_mol, n_atoms), dtype=A_thole.dtype, device=A_thole.device + ) + + for mol_idx in range(n_mol): + for atom_idx in range(n_atoms): + block = A_thole[ + mol_idx, + 3 * atom_idx : 3 * atom_idx + 3, + 3 * atom_idx : 3 * atom_idx + 3, + ] + inv_block = _torch.inverse(block) + traces[mol_idx, atom_idx] = _torch.trace(inv_block) + + return traces + + return _get_traces(A_thole) / 3.0 + + def calculate_atomic_lj_parameters( + self, atomic_numbers: Tensor, rcubed: Tensor, alpha: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Calculate Lennard-Jones sigma and epsilon parameters. + + Parameters + ---------- + + atomic_numbers : torch.Tensor(N_BATCH, N_ATOMS) + Atomic numbers of atoms in the molecule. + + rcubed : torch.Tensor(N_BATCH, N_ATOMS) + Cube of the vdW radii of atoms in the molecule in Bohr^3. + + alpha : torch.Tensor(N_BATCH, N_ATOMS) + Isotropic polarizabilities per atom in Bohr^3. + + Returns + ------- + + Tuple[torch.Tensor, torch.Tensor] + Tuple containing the sigma (Bohr) and epsilon (Hartree) LJ parameters for each atom. + """ + # Mask out dummy atoms + mask = atomic_numbers > 0 + alpha = alpha * mask + rcubed = rcubed * mask + + # Get free atom volumes + try: + vol_isolated = self._vol_isolated[atomic_numbers] + except KeyError as e: + raise ValueError(f"Missing RCUBED_FREE entry for atomic number {e.args[0]}") + + # Volume scaling factor + scaling = rcubed / vol_isolated + + # Get C6 coefficients + try: + c6 = self._c6[atomic_numbers] + except KeyError as e: + raise ValueError( + f"Missing C6_COEFFICIENTS entry for atomic number {e.args[0]}" + ) + + # Scale C6 coefficients + c6_scaled = c6 * scaling**2 + + # Calculate vdW radius from polarizability (Fedorov-Tkatchenko relation) + radius = 2.54 * alpha ** (1.0 / 7.0) + rmin = 2 * radius + + # Calculate Lennard-Jones parameters + sigma = rmin / (2 ** (1.0 / 6.0)) + epsilon = c6_scaled / (2 * rmin**6.0) + + return sigma, epsilon From 69e6585bf6351e208255be6719f33104f67e53fb Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Thu, 24 Jul 2025 08:37:52 +0100 Subject: [PATCH 2/8] Add C6 training, and LJ implementation --- emle/models/_emle.py | 51 +++----------- emle/models/_emle_base.py | 144 ++++++++++++++++---------------------- emle/train/_loss.py | 74 +++++++++++++++++++- emle/train/_trainer.py | 68 ++++++++++++++++-- 4 files changed, 205 insertions(+), 132 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 692d53e..3f36ed7 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -296,6 +296,11 @@ def __init__( if "sqrtk_ref" in params else None ), + "ref_values_C6": ( + _torch.tensor(params["ref_values_C6"], dtype=dtype, device=device) + if "ref_values_C6" in params + else None + ), } if method == "mm": @@ -371,6 +376,7 @@ def __init__( q_core, emle_aev_computer=emle_aev_computer, alpha_mode=self._alpha_mode, + lj_mode=self._lj_method, species=params.get("species", self._species), device=device, dtype=dtype, @@ -499,11 +505,8 @@ def forward( 2, batch_size, dtype=self._xyz_qm.dtype, device=self._xyz_qm.device ) - # Get the parameters from the base model: - # valence widths, core charges, valence charges, A_thole tensor - # These are returned as batched tensors, so we need to extract the - # first element of each. - s, q_core, q_val, A_thole = self._emle_base( + # Get the parameters from the base model. + s, q_core, q_val, A_thole, C6 = self._emle_base( self._atomic_numbers, self._xyz_qm, qm_charge, @@ -545,41 +548,9 @@ def forward( if self._lj_method is not None: if self._lj_method == "dynamic": - # Calculate the isotropic polarizabilities and cube of the vdW radii. - alpha_qm = self._emle_base.calculate_isotropic_polarizabilities(A_thole) - rcubed_qm = -60 * q_val * s**3 / ANGSTROM_TO_BOHR**3 - - # Calculate the LJ parameters. - sigma_qm, epsilon_qm = self._emle_base.calculate_atomic_lj_parameters( - self._atomic_numbers, - rcubed_qm, - alpha_qm, - ) - - # TODO: How to handle this properly? - # Calculate the LJ parameters for the MM atoms. - sigma_tip3p_O = 3.0050806999206543 * 1.8897259886 # ≈ 5.676 Bohr - epsilon_tip3p_O = ( - 0.4382217228412628 * 0.00038087980 - ) # ≈ 0.000167 Hartree - - # Hydrogen - sigma_tip3p_H = 2.608452081680298 * 1.8897259886 # ≈ 4.927 Bohr - epsilon_tip3p_H = ( - 0.01849512942135334 * 0.00038087980 - ) # ≈ 7.043e-6 Hartree - - sigma_mm = charges_mm - sigma_mm[charges_mm < 0] = sigma_tip3p_O - sigma_mm[charges_mm > 0] = sigma_tip3p_H - epsilon_mm = charges_mm - epsilon_mm[charges_mm < 0] = epsilon_tip3p_O - epsilon_mm[charges_mm > 0] = epsilon_tip3p_H - - # Compute the LJ energy. - E_lj = self._emle_base.get_lj_energy( - sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data - ) + alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) + sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(C6, alpha_qm) + E_lj = self._emle_base.get_lj_energy(sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data) else: raise NotImplementedError("Static LJ model not implemented") diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index b21bba3..4ec6cf2 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -43,7 +43,6 @@ except: _has_nnpops = False - class EMLEBase(_torch.nn.Module): """ Base class for the EMLE model. This is used to compute valence shell @@ -55,30 +54,6 @@ class EMLEBase(_torch.nn.Module): # Store the list of supported species. _species = [1, 6, 7, 8, 16] - # Values in atomic units, taken from: - # Chu, X., & Dalgarno, A. (2004). - # Linear response time-dependent density functional theory for van der Waals coefficients. - # The Journal of Chemical Physics, 121(9), 4083--4088. http://doi.org/10.1063/1.1779576 - # Bfree: Ha.Bohr**6 (https://pubs.acs.org/doi/10.1021/acs.jctc.6b00027) #TODO: convert to Angstrom^6? - C6_COEFFICIENTS = { - 0: 0.0, # Dummy atom - 1: 6.5, # H - 6: 46.6, # C - 7: 24.2, # N - 8: 15.6, # O - 16: 134.0, # S - } - - # Calculated free atom volumes in Angstrom^3 at B3LYP/cc-pVTZ #TODO: convert to Bohr^3? - RCUBED_FREE = { - 0: 0.0, # Dummy atom - 1: 1.172520801342987, # H - 6: 5.044203900156338, # C - 7: 3.7853134153705934, # N - 8: 3.135943377417264, # O - 16: 11.015796687279208, # S - } - def __init__( self, params, @@ -88,6 +63,7 @@ def __init__( emle_aev_computer=None, species=None, alpha_mode="species", + lj_mode=None, device=None, dtype=None, ): @@ -116,6 +92,10 @@ def __init__( "reference": scaling factors are obtained with GPR using the values learned for each reference environment + + lj_mode: str + Mode for calculating the Lennard-Jones potential. + If None, the Lennard-Jones potential is not calculated. emle_aev_computer: EMLEAEVComputer EMLE AEV computer instance used to compute AEVs (masked and normalized). @@ -227,6 +207,18 @@ def __init__( "using 'reference' alpha mode." ) raise ValueError(msg) + + if lj_mode is not None: + assert lj_mode in ["static", "dynamic"], "Invalid Lennard-Jones mode" + try: + self.ref_values_C6 = _torch.nn.Parameter(params["ref_values_C6"]) + except: + msg = ( + "Missing 'ref_values_C6' key in params. This is required when " + "using the Lennard-Jones potential." + ) + raise ValueError(msg) + self._lj_mode = lj_mode # Validate the species. if species is None: @@ -249,19 +241,6 @@ def __init__( species_map[s] = i species_map = _torch.tensor(species_map, dtype=_torch.int64, device=device) - # Create lookup tensors for free atom volumes and C6 coefficients - self._vol_isolated = _torch.zeros( - max(self.RCUBED_FREE.keys()) + 1, dtype=_torch.float - ) - for Z, vol in self.RCUBED_FREE.items(): - self._vol_isolated[Z] = vol - - self._c6 = _torch.zeros( - max(self.C6_COEFFICIENTS.keys()) + 1, dtype=_torch.float - ) - for Z, c6_coeff in self.C6_COEFFICIENTS.items(): - self._c6[Z] = c6_coeff - # Compute the inverse of the K matrix. Kinv = self._get_Kinv(ref_features, 1e-3) @@ -276,6 +255,12 @@ def __init__( else: ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, self.ref_values_sqrtk, Kinv) + if lj_mode is not None: + ref_mean_C6, c_C6 = self._get_c(n_ref, self.ref_values_C6, Kinv) + else: + ref_mean_C6 = _torch.zeros_like(ref_mean_s, dtype=dtype, device=device) + c_C6 = _torch.zeros_like(c_s, dtype=dtype, device=device) + # Store the current device. self._device = device @@ -288,9 +273,11 @@ def __init__( self.register_buffer("_ref_mean_s", ref_mean_s) self.register_buffer("_ref_mean_chi", ref_mean_chi) self.register_buffer("_ref_mean_sqrtk", ref_mean_sqrtk) + self.register_buffer("_ref_mean_C6", ref_mean_C6) self.register_buffer("_c_s", c_s) self.register_buffer("_c_chi", c_chi) self.register_buffer("_c_sqrtk", c_sqrtk) + self.register_buffer("_c_C6", c_C6) def to(self, *args, **kwargs): """ @@ -305,9 +292,11 @@ def to(self, *args, **kwargs): self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(*args, **kwargs) + self._ref_mean_C6 = self._ref_mean_C6.to(*args, **kwargs) self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) + self._c_C6 = self._c_C6.to(*args, **kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.to(*args, **kwargs)) # Check for a device type in args and update the device attribute. @@ -331,9 +320,11 @@ def cuda(self, **kwargs): self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.cuda(**kwargs) + self._ref_mean_C6 = self._ref_mean_C6.cuda(**kwargs) self._c_s = self._c_s.cuda(**kwargs) self._c_chi = self._c_chi.cuda(**kwargs) self._c_sqrtk = self._c_sqrtk.cuda(**kwargs) + self._c_C6 = self._c_C6.cuda(**kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.cuda(**kwargs)) # Update the device attribute. @@ -353,10 +344,12 @@ def cpu(self, **kwargs): self._n_ref = self._n_ref.cpu(**kwargs) self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) - self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(**kwargs) + self._ref_mean_sqrtk = self._ref_mean_sqrtk.cpu(**kwargs) + self._ref_mean_C6 = self._ref_mean_C6.cpu(**kwargs) self._c_s = self._c_s.cpu(**kwargs) self._c_chi = self._c_chi.cpu(**kwargs) self._c_sqrtk = self._c_sqrtk.cpu(**kwargs) + self._c_C6 = self._c_C6.cpu(**kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.cpu(**kwargs)) # Update the device attribute. @@ -375,9 +368,11 @@ def double(self): self._ref_mean_s = self._ref_mean_s.double() self._ref_mean_chi = self._ref_mean_chi.double() self._ref_mean_sqrtk = self._ref_mean_sqrtk.double() + self._ref_mean_C6 = self._ref_mean_C6.double() self._c_s = self._c_s.double() self._c_chi = self._c_chi.double() self._c_sqrtk = self._c_sqrtk.double() + self._c_C6 = self._c_C6.double() self.k_Z = _torch.nn.Parameter(self.k_Z.double()) return self @@ -392,9 +387,11 @@ def float(self): self._ref_mean_s = self._ref_mean_s.float() self._ref_mean_chi = self._ref_mean_chi.float() self._ref_mean_sqrtk = self._ref_mean_sqrtk.float() + self._ref_mean_C6 = self._ref_mean_C6.float() self._c_s = self._c_s.float() self._c_chi = self._c_chi.float() self._c_sqrtk = self._c_sqrtk.float() + self._c_C6 = self._c_C6.float() self.k_Z = _torch.nn.Parameter(self.k_Z.float()) return self @@ -421,8 +418,9 @@ def forward(self, atomic_numbers, xyz_qm, q_total): result: (torch.Tensor (N_BATCH, N_QM_ATOMS,), torch.Tensor (N_BATCH, N_QM_ATOMS,), torch.Tensor (N_BATCH, N_QM_ATOMS,), - torch.Tensor (N_BATCH, N_QM_ATOMS * 3, N_QM_ATOMS * 3,)) - Valence widths, core charges, valence charges, A_thole tensor + torch.Tensor (N_BATCH, N_QM_ATOMS * 3, N_QM_ATOMS * 3,), + torch.Tensor (N_BATCH, N_QM_ATOMS,)) + Valence widths, core charges, valence charges, A_thole tensor, C6 coefficients """ # Mask for padded coordinates. @@ -460,7 +458,12 @@ def forward(self, atomic_numbers, xyz_qm, q_total): A_thole = self._get_A_thole(r_data, s, q_val, k, self.a_Thole) - return s, q_core, q_val, A_thole + if self._lj_mode is not None: + C6 = self._gpr(aev, self._ref_mean_C6, self._c_C6, species_id) + else: + C6 = None + + return s, q_core, q_val, A_thole, C6 @classmethod def _get_Kinv(cls, ref_features, sigma): @@ -1056,6 +1059,7 @@ def _get_f1_slater(r: Tensor, s: Tensor) -> Tensor: @staticmethod def _get_T0_slater(r: Tensor, s: Tensor) -> Tensor: """ + # Get distances Internal method, calculates T0 tensor for Slater densities. Parameters @@ -1073,6 +1077,7 @@ def _get_T0_slater(r: Tensor, s: Tensor) -> Tensor: results: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS) """ return (1 - (1 + r / (s * 2)) * _torch.exp(-r / s)) / r + @staticmethod def get_lj_energy( @@ -1124,15 +1129,15 @@ def get_lj_energy( sigma_r_inv_12 = sigma_r_inv_6 * sigma_r_inv_6 # Calculate pairwise energy matrix (N_BATCH, N_QM, N_MM) - pairwise_energy = 4 * epsilon * (sigma_r_inv_12 - sigma_r_inv_6) + lj_energy = 4 * epsilon * (sigma_r_inv_12 - sigma_r_inv_6) # Sum over QM and MM atoms for each batch element - total_energy = pairwise_energy.sum(dim=(1, 2)) + lj_energy = lj_energy.sum(dim=(1, 2)) - return total_energy + return lj_energy @staticmethod - def calculate_isotropic_polarizabilities(A_thole: _torch.Tensor) -> _torch.Tensor: + def get_isotropic_polarizabilities(A_thole: _torch.Tensor) -> _torch.Tensor: """ Calculate isotropic polarizabilities from the A_thole tensor. @@ -1176,8 +1181,8 @@ def _get_traces(A_thole: _torch.Tensor) -> _torch.Tensor: return _get_traces(A_thole) / 3.0 - def calculate_atomic_lj_parameters( - self, atomic_numbers: Tensor, rcubed: Tensor, alpha: Tensor + def get_lj_parameters( + self, c6: _torch.Tensor, alpha: _torch.Tensor ) -> Tuple[Tensor, Tensor]: """ Calculate Lennard-Jones sigma and epsilon parameters. @@ -1185,14 +1190,11 @@ def calculate_atomic_lj_parameters( Parameters ---------- - atomic_numbers : torch.Tensor(N_BATCH, N_ATOMS) - Atomic numbers of atoms in the molecule. - - rcubed : torch.Tensor(N_BATCH, N_ATOMS) - Cube of the vdW radii of atoms in the molecule in Bohr^3. - - alpha : torch.Tensor(N_BATCH, N_ATOMS) - Isotropic polarizabilities per atom in Bohr^3. + c6: _torch.Tensor(N_BATCH, N_ATOMS) + C6 coefficients per atom. + + alpha: _torch.Tensor(N_BATCH, N_ATOMS) + Isotropic polarizabilities per atom. Returns ------- @@ -1200,37 +1202,9 @@ def calculate_atomic_lj_parameters( Tuple[torch.Tensor, torch.Tensor] Tuple containing the sigma (Bohr) and epsilon (Hartree) LJ parameters for each atom. """ - # Mask out dummy atoms - mask = atomic_numbers > 0 - alpha = alpha * mask - rcubed = rcubed * mask - - # Get free atom volumes - try: - vol_isolated = self._vol_isolated[atomic_numbers] - except KeyError as e: - raise ValueError(f"Missing RCUBED_FREE entry for atomic number {e.args[0]}") - - # Volume scaling factor - scaling = rcubed / vol_isolated - - # Get C6 coefficients - try: - c6 = self._c6[atomic_numbers] - except KeyError as e: - raise ValueError( - f"Missing C6_COEFFICIENTS entry for atomic number {e.args[0]}" - ) - - # Scale C6 coefficients - c6_scaled = c6 * scaling**2 - - # Calculate vdW radius from polarizability (Fedorov-Tkatchenko relation) radius = 2.54 * alpha ** (1.0 / 7.0) rmin = 2 * radius - - # Calculate Lennard-Jones parameters sigma = rmin / (2 ** (1.0 / 6.0)) - epsilon = c6_scaled / (2 * rmin**6.0) + epsilon = c6 / (2 * rmin**6.0) return sigma, epsilon diff --git a/emle/train/_loss.py b/emle/train/_loss.py index 0d1eb14..23a403c 100644 --- a/emle/train/_loss.py +++ b/emle/train/_loss.py @@ -134,7 +134,7 @@ def forward(self, atomic_numbers, xyz, q_mol, q_target): self._update_chi_gpr(self._emle_base) # Calculate q_core and q_val - _, q_core, q_val, _ = self._emle_base(atomic_numbers, xyz, q_mol) + _, q_core, q_val, *_ = self._emle_base(atomic_numbers, xyz, q_mol) mask = atomic_numbers > 0 target = q_target[mask] @@ -275,7 +275,7 @@ def forward( self._update_sqrtk_gpr(self._emle_base) # Calculate A_thole and alpha_mol. - _, _, _, A_thole = self._emle_base(atomic_numbers, xyz, q_mol) + _, _, _, A_thole, *_ = self._emle_base(atomic_numbers, xyz, q_mol) alpha_mol = self._get_alpha_mol(A_thole, atomic_numbers > 0) triu_row, triu_col = _torch.triu_indices(3, 3, offset=0) @@ -311,3 +311,73 @@ def _update_sqrtk_gpr(emle_base): emle_base.ref_values_sqrtk, emle_base._Kinv, ) + + +class DispersionCoefficientLoss(_BaseLoss): + """ + Loss function for dispersion coefficients. Used to train ref_values_C6. + """ + def __init__(self, emle_base, loss=_torch.nn.MSELoss()): + super().__init__() + + from ..models._emle_base import EMLEBase + + if not isinstance(emle_base, EMLEBase): + raise TypeError("emle_base must be an instance of EMLEBase") + self._emle_base = emle_base + + if not isinstance(loss, _torch.nn.Module): + raise TypeError("loss must be an instance of torch.nn.Module") + self._loss = loss + + self._pol = None + + def forward(self, atomic_numbers, xyz, q_mol, C6_target): + """ + Forward pass. + + Parameters + ---------- + atomic_numbers: torch.Tensor(N_BATCH, MAX_N_ATOMS) + Atomic numbers. + + xyz: torch.Tensor(N_BATCH, MAX_N_ATOMS, 3) + Cartesian coordinates. + + q_mol: torch.Tensor(N_BATCH, MAX_N_ATOMS) + Molecular charges. + + C6_target: torch.Tensor(N_BATCH, MAX_N_ATOMS) + Target dispersion coefficients. + """ + # Update reference values for C6. + self._update_C6_gpr(self._emle_base) + + # Calculate C6. + s, q_core, q_val, A_thole, C6 = self._emle_base(atomic_numbers, xyz, q_mol) + # Calculate isotropic polarizabilities if not already calculated. + if self._pol is None: + self._pol = self._emle_base.calculate_isotropic_polarizabilities(A_thole).detach() + + # Mask out dummy atoms. + mask = atomic_numbers > 0 + target = C6_target[mask] + values = C6 + values = values[mask] + + # Calculate loss. + loss = self._loss(values, target) + + return ( + loss, + self._get_rmse(values, target), + self._get_max_error(values, target), + ) + + @staticmethod + def _update_C6_gpr(emle_base): + emle_base._ref_mean_C6, emle_base._c_C6 = emle_base._get_c( + emle_base._n_ref, + emle_base.ref_values_C6, + emle_base._Kinv, + ) \ No newline at end of file diff --git a/emle/train/_trainer.py b/emle/train/_trainer.py index 2ab296c..4b231e3 100644 --- a/emle/train/_trainer.py +++ b/emle/train/_trainer.py @@ -31,6 +31,7 @@ from ._ivm import IVM as _IVM from ._loss import QEqLoss as _QEqLoss from ._loss import TholeLoss as _TholeLoss +from ._loss import DispersionCoefficientLoss as _DispersionCoefficientLoss from ._utils import pad_to_max as _pad_to_max from ._utils import mean_by_z as _mean_by_z @@ -41,6 +42,7 @@ def __init__( emle_base=_EMLEBase, qeq_loss=_QEqLoss, thole_loss=_TholeLoss, + dispersion_coefficient_loss=_DispersionCoefficientLoss, log_level=None, log_file=None, ): @@ -56,6 +58,10 @@ def __init__( raise TypeError("thole_loss must be a reference to TholeLoss") self._thole_loss = thole_loss + if dispersion_coefficient_loss is not _DispersionCoefficientLoss: + raise TypeError("dispersion_coefficient_loss must be a reference to DispersionCoefficientLoss") + self._dispersion_coefficient_loss = dispersion_coefficient_loss + # First handle the logger. if log_level is None: log_level = "INFO" @@ -293,7 +299,8 @@ def train( q_core, q_val, alpha, - train_mask, + C6=None, + train_mask=None, alpha_mode="reference", sigma=1e-3, ivm_thr=0.05, @@ -301,6 +308,7 @@ def train( lr_qeq=0.05, lr_thole=0.05, lr_sqrtk=0.05, + lr_C6=0.05, print_every=10, computer_n_species=None, computer_zid_map=None, @@ -333,6 +341,9 @@ def train( alpha: array or tensor or list of tensor/arrays of shape (N_BATCH, 3, 3) Atomic polarizabilities. + C6: array or tensor or list of tensor/arrays of shape (N_BATCH, N_ATOMS, N_ATOMS) + C6 dispersion coefficients. If None, the C6 dispersion coefficients are not trained. + train_mask: torch.Tensor(N_BATCH,) Mask for training samples. @@ -357,6 +368,9 @@ def train( lr_sqrtk: float Learning rate for sqrtk. + lr_C6: float + Learning rate for C6. + print_every: int How often to print training progress. @@ -427,6 +441,11 @@ def train( alpha_train = alpha_train.to(device=device, dtype=dtype) species = species.to(device=device, dtype=_torch.int64) + if C6 is not None: + C6 = _pad_to_max(C6) + C6_train = C6[train_mask] + C6_train = C6_train.to(device=device, dtype=dtype) + # Get zid mapping. zid_mapping = self._get_zid_mapping(species) zid_train = zid_mapping[z_train] @@ -504,6 +523,15 @@ def train( if alpha_mode == "reference" else None ), + "ref_values_C6": ( + _torch.ones( + *ref_values_s.shape, + dtype=ref_values_s.dtype, + device=_torch.device(device), + ) + if C6 is not None + else None + ), } # Create the EMLE base instance. @@ -515,10 +543,10 @@ def train( emle_aev_computer=emle_aev_computer, species=species, alpha_mode=alpha_mode, + lj_mode="static" if C6 is not None else None, device=_torch.device(device), dtype=dtype, ) - # Fit chi, a_QEq (QEq over chi predicted with GPR). _logger.info("Fitting a_QEq and chi values...") self._train_model( @@ -536,9 +564,8 @@ def train( # Update GPR constants for chi # (now inconsistent since not updated after the last epoch) self._qeq_loss._update_chi_gpr(emle_base) - + """ _logger.debug(f"Optimized a_QEq: {emle_base.a_QEq.data.item()}") - # Fit a_Thole, k_Z (uses volumes predicted by QEq model). _logger.info("Fitting a_Thole and k_Z values...") self._train_model( @@ -575,6 +602,24 @@ def train( # Update GPR constants for sqrtk # (now inconsistent since not updated after the last epoch) self._thole_loss._update_sqrtk_gpr(emle_base) + """ + if C6 is not None: + _logger.info("Fitting ref_values_C6 values...") + self._train_model( + loss_class=self._dispersion_coefficient_loss, + opt_param_names=["ref_values_C6"], + lr=lr_C6, + epochs=1000, + print_every=print_every, + emle_base=emle_base, + atomic_numbers=z_train, + xyz=xyz_train, + q_mol=q_mol_train, + C6_target=C6_train, + ) + + # Update reference values for C6. + self._dispersion_coefficient_loss._update_C6_gpr(emle_base) # Create the final model. emle_model = { @@ -587,6 +632,9 @@ def train( "sqrtk_ref": ( emle_base.ref_values_sqrtk if alpha_mode == "reference" else None ), + "ref_values_C6": ( + emle_base.ref_values_C6 if C6 is not None else None + ), "species": species, "alpha_mode": alpha_mode, "n_ref": n_ref, @@ -603,11 +651,17 @@ def train( return emle_base emle_base._alpha_mode = "species" - s_pred, q_core_pred, q_val_pred, A_thole = emle_base( + emle_base_output = emle_base( z.to(device=device, dtype=_torch.int64), xyz.to(device=device, dtype=dtype), q_mol, ) + + s_pred = emle_base_output[0] + q_core_pred = emle_base_output[1] + q_val_pred = emle_base_output[2] + A_thole = emle_base_output[3] + z_mask = _torch.tensor(z > 0, device=device) plot_data = { "s_emle": s_pred, @@ -631,6 +685,10 @@ def train( plot_data["alpha_reference"] = self._thole_loss._get_alpha_mol( A_thole, z_mask ) + + if C6 is not None: + C6_pred = emle_base_output[4] + plot_data["C6_emle"] = C6_pred self._write_model_to_file(plot_data, plot_data_filename) From e556fe7d6a65d235c0dc296c2d80284301f8cf8d Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Thu, 24 Jul 2025 09:44:40 +0100 Subject: [PATCH 3/8] Ruff formatting --- emle/_analyzer.py | 4 ++-- emle/_orca_parser.py | 1 - emle/calculator.py | 9 ++++++--- emle/models/_patches.py | 1 + versioneer.py | 10 ++++------ 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/emle/_analyzer.py b/emle/_analyzer.py index b8c71d2..5de9c45 100644 --- a/emle/_analyzer.py +++ b/emle/_analyzer.py @@ -49,7 +49,7 @@ def __init__( parser=None, q_total=None, start=None, - end=None + end=None, ): """ Constructor. @@ -168,7 +168,7 @@ def __init__( self.qm_xyz, self.q_total, ) - self.atomic_alpha = 1. / _torch.diagonal(self.A_thole, dim1=1, dim2=2)[:, ::3] + self.atomic_alpha = 1.0 / _torch.diagonal(self.A_thole, dim1=1, dim2=2)[:, ::3] self.alpha = self._get_mol_alpha(self.A_thole, self.atomic_numbers) mask = (self.atomic_numbers > 0).unsqueeze(-1) diff --git a/emle/_orca_parser.py b/emle/_orca_parser.py index 94bd554..b3d6726 100644 --- a/emle/_orca_parser.py +++ b/emle/_orca_parser.py @@ -125,7 +125,6 @@ def __init__(self, filename, decompose=False, alpha=False): try: with _tarfile.open(filename, "r") as tar: - self._tar = tar self.names = self._get_names(tar) diff --git a/emle/calculator.py b/emle/calculator.py index 6491624..1d57a29 100644 --- a/emle/calculator.py +++ b/emle/calculator.py @@ -1327,12 +1327,15 @@ def _calculate_energy_and_gradients( _logger.error(msg) raise RuntimeError(msg) - if (self._qbc_deviation): + if self._qbc_deviation: E_std = _torch.std(base_model._E_vac_qbc).item() max_f_std = _torch.max(_torch.std(base_model._grads_qbc, axis=0)).item() with open(self._qbc_deviation, "a") as f: f.write(f"{E_std:12.5f}{max_f_std:12.5f}\n") - if self._qbc_deviation_threshold and max_f_std > self._qbc_deviation_threshold: + if ( + self._qbc_deviation_threshold + and max_f_std > self._qbc_deviation_threshold + ): msg = "Force deviation threshold reached!" raise ValueError(msg) @@ -1371,7 +1374,7 @@ def _calculate_energy_and_gradients( else: offset = int(not self._restart) lam = self._lambda_interpolate[0] + ( - (self._step / (self._interpolate_steps - offset)) + self._step / (self._interpolate_steps - offset) ) * (self._lambda_interpolate[1] - self._lambda_interpolate[0]) if lam < 0.0: lam = 0.0 diff --git a/emle/models/_patches.py b/emle/models/_patches.py index 2e67ee1..b21f4c9 100644 --- a/emle/models/_patches.py +++ b/emle/models/_patches.py @@ -32,6 +32,7 @@ # convert atom species from string to long tensor model0.species_to_tensor(['C', 'H', 'H', 'H', 'H']) """ + import os import torch import torchani diff --git a/versioneer.py b/versioneer.py index de97d90..080191b 100644 --- a/versioneer.py +++ b/versioneer.py @@ -514,9 +514,7 @@ def run_command( return stdout, process.returncode -LONG_VERSION_PY[ - "git" -] = r''' +LONG_VERSION_PY["git"] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -1840,9 +1838,9 @@ def get_versions(verbose: bool = False) -> Dict[str, Any]: handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" + assert cfg.versionfile_source is not None, ( + "please set versioneer.versionfile_source" + ) assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) From 2edd76b58f034e8dd77f71b31116b7c5491d2789 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Thu, 24 Jul 2025 09:45:10 +0100 Subject: [PATCH 4/8] Add "fixed" mode and miscellaneous fixes/updates --- emle/models/_ani.py | 4 +- emle/models/_emle.py | 143 +++++++++++++++++++++++++++++--------- emle/models/_emle_base.py | 12 ++-- emle/models/_mace.py | 99 +++++++++++++++++--------- emle/train/_loss.py | 21 +++--- emle/train/_trainer.py | 14 ++-- 6 files changed, 202 insertions(+), 91 deletions(-) diff --git a/emle/models/_ani.py b/emle/models/_ani.py index f21ff16..c922c25 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -375,7 +375,7 @@ def forward( ------- result: torch.Tensor (3,) or (3, BATCH) - The ANI2x and static and induced EMLE energy components in Hartree. + The ANI2x and static, induced and LJ EMLE energy components in Hartree. """ # Batch the inputs if necessary. if atomic_numbers.ndim == 1: @@ -407,4 +407,4 @@ def forward( E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge) # Return the ANI2x and EMLE energy components. - return _torch.stack((E_vac, E_emle[0], E_emle[1])) + return _torch.stack((E_vac, E_emle[0], E_emle[1], E_emle[2])) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 3f36ed7..3d3c6a5 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -38,6 +38,7 @@ from . import EMLEBase as _EMLEBase from . import _patches +from .._units import _HARTREE_TO_KJ_MOL, _BOHR_TO_ANGSTROM # Monkey-patch the TorchANI BuiltInModel and BuiltinEnsemble classes so that # they call self.aev_computer using args only to allow forward hooks to work @@ -83,10 +84,12 @@ def __init__( model=None, method="electrostatic", alpha_mode="species", - lj_method=None, atomic_numbers=None, qm_charge=0, mm_charges=None, + lj_mode=None, + lj_params_qm=None, + lj_xyz_qm=None, device=None, dtype=None, create_aev_calculator=True, @@ -124,15 +127,6 @@ def __init__( scaling factors are obtained with GPR using the values learned for each reference environment - lj_method: str - How the LJ parameters are calculated. - "dynamic": - Lennard-Jones parameters are calculated dynamically from EMLE parameters. - "static": - Lennard-Jones parameters are fixed. - None - Lennard-Jones parameters and interactions are not included. - atomic_numbers: List[int], Tuple[int], numpy.ndarray, torch.Tensor Atomic numbers for the QM region. This allows use of optimised AEV symmetry functions from the NNPOps package. Only use this option @@ -147,6 +141,26 @@ def __init__( List of MM charges for atoms in the QM region in units of mod electron charge. This is required if the 'mm' method is specified. + lj_mode: str + How the LJ parameters are calculated. + "flexible": + Lennard-Jones parameters are calculated dynamically for a given configuration. + "fixed": + Lennard-Jones parameters are fixed, i.e. independent of the configuration. + Requires specifying the LJ parameters for each atom in the QM region or to + provide an initial configuration. + None + Lennard-Jones parameters and interactions are not included. + + lj_params_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Lennard-Jones parameters for each atom in the QM region (sigma, epsilon) in units of Angstrom (sigma) + and kJ/mol (epsilon). This is required if the "lj_mode" is "fixed" and lj_param_qm is not provided. + Takes precedence over lj_xyz_qm. + + lj_xyz_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Positions of QM atoms in Angstrom. This is required if the "lj_mode" is "fixed" + and lj_param_qm is not provided. + device: torch.device The device on which to run the model. @@ -189,16 +203,6 @@ def __init__( raise ValueError("'alpha_mode' must be 'species' or 'reference'") self._alpha_mode = alpha_mode - if lj_method is not None: - if not isinstance(lj_method, str): - raise TypeError("'lj_method' must be of type 'str'") - lj_method = lj_method.lower().replace(" ", "") - if lj_method not in {"dynamic", "static"}: - raise ValueError("'lj_method' must be 'dynamic' or 'static'") - if lj_method == "static": - raise NotImplementedError("Static LJ model not implemented") - self._lj_method = lj_method - if atomic_numbers is not None: if isinstance(atomic_numbers, (_np.ndarray, _torch.Tensor)): atomic_numbers = atomic_numbers.tolist() @@ -254,6 +258,58 @@ def __init__( # Use the default species. species = self._species + if lj_mode is not None: + if not isinstance(lj_mode, str): + raise TypeError("'lj_mode' must be of type 'str'") + + lj_mode = lj_mode.lower().replace(" ", "") + if lj_mode not in {"flexible", "fixed"}: + raise ValueError("'lj_mode' must be 'flexible' or 'fixed'") + + if lj_mode == "fixed": + if lj_params_qm is None and lj_xyz_qm is None: + raise ValueError( + "lj_params_qm or lj_xyz_qm must be provided if lj_mode is 'fixed'" + ) + + if lj_params_qm is not None: + if not isinstance( + lj_params_qm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_params_qm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + raise TypeError( + "lj_params_qm must be a list of lists, tuples, or arrays" + ) + if len(lj_params_qm) != len(atomic_numbers): + raise ValueError( + "lj_params_qm must have the same length as the number of QM atoms" + ) + + lj_params_qm = _torch.tensor( + lj_params_qm, dtype=dtype, device=device + ) + self._lj_epsilon_qm = lj_params_qm[:, 1] / _HARTREE_TO_KJ_MOL + self._lj_sigma_qm = lj_params_qm[:, 0] / _BOHR_TO_ANGSTROM + lj_xyz_qm = None + else: + if not isinstance( + lj_xyz_qm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_xyz_qm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + raise TypeError( + "lj_xyz_qm must be a list of lists, tuples, or arrays" + ) + if len(lj_xyz_qm) != len(atomic_numbers): + raise ValueError( + "lj_xyz_qm must have the same length as the number of QM atoms" + ) + + lj_xyz_qm = _torch.tensor(lj_xyz_qm, dtype=dtype, device=device) + + self._lj_mode = lj_mode + if device is not None: if not isinstance(device, _torch.device): raise TypeError("'device' must be of type 'torch.device'") @@ -382,6 +438,16 @@ def __init__( dtype=dtype, ) + if lj_xyz_qm: + # Get the LJ parameters for the passed configuration + _, _, _, A_thole, C6 = self._emle_base( + self.atomic_numbers, lj_xyz_qm, qm_charge + ) + alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) + sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(C6, alpha_qm) + self._lj_sigma_qm = sigma_qm + self._lj_epsilon_qm = epsilon_qm + def to(self, *args, **kwargs): """ Performs Tensor dtype and/or device conversion on the model. @@ -466,11 +532,17 @@ def forward( qm_charge: int or torch.Tensor (BATCH,) The charge on the QM region. + sigma_mm: torch.Tensor (N_MM_ATOMS,) or (BATCH, N_MM_ATOMS) + Lennard-Jones sigma parameters for MM atoms in Angstrom. + + epsilon_mm: torch.Tensor (N_MM_ATOMS,) or (BATCH, N_MM_ATOMS) + Lennard-Jones epsilon parameters for MM atoms in kJ/mol. + Returns ------- - result: torch.Tensor (2,) or (2, BATCH) - The static and induced EMLE energy components in Hartree. + result: torch.Tensor (3,) or (3, BATCH) + The static, induced, and LJ EMLE energy components in Hartree. """ # Store the inputs as internal attributes. self._atomic_numbers = atomic_numbers @@ -513,9 +585,8 @@ def forward( ) # Convert coordinates to Bohr. - ANGSTROM_TO_BOHR = 1.8897261258369282 - xyz_qm_bohr = self._xyz_qm * ANGSTROM_TO_BOHR - xyz_mm_bohr = self._xyz_mm * ANGSTROM_TO_BOHR + xyz_qm_bohr = self._xyz_qm / _BOHR_TO_ANGSTROM + xyz_mm_bohr = self._xyz_mm / _BOHR_TO_ANGSTROM # Compute the static energy. if self._method == "mm": @@ -546,14 +617,20 @@ def forward( E_static, dtype=self._charges_mm.dtype, device=self._device ) - if self._lj_method is not None: - if self._lj_method == "dynamic": + # Compute the LJ energy. + if self._lj_mode is not None: + if self._lj_mode == "flexible": alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(C6, alpha_qm) - E_lj = self._emle_base.get_lj_energy(sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data) - else: - raise NotImplementedError("Static LJ model not implemented") - - return _torch.stack((E_static, E_ind, E_lj), dim=0) + elif self._lj_mode == "fixed": + sigma_qm = self._lj_sigma_qm.expand(batch_size, -1) + epsilon_qm = self._lj_epsilon_qm.expand(batch_size, -1) + E_lj = self._emle_base.get_lj_energy( + sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data + ) + else: + E_lj = _torch.zeros_like( + E_static, dtype=self._charges_mm.dtype, device=self._device + ) - return _torch.stack((E_static, E_ind), dim=0) + return _torch.stack((E_static, E_ind, E_lj), dim=0) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 4ec6cf2..679eb4c 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -43,6 +43,7 @@ except: _has_nnpops = False + class EMLEBase(_torch.nn.Module): """ Base class for the EMLE model. This is used to compute valence shell @@ -92,9 +93,9 @@ def __init__( "reference": scaling factors are obtained with GPR using the values learned for each reference environment - + lj_mode: str - Mode for calculating the Lennard-Jones potential. + Mode for calculating the Lennard-Jones potential. If None, the Lennard-Jones potential is not calculated. emle_aev_computer: EMLEAEVComputer @@ -207,7 +208,7 @@ def __init__( "using 'reference' alpha mode." ) raise ValueError(msg) - + if lj_mode is not None: assert lj_mode in ["static", "dynamic"], "Invalid Lennard-Jones mode" try: @@ -1077,7 +1078,6 @@ def _get_T0_slater(r: Tensor, s: Tensor) -> Tensor: results: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS) """ return (1 - (1 + r / (s * 2)) * _torch.exp(-r / s)) / r - @staticmethod def get_lj_energy( @@ -1192,9 +1192,9 @@ def get_lj_parameters( c6: _torch.Tensor(N_BATCH, N_ATOMS) C6 coefficients per atom. - + alpha: _torch.Tensor(N_BATCH, N_ATOMS) - Isotropic polarizabilities per atom. + Isotropic polarizabilities per atom. Returns ------- diff --git a/emle/models/_mace.py b/emle/models/_mace.py index ff1d379..1e1a33c 100644 --- a/emle/models/_mace.py +++ b/emle/models/_mace.py @@ -123,7 +123,7 @@ def __init__( Available pre-trained models are 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'. To use a locally trained MACE model, provide the path to the model file. If None, the MACE-OFF23(S) model will be used by default. - If more than one model is provided, only the energy from the first model will be returned + If more than one model is provided, only the energy from the first model will be returned in the forward pass, but the energy and forces from all models will be stored. atomic_numbers: List[int], Tuple[int], numpy.ndarray, torch.Tensor (N_ATOMS,) @@ -198,16 +198,25 @@ def __init__( ) if not isinstance(mace_model, (list, tuple)): - mace_model = [mace_model] if mace_model is None or isinstance(mace_model, str) else None + mace_model = ( + [mace_model] + if mace_model is None or isinstance(mace_model, str) + else None + ) - if mace_model is None or any(not isinstance(i, (str, type(None))) for i in mace_model): - raise TypeError("'mace_model' must be a list, tuple, or str, with elements of type str or None") + if mace_model is None or any( + not isinstance(i, (str, type(None))) for i in mace_model + ): + raise TypeError( + "'mace_model' must be a list, tuple, or str, with elements of type str or None" + ) from mace.tools.scripts_utils import extract_config_mace_model + self._mace_models = _torch.nn.ModuleList() for model in mace_model: source_model = self._load_mace_model(model, device) - + # Extract the config from the model. config = extract_config_mace_model(source_model) @@ -281,7 +290,7 @@ def _load_mace_model(mace_model: str, device: _torch.device): Path to the MACE model file or the name of the pre-trained MACE model. device: torch.device Device on which to load the model. - + Returns ------- source_model: torch.nn.Module @@ -392,9 +401,9 @@ def to(self, *args, **kwargs): """ self._emle = self._emle.to(*args, **kwargs) self._mace = self._mace.to(*args, **kwargs) - self._mace_models = _torch.nn.ModuleList([ - model.to(*args, **kwargs) for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.to(*args, **kwargs) for model in self._mace_models] + ) return self def cpu(self, **kwargs): @@ -405,9 +414,9 @@ def cpu(self, **kwargs): self._mace = self._mace.cpu(**kwargs) if self._atomic_numbers is not None: self._atomic_numbers = self._atomic_numbers.cpu(**kwargs) - self._mace_models = _torch.nn.ModuleList([ - model.cpu(**kwargs) for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.cpu(**kwargs) for model in self._mace_models] + ) return self def cuda(self, **kwargs): @@ -418,9 +427,9 @@ def cuda(self, **kwargs): self._mace = self._mace.cuda(**kwargs) if self._atomic_numbers is not None: self._atomic_numbers = self._atomic_numbers.cuda(**kwargs) - self._mace_models = _torch.nn.ModuleList([ - model.cuda(**kwargs) for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.cuda(**kwargs) for model in self._mace_models] + ) return self def double(self): @@ -429,9 +438,9 @@ def double(self): """ self._emle = self._emle.double() self._mace = self._mace.double() - self._mace_models = _torch.nn.ModuleList([ - model.double() for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.double() for model in self._mace_models] + ) return self def float(self): @@ -440,9 +449,9 @@ def float(self): """ self._emle = self._emle.float() self._mace = self._mace.float() - self._mace_models = _torch.nn.ModuleList([ - model.float() for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.float() for model in self._mace_models] + ) return self def forward( @@ -478,7 +487,7 @@ def forward( ------- result: torch.Tensor (3,) - The ANI2x and static and induced EMLE energy components in Hartree. + The MACE and static, induced and LJ EMLE energy components in Hartree. """ # Get the device. device = xyz_qm.device @@ -497,8 +506,17 @@ def forward( num_models = len(self._mace_models) # Create tensors to store the data for QbC. - self._E_vac_qbc = _torch.empty(num_models, num_batches, dtype=self._dtype, device=device) - self._grads_qbc = _torch.empty(num_models, num_batches, xyz_qm.shape[1], 3, dtype=self._dtype, device=device) + self._E_vac_qbc = _torch.empty( + num_models, num_batches, dtype=self._dtype, device=device + ) + self._grads_qbc = _torch.empty( + num_models, + num_batches, + xyz_qm.shape[1], + 3, + dtype=self._dtype, + device=device, + ) # Create tensors to store the results. results_E_vac = _torch.empty(num_batches, dtype=self._dtype, device=device) @@ -550,26 +568,32 @@ def forward( E_vac = self._mace(input_dict, compute_force=False)["interaction_energy"] - assert ( - E_vac is not None - ), "The model did not return any energy. Please check the input." + assert E_vac is not None, ( + "The model did not return any energy. Please check the input." + ) results_E_vac[i] = E_vac[0] * EV_TO_HARTREE # Decouple the positions from the computation graph for the next models. - input_dict["positions"] = input_dict["positions"].clone().detach().requires_grad_(True) + input_dict["positions"] = ( + input_dict["positions"].clone().detach().requires_grad_(True) + ) # Do inference for the other models. if len(self._mace_models) > 1: for j, mace in enumerate(self._mace_models): - E_vac_qbc = mace(input_dict, compute_force=False)["interaction_energy"] + E_vac_qbc = mace(input_dict, compute_force=False)[ + "interaction_energy" + ] - assert ( - E_vac_qbc is not None - ), "The model did not return any energy. Please check the input." + assert E_vac_qbc is not None, ( + "The model did not return any energy. Please check the input." + ) # Calculate the gradients - grads_qbc = _torch.autograd.grad([E_vac_qbc], [input_dict["positions"]])[0] + grads_qbc = _torch.autograd.grad( + [E_vac_qbc], [input_dict["positions"]] + )[0] assert grads_qbc is not None, "Gradient computation failed" # Store the results. @@ -582,6 +606,7 @@ def forward( zero = _torch.tensor(0.0, dtype=xyz_qm.dtype, device=device) results_E_emle_static[i] = zero results_E_emle_induced[i] = zero + results_E_emle_lj[i] = zero else: # Get the EMLE energy components. E_emle = self._emle( @@ -589,8 +614,14 @@ def forward( ) results_E_emle_static[i] = E_emle[0][0] results_E_emle_induced[i] = E_emle[1][0] + results_E_emle_lj[i] = E_emle[2][0] # Return the MACE and EMLE energy components. return _torch.stack( - [results_E_vac, results_E_emle_static, results_E_emle_induced] + [ + results_E_vac, + results_E_emle_static, + results_E_emle_induced, + results_E_emle_lj, + ] ) diff --git a/emle/train/_loss.py b/emle/train/_loss.py index 23a403c..dfcdd1a 100644 --- a/emle/train/_loss.py +++ b/emle/train/_loss.py @@ -317,6 +317,7 @@ class DispersionCoefficientLoss(_BaseLoss): """ Loss function for dispersion coefficients. Used to train ref_values_C6. """ + def __init__(self, emle_base, loss=_torch.nn.MSELoss()): super().__init__() @@ -329,7 +330,7 @@ def __init__(self, emle_base, loss=_torch.nn.MSELoss()): if not isinstance(loss, _torch.nn.Module): raise TypeError("loss must be an instance of torch.nn.Module") self._loss = loss - + self._pol = None def forward(self, atomic_numbers, xyz, q_mol, C6_target): @@ -343,7 +344,7 @@ def forward(self, atomic_numbers, xyz, q_mol, C6_target): xyz: torch.Tensor(N_BATCH, MAX_N_ATOMS, 3) Cartesian coordinates. - + q_mol: torch.Tensor(N_BATCH, MAX_N_ATOMS) Molecular charges. @@ -357,27 +358,29 @@ def forward(self, atomic_numbers, xyz, q_mol, C6_target): s, q_core, q_val, A_thole, C6 = self._emle_base(atomic_numbers, xyz, q_mol) # Calculate isotropic polarizabilities if not already calculated. if self._pol is None: - self._pol = self._emle_base.calculate_isotropic_polarizabilities(A_thole).detach() + self._pol = self._emle_base.calculate_isotropic_polarizabilities( + A_thole + ).detach() - # Mask out dummy atoms. + # Mask out dummy atoms. mask = atomic_numbers > 0 target = C6_target[mask] - values = C6 - values = values[mask] + values = C6 + values = values[mask] # Calculate loss. - loss = self._loss(values, target) + loss = self._loss(values, target) return ( loss, self._get_rmse(values, target), self._get_max_error(values, target), ) - + @staticmethod def _update_C6_gpr(emle_base): emle_base._ref_mean_C6, emle_base._c_C6 = emle_base._get_c( emle_base._n_ref, emle_base.ref_values_C6, emle_base._Kinv, - ) \ No newline at end of file + ) diff --git a/emle/train/_trainer.py b/emle/train/_trainer.py index 4b231e3..1c95ee6 100644 --- a/emle/train/_trainer.py +++ b/emle/train/_trainer.py @@ -59,7 +59,9 @@ def __init__( self._thole_loss = thole_loss if dispersion_coefficient_loss is not _DispersionCoefficientLoss: - raise TypeError("dispersion_coefficient_loss must be a reference to DispersionCoefficientLoss") + raise TypeError( + "dispersion_coefficient_loss must be a reference to DispersionCoefficientLoss" + ) self._dispersion_coefficient_loss = dispersion_coefficient_loss # First handle the logger. @@ -273,7 +275,7 @@ def _train_loop( optimizer.step() if (epoch + 1) % print_every == 0: _logger.info( - f"Epoch {epoch+1}: Loss ={loss.item():9.4f} " + f"Epoch {epoch + 1}: Loss ={loss.item():9.4f} " f"RMSE ={rmse.item():9.4f} " f"Max Error ={max_error.item():9.4f}" ) @@ -632,9 +634,7 @@ def train( "sqrtk_ref": ( emle_base.ref_values_sqrtk if alpha_mode == "reference" else None ), - "ref_values_C6": ( - emle_base.ref_values_C6 if C6 is not None else None - ), + "ref_values_C6": (emle_base.ref_values_C6 if C6 is not None else None), "species": species, "alpha_mode": alpha_mode, "n_ref": n_ref, @@ -656,7 +656,7 @@ def train( xyz.to(device=device, dtype=dtype), q_mol, ) - + s_pred = emle_base_output[0] q_core_pred = emle_base_output[1] q_val_pred = emle_base_output[2] @@ -685,7 +685,7 @@ def train( plot_data["alpha_reference"] = self._thole_loss._get_alpha_mol( A_thole, z_mask ) - + if C6 is not None: C6_pred = emle_base_output[4] plot_data["C6_emle"] = C6_pred From 294bc0d41839cd107846a87cb8520031bcdcc08d Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Thu, 24 Jul 2025 09:55:24 +0100 Subject: [PATCH 5/8] Fix font cases --- emle/models/_emle.py | 14 +++++----- emle/models/_emle_base.py | 40 ++++++++++++++-------------- emle/train/_loss.py | 18 ++++++------- emle/train/_trainer.py | 55 ++++++++++++++++++--------------------- 4 files changed, 61 insertions(+), 66 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 3d3c6a5..4637c51 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -352,9 +352,9 @@ def __init__( if "sqrtk_ref" in params else None ), - "ref_values_C6": ( - _torch.tensor(params["ref_values_C6"], dtype=dtype, device=device) - if "ref_values_C6" in params + "ref_values_c6": ( + _torch.tensor(params["c6_ref"], dtype=dtype, device=device) + if "c6_ref" in params else None ), } @@ -440,11 +440,11 @@ def __init__( if lj_xyz_qm: # Get the LJ parameters for the passed configuration - _, _, _, A_thole, C6 = self._emle_base( + _, _, _, A_thole, c6 = self._emle_base( self.atomic_numbers, lj_xyz_qm, qm_charge ) alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) - sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(C6, alpha_qm) + sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(c6, alpha_qm) self._lj_sigma_qm = sigma_qm self._lj_epsilon_qm = epsilon_qm @@ -578,7 +578,7 @@ def forward( ) # Get the parameters from the base model. - s, q_core, q_val, A_thole, C6 = self._emle_base( + s, q_core, q_val, A_thole, c6 = self._emle_base( self._atomic_numbers, self._xyz_qm, qm_charge, @@ -621,7 +621,7 @@ def forward( if self._lj_mode is not None: if self._lj_mode == "flexible": alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) - sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(C6, alpha_qm) + sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(c6, alpha_qm) elif self._lj_mode == "fixed": sigma_qm = self._lj_sigma_qm.expand(batch_size, -1) epsilon_qm = self._lj_epsilon_qm.expand(batch_size, -1) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 679eb4c..1c9fa50 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -212,10 +212,10 @@ def __init__( if lj_mode is not None: assert lj_mode in ["static", "dynamic"], "Invalid Lennard-Jones mode" try: - self.ref_values_C6 = _torch.nn.Parameter(params["ref_values_C6"]) + self.ref_values_c6 = _torch.nn.Parameter(params["c6_ref"]) except: msg = ( - "Missing 'ref_values_C6' key in params. This is required when " + "Missing 'c6_ref' key in params. This is required when " "using the Lennard-Jones potential." ) raise ValueError(msg) @@ -257,10 +257,10 @@ def __init__( ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, self.ref_values_sqrtk, Kinv) if lj_mode is not None: - ref_mean_C6, c_C6 = self._get_c(n_ref, self.ref_values_C6, Kinv) + ref_mean_c6, c_c6 = self._get_c(n_ref, self.ref_values_c6, Kinv) else: - ref_mean_C6 = _torch.zeros_like(ref_mean_s, dtype=dtype, device=device) - c_C6 = _torch.zeros_like(c_s, dtype=dtype, device=device) + ref_mean_c6 = _torch.zeros_like(ref_mean_s, dtype=dtype, device=device) + c_c6 = _torch.zeros_like(c_s, dtype=dtype, device=device) # Store the current device. self._device = device @@ -274,11 +274,11 @@ def __init__( self.register_buffer("_ref_mean_s", ref_mean_s) self.register_buffer("_ref_mean_chi", ref_mean_chi) self.register_buffer("_ref_mean_sqrtk", ref_mean_sqrtk) - self.register_buffer("_ref_mean_C6", ref_mean_C6) + self.register_buffer("_ref_mean_c6", ref_mean_c6) self.register_buffer("_c_s", c_s) self.register_buffer("_c_chi", c_chi) self.register_buffer("_c_sqrtk", c_sqrtk) - self.register_buffer("_c_C6", c_C6) + self.register_buffer("_c_c6", c_c6) def to(self, *args, **kwargs): """ @@ -293,11 +293,11 @@ def to(self, *args, **kwargs): self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(*args, **kwargs) - self._ref_mean_C6 = self._ref_mean_C6.to(*args, **kwargs) + self._ref_mean_c6 = self._ref_mean_c6.to(*args, **kwargs) self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) - self._c_C6 = self._c_C6.to(*args, **kwargs) + self._c_c6 = self._c_c6.to(*args, **kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.to(*args, **kwargs)) # Check for a device type in args and update the device attribute. @@ -321,11 +321,11 @@ def cuda(self, **kwargs): self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.cuda(**kwargs) - self._ref_mean_C6 = self._ref_mean_C6.cuda(**kwargs) + self._ref_mean_c6 = self._ref_mean_c6.cuda(**kwargs) self._c_s = self._c_s.cuda(**kwargs) self._c_chi = self._c_chi.cuda(**kwargs) self._c_sqrtk = self._c_sqrtk.cuda(**kwargs) - self._c_C6 = self._c_C6.cuda(**kwargs) + self._c_c6 = self._c_c6.cuda(**kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.cuda(**kwargs)) # Update the device attribute. @@ -346,11 +346,11 @@ def cpu(self, **kwargs): self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.cpu(**kwargs) - self._ref_mean_C6 = self._ref_mean_C6.cpu(**kwargs) + self._ref_mean_c6 = self._ref_mean_c6.cpu(**kwargs) self._c_s = self._c_s.cpu(**kwargs) self._c_chi = self._c_chi.cpu(**kwargs) self._c_sqrtk = self._c_sqrtk.cpu(**kwargs) - self._c_C6 = self._c_C6.cpu(**kwargs) + self._c_c6 = self._c_c6.cpu(**kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.cpu(**kwargs)) # Update the device attribute. @@ -369,11 +369,11 @@ def double(self): self._ref_mean_s = self._ref_mean_s.double() self._ref_mean_chi = self._ref_mean_chi.double() self._ref_mean_sqrtk = self._ref_mean_sqrtk.double() - self._ref_mean_C6 = self._ref_mean_C6.double() + self._ref_mean_c6 = self._ref_mean_c6.double() self._c_s = self._c_s.double() self._c_chi = self._c_chi.double() self._c_sqrtk = self._c_sqrtk.double() - self._c_C6 = self._c_C6.double() + self._c_c6 = self._c_c6.double() self.k_Z = _torch.nn.Parameter(self.k_Z.double()) return self @@ -388,11 +388,11 @@ def float(self): self._ref_mean_s = self._ref_mean_s.float() self._ref_mean_chi = self._ref_mean_chi.float() self._ref_mean_sqrtk = self._ref_mean_sqrtk.float() - self._ref_mean_C6 = self._ref_mean_C6.float() + self._ref_mean_c6 = self._ref_mean_c6.float() self._c_s = self._c_s.float() self._c_chi = self._c_chi.float() self._c_sqrtk = self._c_sqrtk.float() - self._c_C6 = self._c_C6.float() + self._c_c6 = self._c_c6.float() self.k_Z = _torch.nn.Parameter(self.k_Z.float()) return self @@ -460,11 +460,11 @@ def forward(self, atomic_numbers, xyz_qm, q_total): A_thole = self._get_A_thole(r_data, s, q_val, k, self.a_Thole) if self._lj_mode is not None: - C6 = self._gpr(aev, self._ref_mean_C6, self._c_C6, species_id) + c6 = self._gpr(aev, self._ref_mean_c6, self._c_c6, species_id) else: - C6 = None + c6 = None - return s, q_core, q_val, A_thole, C6 + return s, q_core, q_val, A_thole, c6 @classmethod def _get_Kinv(cls, ref_features, sigma): diff --git a/emle/train/_loss.py b/emle/train/_loss.py index dfcdd1a..ca6fee8 100644 --- a/emle/train/_loss.py +++ b/emle/train/_loss.py @@ -333,7 +333,7 @@ def __init__(self, emle_base, loss=_torch.nn.MSELoss()): self._pol = None - def forward(self, atomic_numbers, xyz, q_mol, C6_target): + def forward(self, atomic_numbers, xyz, q_mol, c6_target): """ Forward pass. @@ -348,14 +348,14 @@ def forward(self, atomic_numbers, xyz, q_mol, C6_target): q_mol: torch.Tensor(N_BATCH, MAX_N_ATOMS) Molecular charges. - C6_target: torch.Tensor(N_BATCH, MAX_N_ATOMS) + c6_target: torch.Tensor(N_BATCH, MAX_N_ATOMS) Target dispersion coefficients. """ # Update reference values for C6. - self._update_C6_gpr(self._emle_base) + self._update_c6_gpr(self._emle_base) # Calculate C6. - s, q_core, q_val, A_thole, C6 = self._emle_base(atomic_numbers, xyz, q_mol) + s, q_core, q_val, A_thole, c6 = self._emle_base(atomic_numbers, xyz, q_mol) # Calculate isotropic polarizabilities if not already calculated. if self._pol is None: self._pol = self._emle_base.calculate_isotropic_polarizabilities( @@ -364,8 +364,8 @@ def forward(self, atomic_numbers, xyz, q_mol, C6_target): # Mask out dummy atoms. mask = atomic_numbers > 0 - target = C6_target[mask] - values = C6 + target = c6_target[mask] + values = c6 values = values[mask] # Calculate loss. @@ -378,9 +378,9 @@ def forward(self, atomic_numbers, xyz, q_mol, C6_target): ) @staticmethod - def _update_C6_gpr(emle_base): - emle_base._ref_mean_C6, emle_base._c_C6 = emle_base._get_c( + def _update_c6_gpr(emle_base): + emle_base._ref_mean_c6, emle_base._c_c6 = emle_base._get_c( emle_base._n_ref, - emle_base.ref_values_C6, + emle_base.ref_values_c6, emle_base._Kinv, ) diff --git a/emle/train/_trainer.py b/emle/train/_trainer.py index 1c95ee6..833de8c 100644 --- a/emle/train/_trainer.py +++ b/emle/train/_trainer.py @@ -301,7 +301,7 @@ def train( q_core, q_val, alpha, - C6=None, + c6=None, train_mask=None, alpha_mode="reference", sigma=1e-3, @@ -310,7 +310,7 @@ def train( lr_qeq=0.05, lr_thole=0.05, lr_sqrtk=0.05, - lr_C6=0.05, + lr_c6=0.05, print_every=10, computer_n_species=None, computer_zid_map=None, @@ -343,7 +343,7 @@ def train( alpha: array or tensor or list of tensor/arrays of shape (N_BATCH, 3, 3) Atomic polarizabilities. - C6: array or tensor or list of tensor/arrays of shape (N_BATCH, N_ATOMS, N_ATOMS) + c6: array or tensor or list of tensor/arrays of shape (N_BATCH, N_ATOMS, N_ATOMS) C6 dispersion coefficients. If None, the C6 dispersion coefficients are not trained. train_mask: torch.Tensor(N_BATCH,) @@ -370,8 +370,8 @@ def train( lr_sqrtk: float Learning rate for sqrtk. - lr_C6: float - Learning rate for C6. + lr_c6: float + Learning rate for c6. print_every: int How often to print training progress. @@ -443,10 +443,10 @@ def train( alpha_train = alpha_train.to(device=device, dtype=dtype) species = species.to(device=device, dtype=_torch.int64) - if C6 is not None: - C6 = _pad_to_max(C6) - C6_train = C6[train_mask] - C6_train = C6_train.to(device=device, dtype=dtype) + if c6 is not None: + c6 = _pad_to_max(c6) + c6_train = c6[train_mask] + c6_train = c6_train.to(device=device, dtype=dtype) # Get zid mapping. zid_mapping = self._get_zid_mapping(species) @@ -525,13 +525,13 @@ def train( if alpha_mode == "reference" else None ), - "ref_values_C6": ( + "ref_values_c6": ( _torch.ones( *ref_values_s.shape, dtype=ref_values_s.dtype, device=_torch.device(device), ) - if C6 is not None + if c6 is not None else None ), } @@ -545,7 +545,7 @@ def train( emle_aev_computer=emle_aev_computer, species=species, alpha_mode=alpha_mode, - lj_mode="static" if C6 is not None else None, + lj_mode="static" if c6 is not None else None, device=_torch.device(device), dtype=dtype, ) @@ -566,7 +566,7 @@ def train( # Update GPR constants for chi # (now inconsistent since not updated after the last epoch) self._qeq_loss._update_chi_gpr(emle_base) - """ + _logger.debug(f"Optimized a_QEq: {emle_base.a_QEq.data.item()}") # Fit a_Thole, k_Z (uses volumes predicted by QEq model). _logger.info("Fitting a_Thole and k_Z values...") @@ -604,24 +604,24 @@ def train( # Update GPR constants for sqrtk # (now inconsistent since not updated after the last epoch) self._thole_loss._update_sqrtk_gpr(emle_base) - """ - if C6 is not None: - _logger.info("Fitting ref_values_C6 values...") + + if c6 is not None: + _logger.info("Fitting ref_values_c6 values...") self._train_model( loss_class=self._dispersion_coefficient_loss, - opt_param_names=["ref_values_C6"], - lr=lr_C6, + opt_param_names=["ref_values_c6"], + lr=lr_c6, epochs=1000, print_every=print_every, emle_base=emle_base, atomic_numbers=z_train, xyz=xyz_train, q_mol=q_mol_train, - C6_target=C6_train, + c6_target=c6_train, ) # Update reference values for C6. - self._dispersion_coefficient_loss._update_C6_gpr(emle_base) + self._dispersion_coefficient_loss._update_c6_gpr(emle_base) # Create the final model. emle_model = { @@ -634,7 +634,7 @@ def train( "sqrtk_ref": ( emle_base.ref_values_sqrtk if alpha_mode == "reference" else None ), - "ref_values_C6": (emle_base.ref_values_C6 if C6 is not None else None), + "c6_ref": (emle_base.ref_values_c6 if c6 is not None else None), "species": species, "alpha_mode": alpha_mode, "n_ref": n_ref, @@ -651,17 +651,12 @@ def train( return emle_base emle_base._alpha_mode = "species" - emle_base_output = emle_base( + s_pred, q_core_pred, q_val_pred, A_thole, c6_pred = emle_base( z.to(device=device, dtype=_torch.int64), xyz.to(device=device, dtype=dtype), q_mol, ) - s_pred = emle_base_output[0] - q_core_pred = emle_base_output[1] - q_val_pred = emle_base_output[2] - A_thole = emle_base_output[3] - z_mask = _torch.tensor(z > 0, device=device) plot_data = { "s_emle": s_pred, @@ -686,9 +681,9 @@ def train( A_thole, z_mask ) - if C6 is not None: - C6_pred = emle_base_output[4] - plot_data["C6_emle"] = C6_pred + if c6 is not None: + plot_data["c6_qm"] = c6 + plot_data["c6_emle"] = c6_pred self._write_model_to_file(plot_data, plot_data_filename) From 7b1c8faf89a14ca4c941a6578718f2fc3891a61c Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Thu, 24 Jul 2025 15:03:16 +0100 Subject: [PATCH 6/8] Small fixes --- emle/models/_emle.py | 20 ++++++++++++++------ emle/models/_emle_base.py | 10 ++++------ emle/train/_loss.py | 2 +- emle/train/_trainer.py | 4 ++-- 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 4637c51..8d80ad3 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -432,7 +432,7 @@ def __init__( q_core, emle_aev_computer=emle_aev_computer, alpha_mode=self._alpha_mode, - lj_mode=self._lj_method, + lj_mode=self._lj_mode, species=params.get("species", self._species), device=device, dtype=dtype, @@ -509,6 +509,7 @@ def forward( charges_mm: Tensor, xyz_qm: Tensor, xyz_mm: Tensor, + lj_params_mm: Tensor = None, qm_charge: Union[int, Tensor] = 0, ) -> Tensor: """ @@ -532,11 +533,8 @@ def forward( qm_charge: int or torch.Tensor (BATCH,) The charge on the QM region. - sigma_mm: torch.Tensor (N_MM_ATOMS,) or (BATCH, N_MM_ATOMS) - Lennard-Jones sigma parameters for MM atoms in Angstrom. - - epsilon_mm: torch.Tensor (N_MM_ATOMS,) or (BATCH, N_MM_ATOMS) - Lennard-Jones epsilon parameters for MM atoms in kJ/mol. + lj_params_mm: torch.Tensor (N_MM_ATOMS, 2) or (BATCH, N_MM_ATOMS, 2) + Lennard-Jones parameters for MM atoms in nanometers (sigma) and kJ/mol (epsilon). Returns ------- @@ -557,6 +555,9 @@ def forward( self._xyz_qm = self._xyz_qm.unsqueeze(0) self._xyz_mm = self._xyz_mm.unsqueeze(0) + if lj_params_mm is not None: + self._lj_params_mm = lj_params_mm.unsqueeze(0) + batch_size = self._atomic_numbers.shape[0] # Ensure qm_charge is a tensor and repeat for batch size if necessary @@ -619,12 +620,19 @@ def forward( # Compute the LJ energy. if self._lj_mode is not None: + sigma_mm = lj_params_mm[:, :, 0] * 10.0 / _BOHR_TO_ANGSTROM + epsilon_mm = lj_params_mm[:, :, 1] / _HARTREE_TO_KJ_MOL + if self._lj_mode == "flexible": alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(c6, alpha_qm) elif self._lj_mode == "fixed": sigma_qm = self._lj_sigma_qm.expand(batch_size, -1) epsilon_qm = self._lj_epsilon_qm.expand(batch_size, -1) + import numpy as np + np.savetxt("sigma_qm.txt", sigma_qm.detach().cpu().numpy()) + np.savetxt("epsilon_qm.txt", epsilon_qm.detach().cpu().numpy()) + E_lj = self._emle_base.get_lj_energy( sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data ) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index 1c9fa50..e35066b 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -110,7 +110,6 @@ def __init__( dtype: torch.dtype The data type to use for the models floating point tensors. """ - # Call the base class constructor. super().__init__() @@ -119,7 +118,7 @@ def __init__( raise TypeError("'params' must be of type 'dict'") if not all( k in params - for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z"] + for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z", "ref_values_c6"] ): raise ValueError( "'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'" @@ -210,12 +209,12 @@ def __init__( raise ValueError(msg) if lj_mode is not None: - assert lj_mode in ["static", "dynamic"], "Invalid Lennard-Jones mode" + assert lj_mode in ["fixed", "flexible"], "Invalid Lennard-Jones mode" try: - self.ref_values_c6 = _torch.nn.Parameter(params["c6_ref"]) + self.ref_values_c6 = _torch.nn.Parameter(params["ref_values_c6"]) except: msg = ( - "Missing 'c6_ref' key in params. This is required when " + "Missing 'ref_values_c6' key in params. This is required when " "using the Lennard-Jones potential." ) raise ValueError(msg) @@ -444,7 +443,6 @@ def forward(self, atomic_numbers, xyz_qm, q_total): xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR r_data = self._get_r_data(xyz_qm_bohr, mask) - q_core = self._q_core[species_id] * mask q = self._get_q(r_data, s, chi, q_total, mask) q_val = q - q_core diff --git a/emle/train/_loss.py b/emle/train/_loss.py index ca6fee8..cf8fdb6 100644 --- a/emle/train/_loss.py +++ b/emle/train/_loss.py @@ -358,7 +358,7 @@ def forward(self, atomic_numbers, xyz, q_mol, c6_target): s, q_core, q_val, A_thole, c6 = self._emle_base(atomic_numbers, xyz, q_mol) # Calculate isotropic polarizabilities if not already calculated. if self._pol is None: - self._pol = self._emle_base.calculate_isotropic_polarizabilities( + self._pol = self._emle_base.get_isotropic_polarizabilities( A_thole ).detach() diff --git a/emle/train/_trainer.py b/emle/train/_trainer.py index 833de8c..314ef6b 100644 --- a/emle/train/_trainer.py +++ b/emle/train/_trainer.py @@ -545,7 +545,7 @@ def train( emle_aev_computer=emle_aev_computer, species=species, alpha_mode=alpha_mode, - lj_mode="static" if c6 is not None else None, + lj_mode="fixed", device=_torch.device(device), dtype=dtype, ) @@ -611,7 +611,7 @@ def train( loss_class=self._dispersion_coefficient_loss, opt_param_names=["ref_values_c6"], lr=lr_c6, - epochs=1000, + epochs=2000, print_every=print_every, emle_base=emle_base, atomic_numbers=z_train, From 69dce173d7147c8e50416993eb8102cc09ccd6ba Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Thu, 24 Jul 2025 17:12:47 +0100 Subject: [PATCH 7/8] Fixed and add LJ to calculator --- emle/calculator.py | 83 ++++++++++++++++++++++++++++++++++++++- emle/models/_emle.py | 38 ++++++++++++------ emle/models/_emle_base.py | 9 ++++- emle/train/_loss.py | 4 +- emle/train/_trainer.py | 4 +- 5 files changed, 119 insertions(+), 19 deletions(-) diff --git a/emle/calculator.py b/emle/calculator.py index 1d57a29..9d79352 100644 --- a/emle/calculator.py +++ b/emle/calculator.py @@ -107,6 +107,10 @@ def __init__( interpolate_steps=None, restart=False, device=None, + lj_mode=None, + lj_params_mm=None, + lj_params_qm=None, + lj_xyz_qm=None, orca_template=None, energy_frequency=0, energy_file="emle_energy.txt", @@ -301,6 +305,26 @@ def __init__( ! Angs NoUseSym *xyzfile 0 1 inpfile.xyz + lj_mode: str + The mode to use for the Lennard-Jones parameters. Options are: + "fixed": + Lennard-Jones parameters are fixed and provided in lj_params_qm. + "flexible": + Lennard-Jones parameters are determined from a configuration. + This requires lj_xyz_qm and atomic_numbers to be provided. + + lj_params_mm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Lennard-Jones parameters for each atom in the MM region (sigma, epsilon) in units of nanometers (sigma) + and kJ/mol (epsilon). This is required for both "fixed" and "flexible" modes. + + lj_params_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Lennard-Jones parameters for each atom in the QM region (sigma, epsilon) in units of nanometers (sigma) + and kJ/mol (epsilon). This is required if the "lj_mode" is "fixed" and lj_xyz_qm is not provided. + Takes precedence over lj_xyz_qm. + + lj_xyz_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Positions of the atoms in the QM region for which the Lennard-Jones parameters are to be determined. + energy_frequency: int The frequency of logging energies to file. If 0, then no energies are logged. @@ -459,6 +483,47 @@ def __init__( raise TypeError(msg) self._qm_charge = qm_charge + if lj_mode is not None: + if lj_mode.lower().replace(" ", "") not in ["fixed", "flexible"]: + msg = f"Unsupported Lennard-Jones mode: {lj_mode}. Options are: 'fixed', 'flexible'" + _logger.error(msg) + raise ValueError(msg) + + self._lj_mode = lj_mode.lower().replace(" ", "") + + if lj_params_mm is None: + msg = ( + "lj_params_mm must be provided if lj_mode is 'fixed' or 'flexible'" + ) + _logger.error(msg) + raise ValueError(msg) + else: + if not isinstance( + lj_params_mm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_params_mm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + msg = "lj_params_mm must be a list of lists, tuples, or arrays" + _logger.error(msg) + raise TypeError(msg) + self._lj_params_mm = _torch.tensor( + lj_params_mm, dtype=self._device.dtype, device=self._device + ) + + if lambda_interpolate is not None: + if not isinstance( + lj_params_mm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_params_mm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + msg = "When using interpolation, 'lj_params_qm' must be provided and must be a list of lists, tuples, or arrays." + _logger.error(msg) + raise TypeError(msg) + + self._lj_params_qm = _torch.tensor( + lj_params_qm, dtype=self._device.dtype, device=self._device + ) + # Create the EMLE model instance. self._emle = _EMLE( model=model, @@ -468,6 +533,9 @@ def __init__( mm_charges=self._mm_charges, qm_charge=self._qm_charge, device=self._device, + lj_mode=lj_mode, + lj_params_qm=lj_params_qm, + lj_xyz_qm=lj_xyz_qm, ) # Validate the backend(s). @@ -868,6 +936,8 @@ def __init__( method="mm", mm_charges=self._mm_charges, device=self._device, + lj_mode="fixed" if self._lj_mode is not None else None, + lj_params_qm=self._lj_params_qm, ) else: @@ -1131,6 +1201,7 @@ def _calculate_energy_and_gradients( xyz_qm, xyz_mm, atoms=None, + lj_params_mm=None, charge=0, ): """ @@ -1154,6 +1225,9 @@ def _calculate_energy_and_gradients( atoms: ase.Atoms The atoms object for the QM region. + lj_params_mm: numpy.ndarray, (N_MM_ATOMS, 2, 2) + The LJ parameters for the MM atoms. + charge: int The total charge of the QM region. @@ -1286,7 +1360,9 @@ def _calculate_energy_and_gradients( if base_model is None: try: if len(xyz_mm) > 0: - E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge) + E = self._emle( + atomic_numbers, charges_mm, xyz_qm, xyz_mm, lj_params_mm, charge + ) dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad( E.sum(), (xyz_qm, xyz_mm), allow_unused=allow_unused ) @@ -1506,12 +1582,17 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None _logger.error(msg) raise ValueError(msg) + if self._lj_mode is not None: + # Determine LJ parameters for the QM region. + lj_params_mm = self._lj_params_mm[idx_mm, :, :] + # Compute the energy and gradients. E_vac, grad_vac, E_tot, grad_qm, grad_mm = self._calculate_energy_and_gradients( atomic_numbers, charges_mm, xyz_qm, xyz_mm, + lj_params_mm=lj_params_mm, ) # Store the number of MM atoms. diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 8d80ad3..5328a9a 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -153,7 +153,7 @@ def __init__( Lennard-Jones parameters and interactions are not included. lj_params_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor - Lennard-Jones parameters for each atom in the QM region (sigma, epsilon) in units of Angstrom (sigma) + Lennard-Jones parameters for each atom in the QM region (sigma, epsilon) in units of nanometers (sigma) and kJ/mol (epsilon). This is required if the "lj_mode" is "fixed" and lj_param_qm is not provided. Takes precedence over lj_xyz_qm. @@ -281,16 +281,12 @@ def __init__( raise TypeError( "lj_params_qm must be a list of lists, tuples, or arrays" ) - if len(lj_params_qm) != len(atomic_numbers): - raise ValueError( - "lj_params_qm must have the same length as the number of QM atoms" - ) lj_params_qm = _torch.tensor( lj_params_qm, dtype=dtype, device=device ) self._lj_epsilon_qm = lj_params_qm[:, 1] / _HARTREE_TO_KJ_MOL - self._lj_sigma_qm = lj_params_qm[:, 0] / _BOHR_TO_ANGSTROM + self._lj_sigma_qm = lj_params_qm[:, 0] * 10.0 / _BOHR_TO_ANGSTROM lj_xyz_qm = None else: if not isinstance( @@ -301,6 +297,12 @@ def __init__( raise TypeError( "lj_xyz_qm must be a list of lists, tuples, or arrays" ) + + if atomic_numbers is None: + raise ValueError( + "atomic_numbers must be provided if LJ parameters are to be determined from a configuration" + ) + if len(lj_xyz_qm) != len(atomic_numbers): raise ValueError( "lj_xyz_qm must have the same length as the number of QM atoms" @@ -438,15 +440,30 @@ def __init__( dtype=dtype, ) - if lj_xyz_qm: + if lj_xyz_qm is not None: + raise NotImplementedError( + "LJ parameters for fixed mode with a configuration is not implemented yet" + ) + """ + atomic_numbers = _torch.as_tensor(atomic_numbers) + lj_xyz_qm = _torch.as_tensor(lj_xyz_qm) + qm_charge = _torch.as_tensor(qm_charge) + if atomic_numbers.ndim == 1: + atomic_numbers = atomic_numbers.unsqueeze(0) + if lj_xyz_qm.ndim == 2: + lj_xyz_qm = lj_xyz_qm.unsqueeze(0) + if qm_charge.ndim == 0: + qm_charge = qm_charge.unsqueeze(0) + # Get the LJ parameters for the passed configuration _, _, _, A_thole, c6 = self._emle_base( - self.atomic_numbers, lj_xyz_qm, qm_charge + atomic_numbers, lj_xyz_qm, qm_charge ) alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(c6, alpha_qm) self._lj_sigma_qm = sigma_qm self._lj_epsilon_qm = epsilon_qm + """ def to(self, *args, **kwargs): """ @@ -620,7 +637,7 @@ def forward( # Compute the LJ energy. if self._lj_mode is not None: - sigma_mm = lj_params_mm[:, :, 0] * 10.0 / _BOHR_TO_ANGSTROM + sigma_mm = lj_params_mm[:, :, 0] * 10.0 / _BOHR_TO_ANGSTROM epsilon_mm = lj_params_mm[:, :, 1] / _HARTREE_TO_KJ_MOL if self._lj_mode == "flexible": @@ -629,9 +646,6 @@ def forward( elif self._lj_mode == "fixed": sigma_qm = self._lj_sigma_qm.expand(batch_size, -1) epsilon_qm = self._lj_epsilon_qm.expand(batch_size, -1) - import numpy as np - np.savetxt("sigma_qm.txt", sigma_qm.detach().cpu().numpy()) - np.savetxt("epsilon_qm.txt", epsilon_qm.detach().cpu().numpy()) E_lj = self._emle_base.get_lj_energy( sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index e35066b..a8dbcd1 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -118,7 +118,14 @@ def __init__( raise TypeError("'params' must be of type 'dict'") if not all( k in params - for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z", "ref_values_c6"] + for k in [ + "a_QEq", + "a_Thole", + "ref_values_s", + "ref_values_chi", + "k_Z", + "ref_values_c6", + ] ): raise ValueError( "'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'" diff --git a/emle/train/_loss.py b/emle/train/_loss.py index cf8fdb6..8248d3c 100644 --- a/emle/train/_loss.py +++ b/emle/train/_loss.py @@ -358,9 +358,7 @@ def forward(self, atomic_numbers, xyz, q_mol, c6_target): s, q_core, q_val, A_thole, c6 = self._emle_base(atomic_numbers, xyz, q_mol) # Calculate isotropic polarizabilities if not already calculated. if self._pol is None: - self._pol = self._emle_base.get_isotropic_polarizabilities( - A_thole - ).detach() + self._pol = self._emle_base.get_isotropic_polarizabilities(A_thole).detach() # Mask out dummy atoms. mask = atomic_numbers > 0 diff --git a/emle/train/_trainer.py b/emle/train/_trainer.py index 314ef6b..4b8f174 100644 --- a/emle/train/_trainer.py +++ b/emle/train/_trainer.py @@ -566,7 +566,7 @@ def train( # Update GPR constants for chi # (now inconsistent since not updated after the last epoch) self._qeq_loss._update_chi_gpr(emle_base) - + _logger.debug(f"Optimized a_QEq: {emle_base.a_QEq.data.item()}") # Fit a_Thole, k_Z (uses volumes predicted by QEq model). _logger.info("Fitting a_Thole and k_Z values...") @@ -604,7 +604,7 @@ def train( # Update GPR constants for sqrtk # (now inconsistent since not updated after the last epoch) self._thole_loss._update_sqrtk_gpr(emle_base) - + if c6 is not None: _logger.info("Fitting ref_values_c6 values...") self._train_model( From 7112c9a786d909926dd3761f03e448da050f3c40 Mon Sep 17 00:00:00 2001 From: Joao Morado Date: Fri, 25 Jul 2025 09:02:03 +0100 Subject: [PATCH 8/8] Add fixed mode with initial configuration --- emle/models/_emle.py | 76 ++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 5328a9a..778d8c1 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -287,7 +287,7 @@ def __init__( ) self._lj_epsilon_qm = lj_params_qm[:, 1] / _HARTREE_TO_KJ_MOL self._lj_sigma_qm = lj_params_qm[:, 0] * 10.0 / _BOHR_TO_ANGSTROM - lj_xyz_qm = None + self._lj_xyz_qm = None else: if not isinstance( lj_xyz_qm, (list, tuple, _np.ndarray, _torch.Tensor) @@ -297,18 +297,11 @@ def __init__( raise TypeError( "lj_xyz_qm must be a list of lists, tuples, or arrays" ) - - if atomic_numbers is None: - raise ValueError( - "atomic_numbers must be provided if LJ parameters are to be determined from a configuration" - ) - - if len(lj_xyz_qm) != len(atomic_numbers): - raise ValueError( - "lj_xyz_qm must have the same length as the number of QM atoms" - ) - - lj_xyz_qm = _torch.tensor(lj_xyz_qm, dtype=dtype, device=device) + self._lj_epsilon_qm = None + self._lj_sigma_qm = None + self._lj_xyz_qm = _torch.tensor( + lj_xyz_qm, dtype=dtype, device=device + ) self._lj_mode = lj_mode @@ -441,29 +434,7 @@ def __init__( ) if lj_xyz_qm is not None: - raise NotImplementedError( - "LJ parameters for fixed mode with a configuration is not implemented yet" - ) - """ - atomic_numbers = _torch.as_tensor(atomic_numbers) - lj_xyz_qm = _torch.as_tensor(lj_xyz_qm) - qm_charge = _torch.as_tensor(qm_charge) - if atomic_numbers.ndim == 1: - atomic_numbers = atomic_numbers.unsqueeze(0) - if lj_xyz_qm.ndim == 2: - lj_xyz_qm = lj_xyz_qm.unsqueeze(0) - if qm_charge.ndim == 0: - qm_charge = qm_charge.unsqueeze(0) - - # Get the LJ parameters for the passed configuration - _, _, _, A_thole, c6 = self._emle_base( - atomic_numbers, lj_xyz_qm, qm_charge - ) - alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) - sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(c6, alpha_qm) - self._lj_sigma_qm = sigma_qm - self._lj_epsilon_qm = epsilon_qm - """ + pass def to(self, *args, **kwargs): """ @@ -635,24 +606,45 @@ def forward( E_static, dtype=self._charges_mm.dtype, device=self._device ) - # Compute the LJ energy. - if self._lj_mode is not None: + # Compute the LJ energy + if self._lj_mode is None: + E_lj = _torch.zeros_like( + E_static, dtype=self._charges_mm.dtype, device=self._device + ) + else: + # Convert MM LJ parameters sigma_mm = lj_params_mm[:, :, 0] * 10.0 / _BOHR_TO_ANGSTROM epsilon_mm = lj_params_mm[:, :, 1] / _HARTREE_TO_KJ_MOL if self._lj_mode == "flexible": alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(c6, alpha_qm) + elif self._lj_mode == "fixed": + if self._lj_sigma_qm is None or self._lj_epsilon_qm is None: + if self._lj_xyz_qm is None: + raise RuntimeError( + "LJ mode is 'fixed', but LJ parameters are not set and lj_xyz_qm is missing." + ) + lj_xyz_qm = ( + self._lj_xyz_qm.unsqueeze(0) + if self._lj_xyz_qm.ndim == 2 + else self._lj_xyz_qm + ) + _, _, _, A_thole, c6 = self._emle_base( + self._atomic_numbers[0:1, :], lj_xyz_qm, qm_charge[0:1] + ) + alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) + self._lj_sigma_qm, self._lj_epsilon_qm = ( + self._emle_base.get_lj_parameters(c6, alpha_qm) + ) + sigma_qm = self._lj_sigma_qm.expand(batch_size, -1) epsilon_qm = self._lj_epsilon_qm.expand(batch_size, -1) + # Compute Lennard-Jones energy E_lj = self._emle_base.get_lj_energy( sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data ) - else: - E_lj = _torch.zeros_like( - E_static, dtype=self._charges_mm.dtype, device=self._device - ) return _torch.stack((E_static, E_ind, E_lj), dim=0)