Skip to content

Commit 453c8e2

Browse files
authored
Added implementation of the Kabsch algorithm in ARCSpecies API and as a standalone function in converter. (#819)
### About the algorithm Kabsch’s algorithm computes the optimal rotation that superimposes two sets of vectors (represented as matrices) by minimizing the root-mean-square deviation (RMSD) between them. For numerical stability and robustness, this implementation relies on the optimized routine provided by SciPy via [scipy.spatial.transform.Rotation.align_vectors](https://docs.scipy.org/doc/scipy-1.16.2/reference/generated/scipy.spatial.transform.Rotation.align_vectors.html) ### This PR This PR introduces a Kabsch-based alignment utility in `species/converter.py` and integrates it into the `ARCSpecies` API. The implementation explicitly accounts for key assumptions required for molecular alignment, including S2S mapping. To support this workflow, a new helper method, `translate_to_indices`, was added. This method converts an atom mapping into index-based ordering, enabling consistent reordering and alignment of species geometries prior to applying the Kabsch algorithm. Overall, this addition enables robust, minimal-RMSD alignment of mapped species geometries while preserving chemical correspondence.
2 parents a504aa7 + 59837a7 commit 453c8e2

File tree

4 files changed

+150
-0
lines changed

4 files changed

+150
-0
lines changed

arc/species/converter.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
99

1010
from ase import Atoms
11+
from scipy.spatial.transform import Rotation
1112
from openbabel import openbabel as ob
1213
from openbabel import pybel
1314
from rdkit import Chem
@@ -2446,3 +2447,25 @@ def sorted_distances_of_atom(xyz_dict: dict, atom_index: int) -> List[Tuple[int,
24462447

24472448
distances = [(i, d_matrix[atom_index, i]) for i in range(d_matrix.shape[0]) if i != atom_index]
24482449
return sorted(distances, key=lambda x: x[1])
2450+
2451+
2452+
def kabsch(xyz1: dict, xyz2: dict) -> float:
2453+
"""
2454+
Return the kabsch similarity score between two sets of Cartesian coordinates in Ångstrom.
2455+
The algorithm requires the atoms to be ordered the same in both sets of coordinates.
2456+
This will not be directly useful for comparing two conformers of the same species,
2457+
but could be used to compare two different species with the same atom types and counts. (e.g., isomers, reactants and products, etc.).
2458+
2459+
Args:
2460+
xyz1 (dict): The first set of Cartesian coordinates.
2461+
xyz2 (dict): The second set of Cartesian coordinates.
2462+
2463+
Returns:
2464+
float: The Kabsch similarity score in Ångstrom.
2465+
"""
2466+
if xyz1["symbols"] != xyz2["symbols"]:
2467+
raise ValueError("The two xyz coordinates must have the same atom symbols to compute Kabsch score.")
2468+
xyz1, xyz2 = translate_to_center_of_mass(xyz1), translate_to_center_of_mass(xyz2)
2469+
coords1, coords2 = np.array(xyz1['coords']), np.array(xyz2['coords'])
2470+
_, score = Rotation.align_vectors(coords1, coords2)
2471+
return score

arc/species/converter_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010

1111
import numpy as np
12+
from scipy.spatial.transform import Rotation
1213
import unittest
1314

1415
from ase import Atoms
@@ -4966,6 +4967,70 @@ def test_sorted_distances_of_atom(self):
49664967
with self.assertRaises(IndexError):
49674968
converter.sorted_distances_of_atom(xyz_dict, 5)
49684969

4970+
def test_kabsch(self):
4971+
"""Test the kabsch function"""
4972+
xyz1 = {'symbols': ('O', 'H', 'H'), 'isotopes': (16, 1, 1),
4973+
'coords': ((0.0, 0.0, 0.0),
4974+
(0.0, 0.757, 0.586),
4975+
(0.0, -0.757, 0.586))}
4976+
xyz2 = converter.translate_xyz(xyz1, (10.0, 0.0, 0.0))
4977+
score = converter.kabsch(xyz1, xyz2)
4978+
self.assertAlmostEqual(score, 0.0, places=5)
4979+
4980+
r = Rotation.from_quat([0, 0, np.sin(np.pi/4), np.cos(np.pi/4)])
4981+
self.assertAlmostEqual(converter.kabsch(xyz1, converter.xyz_from_data(coords=r.apply(np.array(xyz1["coords"])), symbols=xyz1["symbols"], isotopes=xyz1["isotopes"])), 0.0, places=5)
4982+
xyz2 = {i:v for i, v in xyz1.items()}
4983+
xyz2['symbols'] = ('O', 'H', 'H', 'H')
4984+
with self.assertRaises(ValueError):
4985+
converter.kabsch(xyz1, xyz2)
4986+
4987+
# Wildtype test: Aspirin (21 atoms)
4988+
aspirin_xyz = {
4989+
'symbols': ('O', 'C', 'O', 'C', 'C', 'C', 'C', 'C', 'C', 'H', 'H', 'H', 'H', 'C', 'O', 'O', 'C', 'H', 'H', 'H', 'H'),
4990+
'isotopes': (16, 12, 16, 12, 12, 12, 12, 12, 12, 1, 1, 1, 1, 12, 16, 16, 12, 1, 1, 1, 1),
4991+
'coords': ((-2.6373, 1.2580, -0.2078),
4992+
(-1.7770, 0.3860, -0.0632),
4993+
(-2.1558, -0.8741, 0.1691),
4994+
(-0.3592, 0.7788, -0.1627),
4995+
(0.6120, -0.2186, -0.0163),
4996+
(1.9423, 0.1873, -0.1118),
4997+
(2.2874, 1.5173, -0.3475),
4998+
(1.3090, 2.4837, -0.4907),
4999+
(-0.0249, 2.1150, -0.3989),
5000+
(2.6841, -0.5841, 0.0004),
5001+
(3.3216, 1.8105, -0.4206),
5002+
(1.5971, 3.5230, -0.6740),
5003+
(-0.7675, 2.8711, -0.5103),
5004+
(0.1843, -1.6369, 0.2443),
5005+
(0.8550, -2.5186, 0.7303),
5006+
(-1.1095, -1.8841, -0.1121),
5007+
(-1.6441, -3.2084, 0.0818),
5008+
(-2.6718, -3.1782, -0.2678),
5009+
(-1.0850, -3.8967, -0.5484),
5010+
(-1.6041, -3.4832, 1.1345),
5011+
(-3.5658, 0.9859, -0.1477))
5012+
}
5013+
5014+
# Case 1: Pure rotation (should be 0.0)
5015+
# Rotate 90 degrees about the z-axis, then 45 degrees about the y-axis
5016+
r = Rotation.from_euler('zy', [90, 45], degrees=True)
5017+
rotated_coords = r.apply(np.array(aspirin_xyz["coords"]))
5018+
aspirin_rotated_xyz = converter.xyz_from_data(coords=rotated_coords,
5019+
symbols=aspirin_xyz["symbols"],
5020+
isotopes=aspirin_xyz["isotopes"])
5021+
score = converter.kabsch(aspirin_xyz, aspirin_rotated_xyz)
5022+
self.assertAlmostEqual(score, 0.0, places=4)
5023+
5024+
# Case 2: Random structural perturbation (should not be 0.0)
5025+
# Add random noise to coordinates
5026+
rng = np.random.RandomState(42)
5027+
perturbed_coords = np.array(aspirin_xyz["coords"]) + 0.1 * rng.rand(*np.array(aspirin_xyz["coords"]).shape)
5028+
aspirin_perturbed_xyz = converter.xyz_from_data(coords=perturbed_coords,
5029+
symbols=aspirin_xyz["symbols"],
5030+
isotopes=aspirin_xyz["isotopes"])
5031+
score = converter.kabsch(aspirin_xyz, aspirin_perturbed_xyz)
5032+
self.assertGreater(score, 0.01)
5033+
49695034
@classmethod
49705035
def tearDownClass(cls):
49715036
"""

arc/species/species.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,12 @@
4646
order_atoms_in_mol_list,
4747
remove_dummies,
4848
rmg_mol_from_inchi,
49+
sort_xyz_using_indices,
4950
str_to_xyz,
5051
translate_to_center_of_mass,
5152
xyz_from_data,
5253
xyz_to_str,
54+
kabsch,
5355
)
5456
from arc.species.perceive import perceive_molecule_from_xyz, is_mol_valid
5557
from arc.species.vectors import calculate_angle, calculate_distance, calculate_dihedral_angle
@@ -2142,6 +2144,27 @@ def get_bonds(self) -> List[tuple]:
21422144
for atom2, bond12 in atom1.edges.items():
21432145
bonds.append(tuple(sorted((self.mol.atoms.index(atom1), self.mol.atoms.index(atom2)))))
21442146
return list(set(bonds))
2147+
2148+
def kabsch(self, other: 'ARCSpecies', map_: list) -> float:
2149+
"""
2150+
Calculate the Kabsch RMSD between this species and another species.
2151+
2152+
Args:
2153+
other (ARCSpecies): The other species to compare to.
2154+
map_ (list): A list of atom indices mapping atoms from this species to the other species. (i.e., if
2155+
this species has atoms [A, B, C] and the other species has atoms [C, A, B], then map_ would be [1, 2, 0]
2156+
Returns:
2157+
float: The Kabsch RMSD value.
2158+
"""
2159+
if not isinstance(other, ARCSpecies):
2160+
raise SpeciesError(f'Other must be an ARCSpecies instance, got {type(other)}.\n'
2161+
f'If you meant to use the XYZ coordinates directly, use arc.species.converter.kabsch.')
2162+
2163+
if len(map_) != self.number_of_atoms:
2164+
raise SpeciesError(f'The map_ list must have the same length as the number of atoms in {self.label} '
2165+
f'({self.number_of_atoms}), got {len(map_)}.')
2166+
2167+
return kabsch(self.get_xyz(), sort_xyz_using_indices(other.get_xyz(), map_))
21452168

21462169

21472170
class TSGuess(object):

arc/species/species_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2804,6 +2804,45 @@ def test_species_indexing(self):
28042804
self.assertEqual(spc_b.index, 1)
28052805
self.assertEqual(spc_c.index, 2)
28062806

2807+
def test_kabsch(self):
2808+
"""Test the kabsch() method."""
2809+
# Test with self (RMSD should be 0)
2810+
rmsd = self.spc1.kabsch(self.spc1, [0, 1, 2, 3, 4, 5])
2811+
self.assertAlmostEqual(rmsd, 0.0, delta=1e-5)
2812+
2813+
# Test with a copy (RMSD should be 0)
2814+
spc1_copy = self.spc1.copy()
2815+
rmsd = self.spc1.kabsch(spc1_copy, [0, 1, 2, 3, 4, 5])
2816+
self.assertAlmostEqual(rmsd, 0.0, delta=1e-5)
2817+
2818+
# Test with shuffled atoms
2819+
# Create a shuffled version of spc1: swap first two C atoms
2820+
# spc1: C, C, O, H, H, H
2821+
# shuffled: C(1), C(0), O(2), H(3), H(4), H(5)
2822+
# xyz of spc1
2823+
xyz = self.spc1.get_xyz()
2824+
new_coords = (xyz['coords'][1], xyz['coords'][0]) + xyz['coords'][2:]
2825+
new_symbols = (xyz['symbols'][1], xyz['symbols'][0]) + xyz['symbols'][2:]
2826+
new_isotopes = (xyz['isotopes'][1], xyz['isotopes'][0]) + xyz['isotopes'][2:]
2827+
shuffled_xyz = {'symbols': new_symbols, 'isotopes': new_isotopes, 'coords': new_coords}
2828+
spc_shuffled = ARCSpecies(label='shuffled', xyz=shuffled_xyz, smiles='C=C[O]')
2829+
2830+
# Map: we want to pull atoms from spc_shuffled to match spc1
2831+
# spc1[0] is C(0). In spc_shuffled, C(0) is at index 1.
2832+
# spc1[1] is C(1). In spc_shuffled, C(1) is at index 0.
2833+
# Rest are same.
2834+
map_indices = [1, 0, 2, 3, 4, 5]
2835+
rmsd = self.spc1.kabsch(spc_shuffled, map_indices)
2836+
self.assertAlmostEqual(rmsd, 0.0, delta=1e-5)
2837+
2838+
# Test exception
2839+
with self.assertRaises(SpeciesError):
2840+
self.spc1.kabsch("not a species", [])
2841+
2842+
# Test incorrect map_ length
2843+
with self.assertRaises(SpeciesError):
2844+
self.spc1.kabsch(self.spc1, [0, 1, 2])
2845+
28072846

28082847
class TestTSGuess(unittest.TestCase):
28092848
"""

0 commit comments

Comments
 (0)