Skip to content

Commit 9dce0a1

Browse files
committed
add the necessary functions for deriving the frames of the ligands, or other biomolecules that are not amino acids or nucleotides
1 parent b766fcb commit 9dce0a1

File tree

3 files changed

+114
-1
lines changed

3 files changed

+114
-1
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@
139139

140140
# functions
141141

142+
def l2norm(t):
143+
return F.normalize(t, dim = -1)
144+
142145
def flatten(arr):
143146
return [el for sub_arr in arr for el in sub_arr]
144147

@@ -199,6 +202,96 @@ def hard_validate_atom_indices_ascending(
199202

200203
assert (difference >= 0).all(), f'detected invalid {error_msg_field} for in a batch: {present_indices}'
201204

205+
# functions for deriving the frames for ligands
206+
# this follows the logic from Alphafold3 Supplementary section 4.3.2
207+
208+
@typecheck
209+
def get_indices_three_closest_atom_pos(
210+
atom_pos: Float['... n d'],
211+
) -> Int['... 3']:
212+
213+
prec_dims, device = atom_pos.shape[:-2], atom_pos.device
214+
num_atoms, has_batch = atom_pos.shape[-2], atom_pos.ndim == 3
215+
216+
if num_atoms < 3:
217+
return atom_pos.new_full((*prec_dims, 3), -1).long()
218+
219+
if not has_batch:
220+
atom_pos = rearrange(atom_pos, '... -> 1 ...')
221+
222+
atom_dist = torch.cdist(atom_pos, atom_pos)
223+
224+
# mask out the distance to self
225+
226+
eye = torch.eye(num_atoms, device = device, dtype = torch.bool)
227+
228+
atom_dist.masked_fill_(eye, 1e4)
229+
230+
# will use topk on the negative of the distance
231+
232+
neg_distance, two_closest_atom_indices = (-atom_dist).topk(2, dim = -1)
233+
234+
mean_neg_distance = neg_distance.mean(dim = -1)
235+
236+
best_atom_pair_index = mean_neg_distance.argmax(dim = -1)
237+
238+
best_two_atom_neighbors = einx.get_at('... [m] c, ... -> ... c', two_closest_atom_indices, best_atom_pair_index)
239+
240+
# place the chosen atom at the center
241+
242+
three_atom_indices, _ = pack((
243+
best_two_atom_neighbors[..., 0],
244+
best_atom_pair_index,
245+
best_two_atom_neighbors[..., 1],
246+
), 'b *')
247+
248+
if not has_batch:
249+
three_atom_indices = rearrange(three_atom_indices, '1 ... -> ...')
250+
251+
return three_atom_indices
252+
253+
@typecheck
254+
def get_angle_between_edges(
255+
edge1: Float['... 3'],
256+
edge2: Float['... 3']
257+
) -> Float['...']:
258+
cos = torch.dot(l2norm(edge1), l2norm(edge2))
259+
return torch.acos(cos)
260+
261+
@typecheck
262+
def get_frames_from_atom_pos(
263+
atom_pos: Float['... n d'],
264+
filter_colinear_pos: bool = False,
265+
is_colinear_angle_thres: float = 25. # they use 25 degrees as a way of filtering out invalid frames
266+
) -> Int['... 3']:
267+
268+
frames = get_indices_three_closest_atom_pos(atom_pos)
269+
270+
if not filter_colinear_pos:
271+
return frames
272+
273+
# get the edges and derive angles
274+
275+
three_atom_pos = einx.get_at('... [m] c, ... three -> ... three c', atom_pos, frames)
276+
277+
left_pos, center_pos, right_pos = three_atom_pos.unbind(dim = -2)
278+
279+
edges1, edges2 = (left_pos - center_pos), (right_pos - center_pos)
280+
281+
angle = get_angle_between_edges(edges1, edges2)
282+
283+
degree = torch.rad2deg(angle)
284+
285+
is_colinear = (
286+
(degree.abs() < is_colinear_angle_thres) |
287+
((180. - degree.abs()) < is_colinear_angle_thres)
288+
)
289+
290+
# set any three atoms that are colinear to -1 indices
291+
292+
three_atom_indices = einx.where('..., ... three, -> ... three', ~is_colinear, frames, -1)
293+
return three_atom_indices
294+
202295
# atom level, what Alphafold3 accepts
203296

204297
UNCOLLATABLE_ATOM_INPUT_FIELDS = {'filepath'}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.14"
3+
version = "0.4.15"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@
6666
PDBDataset,
6767
default_extract_atom_feats_fn,
6868
default_extract_atompair_feats_fn,
69+
get_indices_three_closest_atom_pos,
70+
get_angle_between_edges,
71+
get_frames_from_atom_pos
6972
)
7073

7174
from alphafold3_pytorch.utils.model_utils import exclusive_cumsum
@@ -192,6 +195,23 @@ def test_rigid_from_three_points():
192195
rotation, _ = rigid_from_3_points((points, points, points))
193196
assert rotation.shape == (7, 11, 23, 3, 3)
194197

198+
def test_deriving_frames_for_ligands():
199+
points = torch.tensor([
200+
[1., 1., 1.],
201+
[-99, -99, -99],
202+
[0, 0, 0],
203+
[100, 100, 100],
204+
[-1., -1., -1.],
205+
])
206+
207+
frames = get_frames_from_atom_pos(points, filter_colinear_pos = True)
208+
209+
assert torch.allclose(frames, torch.tensor([-1, -1, -1]))
210+
211+
frames = get_frames_from_atom_pos(points, filter_colinear_pos = False)
212+
213+
assert torch.allclose(frames, torch.tensor([0, 2, 4]))
214+
195215
def test_compute_alignment_error():
196216
pred_coords = torch.randn(2, 100, 3)
197217
pred_frames = torch.randn(2, 100, 3, 3)

0 commit comments

Comments
 (0)