@@ -158,6 +158,20 @@ def offset_only_positive(t, offset):
158158 t_offsetted = t + offset
159159 return torch .where (is_positive , t_offsetted , t )
160160
161+ @typecheck
162+ def remove_consecutive_duplicate (
163+ t : Int ['n ...' ],
164+ remove_to_value = - 1
165+ ) -> Int ['n ...' ]:
166+
167+ is_duplicate = t [1 :] == t [:- 1 ]
168+
169+ if is_duplicate .ndim == 2 :
170+ is_duplicate = is_duplicate .all (dim = - 1 )
171+
172+ is_duplicate = F .pad (is_duplicate , (1 , 0 ), value = False )
173+ return einx .where ('n, n ..., -> n ... ' , ~ is_duplicate , t , remove_to_value )
174+
161175def compose (* fns : Callable ):
162176 # for chaining from Alphafold3Input -> MoleculeInput -> AtomInput
163177
@@ -1322,6 +1336,10 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
13221336 molecule_atom_indices = offset_only_positive (molecule_atom_indices , atom_indices_offsets )
13231337 atom_indices_for_frame = offset_only_positive (atom_indices_for_frame , atom_indices_offsets [..., None ])
13241338
1339+ # just use a hack to remove any duplicated indices (ligands and modified biomolecules) in a row
1340+
1341+ atom_indices_for_frame = remove_consecutive_duplicate (atom_indices_for_frame )
1342+
13251343 # handle atom positions
13261344
13271345 atom_pos = i .atom_pos
@@ -1451,6 +1469,13 @@ def alphafold3_input_to_molecule_lengthed_molecule_input(alphafold3_input: Alpha
14511469 ss_rnas = list (i .ss_rna )
14521470 ss_dnas = list (i .ss_dna )
14531471
1472+ # handle atom positions - need atom positions for deriving frame of ligand for PAE
1473+
1474+ atom_pos = i .atom_pos
1475+
1476+ if isinstance (atom_pos , list ):
1477+ atom_pos = torch .cat (atom_pos )
1478+
14541479 # any double stranded nucleic acids is added to single stranded lists with its reverse complement
14551480 # rc stands for reverse complement
14561481
@@ -1568,12 +1593,31 @@ def alphafold3_input_to_molecule_lengthed_molecule_input(alphafold3_input: Alpha
15681593 # convert ligands to rdchem.Mol
15691594
15701595 ligands = list (alphafold3_input .ligands )
1596+
15711597 mol_ligands = [
15721598 (mol_from_smile (ligand ) if isinstance (ligand , str ) else ligand ) for ligand in ligands
15731599 ]
15741600
15751601 molecule_ids .append (tensor ([ligand_id ] * len (mol_ligands )))
1576-
1602+
1603+ # handle frames for the ligands, which depends on knowing the atom positions (section 4.3.2)
1604+
1605+ if exists (atom_pos ):
1606+ ligand_atom_pos_offset = 0
1607+
1608+ for mol in flatten ([* mol_proteins , * mol_ss_rnas , * mol_ss_dnas ]):
1609+ ligand_atom_pos_offset += mol .GetNumAtoms ()
1610+
1611+ for mol_ligand in mol_ligands :
1612+ num_ligand_atoms = mol_ligand .GetNumAtoms ()
1613+ ligand_atom_pos = atom_pos [ligand_atom_pos_offset :(ligand_atom_pos_offset + num_ligand_atoms )]
1614+
1615+ frames = get_frames_from_atom_pos (ligand_atom_pos , filter_colinear_pos = True )
1616+
1617+ atom_indices_for_frame .append (frames .tolist ())
1618+
1619+ ligand_atom_pos_offset += num_ligand_atoms
1620+
15771621 # convert metal ions to rdchem.Mol
15781622
15791623 metal_ions = alphafold3_input .metal_ions
@@ -1760,10 +1804,6 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
17601804 molecule_atom_indices = tensor (molecule_atom_indices )
17611805 molecule_atom_indices = pad_to_len (molecule_atom_indices , num_tokens , value = - 1 )
17621806
1763- # atom positions
1764-
1765- atom_pos = i .atom_pos
1766-
17671807 # handle missing atom indices
17681808
17691809 missing_atom_indices = None
0 commit comments