@@ -217,9 +217,7 @@ class PubChemBatched(PubChem):
217217
218218 READER : Type [dr .ChemDataReader ] = dr .ChemDataReader
219219
220- def __init__ (
221- self , train_batch_size = 10_000_000 , * args , ** kwargs
222- ):
220+ def __init__ (self , train_batch_size = 10_000_000 , * args , ** kwargs ):
223221 super (PubChemBatched , self ).__init__ (* args , ** kwargs )
224222 self .curr_epoch = 0
225223 self .train_batch_size = train_batch_size
@@ -245,7 +243,7 @@ def processed_file_names_dict(self) -> List[str]:
245243 List[str]: List of processed data file names.
246244 """
247245 train_samples = (
248- self ._k if self ._k != self .FULL else 120_000_000 # estimated PubChem size
246+ self ._k if self ._k != self .FULL else 120_000_000 # estimated PubChem size
249247 ) # estimate size
250248 train_samples -= self .val_batch_size + self .test_batch_size
251249 train_batches = (
@@ -279,21 +277,15 @@ def _tokenize_batched(self, data):
279277 batch .append (self .reader .to_data (d ))
280278 if i % 1_000_000 == 0 and i > 0 :
281279 print (f"Saving batch { i // 1_000_000 } " )
282- batch = [
283- b
284- for b in batch
285- if b ["features" ] is not None
286- ]
280+ batch = [b for b in batch if b ["features" ] is not None ]
287281 if self .n_token_limit is not None :
288- batch = [b for b in batch if len (b ["features" ]) <= self .n_token_limit ]
282+ batch = [
283+ b for b in batch if len (b ["features" ]) <= self .n_token_limit
284+ ]
289285 yield batch
290286 batch = []
291287 print ("Saving final batch" )
292- batch = [
293- b
294- for b in batch
295- if b ["features" ] is not None
296- ]
288+ batch = [b for b in batch if b ["features" ] is not None ]
297289 if self .n_token_limit is not None :
298290 batch = [b for b in batch if len (b ["features" ]) <= self .n_token_limit ]
299291 yield batch
@@ -348,13 +340,18 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
348340 DataLoader: A DataLoader object for training data.
349341 """
350342 return self .dataloader (
351- "train" if "train" in self .processed_file_names_dict else f"train_{ self .curr_epoch } " ,
343+ (
344+ "train"
345+ if "train" in self .processed_file_names_dict
346+ else f"train_{ self .curr_epoch } "
347+ ),
352348 shuffle = True ,
353349 num_workers = self .num_workers ,
354350 persistent_workers = True ,
355351 ** kwargs ,
356352 )
357353
354+
358355class PubChemDissimilar (PubChem ):
359356 """
360357 Subset of PubChem, but choosing the most dissimilar molecules (according to fingerprint)
0 commit comments