-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathcompile_dft_dataset.py
More file actions
90 lines (74 loc) · 2.12 KB
/
compile_dft_dataset.py
File metadata and controls
90 lines (74 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Compile dataset for UQ test.
This code is modified from the Open Catalyst Project scripts.
"""
import pickle
from pathlib import Path
from tqdm import tqdm
from ase.io import read
import lmdb
import torch
from ocpmodels.datasets.lmdb_dataset import LmdbDataset
from ocpmodels.preprocessing import AtomsToGraphs
a2g = AtomsToGraphs(
max_neigh=50,
radius=6,
r_energy=True,
r_forces=True,
r_fixed=True,
r_distances=False,
r_pbc=True,
)
dft_adsorption_energies = {
"CuZn_CO2": -0.066,
"CuZn_CHOH": 5.951,
"CuZn_OCHO": 5.836,
"CuZn_OHCH3": 2.699,
"CuAlZn_CO2": 6.816,
"CuAlZn_CHOH": -1.824,
"CuAlZn_OCHO": 2.820,
"CuAlZn_OHCH3": -5.615,
}
lmdb_dir = Path("test_lmdb")
i = 0
with lmdb.open(
str(Path("test_lmdb", "00.lmdb")),
map_size=1099511627776 * 2,
subdir=False,
meminit=False,
map_async=True,
) as db:
for _, p in tqdm(
enumerate(
Path("/Users/spru445/Desktop/methanol_chemreasoner_results").rglob("*.xyz")
)
):
print(p)
fname = p.stem
ats = read(str(p))
data_object = a2g.convert(ats)
# add atom tags
data_object.tags = torch.LongTensor(ats.get_tags())
data_object.sid = str(p)
data_object.descriptor = p.stem
if p.stem in dft_adsorption_energies.keys():
data_object.y = dft_adsorption_energies[p.stem]
# if p.stem in dft_adsorption_energies.keys():
# ats.info.update({"dft_energy": dft_adsorption_energies[p.stem]})
# db.write(
# ats,
# data={"info": ats.info},
# y=dft_adsorption_energies[p.stem],
# name=p.stem,
# )
txn = db.begin(write=True)
txn.put(
f"{i}".encode("ascii"),
pickle.dumps(data_object, protocol=-1),
)
txn.commit()
i += 1
# else:
# conn.write(ats, data={"info": ats.info}, name=p.stem)
dataset = LmdbDataset({"src": str(lmdb_dir)})
for i in range(len(dataset)):
print(dataset[i])