@@ -2264,7 +2264,7 @@ def find_mismatched_symmetry(
22642264def load_msa_from_msa_dir (
22652265 msa_dir : str | None ,
22662266 file_id : str ,
2267- chain_id_to_chem_types : Dict [str , List [int ]],
2267+ chain_id_to_residue : Dict [str , Dict [ str , List [int ] ]],
22682268 max_msas_per_chain : int | None = None ,
22692269 randomly_truncate : bool = True ,
22702270 raise_missing_exception : bool = False ,
@@ -2279,16 +2279,16 @@ def load_msa_from_msa_dir(
22792279 return {}
22802280
22812281 msas = {}
2282- for chain_id in chain_id_to_chem_types :
2282+ for chain_id in chain_id_to_residue :
22832283 msa_fpaths = glob .glob (os .path .join (msa_dir , f"{ file_id } { chain_id } _*.a3m" ))
22842284
22852285 if not msa_fpaths :
22862286 msas [chain_id ] = None
22872287 continue
22882288
22892289 # NOTE: A single chain-specific MSA file contains alignments for all polymer residues in the chain,
2290- # but the ligand (and some "unmappable" modified polymer residues) are not included in the MSA file
2291- # and therefore must be manually inserted into the MSAs as unknown amino acid residues.
2290+ # but the chain's ligands are not included in the MSA file and therefore must be manually inserted
2291+ # into the MSAs as unknown amino acid residues.
22922292 assert len (msa_fpaths ) == 1 , (
22932293 f"{ len (msa_fpaths )} MSA files found for chain { chain_id } of file { file_id } . "
22942294 "Please ensure that one MSA file is present for each chain."
@@ -2310,7 +2310,7 @@ def load_msa_from_msa_dir(
23102310 )
23112311 msas [chain_id ] = msa
23122312
2313- features = make_msa_features (msas , chain_id_to_chem_types )
2313+ features = make_msa_features (msas , chain_id_to_residue )
23142314 features = make_msa_mask (features )
23152315
23162316 return features
@@ -2606,10 +2606,12 @@ def pdb_input_to_molecule_input(
26062606
26072607 # concat for all of additional_molecule_feats
26082608
2609+ # NOTE: `Biomolecule.residue_index` is 1-based originally
2610+ residue_index = torch .from_numpy (biomol .residue_index ) - 1
2611+
26092612 additional_molecule_feats = torch .stack (
26102613 (
2611- # NOTE: `Biomolecule.residue_index` is 1-based originally
2612- torch .from_numpy (biomol .residue_index ) - 1 ,
2614+ residue_index ,
26132615 torch .arange (num_tokens ),
26142616 torch .from_numpy (biomol .chain_index ),
26152617 entity_ids ,
@@ -2790,15 +2792,63 @@ def pdb_input_to_molecule_input(
27902792 num_present_atoms = mol_total_atoms - num_missing_atom_indices
27912793 assert num_present_atoms == int (biomol .atom_mask .sum ())
27922794
2795+ # handle `atom_indices_for_frame` for the PAE
2796+
2797+ atom_indices_for_frame = tensor (
2798+ [default (indices , (- 1 , - 1 , - 1 )) for indices in atom_indices_for_frame ]
2799+ )
2800+
2801+ # build offsets for all indices
2802+
2803+ # derive `atom_lens` based on `one_token_per_atom`, for ligands and modified biomolecules
2804+ atoms_per_molecule = tensor ([mol .GetNumAtoms () for mol in molecules ])
2805+ ones = torch .ones_like (atoms_per_molecule )
2806+
2807+ # `is_molecule_mod` can either be
2808+ # 1. Bool['n'], in which case it will only be used for determining `one_token_per_atom`, or
2809+ # 2. Bool['n num_mods'], where it will be passed to Alphafold3 for molecule modification embeds
2810+ is_molecule_mod = tensor (is_molecule_mod )
2811+ is_molecule_any_mod = False
2812+
2813+ if is_molecule_mod .ndim == 2 :
2814+ is_molecule_any_mod = is_molecule_mod [unique_chain_residue_indices ].any (dim = - 1 )
2815+ else :
2816+ is_molecule_any_mod = is_molecule_mod [unique_chain_residue_indices ]
2817+
2818+ # get `one_token_per_atom`
2819+ # default to what the paper did, which is ligands and any modified biomolecule
2820+ is_ligand = is_molecule_types [unique_chain_residue_indices ][..., IS_LIGAND_INDEX ]
2821+ one_token_per_atom = is_ligand | is_molecule_any_mod
2822+
2823+ assert len (molecules ) == len (one_token_per_atom )
2824+
2825+ # derive the number of repeats needed to expand molecule lengths to token lengths
2826+ token_repeats = torch .where (one_token_per_atom , atoms_per_molecule , ones )
2827+
2828+ # craft offsets for all atom indices
2829+ atom_indices_offsets = repeat_interleave (
2830+ exclusive_cumsum (atoms_per_molecule ), token_repeats , dim = 0
2831+ )
2832+
2833+ # offset only positive atom indices
2834+ distogram_atom_indices = offset_only_positive (distogram_atom_indices , atom_indices_offsets )
2835+ molecule_atom_indices = offset_only_positive (molecule_atom_indices , atom_indices_offsets )
2836+ atom_indices_for_frame = offset_only_positive (
2837+ atom_indices_for_frame , atom_indices_offsets [..., None ]
2838+ )
2839+
27932840 # retrieve multiple sequence alignments (MSAs) for each chain
27942841 # NOTE: if they are not locally available, `Nones` will be used
27952842 msa_chain_ids = list (dict .fromkeys (biomol .chain_id .tolist ()))
2796- chain_id_to_chem_types = {
2797- chain_id : biomol .chemtype [biomol .chain_id == chain_id ].tolist ()
2843+ chain_id_to_residue = {
2844+ chain_id : {
2845+ "chemtype" : biomol .chemtype [biomol .chain_id == chain_id ].tolist (),
2846+ "residue_index" : residue_index [biomol .chain_id == chain_id ].tolist (),
2847+ }
27982848 for chain_id in msa_chain_ids
27992849 }
28002850 msa_features = load_msa_from_msa_dir (
2801- i .msa_dir , file_id , chain_id_to_chem_types , max_msas_per_chain = i .max_msas_per_chain
2851+ i .msa_dir , file_id , chain_id_to_residue , max_msas_per_chain = i .max_msas_per_chain
28022852 )
28032853
28042854 msa = msa_features .get ("msa" )
@@ -2817,6 +2867,10 @@ def pdb_input_to_molecule_input(
28172867 num_msas = len (msa ) if exists (msa ) else 1
28182868
28192869 if exists (msa ):
2870+ assert (
2871+ msa .shape [- 1 ] == num_tokens
2872+ ), f"The number of tokens in the MSA ({ msa .shape [- 1 ]} ) does not match the number of tokens in the biomolecule ({ num_tokens } ). "
2873+
28202874 has_deletion = torch .clip (msa_features ["deletion_matrix" ], 0.0 , 1.0 )
28212875 deletion_value = torch .atan (msa_features ["deletion_matrix" ] / 3.0 ) * (2.0 / torch .pi )
28222876
@@ -2883,51 +2937,6 @@ def pdb_input_to_molecule_input(
28832937 is_resolved_label = ((resolution >= 0.1 ) & (resolution <= 3.0 )).item ()
28842938 resolved_labels = torch .full ((num_atoms ,), is_resolved_label , dtype = torch .long )
28852939
2886- # handle `atom_indices_for_frame` for the PAE
2887-
2888- atom_indices_for_frame = tensor (
2889- [default (indices , (- 1 , - 1 , - 1 )) for indices in atom_indices_for_frame ]
2890- )
2891-
2892- # build offsets for all indices
2893-
2894- # derive `atom_lens` based on `one_token_per_atom`, for ligands and modified biomolecules
2895- atoms_per_molecule = tensor ([mol .GetNumAtoms () for mol in molecules ])
2896- ones = torch .ones_like (atoms_per_molecule )
2897-
2898- # `is_molecule_mod` can either be
2899- # 1. Bool['n'], in which case it will only be used for determining `one_token_per_atom`, or
2900- # 2. Bool['n num_mods'], where it will be passed to Alphafold3 for molecule modification embeds
2901- is_molecule_mod = tensor (is_molecule_mod )
2902- is_molecule_any_mod = False
2903-
2904- if is_molecule_mod .ndim == 2 :
2905- is_molecule_any_mod = is_molecule_mod [unique_chain_residue_indices ].any (dim = - 1 )
2906- else :
2907- is_molecule_any_mod = is_molecule_mod [unique_chain_residue_indices ]
2908-
2909- # get `one_token_per_atom`
2910- # default to what the paper did, which is ligands and any modified biomolecule
2911- is_ligand = is_molecule_types [unique_chain_residue_indices ][..., IS_LIGAND_INDEX ]
2912- one_token_per_atom = is_ligand | is_molecule_any_mod
2913-
2914- assert len (molecules ) == len (one_token_per_atom )
2915-
2916- # derive the number of repeats needed to expand molecule lengths to token lengths
2917- token_repeats = torch .where (one_token_per_atom , atoms_per_molecule , ones )
2918-
2919- # craft offsets for all atom indices
2920- atom_indices_offsets = repeat_interleave (
2921- exclusive_cumsum (atoms_per_molecule ), token_repeats , dim = 0
2922- )
2923-
2924- # offset only positive atom indices
2925- distogram_atom_indices = offset_only_positive (distogram_atom_indices , atom_indices_offsets )
2926- molecule_atom_indices = offset_only_positive (molecule_atom_indices , atom_indices_offsets )
2927- atom_indices_for_frame = offset_only_positive (
2928- atom_indices_for_frame , atom_indices_offsets [..., None ]
2929- )
2930-
29312940 # create molecule input
29322941
29332942 molecule_input = MoleculeInput (
0 commit comments