@@ -691,7 +691,7 @@ class MoleculeLengthMoleculeInput:
691691 src_tgt_atom_indices : Int ['n 2' ]
692692 token_bonds : Bool ['n n' ] | None = None
693693 one_token_per_atom : List [bool ] | None = None
694- is_molecule_mod : Bool ['n num_mods' ] | None = None
694+ is_molecule_mod : Bool ['n num_mods' ] | Bool [ 'n' ] | None = None
695695 molecule_atom_indices : List [int | None ] | None = None
696696 distogram_atom_indices : List [int | None ] | None = None
697697 missing_atom_indices : List [Int [' _' ] | None ] | None = None
@@ -724,11 +724,23 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
724724
725725 # derive `atom_lens` based on `one_token_per_atom`, for ligands and modified biomolecules
726726
727- assert len (molecules ) == len (i .one_token_per_atom )
728-
729727 atoms_per_molecule = tensor ([mol .GetNumAtoms () for mol in molecules ])
730728 ones = torch .ones_like (atoms_per_molecule )
731729
730+ # `is_molecule_mod` can either be
731+ # 1. Bool['n'], in which case it will only be used for determining `one_token_per_atom`, or
732+ # 2. Bool['n num_mods'], where it will be passed to Alphafold3 for molecule modification embeds
733+
734+ is_molecule_mod = i .is_molecule_mod
735+ is_molecule_any_mod = False
736+
737+ if exists (is_molecule_mod ):
738+ if i .is_molecule_mod .ndim == 2 :
739+ is_molecule_any_mod = is_molecule_mod .any (dim = - 1 )
740+ else :
741+ is_molecule_any_mod = is_molecule_mod
742+ is_molecule_mod = None
743+
732744 # get `one_token_per_atom`, which can be fully customizable
733745
734746 if exists (i .one_token_per_atom ):
@@ -737,7 +749,9 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
737749 # if which molecule is `one_token_per_atom` is not passed in
738750 # default to what the paper did, which is ligands and any modified biomolecule
739751 is_ligand = i .is_molecule_types [..., IS_LIGAND_INDEX ]
740- one_token_per_atom = is_ligand | is_molecule_mod .any (dim = - 1 )
752+ one_token_per_atom = is_ligand | is_molecule_any_mod
753+
754+ assert len (molecules ) == len (one_token_per_atom )
741755
742756 # derive the number of repeats needed to expand molecule lengths to token lengths
743757
@@ -782,7 +796,7 @@ def molecule_lengthed_molecule_input_to_atom_input(mol_input: MoleculeLengthMole
782796 molecule_atom_indices = repeat_interleave (i .molecule_atom_indices , token_repeats )
783797
784798 msa = maybe (repeat_interleave )(i .msa , token_repeats , dim = - 2 )
785- is_molecule_mod = maybe (repeat_interleave )(i .is_molecule_mod , token_repeats , dim = - 2 )
799+ is_molecule_mod = maybe (repeat_interleave )(i .is_molecule_mod , token_repeats , dim = 0 )
786800
787801 templates = maybe (repeat_interleave )(i .templates , token_repeats , dim = - 3 )
788802 templates = maybe (repeat_interleave )(templates , token_repeats , dim = - 2 )
@@ -1340,12 +1354,6 @@ def alphafold3_input_to_molecule_lengthed_molecule_input(alphafold3_input: Alpha
13401354 * mol_metal_ions
13411355 ]
13421356
1343- one_token_per_atom = [
1344- * ((False ,) * len (molecules_without_ligands )),
1345- * ((True ,) * len (mol_ligands )),
1346- * ((False ,) * len (mol_metal_ions )),
1347- ]
1348-
13491357 for mol in molecules :
13501358 Chem .SanitizeMol (mol )
13511359
@@ -1498,7 +1506,6 @@ def get_num_atoms_per_chain(chains: List[List[Mol]]) -> List[int]:
14981506
14991507 molecule_input = MoleculeLengthMoleculeInput (
15001508 molecules = molecules ,
1501- one_token_per_atom = one_token_per_atom ,
15021509 molecule_atom_indices = molecule_atom_indices ,
15031510 distogram_atom_indices = distogram_atom_indices ,
15041511 molecule_ids = molecule_ids ,
0 commit comments