11"""Pyg implementation of Graph Net Simulator."""
22
33from collections import OrderedDict
4- from typing import Callable , List , Optional , Literal , Dict , Any , Tuple
4+ from typing import Callable , List , Optional , Literal , Dict , Any , Tuple , Union
55
66import torch
77from torch import nn
1010from orb_models .forcefield import base , segment_ops
1111from orb_models .forcefield .nn_util import build_mlp , get_cutoff , mlp_and_layer_norm
1212from orb_models .forcefield .embedding import AtomEmbedding , AtomEmbeddingBag
13+ from orb_models .forcefield .angular import UnitVector
1314
1415_KEY = "feat"
1516
@@ -274,23 +275,31 @@ def forward(self, nodes: torch.Tensor) -> torch.Tensor:
274275class MoleculeGNS (nn .Module ):
275276 """GNS that works on molecular data."""
276277
277- _deprecated_args = ["noise_scale" , "add_virtual_node" , "self_cond" , "interactions" ]
278+ _deprecated_args = [
279+ "noise_scale" ,
280+ "add_virtual_node" ,
281+ "self_cond" ,
282+ "interactions" ,
283+ "num_node_in_features" ,
284+ "num_edge_in_features" ,
285+ ]
278286
279287 def __init__ (
280288 self ,
281- num_node_in_features : int ,
282- num_node_out_features : int ,
283- num_edge_in_features : int ,
284289 latent_dim : int ,
285290 num_message_passing_steps : int ,
286291 num_mlp_layers : int ,
287292 mlp_hidden_dim : int ,
288293 rbf_transform : Callable ,
289- node_feature_names : Optional [List [str ]] = None ,
290- edge_feature_names : Optional [List [str ]] = None ,
294+ angular_transform : Optional [Callable ] = None ,
295+ outer_product_with_cutoff : bool = False ,
296+ use_embedding : bool = False , # atom type embedding
291297 expects_atom_type_embedding : bool = False ,
292- use_embedding : bool = False ,
293298 interaction_params : Optional [Dict [str , Any ]] = None ,
299+ num_node_out_features : int = 3 ,
300+ extra_embed_dims : Union [int , Tuple [int , int ]] = 0 ,
301+ node_feature_names : Optional [List [str ]] = None ,
302+ edge_feature_names : Optional [List [str ]] = None ,
294303 checkpoint : Optional [str ] = None ,
295304 activation = "ssp" ,
296305 mlp_norm : str = "layer_norm" ,
@@ -299,28 +308,33 @@ def __init__(
299308 """Initializes the molecular GNS.
300309
301310 Args:
302- num_node_in_features (int): Number input nodes features.
303- num_node_out_features (int): Number output nodes features.
304- num_edge_in_features (int): Number input edge features.
305311 latent_dim (int): Latent dimension of processor.
306312 num_message_passing_steps (int): Number of message passing steps.
307313 num_mlp_layers (int): Number of MLP layers.
308314 mlp_hidden_dim (int): MLP hidden dimension.
315+ rbf_transform (Callable): A function that takes in edge lengths and returns
316+ a tensor of RBF features.
317+ angular_transform (Callable): A function that takes in edge vectors and
318+ returns a tensor of angular features.
319+ outer_product_with_cutoff (bool): Create initial edge embeddings via
320+ an outer product of rbf and angular embeddings and a envelope cutoff.
321+ use_embedding: Whether to embed atom types using an embedding table or embedding bag.
322+ expects_atom_type_embedding (bool): Whether or not the model expects the input
323+ to be pre-embedded. This is used for atom type models, because the one-hot
324+ embedding is noised, rather than being explicitly one-hot.
325+ interaction_params (Optional[Dict[str, Any]]): Additional parameters
326+ to pass to the interaction network.
327+ num_node_out_features (int): Number output nodes features.
328+ extra_embed_dims (int): Number of extra embedding dimensions to use.
329+ If an int, both the node and edge embeddings will have this number of extra dims.
330+ If a tuple, then it is interpreted as [extra_node_embed_dim, extra_edge_embed_dim].
309331 node_feature_names (List[str]): Which tensors from batch.node_features to
310332 concatenate to form the initial node latents. Note: These are "extra"
311333 features - we assume the base atomic number representation is already
312334 included.
313335 edge_feature_names (List[str]): Which tensors from batch.edge_features to
314336 concatenate to form the initial edge latents. Note: These are "extra"
315337 features - we assume the base edge vector features are already included.
316- rbf_transform: An RBF transform to use for the edge features.
317- expects_atom_type_embedding (bool): Whether or not the model expects
318- the input to be pre-embedded. This is used for atom type models,
319- because the one-hot embedding is noised, rather than being
320- explicitly one-hot.
321- use_embedding: Whether to embed atom types using an embedding table or embedding bag.
322- interaction_params (Optional[Dict[str, Any]]): Additional parameters
323- to pass to the interaction network.
324338 checkpoint (bool): Whether or not to use checkpointing.
325339 activation (str): Activation function to use.
326340 mlp_norm (str): Normalization layer to use in the MLP.
@@ -333,9 +347,42 @@ def __init__(
333347 f"The following kwargs are not arguments to GraphRegressor: { kwargs .keys ()} "
334348 )
335349
350+ self .node_feature_names = node_feature_names or []
351+ self .edge_feature_names = edge_feature_names or []
352+
353+ # Edge embedding
354+ self .outer_product_with_cutoff = outer_product_with_cutoff
355+ self .rbf_transform = rbf_transform
356+ if angular_transform is None :
357+ angular_transform = UnitVector ()
358+ self .angular_transform = angular_transform
359+ if self .outer_product_with_cutoff :
360+ self .edge_embed_size = rbf_transform .num_bases * angular_transform .dim # type: ignore
361+ else :
362+ if hasattr (rbf_transform , "num_bases" ):
363+ num_bases = rbf_transform .num_bases
364+ else :
365+ num_bases = rbf_transform .keywords ["num_bases" ] # type: ignore
366+ self .edge_embed_size = num_bases + angular_transform .dim # type: ignore
367+
368+ # Node embedding
369+ self .expects_atom_type_embedding = expects_atom_type_embedding
370+ self .use_embedding = use_embedding
371+ if self .use_embedding :
372+ self .node_embed_size = latent_dim
373+ if self .expects_atom_type_embedding :
374+ # Use embedding bag for atom type diffusion
375+ self .atom_emb = AtomEmbeddingBag (self .node_embed_size , 118 )
376+ else :
377+ self .atom_emb = AtomEmbedding (self .node_embed_size , 118 ) # type: ignore
378+ else :
379+ self .node_embed_size = 118
380+ if isinstance (extra_embed_dims , int ):
381+ extra_embed_dims = (extra_embed_dims , extra_embed_dims ) # type: ignore
382+
336383 self ._encoder = Encoder (
337- num_node_in_features = num_node_in_features ,
338- num_edge_in_features = num_edge_in_features ,
384+ num_node_in_features = self . node_embed_size + extra_embed_dims [ 0 ] ,
385+ num_edge_in_features = self . edge_embed_size + extra_embed_dims [ 1 ] ,
339386 latent_dim = latent_dim ,
340387 num_mlp_layers = num_mlp_layers ,
341388 mlp_hidden_dim = mlp_hidden_dim ,
@@ -370,19 +417,6 @@ def __init__(
370417 checkpoint = checkpoint ,
371418 activation = activation ,
372419 )
373- self .rbf = rbf_transform
374- self .expects_atom_type_embedding = expects_atom_type_embedding
375- self .use_embedding = use_embedding
376-
377- if self .use_embedding :
378- if self .expects_atom_type_embedding :
379- # Use embedding bag for atom type diffusion
380- self .atom_emb = AtomEmbeddingBag (latent_dim , 118 )
381- else :
382- self .atom_emb = AtomEmbedding (latent_dim , 118 ) # type: ignore
383-
384- self .node_feature_names = node_feature_names or []
385- self .edge_feature_names = edge_feature_names or []
386420
387421 def forward (self , batch : base .AtomGraphs ) -> Dict [str , torch .Tensor ]:
388422 """Encode a graph using molecular GNS.
@@ -455,14 +489,22 @@ def featurize_edges(self, batch: base.AtomGraphs) -> torch.Tensor:
455489 vectors = batch .edge_features ["vectors" ]
456490 # replace 0s with 1s to avoid division by zero
457491 lengths = vectors .norm (dim = 1 )
458- non_zero_divisor = torch .where (lengths == 0 , torch .ones_like (lengths ), lengths )
459- unit_vectors = vectors / non_zero_divisor .unsqueeze (1 )
460- rbfs = self .rbf (lengths )
461- edge_features = torch .cat ([rbfs , unit_vectors ], dim = 1 )
462492
463- # This is for backward compatibility with old code
464- # Configs now assume that the base model features are already included
465- # and only specify "extra" features
493+ angular_embedding = self .angular_transform (vectors ) # (nedges, x)
494+ rbfs = self .rbf_transform (lengths ) # (nedges, y)
495+
496+ if self .outer_product_with_cutoff :
497+ cutoff = get_cutoff (lengths )
498+ # (nedges, x, y)
499+ outer_product = rbfs [:, :, None ] * angular_embedding [:, None , :]
500+ # (nedges, x * y)
501+ edge_features = cutoff * outer_product .view (
502+ vectors .shape [0 ], self .edge_embed_size
503+ )
504+ else :
505+ edge_features = torch .cat ([rbfs , angular_embedding ], dim = 1 )
506+
507+ # For backwards compatibility, exclude 'feat'
466508 feature_names = [k for k in self .edge_feature_names if k != "feat" ]
467509 return torch .cat (
468510 [edge_features , * [batch .edge_features [k ] for k in feature_names ]], dim = - 1
0 commit comments