@@ -2566,6 +2566,13 @@ def pdb_input_to_molecule_input(
25662566 assert len (molecules ) == len (missing_atom_indices )
25672567 assert len (missing_token_indices ) == num_tokens
25682568
2569+ mol_total_atoms = sum ([mol .GetNumAtoms () for mol in molecules ])
2570+ num_missing_atom_indices = sum (
2571+ len (mol_miss_atom_indices ) for mol_miss_atom_indices in missing_atom_indices
2572+ )
2573+ num_present_atoms = mol_total_atoms - num_missing_atom_indices
2574+ assert num_present_atoms == int (biomol .atom_mask .sum ())
2575+
25692576 # TODO: install additional token features once MSAs are available
25702577 # 0: f_profile
25712578 # 1: f_deletion_mean
@@ -2645,6 +2652,7 @@ def __init__(
26452652 spatial_interface_weight : float = 0.4 ,
26462653 crop_size : int = 384 ,
26472654 training : bool | None = None , # extra training flag placed by Alex on PDBInput
2655+ sample_only_pdb_ids : Set [str ] | None = None ,
26482656 ** pdb_input_kwargs ,
26492657 ):
26502658 if isinstance (folder , str ):
@@ -2660,6 +2668,7 @@ def __init__(
26602668 self .sampler = sampler
26612669 self .sample_type = sample_type
26622670 self .training = training
2671+ self .sample_only_pdb_ids = sample_only_pdb_ids
26632672 self .pdb_input_kwargs = pdb_input_kwargs
26642673
26652674 self .cropping_config = {
@@ -2679,6 +2688,12 @@ def __init__(
26792688 if file in sampler_pdb_ids
26802689 }
26812690
2691+ if exists (sample_only_pdb_ids ):
2692+ assert exists (self .sampler ), "A sampler must be provided to use `sample_only_pdb_ids`."
2693+ assert all (
2694+ pdb_id in sampler_pdb_ids for pdb_id in sample_only_pdb_ids
2695+ ), "Some PDB IDs in `sample_only_pdb_ids` are not present in the dataset's sampler mappings."
2696+
26822697 assert len (self ) > 0 , f"No valid mmCIFs / PDBs found at { str (folder )} "
26832698
26842699 def __len__ (self ):
@@ -2690,10 +2705,18 @@ def __getitem__(self, idx: int | str) -> PDBInput:
26902705 sampled_id = None
26912706
26922707 if exists (self .sampler ):
2693- if self .sample_type == "clustered" :
2694- (sampled_id ,) = self .sampler .cluster_based_sample (1 )
2695- else :
2696- (sampled_id ,) = self .sampler .sample (1 )
2708+ sample_fn = (
2709+ self .sampler .cluster_based_sample
2710+ if self .sample_type == "clustered"
2711+ else self .sampler .sample
2712+ )
2713+ (sampled_id ,) = sample_fn (1 )
2714+
2715+ # ensure that the sampled PDB ID is in the specified set of PDB IDs from which to sample
2716+
2717+ if exists (self .sample_only_pdb_ids ):
2718+ while sampled_id [0 ] not in self .sample_only_pdb_ids :
2719+ (sampled_id ,) = sample_fn (1 )
26972720
26982721 pdb_id , chain_id_1 , chain_id_2 = None , None , None
26992722
0 commit comments