Skip to content

Commit 8933bc8

Browse files
committed
address #81, separating out the atom indices for distogram loss vs the "token centre", pointed out by Alex Morehead
1 parent 5b944dd commit 8933bc8

File tree

5 files changed

+90
-44
lines changed

5 files changed

+90
-44
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3347,7 +3347,8 @@ def forward(
33473347
num_recycling_steps: int = 1,
33483348
diffusion_add_bond_loss: bool = False,
33493349
diffusion_add_smooth_lddt_loss: bool = False,
3350-
molecule_atom_indices: Int['b n'] | None = None,
3350+
distogram_atom_indices: Int['b n'] | None = None,
3351+
molecule_atom_indices: Int['b n'] | None = None, # the 'token centre atoms' mentioned in the paper, unsure where it is used in the architecture
33513352
num_sample_steps: int | None = None,
33523353
atom_pos: Float['b m 3'] | None = None,
33533354
output_atompos_indices: Int['b m'] | None = None,
@@ -3379,6 +3380,10 @@ def forward(
33793380
molecule_atom_indices = molecule_atom_indices.masked_fill(~valid_atom_len_mask, 0)
33803381
assert (molecule_atom_indices < molecule_atom_lens)[valid_atom_len_mask].all(), 'molecule_atom_indices cannot have an index that exceeds the length of the atoms for that molecule as given by molecule_atom_lens'
33813382

3383+
if exists(distogram_atom_indices):
3384+
distogram_atom_indices = distogram_atom_indices.masked_fill(~valid_atom_len_mask, 0)
3385+
assert (distogram_atom_indices < molecule_atom_lens)[valid_atom_len_mask].all(), 'distogram_atom_indices cannot have an index that exceeds the length of the atoms for that molecule as given by molecule_atom_lens'
3386+
33823387
assert exists(molecule_atom_lens) or exists(atom_mask)
33833388

33843389
# if atompair inputs are not windowed, window it
@@ -3640,9 +3645,8 @@ def forward(
36403645

36413646
# distogram head
36423647

3643-
if not exists(distance_labels) and atom_pos_given and exists(molecule_atom_indices):
3644-
3645-
molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, molecule_atom_indices)
3648+
if not exists(distance_labels) and atom_pos_given and exists(distogram_atom_indices):
3649+
molecule_pos = einx.get_at('b [m] c, b n -> b n c', atom_pos, distogram_atom_indices)
36463650
molecule_dist = torch.cdist(molecule_pos, molecule_pos, p = 2)
36473651
dist_from_dist_bins = einx.subtract('b m dist, dist_bins -> b m dist dist_bins', molecule_dist, self.distance_bins).abs()
36483652
distance_labels = dist_from_dist_bins.argmin(dim = -1)
@@ -3678,6 +3682,7 @@ def forward(
36783682
relative_position_encoding,
36793683
additional_molecule_feats,
36803684
is_molecule_types,
3685+
distogram_atom_indices,
36813686
molecule_atom_indices,
36823687
molecule_atom_lens,
36833688
pae_labels,
@@ -3701,6 +3706,7 @@ def forward(
37013706
relative_position_encoding,
37023707
additional_molecule_feats,
37033708
is_molecule_types,
3709+
distogram_atom_indices,
37043710
molecule_atom_indices,
37053711
molecule_atom_lens,
37063712
pae_labels,
@@ -3756,7 +3762,7 @@ def forward(
37563762
should_call_confidence_head = any([*map(exists, confidence_head_labels)])
37573763
return_pae_logits = exists(pae_labels)
37583764

3759-
if calc_diffusion_loss and should_call_confidence_head:
3765+
if calc_diffusion_loss and should_call_confidence_head and exists(molecule_atom_indices):
37603766

37613767
# rollout
37623768

alphafold3_pytorch/inputs.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ class AtomInput:
137137
atom_pos: Float['m 3'] | None = None
138138
output_atompos_indices: Int[' m'] | None = None
139139
molecule_atom_indices: Int[' n'] | None = None
140+
distogram_atom_indices: Int[' n'] | None = None
140141
distance_labels: Int['n n'] | None = None
141142
pae_labels: Int['n n'] | None = None
142143
pde_labels: Int['n n'] | None = None
@@ -167,6 +168,7 @@ class BatchedAtomInput:
167168
atom_pos: Float['b m 3'] | None = None
168169
output_atompos_indices: Int['b m'] | None = None
169170
molecule_atom_indices: Int['b n'] | None = None
171+
distogram_atom_indices: Int['b n'] | None = None
170172
distance_labels: Int['b n n'] | None = None
171173
pae_labels: Int['b n n'] | None = None
172174
pde_labels: Int['b n n'] | None = None
@@ -202,11 +204,12 @@ def default_extract_atompair_feats_fn(mol: Mol):
202204
class MoleculeInput:
203205
molecules: List[Mol]
204206
molecule_token_pool_lens: List[int]
205-
molecule_atom_indices: List[int | None]
206207
molecule_ids: Int[' n']
207208
additional_molecule_feats: Int[f'n {ADDITIONAL_MOLECULE_FEATS}']
208209
is_molecule_types: Bool[f'n {IS_MOLECULE_TYPES}']
209210
token_bonds: Bool['n n']
211+
molecule_atom_indices: List[int | None] | None = None
212+
distogram_atom_indices: List[int | None] | None = None
210213
atom_parent_ids: Int[' m'] | None = None
211214
additional_token_feats: Float[f'n dtf'] | None = None
212215
templates: Float['t n n dt'] | None = None
@@ -398,6 +401,8 @@ def molecule_to_atom_input(
398401
atompair_inputs = atompair_inputs,
399402
molecule_atom_lens = tensor(atom_lens, dtype = torch.long),
400403
molecule_ids = i.molecule_ids,
404+
molecule_atom_indices = i.molecule_atom_indices,
405+
distogram_atom_indices = i.distogram_atom_indices,
401406
additional_token_feats = i.additional_token_feats,
402407
additional_molecule_feats = i.additional_molecule_feats,
403408
is_molecule_types = i.is_molecule_types,
@@ -542,12 +547,15 @@ def alphafold3_input_to_molecule_input(
542547
proteins = i.proteins
543548
mol_proteins = []
544549
protein_entries = []
550+
551+
distogram_atom_indices = []
545552
molecule_atom_indices = []
546553

547554
for protein in proteins:
548555
mol_peptides, protein_entries = map_int_or_string_indices_to_mol(HUMAN_AMINO_ACIDS, protein, chain = True, return_entries = True)
549556
mol_proteins.append(mol_peptides)
550557

558+
distogram_atom_indices.extend([entry['token_center_atom_idx'] for entry in protein_entries])
551559
molecule_atom_indices.extend([entry['distogram_atom_idx'] for entry in protein_entries])
552560

553561
protein_ids = maybe_string_to_int(HUMAN_AMINO_ACIDS, protein) + protein_offset
@@ -810,7 +818,10 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
810818
sym_ids
811819
), dim = -1)
812820

813-
# molecule atom indices
821+
# distogram and token centre atom indices
822+
823+
distogram_atom_indices = tensor(distogram_atom_indices)
824+
distogram_atom_indices = pad_to_len(distogram_atom_indices, num_tokens, value = -1)
814825

815826
molecule_atom_indices = tensor(molecule_atom_indices)
816827
molecule_atom_indices = pad_to_len(molecule_atom_indices, num_tokens, value = -1)
@@ -874,6 +885,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
874885
molecules = molecules,
875886
molecule_token_pool_lens = token_pool_lens,
876887
molecule_atom_indices = molecule_atom_indices,
888+
distogram_atom_indices = distogram_atom_indices,
877889
molecule_ids = molecule_ids,
878890
token_bonds = token_bonds,
879891
additional_molecule_feats = additional_molecule_feats,

0 commit comments

Comments
 (0)