Skip to content

Commit 46251a7

Browse files
anyangmlpre-commit-ci[bot]wanghan-iapcm
authored
Feat: set force label optional (#772)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Improvements** - Enhanced flexibility in data handling by making forces data optional in the system configuration. - Added a method to check for the presence of forces data in the system. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Han Wang <[email protected]>
1 parent 2905792 commit 46251a7

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

dpdata/system.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,11 @@ class LabeledSystem(System):
12091209
DTYPES: tuple[DataType, ...] = System.DTYPES + (
12101210
DataType("energies", np.ndarray, (Axis.NFRAMES,), deepmd_name="energy"),
12111211
DataType(
1212-
"forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3), deepmd_name="force"
1212+
"forces",
1213+
np.ndarray,
1214+
(Axis.NFRAMES, Axis.NATOMS, 3),
1215+
required=False,
1216+
deepmd_name="force",
12131217
),
12141218
DataType(
12151219
"virials",
@@ -1269,13 +1273,17 @@ def __add__(self, others):
12691273
raise RuntimeError("Unspported data structure")
12701274
return self.__class__.from_dict({"data": self_copy.data})
12711275

1276+
def has_forces(self) -> bool:
1277+
return "forces" in self.data
1278+
12721279
def has_virial(self) -> bool:
12731280
# return ('virials' in self.data) and (len(self.data['virials']) > 0)
12741281
return "virials" in self.data
12751282

12761283
def affine_map_fv(self, trans, f_idx: int | numbers.Integral):
12771284
assert np.linalg.det(trans) != 0
1278-
self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans)
1285+
if self.has_forces():
1286+
self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans)
12791287
if self.has_virial():
12801288
self.data["virials"][f_idx] = np.matmul(
12811289
trans.T, np.matmul(self.data["virials"][f_idx], trans)
@@ -1308,7 +1316,8 @@ def correction(self, hl_sys: LabeledSystem) -> LabeledSystem:
13081316
raise RuntimeError("high_sys should be LabeledSystem")
13091317
corrected_sys = self.copy()
13101318
corrected_sys.data["energies"] = hl_sys.data["energies"] - self.data["energies"]
1311-
corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"]
1319+
if "forces" in self.data and "forces" in hl_sys.data:
1320+
corrected_sys.data["forces"] = hl_sys.data["forces"] - self.data["forces"]
13121321
if "virials" in self.data and "virials" in hl_sys.data:
13131322
corrected_sys.data["virials"] = (
13141323
hl_sys.data["virials"] - self.data["virials"]

0 commit comments

Comments
 (0)