Skip to content

Commit aa44098

Browse files
feat: handle masked forces in test (#4893)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Added per-atom weighting for force evaluation: computes and reports weighted MAE/RMSE alongside unweighted metrics, includes weighted metrics in system-average summaries, logs weighted force metrics, and safely handles zero-weight cases. Also propagates the per-atom weight field into reporting. - Tests - Added end-to-end tests validating weighted vs unweighted force MAE/RMSE and verifying evaluator outputs when using per-atom weight masks. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 85ab1b9 commit aa44098

File tree

2 files changed

+121
-2
lines changed

2 files changed

+121
-2
lines changed

deepmd/entrypoints/test.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def test_ener(
291291

292292
data.add("energy", 1, atomic=False, must=False, high_prec=True)
293293
data.add("force", 3, atomic=True, must=False, high_prec=False)
294+
data.add("atom_pref", 1, atomic=True, must=False, high_prec=False, repeat=3)
294295
data.add("virial", 9, atomic=False, must=False, high_prec=False)
295296
if dp.has_efield:
296297
data.add("efield", 3, atomic=True, must=True, high_prec=False)
@@ -313,6 +314,7 @@ def test_ener(
313314
find_force = test_data.get("find_force")
314315
find_virial = test_data.get("find_virial")
315316
find_force_mag = test_data.get("find_force_mag")
317+
find_atom_pref = test_data.get("find_atom_pref")
316318
mixed_type = data.mixed_type
317319
natoms = len(test_data["type"][0])
318320
nframes = test_data["box"].shape[0]
@@ -419,6 +421,16 @@ def test_ener(
419421
diff_f = force - test_data["force"][:numb_test]
420422
mae_f = mae(diff_f)
421423
rmse_f = rmse(diff_f)
424+
size_f = diff_f.size
425+
if find_atom_pref == 1:
426+
atom_weight = test_data["atom_pref"][:numb_test]
427+
weight_sum = np.sum(atom_weight)
428+
if weight_sum > 0:
429+
mae_fw = np.sum(np.abs(diff_f) * atom_weight) / weight_sum
430+
rmse_fw = np.sqrt(np.sum(diff_f * diff_f * atom_weight) / weight_sum)
431+
else:
432+
mae_fw = 0.0
433+
rmse_fw = 0.0
422434
diff_v = virial - test_data["virial"][:numb_test]
423435
mae_v = mae(diff_v)
424436
rmse_v = rmse(diff_v)
@@ -453,8 +465,13 @@ def test_ener(
453465
if not out_put_spin and find_force == 1:
454466
log.info(f"Force MAE : {mae_f:e} eV/A")
455467
log.info(f"Force RMSE : {rmse_f:e} eV/A")
456-
dict_to_return["mae_f"] = (mae_f, force.size)
457-
dict_to_return["rmse_f"] = (rmse_f, force.size)
468+
dict_to_return["mae_f"] = (mae_f, size_f)
469+
dict_to_return["rmse_f"] = (rmse_f, size_f)
470+
if find_atom_pref == 1:
471+
log.info(f"Force weighted MAE : {mae_fw:e} eV/A")
472+
log.info(f"Force weighted RMSE: {rmse_fw:e} eV/A")
473+
dict_to_return["mae_fw"] = (mae_fw, weight_sum)
474+
dict_to_return["rmse_fw"] = (rmse_fw, weight_sum)
458475
if out_put_spin and find_force == 1:
459476
log.info(f"Force atom MAE : {mae_fr:e} eV/A")
460477
log.info(f"Force atom RMSE : {rmse_fr:e} eV/A")
@@ -600,6 +617,9 @@ def print_ener_sys_avg(avg: dict[str, float]) -> None:
600617
if "rmse_f" in avg:
601618
log.info(f"Force MAE : {avg['mae_f']:e} eV/A")
602619
log.info(f"Force RMSE : {avg['rmse_f']:e} eV/A")
620+
if "rmse_fw" in avg:
621+
log.info(f"Force weighted MAE : {avg['mae_fw']:e} eV/A")
622+
log.info(f"Force weighted RMSE: {avg['rmse_fw']:e} eV/A")
603623
else:
604624
log.info(f"Force atom MAE : {avg['mae_fr']:e} eV/A")
605625
log.info(f"Force spin MAE : {avg['mae_fm']:e} eV/uB")

source/tests/pt/test_dp_test.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,19 @@
1515
import torch
1616

1717
from deepmd.entrypoints.test import test as dp_test
18+
from deepmd.entrypoints.test import test_ener as dp_test_ener
19+
from deepmd.infer.deep_eval import (
20+
DeepEval,
21+
)
1822
from deepmd.pt.entrypoints.main import (
1923
get_trainer,
2024
)
2125
from deepmd.pt.utils.utils import (
2226
to_numpy_array,
2327
)
28+
from deepmd.utils.data import (
29+
DeepmdData,
30+
)
2431

2532
from .model.test_permutation import (
2633
model_property,
@@ -140,6 +147,98 @@ def setUp(self) -> None:
140147
json.dump(self.config, fp, indent=4)
141148

142149

150+
class TestDPTestForceWeight(DPTest, unittest.TestCase):
151+
def setUp(self) -> None:
152+
self.detail_file = "test_dp_test_force_weight_detail"
153+
input_json = str(Path(__file__).parent / "water/se_atten.json")
154+
with open(input_json) as f:
155+
self.config = json.load(f)
156+
self.config["training"]["numb_steps"] = 1
157+
self.config["training"]["save_freq"] = 1
158+
system_dir = self._prepare_weighted_system()
159+
data_file = [system_dir]
160+
self.config["training"]["training_data"]["systems"] = data_file
161+
self.config["training"]["validation_data"]["systems"] = data_file
162+
self.config["model"] = deepcopy(model_se_e2_a)
163+
self.system_dir = system_dir
164+
self.input_json = "test_dp_test_force_weight.json"
165+
with open(self.input_json, "w") as fp:
166+
json.dump(self.config, fp, indent=4)
167+
168+
def _prepare_weighted_system(self) -> str:
169+
src = Path(__file__).parent / "water/data/single"
170+
tmp_dir = tempfile.mkdtemp()
171+
shutil.copytree(src, tmp_dir, dirs_exist_ok=True)
172+
set_dir = Path(tmp_dir) / "set.000"
173+
forces = np.load(set_dir / "force.npy")
174+
forces[0, :3] += 1.0
175+
forces[0, -3:] += 10.0
176+
np.save(set_dir / "force.npy", forces)
177+
natoms = forces.shape[1] // 3
178+
atom_pref = np.ones((forces.shape[0], natoms), dtype=forces.dtype)
179+
atom_pref[:, 0] = 2.0
180+
atom_pref[:, -1] = 0.0
181+
np.save(set_dir / "atom_pref.npy", atom_pref)
182+
return tmp_dir
183+
184+
def test_force_weight(self) -> None:
185+
trainer = get_trainer(deepcopy(self.config))
186+
with torch.device("cpu"):
187+
trainer.get_data(is_train=False)
188+
model = torch.jit.script(trainer.model)
189+
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
190+
torch.jit.save(model, tmp_model.name)
191+
dp = DeepEval(tmp_model.name)
192+
data = DeepmdData(
193+
self.system_dir,
194+
set_prefix="set",
195+
shuffle_test=False,
196+
type_map=dp.get_type_map(),
197+
sort_atoms=False,
198+
)
199+
err = dp_test_ener(
200+
dp,
201+
data,
202+
self.system_dir,
203+
numb_test=1,
204+
detail_file=None,
205+
has_atom_ener=False,
206+
)
207+
test_data = data.get_test()
208+
coord = test_data["coord"].reshape([1, -1])
209+
box = test_data["box"][:1]
210+
atype = test_data["type"][0]
211+
ret = dp.eval(
212+
coord,
213+
box,
214+
atype,
215+
fparam=None,
216+
aparam=None,
217+
atomic=False,
218+
efield=None,
219+
mixed_type=False,
220+
spin=None,
221+
)
222+
force_pred = ret[1].reshape([1, -1])
223+
force_true = test_data["force"][:1]
224+
weight = test_data["atom_pref"][:1]
225+
diff = force_pred - force_true
226+
mae_unweighted = np.sum(np.abs(diff)) / diff.size
227+
rmse_unweighted = np.sqrt(np.sum(diff * diff) / diff.size)
228+
denom = weight.sum()
229+
mae_weighted = np.sum(np.abs(diff) * weight) / denom
230+
rmse_weighted = np.sqrt(np.sum(diff * diff * weight) / denom)
231+
np.testing.assert_allclose(err["mae_f"][0], mae_unweighted)
232+
np.testing.assert_allclose(err["rmse_f"][0], rmse_unweighted)
233+
np.testing.assert_allclose(err["mae_fw"][0], mae_weighted)
234+
np.testing.assert_allclose(err["rmse_fw"][0], rmse_weighted)
235+
os.unlink(tmp_model.name)
236+
237+
def tearDown(self) -> None:
238+
super().tearDown()
239+
shutil.rmtree(self.system_dir)
240+
241+
143242
class TestDPTestPropertySeA(unittest.TestCase):
144243
def setUp(self) -> None:
145244
self.detail_file = "test_dp_test_property_detail"

0 commit comments

Comments
 (0)