Skip to content

Commit a84ea9f

Browse files
committed
add dplr(torch)
1 parent f736ab2 commit a84ea9f

File tree

8 files changed

+577
-1
lines changed

8 files changed

+577
-1
lines changed

deepmd/pt/modifier/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@
77
from .base_modifier import (
88
BaseModifier,
99
)
10+
from .dipole_charge import (
11+
DipoleChargeModifier,
12+
)
1013

1114
__all__ = [
1215
"BaseModifier",
16+
"DipoleChargeModifier",
1317
"get_data_modifier",
1418
]
1519

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import numpy as np
3+
import torch
4+
from torch_admp.pme import (
5+
CoulombForceModule,
6+
)
7+
from torch_admp.utils import (
8+
calc_grads,
9+
)
10+
11+
from deepmd.pt.modifier.base_modifier import (
12+
BaseModifier,
13+
)
14+
from deepmd.pt.utils import (
15+
env,
16+
)
17+
from deepmd.pt.utils.utils import (
18+
to_torch_tensor,
19+
)
20+
21+
22+
@BaseModifier.register("dipole_charge")
23+
class DipoleChargeModifier(BaseModifier):
24+
"""Parameters
25+
----------
26+
model_name
27+
The model file for the DeepDipole model
28+
model_charge_map
29+
Gives the amount of charge for the wfcc
30+
sys_charge_map
31+
Gives the amount of charge for the real atoms
32+
ewald_h
33+
Grid spacing of the reciprocal part of Ewald sum. Unit: A
34+
ewald_beta
35+
Splitting parameter of the Ewald sum. Unit: A^{-1}
36+
"""
37+
38+
def __new__(
39+
cls, *args: tuple, model_name: str | None = None, **kwargs: dict
40+
) -> "DipoleChargeModifier":
41+
return super().__new__(cls, model_name)
42+
43+
def __init__(
44+
self,
45+
model_name: str,
46+
model_charge_map: list[float],
47+
sys_charge_map: list[float],
48+
ewald_h: float = 1.0,
49+
ewald_beta: float = 1.0,
50+
) -> None:
51+
"""Constructor."""
52+
super().__init__()
53+
self.modifier_type = "dipole_charge"
54+
self.model_name = model_name
55+
56+
self.model = torch.jit.load(model_name, map_location=env.DEVICE)
57+
self.rcut = self.model.get_rcut()
58+
self.type_map = self.model.get_type_map()
59+
sel_type = self.model.get_sel_type()
60+
self.sel_type = to_torch_tensor(np.array(sel_type))
61+
self.model_charge_map = to_torch_tensor(np.array(model_charge_map))
62+
self.sys_charge_map = to_torch_tensor(np.array(sys_charge_map))
63+
self._model_charge_map = model_charge_map
64+
self._sys_charge_map = sys_charge_map
65+
66+
# init ewald recp
67+
self.ewald_h = ewald_h
68+
self.ewald_beta = ewald_beta
69+
self.er = CoulombForceModule(
70+
rcut=self.rcut,
71+
rspace=False,
72+
kappa=ewald_beta,
73+
spacing=ewald_h,
74+
)
75+
self.placeholder_pairs = torch.ones((1, 2), device=env.DEVICE, dtype=torch.long)
76+
self.placeholder_ds = torch.ones((1), device=env.DEVICE, dtype=torch.float64)
77+
self.placeholder_buffer_scales = torch.zeros(
78+
(1), device=env.DEVICE, dtype=torch.float64
79+
)
80+
81+
def serialize(self) -> dict:
82+
"""Serialize the modifier.
83+
84+
Returns
85+
-------
86+
dict
87+
The serialized data
88+
"""
89+
data = {
90+
"@class": "Modifier",
91+
"type": self.modifier_type,
92+
"@version": 3,
93+
"model_name": self.model_name,
94+
"model_charge_map": self._model_charge_map,
95+
"sys_charge_map": self._sys_charge_map,
96+
"ewald_h": self.ewald_h,
97+
"ewald_beta": self.ewald_beta,
98+
}
99+
return data
100+
101+
def forward(
102+
self,
103+
coord: torch.Tensor,
104+
atype: torch.Tensor,
105+
box: torch.Tensor | None = None,
106+
fparam: torch.Tensor | None = None,
107+
aparam: torch.Tensor | None = None,
108+
do_atomic_virial: bool = False,
109+
) -> dict[str, torch.Tensor]:
110+
"""Compute energy, force, and virial corrections for dipole-charge systems.
111+
112+
This method extends the system with Wannier Function Charge Centers (WFCC)
113+
by adding dipole vectors to atomic coordinates for selected atom types.
114+
It then calculates the electrostatic interactions using Ewald reciprocal
115+
summation to obtain energy, force, and virial corrections.
116+
117+
Parameters
118+
----------
119+
coord : torch.Tensor
120+
The coordinates of atoms with shape (nframes, natoms, 3)
121+
atype : torch.Tensor
122+
The atom types with shape (nframes, natoms)
123+
box : torch.Tensor | None, optional
124+
The simulation box with shape (nframes, 3, 3), by default None
125+
Note: This modifier can only be applied for periodic systems
126+
fparam : torch.Tensor | None, optional
127+
Frame parameters with shape (nframes, nfp), by default None
128+
aparam : torch.Tensor | None, optional
129+
Atom parameters with shape (nframes, natoms, nap), by default None
130+
do_atomic_virial : bool, optional
131+
Whether to compute atomic virial, by default False
132+
133+
Returns
134+
-------
135+
dict[str, torch.Tensor]
136+
Dictionary containing the correction terms:
137+
- energy: Energy correction tensor with shape (nframes, 1)
138+
- force: Force correction tensor with shape (nframes, natoms+nsel, 3)
139+
- virial: Virial correction tensor with shape (nframes, 3, 3)
140+
"""
141+
if box is None:
142+
raise RuntimeWarning(
143+
"dipole_charge data modifier can only be applied for periodic systems."
144+
)
145+
else:
146+
modifier_pred = {}
147+
nframes = coord.shape[0]
148+
natoms = coord.shape[1]
149+
150+
input_box = box.reshape(nframes, 9)
151+
input_box.requires_grad_(True)
152+
153+
detached_box = input_box.detach()
154+
sfactor = torch.matmul(
155+
torch.linalg.inv(detached_box.reshape(nframes, 3, 3)),
156+
input_box.reshape(nframes, 3, 3),
157+
)
158+
input_coord = torch.matmul(coord, sfactor).reshape(nframes, -1)
159+
160+
extended_coord, extended_charge = self.extend_system(
161+
input_coord,
162+
atype,
163+
input_box,
164+
fparam,
165+
aparam,
166+
)
167+
168+
tot_e = []
169+
# add Ewald reciprocal correction
170+
for ii in range(nframes):
171+
self.er(
172+
extended_coord[ii].reshape((-1, 3)),
173+
input_box[ii].reshape((3, 3)),
174+
self.placeholder_pairs,
175+
self.placeholder_ds,
176+
self.placeholder_buffer_scales,
177+
{"charge": extended_charge[ii].reshape((-1,))},
178+
)
179+
tot_e.append(self.er.reciprocal_energy.unsqueeze(0))
180+
# nframe,
181+
tot_e = torch.concat(tot_e, dim=0)
182+
# nframe, nat * 3
183+
tot_f = -calc_grads(tot_e, input_coord)
184+
# nframe, nat, 3
185+
tot_f = torch.reshape(tot_f, (nframes, natoms, 3))
186+
# nframe, 9
187+
tot_v = calc_grads(tot_e, input_box)
188+
tot_v = torch.reshape(tot_v, (nframes, 3, 3))
189+
# nframe, 3, 3
190+
tot_v = -torch.matmul(
191+
tot_v.transpose(2, 1), input_box.reshape(nframes, 3, 3)
192+
)
193+
194+
modifier_pred["energy"] = tot_e
195+
modifier_pred["force"] = tot_f
196+
modifier_pred["virial"] = tot_v
197+
return modifier_pred
198+
199+
def extend_system(
200+
self,
201+
coord: torch.Tensor,
202+
atype: torch.Tensor,
203+
box: torch.Tensor,
204+
fparam: torch.Tensor | None = None,
205+
aparam: torch.Tensor | None = None,
206+
) -> tuple[torch.Tensor, torch.Tensor]:
207+
"""Extend the system with WFCC (Wannier Function Charge Centers).
208+
209+
Parameters
210+
----------
211+
coord : torch.Tensor
212+
The coordinates of atoms with shape (nframes, natoms * 3)
213+
atype : torch.Tensor
214+
The atom types with shape (nframes, natoms)
215+
box : torch.Tensor
216+
The simulation box with shape (nframes, 9)
217+
fparam : torch.Tensor | None, optional
218+
Frame parameters with shape (nframes, nfp), by default None
219+
aparam : torch.Tensor | None, optional
220+
Atom parameters with shape (nframes, natoms, nap), by default None
221+
222+
Returns
223+
-------
224+
tuple
225+
(extended_coord, extended_charge)
226+
extended_coord : torch.Tensor
227+
Extended coordinates with shape (nframes, (natoms + nsel) * 3)
228+
extended_charge : torch.Tensor
229+
Extended charges with shape (nframes, natoms + nsel)
230+
"""
231+
nframes = coord.shape[0]
232+
mask = make_mask(self.sel_type, atype)
233+
234+
extended_coord = self.extend_system_coord(
235+
coord,
236+
atype,
237+
box,
238+
fparam,
239+
aparam,
240+
)
241+
# Get ion charges based on atom types
242+
# nframe x nat
243+
ion_charge = self.sys_charge_map[atype]
244+
# Initialize wfcc charges
245+
wc_charge = torch.zeros_like(ion_charge)
246+
# Assign charges to selected atom types
247+
for ii, charge in enumerate(self.model_charge_map):
248+
wc_charge[atype == self.sel_type[ii]] = charge
249+
# Get the charges for selected atoms only
250+
wc_charge_selected = wc_charge[mask].reshape(nframes, -1)
251+
# Concatenate ion charges and wfcc charges
252+
extended_charge = torch.cat([ion_charge, wc_charge_selected], dim=1)
253+
return extended_coord, extended_charge
254+
255+
def extend_system_coord(
256+
self,
257+
coord: torch.Tensor,
258+
atype: torch.Tensor,
259+
box: torch.Tensor,
260+
fparam: torch.Tensor | None = None,
261+
aparam: torch.Tensor | None = None,
262+
) -> torch.Tensor:
263+
"""Extend the system with WFCC (Wannier Function Charge Centers).
264+
265+
This function calculates Wannier Function Charge Centers (WFCC) by adding dipole
266+
vectors to atomic coordinates for selected atom types, then concatenates these
267+
WFCC coordinates with the original atomic coordinates.
268+
269+
Parameters
270+
----------
271+
coord : torch.Tensor
272+
The coordinates of atoms with shape (nframes, natoms * 3)
273+
atype : torch.Tensor
274+
The atom types with shape (nframes, natoms)
275+
box : torch.Tensor
276+
The simulation box with shape (nframes, 9)
277+
fparam : torch.Tensor | None, optional
278+
Frame parameters with shape (nframes, nfp), by default None
279+
aparam : torch.Tensor | None, optional
280+
Atom parameters with shape (nframes, natoms, nap), by default None
281+
282+
Returns
283+
-------
284+
all_coord : torch.Tensor
285+
Extended coordinates with shape (nframes, (natoms + nsel) * 3)
286+
where nsel is the number of selected atoms
287+
"""
288+
mask = make_mask(self.sel_type, atype)
289+
290+
nframes = coord.shape[0]
291+
natoms = coord.shape[1] // 3
292+
293+
all_dipole = []
294+
for ii in range(nframes):
295+
dipole_batch = self.model(
296+
coord=coord[ii].reshape(1, -1),
297+
atype=atype[ii].reshape(1, -1),
298+
box=box[ii].reshape(1, -1),
299+
do_atomic_virial=False,
300+
fparam=fparam[ii].reshape(1, -1) if fparam is not None else None,
301+
aparam=aparam[ii].reshape(1, -1) if aparam is not None else None,
302+
)
303+
# Extract dipole from the output dictionary
304+
all_dipole.append(dipole_batch["dipole"])
305+
306+
# nframe x natoms x 3
307+
dipole = torch.cat(all_dipole, dim=0)
308+
assert dipole.shape[0] == nframes
309+
310+
dipole_reshaped = dipole.reshape(nframes, natoms, 3)
311+
coord_reshaped = coord.reshape(nframes, natoms, 3)
312+
_wfcc_coord = coord_reshaped + dipole_reshaped
313+
# Apply mask and reshape
314+
wfcc_coord = _wfcc_coord[mask.unsqueeze(-1).expand_as(_wfcc_coord)]
315+
wfcc_coord = wfcc_coord.reshape(nframes, -1)
316+
all_coord = torch.cat((coord, wfcc_coord), dim=1)
317+
return all_coord
318+
319+
320+
@torch.jit.export
321+
def make_mask(
322+
sel_type: torch.Tensor,
323+
atype: torch.Tensor,
324+
) -> torch.Tensor:
325+
"""Create a boolean mask for selected atom types.
326+
327+
Parameters
328+
----------
329+
sel_type : torch.Tensor
330+
The selected atom types to create a mask for
331+
atype : torch.Tensor
332+
The atom types in the system
333+
334+
Returns
335+
-------
336+
mask : torch.Tensor
337+
Boolean mask where True indicates atoms of selected types
338+
"""
339+
# Ensure tensors are of the right type
340+
sel_type = sel_type.to(torch.long)
341+
atype = atype.to(torch.long)
342+
343+
# Create mask using broadcasting
344+
mask = torch.zeros_like(atype, dtype=torch.bool)
345+
for t in sel_type:
346+
mask = mask | (atype == t)
347+
return mask

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,11 @@ pin_pytorch_cpu = [
164164
# macos x86 has been deprecated
165165
"torch>=2.8,<2.10; platform_machine!='x86_64' or platform_system != 'Darwin'",
166166
"torch; platform_machine=='x86_64' and platform_system == 'Darwin'",
167+
"torch_admp @ git+https://github.com/chiahsinchu/torch-admp.git@v1.1.0a",
167168
]
168169
pin_pytorch_gpu = [
169170
"torch>=2.7,<2.10",
171+
"torch_admp @ git+https://github.com/chiahsinchu/torch-admp.git@v1.1.0a",
170172
]
171173
pin_jax = [
172174
"jax==0.5.0;python_version>='3.10'",
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later

source/tests/pt/test_data_modifier.py renamed to source/tests/pt/modifier/test_data_modifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
DeepmdData,
5353
)
5454

55-
from ..consistent.common import (
55+
from ...consistent.common import (
5656
parameterized,
5757
)
5858

0 commit comments

Comments
 (0)