Skip to content

Commit e9e0bab

Browse files
authored
Update inputs.py (#262)
1 parent 872d736 commit e9e0bab

File tree

1 file changed

+43
-5
lines changed

1 file changed

+43
-5
lines changed

alphafold3_pytorch/inputs.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,7 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
819819

820820
molecule_atom_indices = i.molecule_atom_indices
821821
distogram_atom_indices = i.distogram_atom_indices
822+
atom_indices_for_frame = i.atom_indices_for_frame
822823

823824
if exists(missing_token_indices) and missing_token_indices.shape[-1]:
824825
is_missing_molecule_atom = einx.equal(
@@ -827,9 +828,29 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
827828
is_missing_distogram_atom = einx.equal(
828829
"n missing, n -> n missing", missing_token_indices, distogram_atom_indices
829830
).any(dim=-1)
831+
is_missing_atom_indices_for_frame = einx.equal(
832+
"n missing, n three -> n three missing", missing_token_indices, atom_indices_for_frame
833+
).any(dim=-1)
830834

831835
molecule_atom_indices = molecule_atom_indices.masked_fill(is_missing_molecule_atom, -1)
832836
distogram_atom_indices = distogram_atom_indices.masked_fill(is_missing_distogram_atom, -1)
837+
atom_indices_for_frame = atom_indices_for_frame.masked_fill(
838+
is_missing_atom_indices_for_frame, -1
839+
)
840+
841+
# sanity-check the atom indices
842+
if not (-1 <= molecule_atom_indices.min() <= molecule_atom_indices.max() < total_atoms):
843+
raise ValueError(
844+
f"Invalid molecule atom indices found in `molecule_to_atom_input()` for {i.filepath}: {molecule_atom_indices}"
845+
)
846+
if not (-1 <= distogram_atom_indices.min() <= distogram_atom_indices.max() < total_atoms):
847+
raise ValueError(
848+
f"Invalid distogram atom indices found in `molecule_to_atom_input()` for {i.filepath}: {distogram_atom_indices}"
849+
)
850+
if not (-1 <= atom_indices_for_frame.min() <= atom_indices_for_frame.max() < total_atoms):
851+
raise ValueError(
852+
f"Invalid atom indices for frame found in `molecule_to_atom_input()` for {i.filepath}: {atom_indices_for_frame}"
853+
)
833854

834855
# handle atom positions
835856

@@ -856,9 +877,9 @@ def molecule_to_atom_input(mol_input: MoleculeInput) -> AtomInput:
856877
atompair_inputs=atompair_inputs,
857878
molecule_atom_lens=atom_lens.long(),
858879
molecule_ids=i.molecule_ids,
859-
molecule_atom_indices=i.molecule_atom_indices,
860-
distogram_atom_indices=i.distogram_atom_indices,
861-
atom_indices_for_frame=i.atom_indices_for_frame,
880+
molecule_atom_indices=molecule_atom_indices,
881+
distogram_atom_indices=distogram_atom_indices,
882+
atom_indices_for_frame=atom_indices_for_frame,
862883
is_molecule_mod=is_molecule_mod,
863884
msa=i.msa,
864885
templates=i.templates,
@@ -3025,12 +3046,14 @@ def pdb_input_to_molecule_input(
30253046

30263047
current_atom_index = 0
30273048
current_res_index = -1
3049+
current_chain_index = -1
30283050

3029-
for mol_type, atom_mask, chemid, res_index in zip(
3051+
for mol_type, atom_mask, chemid, res_index, res_chain_index in zip(
30303052
molecule_atom_types,
30313053
biomol.atom_mask,
30323054
biomol.chemid,
30333055
biomol.residue_index,
3056+
biomol.chain_index,
30343057
):
30353058
residue_constants = get_residue_constants(
30363059
mol_type.replace("protein", "peptide").replace("mod_", "")
@@ -3045,11 +3068,12 @@ def pdb_input_to_molecule_input(
30453068

30463069
if is_atomized_residue(mol_type):
30473070
# collect indices for each ligand and modified polymer residue token (i.e., atom)
3048-
if current_res_index == res_index:
3071+
if current_res_index == res_index and current_chain_index == res_chain_index:
30493072
current_atom_index += 1
30503073
else:
30513074
current_atom_index = 0
30523075
current_res_index = res_index
3076+
current_chain_index = res_chain_index
30533077

30543078
# NOTE: we have to dynamically determine the token center atom index for atomized residues
30553079
token_center_atom_index = np.where(atom_mask)[0][0]
@@ -3439,6 +3463,20 @@ def pdb_input_to_molecule_input(
34393463
)
34403464
num_atoms = atom_pos.shape[0]
34413465

3466+
# sanity-check the atom indices
3467+
if not (-1 <= distogram_atom_indices.min() <= distogram_atom_indices.max() < num_atoms):
3468+
raise ValueError(
3469+
f"Invalid distogram atom indices found in `pdb_input_to_molecule_input()` for {filepath}: {distogram_atom_indices}"
3470+
)
3471+
if not (-1 <= molecule_atom_indices.min() <= molecule_atom_indices.max() < num_atoms):
3472+
raise ValueError(
3473+
f"Invalid molecule atom indices found in `pdb_input_to_molecule_input()` for {filepath}: {molecule_atom_indices}"
3474+
)
3475+
if not (-1 <= atom_indices_for_frame.min() <= atom_indices_for_frame.max() < num_atoms):
3476+
raise ValueError(
3477+
f"Invalid atom indices for frame found in `pdb_input_to_molecule_input()` for {filepath}: {atom_indices_for_frame}"
3478+
)
3479+
34423480
# create atom_parent_ids using the `Biomolecule` object, which governs in the atom
34433481
# encoder / decoder which atom attends to which, where a design choice is made such
34443482
# that mmCIF author chain indices are directly adopted to group atoms belonging to

0 commit comments

Comments
 (0)