diff --git a/emle/_analyzer.py b/emle/_analyzer.py index b8c71d2..5de9c45 100644 --- a/emle/_analyzer.py +++ b/emle/_analyzer.py @@ -49,7 +49,7 @@ def __init__( parser=None, q_total=None, start=None, - end=None + end=None, ): """ Constructor. @@ -168,7 +168,7 @@ def __init__( self.qm_xyz, self.q_total, ) - self.atomic_alpha = 1. / _torch.diagonal(self.A_thole, dim1=1, dim2=2)[:, ::3] + self.atomic_alpha = 1.0 / _torch.diagonal(self.A_thole, dim1=1, dim2=2)[:, ::3] self.alpha = self._get_mol_alpha(self.A_thole, self.atomic_numbers) mask = (self.atomic_numbers > 0).unsqueeze(-1) diff --git a/emle/_orca_parser.py b/emle/_orca_parser.py index 94bd554..b3d6726 100644 --- a/emle/_orca_parser.py +++ b/emle/_orca_parser.py @@ -125,7 +125,6 @@ def __init__(self, filename, decompose=False, alpha=False): try: with _tarfile.open(filename, "r") as tar: - self._tar = tar self.names = self._get_names(tar) diff --git a/emle/calculator.py b/emle/calculator.py index 6491624..9d79352 100644 --- a/emle/calculator.py +++ b/emle/calculator.py @@ -107,6 +107,10 @@ def __init__( interpolate_steps=None, restart=False, device=None, + lj_mode=None, + lj_params_mm=None, + lj_params_qm=None, + lj_xyz_qm=None, orca_template=None, energy_frequency=0, energy_file="emle_energy.txt", @@ -301,6 +305,26 @@ def __init__( ! Angs NoUseSym *xyzfile 0 1 inpfile.xyz + lj_mode: str + The mode to use for the Lennard-Jones parameters. Options are: + "fixed": + Lennard-Jones parameters are fixed and provided in lj_params_qm. + "flexible": + Lennard-Jones parameters are determined from a configuration. + This requires lj_xyz_qm and atomic_numbers to be provided. + + lj_params_mm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Lennard-Jones parameters for each atom in the MM region (sigma, epsilon) in units of nanometers (sigma) + and kJ/mol (epsilon). This is required for both "fixed" and "flexible" modes. + + lj_params_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Lennard-Jones parameters for each atom in the QM region (sigma, epsilon) in units of nanometers (sigma) + and kJ/mol (epsilon). This is required if the "lj_mode" is "fixed" and lj_xyz_qm is not provided. + Takes precedence over lj_xyz_qm. + + lj_xyz_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Positions of the atoms in the QM region for which the Lennard-Jones parameters are to be determined. + energy_frequency: int The frequency of logging energies to file. If 0, then no energies are logged. @@ -459,6 +483,47 @@ def __init__( raise TypeError(msg) self._qm_charge = qm_charge + if lj_mode is not None: + if lj_mode.lower().replace(" ", "") not in ["fixed", "flexible"]: + msg = f"Unsupported Lennard-Jones mode: {lj_mode}. Options are: 'fixed', 'flexible'" + _logger.error(msg) + raise ValueError(msg) + + self._lj_mode = lj_mode.lower().replace(" ", "") + + if lj_params_mm is None: + msg = ( + "lj_params_mm must be provided if lj_mode is 'fixed' or 'flexible'" + ) + _logger.error(msg) + raise ValueError(msg) + else: + if not isinstance( + lj_params_mm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_params_mm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + msg = "lj_params_mm must be a list of lists, tuples, or arrays" + _logger.error(msg) + raise TypeError(msg) + self._lj_params_mm = _torch.tensor( + lj_params_mm, dtype=self._device.dtype, device=self._device + ) + + if lambda_interpolate is not None: + if not isinstance( + lj_params_mm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_params_mm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + msg = "When using interpolation, 'lj_params_qm' must be provided and must be a list of lists, tuples, or arrays." + _logger.error(msg) + raise TypeError(msg) + + self._lj_params_qm = _torch.tensor( + lj_params_qm, dtype=self._device.dtype, device=self._device + ) + # Create the EMLE model instance. self._emle = _EMLE( model=model, @@ -468,6 +533,9 @@ def __init__( mm_charges=self._mm_charges, qm_charge=self._qm_charge, device=self._device, + lj_mode=lj_mode, + lj_params_qm=lj_params_qm, + lj_xyz_qm=lj_xyz_qm, ) # Validate the backend(s). @@ -868,6 +936,8 @@ def __init__( method="mm", mm_charges=self._mm_charges, device=self._device, + lj_mode="fixed" if self._lj_mode is not None else None, + lj_params_qm=self._lj_params_qm, ) else: @@ -1131,6 +1201,7 @@ def _calculate_energy_and_gradients( xyz_qm, xyz_mm, atoms=None, + lj_params_mm=None, charge=0, ): """ @@ -1154,6 +1225,9 @@ def _calculate_energy_and_gradients( atoms: ase.Atoms The atoms object for the QM region. + lj_params_mm: numpy.ndarray, (N_MM_ATOMS, 2, 2) + The LJ parameters for the MM atoms. + charge: int The total charge of the QM region. @@ -1286,7 +1360,9 @@ def _calculate_energy_and_gradients( if base_model is None: try: if len(xyz_mm) > 0: - E = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, charge) + E = self._emle( + atomic_numbers, charges_mm, xyz_qm, xyz_mm, lj_params_mm, charge + ) dE_dxyz_qm, dE_dxyz_mm = _torch.autograd.grad( E.sum(), (xyz_qm, xyz_mm), allow_unused=allow_unused ) @@ -1327,12 +1403,15 @@ def _calculate_energy_and_gradients( _logger.error(msg) raise RuntimeError(msg) - if (self._qbc_deviation): + if self._qbc_deviation: E_std = _torch.std(base_model._E_vac_qbc).item() max_f_std = _torch.max(_torch.std(base_model._grads_qbc, axis=0)).item() with open(self._qbc_deviation, "a") as f: f.write(f"{E_std:12.5f}{max_f_std:12.5f}\n") - if self._qbc_deviation_threshold and max_f_std > self._qbc_deviation_threshold: + if ( + self._qbc_deviation_threshold + and max_f_std > self._qbc_deviation_threshold + ): msg = "Force deviation threshold reached!" raise ValueError(msg) @@ -1371,7 +1450,7 @@ def _calculate_energy_and_gradients( else: offset = int(not self._restart) lam = self._lambda_interpolate[0] + ( - (self._step / (self._interpolate_steps - offset)) + self._step / (self._interpolate_steps - offset) ) * (self._lambda_interpolate[1] - self._lambda_interpolate[0]) if lam < 0.0: lam = 0.0 @@ -1503,12 +1582,17 @@ def _sire_callback(self, atomic_numbers, charges_mm, xyz_qm, xyz_mm, idx_mm=None _logger.error(msg) raise ValueError(msg) + if self._lj_mode is not None: + # Determine LJ parameters for the QM region. + lj_params_mm = self._lj_params_mm[idx_mm, :, :] + # Compute the energy and gradients. E_vac, grad_vac, E_tot, grad_qm, grad_mm = self._calculate_energy_and_gradients( atomic_numbers, charges_mm, xyz_qm, xyz_mm, + lj_params_mm=lj_params_mm, ) # Store the number of MM atoms. diff --git a/emle/models/_ani.py b/emle/models/_ani.py index f21ff16..c922c25 100644 --- a/emle/models/_ani.py +++ b/emle/models/_ani.py @@ -375,7 +375,7 @@ def forward( ------- result: torch.Tensor (3,) or (3, BATCH) - The ANI2x and static and induced EMLE energy components in Hartree. + The ANI2x and static, induced and LJ EMLE energy components in Hartree. """ # Batch the inputs if necessary. if atomic_numbers.ndim == 1: @@ -407,4 +407,4 @@ def forward( E_emle = self._emle(atomic_numbers, charges_mm, xyz_qm, xyz_mm, qm_charge) # Return the ANI2x and EMLE energy components. - return _torch.stack((E_vac, E_emle[0], E_emle[1])) + return _torch.stack((E_vac, E_emle[0], E_emle[1], E_emle[2])) diff --git a/emle/models/_emle.py b/emle/models/_emle.py index 2d6b29c..778d8c1 100644 --- a/emle/models/_emle.py +++ b/emle/models/_emle.py @@ -27,17 +27,18 @@ __all__ = ["EMLE"] -import numpy as _np import os as _os +from typing import Union + +import numpy as _np import scipy.io as _scipy_io import torch as _torch import torchani as _torchani - from torch import Tensor -from typing import Union -from . import _patches from . import EMLEBase as _EMLEBase +from . import _patches +from .._units import _HARTREE_TO_KJ_MOL, _BOHR_TO_ANGSTROM # Monkey-patch the TorchANI BuiltInModel and BuiltinEnsemble classes so that # they call self.aev_computer using args only to allow forward hooks to work @@ -86,6 +87,9 @@ def __init__( atomic_numbers=None, qm_charge=0, mm_charges=None, + lj_mode=None, + lj_params_qm=None, + lj_xyz_qm=None, device=None, dtype=None, create_aev_calculator=True, @@ -137,6 +141,26 @@ def __init__( List of MM charges for atoms in the QM region in units of mod electron charge. This is required if the 'mm' method is specified. + lj_mode: str + How the LJ parameters are calculated. + "flexible": + Lennard-Jones parameters are calculated dynamically for a given configuration. + "fixed": + Lennard-Jones parameters are fixed, i.e. independent of the configuration. + Requires specifying the LJ parameters for each atom in the QM region or to + provide an initial configuration. + None + Lennard-Jones parameters and interactions are not included. + + lj_params_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Lennard-Jones parameters for each atom in the QM region (sigma, epsilon) in units of nanometers (sigma) + and kJ/mol (epsilon). This is required if the "lj_mode" is "fixed" and lj_param_qm is not provided. + Takes precedence over lj_xyz_qm. + + lj_xyz_qm: List[List[float]], Tuple[List[List[Float]]], numpy.ndarray, torch.Tensor + Positions of QM atoms in Angstrom. This is required if the "lj_mode" is "fixed" + and lj_param_qm is not provided. + device: torch.device The device on which to run the model. @@ -234,6 +258,53 @@ def __init__( # Use the default species. species = self._species + if lj_mode is not None: + if not isinstance(lj_mode, str): + raise TypeError("'lj_mode' must be of type 'str'") + + lj_mode = lj_mode.lower().replace(" ", "") + if lj_mode not in {"flexible", "fixed"}: + raise ValueError("'lj_mode' must be 'flexible' or 'fixed'") + + if lj_mode == "fixed": + if lj_params_qm is None and lj_xyz_qm is None: + raise ValueError( + "lj_params_qm or lj_xyz_qm must be provided if lj_mode is 'fixed'" + ) + + if lj_params_qm is not None: + if not isinstance( + lj_params_qm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_params_qm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + raise TypeError( + "lj_params_qm must be a list of lists, tuples, or arrays" + ) + + lj_params_qm = _torch.tensor( + lj_params_qm, dtype=dtype, device=device + ) + self._lj_epsilon_qm = lj_params_qm[:, 1] / _HARTREE_TO_KJ_MOL + self._lj_sigma_qm = lj_params_qm[:, 0] * 10.0 / _BOHR_TO_ANGSTROM + self._lj_xyz_qm = None + else: + if not isinstance( + lj_xyz_qm, (list, tuple, _np.ndarray, _torch.Tensor) + ) or not isinstance( + lj_xyz_qm[0], (list, tuple, _np.ndarray, _torch.Tensor) + ): + raise TypeError( + "lj_xyz_qm must be a list of lists, tuples, or arrays" + ) + self._lj_epsilon_qm = None + self._lj_sigma_qm = None + self._lj_xyz_qm = _torch.tensor( + lj_xyz_qm, dtype=dtype, device=device + ) + + self._lj_mode = lj_mode + if device is not None: if not isinstance(device, _torch.device): raise TypeError("'device' must be of type 'torch.device'") @@ -276,6 +347,11 @@ def __init__( if "sqrtk_ref" in params else None ), + "ref_values_c6": ( + _torch.tensor(params["c6_ref"], dtype=dtype, device=device) + if "c6_ref" in params + else None + ), } if method == "mm": @@ -351,11 +427,15 @@ def __init__( q_core, emle_aev_computer=emle_aev_computer, alpha_mode=self._alpha_mode, + lj_mode=self._lj_mode, species=params.get("species", self._species), device=device, dtype=dtype, ) + if lj_xyz_qm is not None: + pass + def to(self, *args, **kwargs): """ Performs Tensor dtype and/or device conversion on the model. @@ -417,6 +497,7 @@ def forward( charges_mm: Tensor, xyz_qm: Tensor, xyz_mm: Tensor, + lj_params_mm: Tensor = None, qm_charge: Union[int, Tensor] = 0, ) -> Tensor: """ @@ -440,11 +521,14 @@ def forward( qm_charge: int or torch.Tensor (BATCH,) The charge on the QM region. + lj_params_mm: torch.Tensor (N_MM_ATOMS, 2) or (BATCH, N_MM_ATOMS, 2) + Lennard-Jones parameters for MM atoms in nanometers (sigma) and kJ/mol (epsilon). + Returns ------- - result: torch.Tensor (2,) or (2, BATCH) - The static and induced EMLE energy components in Hartree. + result: torch.Tensor (3,) or (3, BATCH) + The static, induced, and LJ EMLE energy components in Hartree. """ # Store the inputs as internal attributes. self._atomic_numbers = atomic_numbers @@ -459,6 +543,9 @@ def forward( self._xyz_qm = self._xyz_qm.unsqueeze(0) self._xyz_mm = self._xyz_mm.unsqueeze(0) + if lj_params_mm is not None: + self._lj_params_mm = lj_params_mm.unsqueeze(0) + batch_size = self._atomic_numbers.shape[0] # Ensure qm_charge is a tensor and repeat for batch size if necessary @@ -479,20 +566,16 @@ def forward( 2, batch_size, dtype=self._xyz_qm.dtype, device=self._xyz_qm.device ) - # Get the parameters from the base model: - # valence widths, core charges, valence charges, A_thole tensor - # These are returned as batched tensors, so we need to extract the - # first element of each. - s, q_core, q_val, A_thole = self._emle_base( + # Get the parameters from the base model. + s, q_core, q_val, A_thole, c6 = self._emle_base( self._atomic_numbers, self._xyz_qm, qm_charge, ) # Convert coordinates to Bohr. - ANGSTROM_TO_BOHR = 1.8897261258369282 - xyz_qm_bohr = self._xyz_qm * ANGSTROM_TO_BOHR - xyz_mm_bohr = self._xyz_mm * ANGSTROM_TO_BOHR + xyz_qm_bohr = self._xyz_qm / _BOHR_TO_ANGSTROM + xyz_mm_bohr = self._xyz_mm / _BOHR_TO_ANGSTROM # Compute the static energy. if self._method == "mm": @@ -523,4 +606,45 @@ def forward( E_static, dtype=self._charges_mm.dtype, device=self._device ) - return _torch.stack((E_static, E_ind), dim=0) + # Compute the LJ energy + if self._lj_mode is None: + E_lj = _torch.zeros_like( + E_static, dtype=self._charges_mm.dtype, device=self._device + ) + else: + # Convert MM LJ parameters + sigma_mm = lj_params_mm[:, :, 0] * 10.0 / _BOHR_TO_ANGSTROM + epsilon_mm = lj_params_mm[:, :, 1] / _HARTREE_TO_KJ_MOL + + if self._lj_mode == "flexible": + alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) + sigma_qm, epsilon_qm = self._emle_base.get_lj_parameters(c6, alpha_qm) + + elif self._lj_mode == "fixed": + if self._lj_sigma_qm is None or self._lj_epsilon_qm is None: + if self._lj_xyz_qm is None: + raise RuntimeError( + "LJ mode is 'fixed', but LJ parameters are not set and lj_xyz_qm is missing." + ) + lj_xyz_qm = ( + self._lj_xyz_qm.unsqueeze(0) + if self._lj_xyz_qm.ndim == 2 + else self._lj_xyz_qm + ) + _, _, _, A_thole, c6 = self._emle_base( + self._atomic_numbers[0:1, :], lj_xyz_qm, qm_charge[0:1] + ) + alpha_qm = self._emle_base.get_isotropic_polarizabilities(A_thole) + self._lj_sigma_qm, self._lj_epsilon_qm = ( + self._emle_base.get_lj_parameters(c6, alpha_qm) + ) + + sigma_qm = self._lj_sigma_qm.expand(batch_size, -1) + epsilon_qm = self._lj_epsilon_qm.expand(batch_size, -1) + + # Compute Lennard-Jones energy + E_lj = self._emle_base.get_lj_energy( + sigma_qm, epsilon_qm, sigma_mm, epsilon_mm, mesh_data + ) + + return _torch.stack((E_static, E_ind, E_lj), dim=0) diff --git a/emle/models/_emle_base.py b/emle/models/_emle_base.py index d374775..a8dbcd1 100644 --- a/emle/models/_emle_base.py +++ b/emle/models/_emle_base.py @@ -27,14 +27,12 @@ __all__ = ["EMLEBase"] -import numpy as _np - -import torch as _torch - -from torch import Tensor from typing import Tuple +import numpy as _np +import torch as _torch import torchani as _torchani +from torch import Tensor try: import NNPOps as _NNPOps @@ -66,6 +64,7 @@ def __init__( emle_aev_computer=None, species=None, alpha_mode="species", + lj_mode=None, device=None, dtype=None, ): @@ -95,6 +94,10 @@ def __init__( scaling factors are obtained with GPR using the values learned for each reference environment + lj_mode: str + Mode for calculating the Lennard-Jones potential. + If None, the Lennard-Jones potential is not calculated. + emle_aev_computer: EMLEAEVComputer EMLE AEV computer instance used to compute AEVs (masked and normalized). @@ -107,7 +110,6 @@ def __init__( dtype: torch.dtype The data type to use for the models floating point tensors. """ - # Call the base class constructor. super().__init__() @@ -116,7 +118,14 @@ def __init__( raise TypeError("'params' must be of type 'dict'") if not all( k in params - for k in ["a_QEq", "a_Thole", "ref_values_s", "ref_values_chi", "k_Z"] + for k in [ + "a_QEq", + "a_Thole", + "ref_values_s", + "ref_values_chi", + "k_Z", + "ref_values_c6", + ] ): raise ValueError( "'params' must contain keys 'a_QEq', 'a_Thole', 'ref_values_s', 'ref_values_chi', and 'k_Z'" @@ -206,6 +215,18 @@ def __init__( ) raise ValueError(msg) + if lj_mode is not None: + assert lj_mode in ["fixed", "flexible"], "Invalid Lennard-Jones mode" + try: + self.ref_values_c6 = _torch.nn.Parameter(params["ref_values_c6"]) + except: + msg = ( + "Missing 'ref_values_c6' key in params. This is required when " + "using the Lennard-Jones potential." + ) + raise ValueError(msg) + self._lj_mode = lj_mode + # Validate the species. if species is None: # Use the default species. @@ -241,6 +262,12 @@ def __init__( else: ref_mean_sqrtk, c_sqrtk = self._get_c(n_ref, self.ref_values_sqrtk, Kinv) + if lj_mode is not None: + ref_mean_c6, c_c6 = self._get_c(n_ref, self.ref_values_c6, Kinv) + else: + ref_mean_c6 = _torch.zeros_like(ref_mean_s, dtype=dtype, device=device) + c_c6 = _torch.zeros_like(c_s, dtype=dtype, device=device) + # Store the current device. self._device = device @@ -253,9 +280,11 @@ def __init__( self.register_buffer("_ref_mean_s", ref_mean_s) self.register_buffer("_ref_mean_chi", ref_mean_chi) self.register_buffer("_ref_mean_sqrtk", ref_mean_sqrtk) + self.register_buffer("_ref_mean_c6", ref_mean_c6) self.register_buffer("_c_s", c_s) self.register_buffer("_c_chi", c_chi) self.register_buffer("_c_sqrtk", c_sqrtk) + self.register_buffer("_c_c6", c_c6) def to(self, *args, **kwargs): """ @@ -270,9 +299,11 @@ def to(self, *args, **kwargs): self._ref_mean_s = self._ref_mean_s.to(*args, **kwargs) self._ref_mean_chi = self._ref_mean_chi.to(*args, **kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(*args, **kwargs) + self._ref_mean_c6 = self._ref_mean_c6.to(*args, **kwargs) self._c_s = self._c_s.to(*args, **kwargs) self._c_chi = self._c_chi.to(*args, **kwargs) self._c_sqrtk = self._c_sqrtk.to(*args, **kwargs) + self._c_c6 = self._c_c6.to(*args, **kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.to(*args, **kwargs)) # Check for a device type in args and update the device attribute. @@ -296,9 +327,11 @@ def cuda(self, **kwargs): self._ref_mean_s = self._ref_mean_s.cuda(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cuda(**kwargs) self._ref_mean_sqrtk = self._ref_mean_sqrtk.cuda(**kwargs) + self._ref_mean_c6 = self._ref_mean_c6.cuda(**kwargs) self._c_s = self._c_s.cuda(**kwargs) self._c_chi = self._c_chi.cuda(**kwargs) self._c_sqrtk = self._c_sqrtk.cuda(**kwargs) + self._c_c6 = self._c_c6.cuda(**kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.cuda(**kwargs)) # Update the device attribute. @@ -318,10 +351,12 @@ def cpu(self, **kwargs): self._n_ref = self._n_ref.cpu(**kwargs) self._ref_mean_s = self._ref_mean_s.cpu(**kwargs) self._ref_mean_chi = self._ref_mean_chi.cpu(**kwargs) - self._ref_mean_sqrtk = self._ref_mean_sqrtk.to(**kwargs) + self._ref_mean_sqrtk = self._ref_mean_sqrtk.cpu(**kwargs) + self._ref_mean_c6 = self._ref_mean_c6.cpu(**kwargs) self._c_s = self._c_s.cpu(**kwargs) self._c_chi = self._c_chi.cpu(**kwargs) self._c_sqrtk = self._c_sqrtk.cpu(**kwargs) + self._c_c6 = self._c_c6.cpu(**kwargs) self.k_Z = _torch.nn.Parameter(self.k_Z.cpu(**kwargs)) # Update the device attribute. @@ -340,9 +375,11 @@ def double(self): self._ref_mean_s = self._ref_mean_s.double() self._ref_mean_chi = self._ref_mean_chi.double() self._ref_mean_sqrtk = self._ref_mean_sqrtk.double() + self._ref_mean_c6 = self._ref_mean_c6.double() self._c_s = self._c_s.double() self._c_chi = self._c_chi.double() self._c_sqrtk = self._c_sqrtk.double() + self._c_c6 = self._c_c6.double() self.k_Z = _torch.nn.Parameter(self.k_Z.double()) return self @@ -357,9 +394,11 @@ def float(self): self._ref_mean_s = self._ref_mean_s.float() self._ref_mean_chi = self._ref_mean_chi.float() self._ref_mean_sqrtk = self._ref_mean_sqrtk.float() + self._ref_mean_c6 = self._ref_mean_c6.float() self._c_s = self._c_s.float() self._c_chi = self._c_chi.float() self._c_sqrtk = self._c_sqrtk.float() + self._c_c6 = self._c_c6.float() self.k_Z = _torch.nn.Parameter(self.k_Z.float()) return self @@ -386,8 +425,9 @@ def forward(self, atomic_numbers, xyz_qm, q_total): result: (torch.Tensor (N_BATCH, N_QM_ATOMS,), torch.Tensor (N_BATCH, N_QM_ATOMS,), torch.Tensor (N_BATCH, N_QM_ATOMS,), - torch.Tensor (N_BATCH, N_QM_ATOMS * 3, N_QM_ATOMS * 3,)) - Valence widths, core charges, valence charges, A_thole tensor + torch.Tensor (N_BATCH, N_QM_ATOMS * 3, N_QM_ATOMS * 3,), + torch.Tensor (N_BATCH, N_QM_ATOMS,)) + Valence widths, core charges, valence charges, A_thole tensor, C6 coefficients """ # Mask for padded coordinates. @@ -410,7 +450,6 @@ def forward(self, atomic_numbers, xyz_qm, q_total): xyz_qm_bohr = xyz_qm * ANGSTROM_TO_BOHR r_data = self._get_r_data(xyz_qm_bohr, mask) - q_core = self._q_core[species_id] * mask q = self._get_q(r_data, s, chi, q_total, mask) q_val = q - q_core @@ -425,7 +464,12 @@ def forward(self, atomic_numbers, xyz_qm, q_total): A_thole = self._get_A_thole(r_data, s, q_val, k, self.a_Thole) - return s, q_core, q_val, A_thole + if self._lj_mode is not None: + c6 = self._gpr(aev, self._ref_mean_c6, self._c_c6, species_id) + else: + c6 = None + + return s, q_core, q_val, A_thole, c6 @classmethod def _get_Kinv(cls, ref_features, sigma): @@ -1021,6 +1065,7 @@ def _get_f1_slater(r: Tensor, s: Tensor) -> Tensor: @staticmethod def _get_T0_slater(r: Tensor, s: Tensor) -> Tensor: """ + # Get distances Internal method, calculates T0 tensor for Slater densities. Parameters @@ -1038,3 +1083,133 @@ def _get_T0_slater(r: Tensor, s: Tensor) -> Tensor: results: torch.Tensor (N_BATCH, MAX_QM_ATOMS, MAX_MM_ATOMS) """ return (1 - (1 + r / (s * 2)) * _torch.exp(-r / s)) / r + + @staticmethod + def get_lj_energy( + sigma_qm: Tensor, + epsilon_qm: Tensor, + sigma_mm: Tensor, + epsilon_mm: Tensor, + mesh_data: Tuple[Tensor, Tensor, Tensor], + ) -> Tensor: + """ + Calculate the Lennard-Jones energy. + + Parameters + ---------- + + sigma_qm: Tensor (N_BATCH, N_QM_ATOMS) + Lennard-Jones sigma values in Bohr. + + epsilon_qm: Tensor (N_BATCH, N_QM_ATOMS) + Lennard-Jones epsilon values in atomic units. + + sigma_mm: Tensor (N_BATCH, N_MM_ATOMS) + Lennard-Jones sigma values in Bohr. + + epsilon_mm: Tensor (N_BATCH, N_MM_ATOMS) + Lennard-Jones epsilon values in atomic units. + + mesh_data: Tuple[Tensor, Tensor, Tensor] + Mesh data tuple containing (r_inv, r_vec, s_outer_product). + r_inv: Tensor (N_BATCH, N_QM_ATOMS, N_MM_ATOMS) of inverse QM-MM distances. + + Returns + ------- + + Tensor (N_BATCH,) + Total Lennard-Jones energy for each batch element in atomic units. + """ + # Lorentz-Berthelot combining rules + # sigma (N_BATCH, N_QM, N_MM) + # epsilon (N_BATCH, N_QM, N_MM) + sigma = 0.5 * (sigma_qm[:, :, None] + sigma_mm[:, None, :]) + epsilon_product = epsilon_qm[:, :, None] * epsilon_mm[:, None, :] + epsilon = _torch.where(epsilon_product > 0, _torch.sqrt(epsilon_product), 0.0) + + # Get distances + # r_inv (N_BATCH, N_QM, N_MM) + r_inv, _, _ = mesh_data + sigma_r_inv_6 = (sigma * r_inv) ** 6 + sigma_r_inv_12 = sigma_r_inv_6 * sigma_r_inv_6 + + # Calculate pairwise energy matrix (N_BATCH, N_QM, N_MM) + lj_energy = 4 * epsilon * (sigma_r_inv_12 - sigma_r_inv_6) + + # Sum over QM and MM atoms for each batch element + lj_energy = lj_energy.sum(dim=(1, 2)) + + return lj_energy + + @staticmethod + def get_isotropic_polarizabilities(A_thole: _torch.Tensor) -> _torch.Tensor: + """ + Calculate isotropic polarizabilities from the A_thole tensor. + + Parameters + ---------- + + A_thole : torch.Tensor(N_BATCH, 3N_ATOMS, 3N_ATOMS) + Full polarizability tensor in block form. + + Returns + ------- + + torch.Tensor(N_BATCH, N_ATOMS) + Isotropic polarizabilities per atom. + """ + + def _get_traces(A_thole: _torch.Tensor) -> _torch.Tensor: + """ + Compute the trace of the inverse of each 3x3 block in each polarizability tensor. + """ + n_mol, dim, _ = A_thole.shape + if dim % 3 != 0: + raise ValueError("Dimension of A_thole must be divisible by 3.") + + n_atoms = dim // 3 + traces = _torch.empty( + (n_mol, n_atoms), dtype=A_thole.dtype, device=A_thole.device + ) + + for mol_idx in range(n_mol): + for atom_idx in range(n_atoms): + block = A_thole[ + mol_idx, + 3 * atom_idx : 3 * atom_idx + 3, + 3 * atom_idx : 3 * atom_idx + 3, + ] + inv_block = _torch.inverse(block) + traces[mol_idx, atom_idx] = _torch.trace(inv_block) + + return traces + + return _get_traces(A_thole) / 3.0 + + def get_lj_parameters( + self, c6: _torch.Tensor, alpha: _torch.Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Calculate Lennard-Jones sigma and epsilon parameters. + + Parameters + ---------- + + c6: _torch.Tensor(N_BATCH, N_ATOMS) + C6 coefficients per atom. + + alpha: _torch.Tensor(N_BATCH, N_ATOMS) + Isotropic polarizabilities per atom. + + Returns + ------- + + Tuple[torch.Tensor, torch.Tensor] + Tuple containing the sigma (Bohr) and epsilon (Hartree) LJ parameters for each atom. + """ + radius = 2.54 * alpha ** (1.0 / 7.0) + rmin = 2 * radius + sigma = rmin / (2 ** (1.0 / 6.0)) + epsilon = c6 / (2 * rmin**6.0) + + return sigma, epsilon diff --git a/emle/models/_mace.py b/emle/models/_mace.py index ff1d379..1e1a33c 100644 --- a/emle/models/_mace.py +++ b/emle/models/_mace.py @@ -123,7 +123,7 @@ def __init__( Available pre-trained models are 'mace-off23-small', 'mace-off23-medium', 'mace-off23-large'. To use a locally trained MACE model, provide the path to the model file. If None, the MACE-OFF23(S) model will be used by default. - If more than one model is provided, only the energy from the first model will be returned + If more than one model is provided, only the energy from the first model will be returned in the forward pass, but the energy and forces from all models will be stored. atomic_numbers: List[int], Tuple[int], numpy.ndarray, torch.Tensor (N_ATOMS,) @@ -198,16 +198,25 @@ def __init__( ) if not isinstance(mace_model, (list, tuple)): - mace_model = [mace_model] if mace_model is None or isinstance(mace_model, str) else None + mace_model = ( + [mace_model] + if mace_model is None or isinstance(mace_model, str) + else None + ) - if mace_model is None or any(not isinstance(i, (str, type(None))) for i in mace_model): - raise TypeError("'mace_model' must be a list, tuple, or str, with elements of type str or None") + if mace_model is None or any( + not isinstance(i, (str, type(None))) for i in mace_model + ): + raise TypeError( + "'mace_model' must be a list, tuple, or str, with elements of type str or None" + ) from mace.tools.scripts_utils import extract_config_mace_model + self._mace_models = _torch.nn.ModuleList() for model in mace_model: source_model = self._load_mace_model(model, device) - + # Extract the config from the model. config = extract_config_mace_model(source_model) @@ -281,7 +290,7 @@ def _load_mace_model(mace_model: str, device: _torch.device): Path to the MACE model file or the name of the pre-trained MACE model. device: torch.device Device on which to load the model. - + Returns ------- source_model: torch.nn.Module @@ -392,9 +401,9 @@ def to(self, *args, **kwargs): """ self._emle = self._emle.to(*args, **kwargs) self._mace = self._mace.to(*args, **kwargs) - self._mace_models = _torch.nn.ModuleList([ - model.to(*args, **kwargs) for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.to(*args, **kwargs) for model in self._mace_models] + ) return self def cpu(self, **kwargs): @@ -405,9 +414,9 @@ def cpu(self, **kwargs): self._mace = self._mace.cpu(**kwargs) if self._atomic_numbers is not None: self._atomic_numbers = self._atomic_numbers.cpu(**kwargs) - self._mace_models = _torch.nn.ModuleList([ - model.cpu(**kwargs) for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.cpu(**kwargs) for model in self._mace_models] + ) return self def cuda(self, **kwargs): @@ -418,9 +427,9 @@ def cuda(self, **kwargs): self._mace = self._mace.cuda(**kwargs) if self._atomic_numbers is not None: self._atomic_numbers = self._atomic_numbers.cuda(**kwargs) - self._mace_models = _torch.nn.ModuleList([ - model.cuda(**kwargs) for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.cuda(**kwargs) for model in self._mace_models] + ) return self def double(self): @@ -429,9 +438,9 @@ def double(self): """ self._emle = self._emle.double() self._mace = self._mace.double() - self._mace_models = _torch.nn.ModuleList([ - model.double() for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.double() for model in self._mace_models] + ) return self def float(self): @@ -440,9 +449,9 @@ def float(self): """ self._emle = self._emle.float() self._mace = self._mace.float() - self._mace_models = _torch.nn.ModuleList([ - model.float() for model in self._mace_models - ]) + self._mace_models = _torch.nn.ModuleList( + [model.float() for model in self._mace_models] + ) return self def forward( @@ -478,7 +487,7 @@ def forward( ------- result: torch.Tensor (3,) - The ANI2x and static and induced EMLE energy components in Hartree. + The MACE and static, induced and LJ EMLE energy components in Hartree. """ # Get the device. device = xyz_qm.device @@ -497,8 +506,17 @@ def forward( num_models = len(self._mace_models) # Create tensors to store the data for QbC. - self._E_vac_qbc = _torch.empty(num_models, num_batches, dtype=self._dtype, device=device) - self._grads_qbc = _torch.empty(num_models, num_batches, xyz_qm.shape[1], 3, dtype=self._dtype, device=device) + self._E_vac_qbc = _torch.empty( + num_models, num_batches, dtype=self._dtype, device=device + ) + self._grads_qbc = _torch.empty( + num_models, + num_batches, + xyz_qm.shape[1], + 3, + dtype=self._dtype, + device=device, + ) # Create tensors to store the results. results_E_vac = _torch.empty(num_batches, dtype=self._dtype, device=device) @@ -550,26 +568,32 @@ def forward( E_vac = self._mace(input_dict, compute_force=False)["interaction_energy"] - assert ( - E_vac is not None - ), "The model did not return any energy. Please check the input." + assert E_vac is not None, ( + "The model did not return any energy. Please check the input." + ) results_E_vac[i] = E_vac[0] * EV_TO_HARTREE # Decouple the positions from the computation graph for the next models. - input_dict["positions"] = input_dict["positions"].clone().detach().requires_grad_(True) + input_dict["positions"] = ( + input_dict["positions"].clone().detach().requires_grad_(True) + ) # Do inference for the other models. if len(self._mace_models) > 1: for j, mace in enumerate(self._mace_models): - E_vac_qbc = mace(input_dict, compute_force=False)["interaction_energy"] + E_vac_qbc = mace(input_dict, compute_force=False)[ + "interaction_energy" + ] - assert ( - E_vac_qbc is not None - ), "The model did not return any energy. Please check the input." + assert E_vac_qbc is not None, ( + "The model did not return any energy. Please check the input." + ) # Calculate the gradients - grads_qbc = _torch.autograd.grad([E_vac_qbc], [input_dict["positions"]])[0] + grads_qbc = _torch.autograd.grad( + [E_vac_qbc], [input_dict["positions"]] + )[0] assert grads_qbc is not None, "Gradient computation failed" # Store the results. @@ -582,6 +606,7 @@ def forward( zero = _torch.tensor(0.0, dtype=xyz_qm.dtype, device=device) results_E_emle_static[i] = zero results_E_emle_induced[i] = zero + results_E_emle_lj[i] = zero else: # Get the EMLE energy components. E_emle = self._emle( @@ -589,8 +614,14 @@ def forward( ) results_E_emle_static[i] = E_emle[0][0] results_E_emle_induced[i] = E_emle[1][0] + results_E_emle_lj[i] = E_emle[2][0] # Return the MACE and EMLE energy components. return _torch.stack( - [results_E_vac, results_E_emle_static, results_E_emle_induced] + [ + results_E_vac, + results_E_emle_static, + results_E_emle_induced, + results_E_emle_lj, + ] ) diff --git a/emle/models/_patches.py b/emle/models/_patches.py index 2e67ee1..b21f4c9 100644 --- a/emle/models/_patches.py +++ b/emle/models/_patches.py @@ -32,6 +32,7 @@ # convert atom species from string to long tensor model0.species_to_tensor(['C', 'H', 'H', 'H', 'H']) """ + import os import torch import torchani diff --git a/emle/train/_loss.py b/emle/train/_loss.py index 0d1eb14..8248d3c 100644 --- a/emle/train/_loss.py +++ b/emle/train/_loss.py @@ -134,7 +134,7 @@ def forward(self, atomic_numbers, xyz, q_mol, q_target): self._update_chi_gpr(self._emle_base) # Calculate q_core and q_val - _, q_core, q_val, _ = self._emle_base(atomic_numbers, xyz, q_mol) + _, q_core, q_val, *_ = self._emle_base(atomic_numbers, xyz, q_mol) mask = atomic_numbers > 0 target = q_target[mask] @@ -275,7 +275,7 @@ def forward( self._update_sqrtk_gpr(self._emle_base) # Calculate A_thole and alpha_mol. - _, _, _, A_thole = self._emle_base(atomic_numbers, xyz, q_mol) + _, _, _, A_thole, *_ = self._emle_base(atomic_numbers, xyz, q_mol) alpha_mol = self._get_alpha_mol(A_thole, atomic_numbers > 0) triu_row, triu_col = _torch.triu_indices(3, 3, offset=0) @@ -311,3 +311,74 @@ def _update_sqrtk_gpr(emle_base): emle_base.ref_values_sqrtk, emle_base._Kinv, ) + + +class DispersionCoefficientLoss(_BaseLoss): + """ + Loss function for dispersion coefficients. Used to train ref_values_C6. + """ + + def __init__(self, emle_base, loss=_torch.nn.MSELoss()): + super().__init__() + + from ..models._emle_base import EMLEBase + + if not isinstance(emle_base, EMLEBase): + raise TypeError("emle_base must be an instance of EMLEBase") + self._emle_base = emle_base + + if not isinstance(loss, _torch.nn.Module): + raise TypeError("loss must be an instance of torch.nn.Module") + self._loss = loss + + self._pol = None + + def forward(self, atomic_numbers, xyz, q_mol, c6_target): + """ + Forward pass. + + Parameters + ---------- + atomic_numbers: torch.Tensor(N_BATCH, MAX_N_ATOMS) + Atomic numbers. + + xyz: torch.Tensor(N_BATCH, MAX_N_ATOMS, 3) + Cartesian coordinates. + + q_mol: torch.Tensor(N_BATCH, MAX_N_ATOMS) + Molecular charges. + + c6_target: torch.Tensor(N_BATCH, MAX_N_ATOMS) + Target dispersion coefficients. + """ + # Update reference values for C6. + self._update_c6_gpr(self._emle_base) + + # Calculate C6. + s, q_core, q_val, A_thole, c6 = self._emle_base(atomic_numbers, xyz, q_mol) + # Calculate isotropic polarizabilities if not already calculated. + if self._pol is None: + self._pol = self._emle_base.get_isotropic_polarizabilities(A_thole).detach() + + # Mask out dummy atoms. + mask = atomic_numbers > 0 + target = c6_target[mask] + values = c6 + values = values[mask] + + # Calculate loss. + loss = self._loss(values, target) + + return ( + loss, + self._get_rmse(values, target), + self._get_max_error(values, target), + ) + + @staticmethod + def _update_c6_gpr(emle_base): + emle_base._ref_mean_c6, emle_base._c_c6 = emle_base._get_c( + emle_base._n_ref, + emle_base.ref_values_c6, + emle_base._Kinv, + ) diff --git a/emle/train/_trainer.py b/emle/train/_trainer.py index 2ab296c..4b8f174 100644 --- a/emle/train/_trainer.py +++ b/emle/train/_trainer.py @@ -31,6 +31,7 @@ from ._ivm import IVM as _IVM from ._loss import QEqLoss as _QEqLoss from ._loss import TholeLoss as _TholeLoss +from ._loss import DispersionCoefficientLoss as _DispersionCoefficientLoss from ._utils import pad_to_max as _pad_to_max from ._utils import mean_by_z as _mean_by_z @@ -41,6 +42,7 @@ def __init__( emle_base=_EMLEBase, qeq_loss=_QEqLoss, thole_loss=_TholeLoss, + dispersion_coefficient_loss=_DispersionCoefficientLoss, log_level=None, log_file=None, ): @@ -56,6 +58,12 @@ def __init__( raise TypeError("thole_loss must be a reference to TholeLoss") self._thole_loss = thole_loss + if dispersion_coefficient_loss is not _DispersionCoefficientLoss: + raise TypeError( + "dispersion_coefficient_loss must be a reference to DispersionCoefficientLoss" + ) + self._dispersion_coefficient_loss = dispersion_coefficient_loss + # First handle the logger. if log_level is None: log_level = "INFO" @@ -267,7 +275,7 @@ def _train_loop( optimizer.step() if (epoch + 1) % print_every == 0: _logger.info( - f"Epoch {epoch+1}: Loss ={loss.item():9.4f} " + f"Epoch {epoch + 1}: Loss ={loss.item():9.4f} " f"RMSE ={rmse.item():9.4f} " f"Max Error ={max_error.item():9.4f}" ) @@ -293,7 +301,8 @@ def train( q_core, q_val, alpha, - train_mask, + c6=None, + train_mask=None, alpha_mode="reference", sigma=1e-3, ivm_thr=0.05, @@ -301,6 +310,7 @@ def train( lr_qeq=0.05, lr_thole=0.05, lr_sqrtk=0.05, + lr_c6=0.05, print_every=10, computer_n_species=None, computer_zid_map=None, @@ -333,6 +343,9 @@ def train( alpha: array or tensor or list of tensor/arrays of shape (N_BATCH, 3, 3) Atomic polarizabilities. + c6: array or tensor or list of tensor/arrays of shape (N_BATCH, N_ATOMS, N_ATOMS) + C6 dispersion coefficients. If None, the C6 dispersion coefficients are not trained. + train_mask: torch.Tensor(N_BATCH,) Mask for training samples. @@ -357,6 +370,9 @@ def train( lr_sqrtk: float Learning rate for sqrtk. + lr_c6: float + Learning rate for c6. + print_every: int How often to print training progress. @@ -427,6 +443,11 @@ def train( alpha_train = alpha_train.to(device=device, dtype=dtype) species = species.to(device=device, dtype=_torch.int64) + if c6 is not None: + c6 = _pad_to_max(c6) + c6_train = c6[train_mask] + c6_train = c6_train.to(device=device, dtype=dtype) + # Get zid mapping. zid_mapping = self._get_zid_mapping(species) zid_train = zid_mapping[z_train] @@ -504,6 +525,15 @@ def train( if alpha_mode == "reference" else None ), + "ref_values_c6": ( + _torch.ones( + *ref_values_s.shape, + dtype=ref_values_s.dtype, + device=_torch.device(device), + ) + if c6 is not None + else None + ), } # Create the EMLE base instance. @@ -515,10 +545,10 @@ def train( emle_aev_computer=emle_aev_computer, species=species, alpha_mode=alpha_mode, + lj_mode="fixed", device=_torch.device(device), dtype=dtype, ) - # Fit chi, a_QEq (QEq over chi predicted with GPR). _logger.info("Fitting a_QEq and chi values...") self._train_model( @@ -538,7 +568,6 @@ def train( self._qeq_loss._update_chi_gpr(emle_base) _logger.debug(f"Optimized a_QEq: {emle_base.a_QEq.data.item()}") - # Fit a_Thole, k_Z (uses volumes predicted by QEq model). _logger.info("Fitting a_Thole and k_Z values...") self._train_model( @@ -576,6 +605,24 @@ def train( # (now inconsistent since not updated after the last epoch) self._thole_loss._update_sqrtk_gpr(emle_base) + if c6 is not None: + _logger.info("Fitting ref_values_c6 values...") + self._train_model( + loss_class=self._dispersion_coefficient_loss, + opt_param_names=["ref_values_c6"], + lr=lr_c6, + epochs=2000, + print_every=print_every, + emle_base=emle_base, + atomic_numbers=z_train, + xyz=xyz_train, + q_mol=q_mol_train, + c6_target=c6_train, + ) + + # Update reference values for C6. + self._dispersion_coefficient_loss._update_c6_gpr(emle_base) + # Create the final model. emle_model = { "q_core": q_core_z, @@ -587,6 +634,7 @@ def train( "sqrtk_ref": ( emle_base.ref_values_sqrtk if alpha_mode == "reference" else None ), + "c6_ref": (emle_base.ref_values_c6 if c6 is not None else None), "species": species, "alpha_mode": alpha_mode, "n_ref": n_ref, @@ -603,11 +651,12 @@ def train( return emle_base emle_base._alpha_mode = "species" - s_pred, q_core_pred, q_val_pred, A_thole = emle_base( + s_pred, q_core_pred, q_val_pred, A_thole, c6_pred = emle_base( z.to(device=device, dtype=_torch.int64), xyz.to(device=device, dtype=dtype), q_mol, ) + z_mask = _torch.tensor(z > 0, device=device) plot_data = { "s_emle": s_pred, @@ -632,6 +681,10 @@ def train( A_thole, z_mask ) + if c6 is not None: + plot_data["c6_qm"] = c6 + plot_data["c6_emle"] = c6_pred + self._write_model_to_file(plot_data, plot_data_filename) return emle_base diff --git a/versioneer.py b/versioneer.py index de97d90..080191b 100644 --- a/versioneer.py +++ b/versioneer.py @@ -514,9 +514,7 @@ def run_command( return stdout, process.returncode -LONG_VERSION_PY[ - "git" -] = r''' +LONG_VERSION_PY["git"] = r''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -1840,9 +1838,9 @@ def get_versions(verbose: bool = False) -> Dict[str, Any]: handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" + assert cfg.versionfile_source is not None, ( + "please set versioneer.versionfile_source" + ) assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source)