Skip to content

Commit 2da0c5a

Browse files
committed
fix an issue with missing_atom_indices due to ligands being one token per atom
1 parent ddd3027 commit 2da0c5a

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,11 +301,11 @@ def molecule_to_atom_input(
301301

302302
assert len(molecules) == len(i.missing_atom_indices), f'{len(i.missing_atom_indices)} missing atom indices does not match the number of molecules given ({len(molecules)})'
303303

304-
missing_atom_indices: List[Int[' _']] = [default(indices, torch.empty((0,), dtype = torch.long)) for indices in i.missing_atom_indices]
305-
304+
missing_atom_indices: List[Int[' _']] = []
306305
missing_atom_mask: List[Bool[' _']] = []
307306

308-
for num_atoms, mol_missing_atom_indices in zip(all_num_atoms, missing_atom_indices):
307+
for num_atoms, mol_missing_atom_indices, is_ligand in zip(all_num_atoms, i.missing_atom_indices, i.is_molecule_types[:, -1]):
308+
mol_missing_atom_indices = default(mol_missing_atom_indices, torch.empty((0,), dtype = torch.long))
309309

310310
mol_miss_atom_mask = torch.zeros(num_atoms, dtype = torch.bool)
311311

@@ -314,8 +314,14 @@ def molecule_to_atom_input(
314314

315315
missing_atom_mask.append(mol_miss_atom_mask)
316316

317-
missing_atom_mask = torch.cat(missing_atom_mask)
317+
if not is_ligand:
318+
missing_atom_indices.append(mol_missing_atom_indices)
319+
else:
320+
for is_missing_atom_in_ligand in mol_miss_atom_mask:
321+
index = tensor([0] if is_missing_atom_in_ligand else [], dtype = torch.long)
322+
missing_atom_indices.append(index)
318323

324+
missing_atom_mask = torch.cat(missing_atom_mask)
319325
missing_atom_indices = pad_sequence(missing_atom_indices, batch_first = True, padding_value = -1)
320326

321327
# handle maybe atompair embeds

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.133"
3+
version = "0.1.134"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_input.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,14 @@ def test_atompos_input():
8787

8888
mock_atompos = [
8989
torch.randn(5, 3), # alanine has 5 non-hydrogen atoms
90-
torch.randn(4, 3) # glycine has 4 non-hydrogen atoms
90+
torch.randn(4, 3), # glycine has 4 non-hydrogen atoms
91+
torch.randn(3, 3) # ligand has 3 carbons
9192
]
9293

9394
train_alphafold3_input = Alphafold3Input(
9495
proteins = [contrived_protein],
95-
missing_atom_indices = [[1, 2], None],
96+
missing_atom_indices = [[1, 2], None, [0, 1]],
97+
ligands = ['CCC'],
9698
atom_pos = mock_atompos
9799
)
98100

0 commit comments

Comments
 (0)