@@ -212,6 +212,78 @@ def _perform_data_preparation(self, *args, **kwargs):
212212 print ("Done" )
213213
214214
215+ class PubChemBatched (PubChem ):
216+ """Store train data as batches of 10m, validation and test should each be 100k max"""
217+
218+ def __init__ (self , * args , ** kwargs ):
219+ super (PubChemBatched , self ).__init__ (* args , ** kwargs )
220+ self .train_batch_size = 10_000_000
221+ if self .k != self .FULL :
222+ self .val_batch_size = (
223+ 100_000
224+ if self .validation_split * self .k > 100_000
225+ else int (self .validation_split * self .k )
226+ )
227+ self .test_batch_size = (
228+ 100_000
229+ if self .test_split * self .k > 100_000
230+ else int (self .test_split * self .k )
231+ )
232+ else :
233+ self .val_batch_size = 100_000
234+ self .test_batch_size = 100_000
235+
236+ @property
237+ def processed_file_names (self ) -> List [str ]:
238+ """
239+ Returns:
240+ List[str]: List of processed data file names.
241+ """
242+ train_samples = (
243+ self ._k if self ._k != self .FULL else 120_000_000
244+ ) # estimate size
245+ train_samples -= self .val_batch_size + self .test_batch_size
246+ train_batches = (
247+ ["train.pt" ]
248+ if train_samples <= self .train_batch_size
249+ else [
250+ f"train_{ i } .pt"
251+ for i in range ((train_samples // self .train_batch_size ) + 1 )
252+ ]
253+ )
254+ return train_batches + ["test.pt" , "validation.pt" ]
255+
256+ def setup_processed (self ):
257+ """
258+ Prepares processed data and saves them as Torch tensors.
259+ """
260+ filename = os .path .join (self .raw_dir , self .raw_file_names [0 ])
261+ print ("Load data from file" , filename )
262+ data = self ._load_data_from_file (filename )
263+ print ("Create splits" )
264+ train , test = train_test_split (
265+ data , test_size = self .test_batch_size + self .val_batch_size
266+ )
267+ del data
268+ test , val = train_test_split (test , train_size = self .test_batch_size )
269+ torch .save (test , os .path .join (self .processed_dir , "test.pt" ))
270+ torch .save (val , os .path .join (self .processed_dir , "validation.pt" ))
271+
272+ # batch training if necessary
273+ if len (train ) > self .train_batch_size :
274+ train_batches = [
275+ train [i : i + self .train_batch_size ]
276+ for i in range (0 , len (train ), self .train_batch_size )
277+ ]
278+ train = [torch .tensor (batch ) for batch in train_batches ]
279+ for i , batch in enumerate (train ):
280+ torch .save (batch , os .path .join (self .processed_dir , f"train_{ i } .pt" ))
281+ else :
282+ torch .save (train , os .path .join (self .processed_dir , "train.pt" ))
283+
284+ self .reader .on_finish ()
285+
286+
215287class PubChemDissimilar (PubChem ):
216288 """
217289 Subset of PubChem, but choosing the most dissimilar molecules (according to fingerprint)
0 commit comments