Skip to content

Commit 802c865

Browse files
DeNeutoyben rhodes
andauthored
Add net torque removal (#30)
Co-authored-by: ben rhodes <benrhodes@bens-MacBook-Pro.local>
1 parent bcbe80f commit 802c865

File tree

3 files changed

+444
-1
lines changed

3 files changed

+444
-1
lines changed

orb_models/forcefield/graph_regressor.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
target: Union[str, PropertyDefinition],
119119
dropout: Optional[float] = None,
120120
remove_mean: bool = True,
121+
remove_torque_for_nonpbc_systems: bool = True,
121122
):
122123
"""Initializes the NodeHead MLP.
123124
@@ -128,6 +129,8 @@ def __init__(
128129
target: either the name of a PropertyDefinition or a PropertyDefinition itself.
129130
dropout: The level of dropout to apply.
130131
remove_mean: Whether to remove the mean of the node features.
132+
remove_torque_for_nonpbc_systems: Whether to remove net torque from the
133+
force predictions for non-PBC systems.
131134
"""
132135
super().__init__()
133136
if isinstance(target, str):
@@ -153,11 +156,13 @@ def __init__(
153156
)
154157

155158
self.remove_mean = remove_mean
159+
self.remove_torque_for_nonpbc_systems = remove_torque_for_nonpbc_systems
156160

157161
def forward(self, batch: base.AtomGraphs) -> base.AtomGraphs:
158162
"""Predictions with raw logits (no sigmoid/softmax or any inverse transformations)."""
159163
feat = batch.node_features[_KEY]
160164
pred = self.mlp(feat)
165+
161166
if self.remove_mean:
162167
system_means = segment_ops.aggregate_nodes(
163168
pred, batch.n_node, reduction="mean"
@@ -166,6 +171,12 @@ def forward(self, batch: base.AtomGraphs) -> base.AtomGraphs:
166171
system_means, batch.n_node, dim=0
167172
)
168173
pred = pred - node_broadcasted_means
174+
175+
if self.remove_torque_for_nonpbc_systems:
176+
pred = selectively_remove_net_torque_for_nonpbc_systems(
177+
pred, batch.positions, batch.system_features["cell"], batch.n_node
178+
)
179+
169180
batch.node_features["node_pred"] = pred
170181
return batch
171182

@@ -723,3 +734,114 @@ def cross_entropy_loss(
723734
f"{metric_prefix}_loss": loss.item(),
724735
},
725736
)
737+
738+
def selectively_remove_net_torque_for_nonpbc_systems(
739+
pred: torch.Tensor,
740+
positions: torch.Tensor,
741+
cell: torch.Tensor,
742+
n_node: torch.Tensor,
743+
):
744+
"""Remove net torque from non-PBC-system forces, but preserve PBC-system forces.
745+
746+
Args:
747+
pred: The predicted forces of shape (n_atoms_in_batch, 3).
748+
positions: The positions of shape (n_atoms_in_batch, 3).
749+
cell: The cell of shape (n_batch, 3, 3).
750+
n_node: The number of nodes per graph, of shape (n_batch,).
751+
"""
752+
nopbc_graph = torch.all(cell == 0.0, dim=(1, 2))
753+
if torch.any(nopbc_graph):
754+
if torch.all(nopbc_graph):
755+
pred = remove_net_torque(positions, pred, n_node)
756+
else:
757+
# Handle a mixed batch of pbc and non-pbc systems
758+
batch_indices = torch.repeat_interleave(
759+
torch.arange(cell.size(0), device=n_node.device), n_node
760+
)
761+
nopbc_atom = nopbc_graph[batch_indices]
762+
adjusted_pred_non_pbc = remove_net_torque(
763+
positions[nopbc_atom], pred[nopbc_atom], n_node[nopbc_graph]
764+
)
765+
pred = pred.clone()
766+
pred[nopbc_atom] = adjusted_pred_non_pbc
767+
768+
return pred
769+
770+
771+
def remove_net_torque(
772+
positions: torch.Tensor,
773+
forces: torch.Tensor,
774+
n_nodes: torch.Tensor,
775+
) -> torch.Tensor:
776+
"""Adjust the predicted forces to eliminate net torque for each graph in the batch.
777+
778+
We frame the problem of net-torque-elimination as a constrained optimisation problem;
779+
what is the minimal additive adjustment (in L2 norm) that eliminates net torque?
780+
781+
This analytically solvable with Lagrange multipliers and the solution involves cheap
782+
linear algebra operations (cross products and the inversion of 3x3 matrices).
783+
784+
Args:
785+
positions : torch.Tensor of shape (N, 3)
786+
Positions of atoms (concatenated for all graphs in the batch).
787+
forces : torch.Tensor of shape (N, 3)
788+
Predicted forces on atoms.
789+
n_nodes : torch.Tensor of shape (B,)
790+
Number of nodes in each graph, where B is the number of graphs in the batch.
791+
792+
Returns:
793+
adjusted_forces : torch.Tensor of shape (N, 3)
794+
Adjusted forces with zero net torque and net force for each graph.
795+
"""
796+
B = n_nodes.shape[0]
797+
tau_total, r = compute_net_torque(positions, forces, n_nodes)
798+
799+
# Compute scalar s per graph: sum_i ||r_i||^2
800+
r_squared = torch.sum(r**2, dim=1) # Shape: (N,)
801+
s = segment_ops.aggregate_nodes(r_squared, n_nodes, "sum") # Shape: (B,)
802+
803+
# Compute matrix S per graph: sum_i outer(r_i, r_i)
804+
r_unsqueezed = r.unsqueeze(2) # Shape: (N, 3, 1)
805+
r_T_unsqueezed = r.unsqueeze(1) # Shape: (N, 1, 3)
806+
outer_products = r_unsqueezed @ r_T_unsqueezed # Shape: (N, 3, 3)
807+
S = segment_ops.aggregate_nodes(outer_products, n_nodes, "sum") # Shape: (B, 3, 3)
808+
809+
# Compute M = S - sI
810+
I = ( # noqa: E741
811+
torch.eye(3, device=positions.device).unsqueeze(0).expand(B, -1, -1)
812+
) # Shape: (B, 3, 3)
813+
M = S - (s.view(-1, 1, 1)) * I # Shape: (B, 3, 3)
814+
815+
# Right-hand side vector b per graph
816+
b = -tau_total # Shape: (B, 3)
817+
818+
# Solve M * mu = b for mu per graph
819+
try:
820+
mu = torch.linalg.solve(M, b.unsqueeze(2)).squeeze(2) # Shape: (B, 3)
821+
except RuntimeError:
822+
# Handle singular matrix M by using the pseudo-inverse
823+
M_pinv = torch.linalg.pinv(M) # Shape: (B, 3, 3)
824+
mu = torch.bmm(M_pinv, b.unsqueeze(2)).squeeze(2) # Shape: (B, 3)
825+
826+
# Compute adjustments to forces
827+
mu_batch = torch.repeat_interleave(mu, n_nodes, dim=0) # Shape: (N, 3)
828+
forces_delta = torch.linalg.cross(r, mu_batch) # Shape: (N, 3)
829+
830+
# Adjusted forces
831+
adjusted_forces = forces + forces_delta # Shape: (N, 3)
832+
833+
return adjusted_forces
834+
835+
836+
def compute_net_torque(
837+
positions: torch.Tensor,
838+
forces: torch.Tensor,
839+
n_nodes: torch.Tensor,
840+
) -> Tuple[torch.Tensor, torch.Tensor]:
841+
"""Compute the net torque on a system of particles."""
842+
com = segment_ops.aggregate_nodes(positions, n_nodes, "mean")
843+
com_repeat = torch.repeat_interleave(com, n_nodes, dim=0) # Shape: (N, 3)
844+
com_relative_positions = positions - com_repeat # Shape: (N, 3)
845+
torques = torch.linalg.cross(com_relative_positions, forces) # Shape: (N, 3)
846+
net_torque = segment_ops.aggregate_nodes(torques, n_nodes, "sum")
847+
return net_torque, com_relative_positions

tests/test_model_backward_compatibility.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_optimization(model_fn):
4141
atoms = bulk("Cu", "fcc", a=3.58, cubic=True)
4242
orbff = model_fn(device="cpu")
4343
calc = ORBCalculator(orbff, device="cpu")
44-
atoms.set_calculator(calc)
44+
atoms.calc = calc
4545
atoms.rattle(0.5)
4646
rattled_energy = atoms.get_potential_energy()
4747
dyn = BFGS(atoms)

0 commit comments

Comments
 (0)