|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +import numpy as np |
| 3 | +import torch |
| 4 | +from torch_admp.pme import ( |
| 5 | + CoulombForceModule, |
| 6 | +) |
| 7 | +from torch_admp.utils import ( |
| 8 | + calc_grads, |
| 9 | +) |
| 10 | + |
| 11 | +from deepmd.pt.modifier.base_modifier import ( |
| 12 | + BaseModifier, |
| 13 | +) |
| 14 | +from deepmd.pt.utils import ( |
| 15 | + env, |
| 16 | +) |
| 17 | +from deepmd.pt.utils.utils import ( |
| 18 | + to_torch_tensor, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +@BaseModifier.register("dipole_charge") |
| 23 | +class DipoleChargeModifier(BaseModifier): |
| 24 | + """Parameters |
| 25 | + ---------- |
| 26 | + model_name |
| 27 | + The model file for the DeepDipole model |
| 28 | + model_charge_map |
| 29 | + Gives the amount of charge for the wfcc |
| 30 | + sys_charge_map |
| 31 | + Gives the amount of charge for the real atoms |
| 32 | + ewald_h |
| 33 | + Grid spacing of the reciprocal part of Ewald sum. Unit: A |
| 34 | + ewald_beta |
| 35 | + Splitting parameter of the Ewald sum. Unit: A^{-1} |
| 36 | + """ |
| 37 | + |
| 38 | + def __new__( |
| 39 | + cls, *args: tuple, model_name: str | None = None, **kwargs: dict |
| 40 | + ) -> "DipoleChargeModifier": |
| 41 | + return super().__new__(cls, model_name) |
| 42 | + |
| 43 | + def __init__( |
| 44 | + self, |
| 45 | + model_name: str, |
| 46 | + model_charge_map: list[float], |
| 47 | + sys_charge_map: list[float], |
| 48 | + ewald_h: float = 1.0, |
| 49 | + ewald_beta: float = 1.0, |
| 50 | + ) -> None: |
| 51 | + """Constructor.""" |
| 52 | + super().__init__() |
| 53 | + self.modifier_type = "dipole_charge" |
| 54 | + self.model_name = model_name |
| 55 | + |
| 56 | + self.model = torch.jit.load(model_name, map_location=env.DEVICE) |
| 57 | + self.rcut = self.model.get_rcut() |
| 58 | + self.type_map = self.model.get_type_map() |
| 59 | + sel_type = self.model.get_sel_type() |
| 60 | + self.sel_type = to_torch_tensor(np.array(sel_type)) |
| 61 | + self.model_charge_map = to_torch_tensor(np.array(model_charge_map)) |
| 62 | + self.sys_charge_map = to_torch_tensor(np.array(sys_charge_map)) |
| 63 | + self._model_charge_map = model_charge_map |
| 64 | + self._sys_charge_map = sys_charge_map |
| 65 | + |
| 66 | + # init ewald recp |
| 67 | + self.ewald_h = ewald_h |
| 68 | + self.ewald_beta = ewald_beta |
| 69 | + self.er = CoulombForceModule( |
| 70 | + rcut=self.rcut, |
| 71 | + rspace=False, |
| 72 | + kappa=ewald_beta, |
| 73 | + spacing=ewald_h, |
| 74 | + ) |
| 75 | + self.placeholder_pairs = torch.ones((1, 2), device=env.DEVICE, dtype=torch.long) |
| 76 | + self.placeholder_ds = torch.ones((1), device=env.DEVICE, dtype=torch.float64) |
| 77 | + self.placeholder_buffer_scales = torch.zeros( |
| 78 | + (1), device=env.DEVICE, dtype=torch.float64 |
| 79 | + ) |
| 80 | + |
| 81 | + def serialize(self) -> dict: |
| 82 | + """Serialize the modifier. |
| 83 | +
|
| 84 | + Returns |
| 85 | + ------- |
| 86 | + dict |
| 87 | + The serialized data |
| 88 | + """ |
| 89 | + data = { |
| 90 | + "@class": "Modifier", |
| 91 | + "type": self.modifier_type, |
| 92 | + "@version": 3, |
| 93 | + "model_name": self.model_name, |
| 94 | + "model_charge_map": self._model_charge_map, |
| 95 | + "sys_charge_map": self._sys_charge_map, |
| 96 | + "ewald_h": self.ewald_h, |
| 97 | + "ewald_beta": self.ewald_beta, |
| 98 | + } |
| 99 | + return data |
| 100 | + |
| 101 | + def forward( |
| 102 | + self, |
| 103 | + coord: torch.Tensor, |
| 104 | + atype: torch.Tensor, |
| 105 | + box: torch.Tensor | None = None, |
| 106 | + fparam: torch.Tensor | None = None, |
| 107 | + aparam: torch.Tensor | None = None, |
| 108 | + do_atomic_virial: bool = False, |
| 109 | + ) -> dict[str, torch.Tensor]: |
| 110 | + """Compute energy, force, and virial corrections for dipole-charge systems. |
| 111 | +
|
| 112 | + This method extends the system with Wannier Function Charge Centers (WFCC) |
| 113 | + by adding dipole vectors to atomic coordinates for selected atom types. |
| 114 | + It then calculates the electrostatic interactions using Ewald reciprocal |
| 115 | + summation to obtain energy, force, and virial corrections. |
| 116 | +
|
| 117 | + Parameters |
| 118 | + ---------- |
| 119 | + coord : torch.Tensor |
| 120 | + The coordinates of atoms with shape (nframes, natoms, 3) |
| 121 | + atype : torch.Tensor |
| 122 | + The atom types with shape (nframes, natoms) |
| 123 | + box : torch.Tensor | None, optional |
| 124 | + The simulation box with shape (nframes, 3, 3), by default None |
| 125 | + Note: This modifier can only be applied for periodic systems |
| 126 | + fparam : torch.Tensor | None, optional |
| 127 | + Frame parameters with shape (nframes, nfp), by default None |
| 128 | + aparam : torch.Tensor | None, optional |
| 129 | + Atom parameters with shape (nframes, natoms, nap), by default None |
| 130 | + do_atomic_virial : bool, optional |
| 131 | + Whether to compute atomic virial, by default False |
| 132 | +
|
| 133 | + Returns |
| 134 | + ------- |
| 135 | + dict[str, torch.Tensor] |
| 136 | + Dictionary containing the correction terms: |
| 137 | + - energy: Energy correction tensor with shape (nframes, 1) |
| 138 | + - force: Force correction tensor with shape (nframes, natoms+nsel, 3) |
| 139 | + - virial: Virial correction tensor with shape (nframes, 3, 3) |
| 140 | + """ |
| 141 | + if box is None: |
| 142 | + raise RuntimeWarning( |
| 143 | + "dipole_charge data modifier can only be applied for periodic systems." |
| 144 | + ) |
| 145 | + else: |
| 146 | + modifier_pred = {} |
| 147 | + nframes = coord.shape[0] |
| 148 | + natoms = coord.shape[1] |
| 149 | + |
| 150 | + input_box = box.reshape(nframes, 9) |
| 151 | + input_box.requires_grad_(True) |
| 152 | + |
| 153 | + detached_box = input_box.detach() |
| 154 | + sfactor = torch.matmul( |
| 155 | + torch.linalg.inv(detached_box.reshape(nframes, 3, 3)), |
| 156 | + input_box.reshape(nframes, 3, 3), |
| 157 | + ) |
| 158 | + input_coord = torch.matmul(coord, sfactor).reshape(nframes, -1) |
| 159 | + |
| 160 | + extended_coord, extended_charge = self.extend_system( |
| 161 | + input_coord, |
| 162 | + atype, |
| 163 | + input_box, |
| 164 | + fparam, |
| 165 | + aparam, |
| 166 | + ) |
| 167 | + |
| 168 | + tot_e = [] |
| 169 | + # add Ewald reciprocal correction |
| 170 | + for ii in range(nframes): |
| 171 | + self.er( |
| 172 | + extended_coord[ii].reshape((-1, 3)), |
| 173 | + input_box[ii].reshape((3, 3)), |
| 174 | + self.placeholder_pairs, |
| 175 | + self.placeholder_ds, |
| 176 | + self.placeholder_buffer_scales, |
| 177 | + {"charge": extended_charge[ii].reshape((-1,))}, |
| 178 | + ) |
| 179 | + tot_e.append(self.er.reciprocal_energy.unsqueeze(0)) |
| 180 | + # nframe, |
| 181 | + tot_e = torch.concat(tot_e, dim=0) |
| 182 | + # nframe, nat * 3 |
| 183 | + tot_f = -calc_grads(tot_e, input_coord) |
| 184 | + # nframe, nat, 3 |
| 185 | + tot_f = torch.reshape(tot_f, (nframes, natoms, 3)) |
| 186 | + # nframe, 9 |
| 187 | + tot_v = calc_grads(tot_e, input_box) |
| 188 | + tot_v = torch.reshape(tot_v, (nframes, 3, 3)) |
| 189 | + # nframe, 3, 3 |
| 190 | + tot_v = -torch.matmul( |
| 191 | + tot_v.transpose(2, 1), input_box.reshape(nframes, 3, 3) |
| 192 | + ) |
| 193 | + |
| 194 | + modifier_pred["energy"] = tot_e |
| 195 | + modifier_pred["force"] = tot_f |
| 196 | + modifier_pred["virial"] = tot_v |
| 197 | + return modifier_pred |
| 198 | + |
| 199 | + def extend_system( |
| 200 | + self, |
| 201 | + coord: torch.Tensor, |
| 202 | + atype: torch.Tensor, |
| 203 | + box: torch.Tensor, |
| 204 | + fparam: torch.Tensor | None = None, |
| 205 | + aparam: torch.Tensor | None = None, |
| 206 | + ) -> tuple[torch.Tensor, torch.Tensor]: |
| 207 | + """Extend the system with WFCC (Wannier Function Charge Centers). |
| 208 | +
|
| 209 | + Parameters |
| 210 | + ---------- |
| 211 | + coord : torch.Tensor |
| 212 | + The coordinates of atoms with shape (nframes, natoms * 3) |
| 213 | + atype : torch.Tensor |
| 214 | + The atom types with shape (nframes, natoms) |
| 215 | + box : torch.Tensor |
| 216 | + The simulation box with shape (nframes, 9) |
| 217 | + fparam : torch.Tensor | None, optional |
| 218 | + Frame parameters with shape (nframes, nfp), by default None |
| 219 | + aparam : torch.Tensor | None, optional |
| 220 | + Atom parameters with shape (nframes, natoms, nap), by default None |
| 221 | +
|
| 222 | + Returns |
| 223 | + ------- |
| 224 | + tuple |
| 225 | + (extended_coord, extended_charge) |
| 226 | + extended_coord : torch.Tensor |
| 227 | + Extended coordinates with shape (nframes, (natoms + nsel) * 3) |
| 228 | + extended_charge : torch.Tensor |
| 229 | + Extended charges with shape (nframes, natoms + nsel) |
| 230 | + """ |
| 231 | + nframes = coord.shape[0] |
| 232 | + mask = make_mask(self.sel_type, atype) |
| 233 | + |
| 234 | + extended_coord = self.extend_system_coord( |
| 235 | + coord, |
| 236 | + atype, |
| 237 | + box, |
| 238 | + fparam, |
| 239 | + aparam, |
| 240 | + ) |
| 241 | + # Get ion charges based on atom types |
| 242 | + # nframe x nat |
| 243 | + ion_charge = self.sys_charge_map[atype] |
| 244 | + # Initialize wfcc charges |
| 245 | + wc_charge = torch.zeros_like(ion_charge) |
| 246 | + # Assign charges to selected atom types |
| 247 | + for ii, charge in enumerate(self.model_charge_map): |
| 248 | + wc_charge[atype == self.sel_type[ii]] = charge |
| 249 | + # Get the charges for selected atoms only |
| 250 | + wc_charge_selected = wc_charge[mask].reshape(nframes, -1) |
| 251 | + # Concatenate ion charges and wfcc charges |
| 252 | + extended_charge = torch.cat([ion_charge, wc_charge_selected], dim=1) |
| 253 | + return extended_coord, extended_charge |
| 254 | + |
| 255 | + def extend_system_coord( |
| 256 | + self, |
| 257 | + coord: torch.Tensor, |
| 258 | + atype: torch.Tensor, |
| 259 | + box: torch.Tensor, |
| 260 | + fparam: torch.Tensor | None = None, |
| 261 | + aparam: torch.Tensor | None = None, |
| 262 | + ) -> torch.Tensor: |
| 263 | + """Extend the system with WFCC (Wannier Function Charge Centers). |
| 264 | +
|
| 265 | + This function calculates Wannier Function Charge Centers (WFCC) by adding dipole |
| 266 | + vectors to atomic coordinates for selected atom types, then concatenates these |
| 267 | + WFCC coordinates with the original atomic coordinates. |
| 268 | +
|
| 269 | + Parameters |
| 270 | + ---------- |
| 271 | + coord : torch.Tensor |
| 272 | + The coordinates of atoms with shape (nframes, natoms * 3) |
| 273 | + atype : torch.Tensor |
| 274 | + The atom types with shape (nframes, natoms) |
| 275 | + box : torch.Tensor |
| 276 | + The simulation box with shape (nframes, 9) |
| 277 | + fparam : torch.Tensor | None, optional |
| 278 | + Frame parameters with shape (nframes, nfp), by default None |
| 279 | + aparam : torch.Tensor | None, optional |
| 280 | + Atom parameters with shape (nframes, natoms, nap), by default None |
| 281 | +
|
| 282 | + Returns |
| 283 | + ------- |
| 284 | + all_coord : torch.Tensor |
| 285 | + Extended coordinates with shape (nframes, (natoms + nsel) * 3) |
| 286 | + where nsel is the number of selected atoms |
| 287 | + """ |
| 288 | + mask = make_mask(self.sel_type, atype) |
| 289 | + |
| 290 | + nframes = coord.shape[0] |
| 291 | + natoms = coord.shape[1] // 3 |
| 292 | + |
| 293 | + all_dipole = [] |
| 294 | + for ii in range(nframes): |
| 295 | + dipole_batch = self.model( |
| 296 | + coord=coord[ii].reshape(1, -1), |
| 297 | + atype=atype[ii].reshape(1, -1), |
| 298 | + box=box[ii].reshape(1, -1), |
| 299 | + do_atomic_virial=False, |
| 300 | + fparam=fparam[ii].reshape(1, -1) if fparam is not None else None, |
| 301 | + aparam=aparam[ii].reshape(1, -1) if aparam is not None else None, |
| 302 | + ) |
| 303 | + # Extract dipole from the output dictionary |
| 304 | + all_dipole.append(dipole_batch["dipole"]) |
| 305 | + |
| 306 | + # nframe x natoms x 3 |
| 307 | + dipole = torch.cat(all_dipole, dim=0) |
| 308 | + assert dipole.shape[0] == nframes |
| 309 | + |
| 310 | + dipole_reshaped = dipole.reshape(nframes, natoms, 3) |
| 311 | + coord_reshaped = coord.reshape(nframes, natoms, 3) |
| 312 | + _wfcc_coord = coord_reshaped + dipole_reshaped |
| 313 | + # Apply mask and reshape |
| 314 | + wfcc_coord = _wfcc_coord[mask.unsqueeze(-1).expand_as(_wfcc_coord)] |
| 315 | + wfcc_coord = wfcc_coord.reshape(nframes, -1) |
| 316 | + all_coord = torch.cat((coord, wfcc_coord), dim=1) |
| 317 | + return all_coord |
| 318 | + |
| 319 | + |
| 320 | +@torch.jit.export |
| 321 | +def make_mask( |
| 322 | + sel_type: torch.Tensor, |
| 323 | + atype: torch.Tensor, |
| 324 | +) -> torch.Tensor: |
| 325 | + """Create a boolean mask for selected atom types. |
| 326 | +
|
| 327 | + Parameters |
| 328 | + ---------- |
| 329 | + sel_type : torch.Tensor |
| 330 | + The selected atom types to create a mask for |
| 331 | + atype : torch.Tensor |
| 332 | + The atom types in the system |
| 333 | +
|
| 334 | + Returns |
| 335 | + ------- |
| 336 | + mask : torch.Tensor |
| 337 | + Boolean mask where True indicates atoms of selected types |
| 338 | + """ |
| 339 | + # Ensure tensors are of the right type |
| 340 | + sel_type = sel_type.to(torch.long) |
| 341 | + atype = atype.to(torch.long) |
| 342 | + |
| 343 | + # Create mask using broadcasting |
| 344 | + mask = torch.zeros_like(atype, dtype=torch.bool) |
| 345 | + for t in sel_type: |
| 346 | + mask = mask | (atype == t) |
| 347 | + return mask |
0 commit comments