Skip to content

Commit dedd1b8

Browse files
iProzdOutisLi
authored andcommitted
feat(pt): support spin virial
1 parent 9553e6e commit dedd1b8

File tree

12 files changed

+331
-40
lines changed

12 files changed

+331
-40
lines changed

deepmd/pt/loss/ener_spin.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,22 @@ def forward(
279279
rmse_ae.detach(), find_atom_ener
280280
)
281281

282+
if self.has_v and "virial" in model_pred and "virial" in label:
283+
find_virial = label.get("find_virial", 0.0)
284+
pref_v = pref_v * find_virial
285+
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
286+
l2_virial_loss = torch.mean(torch.square(diff_v))
287+
if not self.inference:
288+
more_loss["l2_virial_loss"] = self.display_if_exist(
289+
l2_virial_loss.detach(), find_virial
290+
)
291+
loss += atom_norm * (pref_v * l2_virial_loss)
292+
rmse_v = l2_virial_loss.sqrt() * atom_norm
293+
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
294+
if mae:
295+
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
296+
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
297+
282298
if not self.inference:
283299
more_loss["rmse"] = torch.sqrt(loss.detach())
284300
return model_pred, loss, more_loss

deepmd/pt/model/model/make_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def forward_common(
138138
fparam: torch.Tensor | None = None,
139139
aparam: torch.Tensor | None = None,
140140
do_atomic_virial: bool = False,
141+
coord_corr_for_virial: torch.Tensor | None = None,
141142
) -> dict[str, torch.Tensor]:
142143
"""Return model prediction.
143144
@@ -156,6 +157,9 @@ def forward_common(
156157
atomic parameter. nf x nloc x nda
157158
do_atomic_virial
158159
If calculate the atomic virial.
160+
coord_corr_for_virial
161+
The coordinates correction of the atoms for virial.
162+
shape: nf x (nloc x 3)
159163
160164
Returns
161165
-------
@@ -183,6 +187,14 @@ def forward_common(
183187
mixed_types=True,
184188
box=bb,
185189
)
190+
if coord_corr_for_virial is not None:
191+
coord_corr_for_virial = coord_corr_for_virial.to(cc.dtype)
192+
extended_coord_corr = torch.gather(
193+
coord_corr_for_virial, 1, mapping.unsqueeze(-1).expand(-1, -1, 3)
194+
)
195+
else:
196+
extended_coord_corr = None
197+
186198
model_predict_lower = self.forward_common_lower(
187199
extended_coord,
188200
extended_atype,
@@ -191,6 +203,7 @@ def forward_common(
191203
do_atomic_virial=do_atomic_virial,
192204
fparam=fp,
193205
aparam=ap,
206+
extended_coord_corr=extended_coord_corr,
194207
)
195208
model_predict = communicate_extended_output(
196209
model_predict_lower,
@@ -247,6 +260,7 @@ def forward_common_lower(
247260
do_atomic_virial: bool = False,
248261
comm_dict: dict[str, torch.Tensor] | None = None,
249262
extra_nlist_sort: bool = False,
263+
extended_coord_corr: torch.Tensor | None = None,
250264
) -> dict[str, torch.Tensor]:
251265
"""Return model prediction. Lower interface that takes
252266
extended atomic coordinates and types, nlist, and mapping
@@ -273,6 +287,8 @@ def forward_common_lower(
273287
The data needed for communication for parallel inference.
274288
extra_nlist_sort
275289
whether to forcibly sort the nlist.
290+
extended_coord_corr
291+
coordinates correction for virial in extended region. nf x (nall x 3)
276292
277293
Returns
278294
-------
@@ -305,6 +321,7 @@ def forward_common_lower(
305321
do_atomic_virial=do_atomic_virial,
306322
create_graph=self.training,
307323
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
324+
extended_coord_corr=extended_coord_corr,
308325
)
309326
model_predict = self.output_type_cast(model_predict, input_prec)
310327
return model_predict

deepmd/pt/model/model/spin_model.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,14 @@ def process_spin_input(
5959
coord = coord.reshape(nframes, nloc, 3)
6060
spin = spin.reshape(nframes, nloc, 3)
6161
atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1)
62-
virtual_coord = coord + spin * (self.virtual_scale_mask.to(atype.device))[
63-
atype
64-
].reshape([nframes, nloc, 1])
62+
spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape(
63+
[nframes, nloc, 1]
64+
)
65+
virtual_coord = coord + spin_dist
6566
coord_spin = torch.concat([coord, virtual_coord], dim=-2)
66-
return coord_spin, atype_spin
67+
# for spin virial corr
68+
coord_corr = torch.concat([torch.zeros_like(coord), -spin_dist], dim=-2)
69+
return coord_spin, atype_spin, coord_corr
6770

6871
def process_spin_input_lower(
6972
self,
@@ -83,13 +86,18 @@ def process_spin_input_lower(
8386
"""
8487
nframes, nall = extended_coord.shape[:2]
8588
nloc = nlist.shape[1]
86-
virtual_extended_coord = extended_coord + extended_spin * (
89+
extended_spin_dist = extended_spin * (
8790
self.virtual_scale_mask.to(extended_atype.device)
8891
)[extended_atype].reshape([nframes, nall, 1])
92+
virtual_extended_coord = extended_coord + extended_spin_dist
8993
virtual_extended_atype = extended_atype + self.ntypes_real
9094
extended_coord_updated = concat_switch_virtual(
9195
extended_coord, virtual_extended_coord, nloc
9296
)
97+
# for spin virial corr
98+
extended_coord_corr = concat_switch_virtual(
99+
torch.zeros_like(extended_coord), -extended_spin_dist, nloc
100+
)
93101
extended_atype_updated = concat_switch_virtual(
94102
extended_atype, virtual_extended_atype, nloc
95103
)
@@ -105,6 +113,7 @@ def process_spin_input_lower(
105113
extended_atype_updated,
106114
nlist_updated,
107115
mapping_updated,
116+
extended_coord_corr,
108117
)
109118

110119
def process_spin_output(
@@ -376,7 +385,7 @@ def spin_sampled_func() -> list[dict[str, Any]]:
376385
sampled = sampled_func()
377386
spin_sampled = []
378387
for sys in sampled:
379-
coord_updated, atype_updated = self.process_spin_input(
388+
coord_updated, atype_updated, _ = self.process_spin_input(
380389
sys["coord"], sys["atype"], sys["spin"]
381390
)
382391
tmp_dict = {
@@ -407,7 +416,9 @@ def forward_common(
407416
do_atomic_virial: bool = False,
408417
) -> dict[str, torch.Tensor]:
409418
nframes, nloc = atype.shape
410-
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin)
419+
coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input(
420+
coord, atype, spin
421+
)
411422
if aparam is not None:
412423
aparam = self.expand_aparam(aparam, nloc * 2)
413424
model_ret = self.backbone_model.forward_common(
@@ -417,6 +428,7 @@ def forward_common(
417428
fparam=fparam,
418429
aparam=aparam,
419430
do_atomic_virial=do_atomic_virial,
431+
coord_corr_for_virial=coord_corr_for_virial,
420432
)
421433
model_output_type = self.backbone_model.model_output_type()
422434
if "mask" in model_output_type:
@@ -463,6 +475,7 @@ def forward_common_lower(
463475
extended_atype_updated,
464476
nlist_updated,
465477
mapping_updated,
478+
extended_coord_corr_for_virial,
466479
) = self.process_spin_input_lower(
467480
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
468481
)
@@ -478,6 +491,7 @@ def forward_common_lower(
478491
do_atomic_virial=do_atomic_virial,
479492
comm_dict=comm_dict,
480493
extra_nlist_sort=extra_nlist_sort,
494+
extended_coord_corr=extended_coord_corr_for_virial,
481495
)
482496
model_output_type = self.backbone_model.model_output_type()
483497
if "mask" in model_output_type:
@@ -550,6 +564,11 @@ def translated_output_def(self) -> dict[str, Any]:
550564
output_def["force"].squeeze(-2)
551565
output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"])
552566
output_def["force_mag"].squeeze(-2)
567+
if self.do_grad_c("energy"):
568+
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
569+
output_def["virial"].squeeze(-2)
570+
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
571+
output_def["atom_virial"].squeeze(-3)
553572
return output_def
554573

555574
def forward(
@@ -578,7 +597,10 @@ def forward(
578597
if self.backbone_model.do_grad_r("energy"):
579598
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
580599
model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2)
581-
# not support virial by far
600+
if self.backbone_model.do_grad_c("energy"):
601+
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
602+
if do_atomic_virial:
603+
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
582604
return model_predict
583605

584606
@torch.jit.export
@@ -615,5 +637,10 @@ def forward_lower(
615637
model_predict["extended_force_mag"] = model_ret[
616638
"energy_derv_r_mag"
617639
].squeeze(-2)
618-
# not support virial by far
640+
if self.backbone_model.do_grad_c("energy"):
641+
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
642+
if do_atomic_virial:
643+
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
644+
-3
645+
)
619646
return model_predict

deepmd/pt/model/model/transform_output.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def fit_output_to_model_output(
156156
do_atomic_virial: bool = False,
157157
create_graph: bool = True,
158158
mask: torch.Tensor | None = None,
159+
extended_coord_corr: torch.Tensor | None = None,
159160
) -> dict[str, torch.Tensor]:
160161
"""Transform the output of the fitting network to
161162
the model output.
@@ -192,6 +193,12 @@ def fit_output_to_model_output(
192193
model_ret[kk_derv_r] = dr
193194
if vdef.c_differentiable:
194195
assert dc is not None
196+
if extended_coord_corr is not None:
197+
dc_corr = (
198+
dr.squeeze(-2).unsqueeze(-1)
199+
@ extended_coord_corr.unsqueeze(-2)
200+
).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005
201+
dc = dc + dc_corr
195202
model_ret[kk_derv_c] = dc
196203
model_ret[kk_derv_c + "_redu"] = torch.sum(
197204
model_ret[kk_derv_c].to(redu_prec), dim=1

source/api_c/include/deepmd.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2602,9 +2602,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
26022602
for (int j = 0; j < natoms * 3; j++) {
26032603
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
26042604
}
2605-
// for (int j = 0; j < 9; j++) {
2606-
// virial[i][j] = virial_flat[i * 9 + j];
2607-
// }
2605+
for (int j = 0; j < 9; j++) {
2606+
virial[i][j] = virial_flat[i * 9 + j];
2607+
}
26082608
}
26092609
};
26102610
/**
@@ -2705,9 +2705,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
27052705
for (int j = 0; j < natoms * 3; j++) {
27062706
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
27072707
}
2708-
// for (int j = 0; j < 9; j++) {
2709-
// virial[i][j] = virial_flat[i * 9 + j];
2710-
// }
2708+
for (int j = 0; j < 9; j++) {
2709+
virial[i][j] = virial_flat[i * 9 + j];
2710+
}
27112711
for (int j = 0; j < natoms; j++) {
27122712
atom_energy[i][j] = atom_energy_flat[i * natoms + j];
27132713
}

source/api_c/src/c_api.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -862,11 +862,11 @@ void DP_DeepSpinModelDeviCompute_variant(DP_DeepSpinModelDevi* dp,
862862
flatten_vector(fm_flat, fm);
863863
std::copy(fm_flat.begin(), fm_flat.end(), force_mag);
864864
}
865-
// if (virial) {
866-
// std::vector<VALUETYPE> v_flat;
867-
// flatten_vector(v_flat, v);
868-
// std::copy(v_flat.begin(), v_flat.end(), virial);
869-
// }
865+
if (virial) {
866+
std::vector<VALUETYPE> v_flat;
867+
flatten_vector(v_flat, v);
868+
std::copy(v_flat.begin(), v_flat.end(), virial);
869+
}
870870
if (atomic_energy) {
871871
std::vector<VALUETYPE> ae_flat;
872872
flatten_vector(ae_flat, ae);

source/api_cc/src/DeepSpinPT.cc

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
251251
c10::IValue energy_ = outputs.at("energy");
252252
c10::IValue force_ = outputs.at("extended_force");
253253
c10::IValue force_mag_ = outputs.at("extended_force_mag");
254-
// spin model not suported yet
255-
// c10::IValue virial_ = outputs.at("virial");
254+
c10::IValue virial_ = outputs.at("virial");
256255
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
257256
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
258257
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
@@ -267,11 +266,11 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
267266
dforce_mag.assign(
268267
cpu_force_mag_.data_ptr<VALUETYPE>(),
269268
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
270-
// spin model not suported yet
271-
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
272-
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
273-
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
274-
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
269+
270+
torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
271+
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
272+
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
273+
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
275274

276275
// bkw map
277276
force.resize(static_cast<size_t>(nframes) * fwd_map.size() * 3);
@@ -415,8 +414,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
415414
c10::IValue energy_ = outputs.at("energy");
416415
c10::IValue force_ = outputs.at("force");
417416
c10::IValue force_mag_ = outputs.at("force_mag");
418-
// spin model not suported yet
419-
// c10::IValue virial_ = outputs.at("virial");
417+
c10::IValue virial_ = outputs.at("virial");
420418
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
421419
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
422420
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
@@ -431,11 +429,10 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
431429
force_mag.assign(
432430
cpu_force_mag_.data_ptr<VALUETYPE>(),
433431
cpu_force_mag_.data_ptr<VALUETYPE>() + cpu_force_mag_.numel());
434-
// spin model not suported yet
435-
// torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
436-
// torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
437-
// virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
438-
// cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
432+
torch::Tensor flat_virial_ = virial_.toTensor().view({-1}).to(floatType);
433+
torch::Tensor cpu_virial_ = flat_virial_.to(torch::kCPU);
434+
virial.assign(cpu_virial_.data_ptr<VALUETYPE>(),
435+
cpu_virial_.data_ptr<VALUETYPE>() + cpu_virial_.numel());
439436
if (atomic) {
440437
// c10::IValue atom_virial_ = outputs.at("atom_virial");
441438
c10::IValue atom_energy_ = outputs.at("atom_energy");

source/tests/pt/model/test_autodiff.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,17 @@ def test(
141141
cell = (cell) + 5.0 * torch.eye(3, device="cpu")
142142
coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
143143
coord = torch.matmul(coord, cell)
144+
spin = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator)
144145
atype = torch.IntTensor([0, 0, 0, 1, 1])
145146
# assumes input to be numpy tensor
146147
coord = coord.numpy()
148+
spin = spin.numpy()
147149
cell = cell.numpy()
148-
test_keys = ["energy", "force", "virial"]
150+
test_spin = getattr(self, "test_spin", False)
151+
if not test_spin:
152+
test_keys = ["energy", "force", "virial"]
153+
else:
154+
test_keys = ["energy", "force", "force_mag", "virial"]
149155

150156
def np_infer(
151157
new_cell,
@@ -157,6 +163,7 @@ def np_infer(
157163
).unsqueeze(0),
158164
torch.tensor(new_cell, device="cpu").unsqueeze(0),
159165
atype,
166+
spins=torch.tensor(spin, device=env.DEVICE).unsqueeze(0),
160167
)
161168
# detach
162169
ret = {key: to_numpy_array(result[key].squeeze(0)) for key in test_keys}
@@ -251,3 +258,11 @@ def setUp(self) -> None:
251258
self.type_split = False
252259
self.test_spin = True
253260
self.model = get_model(model_params).to(env.DEVICE)
261+
262+
263+
class TestEnergyModelSpinSeAVirial(unittest.TestCase, VirialTest):
264+
def setUp(self) -> None:
265+
model_params = copy.deepcopy(model_spin)
266+
self.type_split = False
267+
self.test_spin = True
268+
self.model = get_model(model_params).to(env.DEVICE)

source/tests/pt/model/test_ener_spin_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_input_output_process(self) -> None:
115115
nframes, nloc = self.coord.shape[:2]
116116
self.real_ntypes = self.model.spin.get_ntypes_real()
117117
# 1. test forward input process
118-
coord_updated, atype_updated = self.model.process_spin_input(
118+
coord_updated, atype_updated, _ = self.model.process_spin_input(
119119
self.coord, self.atype, self.spin
120120
)
121121
# compare atypes of real and virtual atoms
@@ -174,6 +174,7 @@ def test_input_output_process(self) -> None:
174174
extended_atype_updated,
175175
nlist_updated,
176176
mapping_updated,
177+
_,
177178
) = self.model.process_spin_input_lower(
178179
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
179180
)

0 commit comments

Comments
 (0)