@@ -213,6 +213,7 @@ class MoleculeInput:
213213 token_bonds : Bool ['n n' ]
214214 molecule_atom_indices : List [int | None ] | None = None
215215 distogram_atom_indices : List [int | None ] | None = None
216+ missing_atom_indices : List [Int [' _' ] | None ] | None = None
216217 atom_parent_ids : Int [' m' ] | None = None
217218 additional_token_feats : Float [f'n dtf' ] | None = None
218219 templates : Float ['t n n dt' ] | None = None
@@ -288,6 +289,24 @@ def molecule_to_atom_input(
288289 all_num_atoms = tensor ([mol .GetNumAtoms () for mol in molecules ])
289290 offsets = exclusive_cumsum (all_num_atoms )
290291
292+ # handle maybe missing atom indices
293+
294+ missing_atom_mask = None
295+
296+ if exists (i .missing_atom_indices ) and len (i .missing_atom_indices ) > 0 :
297+
298+ missing_atom_mask = []
299+
300+ for num_atoms , mol_missing_atom_indices in zip (all_num_atoms , i .missing_atom_indices ):
301+ mol_miss_atom_mask = torch .zeros (num_atoms , dtype = torch .bool )
302+
303+ if exists (mol_missing_atom_indices ) and mol_missing_atom_indices .numel () > 0 :
304+ mol_miss_atom_mask .scatter_ (- 1 , mol_missing_atom_indices , True )
305+
306+ missing_atom_mask .append (mol_miss_atom_mask )
307+
308+ missing_atom_mask = torch .cat (missing_atom_mask )
309+
291310 # handle maybe atompair embeds
292311
293312 atompair_ids = None
@@ -420,6 +439,7 @@ def molecule_to_atom_input(
420439 molecule_ids = i .molecule_ids ,
421440 molecule_atom_indices = i .molecule_atom_indices ,
422441 distogram_atom_indices = i .distogram_atom_indices ,
442+ missing_atom_mask = missing_atom_mask ,
423443 additional_token_feats = i .additional_token_feats ,
424444 additional_molecule_feats = i .additional_molecule_feats ,
425445 is_molecule_types = i .is_molecule_types ,
@@ -448,6 +468,7 @@ class Alphafold3Input:
448468 ds_dna : List [Int [' _' ] | str ] = imm_list ()
449469 ds_rna : List [Int [' _' ] | str ] = imm_list ()
450470 atom_parent_ids : Int [' m' ] | None = None
471+ missing_atom_indices : List [List [int ] | None ] = imm_list ()
451472 additional_token_feats : Float [f'n dtf' ] | None = None
452473 templates : Float ['t n n dt' ] | None = None
453474 msa : Float ['s n dm' ] | None = None
@@ -844,10 +865,25 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
844865 src_tgt_atom_indices = tensor (src_tgt_atom_indices )
845866 src_tgt_atom_indices = pad_to_len (src_tgt_atom_indices , num_tokens , value = - 1 , dim = - 2 )
846867
847- # todo - handle atom positions for variable lengthed atoms (eventual missing atoms from mmCIF)
868+ # atom positions
848869
849870 atom_pos = i .atom_pos
850871
872+ # handle missing atom indices
873+
874+ missing_atom_indices = None
875+
876+ if exists (i .missing_atom_indices ) and len (i .missing_atom_indices ) > 0 :
877+ missing_atom_indices = []
878+
879+ for mol_miss_atom_indices in i .missing_atom_indices :
880+ mol_miss_atom_indices = default (mol_miss_atom_indices , [])
881+ mol_miss_atom_indices = tensor (mol_miss_atom_indices , dtype = torch .long )
882+
883+ missing_atom_indices .append (mol_miss_atom_indices )
884+
885+ assert len (molecules ) == len (missing_atom_indices )
886+
851887 # create molecule input
852888
853889 molecule_input = MoleculeInput (
@@ -860,6 +896,7 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
860896 additional_molecule_feats = additional_molecule_feats ,
861897 additional_token_feats = default (i .additional_token_feats , torch .zeros (num_tokens , 2 )),
862898 is_molecule_types = is_molecule_types ,
899+ missing_atom_indices = missing_atom_indices ,
863900 src_tgt_atom_indices = src_tgt_atom_indices ,
864901 atom_pos = atom_pos ,
865902 templates = i .templates ,
0 commit comments