@@ -176,12 +176,16 @@ def raw_file_names(self) -> List[str]:
176176 return ["smiles.txt" ]
177177
178178 @property
179- def processed_file_names (self ) -> List [str ]:
179+ def processed_file_names_dict (self ) -> List [str ]:
180180 """
181181 Returns:
182182 List[str]: List of processed data file names.
183183 """
184- return ["test.pt" , "train.pt" , "validation.pt" ]
184+ return {
185+ "train" : "train.pt" ,
186+ "test" : "test.pt" ,
187+ "validation" : "validation.pt"
188+ }
185189
186190 def _set_processed_data_props (self ):
187191 """
@@ -215,9 +219,12 @@ def _perform_data_preparation(self, *args, **kwargs):
215219class PubChemBatched (PubChem ):
216220 """Store train data as batches of 10m, validation and test should each be 100k max"""
217221
218- def __init__ (self , * args , ** kwargs ):
222+ READER : Type [dr .ChemDataReader ] = dr .ChemDataReader
223+
224+ def __init__ (self , pc_train_batch_idx = 0 , train_batch_size = 10_000_000 , * args , ** kwargs ):
219225 super (PubChemBatched , self ).__init__ (* args , ** kwargs )
220- self .train_batch_size = 10_000_000
226+ self .pc_train_batch_idx = pc_train_batch_idx
227+ self .train_batch_size = train_batch_size
221228 if self .k != self .FULL :
222229 self .val_batch_size = (
223230 100_000
@@ -234,7 +241,7 @@ def __init__(self, *args, **kwargs):
234241 self .test_batch_size = 100_000
235242
236243 @property
237- def processed_file_names (self ) -> List [str ]:
244+ def processed_file_names_dict (self ) -> List [str ]:
238245 """
239246 Returns:
240247 List[str]: List of processed data file names.
@@ -244,14 +251,16 @@ def processed_file_names(self) -> List[str]:
244251 ) # estimate size
245252 train_samples -= self .val_batch_size + self .test_batch_size
246253 train_batches = (
247- [ "train.pt" ]
254+ { "train" : "train .pt"}
248255 if train_samples <= self .train_batch_size
249- else [
250- f"train_{ i } .pt"
256+ else {
257+ f"train" if i == self . pc_train_batch_idx else f"train_ { i } " : f" train_{ i } .pt"
251258 for i in range ((train_samples // self .train_batch_size ) + 1 )
252- ]
259+ }
253260 )
254- return train_batches + ["test.pt" , "validation.pt" ]
261+ train_batches ["test" ] = "test.pt"
262+ train_batches ["validation" ] = "validation.pt"
263+ return train_batches
255264
256265 def setup_processed (self ):
257266 """
@@ -266,8 +275,8 @@ def setup_processed(self):
266275 )
267276 del data
268277 test , val = train_test_split (test , train_size = self .test_batch_size )
269- torch .save (test , os .path .join (self .processed_dir , "test.pt" ))
270- torch .save (val , os .path .join (self .processed_dir , "validation.pt" ))
278+ torch .save (test , os .path .join (self .processed_dir , self . processed_file_names_dict [ "test" ] ))
279+ torch .save (val , os .path .join (self .processed_dir , self . processed_file_names_dict [ "validation" ] ))
271280
272281 # batch training if necessary
273282 if len (train ) > self .train_batch_size :
0 commit comments