Skip to content

Commit ea303c4

Browse files
authored
Fix calculator naming for conservative models with no direct heads (#76)
* fix calculator error * make sure all readme examples work * add note on compilation * fix loading for direct models
1 parent 3753705 commit ea303c4

File tree

12 files changed

+59
-50
lines changed

12 files changed

+59
-50
lines changed

README.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,12 @@ For example, `orb-v3-conservative-inf-omat` is a model that:
7070
- Computes forces/stress as gradients of energy
7171
- Has effectively infinite neighbors (120 in practice)
7272
- Was trained on the OMat24 dataset
73-
```
7473

75-
*We suggest using models trained on OMAT24, as these models are more performant and the data they are trained on uses newer pseudopotentials in VASP (PBE54 vs PBE52)*. `-mpa` models should be used if compatability with benchmarks (for example, Matbench Discovery) is required.
74+
75+
Orb-v3 models are **compiled** by default and use Pytorch's dynamic batching, which means that they do not need to recompile as graph sizes change. However, the first call to the model will be slower, as the graph is compiled by torch.
76+
77+
78+
**We suggest using models trained on OMAT24**, as these models are more performant and the data they are trained on uses newer pseudopotentials in VASP (PBE54 vs PBE52)*. `-mpa` models should be used if compatability with benchmarks (for example, Matbench Discovery) is required.
7679

7780
#### V2 Models
7881

@@ -99,11 +102,13 @@ from orb_models.forcefield.base import batch_graphs
99102

100103
device = "cpu" # or device="cuda"
101104
orbff = pretrained.orb_v3_conservative_inf_omat(
102-
device=device
105+
device=device,
103106
precision="float32-high", # or "float32-highest" / "float64
104107
)
105108
atoms = bulk('Cu', 'fcc', a=3.58, cubic=True)
106109
graph = atomic_system.ase_atoms_to_atom_graphs(atoms, orbff.system_config, device=device)
110+
atoms = bulk('Cu', 'fcc', a=3.58, cubic=True)
111+
graph = atomic_system.ase_atoms_to_atom_graphs(atoms, orbff.system_config, device=device)
107112

108113
# Optionally, batch graphs for faster inference
109114
# graph = batch_graphs([graph, graph, ...])
@@ -131,7 +136,7 @@ from orb_models.forcefield.calculator import ORBCalculator
131136
device="cpu" # or device="cuda"
132137
# or choose another model using ORB_PRETRAINED_MODELS[model_name]()
133138
orbff = pretrained.orb_v3_conservative_inf_omat(
134-
device=device
139+
device=device,
135140
precision="float32-high", # or "float32-highest" / "float64
136141
)
137142
calc = ORBCalculator(orbff, device=device)

orb_models/dataset/base_datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from orb_models.forcefield.base import AtomGraphs
1111

1212

13-
1413
class AtomsDataset(ABC, Dataset):
1514
"""AtomsDataset.
1615

orb_models/forcefield/atomic_system.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class SystemConfig:
2525
radius: float
2626
max_num_neighbors: int
2727

28+
2829
def atom_graphs_to_ase_atoms(
2930
graphs: AtomGraphs,
3031
energy: Optional[torch.Tensor] = None,
@@ -83,6 +84,7 @@ def atom_graphs_to_ase_atoms(
8384

8485
return atoms_list
8586

87+
8688
def ase_atoms_to_atom_graphs(
8789
atoms: ase.Atoms,
8890
system_config: SystemConfig,

orb_models/forcefield/calculator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
118118
self.results[property] = to_numpy(out[_property].squeeze())
119119

120120
if self.conservative:
121-
self.results["direct_forces"] = self.results["forces"]
122-
self.results["direct_stress"] = self.results["stress"]
121+
if self.model.forces_name in self.results:
122+
self.results["direct_forces"] = self.results[self.model.forces_name]
123+
if self.model.stress_name in self.results:
124+
self.results["direct_stress"] = self.results[self.model.stress_name]
123125
self.results["forces"] = self.results[self.model.grad_forces_name]
124126
self.results["stress"] = self.results[self.model.grad_stress_name]

orb_models/forcefield/direct_regressor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33

44
from orb_models.forcefield.pair_repulsion import ZBLBasis
55
from orb_models.forcefield import base
6-
from orb_models.forcefield.forcefield_utils import split_prediction, validate_regressor_inputs
6+
from orb_models.forcefield.forcefield_utils import (
7+
split_prediction,
8+
validate_regressor_inputs,
9+
)
710
from orb_models.forcefield.gns import MoleculeGNS
811
from orb_models.forcefield.load import load_forcefield_state_dict
912
from orb_models.forcefield.atomic_system import SystemConfig
1013

14+
1115
class DirectForcefieldRegressor(torch.nn.Module):
1216
"""Direct Forcefield regressor."""
1317

@@ -70,19 +74,17 @@ def __init__(
7074
param.requires_grad = False
7175

7276
if heads_require_grad is not None:
73-
for head_name, requires_grad in heads_require_grad.items():
74-
assert head_name in self.heads
75-
for param in self.heads[head_name].parameters():
76-
param.requires_grad = requires_grad
77-
77+
for head_name, requires_grad in heads_require_grad.items():
78+
assert head_name in self.heads
79+
for param in self.heads[head_name].parameters():
80+
param.requires_grad = requires_grad
7881

7982
self._system_config = system_config
8083

8184
@property
8285
def system_config(self) -> SystemConfig:
8386
return self._system_config
8487

85-
8688
def forward(
8789
self, batch: base.AtomGraphs
8890
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
@@ -206,4 +208,4 @@ def _get_raw_repulsion(
206208
for prop_type in property_types:
207209
if prop_type in name and "d3" not in name and "d4" not in name:
208210
return out_pair_repulsion[prop_type]
209-
return None
211+
return None

orb_models/forcefield/forcefield_heads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,4 +706,4 @@ def loss(
706706
f"{name}_mse_raw": ((raw_pred - target) ** 2).mean(),
707707
}
708708

709-
return base.ModelOutput(loss=loss, log=metrics)
709+
return base.ModelOutput(loss=loss, log=metrics)

orb_models/forcefield/forcefield_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from orb_models.forcefield.gns import MoleculeGNS
77

88

9-
109
def validate_regressor_inputs(
1110
heads: Union[Sequence[torch.nn.Module], Mapping[str, torch.nn.Module]],
1211
loss_weights: Dict[str, float],

orb_models/forcefield/load.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def load_forcefield_state_dict(
2525
if skip_artifact_reference_energy is True.
2626
2727
NOTE: We assume that the presence of the prefix "heads." in any key of the
28-
state_dict implies that the state_dict comes from a DirectForcefieldRegressor
28+
state_dict implies that the state_dict comes from a DirectForcefieldRegressor
2929
or ConservativeForcefieldRegressor.
3030
"""
3131
state_dict = dict(state_dict) # Shallow copy

orb_models/forcefield/pretrained.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
gaussian_basis_function,
1515
)
1616
from orb_models.forcefield.gns import MoleculeGNS
17-
from orb_models.forcefield.forcefield_heads import ConfidenceHead, EnergyHead, ForceHead, GraphHead, StressHead
17+
from orb_models.forcefield.forcefield_heads import (
18+
ConfidenceHead,
19+
EnergyHead,
20+
ForceHead,
21+
GraphHead,
22+
StressHead,
23+
)
1824
from orb_models.forcefield.rbf import BesselBasis
1925
from orb_models.utils import set_torch_precision
2026

@@ -88,7 +94,7 @@ def orb_v2_architecture(
8894
activation="ssp",
8995
),
9096
"forces": ForceHead(
91-
latent_dim=256,
97+
latent_dim=256,
9298
num_mlp_layers=1,
9399
mlp_hidden_dim=256,
94100
remove_mean=True,
@@ -140,7 +146,7 @@ def orb_v3_conservative_architecture(
140146
head_mlp_depth: int = 1,
141147
num_message_passing_steps: int = 5,
142148
activation: str = "silu",
143-
device: Optional[torch.device] = None,
149+
device: Optional[Union[torch.device, str]] = None,
144150
system_config: Optional[SystemConfig] = None,
145151
) -> ConservativeForcefieldRegressor:
146152
"""The orb-v3 conservative architecture."""
@@ -189,6 +195,7 @@ def orb_v3_conservative_architecture(
189195
pair_repulsion=True,
190196
system_config=system_config,
191197
)
198+
device = get_device(device)
192199
if device is not None and device != torch.device("cpu"):
193200
model.cuda(device)
194201
else:
@@ -268,13 +275,15 @@ def orb_v3_direct_architecture(
268275
pair_repulsion=True,
269276
system_config=system_config,
270277
)
278+
device = get_device(device)
271279
if device is not None and device != torch.device("cpu"):
272280
model.cuda(device)
273281
else:
274282
model = model.cpu()
275283

276284
return model
277285

286+
278287
def orb_v3_conservative_20_omat(
279288
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v3/orb-v3-conservative-20-omat-20250404.ckpt", # noqa: E501
280289
device: Union[torch.device, str, None] = None,
@@ -291,6 +300,7 @@ def orb_v3_conservative_20_omat(
291300

292301
return model
293302

303+
294304
def orb_v3_conservative_inf_omat(
295305
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v3/orb-v3-conservative-inf-omat-20250404.ckpt", # noqa: E501
296306
device: Union[torch.device, str, None] = None,
@@ -310,6 +320,7 @@ def orb_v3_conservative_inf_omat(
310320

311321
return model
312322

323+
313324
def orb_v3_direct_20_omat(
314325
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v3/orb-v3-direct-20-omat-20250404.ckpt", # noqa: E501
315326
device: Union[torch.device, str, None] = None,
@@ -325,6 +336,7 @@ def orb_v3_direct_20_omat(
325336

326337
return model
327338

339+
328340
def orb_v3_direct_inf_omat(
329341
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v3/orb-v3-direct-inf-omat-20250404.ckpt", # noqa: E501
330342
device: Union[torch.device, str, None] = None,
@@ -344,6 +356,7 @@ def orb_v3_direct_inf_omat(
344356

345357
return model
346358

359+
347360
def orb_v3_conservative_20_mpa(
348361
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v3/orb-v3-conservative-20-mpa-20250404.ckpt", # noqa: E501
349362
device: Union[torch.device, str, None] = None,
@@ -359,6 +372,7 @@ def orb_v3_conservative_20_mpa(
359372

360373
return model
361374

375+
362376
def orb_v3_conservative_inf_mpa(
363377
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v3/orb-v3-conservative-inf-mpa-20250404.ckpt", # noqa: E501
364378
device: Union[torch.device, str, None] = None,
@@ -378,6 +392,7 @@ def orb_v3_conservative_inf_mpa(
378392

379393
return model
380394

395+
381396
def orb_v3_direct_20_mpa(
382397
weights_path: str = "https://orbitalmaterials-public-models.s3.us-west-1.amazonaws.com/forcefields/orb-v3/orb-v3-direct-20-mpa-20250404.ckpt", # noqa: E501
383398
device: Union[torch.device, str, None] = None,
@@ -393,6 +408,7 @@ def orb_v3_direct_20_mpa(
393408

394409
return model
395410

411+
396412
def orb_v3_direct_inf_mpa(
397413
weights_path: str = "", # noqa: E501
398414
device: Union[torch.device, str, None] = None,
@@ -469,7 +485,9 @@ def orb_d3_sm_v2(
469485
) -> DirectForcefieldRegressor:
470486
"""Load ORB D3 small v2 with 20 max neighbors, trained on MPTraj + Alexandria."""
471487
system_config = SystemConfig(radius=6.0, max_num_neighbors=20)
472-
model = orb_v2_architecture(num_message_passing_steps=10, device=device, system_config=system_config)
488+
model = orb_v2_architecture(
489+
num_message_passing_steps=10, device=device, system_config=system_config
490+
)
473491
model = load_model_for_inference(
474492
model, weights_path, device, precision=precision, compile=compile
475493
)
@@ -485,7 +503,9 @@ def orb_d3_xs_v2(
485503
) -> DirectForcefieldRegressor:
486504
"""Load ORB D3 xs v2 with 20 max neighbors, trained on MPTraj + Alexandria."""
487505
system_config = SystemConfig(radius=6.0, max_num_neighbors=20)
488-
model = orb_v2_architecture(num_message_passing_steps=5, device=device, system_config=system_config)
506+
model = orb_v2_architecture(
507+
num_message_passing_steps=5, device=device, system_config=system_config
508+
)
489509
model = load_model_for_inference(
490510
model, weights_path, device, precision=precision, compile=compile
491511
)
@@ -547,7 +567,7 @@ def orb_v1_mptraj_only(
547567

548568
ORB_PRETRAINED_MODELS = {
549569
# most performant orb-v3 omat models
550-
"orb-v3-conservative-20-omat": orb_v3_conservative_20_omat,
570+
"orb-v3-conservative-20-omat": orb_v3_conservative_20_omat,
551571
"orb-v3-conservative-inf-omat": orb_v3_conservative_inf_omat,
552572
"orb-v3-direct-20-omat": orb_v3_direct_20_omat,
553573
"orb-v3-direct-inf-omat": orb_v3_direct_inf_omat,

tests/forcefield/conftest.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,12 @@ def gns_model():
167167

168168

169169
@pytest.fixture
170-
def conservative_regressor(gns_model, energy_head, force_head, stress_head):
170+
def conservative_regressor(gns_model, energy_head):
171171
return ConservativeForcefieldRegressor(
172-
heads={"energy": energy_head, "forces": force_head, "stress": stress_head},
172+
heads={"energy": energy_head},
173173
model=gns_model,
174174
loss_weights={
175175
"energy": 1.0,
176-
"forces": 1.0,
177-
"stress": 1.0,
178176
"grad_forces": 1.0,
179177
"grad_stress": 1.0,
180178
"rotational_grad": 1.0,

0 commit comments

Comments
 (0)