Skip to content

Commit 4e5ab18

Browse files
fix: add optional force check (#780)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Enhanced error handling when processing forces in various data formats. - Added conditional checks to prevent potential runtime errors when force data is missing. - Improved robustness of data conversion methods across multiple plugins. - **Refactor** - Streamlined data handling for optional force and virial information. - Implemented safer data extraction methods in ASE, PWmat, and VASP plugins. - Corrected a typographical error in the documentation of the driver methods. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b826633 commit 4e5ab18

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

dpdata/ase_calculator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def calculate(
6262
self.results["energy"] = data["energies"][0]
6363
# see https://gitlab.com/ase/ase/-/merge_requests/2485
6464
self.results["free_energy"] = data["energies"][0]
65-
self.results["forces"] = data["forces"][0]
65+
if "forces" in data:
66+
self.results["forces"] = data["forces"][0]
6667
if "virials" in data:
6768
self.results["virial"] = data["virials"][0].reshape(3, 3)
6869

dpdata/driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,8 @@ def label(self, data: dict) -> dict:
166166
labeled_data = lb_data.copy()
167167
else:
168168
labeled_data["energies"] += lb_data["energies"]
169-
labeled_data["forces"] += lb_data["forces"]
169+
if "forces" in labeled_data and "forces" in lb_data:
170+
labeled_data["forces"] += lb_data["forces"]
170171
if "virials" in labeled_data and "virials" in lb_data:
171172
labeled_data["virials"] += lb_data["virials"]
172173
return labeled_data

dpdata/plugins/ase.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
175175
cell=data["cells"][ii],
176176
)
177177

178-
results = {"energy": data["energies"][ii], "forces": data["forces"][ii]}
178+
results = {"energy": data["energies"][ii]}
179+
if "forces" in data:
180+
results["forces"] = data["forces"][ii]
179181
if "virials" in data:
180182
# convert to GPa as this is ase convention
181183
# v_pref = 1 * 1e4 / 1.602176621e6
@@ -296,7 +298,10 @@ def from_labeled_system(
296298
dict_frames["energies"] = np.append(
297299
dict_frames["energies"], tmp["energies"][0]
298300
)
299-
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
301+
if "forces" in tmp.keys() and "forces" in dict_frames.keys():
302+
dict_frames["forces"] = np.append(
303+
dict_frames["forces"], tmp["forces"][0]
304+
)
300305
if "virials" in tmp.keys() and "virials" in dict_frames.keys():
301306
dict_frames["virials"] = np.append(
302307
dict_frames["virials"], tmp["virials"][0]
@@ -305,7 +310,8 @@ def from_labeled_system(
305310
## Correct the shape of numpy arrays
306311
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
307312
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
308-
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
313+
if "forces" in dict_frames.keys():
314+
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
309315
if "virials" in dict_frames.keys():
310316
dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)
311317

dpdata/plugins/pwmat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ def from_labeled_system(
3131
data["cells"],
3232
data["coords"],
3333
data["energies"],
34-
data["forces"],
34+
tmp_force,
3535
tmp_virial,
3636
) = dpdata.pwmat.movement.get_frames(
3737
file_name, begin=begin, step=step, convergence_check=convergence_check
3838
)
39+
if tmp_force is not None:
40+
data["forces"] = tmp_force
3941
if tmp_virial is not None:
4042
data["virials"] = tmp_virial
4143
# scale virial to the unit of eV

dpdata/plugins/vasp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def from_labeled_system(
9595
data["cells"],
9696
data["coords"],
9797
data["energies"],
98-
data["forces"],
98+
tmp_force,
9999
tmp_virial,
100100
) = dpdata.vasp.outcar.get_frames(
101101
file_name,
@@ -104,6 +104,8 @@ def from_labeled_system(
104104
ml=ml,
105105
convergence_check=convergence_check,
106106
)
107+
if tmp_force is not None:
108+
data["forces"] = tmp_force
107109
if tmp_virial is not None:
108110
data["virials"] = tmp_virial
109111
# scale virial to the unit of eV

0 commit comments

Comments
 (0)