@@ -316,7 +316,7 @@ def make_eagle_supervised_data_module(
316316
317317
318318class DataCollatorWithPadding :
319- def __init__ (self , max_length = None ):
319+ def __init__ (self , max_length ):
320320 self .max_length = max_length
321321
322322 def paddingtensor2d (self , intensors , length ):
@@ -331,23 +331,18 @@ def paddingtensor(self, intensors, length):
331331 return outtensors
332332
333333 def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
334- max_length = (
335- self .max_length
336- if self .max_length is not None
337- else max (item ["input_ids" ].shape [0 ] for item in features )
338- )
339334 batch_input_ids = torch .stack (
340- [self .paddingtensor (item ["input_ids" ], max_length ) for item in features ]
335+ [self .paddingtensor (item ["input_ids" ], self . max_length ) for item in features ]
341336 )
342337 batch_attention_mask = torch .stack (
343- [self .paddingtensor (item ["attention_mask" ], max_length ) for item in features ]
338+ [self .paddingtensor (item ["attention_mask" ], self . max_length ) for item in features ]
344339 )
345340 batch_loss_mask = torch .stack (
346- [self .paddingtensor (item ["loss_mask" ], max_length ) for item in features ]
341+ [self .paddingtensor (item ["loss_mask" ], self . max_length ) for item in features ]
347342 )
348343
349344 batch_labels = torch .stack (
350- [self .paddingtensor (item ["labels" ], max_length ) for item in features ]
345+ [self .paddingtensor (item ["labels" ], self . max_length ) for item in features ]
351346 )
352347
353348 batch = {
@@ -367,20 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
367362 raise ValueError ("No kwargs found in batch features. Offline data required." )
368363
369364 features = [item ["kwargs" ]["base_model_outputs" ] for item in features ]
370- max_hs_length = (
371- self .max_length
372- if self .max_length is not None
373- else max (item ["base_model_hidden_states" ].shape [0 ] for item in features )
374- )
375365
376366 batch_hidden_states = torch .stack (
377367 [
378- self .paddingtensor2d (item ["base_model_hidden_states" ], max_hs_length )
368+ self .paddingtensor2d (item ["base_model_hidden_states" ], self . max_length )
379369 for item in features
380370 ]
381371 )
382372 batch_aux_hidden_states = torch .stack (
383- [self .paddingtensor2d (item ["aux_hidden_states" ], max_hs_length ) for item in features ]
373+ [self .paddingtensor2d (item ["aux_hidden_states" ], self . max_length ) for item in features ]
384374 )
385375
386376 batch = {
0 commit comments