@@ -147,6 +147,7 @@ class DatasetConfig(BaseModelWithExtra):
147147 valid_folder : DirectoryPath | None = None
148148 test_folder : DirectoryPath | None = None
149149 convert_pdb_to_atom : bool = False
150+ pdb_to_atom_kwargs : dict = dict ()
150151 train_weighted_sampler : WeightedPDBSamplerConfig | None = None
151152 kwargs : dict = dict ()
152153
@@ -222,6 +223,7 @@ def create_instance(
222223 dataset_kwargs = dataset_config .kwargs
223224
224225 convert_pdb_to_atom = dataset_config .convert_pdb_to_atom
226+ pdb_to_atom_kwargs = dataset_config .pdb_to_atom_kwargs
225227
226228 if convert_pdb_to_atom :
227229 assert dataset_type == 'pdb' , 'must be `pdb` dataset_type if `convert_pdb_to_atom` is set to True'
@@ -233,35 +235,22 @@ def create_instance(
233235 else :
234236 raise ValueError (f'unhandled dataset_type { dataset_type } ' )
235237
236- train_folder , valid_folder , test_folder = tuple ( getattr ( dataset_config , key , None ) for key in ( 'train_folder' , 'valid_folder' , 'test_folder' ))
238+ # create dataset for train, valid, and test
237239
238- if exists ( train_folder ):
239- assert 'dataset' not in trainer_kwargs
240+ for trainer_kwarg_key , config_key in (( 'dataset' , ' train_folder' ), ( 'valid_dataset' , 'valid_folder' ), ( 'test_dataset' , 'test_folder' ) ):
241+ folder = getattr ( dataset_config , config_key , None )
240242
241- dataset = dataset_klass (train_folder , ** dataset_kwargs )
243+ if not exists (folder ):
244+ continue
242245
243- if convert_pdb_to_atom :
244- dataset = pdb_dataset_to_atom_inputs (dataset , return_atom_dataset = True )
245-
246- trainer_kwargs .update (dataset = dataset )
247-
248- if exists (valid_folder ):
249- assert 'valid_dataset' not in trainer_kwargs
250- dataset = dataset_klass (valid_folder , ** dataset_kwargs )
251-
252- if convert_pdb_to_atom :
253- dataset = pdb_dataset_to_atom_inputs (dataset , return_atom_dataset = True )
254-
255- trainer_kwargs .update (valid_dataset = dataset )
246+ assert trainer_kwarg_key not in trainer_kwargs
256247
257- if exists (test_folder ):
258- assert 'test_dataset' not in trainer_kwargs
259- dataset = dataset_klass (test_folder , ** dataset_kwargs )
248+ dataset = dataset_klass (folder , ** dataset_kwargs )
260249
261250 if convert_pdb_to_atom :
262- dataset = pdb_dataset_to_atom_inputs (dataset , return_atom_dataset = True )
251+ dataset = pdb_dataset_to_atom_inputs (dataset , return_atom_dataset = True , ** pdb_to_atom_kwargs )
263252
264- trainer_kwargs .update (test_dataset = dataset )
253+ trainer_kwargs .update (** { trainer_kwarg_key : dataset } )
265254
266255 # handle weighted pdb sampling
267256
0 commit comments