@@ -181,11 +181,7 @@ def processed_file_names_dict(self) -> List[str]:
181181 Returns:
182182 List[str]: List of processed data file names.
183183 """
184- return {
185- "train" : "train.pt" ,
186- "test" : "test.pt" ,
187- "validation" : "validation.pt"
188- }
184+ return {"train" : "train.pt" , "test" : "test.pt" , "validation" : "validation.pt" }
189185
190186 def _set_processed_data_props (self ):
191187 """
@@ -221,7 +217,9 @@ class PubChemBatched(PubChem):
221217
222218 READER : Type [dr .ChemDataReader ] = dr .ChemDataReader
223219
224- def __init__ (self , pc_train_batch_idx = 0 , train_batch_size = 10_000_000 , * args , ** kwargs ):
220+ def __init__ (
221+ self , pc_train_batch_idx = 0 , train_batch_size = 10_000_000 , * args , ** kwargs
222+ ):
225223 super (PubChemBatched , self ).__init__ (* args , ** kwargs )
226224 self .pc_train_batch_idx = pc_train_batch_idx
227225 self .train_batch_size = train_batch_size
@@ -254,7 +252,9 @@ def processed_file_names_dict(self) -> List[str]:
254252 {"train" : "train.pt" }
255253 if train_samples <= self .train_batch_size
256254 else {
257- f"train" if i == self .pc_train_batch_idx else f"train_{ i } " : f"train_{ i } .pt"
255+ (
256+ "train" if i == self .pc_train_batch_idx else f"train_{ i } "
257+ ): f"train_{ i } .pt"
258258 for i in range ((train_samples // self .train_batch_size ) + 1 )
259259 }
260260 )
@@ -275,8 +275,16 @@ def setup_processed(self):
275275 )
276276 del data
277277 test , val = train_test_split (test , train_size = self .test_batch_size )
278- torch .save (test , os .path .join (self .processed_dir , self .processed_file_names_dict ["test" ]))
279- torch .save (val , os .path .join (self .processed_dir , self .processed_file_names_dict ["validation" ]))
278+ torch .save (
279+ test ,
280+ os .path .join (self .processed_dir , self .processed_file_names_dict ["test" ]),
281+ )
282+ torch .save (
283+ val ,
284+ os .path .join (
285+ self .processed_dir , self .processed_file_names_dict ["validation" ]
286+ ),
287+ )
280288
281289 # batch training if necessary
282290 if len (train ) > self .train_batch_size :
0 commit comments