Skip to content

Commit b9023c7

Browse files
committed
Checked in five ET models
1 parent 3529a0b commit b9023c7

File tree

9 files changed

+169
-2
lines changed

9 files changed

+169
-2
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
# spice-models
2-
Models trained on the SPICE dataset
1+
# SPICE-Models
2+
Models trained on the SPICE dataset.

five-et/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
This directory contains the five equivariant transformer models described in (insert reference when available).
2+
3+
The script `createSpiceDataset.py` converts the dataset file SPICE.hdf5 downloaded from https://github.com/openmm/spice-dataset/releases
4+
to the format used by [TorchMD-Net](https://github.com/torchmd/torchmd-net). It generates a new file SPICE-processed.hdf5
5+
which was used for training.
6+
7+
The file `hparams.yaml` contains the configuration used for training the models. All models used identical settings
8+
except that `seed` was set to a different value for each one (the numbers 1 through 5). Note that although the file
9+
specifies `num_epochs: 1000`, training was halted after 24 hours (when the training job reached the end of its allocated
10+
time). This corresponded to 118 epochs.
11+
12+
The files ending in `.ckpt` are checkpoint files for TorchMD-Net 0.2.4 containing the trained models. They should
13+
hopefully work with later versions as well, but that may not be guaranteed. They can be loaded like this:
14+
15+
```python
16+
from torchmdnet.models.model import load_model
17+
model = load_model('model1.ckpt')
18+
```
19+
20+
The `device` argument to `load_model()` can be used to specify a device to load it on. For example,
21+
22+
```python
23+
model = load_model('model1.ckpt', device=torch.device('cuda:0'))
24+
```
25+
26+
To compute energy and forces for a molecular conformation, invoke the model's `forward()` method. It takes two arguments:
27+
a tensor of length `n_atoms` and dtype `long` containing the atom types, and a tensor of shape `(n_atoms, 3)` and dtype
28+
`float32` containing the atom positions in angstroms. It returns two arguments: the potential energy in kJ/mol, and
29+
the force on each atom in kJ/mol/angstrom. Atom types are defined by the element and formal charge of each atom. The
30+
mapping is defined in `createSpiceDataset.py` with this dictionary:
31+
32+
```python
33+
typeDict = {('Br', -1): 0, ('Br', 0): 1, ('C', -1): 2, ('C', 0): 3, ('C', 1): 4, ('Ca', 2): 5, ('Cl', -1): 6,
34+
('Cl', 0): 7, ('F', -1): 8, ('F', 0): 9, ('H', 0): 10, ('I', -1): 11, ('I', 0): 12, ('K', 1): 13,
35+
('Li', 1): 14, ('Mg', 2): 15, ('N', -1): 16, ('N', 0): 17, ('N', 1): 18, ('Na', 1): 19, ('O', -1): 20,
36+
('O', 0): 21, ('O', 1): 22, ('P', 0): 23, ('P', 1): 24, ('S', -1): 25, ('S', 0): 26, ('S', 1): 27}
37+
```
38+
39+
For example, the following computes the energy and forces for a pair of ions (Cl- and Na+) positioned 3 angstroms apart.
40+
41+
```python
42+
types = torch.tensor([6, 19], dtype=torch.long)
43+
pos = torch.tensor([[0, 0, 0], [0, 3, 0]], dtype=torch.float32)
44+
energy, forces = model.forward(types, pos)
45+
```

five-et/createSpiceDataset.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import numpy as np
2+
from openff.toolkit.topology import Molecule
3+
from openmm.unit import *
4+
from collections import defaultdict
5+
import h5py
6+
7+
typeDict = {('Br', -1): 0, ('Br', 0): 1, ('C', -1): 2, ('C', 0): 3, ('C', 1): 4, ('Ca', 2): 5, ('Cl', -1): 6,
8+
('Cl', 0): 7, ('F', -1): 8, ('F', 0): 9, ('H', 0): 10, ('I', -1): 11, ('I', 0): 12, ('K', 1): 13,
9+
('Li', 1): 14, ('Mg', 2): 15, ('N', -1): 16, ('N', 0): 17, ('N', 1): 18, ('Na', 1): 19, ('O', -1): 20,
10+
('O', 0): 21, ('O', 1): 22, ('P', 0): 23, ('P', 1): 24, ('S', -1): 25, ('S', 0): 26, ('S', 1): 27}
11+
12+
infile = h5py.File('SPICE.hdf5')
13+
14+
# First pass: group the samples by total number of atoms.
15+
16+
groupsByAtomCount = defaultdict(list)
17+
for name in infile:
18+
group = infile[name]
19+
count = len(group['atomic_numbers'])
20+
groupsByAtomCount[count].append(group)
21+
22+
# Create the output file.
23+
24+
filename = 'SPICE-processed.hdf5'
25+
outfile = h5py.File(filename, 'w')
26+
27+
# One pass for each number of atoms, creating a group for it.
28+
29+
print(sorted(list(groupsByAtomCount.keys())))
30+
posScale = 1*bohr/angstrom
31+
energyScale = 1*hartree/item/(kilojoules_per_mole)
32+
forceScale = energyScale/posScale
33+
for count in sorted(groupsByAtomCount.keys()):
34+
print(count)
35+
smiles = []
36+
pos = []
37+
types = []
38+
energy = []
39+
forces = []
40+
for g in groupsByAtomCount[count]:
41+
molSmiles = g['smiles'][0]
42+
mol = Molecule.from_mapped_smiles(molSmiles, allow_undefined_stereo=True)
43+
molTypes = [typeDict[(atom.element.symbol, atom.formal_charge/elementary_charge)] for atom in mol.atoms]
44+
assert len(molTypes) == count
45+
for i, atom in enumerate(mol.atoms):
46+
assert atom.atomic_number == g['atomic_numbers'][i]
47+
numConfs = g['conformations'].shape[0]
48+
for i in range(numConfs):
49+
smiles.append(molSmiles)
50+
pos.append(g['conformations'][i])
51+
types.append(molTypes)
52+
energy.append(g['formation_energy'][i])
53+
forces.append(g['dft_total_gradient'][i])
54+
group = outfile.create_group(f'samples{count}')
55+
group.create_dataset('smiles', data=smiles, dtype=h5py.string_dtype())
56+
group.create_dataset('types', data=np.array(types), dtype=np.int8)
57+
group.create_dataset('pos', data=np.array(pos)*posScale, dtype=np.float32)
58+
group.create_dataset('energy', data=np.array(energy)*energyScale, dtype=np.float32)
59+
group.create_dataset('forces', data=-np.array(forces)*forceScale, dtype=np.float32)

five-et/hparams.yaml

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
activation: silu
2+
aggr: add
3+
atom_filter: -1
4+
attn_activation: silu
5+
batch_size: 128
6+
charge: false
7+
conf: null
8+
coord_files: null
9+
cutoff_lower: 0.0
10+
cutoff_upper: 10.0
11+
dataset: HDF5
12+
dataset_arg: null
13+
dataset_root: SPICE-processed.hdf5
14+
derivative: true
15+
distance_influence: both
16+
distributed_backend: ddp
17+
early_stopping_patience: 20
18+
ema_alpha_dy: 1.0
19+
ema_alpha_y: 1.0
20+
embed_files: null
21+
embedding_dimension: 128
22+
energy_files: null
23+
energy_weight: 1.0
24+
force_files: null
25+
force_weight: 1.0
26+
inference_batch_size: 128
27+
load_model: null
28+
log_dir: model1b
29+
lr: 0.0005
30+
lr_factor: 0.5
31+
lr_metric: train_loss
32+
lr_min: 1.0e-07
33+
lr_patience: 0
34+
lr_warmup_steps: 0
35+
max_num_neighbors: 100
36+
max_z: 28
37+
model: equivariant-transformer
38+
neighbor_embedding: true
39+
ngpus: -1
40+
num_epochs: 1000
41+
num_heads: 8
42+
num_layers: 6
43+
num_nodes: 1
44+
num_rbf: 64
45+
num_workers: 16
46+
output_model: Scalar
47+
precision: 32
48+
prior_model: null
49+
rbf_type: expnorm
50+
redirect: true
51+
reduce_op: add
52+
reset_trainer: false
53+
save_interval: 1
54+
seed: 1
55+
spin: false
56+
splits: null
57+
standardize: false
58+
test_interval: 10
59+
test_size: 0.0
60+
train_size: null
61+
trainable_rbf: true
62+
val_size: 0.05
63+
weight_decay: 0.0

five-et/model1.ckpt

16.4 MB
Binary file not shown.

five-et/model2.ckpt

16.4 MB
Binary file not shown.

five-et/model3.ckpt

16.4 MB
Binary file not shown.

five-et/model4.ckpt

16.4 MB
Binary file not shown.

five-et/model5.ckpt

16.4 MB
Binary file not shown.

0 commit comments

Comments
 (0)