@@ -118,6 +118,7 @@ def __init__(
118118 target : Union [str , PropertyDefinition ],
119119 dropout : Optional [float ] = None ,
120120 remove_mean : bool = True ,
121+ remove_torque_for_nonpbc_systems : bool = True ,
121122 ):
122123 """Initializes the NodeHead MLP.
123124
@@ -128,6 +129,8 @@ def __init__(
128129 target: either the name of a PropertyDefinition or a PropertyDefinition itself.
129130 dropout: The level of dropout to apply.
130131 remove_mean: Whether to remove the mean of the node features.
132+ remove_torque_for_nonpbc_systems: Whether to remove net torque from the
133+ force predictions for non-PBC systems.
131134 """
132135 super ().__init__ ()
133136 if isinstance (target , str ):
@@ -153,11 +156,13 @@ def __init__(
153156 )
154157
155158 self .remove_mean = remove_mean
159+ self .remove_torque_for_nonpbc_systems = remove_torque_for_nonpbc_systems
156160
157161 def forward (self , batch : base .AtomGraphs ) -> base .AtomGraphs :
158162 """Predictions with raw logits (no sigmoid/softmax or any inverse transformations)."""
159163 feat = batch .node_features [_KEY ]
160164 pred = self .mlp (feat )
165+
161166 if self .remove_mean :
162167 system_means = segment_ops .aggregate_nodes (
163168 pred , batch .n_node , reduction = "mean"
@@ -166,6 +171,12 @@ def forward(self, batch: base.AtomGraphs) -> base.AtomGraphs:
166171 system_means , batch .n_node , dim = 0
167172 )
168173 pred = pred - node_broadcasted_means
174+
175+ if self .remove_torque_for_nonpbc_systems :
176+ pred = selectively_remove_net_torque_for_nonpbc_systems (
177+ pred , batch .positions , batch .system_features ["cell" ], batch .n_node
178+ )
179+
169180 batch .node_features ["node_pred" ] = pred
170181 return batch
171182
@@ -723,3 +734,114 @@ def cross_entropy_loss(
723734 f"{ metric_prefix } _loss" : loss .item (),
724735 },
725736 )
737+
738+ def selectively_remove_net_torque_for_nonpbc_systems (
739+ pred : torch .Tensor ,
740+ positions : torch .Tensor ,
741+ cell : torch .Tensor ,
742+ n_node : torch .Tensor ,
743+ ):
744+ """Remove net torque from non-PBC-system forces, but preserve PBC-system forces.
745+
746+ Args:
747+ pred: The predicted forces of shape (n_atoms_in_batch, 3).
748+ positions: The positions of shape (n_atoms_in_batch, 3).
749+ cell: The cell of shape (n_batch, 3, 3).
750+ n_node: The number of nodes per graph, of shape (n_batch,).
751+ """
752+ nopbc_graph = torch .all (cell == 0.0 , dim = (1 , 2 ))
753+ if torch .any (nopbc_graph ):
754+ if torch .all (nopbc_graph ):
755+ pred = remove_net_torque (positions , pred , n_node )
756+ else :
757+ # Handle a mixed batch of pbc and non-pbc systems
758+ batch_indices = torch .repeat_interleave (
759+ torch .arange (cell .size (0 ), device = n_node .device ), n_node
760+ )
761+ nopbc_atom = nopbc_graph [batch_indices ]
762+ adjusted_pred_non_pbc = remove_net_torque (
763+ positions [nopbc_atom ], pred [nopbc_atom ], n_node [nopbc_graph ]
764+ )
765+ pred = pred .clone ()
766+ pred [nopbc_atom ] = adjusted_pred_non_pbc
767+
768+ return pred
769+
770+
771+ def remove_net_torque (
772+ positions : torch .Tensor ,
773+ forces : torch .Tensor ,
774+ n_nodes : torch .Tensor ,
775+ ) -> torch .Tensor :
776+ """Adjust the predicted forces to eliminate net torque for each graph in the batch.
777+
778+ We frame the problem of net-torque-elimination as a constrained optimisation problem;
779+ what is the minimal additive adjustment (in L2 norm) that eliminates net torque?
780+
781+ This analytically solvable with Lagrange multipliers and the solution involves cheap
782+ linear algebra operations (cross products and the inversion of 3x3 matrices).
783+
784+ Args:
785+ positions : torch.Tensor of shape (N, 3)
786+ Positions of atoms (concatenated for all graphs in the batch).
787+ forces : torch.Tensor of shape (N, 3)
788+ Predicted forces on atoms.
789+ n_nodes : torch.Tensor of shape (B,)
790+ Number of nodes in each graph, where B is the number of graphs in the batch.
791+
792+ Returns:
793+ adjusted_forces : torch.Tensor of shape (N, 3)
794+ Adjusted forces with zero net torque and net force for each graph.
795+ """
796+ B = n_nodes .shape [0 ]
797+ tau_total , r = compute_net_torque (positions , forces , n_nodes )
798+
799+ # Compute scalar s per graph: sum_i ||r_i||^2
800+ r_squared = torch .sum (r ** 2 , dim = 1 ) # Shape: (N,)
801+ s = segment_ops .aggregate_nodes (r_squared , n_nodes , "sum" ) # Shape: (B,)
802+
803+ # Compute matrix S per graph: sum_i outer(r_i, r_i)
804+ r_unsqueezed = r .unsqueeze (2 ) # Shape: (N, 3, 1)
805+ r_T_unsqueezed = r .unsqueeze (1 ) # Shape: (N, 1, 3)
806+ outer_products = r_unsqueezed @ r_T_unsqueezed # Shape: (N, 3, 3)
807+ S = segment_ops .aggregate_nodes (outer_products , n_nodes , "sum" ) # Shape: (B, 3, 3)
808+
809+ # Compute M = S - sI
810+ I = ( # noqa: E741
811+ torch .eye (3 , device = positions .device ).unsqueeze (0 ).expand (B , - 1 , - 1 )
812+ ) # Shape: (B, 3, 3)
813+ M = S - (s .view (- 1 , 1 , 1 )) * I # Shape: (B, 3, 3)
814+
815+ # Right-hand side vector b per graph
816+ b = - tau_total # Shape: (B, 3)
817+
818+ # Solve M * mu = b for mu per graph
819+ try :
820+ mu = torch .linalg .solve (M , b .unsqueeze (2 )).squeeze (2 ) # Shape: (B, 3)
821+ except RuntimeError :
822+ # Handle singular matrix M by using the pseudo-inverse
823+ M_pinv = torch .linalg .pinv (M ) # Shape: (B, 3, 3)
824+ mu = torch .bmm (M_pinv , b .unsqueeze (2 )).squeeze (2 ) # Shape: (B, 3)
825+
826+ # Compute adjustments to forces
827+ mu_batch = torch .repeat_interleave (mu , n_nodes , dim = 0 ) # Shape: (N, 3)
828+ forces_delta = torch .linalg .cross (r , mu_batch ) # Shape: (N, 3)
829+
830+ # Adjusted forces
831+ adjusted_forces = forces + forces_delta # Shape: (N, 3)
832+
833+ return adjusted_forces
834+
835+
836+ def compute_net_torque (
837+ positions : torch .Tensor ,
838+ forces : torch .Tensor ,
839+ n_nodes : torch .Tensor ,
840+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
841+ """Compute the net torque on a system of particles."""
842+ com = segment_ops .aggregate_nodes (positions , n_nodes , "mean" )
843+ com_repeat = torch .repeat_interleave (com , n_nodes , dim = 0 ) # Shape: (N, 3)
844+ com_relative_positions = positions - com_repeat # Shape: (N, 3)
845+ torques = torch .linalg .cross (com_relative_positions , forces ) # Shape: (N, 3)
846+ net_torque = segment_ops .aggregate_nodes (torques , n_nodes , "sum" )
847+ return net_torque , com_relative_positions
0 commit comments