Skip to content

Commit 57b1322

Browse files
committed
make sure it can work with batched atom positions with masking
1 parent 1f58414 commit 57b1322

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,24 +208,42 @@ def hard_validate_atom_indices_ascending(
208208
@typecheck
209209
def get_indices_three_closest_atom_pos(
210210
atom_pos: Float['... n d'],
211+
mask: Bool['... n'] | None = None
211212
) -> Int['... 3']:
212213

213214
prec_dims, device = atom_pos.shape[:-2], atom_pos.device
214215
num_atoms, has_batch = atom_pos.shape[-2], atom_pos.ndim == 3
215216

216-
if num_atoms < 3:
217+
if not exists(mask) and num_atoms < 3:
217218
return atom_pos.new_full((*prec_dims, 3), -1).long()
218219

219220
if not has_batch:
220221
atom_pos = rearrange(atom_pos, '... -> 1 ...')
221222

223+
if exists(mask):
224+
mask = rearrange(mask, '... -> 1 ...')
225+
226+
# figure out which set of atoms are less than 3 for masking out later
227+
228+
if exists(mask):
229+
insufficient_atom_mask = mask.sum(dim = -1) < 3
230+
231+
# get distances between all atoms
232+
222233
atom_dist = torch.cdist(atom_pos, atom_pos)
223234

224235
# mask out the distance to self
225236

226237
eye = torch.eye(num_atoms, device = device, dtype = torch.bool)
227238

228-
atom_dist.masked_fill_(eye, 1e4)
239+
mask_value = 1e4
240+
atom_dist.masked_fill_(eye, mask_value)
241+
242+
# take care of padding
243+
244+
if exists(mask):
245+
pair_mask = einx.logical_and('... i, ... j -> ... i j', mask, mask)
246+
atom_dist.masked_fill_(~pair_mask, mask_value)
229247

230248
# will use topk on the negative of the distance
231249

@@ -245,6 +263,11 @@ def get_indices_three_closest_atom_pos(
245263
best_two_atom_neighbors[..., 1],
246264
), 'b *')
247265

266+
# mask out
267+
268+
if exists(mask):
269+
three_atom_indices = einx.where('..., ... three, -> ... three', ~insufficient_atom_mask, three_atom_indices, -1)
270+
248271
if not has_batch:
249272
three_atom_indices = rearrange(three_atom_indices, '1 ... -> ...')
250273

@@ -261,11 +284,12 @@ def get_angle_between_edges(
261284
@typecheck
262285
def get_frames_from_atom_pos(
263286
atom_pos: Float['... n d'],
287+
mask: Bool['... n'] | None = None,
264288
filter_colinear_pos: bool = False,
265289
is_colinear_angle_thres: float = 25. # they use 25 degrees as a way of filtering out invalid frames
266290
) -> Int['... 3']:
267291

268-
frames = get_indices_three_closest_atom_pos(atom_pos)
292+
frames = get_indices_three_closest_atom_pos(atom_pos, mask = mask)
269293

270294
if not filter_colinear_pos:
271295
return frames

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.4.16"
3+
version = "0.4.17"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

tests/test_af3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,20 @@ def test_deriving_frames_for_ligands():
206206

207207
frames = get_frames_from_atom_pos(points, filter_colinear_pos = True)
208208

209-
assert torch.allclose(frames, torch.tensor([-1, -1, -1]))
209+
assert (frames == -1).all()
210210

211211
frames = get_frames_from_atom_pos(points, filter_colinear_pos = False)
212212

213213
assert torch.allclose(frames, torch.tensor([0, 2, 4]))
214214

215+
# test with mask
216+
217+
mask = torch.tensor([True, True, False, False, False])
218+
219+
frames = get_frames_from_atom_pos(points, mask = mask)
220+
221+
assert (frames == -1).all()
222+
215223
def test_compute_alignment_error():
216224
pred_coords = torch.randn(2, 100, 3)
217225
pred_frames = torch.randn(2, 100, 3, 3)

0 commit comments

Comments
 (0)