|
9 | 9 | import os |
10 | 10 |
|
11 | 11 | import numpy as np |
| 12 | +from scipy.spatial.transform import Rotation |
12 | 13 | import unittest |
13 | 14 |
|
14 | 15 | from ase import Atoms |
@@ -4966,6 +4967,70 @@ def test_sorted_distances_of_atom(self): |
4966 | 4967 | with self.assertRaises(IndexError): |
4967 | 4968 | converter.sorted_distances_of_atom(xyz_dict, 5) |
4968 | 4969 |
|
| 4970 | + def test_kabsch(self): |
| 4971 | + """Test the kabsch function""" |
| 4972 | + xyz1 = {'symbols': ('O', 'H', 'H'), 'isotopes': (16, 1, 1), |
| 4973 | + 'coords': ((0.0, 0.0, 0.0), |
| 4974 | + (0.0, 0.757, 0.586), |
| 4975 | + (0.0, -0.757, 0.586))} |
| 4976 | + xyz2 = converter.translate_xyz(xyz1, (10.0, 0.0, 0.0)) |
| 4977 | + score = converter.kabsch(xyz1, xyz2) |
| 4978 | + self.assertAlmostEqual(score, 0.0, places=5) |
| 4979 | + |
| 4980 | + r = Rotation.from_quat([0, 0, np.sin(np.pi/4), np.cos(np.pi/4)]) |
| 4981 | + self.assertAlmostEqual(converter.kabsch(xyz1, converter.xyz_from_data(coords=r.apply(np.array(xyz1["coords"])), symbols=xyz1["symbols"], isotopes=xyz1["isotopes"])), 0.0, places=5) |
| 4982 | + xyz2 = {i:v for i, v in xyz1.items()} |
| 4983 | + xyz2['symbols'] = ('O', 'H', 'H', 'H') |
| 4984 | + with self.assertRaises(ValueError): |
| 4985 | + converter.kabsch(xyz1, xyz2) |
| 4986 | + |
| 4987 | + # Wildtype test: Aspirin (21 atoms) |
| 4988 | + aspirin_xyz = { |
| 4989 | + 'symbols': ('O', 'C', 'O', 'C', 'C', 'C', 'C', 'C', 'C', 'H', 'H', 'H', 'H', 'C', 'O', 'O', 'C', 'H', 'H', 'H', 'H'), |
| 4990 | + 'isotopes': (16, 12, 16, 12, 12, 12, 12, 12, 12, 1, 1, 1, 1, 12, 16, 16, 12, 1, 1, 1, 1), |
| 4991 | + 'coords': ((-2.6373, 1.2580, -0.2078), |
| 4992 | + (-1.7770, 0.3860, -0.0632), |
| 4993 | + (-2.1558, -0.8741, 0.1691), |
| 4994 | + (-0.3592, 0.7788, -0.1627), |
| 4995 | + (0.6120, -0.2186, -0.0163), |
| 4996 | + (1.9423, 0.1873, -0.1118), |
| 4997 | + (2.2874, 1.5173, -0.3475), |
| 4998 | + (1.3090, 2.4837, -0.4907), |
| 4999 | + (-0.0249, 2.1150, -0.3989), |
| 5000 | + (2.6841, -0.5841, 0.0004), |
| 5001 | + (3.3216, 1.8105, -0.4206), |
| 5002 | + (1.5971, 3.5230, -0.6740), |
| 5003 | + (-0.7675, 2.8711, -0.5103), |
| 5004 | + (0.1843, -1.6369, 0.2443), |
| 5005 | + (0.8550, -2.5186, 0.7303), |
| 5006 | + (-1.1095, -1.8841, -0.1121), |
| 5007 | + (-1.6441, -3.2084, 0.0818), |
| 5008 | + (-2.6718, -3.1782, -0.2678), |
| 5009 | + (-1.0850, -3.8967, -0.5484), |
| 5010 | + (-1.6041, -3.4832, 1.1345), |
| 5011 | + (-3.5658, 0.9859, -0.1477)) |
| 5012 | + } |
| 5013 | + |
| 5014 | + # Case 1: Pure rotation (should be 0.0) |
| 5015 | + # Rotate 90 degrees about the z-axis, then 45 degrees about the y-axis |
| 5016 | + r = Rotation.from_euler('zy', [90, 45], degrees=True) |
| 5017 | + rotated_coords = r.apply(np.array(aspirin_xyz["coords"])) |
| 5018 | + aspirin_rotated_xyz = converter.xyz_from_data(coords=rotated_coords, |
| 5019 | + symbols=aspirin_xyz["symbols"], |
| 5020 | + isotopes=aspirin_xyz["isotopes"]) |
| 5021 | + score = converter.kabsch(aspirin_xyz, aspirin_rotated_xyz) |
| 5022 | + self.assertAlmostEqual(score, 0.0, places=4) |
| 5023 | + |
| 5024 | + # Case 2: Random structural perturbation (should not be 0.0) |
| 5025 | + # Add random noise to coordinates |
| 5026 | + rng = np.random.RandomState(42) |
| 5027 | + perturbed_coords = np.array(aspirin_xyz["coords"]) + 0.1 * rng.rand(*np.array(aspirin_xyz["coords"]).shape) |
| 5028 | + aspirin_perturbed_xyz = converter.xyz_from_data(coords=perturbed_coords, |
| 5029 | + symbols=aspirin_xyz["symbols"], |
| 5030 | + isotopes=aspirin_xyz["isotopes"]) |
| 5031 | + score = converter.kabsch(aspirin_xyz, aspirin_perturbed_xyz) |
| 5032 | + self.assertGreater(score, 0.01) |
| 5033 | + |
4969 | 5034 | @classmethod |
4970 | 5035 | def tearDownClass(cls): |
4971 | 5036 | """ |
|
0 commit comments