|
| 1 | +import glob |
| 2 | +import os |
| 3 | +import shutil |
| 4 | + |
| 5 | +import numpy as np |
| 6 | + |
| 7 | + |
| 8 | +def load_type(folder): |
| 9 | + data = {} |
| 10 | + data["atom_names"] = [] |
| 11 | + # if find type_map.raw, use it |
| 12 | + assert os.path.isfile( |
| 13 | + os.path.join(folder, "type_map.raw") |
| 14 | + ), "Mixed type system must have type_map.raw!" |
| 15 | + with open(os.path.join(folder, "type_map.raw")) as fp: |
| 16 | + data["atom_names"] = fp.read().split() |
| 17 | + |
| 18 | + return data |
| 19 | + |
| 20 | + |
| 21 | +def formula(atom_names, atom_numbs): |
| 22 | + """ |
| 23 | + Return the formula of this system, like C3H5O2 |
| 24 | + """ |
| 25 | + return "".join( |
| 26 | + ["{}{}".format(symbol, numb) for symbol, numb in zip(atom_names, atom_numbs)] |
| 27 | + ) |
| 28 | + |
| 29 | + |
| 30 | +def _cond_load_data(fname): |
| 31 | + tmp = None |
| 32 | + if os.path.isfile(fname): |
| 33 | + tmp = np.load(fname) |
| 34 | + return tmp |
| 35 | + |
| 36 | + |
| 37 | +def _load_set(folder, nopbc: bool): |
| 38 | + coords = np.load(os.path.join(folder, "coord.npy")) |
| 39 | + if nopbc: |
| 40 | + cells = np.zeros((coords.shape[0], 3, 3)) |
| 41 | + else: |
| 42 | + cells = np.load(os.path.join(folder, "box.npy")) |
| 43 | + eners = _cond_load_data(os.path.join(folder, "energy.npy")) |
| 44 | + forces = _cond_load_data(os.path.join(folder, "force.npy")) |
| 45 | + virs = _cond_load_data(os.path.join(folder, "virial.npy")) |
| 46 | + real_atom_types = np.load(os.path.join(folder, "real_atom_types.npy")) |
| 47 | + return cells, coords, eners, forces, virs, real_atom_types |
| 48 | + |
| 49 | + |
| 50 | +def to_system_data(folder, type_map=None, labels=True): |
| 51 | + # data is empty |
| 52 | + data = load_type(folder) |
| 53 | + data["orig"] = np.zeros([3]) |
| 54 | + if os.path.isfile(os.path.join(folder, "nopbc")): |
| 55 | + data["nopbc"] = True |
| 56 | + sets = sorted(glob.glob(os.path.join(folder, "set.*"))) |
| 57 | + assert len(sets) == 1, "Mixed type must have only one set!" |
| 58 | + cells, coords, eners, forces, virs, real_atom_types = _load_set( |
| 59 | + sets[0], data.get("nopbc", False) |
| 60 | + ) |
| 61 | + nframes = np.reshape(cells, [-1, 3, 3]).shape[0] |
| 62 | + cells = np.reshape(cells, [nframes, 3, 3]) |
| 63 | + coords = np.reshape(coords, [nframes, -1, 3]) |
| 64 | + real_atom_types = np.reshape(real_atom_types, [nframes, -1]) |
| 65 | + natom = real_atom_types.shape[1] |
| 66 | + if labels: |
| 67 | + if eners is not None and eners.size > 0: |
| 68 | + eners = np.reshape(eners, [nframes]) |
| 69 | + if forces is not None and forces.size > 0: |
| 70 | + forces = np.reshape(forces, [nframes, -1, 3]) |
| 71 | + if virs is not None and virs.size > 0: |
| 72 | + virs = np.reshape(virs, [nframes, 3, 3]) |
| 73 | + data_list = [] |
| 74 | + while True: |
| 75 | + if real_atom_types.size == 0: |
| 76 | + break |
| 77 | + temp_atom_numbs = [ |
| 78 | + np.count_nonzero(real_atom_types[0] == i) |
| 79 | + for i in range(len(data["atom_names"])) |
| 80 | + ] |
| 81 | + # temp_formula = formula(data['atom_names'], temp_atom_numbs) |
| 82 | + temp_idx = np.arange(real_atom_types.shape[0])[ |
| 83 | + (real_atom_types == real_atom_types[0]).all(-1) |
| 84 | + ] |
| 85 | + rest_idx = np.arange(real_atom_types.shape[0])[ |
| 86 | + (real_atom_types != real_atom_types[0]).any(-1) |
| 87 | + ] |
| 88 | + temp_data = data.copy() |
| 89 | + temp_data["atom_numbs"] = temp_atom_numbs |
| 90 | + temp_data["atom_types"] = real_atom_types[0] |
| 91 | + real_atom_types = real_atom_types[rest_idx] |
| 92 | + temp_data["cells"] = cells[temp_idx] |
| 93 | + cells = cells[rest_idx] |
| 94 | + temp_data["coords"] = coords[temp_idx] |
| 95 | + coords = coords[rest_idx] |
| 96 | + if labels: |
| 97 | + if eners is not None and eners.size > 0: |
| 98 | + temp_data["energies"] = eners[temp_idx] |
| 99 | + eners = eners[rest_idx] |
| 100 | + if forces is not None and forces.size > 0: |
| 101 | + temp_data["forces"] = forces[temp_idx] |
| 102 | + forces = forces[rest_idx] |
| 103 | + if virs is not None and virs.size > 0: |
| 104 | + temp_data["virials"] = virs[temp_idx] |
| 105 | + virs = virs[rest_idx] |
| 106 | + data_list.append(temp_data) |
| 107 | + return data_list |
| 108 | + |
| 109 | + |
| 110 | +def dump(folder, data, comp_prec=np.float32, remove_sets=True): |
| 111 | + os.makedirs(folder, exist_ok=True) |
| 112 | + sets = sorted(glob.glob(os.path.join(folder, "set.*"))) |
| 113 | + if len(sets) > 0: |
| 114 | + if remove_sets: |
| 115 | + for ii in sets: |
| 116 | + shutil.rmtree(ii) |
| 117 | + else: |
| 118 | + raise RuntimeError( |
| 119 | + "found " |
| 120 | + + str(sets) |
| 121 | + + " in " |
| 122 | + + folder |
| 123 | + + "not a clean deepmd raw dir. please firstly clean set.* then try compress" |
| 124 | + ) |
| 125 | + # if not converted to mixed |
| 126 | + if "real_atom_types" not in data: |
| 127 | + from dpdata import LabeledSystem, System |
| 128 | + |
| 129 | + if "energies" in data: |
| 130 | + temp_sys = LabeledSystem(data=data) |
| 131 | + else: |
| 132 | + temp_sys = System(data=data) |
| 133 | + temp_sys.convert_to_mixed_type() |
| 134 | + # dump raw |
| 135 | + np.savetxt(os.path.join(folder, "type.raw"), data["atom_types"], fmt="%d") |
| 136 | + np.savetxt(os.path.join(folder, "type_map.raw"), data["real_atom_names"], fmt="%s") |
| 137 | + # BondOrder System |
| 138 | + if "bonds" in data: |
| 139 | + np.savetxt( |
| 140 | + os.path.join(folder, "bonds.raw"), |
| 141 | + data["bonds"], |
| 142 | + header="begin_atom, end_atom, bond_order", |
| 143 | + ) |
| 144 | + if "formal_charges" in data: |
| 145 | + np.savetxt(os.path.join(folder, "formal_charges.raw"), data["formal_charges"]) |
| 146 | + # reshape frame properties and convert prec |
| 147 | + nframes = data["cells"].shape[0] |
| 148 | + cells = np.reshape(data["cells"], [nframes, 9]).astype(comp_prec) |
| 149 | + coords = np.reshape(data["coords"], [nframes, -1]).astype(comp_prec) |
| 150 | + eners = None |
| 151 | + forces = None |
| 152 | + virials = None |
| 153 | + real_atom_types = None |
| 154 | + if "energies" in data: |
| 155 | + eners = np.reshape(data["energies"], [nframes]).astype(comp_prec) |
| 156 | + if "forces" in data: |
| 157 | + forces = np.reshape(data["forces"], [nframes, -1]).astype(comp_prec) |
| 158 | + if "virials" in data: |
| 159 | + virials = np.reshape(data["virials"], [nframes, 9]).astype(comp_prec) |
| 160 | + if "atom_pref" in data: |
| 161 | + atom_pref = np.reshape(data["atom_pref"], [nframes, -1]).astype(comp_prec) |
| 162 | + if "real_atom_types" in data: |
| 163 | + real_atom_types = np.reshape(data["real_atom_types"], [nframes, -1]).astype( |
| 164 | + np.int64 |
| 165 | + ) |
| 166 | + # dump frame properties: cell, coord, energy, force and virial |
| 167 | + set_folder = os.path.join(folder, "set.%03d" % 0) |
| 168 | + os.makedirs(set_folder) |
| 169 | + np.save(os.path.join(set_folder, "box"), cells) |
| 170 | + np.save(os.path.join(set_folder, "coord"), coords) |
| 171 | + if eners is not None: |
| 172 | + np.save(os.path.join(set_folder, "energy"), eners) |
| 173 | + if forces is not None: |
| 174 | + np.save(os.path.join(set_folder, "force"), forces) |
| 175 | + if virials is not None: |
| 176 | + np.save(os.path.join(set_folder, "virial"), virials) |
| 177 | + if real_atom_types is not None: |
| 178 | + np.save(os.path.join(set_folder, "real_atom_types"), real_atom_types) |
| 179 | + if "atom_pref" in data: |
| 180 | + np.save(os.path.join(set_folder, "atom_pref"), atom_pref) |
| 181 | + try: |
| 182 | + os.remove(os.path.join(folder, "nopbc")) |
| 183 | + except OSError: |
| 184 | + pass |
| 185 | + if data.get("nopbc", False): |
| 186 | + with open(os.path.join(folder, "nopbc"), "w") as fw_nopbc: |
| 187 | + pass |
| 188 | + |
| 189 | + |
| 190 | +def mix_system(*system, type_map, split_num=200, **kwargs): |
| 191 | + """Mix the systems into mixed_type ones |
| 192 | +
|
| 193 | + Parameters |
| 194 | + ---------- |
| 195 | + *system : System |
| 196 | + The systems to mix |
| 197 | + type_map : list of str |
| 198 | + Maps atom type to name |
| 199 | + split_num : int |
| 200 | + Number of frames in each system |
| 201 | +
|
| 202 | + Returns |
| 203 | + ------- |
| 204 | + mixed_systems: dict |
| 205 | + dict of mixed system with key '{atom_numbs}/sys.xxx' |
| 206 | + """ |
| 207 | + mixed_systems = {} |
| 208 | + temp_systems = {} |
| 209 | + atom_numbs_sys_index = {} # index of sys |
| 210 | + atom_numbs_frame_index = {} # index of frames in cur sys |
| 211 | + for sys in system: |
| 212 | + tmp_sys = sys.copy() |
| 213 | + natom = tmp_sys.get_natoms() |
| 214 | + tmp_sys.convert_to_mixed_type(type_map=type_map) |
| 215 | + if str(natom) not in atom_numbs_sys_index: |
| 216 | + atom_numbs_sys_index[str(natom)] = 0 |
| 217 | + if str(natom) not in atom_numbs_frame_index: |
| 218 | + atom_numbs_frame_index[str(natom)] = 0 |
| 219 | + atom_numbs_frame_index[str(natom)] += tmp_sys.get_nframes() |
| 220 | + if str(natom) not in temp_systems or not temp_systems[str(natom)]: |
| 221 | + temp_systems[str(natom)] = tmp_sys |
| 222 | + else: |
| 223 | + temp_systems[str(natom)].append(tmp_sys) |
| 224 | + if atom_numbs_frame_index[str(natom)] >= split_num: |
| 225 | + while True: |
| 226 | + sys_split, temp_systems[str(natom)], rest_num = split_system( |
| 227 | + temp_systems[str(natom)], split_num=split_num |
| 228 | + ) |
| 229 | + sys_name = ( |
| 230 | + f"{str(natom)}/sys." + "%.6d" % atom_numbs_sys_index[str(natom)] |
| 231 | + ) |
| 232 | + mixed_systems[sys_name] = sys_split |
| 233 | + atom_numbs_sys_index[str(natom)] += 1 |
| 234 | + if rest_num < split_num: |
| 235 | + atom_numbs_frame_index[str(natom)] = rest_num |
| 236 | + break |
| 237 | + for natom in temp_systems: |
| 238 | + if atom_numbs_frame_index[natom] > 0: |
| 239 | + sys_name = f"{natom}/sys." + "%.6d" % atom_numbs_sys_index[natom] |
| 240 | + mixed_systems[sys_name] = temp_systems[natom] |
| 241 | + return mixed_systems |
| 242 | + |
| 243 | + |
| 244 | +def split_system(sys, split_num=100): |
| 245 | + rest = sys.get_nframes() - split_num |
| 246 | + if rest <= 0: |
| 247 | + return sys, None, 0 |
| 248 | + else: |
| 249 | + split_sys = sys.sub_system(range(split_num)) |
| 250 | + rest_sys = sys.sub_system(range(split_num, sys.get_nframes())) |
| 251 | + return split_sys, rest_sys, rest |
0 commit comments