@@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
236
236
237
237
238
238
def make_eagle_supervised_data_module (
239
- tokenizer : transformers .PreTrainedTokenizer , data_args , use_offline_training : bool
239
+ tokenizer : transformers .PreTrainedTokenizer ,
240
+ data_args ,
241
+ use_offline_training : bool ,
242
+ max_length = None ,
240
243
) -> dict :
241
244
"""Make dataset and collator for supervised fine-tuning.
242
245
@@ -295,15 +298,15 @@ def make_eagle_supervised_data_module(
295
298
train_dataset = dataset_cls (valid_entries [:num_train ], tokenizer = tokenizer )
296
299
eval_dataset = dataset_cls (valid_entries [num_train :], tokenizer = tokenizer )
297
300
298
- data_collator = DataCollatorForOffline ()
301
+ data_collator = DataCollatorForOffline (max_length = max_length )
299
302
else :
300
303
print_rank_0 ("Loading input conversations..." )
301
304
dataset_cls = LazySupervisedDataset if data_args .lazy_preprocess else SupervisedDataset
302
305
303
306
train_dataset = dataset_cls (data_json [: int (len (data_json ) * 0.95 )], tokenizer = tokenizer )
304
307
eval_dataset = dataset_cls (data_json [int (len (data_json ) * 0.95 ) :], tokenizer = tokenizer )
305
308
306
- data_collator = DataCollatorWithPadding ()
309
+ data_collator = DataCollatorWithPadding (max_length = max_length )
307
310
308
311
return {
309
312
"train_dataset" : train_dataset ,
@@ -313,6 +316,9 @@ def make_eagle_supervised_data_module(
313
316
314
317
315
318
class DataCollatorWithPadding :
319
+ def __init__ (self , max_length ):
320
+ self .max_length = max_length
321
+
316
322
def paddingtensor2d (self , intensors , length ):
317
323
n , dim = intensors .shape
318
324
padding_tensor = torch .zeros (length - n , dim , dtype = intensors .dtype )
@@ -325,19 +331,18 @@ def paddingtensor(self, intensors, length):
325
331
return outtensors
326
332
327
333
def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
328
- max_length = max (item ["input_ids" ].shape [0 ] for item in features )
329
334
batch_input_ids = torch .stack (
330
- [self .paddingtensor (item ["input_ids" ], max_length ) for item in features ]
335
+ [self .paddingtensor (item ["input_ids" ], self . max_length ) for item in features ]
331
336
)
332
337
batch_attention_mask = torch .stack (
333
- [self .paddingtensor (item ["attention_mask" ], max_length ) for item in features ]
338
+ [self .paddingtensor (item ["attention_mask" ], self . max_length ) for item in features ]
334
339
)
335
340
batch_loss_mask = torch .stack (
336
- [self .paddingtensor (item ["loss_mask" ], max_length ) for item in features ]
341
+ [self .paddingtensor (item ["loss_mask" ], self . max_length ) for item in features ]
337
342
)
338
343
339
344
batch_labels = torch .stack (
340
- [self .paddingtensor (item ["labels" ], max_length ) for item in features ]
345
+ [self .paddingtensor (item ["labels" ], self . max_length ) for item in features ]
341
346
)
342
347
343
348
batch = {
@@ -357,16 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
357
362
raise ValueError ("No kwargs found in batch features. Offline data required." )
358
363
359
364
features = [item ["kwargs" ]["base_model_outputs" ] for item in features ]
360
- max_hs_length = max (item ["base_model_hidden_states" ].shape [0 ] for item in features )
361
365
362
366
batch_hidden_states = torch .stack (
363
367
[
364
- self .paddingtensor2d (item ["base_model_hidden_states" ], max_hs_length )
368
+ self .paddingtensor2d (item ["base_model_hidden_states" ], self . max_length )
365
369
for item in features
366
370
]
367
371
)
368
372
batch_aux_hidden_states = torch .stack (
369
- [self .paddingtensor2d (item ["aux_hidden_states" ], max_hs_length ) for item in features ]
373
+ [self .paddingtensor2d (item ["aux_hidden_states" ], self . max_length ) for item in features ]
370
374
)
371
375
372
376
batch = {
0 commit comments