Skip to content

Commit 2329379

Browse files
authored
Update inputs.py (#146)
1 parent b353bac commit 2329379

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
635635
molecule_atom_indices = i.molecule_atom_indices
636636
distogram_atom_indices = i.distogram_atom_indices
637637

638-
if exists(missing_token_indices):
638+
if exists(missing_token_indices) and missing_token_indices.shape[-1]:
639639
is_missing_molecule_atom = einx.equal(
640640
"n missing, n -> n missing", missing_token_indices, molecule_atom_indices
641641
).any(dim=-1)
@@ -1064,7 +1064,7 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
10641064

10651065
# mask out molecule atom indices and distogram atom indices where it is in the missing atom indices list
10661066

1067-
if exists(missing_token_indices):
1067+
if exists(missing_token_indices) and missing_token_indices.shape[-1]:
10681068
missing_token_indices = repeat_interleave(missing_token_indices, token_repeats, dim = 0)
10691069

10701070
is_missing_molecule_atom = einx.equal(
@@ -2500,26 +2500,23 @@ def pdb_input_to_molecule_input(
25002500
for mol in molecules
25012501
]
25022502

2503-
if any(molecules_missing_atom_indices):
2504-
missing_atom_indices = []
2505-
missing_token_indices = []
2503+
missing_atom_indices = []
2504+
missing_token_indices = []
25062505

2507-
for mol_miss_atom_indices, mol, mol_type in zip(
2508-
molecules_missing_atom_indices, molecules, molecule_types
2509-
):
2510-
mol_miss_atom_indices = default(mol_miss_atom_indices, [])
2511-
mol_miss_atom_indices = tensor(mol_miss_atom_indices, dtype=torch.long)
2506+
for mol_miss_atom_indices, mol, mol_type in zip(
2507+
molecules_missing_atom_indices, molecules, molecule_types
2508+
):
2509+
mol_miss_atom_indices = default(mol_miss_atom_indices, [])
2510+
mol_miss_atom_indices = tensor(mol_miss_atom_indices, dtype=torch.long)
25122511

2513-
missing_atom_indices.append(mol_miss_atom_indices)
2514-
if is_atomized_residue(mol_type):
2515-
missing_token_indices.extend(
2516-
[mol_miss_atom_indices for _ in range(mol.GetNumAtoms())]
2517-
)
2518-
else:
2519-
missing_token_indices.append(mol_miss_atom_indices)
2512+
missing_atom_indices.append(mol_miss_atom_indices)
2513+
if is_atomized_residue(mol_type):
2514+
missing_token_indices.extend([mol_miss_atom_indices for _ in range(mol.GetNumAtoms())])
2515+
else:
2516+
missing_token_indices.append(mol_miss_atom_indices)
25202517

2521-
assert len(molecules) == len(missing_atom_indices)
2522-
assert len(missing_token_indices) == num_tokens
2518+
assert len(molecules) == len(missing_atom_indices)
2519+
assert len(missing_token_indices) == num_tokens
25232520

25242521
# TODO: install additional token features once MSAs are available
25252522
# 0: f_profile

0 commit comments

Comments
 (0)