Skip to content

Commit 021dc7b

Browse files
authored
Update conservative model for V3 (#68)
* add pair repulsion * update conf head * update calculator * add spherical harmonics * update conservative model to not require direct heads * lint * tweak * fix
1 parent 2ccc0e2 commit 021dc7b

File tree

12 files changed

+2520
-158
lines changed

12 files changed

+2520
-158
lines changed

orb_models/forcefield/angular.py

Lines changed: 1561 additions & 0 deletions
Large diffs are not rendered by default.

orb_models/forcefield/calculator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ def __init__(
7474

7575
self.implemented_properties = model.properties # type: ignore
7676

77+
# TODO: Untangle the spaghetti of how we handle the naming for the heads.
78+
# This is required because ASE will check the implemented_properties for
79+
# the existence of the property before calling the calculator, so it's not
80+
# sufficient to just return the property names from the model and handle
81+
# the conservative case in `calculate`.
82+
if self.conservative:
83+
self.implemented_properties.extend(["forces", "stress"])
84+
7785
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
7886
"""Calculate properties.
7987
@@ -103,7 +111,15 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
103111
batch = batch.to(self.device) # type: ignore
104112
out = self.model.predict(batch) # type: ignore
105113
self.results = {}
114+
model_has_direct_heads = (
115+
"forces" in self.model.heads and "stress" in self.model.heads # type: ignore
116+
)
106117
for property in self.implemented_properties:
118+
# The model has no direct heads for forces/stress, so we skip these properties.
119+
if not model_has_direct_heads and property == "forces":
120+
continue
121+
if not model_has_direct_heads and property == "stress":
122+
continue
107123
_property = "energy" if property == "free_energy" else property
108124
self.results[property] = to_numpy(out[_property].squeeze())
109125

orb_models/forcefield/conservative_regressor.py

Lines changed: 143 additions & 110 deletions
Large diffs are not rendered by default.

orb_models/forcefield/forcefield_heads.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def __init__(
356356
dropout: Optional[float] = None,
357357
activation: str = "ssp",
358358
detach_node_features: bool = True,
359+
hard_clamp: bool = True,
359360
):
360361
"""Initializes the ConfidenceHead MLP.
361362
@@ -371,13 +372,16 @@ def __init__(
371372
detach_node_features: If True, detaches node features from computational graph.
372373
This means that the confidence loss has no impact on training the underlying
373374
forcefield model.
375+
hard_clamp: If True, ignore any errors above max_error such that they do not contribute
376+
to the loss, rather than just clamping them to the max_bin.
374377
"""
375378
super().__init__()
376379
self.target = _confidence
377380
self.num_bins = num_bins
378381
self.max_error = max_error
379382
self.detach_node_features = detach_node_features
380-
# Define bin edges (from 0 to max_error)
383+
self.hard_clamp = hard_clamp
384+
self.ignore_index = -100
381385
if binning_scale == "linear":
382386
bins = torch.linspace(0.0, max_error, int(num_bins + 1))
383387
elif binning_scale == "exponential":
@@ -409,9 +413,13 @@ def get_error_bins(self, force_error: torch.Tensor) -> torch.Tensor:
409413
Returns:
410414
Bin indices of shape (n_atoms,)
411415
"""
412-
force_error = torch.clamp(force_error, 0, self.max_error)
413-
bins = torch.bucketize(force_error, self.bin_edges) - 1 # type: ignore
414-
return torch.clamp(bins, 0, self.num_bins - 1)
416+
clamped_error = torch.clamp(force_error, 0, self.max_error)
417+
bins = torch.bucketize(clamped_error, self.bin_edges) - 1 # type: ignore
418+
clamped = torch.clamp(bins, 0, self.num_bins - 1)
419+
420+
if self.hard_clamp:
421+
clamped[force_error > self.max_error] = self.ignore_index
422+
return clamped
415423

416424
def forward(
417425
self, node_features: torch.Tensor, batch: base.AtomGraphs
@@ -451,7 +459,9 @@ def loss(
451459
true_bins = self.get_error_bins(force_error)
452460

453461
# Cross entropy loss
454-
loss = torch.nn.functional.cross_entropy(confidence_logits, true_bins)
462+
loss = torch.nn.functional.cross_entropy(
463+
confidence_logits, true_bins, ignore_index=self.ignore_index
464+
)
455465

456466
# Calculate accuracy
457467
pred_bins = torch.argmax(confidence_logits, dim=-1)

orb_models/forcefield/gns.py

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Pyg implementation of Graph Net Simulator."""
22

33
from collections import OrderedDict
4-
from typing import Callable, List, Optional, Literal, Dict, Any, Tuple
4+
from typing import Callable, List, Optional, Literal, Dict, Any, Tuple, Union
55

66
import torch
77
from torch import nn
@@ -10,6 +10,7 @@
1010
from orb_models.forcefield import base, segment_ops
1111
from orb_models.forcefield.nn_util import build_mlp, get_cutoff, mlp_and_layer_norm
1212
from orb_models.forcefield.embedding import AtomEmbedding, AtomEmbeddingBag
13+
from orb_models.forcefield.angular import UnitVector
1314

1415
_KEY = "feat"
1516

@@ -274,23 +275,31 @@ def forward(self, nodes: torch.Tensor) -> torch.Tensor:
274275
class MoleculeGNS(nn.Module):
275276
"""GNS that works on molecular data."""
276277

277-
_deprecated_args = ["noise_scale", "add_virtual_node", "self_cond", "interactions"]
278+
_deprecated_args = [
279+
"noise_scale",
280+
"add_virtual_node",
281+
"self_cond",
282+
"interactions",
283+
"num_node_in_features",
284+
"num_edge_in_features",
285+
]
278286

279287
def __init__(
280288
self,
281-
num_node_in_features: int,
282-
num_node_out_features: int,
283-
num_edge_in_features: int,
284289
latent_dim: int,
285290
num_message_passing_steps: int,
286291
num_mlp_layers: int,
287292
mlp_hidden_dim: int,
288293
rbf_transform: Callable,
289-
node_feature_names: Optional[List[str]] = None,
290-
edge_feature_names: Optional[List[str]] = None,
294+
angular_transform: Optional[Callable] = None,
295+
outer_product_with_cutoff: bool = False,
296+
use_embedding: bool = False, # atom type embedding
291297
expects_atom_type_embedding: bool = False,
292-
use_embedding: bool = False,
293298
interaction_params: Optional[Dict[str, Any]] = None,
299+
num_node_out_features: int = 3,
300+
extra_embed_dims: Union[int, Tuple[int, int]] = 0,
301+
node_feature_names: Optional[List[str]] = None,
302+
edge_feature_names: Optional[List[str]] = None,
294303
checkpoint: Optional[str] = None,
295304
activation="ssp",
296305
mlp_norm: str = "layer_norm",
@@ -299,28 +308,33 @@ def __init__(
299308
"""Initializes the molecular GNS.
300309
301310
Args:
302-
num_node_in_features (int): Number input nodes features.
303-
num_node_out_features (int): Number output nodes features.
304-
num_edge_in_features (int): Number input edge features.
305311
latent_dim (int): Latent dimension of processor.
306312
num_message_passing_steps (int): Number of message passing steps.
307313
num_mlp_layers (int): Number of MLP layers.
308314
mlp_hidden_dim (int): MLP hidden dimension.
315+
rbf_transform (Callable): A function that takes in edge lengths and returns
316+
a tensor of RBF features.
317+
angular_transform (Callable): A function that takes in edge vectors and
318+
returns a tensor of angular features.
319+
outer_product_with_cutoff (bool): Create initial edge embeddings via
320+
an outer product of rbf and angular embeddings and a envelope cutoff.
321+
use_embedding: Whether to embed atom types using an embedding table or embedding bag.
322+
expects_atom_type_embedding (bool): Whether or not the model expects the input
323+
to be pre-embedded. This is used for atom type models, because the one-hot
324+
embedding is noised, rather than being explicitly one-hot.
325+
interaction_params (Optional[Dict[str, Any]]): Additional parameters
326+
to pass to the interaction network.
327+
num_node_out_features (int): Number output nodes features.
328+
extra_embed_dims (int): Number of extra embedding dimensions to use.
329+
If an int, both the node and edge embeddings will have this number of extra dims.
330+
If a tuple, then it is interpreted as [extra_node_embed_dim, extra_edge_embed_dim].
309331
node_feature_names (List[str]): Which tensors from batch.node_features to
310332
concatenate to form the initial node latents. Note: These are "extra"
311333
features - we assume the base atomic number representation is already
312334
included.
313335
edge_feature_names (List[str]): Which tensors from batch.edge_features to
314336
concatenate to form the initial edge latents. Note: These are "extra"
315337
features - we assume the base edge vector features are already included.
316-
rbf_transform: An RBF transform to use for the edge features.
317-
expects_atom_type_embedding (bool): Whether or not the model expects
318-
the input to be pre-embedded. This is used for atom type models,
319-
because the one-hot embedding is noised, rather than being
320-
explicitly one-hot.
321-
use_embedding: Whether to embed atom types using an embedding table or embedding bag.
322-
interaction_params (Optional[Dict[str, Any]]): Additional parameters
323-
to pass to the interaction network.
324338
checkpoint (bool): Whether or not to use checkpointing.
325339
activation (str): Activation function to use.
326340
mlp_norm (str): Normalization layer to use in the MLP.
@@ -333,9 +347,42 @@ def __init__(
333347
f"The following kwargs are not arguments to GraphRegressor: {kwargs.keys()}"
334348
)
335349

350+
self.node_feature_names = node_feature_names or []
351+
self.edge_feature_names = edge_feature_names or []
352+
353+
# Edge embedding
354+
self.outer_product_with_cutoff = outer_product_with_cutoff
355+
self.rbf_transform = rbf_transform
356+
if angular_transform is None:
357+
angular_transform = UnitVector()
358+
self.angular_transform = angular_transform
359+
if self.outer_product_with_cutoff:
360+
self.edge_embed_size = rbf_transform.num_bases * angular_transform.dim # type: ignore
361+
else:
362+
if hasattr(rbf_transform, "num_bases"):
363+
num_bases = rbf_transform.num_bases
364+
else:
365+
num_bases = rbf_transform.keywords["num_bases"] # type: ignore
366+
self.edge_embed_size = num_bases + angular_transform.dim # type: ignore
367+
368+
# Node embedding
369+
self.expects_atom_type_embedding = expects_atom_type_embedding
370+
self.use_embedding = use_embedding
371+
if self.use_embedding:
372+
self.node_embed_size = latent_dim
373+
if self.expects_atom_type_embedding:
374+
# Use embedding bag for atom type diffusion
375+
self.atom_emb = AtomEmbeddingBag(self.node_embed_size, 118)
376+
else:
377+
self.atom_emb = AtomEmbedding(self.node_embed_size, 118) # type: ignore
378+
else:
379+
self.node_embed_size = 118
380+
if isinstance(extra_embed_dims, int):
381+
extra_embed_dims = (extra_embed_dims, extra_embed_dims) # type: ignore
382+
336383
self._encoder = Encoder(
337-
num_node_in_features=num_node_in_features,
338-
num_edge_in_features=num_edge_in_features,
384+
num_node_in_features=self.node_embed_size + extra_embed_dims[0],
385+
num_edge_in_features=self.edge_embed_size + extra_embed_dims[1],
339386
latent_dim=latent_dim,
340387
num_mlp_layers=num_mlp_layers,
341388
mlp_hidden_dim=mlp_hidden_dim,
@@ -370,19 +417,6 @@ def __init__(
370417
checkpoint=checkpoint,
371418
activation=activation,
372419
)
373-
self.rbf = rbf_transform
374-
self.expects_atom_type_embedding = expects_atom_type_embedding
375-
self.use_embedding = use_embedding
376-
377-
if self.use_embedding:
378-
if self.expects_atom_type_embedding:
379-
# Use embedding bag for atom type diffusion
380-
self.atom_emb = AtomEmbeddingBag(latent_dim, 118)
381-
else:
382-
self.atom_emb = AtomEmbedding(latent_dim, 118) # type: ignore
383-
384-
self.node_feature_names = node_feature_names or []
385-
self.edge_feature_names = edge_feature_names or []
386420

387421
def forward(self, batch: base.AtomGraphs) -> Dict[str, torch.Tensor]:
388422
"""Encode a graph using molecular GNS.
@@ -455,14 +489,22 @@ def featurize_edges(self, batch: base.AtomGraphs) -> torch.Tensor:
455489
vectors = batch.edge_features["vectors"]
456490
# replace 0s with 1s to avoid division by zero
457491
lengths = vectors.norm(dim=1)
458-
non_zero_divisor = torch.where(lengths == 0, torch.ones_like(lengths), lengths)
459-
unit_vectors = vectors / non_zero_divisor.unsqueeze(1)
460-
rbfs = self.rbf(lengths)
461-
edge_features = torch.cat([rbfs, unit_vectors], dim=1)
462492

463-
# This is for backward compatibility with old code
464-
# Configs now assume that the base model features are already included
465-
# and only specify "extra" features
493+
angular_embedding = self.angular_transform(vectors) # (nedges, x)
494+
rbfs = self.rbf_transform(lengths) # (nedges, y)
495+
496+
if self.outer_product_with_cutoff:
497+
cutoff = get_cutoff(lengths)
498+
# (nedges, x, y)
499+
outer_product = rbfs[:, :, None] * angular_embedding[:, None, :]
500+
# (nedges, x * y)
501+
edge_features = cutoff * outer_product.view(
502+
vectors.shape[0], self.edge_embed_size
503+
)
504+
else:
505+
edge_features = torch.cat([rbfs, angular_embedding], dim=1)
506+
507+
# For backwards compatibility, exclude 'feat'
466508
feature_names = [k for k in self.edge_feature_names if k != "feat"]
467509
return torch.cat(
468510
[edge_features, *[batch.edge_features[k] for k in feature_names]], dim=-1

orb_models/forcefield/graph_regressor.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from orb_models.forcefield import segment_ops
1818
from orb_models.forcefield.gns import MoleculeGNS
1919
from orb_models.forcefield.load import _load_forcefield_state_dict
20+
from orb_models.forcefield.pair_repulsion import ZBLBasis
2021

2122

2223
class GraphRegressor(nn.Module):
@@ -35,6 +36,7 @@ def __init__(
3536
model_requires_grad: bool = True,
3637
cutoff_layers: Optional[int] = None,
3738
loss_weights: Optional[Dict[str, float]] = None,
39+
pair_repulsion: bool = False,
3840
) -> None:
3941
"""Initializes the GraphRegressor.
4042
@@ -65,6 +67,14 @@ def __init__(
6567
self.loss_weights = loss_weights
6668
self.model_requires_grad = model_requires_grad
6769

70+
self.pair_repulsion = pair_repulsion
71+
if self.pair_repulsion:
72+
self.pair_repulsion_fn = ZBLBasis(
73+
p=6,
74+
node_aggregation="sum",
75+
compute_gradients=True,
76+
)
77+
6878
self.model = model
6979
if self.cutoff_layers is not None:
7080
gns = (
@@ -85,6 +95,19 @@ def forward(
8595
for name, head in self.heads.items():
8696
res = head(node_features, batch)
8797
out[name] = res
98+
99+
if self.pair_repulsion:
100+
out_pair_raw = self.pair_repulsion_fn(batch)
101+
for name, head in self.heads.items():
102+
raw = out_pair_raw.get(name, None)
103+
if raw is None:
104+
continue
105+
if name == "energy" and head.atom_avg:
106+
raw = (raw / batch.n_node).unsqueeze(1)
107+
out[name] = out[name] + head.normalizer(
108+
raw,
109+
online=False,
110+
)
88111
return out
89112

90113
def predict(
@@ -100,6 +123,15 @@ def predict(
100123
output[name] = _split_prediction(pred, batch.n_node)
101124
else:
102125
output[name] = pred
126+
127+
if self.pair_repulsion:
128+
out_pair_raw = self.pair_repulsion_fn(batch)
129+
for name, head in self.heads.items():
130+
raw = out_pair_raw.get(name, None)
131+
if raw is None:
132+
continue
133+
output[name] = output[name] + raw
134+
103135
return output
104136

105137
def loss(self, batch: base.AtomGraphs) -> base.ModelOutput:

0 commit comments

Comments
 (0)