Skip to content

Commit dcde445

Browse files
committed
first remove the reordering of the atoms per monomer, and allow the phoshpodiesterase and peptide atomic bonds to be incorrectly marked
1 parent a758297 commit dcde445

File tree

5 files changed

+7
-93
lines changed

5 files changed

+7
-93
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3351,7 +3351,6 @@ def forward(
33513351
molecule_atom_indices: Int['b n'] | None = None, # the 'token centre atoms' mentioned in the paper, unsure where it is used in the architecture
33523352
num_sample_steps: int | None = None,
33533353
atom_pos: Float['b m 3'] | None = None,
3354-
output_atompos_indices: Int['b m'] | None = None,
33553354
distance_labels: Int['b n n'] | None = None,
33563355
pae_labels: Int['b n n'] | None = None,
33573356
pde_labels: Int['b n n'] | None = None,
@@ -3603,26 +3602,6 @@ def forward(
36033602
if exists(atom_mask):
36043603
sampled_atom_pos = einx.where('b m, b m c, -> b m c', atom_mask, sampled_atom_pos, 0.)
36053604

3606-
if not exists(output_atompos_indices):
3607-
return sampled_atom_pos
3608-
3609-
# in the case the atoms are passed in not ordered canonically
3610-
3611-
order_mask = output_atompos_indices >= 0 # -1 is padding, which means do not order (metal ions, ligands, or entire row if None was passed in)
3612-
3613-
output_atompos_indices = einx.where(
3614-
'b m, b m, m -> b m',
3615-
order_mask,
3616-
output_atompos_indices,
3617-
torch.arange(atom_seq_len, device = self.device)
3618-
)
3619-
3620-
sampled_atom_pos = einx.get_at(
3621-
'b [m] 3, b rm -> b rm 3',
3622-
sampled_atom_pos,
3623-
output_atompos_indices
3624-
)
3625-
36263605
return sampled_atom_pos
36273606

36283607
# if being forced to return loss, but do not have sufficient information to return losses, just return 0

alphafold3_pytorch/inputs.py

Lines changed: 5 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ class AtomInput:
135135
template_mask: Bool[' t'] | None = None
136136
msa_mask: Bool[' s'] | None = None
137137
atom_pos: Float['m 3'] | None = None
138-
output_atompos_indices: Int[' m'] | None = None
139138
molecule_atom_indices: Int[' n'] | None = None
140139
distogram_atom_indices: Int[' n'] | None = None
141140
distance_labels: Int['n n'] | None = None
@@ -166,7 +165,6 @@ class BatchedAtomInput:
166165
template_mask: Bool['b t'] | None = None
167166
msa_mask: Bool['b s'] | None = None
168167
atom_pos: Float['b m 3'] | None = None
169-
output_atompos_indices: Int['b m'] | None = None
170168
molecule_atom_indices: Int['b n'] | None = None
171169
distogram_atom_indices: Int['b n'] | None = None
172170
distance_labels: Int['b n n'] | None = None
@@ -215,7 +213,6 @@ class MoleculeInput:
215213
templates: Float['t n n dt'] | None = None
216214
msa: Float['s n dm'] | None = None
217215
atom_pos: List[Float['_ 3']] | Float['m 3'] | None = None
218-
output_atompos_indices: Int[' m'] | None = None
219216
template_mask: Bool[' t'] | None = None
220217
msa_mask: Bool[' s'] | None = None
221218
distance_labels: Int['n n'] | None = None
@@ -355,9 +352,13 @@ def molecule_to_atom_input(
355352
# and not the first biomolecule in the chain, add a single covalent bond between first atom of incoming biomolecule and the last atom of the last biomolecule
356353

357354
if is_chainable_biomolecule and not is_first_mol_in_chain:
355+
356+
358357
atompair_ids[offset, offset - 1] = 1
359358
atompair_ids[offset - 1, offset] = 1
360359

360+
last_mol = mol
361+
361362
# atom_inputs
362363

363364
atom_inputs: List[Float['m dai']] = []
@@ -444,7 +445,6 @@ class Alphafold3Input:
444445
resolved_labels: Int[' n'] | None = None
445446
add_atom_ids: bool = False
446447
add_atompair_ids: bool = False
447-
add_output_atompos_indices: bool = True
448448
directed_bonds: bool = False
449449
extract_atom_feats_fn: Callable[[Atom], Float['m dai']] = default_extract_atom_feats_fn
450450
extract_atompair_feats_fn: Callable[[Mol], Float['m m dapi']] = default_extract_atompair_feats_fn
@@ -814,58 +814,9 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
814814
molecule_atom_indices = tensor(molecule_atom_indices)
815815
molecule_atom_indices = pad_to_len(molecule_atom_indices, num_tokens, value = -1)
816816

817-
# handle atom positions
817+
# todo - handle atom positions for variable lengthed atoms (eventual missing atoms from mmCIF)
818818

819819
atom_pos = i.atom_pos
820-
output_atompos_indices = None
821-
822-
if exists(atom_pos):
823-
if isinstance(atom_pos, list):
824-
atom_pos = torch.cat(atom_pos, dim = -2)
825-
826-
assert atom_pos.shape[-2] == total_atoms
827-
828-
# to automatically reorder the atom positions back to canonical
829-
830-
if i.add_output_atompos_indices:
831-
offset = 0
832-
833-
reorder_atompos_indices = []
834-
output_atompos_indices = []
835-
836-
for chain in chainable_biomol_entries:
837-
for idx, entry in enumerate(chain):
838-
is_last = idx == (len(chain) - 1)
839-
840-
mol = entry['rdchem_mol']
841-
num_atoms = mol.GetNumAtoms()
842-
atom_reorder_indices = entry['atom_reorder_indices']
843-
844-
if not is_last:
845-
num_atoms -= 1
846-
atom_reorder_indices = atom_reorder_indices[:-1]
847-
848-
reorder_back_indices = atom_reorder_indices.argsort()
849-
850-
output_atompos_indices.append(reorder_back_indices + offset)
851-
reorder_atompos_indices.append(atom_reorder_indices + offset)
852-
853-
offset += num_atoms
854-
855-
output_atompos_indices = torch.cat(output_atompos_indices, dim = -1)
856-
output_atompos_indices = pad_to_length(output_atompos_indices, total_atoms, value = -1)
857-
858-
# if atom positions are passed in, need to be reordered for the bonds between residues / nucleotides to be contiguous
859-
# todo - fix to have no reordering needed (bonds are built not contiguous, just hydroxyl removed)
860-
861-
if i.reorder_atom_pos:
862-
orig_order = torch.arange(total_atoms)
863-
reorder_atompos_indices = torch.cat(reorder_atompos_indices, dim = -1)
864-
reorder_atompos_indices = pad_to_length(reorder_atompos_indices, total_atoms, value = -1)
865-
866-
reorder_indices = torch.where(reorder_atompos_indices != -1, reorder_atompos_indices, orig_order)
867-
868-
atom_pos = atom_pos[reorder_indices]
869820

870821
# create molecule input
871822

@@ -880,7 +831,6 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
880831
additional_token_feats = default(i.additional_token_feats, torch.zeros(num_tokens, 2)),
881832
is_molecule_types = is_molecule_types,
882833
atom_pos = atom_pos,
883-
output_atompos_indices = output_atompos_indices,
884834
templates = i.templates,
885835
msa = i.msa,
886836
template_mask = i.template_mask,

alphafold3_pytorch/trainer.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,6 @@ def collate_inputs_to_batched_atom_input(
176176

177177
batched_atom_input_dict = dict(tuple(zip(keys, outputs)))
178178

179-
# just ensure output_atompos_indices has full atom_seq_len manually for now
180-
181-
output_atompos_indices = batched_atom_input_dict.get('output_atompos_indices', None)
182-
183-
if exists(output_atompos_indices):
184-
atom_seq_len = batched_atom_input_dict['atom_inputs'].shape[-2]
185-
batched_atom_input_dict.update(output_atompos_indices = pad_or_slice_to(output_atompos_indices, atom_seq_len, dim = -1, pad_value = -1))
186-
187179
# reconstitute dictionary
188180

189181
batched_atom_inputs = BatchedAtomInput(**batched_atom_input_dict)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.120"
3+
version = "0.1.121"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@ def __getitem__(self, idx):
6363
if random() > 0.5:
6464
msa_mask = torch.ones((7,)).bool()
6565

66-
# randomly reorder output atompos indices for testing
67-
68-
output_atompos_indices = None
69-
if random() > 0.5:
70-
output_atompos_indices = torch.arange(atom_seq_len).long()
71-
7266
# required for training, but omitted on inference
7367

7468
atom_pos = torch.randn(atom_seq_len, 3)
@@ -94,7 +88,6 @@ def __getitem__(self, idx):
9488
msa = msa,
9589
msa_mask = msa_mask,
9690
atom_pos = atom_pos,
97-
output_atompos_indices = output_atompos_indices,
9891
molecule_atom_indices = molecule_atom_indices,
9992
distance_labels = distance_labels,
10093
pae_labels = pae_labels,
@@ -182,7 +175,7 @@ def test_trainer():
182175

183176
# assert can load latest checkpoint by loading from a directory
184177

185-
trainer.load('./checkpoints')
178+
trainer.load('./checkpoints', strict = False)
186179

187180
assert exists(trainer.model_loaded_from_path)
188181

0 commit comments

Comments
 (0)