Skip to content

Commit 037c2b8

Browse files
support assigning 'type_map' for mixed_type (#540)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b45911a commit 037c2b8

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

dpdata/deepmd/mixed.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,17 @@ def _load_set(folder, nopbc: bool):
4646
def to_system_data(folder, type_map=None, labels=True):
4747
# data is empty
4848
data = load_type(folder)
49+
old_type_map = data["atom_names"].copy()
50+
if type_map is not None:
51+
assert isinstance(type_map, list)
52+
missing_type = [i for i in old_type_map if i not in type_map]
53+
assert (
54+
not missing_type
55+
), f"These types are missing in selected type_map: {missing_type} !"
56+
index_map = np.array([type_map.index(i) for i in old_type_map])
57+
data["atom_names"] = type_map.copy()
58+
else:
59+
index_map = None
4960
data["orig"] = np.zeros([3])
5061
if os.path.isfile(os.path.join(folder, "nopbc")):
5162
data["nopbc"] = True
@@ -63,7 +74,12 @@ def to_system_data(folder, type_map=None, labels=True):
6374
nframes = np.reshape(cells, [-1, 3, 3]).shape[0]
6475
all_cells.append(np.reshape(cells, [nframes, 3, 3]))
6576
all_coords.append(np.reshape(coords, [nframes, -1, 3]))
66-
all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1]))
77+
if index_map is None:
78+
all_real_atom_types.append(np.reshape(real_atom_types, [nframes, -1]))
79+
else:
80+
all_real_atom_types.append(
81+
np.reshape(index_map[real_atom_types], [nframes, -1])
82+
)
6783
if eners is not None:
6884
eners = np.reshape(eners, [nframes])
6985
if labels:

0 commit comments

Comments
 (0)