77from functools import partial
88from itertools import groupby
99from collections import defaultdict
10+ from collections .abc import Iterable
1011from dataclasses import asdict , dataclass , field
1112from typing import Any , Callable , List , Literal , Set , Tuple , Type
1213
@@ -187,11 +188,12 @@ def dict(self):
187188@typecheck
188189def atom_input_to_file (
189190 atom_input : AtomInput ,
190- path : str ,
191+ path : str | Path ,
191192 overwrite : bool = False
192193) -> Path :
193194
194- path = Path (path )
195+ if isinstance (path , str ):
196+ path = Path (path )
195197
196198 if not overwrite :
197199 assert not path .exists ()
@@ -211,6 +213,53 @@ def file_to_atom_input(path: str | Path) -> AtomInput:
211213 atom_input_dict = torch .load (str (path ))
212214 return AtomInput (** atom_input_dict )
213215
216+ @typecheck
217+ def pdb_dataset_to_atom_inputs (
218+ pdb_dataset : PDBDataset ,
219+ * ,
220+ output_atom_folder : str | Path | None = None ,
221+ indices : Iterable | None = None ,
222+ return_atom_dataset = False ,
223+ verbose = True
224+ ) -> Path | AtomDataset :
225+
226+ if not exists (output_atom_folder ):
227+ pdb_folder = Path (pdb_dataset .folder ).resolve ()
228+ parent_folder = pdb_folder .parents [0 ]
229+ output_atom_folder = parent_folder / f'{ pdb_folder .stem } .atom-inputs'
230+
231+ if isinstance (output_atom_folder , str ):
232+ output_atom_folder = Path (output_atom_folder )
233+
234+ if not exists (indices ):
235+ indices = torch .randperm (len (pdb_dataset )).tolist ()
236+
237+ indices = iter (indices )
238+
239+ to_atom_input_fn = compose (
240+ pdb_input_to_molecule_input ,
241+ molecule_to_atom_input
242+ )
243+
244+ while index := next (indices , None ):
245+ if not exists (index ):
246+ break
247+
248+ pdb_input = pdb_dataset [index ]
249+
250+ atom_input = to_atom_input_fn (pdb_input )
251+ atom_input_path = output_atom_folder / f'{ index } .pt'
252+
253+ atom_input_to_file (atom_input , atom_input_path )
254+
255+ if verbose :
256+ logger .info (f'converted pdb input with index { index } to { str (atom_input_path )} ' )
257+
258+ if not return_atom_dataset :
259+ return output_atom_folder
260+
261+ return AtomDataset (output_atom_folder )
262+
214263# Atom dataset that returns a AtomInput based on folders of atom inputs stored on disk
215264
216265class AtomDataset (Dataset ):
@@ -221,11 +270,13 @@ def __init__(
221270 if isinstance (folder , str ):
222271 folder = Path (folder )
223272
224- assert folder .exists () and folder .is_dir ()
273+ assert folder .exists () and folder .is_dir (), f'atom dataset not found at { str ( folder ) } '
225274
226275 self .folder = folder
227276 self .files = [* folder .glob ('**/*.pt' )]
228277
278+ assert len (self ) > 0 , f'no valid atom .pt files found at { str (folder )} '
279+
229280 def __len__ (self ):
230281 return len (self .files )
231282
@@ -1919,19 +1970,6 @@ def pdb_input_to_molecule_input(pdb_input: PDBInput) -> MoleculeInput:
19191970
19201971# datasets
19211972
1922- # dataset wrapper for returning index along with dataset item
1923- # for caching logic both integrated into trainer and for precaching
1924-
1925- class DatasetWithReturnedIndex (Dataset ):
1926- def __init__ (self , dataset : Dataset ):
1927- self .dataset = dataset
1928-
1929- def __len__ (self ):
1930- return len (self .dataset )
1931-
1932- def __getitem__ (self , idx ):
1933- return idx , self .dataset [idx ]
1934-
19351973# PDB dataset that returns a PDBInput based on folder
19361974
19371975class PDBDataset (Dataset ):
@@ -1953,7 +1991,9 @@ def __init__(
19531991 if isinstance (folder , str ):
19541992 folder = Path (folder )
19551993
1956- assert folder .exists () and folder .is_dir ()
1994+ assert folder .exists () and folder .is_dir (), f'{ str (folder )} does not exist for PDBDataset'
1995+ self .folder = folder
1996+
19571997 self .files = {
19581998 os .path .splitext (os .path .basename (file .name ))[0 ]: file
19591999 for file in folder .glob (os .path .join ("**" , "*.cif" ))
@@ -1967,6 +2007,8 @@ def __init__(
19672007 self .training = training
19682008 self .pdb_input_kwargs = pdb_input_kwargs
19692009
2010+ assert len (self ) > 0 , f'no valid mmcifs / pdbs found at { str (folder )} '
2011+
19702012 def __len__ (self ):
19712013 """Return the number of PDB mmCIF files in the dataset."""
19722014 return len (self .files )
0 commit comments