@@ -316,7 +316,7 @@ def make_eagle_supervised_data_module(
316
316
317
317
318
318
class DataCollatorWithPadding :
319
- def __init__ (self , max_length = None ):
319
+ def __init__ (self , max_length ):
320
320
self .max_length = max_length
321
321
322
322
def paddingtensor2d (self , intensors , length ):
@@ -331,23 +331,18 @@ def paddingtensor(self, intensors, length):
331
331
return outtensors
332
332
333
333
def __call__ (self , features : list [dict [str , Any ]]) -> dict [str , Any ]:
334
- max_length = (
335
- self .max_length
336
- if self .max_length is not None
337
- else max (item ["input_ids" ].shape [0 ] for item in features )
338
- )
339
334
batch_input_ids = torch .stack (
340
- [self .paddingtensor (item ["input_ids" ], max_length ) for item in features ]
335
+ [self .paddingtensor (item ["input_ids" ], self . max_length ) for item in features ]
341
336
)
342
337
batch_attention_mask = torch .stack (
343
- [self .paddingtensor (item ["attention_mask" ], max_length ) for item in features ]
338
+ [self .paddingtensor (item ["attention_mask" ], self . max_length ) for item in features ]
344
339
)
345
340
batch_loss_mask = torch .stack (
346
- [self .paddingtensor (item ["loss_mask" ], max_length ) for item in features ]
341
+ [self .paddingtensor (item ["loss_mask" ], self . max_length ) for item in features ]
347
342
)
348
343
349
344
batch_labels = torch .stack (
350
- [self .paddingtensor (item ["labels" ], max_length ) for item in features ]
345
+ [self .paddingtensor (item ["labels" ], self . max_length ) for item in features ]
351
346
)
352
347
353
348
batch = {
@@ -367,20 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
367
362
raise ValueError ("No kwargs found in batch features. Offline data required." )
368
363
369
364
features = [item ["kwargs" ]["base_model_outputs" ] for item in features ]
370
- max_hs_length = (
371
- self .max_length
372
- if self .max_length is not None
373
- else max (item ["base_model_hidden_states" ].shape [0 ] for item in features )
374
- )
375
365
376
366
batch_hidden_states = torch .stack (
377
367
[
378
- self .paddingtensor2d (item ["base_model_hidden_states" ], max_hs_length )
368
+ self .paddingtensor2d (item ["base_model_hidden_states" ], self . max_length )
379
369
for item in features
380
370
]
381
371
)
382
372
batch_aux_hidden_states = torch .stack (
383
- [self .paddingtensor2d (item ["aux_hidden_states" ], max_hs_length ) for item in features ]
373
+ [self .paddingtensor2d (item ["aux_hidden_states" ], self . max_length ) for item in features ]
384
374
)
385
375
386
376
batch = {
0 commit comments