@@ -46,6 +46,17 @@ def _load_set(folder, nopbc: bool):
46
46
def to_system_data (folder , type_map = None , labels = True ):
47
47
# data is empty
48
48
data = load_type (folder )
49
+ old_type_map = data ["atom_names" ].copy ()
50
+ if type_map is not None :
51
+ assert isinstance (type_map , list )
52
+ missing_type = [i for i in old_type_map if i not in type_map ]
53
+ assert (
54
+ not missing_type
55
+ ), f"These types are missing in selected type_map: { missing_type } !"
56
+ index_map = np .array ([type_map .index (i ) for i in old_type_map ])
57
+ data ["atom_names" ] = type_map .copy ()
58
+ else :
59
+ index_map = None
49
60
data ["orig" ] = np .zeros ([3 ])
50
61
if os .path .isfile (os .path .join (folder , "nopbc" )):
51
62
data ["nopbc" ] = True
@@ -63,7 +74,12 @@ def to_system_data(folder, type_map=None, labels=True):
63
74
nframes = np .reshape (cells , [- 1 , 3 , 3 ]).shape [0 ]
64
75
all_cells .append (np .reshape (cells , [nframes , 3 , 3 ]))
65
76
all_coords .append (np .reshape (coords , [nframes , - 1 , 3 ]))
66
- all_real_atom_types .append (np .reshape (real_atom_types , [nframes , - 1 ]))
77
+ if index_map is None :
78
+ all_real_atom_types .append (np .reshape (real_atom_types , [nframes , - 1 ]))
79
+ else :
80
+ all_real_atom_types .append (
81
+ np .reshape (index_map [real_atom_types ], [nframes , - 1 ])
82
+ )
67
83
if eners is not None :
68
84
eners = np .reshape (eners , [nframes ])
69
85
if labels :
0 commit comments