Skip to content

Commit 1d66a6f

Browse files
ben rhodesben rhodes
authored andcommitted
Finish todos and update readme
1 parent 1176796 commit 1d66a6f

File tree

7 files changed

+57
-25
lines changed

7 files changed

+57
-25
lines changed

README.md

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,17 @@ from orb_models.forcefield import atomic_system, pretrained
6666
from orb_models.forcefield.base import batch_graphs
6767

6868
device = "cpu" # or device="cuda"
69-
orbff = pretrained.orb_v2(device=device)
69+
orbff, system_config = pretrained.orb_v3_conservative_inf_omat(
70+
device=device
71+
precision="float32-high", # or "float32-highest" / "float64
72+
)
7073
atoms = bulk('Cu', 'fcc', a=3.58, cubic=True)
71-
graph = atomic_system.ase_atoms_to_atom_graphs(atoms, device=device)
74+
graph = atomic_system.ase_atoms_to_atom_graphs(atoms, system_config, device=device)
7275

7376
# Optionally, batch graphs for faster inference
7477
# graph = batch_graphs([graph, graph, ...])
7578

76-
result = orbff.predict(graph)
79+
result = orbff.predict(graph, split=False)
7780

7881
# Convert to ASE atoms (unbatches the results and transfers to cpu if necessary)
7982
atoms = atomic_system.atom_graphs_to_ase_atoms(
@@ -94,8 +97,12 @@ from orb_models.forcefield import pretrained
9497
from orb_models.forcefield.calculator import ORBCalculator
9598

9699
device="cpu" # or device="cuda"
97-
orbff = pretrained.orb_v2(device=device) # or choose another model using ORB_PRETRAINED_MODELS[model_name]()
98-
calc = ORBCalculator(orbff, device=device)
100+
# or choose another model using ORB_PRETRAINED_MODELS[model_name]()
101+
orbff, system_config = pretrained.orb_v3_conservative_inf_omat(
102+
device=device
103+
precision="float32-high", # or "float32-highest" / "float64
104+
)
105+
calc = ORBCalculator(orbff, system_config, device=device)
99106
atoms = bulk('Cu', 'fcc', a=3.58, cubic=True)
100107

101108
atoms.calc = calc
@@ -111,7 +118,7 @@ from ase.optimize import BFGS
111118
atoms.rattle(0.5)
112119
print("Rattled Energy:", atoms.get_potential_energy())
113120

114-
calc = ORBCalculator(orbff, device="cpu") # or device="cuda"
121+
calc = ORBCalculator(orbff, system_config, device="cpu") # or device="cuda"
115122
dyn = BFGS(atoms)
116123
dyn.run(fmax=0.01)
117124
print("Optimized Energy:", atoms.get_potential_energy())
@@ -120,24 +127,43 @@ print("Optimized Energy:", atoms.get_potential_energy())
120127
Or you can use it to run MD simulations. The script, an example input xyz file and a Colab notebook demonstration are available in the [examples directory.](./examples) This should work with any input, simply modify the input_file and cell_size parameters. We recommend using constant volume simulations.
121128

122129

130+
### Floating Point Precision
131+
132+
As shown in usage snippets above, we support 3 floating point precision types: `"float32-high"`, `"float32-highest"` and `"float64"`.
133+
134+
The default value of `"float32-high"` is recommended for maximal acceleration when using A100 / H100 Nvidia GPUs. However, we have observed some performance loss for high-precision calculations involving second and third order properties of the PES. In these cases, we recommend `"float32-highest"`.
135+
136+
In stark constrast to other universal forcefields, we have not found any benefit to using `"float64"`.
137+
123138
### Finetuning
124139
You can finetune the model using your custom dataset.
125140
The dataset should be an [ASE sqlite database](https://wiki.fysik.dtu.dk/ase/ase/db/db.html#module-ase.db.core).
126141
```python
127-
python finetune.py --dataset=<dataset_name> --data_path=<your_data_path>
142+
python finetune.py --dataset=<dataset_name> --data_path=<your_data_path> --base_model=<base_model>
128143
```
129-
After the model is finetuned, checkpoints will, by default, be saved to the ckpts folder in the directory you ran the finetuning script from.
144+
Where base_model is one of:
145+
- "orb_v3_conservative_inf_omat"
146+
- "orb_v3_conservative_20_omat"
147+
- "orb_v3_direct_inf_omat"
148+
- "orb_v3_direct_20_omat"
149+
- "orb_v2"
150+
151+
After the model is finetuned, checkpoints will, by default, be saved to the ckpts folder in the directory you ran the finetuning script from.
130152

131153
You can use the new model and load the checkpoint by:
132154
```python
133155
from orb_models.forcefield import pretrained
134156

135-
model = pretrained.orb_v2(weights_path=<path_to_ckpt>)
157+
model, system_config = getattr(pretrained, <base_model>)(
158+
weights_path=<path_to_ckpt>,
159+
device="cpu", # or device="cuda"
160+
precision="float32-high", # or precision="float32-highest"
161+
)
136162
```
137163

138164
> **Caveats**
139165
>
140-
> Our finetuning script is designed for simplicity and advanced users may wish to develop it further. Please be aware that:
166+
> Our finetuning script is designed for simplicity. We strongly advise users to customise it further for their use-case to get the best performance. Please be aware that:
141167
> - The script assumes that your ASE database rows contain **energy, forces, and stress** data. To train on molecular data without stress, you will need to edit the code.
142168
> - **Early stopping** is not implemented. However, you can use the command line argument `save_every_x_epochs` (default is 5), so "retrospective" early stopping can be applied by selecting a suitable checkpoint.
143169
> - The **learning rate schedule is hardcoded** to be `torch.optim.lr_scheduler.OneCycleLR` with `pct_start=0.05`. The `max_lr`/`min_lr` will be 10x greater/smaller than the `lr` specified via the command line. To get the best performance, you may wish to try other schedulers.
@@ -147,7 +173,6 @@ model = pretrained.orb_v2(weights_path=<path_to_ckpt>)
147173

148174

149175

150-
151176
### Citing
152177

153178
A preprint describing the model in more detail can be found here: https://arxiv.org/abs/2410.22570

examples/NaClWaterMD.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def run_md_simulation(
5252
atoms.set_pbc([True] * 3)
5353

5454
# Set the calculator
55-
atoms.calc = ORBCalculator(model=pretrained.orb_d3_v2(), device=device)
55+
atoms.calc = ORBCalculator(*pretrained.orb_d3_v2(), device=device)
5656

5757
# Set the initial velocities
5858
MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_K)

finetune.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,13 +231,13 @@ def run(args):
231231
device = utils.init_device(device_id=args.device_id)
232232
utils.seed_everything(args.random_seed)
233233

234-
# Make sure to use this flag for matmuls on A100 and H100 GPUs.
234+
# Setting this is 2x faster on A100 and H100
235+
# GPUs and does not appear to hurt training
235236
precision = "float32-high"
236237

237238
# Instantiate model
238-
239-
# TODO (BEN): make base model configurable!
240-
model, system_config = pretrained.orb_v2(device=device, precision=precision)
239+
base_model = args.base_model
240+
model, system_config = getattr(pretrained, base_model)(device=device, precision=precision)
241241

242242
for param in model.parameters():
243243
param.requires_grad = True
@@ -385,6 +385,19 @@ def main():
385385
type=float,
386386
help="Learning rate. 3e-4 is purely a sensible default; you may want to tune this for your problem.",
387387
)
388+
parser.add_argument(
389+
"--base_model",
390+
default="orb_v3_conservative_inf_omat",
391+
type=str,
392+
help="Base model to finetune.",
393+
choices=[
394+
"orb_v3_conservative_inf_omat",
395+
"orb_v3_conservative_20_omat",
396+
"orb_v3_direct_inf_omat",
397+
"orb_v3_direct_20_omat",
398+
"orb_v2",
399+
],
400+
)
388401
args = parser.parse_args()
389402
run(args)
390403

internal/check.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def main(model: str, core_model: str):
3131
)
3232

3333
graph_orig = core_atomic_system.ase_atoms_to_atom_graphs(atoms, sys_config)
34-
graph = atomic_system.ase_atoms_to_atom_graphs(atoms)
34+
graph = atomic_system.ase_atoms_to_atom_graphs(atoms, sys_config)
3535

3636
pred_orig = original_orbff.predict(graph_orig)
3737

orb_models/forcefield/atomic_system.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,12 @@ def atom_graphs_to_ase_atoms(
152152

153153
return atoms_list
154154

155-
156155
def ase_atoms_to_atom_graphs(
157156
atoms: ase.Atoms,
157+
system_config: SystemConfig,
158158
*,
159159
wrap: bool = True,
160160
edge_method: Optional[EdgeCreationMethod] = None,
161-
system_config: Optional[SystemConfig] = None,
162161
max_num_neighbors: Optional[int] = None,
163162
system_id: Optional[int] = None,
164163
half_supercell: bool = False,
@@ -195,8 +194,6 @@ def ase_atoms_to_atom_graphs(
195194
Returns:
196195
AtomGraphs object
197196
"""
198-
if system_config is None:
199-
system_config = SystemConfig(radius=6.0, max_num_neighbors=20)
200197
if isinstance(atoms.pbc, Iterable) and any(atoms.pbc) and not all(atoms.pbc):
201198
raise NotImplementedError(
202199
"We do not support periodicity along a subset of axes. Please ensure atoms.pbc is "

orb_models/forcefield/pretrained.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,6 @@ def orb_v3_direct_inf_mpa(
398398
return model, SystemConfig(radius=6.0, max_num_neighbors=120)
399399

400400

401-
402401
def orb_v2(
403402
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v2-20241011.ckpt", # noqa: E501
404403
device: Union[torch.device, str, None] = None,
@@ -410,8 +409,6 @@ def orb_v2(
410409
model = load_model_for_inference(
411410
model, weights_path, device, precision=precision, compile=compile
412411
)
413-
# TODO (BEN): update all functions to return SystemConfig
414-
# TODO (BEN): search repo for max_num_neighbors and avoid any hardcoding
415412

416413
return model, SystemConfig(radius=6.0, max_num_neighbors=20)
417414

orb_models/forcefield/segment_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def scatter_sum(
159159
Returns:
160160
torch.Tensor: The output tensor with values scattered and summed.
161161
"""
162-
assert reduce == "sum" # for now, TODO
162+
assert reduce == "sum"
163163
index = _broadcast(index, src, dim)
164164
if out is None:
165165
size = list(src.size())

0 commit comments

Comments
 (0)