Skip to content

Commit 54c4cf7

Browse files
committed
distance_labels auto derived from atom positions + residue atom indices, if the latter two are passed in
1 parent b63a0d9 commit 54c4cf7

File tree

3 files changed

+15
-6
lines changed

3 files changed

+15
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2452,9 +2452,9 @@ def __init__(
24522452
dim_single = 384,
24532453
dim_pairwise = 128,
24542454
dim_token = 768,
2455-
atompair_dist_bins: Float[' dist_bins'] = torch.linspace(3, 20, 37),
2455+
distance_bins: Float[' dist_bins'] = torch.linspace(3, 20, 38),
24562456
ignore_index = -1,
2457-
num_dist_bins = 38,
2457+
num_dist_bins: int | None = None,
24582458
num_plddt_bins = 50,
24592459
num_pde_bins = 64,
24602460
num_pae_bins = 64,
@@ -2626,14 +2626,19 @@ def __init__(
26262626

26272627
# logit heads
26282628

2629+
self.register_buffer('distance_bins', distance_bins)
2630+
num_dist_bins = default(num_dist_bins, len(distance_bins))
2631+
2632+
assert len(distance_bins) == num_dist_bins, '`distance_bins` must have a length equal to the `num_dist_bins` passed in'
2633+
26292634
self.distogram_head = DistogramHead(
26302635
dim_pairwise = dim_pairwise,
26312636
num_dist_bins = num_dist_bins
26322637
)
26332638

26342639
self.confidence_head = ConfidenceHead(
26352640
dim_single_inputs = dim_single_inputs,
2636-
atompair_dist_bins = atompair_dist_bins,
2641+
atompair_dist_bins = distance_bins,
26372642
dim_single = dim_single,
26382643
dim_pairwise = dim_pairwise,
26392644
num_plddt_bins = num_plddt_bins,
@@ -2830,6 +2835,12 @@ def forward(
28302835

28312836
# distogram head
28322837

2838+
if not exists(distance_labels) and atom_pos_given and exists(residue_atom_indices):
2839+
residue_pos = einx.get_at('b (n [w]) c, b n -> b n c', atom_pos, residue_atom_indices)
2840+
residue_dist = torch.cdist(residue_pos, residue_pos, p = 2)
2841+
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', residue_dist, self.distance_bins).abs()
2842+
distance_labels = dist_from_dist_bins.argmin(dim = -1)
2843+
28332844
if exists(distance_labels):
28342845
distance_labels = torch.where(pairwise_mask, distance_labels, ignore)
28352846
distogram_logits = self.distogram_head(pairwise)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.0.23"
3+
version = "0.0.24"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,6 @@ def test_alphafold3():
384384
atom_pos = torch.randn(2, atom_seq_len, 3)
385385
residue_atom_indices = torch.randint(0, 27, (2, seq_len))
386386

387-
distance_labels = torch.randint(0, 38, (2, seq_len, seq_len))
388387
pae_labels = torch.randint(0, 64, (2, seq_len, seq_len))
389388
pde_labels = torch.randint(0, 64, (2, seq_len, seq_len))
390389
plddt_labels = torch.randint(0, 50, (2, seq_len))
@@ -427,7 +426,6 @@ def test_alphafold3():
427426
template_mask = template_mask,
428427
atom_pos = atom_pos,
429428
residue_atom_indices = residue_atom_indices,
430-
distance_labels = distance_labels,
431429
pae_labels = pae_labels,
432430
pde_labels = pde_labels,
433431
plddt_labels = plddt_labels,

0 commit comments

Comments
 (0)