-
Notifications
You must be signed in to change notification settings - Fork 599
Expand file tree
/
Copy pathener_spin.py
More file actions
356 lines (341 loc) · 14.8 KB
/
ener_spin.py
File metadata and controls
356 lines (341 loc) · 14.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
)
import torch
import torch.nn.functional as F
from deepmd.pt.loss.loss import (
TaskLoss,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.utils.data import (
DataRequirementItem,
)
class EnergySpinLoss(TaskLoss):
def __init__(
self,
starter_learning_rate: float = 1.0,
start_pref_e: float = 0.0,
limit_pref_e: float = 0.0,
start_pref_fr: float = 0.0,
limit_pref_fr: float = 0.0,
start_pref_fm: float = 0.0,
limit_pref_fm: float = 0.0,
start_pref_v: float = 0.0,
limit_pref_v: float = 0.0,
start_pref_ae: float = 0.0,
limit_pref_ae: float = 0.0,
enable_atom_ener_coeff: bool = False,
use_l1_all: bool = False,
inference: bool = False,
**kwargs: Any,
) -> None:
r"""Construct a layer to compute loss on energy, real force, magnetic force and virial.
Parameters
----------
starter_learning_rate : float
The learning rate at the start of the training.
start_pref_e : float
The prefactor of energy loss at the start of the training.
limit_pref_e : float
The prefactor of energy loss at the end of the training.
start_pref_fr : float
The prefactor of real force loss at the start of the training.
limit_pref_fr : float
The prefactor of real force loss at the end of the training.
start_pref_fm : float
The prefactor of magnetic force loss at the start of the training.
limit_pref_fm : float
The prefactor of magnetic force loss at the end of the training.
start_pref_v : float
The prefactor of virial loss at the start of the training.
limit_pref_v : float
The prefactor of virial loss at the end of the training.
start_pref_ae : float
The prefactor of atomic energy loss at the start of the training.
limit_pref_ae : float
The prefactor of atomic energy loss at the end of the training.
enable_atom_ener_coeff : bool
if true, the energy will be computed as \sum_i c_i E_i
use_l1_all : bool
Whether to use L1 loss, if False (default), it will use L2 loss.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.starter_learning_rate = starter_learning_rate
self.has_e = (start_pref_e != 0.0 and limit_pref_e != 0.0) or inference
self.has_fr = (start_pref_fr != 0.0 and limit_pref_fr != 0.0) or inference
self.has_fm = (start_pref_fm != 0.0 and limit_pref_fm != 0.0) or inference
self.has_v = (start_pref_v != 0.0 and limit_pref_v != 0.0) or inference
self.has_ae = (start_pref_ae != 0.0 and limit_pref_ae != 0.0) or inference
self.start_pref_e = start_pref_e
self.limit_pref_e = limit_pref_e
self.start_pref_fr = start_pref_fr
self.limit_pref_fr = limit_pref_fr
self.start_pref_fm = start_pref_fm
self.limit_pref_fm = limit_pref_fm
self.start_pref_v = start_pref_v
self.limit_pref_v = limit_pref_v
self.start_pref_ae = start_pref_ae
self.limit_pref_ae = limit_pref_ae
self.enable_atom_ener_coeff = enable_atom_ener_coeff
self.use_l1_all = use_l1_all
self.inference = inference
def forward(
self,
input_dict: dict[str, torch.Tensor],
model: torch.nn.Module,
label: dict[str, torch.Tensor],
natoms: int,
learning_rate: float,
mae: bool = False,
) -> tuple[dict[str, torch.Tensor], torch.Tensor, dict[str, torch.Tensor]]:
"""Return energy loss with magnetic labels.
Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.
Returns
-------
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
model_pred = model(**input_dict)
coef = learning_rate / self.starter_learning_rate
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_fr = self.limit_pref_fr + (self.start_pref_fr - self.limit_pref_fr) * coef
pref_fm = self.limit_pref_fm + (self.start_pref_fm - self.limit_pref_fm) * coef
pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef
pref_ae = self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * coef
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
atom_norm = 1.0 / natoms
if self.has_e and "energy" in model_pred and "energy" in label:
energy_pred = model_pred["energy"]
energy_label = label["energy"]
if self.enable_atom_ener_coeff and "atom_energy" in model_pred:
atom_ener_pred = model_pred["atom_energy"]
# when ener_coeff (\nu) is defined, the energy is defined as
# E = \sum_i \nu_i E_i
# instead of the sum of atomic energies.
#
# A case is that we want to train reaction energy
# A + B -> C + D
# E = - E(A) - E(B) + E(C) + E(D)
# A, B, C, D could be put far away from each other
atom_ener_coeff = label["atom_ener_coeff"]
atom_ener_coeff = atom_ener_coeff.reshape(atom_ener_pred.shape)
energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1)
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
if not self.use_l1_all:
l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label))
if not self.inference:
more_loss["l2_ener_loss"] = self.display_if_exist(
l2_ener_loss.detach(), find_energy
)
loss += atom_norm * (pref_e * l2_ener_loss)
rmse_e = l2_ener_loss.sqrt() * atom_norm
more_loss["rmse_e"] = self.display_if_exist(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
l1_ener_loss = F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="sum",
)
loss += pref_e * l1_ener_loss
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
).detach(),
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
if mae:
mae_e = torch.mean(torch.abs(energy_pred - energy_label)) * atom_norm
more_loss["mae_e"] = self.display_if_exist(mae_e.detach(), find_energy)
mae_e_all = torch.mean(torch.abs(energy_pred - energy_label))
more_loss["mae_e_all"] = self.display_if_exist(
mae_e_all.detach(), find_energy
)
if self.has_fr and "force" in model_pred and "force" in label:
find_force_r = label.get("find_force", 0.0)
pref_fr = pref_fr * find_force_r
if not self.use_l1_all:
diff_fr = label["force"] - model_pred["force"]
l2_force_real_loss = torch.mean(torch.square(diff_fr))
if not self.inference:
more_loss["l2_force_r_loss"] = self.display_if_exist(
l2_force_real_loss.detach(), find_force_r
)
loss += (pref_fr * l2_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_fr = l2_force_real_loss.sqrt()
more_loss["rmse_fr"] = self.display_if_exist(
rmse_fr.detach(), find_force_r
)
if mae:
mae_fr = torch.mean(torch.abs(diff_fr))
more_loss["mae_fr"] = self.display_if_exist(
mae_fr.detach(), find_force_r
)
else:
l1_force_real_loss = F.l1_loss(
label["force"], model_pred["force"], reduction="none"
)
more_loss["mae_fr"] = self.display_if_exist(
l1_force_real_loss.mean().detach(), find_force_r
)
l1_force_real_loss = l1_force_real_loss.sum(-1).mean(-1).sum()
loss += (pref_fr * l1_force_real_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if self.has_fm and "force_mag" in model_pred and "force_mag" in label:
find_force_m = label.get("find_force_mag", 0.0)
pref_fm = pref_fm * find_force_m
nframes = model_pred["force_mag"].shape[0]
atomic_mask = model_pred["mask_mag"].expand([-1, -1, 3])
label_force_mag = label["force_mag"][atomic_mask].view(nframes, -1, 3)
model_pred_force_mag = model_pred["force_mag"][atomic_mask].view(
nframes, -1, 3
)
if not self.use_l1_all:
diff_fm = label_force_mag - model_pred_force_mag
l2_force_mag_loss = torch.mean(torch.square(diff_fm))
if not self.inference:
more_loss["l2_force_m_loss"] = self.display_if_exist(
l2_force_mag_loss.detach(), find_force_m
)
loss += (pref_fm * l2_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_fm = l2_force_mag_loss.sqrt()
more_loss["rmse_fm"] = self.display_if_exist(
rmse_fm.detach(), find_force_m
)
if mae:
mae_fm = torch.mean(torch.abs(diff_fm))
more_loss["mae_fm"] = self.display_if_exist(
mae_fm.detach(), find_force_m
)
else:
l1_force_mag_loss = F.l1_loss(
label_force_mag, model_pred_force_mag, reduction="none"
)
more_loss["mae_fm"] = self.display_if_exist(
l1_force_mag_loss.mean().detach(), find_force_m
)
l1_force_mag_loss = l1_force_mag_loss.sum(-1).mean(-1).sum()
loss += (pref_fm * l1_force_mag_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if self.has_ae and "atom_energy" in model_pred and "atom_ener" in label:
atom_ener = model_pred["atom_energy"]
atom_ener_label = label["atom_ener"]
find_atom_ener = label.get("find_atom_ener", 0.0)
pref_ae = pref_ae * find_atom_ener
atom_ener_reshape = atom_ener.reshape(-1)
atom_ener_label_reshape = atom_ener_label.reshape(-1)
l2_atom_ener_loss = torch.square(
atom_ener_label_reshape - atom_ener_reshape
).mean()
if not self.inference:
more_loss["l2_atom_ener_loss"] = self.display_if_exist(
l2_atom_ener_loss.detach(), find_atom_ener
)
loss += (pref_ae * l2_atom_ener_loss).to(GLOBAL_PT_FLOAT_PRECISION)
rmse_ae = l2_atom_ener_loss.sqrt()
more_loss["rmse_ae"] = self.display_if_exist(
rmse_ae.detach(), find_atom_ener
)
if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss
@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
label_requirement = []
if self.has_e:
label_requirement.append(
DataRequirementItem(
"energy",
ndof=1,
atomic=False,
must=False,
high_prec=True,
)
)
if self.has_fr:
label_requirement.append(
DataRequirementItem(
"force",
ndof=3,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_fm:
label_requirement.append(
DataRequirementItem(
"force_mag",
ndof=3,
atomic=True,
must=False,
high_prec=False,
)
)
if self.has_v:
label_requirement.append(
DataRequirementItem(
"virial",
ndof=9,
atomic=False,
must=False,
high_prec=False,
)
)
if self.has_ae:
label_requirement.append(
DataRequirementItem(
"atom_ener",
ndof=1,
atomic=True,
must=False,
high_prec=False,
)
)
return label_requirement