@@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
236236
237237
238238def make_eagle_supervised_data_module (
239- tokenizer : transformers .PreTrainedTokenizer , data_args , use_offline_training : bool
239+ tokenizer : transformers .PreTrainedTokenizer ,
240+ data_args ,
241+ use_offline_training : bool ,
242+ max_length = None ,
240243) -> dict :
241244 """Make dataset and collator for supervised fine-tuning.
242245
@@ -295,15 +298,15 @@ def make_eagle_supervised_data_module(
295298 train_dataset = dataset_cls (valid_entries [:num_train ], tokenizer = tokenizer )
296299 eval_dataset = dataset_cls (valid_entries [num_train :], tokenizer = tokenizer )
297300
298- data_collator = DataCollatorForOffline ()
301+ data_collator = DataCollatorForOffline (max_length = max_length )
299302 else :
300303 print_rank_0 ("Loading input conversations..." )
301304 dataset_cls = LazySupervisedDataset if data_args .lazy_preprocess else SupervisedDataset
302305
303306 train_dataset = dataset_cls (data_json [: int (len (data_json ) * 0.95 )], tokenizer = tokenizer )
304307 eval_dataset = dataset_cls (data_json [int (len (data_json ) * 0.95 ) :], tokenizer = tokenizer )
305308
306- data_collator = DataCollatorWithPadding ()
309+ data_collator = DataCollatorWithPadding (max_length = max_length )
307310
308311 return {
309312 "train_dataset" : train_dataset ,
@@ -313,6 +316,9 @@ def make_eagle_supervised_data_module(
313316
314317
315318class DataCollatorWithPadding :
319+ def __init__ (self , max_length = None ):
320+ self .max_length = max_length
321+
316322 def paddingtensor2d (self , intensors , length ):
317323 n , dim = intensors .shape
318324 padding_tensor = torch .zeros (length - n , dim , dtype = intensors .dtype )
@@ -325,7 +331,11 @@ def paddingtensor(self, intensors, length):
325331 return outtensors
326332
327333 def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
328- max_length = max (item ["input_ids" ].shape [0 ] for item in features )
334+ max_length = (
335+ self .max_length
336+ if self .max_length is not None
337+ else max (item ["input_ids" ].shape [0 ] for item in features )
338+ )
329339 batch_input_ids = torch .stack (
330340 [self .paddingtensor (item ["input_ids" ], max_length ) for item in features ]
331341 )
@@ -351,13 +361,20 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
351361
352362
353363class DataCollatorForOffline (DataCollatorWithPadding ):
364+ def __init__ (self , max_length = None ):
365+ super ().__init__ (max_length = max_length )
366+
354367 def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
355368 base_batch = super ().__call__ (features )
356369 if "kwargs" not in features [0 ]:
357370 raise ValueError ("No kwargs found in batch features. Offline data required." )
358371
359372 features = [item ["kwargs" ]["base_model_outputs" ] for item in features ]
360- max_hs_length = max (item ["base_model_hidden_states" ].shape [0 ] for item in features )
373+ max_hs_length = (
374+ self .max_length
375+ if self .max_length is not None
376+ else max (item ["base_model_hidden_states" ].shape [0 ] for item in features )
377+ )
361378
362379 batch_hidden_states = torch .stack (
363380 [
0 commit comments