Skip to content

Commit 49018a6

Browse files
committed
account for missing atoms for atom_indices_for_frame, and offset appropriately
1 parent 22a99d7 commit 49018a6

File tree

3 files changed

+18
-9
lines changed

3 files changed

+18
-9
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,11 @@ def pad_to_len(t, length, value = 0, dim = -1):
139139
zeros = (0, 0) * (-dim - 1)
140140
return F.pad(t, (*zeros, 0, max(0, length - t.shape[dim])), value = value)
141141

142+
def offset_only_positive(t, offset):
143+
is_positive = t >= 0
144+
t_offsetted = t + offset
145+
return torch.where(is_positive, t_offsetted, t)
146+
142147
def compose(*fns: Callable):
143148
# for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
144149

@@ -871,9 +876,7 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
871876
additional_token_feats = repeat_interleave(i.additional_token_feats, token_repeats, dim = 0)
872877
molecule_ids = repeat_interleave(i.molecule_ids, token_repeats)
873878

874-
atom_indices_offsets = exclusive_cumsum(atoms_per_molecule)
875-
distogram_atom_indices = i.distogram_atom_indices + atom_indices_offsets
876-
molecule_atom_indices = i.molecule_atom_indices + atom_indices_offsets
879+
atom_indices_offsets = repeat_interleave(exclusive_cumsum(atoms_per_molecule), token_repeats, dim = 0)
877880

878881
distogram_atom_indices = repeat_interleave(i.distogram_atom_indices, token_repeats)
879882
molecule_atom_indices = repeat_interleave(i.molecule_atom_indices, token_repeats)
@@ -1018,10 +1021,6 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
10181021
atom_indices_for_frame = [default(indices, (-1, -1, -1)) for indices in i.atom_indices_for_frame]
10191022
atom_indices_for_frame = tensor(atom_indices_for_frame)
10201023

1021-
atom_indices_for_frame = atom_indices_for_frame + atom_indices_offsets[..., None]
1022-
valid_atom_indices_for_frame = (atom_indices_for_frame >= 0).all(dim = -1)
1023-
1024-
atom_indices_for_frame = einx.where('n, n c, -> n c', valid_atom_indices_for_frame, atom_indices_for_frame, -1)
10251024
atom_indices_for_frame = repeat_interleave(atom_indices_for_frame, token_repeats, dim = 0)
10261025

10271026
# handle maybe atompair embeds
@@ -1155,8 +1154,19 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
11551154
"n missing, n -> n missing", missing_token_indices, distogram_atom_indices
11561155
).any(dim=-1)
11571156

1157+
is_missing_atom_indices_for_frame = einx.equal(
1158+
"n missing, n c -> n c missing", missing_token_indices, atom_indices_for_frame
1159+
).any(dim=(-1, -2))
1160+
11581161
molecule_atom_indices = molecule_atom_indices.masked_fill(is_missing_molecule_atom, -1)
11591162
distogram_atom_indices = distogram_atom_indices.masked_fill(is_missing_distogram_atom, -1)
1163+
atom_indices_for_frame = atom_indices_for_frame.masked_fill(is_missing_atom_indices_for_frame[..., None], -1)
1164+
1165+
# offsets for all indices
1166+
1167+
distogram_atom_indices = offset_only_positive(distogram_atom_indices, atom_indices_offsets)
1168+
molecule_atom_indices = offset_only_positive(molecule_atom_indices, atom_indices_offsets)
1169+
atom_indices_for_frame = offset_only_positive(atom_indices_for_frame, atom_indices_offsets[..., None])
11601170

11611171
# handle atom positions
11621172

alphafold3_pytorch/life.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,6 @@ def mol_from_template_mmcif_file(
553553
assert 0 <= entry["token_center_atom_idx"] < num_atoms
554554

555555
if exists(entry.get('three_atom_indices_for_frame', None)):
556-
print(num_atoms, entry, entry['three_atom_indices_for_frame'])
557556
assert all([(0 <= i < num_atoms) for i in entry["three_atom_indices_for_frame"]])
558557

559558
assert entry["first_atom_idx"] != entry["last_atom_idx"]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.3.3"
3+
version = "0.3.5"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" },

0 commit comments

Comments
 (0)