@@ -262,41 +262,77 @@ def processed_file_names_dict(self) -> List[str]:
262262 train_batches ["validation" ] = "validation.pt"
263263 return train_batches
264264
265+ def _tokenize_batched (self , data ):
266+ """
267+ Load data from a file and return a list of dictionaries, batched in 1,000,000 entries.
268+
269+ Args:
270+ path (str): The path to the input file.
271+ batch_size (int): The size of each batch.
272+ batch_idx (int): The index of the batch to load.
273+
274+ Returns:
275+ List: A list of dictionaries containing the features and labels.
276+ """
277+ print (f"Processing { len (data )} lines..." )
278+ batch = []
279+ for i , d in enumerate (tqdm .tqdm (data , total = len (data ))):
280+ if d ["features" ] is not None :
281+ batch .append (self .reader .to_data (d ))
282+ if i % 1_000_000 == 0 and i > 0 :
283+ print (f"Saving batch { i // 1_000_000 } " )
284+ batch = [
285+ b
286+ for b in batch
287+ if b ["features" ] is not None
288+ and self .n_token_limit is None
289+ or len (b ["features" ]) <= self .n_token_limit
290+ ]
291+ yield batch
292+ batch = []
293+ print ("Saving final batch" )
294+ batch = [
295+ b
296+ for b in batch
297+ if b ["features" ] is not None
298+ and self .n_token_limit is None
299+ or len (b ["features" ]) <= self .n_token_limit
300+ ]
301+ yield batch
302+
265303 def setup_processed (self ):
266304 """
267305 Prepares processed data and saves them as Torch tensors.
268306 """
269307 filename = os .path .join (self .raw_dir , self .raw_file_names [0 ])
270308 print ("Load data from file" , filename )
271- data = self ._load_data_from_file (filename )
309+ data_not_tokenized = [ entry for entry in self ._load_dict (filename )]
272310 print ("Create splits" )
273311 train , test = train_test_split (
274- data , test_size = self .test_batch_size + self .val_batch_size
312+ data_not_tokenized , test_size = self .test_batch_size + self .val_batch_size
275313 )
276- del data
314+ del data_not_tokenized
277315 test , val = train_test_split (test , train_size = self .test_batch_size )
278316 torch .save (
279- test ,
317+ self . _tokenize_batched ( test ) ,
280318 os .path .join (self .processed_dir , self .processed_file_names_dict ["test" ]),
281319 )
282320 torch .save (
283- val ,
321+ self . _tokenize_batched ( val ) ,
284322 os .path .join (
285323 self .processed_dir , self .processed_file_names_dict ["validation" ]
286324 ),
287325 )
288326
289327 # batch training if necessary
290328 if len (train ) > self .train_batch_size :
291- train_batches = [
292- train [i : i + self .train_batch_size ]
293- for i in range (0 , len (train ), self .train_batch_size )
294- ]
295- train = [torch .tensor (batch ) for batch in train_batches ]
296- for i , batch in enumerate (train ):
329+ for i , batch in enumerate (self ._tokenize_batched (train )):
297330 torch .save (batch , os .path .join (self .processed_dir , f"train_{ i } .pt" ))
298331 else :
299- torch .save (train , os .path .join (self .processed_dir , "train.pt" ))
332+ torch .save (
333+ self ._tokenize_batched (train ),
334+ os .path .join (self .processed_dir , "train.pt" ),
335+ )
300336
301337 self .reader .on_finish ()
302338
0 commit comments