@@ -2485,7 +2485,7 @@ def load_msa_from_msa_dir(
24852485 """Load MSA from a directory containing MSA files."""
24862486 if verbose and (not_exists (msa_dir ) or not os .path .exists (msa_dir )):
24872487 logger .warning (
2488- f"{ msa_dir } does not exist. Dummy MSA features for each chain of file { file_id } will instead be loaded."
2488+ f"{ msa_dir } MSA directory does not exist. Dummy MSA features for each chain of file { file_id } will instead be loaded."
24892489 )
24902490
24912491 msas = {}
@@ -2600,35 +2600,32 @@ def load_templates_from_templates_dir(
26002600 kalign_binary_path : str | None = None ,
26012601 template_cutoff_date : datetime | None = None ,
26022602 randomly_sample_num_templates : bool = False ,
2603- raise_missing_exception : bool = False ,
26042603 verbose : bool = False ,
26052604) -> FeatureDict :
26062605 """Load templates from a directory containing template PDB mmCIF files."""
2607- if (
2608- not_exists (templates_dir ) or not os .path .exists (templates_dir )
2609- ) and raise_missing_exception :
2610- raise FileNotFoundError (f"{ templates_dir } does not exist." )
2611- elif not_exists (templates_dir ) or not os .path .exists (templates_dir ):
2612- if verbose :
2613- logger .warning (
2614- f"{ templates_dir } does not exist. Skipping template loading by returning `Nones`."
2615- )
2616- return {}
2606+ if verbose and (not_exists (templates_dir ) or not os .path .exists (templates_dir )):
2607+ logger .warning (
2608+ f"{ templates_dir } templates directory does not exist. Dummy template features for each chain of file { file_id } will instead be loaded."
2609+ )
26172610
2618- if (not_exists (mmcif_dir ) or not os .path .exists (mmcif_dir )) and raise_missing_exception :
2619- raise FileNotFoundError (f"{ mmcif_dir } does not exist." )
2620- elif not_exists (mmcif_dir ) or not os .path .exists (mmcif_dir ):
2621- if verbose :
2622- logger .warning (
2623- f"{ mmcif_dir } does not exist. Skipping template loading by returning `Nones`."
2624- )
2625- return {}
2611+ if verbose and (not_exists (mmcif_dir ) or not os .path .exists (mmcif_dir )):
2612+ logger .warning (
2613+ f"{ mmcif_dir } mmCIF templates directory does not exist. Dummy template features for each chain of file { file_id } will instead be loaded."
2614+ )
26262615
26272616 templates = defaultdict (list )
26282617 for chain_id in chain_id_to_residue :
2629- template_fpaths = glob .glob (os .path .join (templates_dir , f"{ file_id } { chain_id } _*.m8" ))
2618+ template_fpaths = (
2619+ glob .glob (os .path .join (templates_dir , f"{ file_id } { chain_id } _*.m8" ))
2620+ if exists (templates_dir )
2621+ else []
2622+ )
26302623
26312624 if not template_fpaths :
2625+ if verbose :
2626+ logger .warning (
2627+ f"Could not find template for chain { chain_id } of file { file_id } . A dummy template will be installed for this chain."
2628+ )
26322629 templates [chain_id ] = []
26332630 continue
26342631
@@ -2808,30 +2805,31 @@ def pdb_input_to_molecule_input(
28082805 exists (feat )
28092806 for feat in [msa , msa_row_mask , has_deletion , deletion_value , profile , deletion_mean ]
28102807 )
2811- if all_msa_features_exist :
2812- assert (
2813- msa .shape [- 1 ] == num_tokens
2814- ), f"The number of tokens in the MSA ({ msa .shape [- 1 ]} ) does not match the number of tokens in the biomolecule ({ num_tokens } ). "
2815-
2816- additional_msa_feats = torch .stack (
2817- [
2818- has_deletion ,
2819- deletion_value ,
2820- ],
2821- dim = - 1 ,
2822- )
28232808
2824- additional_token_feats = torch .cat (
2825- [
2826- profile ,
2827- deletion_mean [:, None ],
2828- ],
2829- dim = - 1 ,
2830- )
2809+ assert all_msa_features_exist , "All MSA features must be derived for each example."
2810+ assert (
2811+ msa .shape [- 1 ] == num_tokens
2812+ ), f"The number of tokens in the MSA ({ msa .shape [- 1 ]} ) does not match the number of tokens in the biomolecule ({ num_tokens } ). "
2813+
2814+ additional_msa_feats = torch .stack (
2815+ [
2816+ has_deletion ,
2817+ deletion_value ,
2818+ ],
2819+ dim = - 1 ,
2820+ )
2821+
2822+ additional_token_feats = torch .cat (
2823+ [
2824+ profile ,
2825+ deletion_mean [:, None ],
2826+ ],
2827+ dim = - 1 ,
2828+ )
28312829
2832- # convert the MSA into a one-hot representation
2833- msa = make_one_hot (msa , NUM_MSA_ONE_HOT )
2834- msa_row_mask = msa_row_mask .bool ()
2830+ # convert the MSA into a one-hot representation
2831+ msa = make_one_hot (msa , NUM_MSA_ONE_HOT )
2832+ msa_row_mask = msa_row_mask .bool ()
28352833
28362834 # retrieve templates for each chain
28372835
0 commit comments