2323from torch .nn .utils .rnn import pad_sequence
2424
2525from loguru import logger
26+ from joblib import Parallel , delayed
27+
2628from pdbeccdutils .core import ccd_reader
2729
2830from rdkit import Chem
@@ -221,7 +223,8 @@ def pdb_dataset_to_atom_inputs(
221223 output_atom_folder : str | Path | None = None ,
222224 indices : Iterable | None = None ,
223225 return_atom_dataset = False ,
224- verbose = True
226+ n_jobs : int = 8 ,
227+ parallel_kwargs : dict = dict ()
225228) -> Path | AtomDataset :
226229
227230 if not exists (output_atom_folder ):
@@ -235,26 +238,21 @@ def pdb_dataset_to_atom_inputs(
235238 if not exists (indices ):
236239 indices = torch .randperm (len (pdb_dataset )).tolist ()
237240
238- indices = iter (indices )
239-
240241 to_atom_input_fn = compose (
241242 pdb_input_to_molecule_input ,
242243 molecule_to_atom_input
243244 )
244245
245- while index := next (indices , None ):
246- if not exists (index ):
247- break
248-
246+ @delayed
247+ def pdb_input_to_atom_file (index , path ):
249248 pdb_input = pdb_dataset [index ]
250249
251250 atom_input = to_atom_input_fn (pdb_input )
252- atom_input_path = output_atom_folder / f'{ index } .pt'
251+ atom_input_path = path / f'{ index } .pt'
253252
254253 atom_input_to_file (atom_input , atom_input_path )
255254
256- if verbose :
257- logger .info (f'converted pdb input with index { index } to { str (atom_input_path )} ' )
255+ Parallel (n_jobs = n_jobs , ** parallel_kwargs )(pdb_input_to_atom_file (index , output_atom_folder ) for index in indices )
258256
259257 if not return_atom_dataset :
260258 return output_atom_folder
0 commit comments