Skip to content

Commit 94cbb2a

Browse files
committed
enforce max length in data collator
Signed-off-by: h-guo18 <[email protected]>
1 parent 67cc30b commit 94cbb2a

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def make_eagle_supervised_data_module(
316316

317317

318318
class DataCollatorWithPadding:
319-
def __init__(self, max_length=None):
319+
def __init__(self, max_length):
320320
self.max_length = max_length
321321

322322
def paddingtensor2d(self, intensors, length):
@@ -331,23 +331,18 @@ def paddingtensor(self, intensors, length):
331331
return outtensors
332332

333333
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
334-
max_length = (
335-
self.max_length
336-
if self.max_length is not None
337-
else max(item["input_ids"].shape[0] for item in features)
338-
)
339334
batch_input_ids = torch.stack(
340-
[self.paddingtensor(item["input_ids"], max_length) for item in features]
335+
[self.paddingtensor(item["input_ids"], self.max_length) for item in features]
341336
)
342337
batch_attention_mask = torch.stack(
343-
[self.paddingtensor(item["attention_mask"], max_length) for item in features]
338+
[self.paddingtensor(item["attention_mask"], self.max_length) for item in features]
344339
)
345340
batch_loss_mask = torch.stack(
346-
[self.paddingtensor(item["loss_mask"], max_length) for item in features]
341+
[self.paddingtensor(item["loss_mask"], self.max_length) for item in features]
347342
)
348343

349344
batch_labels = torch.stack(
350-
[self.paddingtensor(item["labels"], max_length) for item in features]
345+
[self.paddingtensor(item["labels"], self.max_length) for item in features]
351346
)
352347

353348
batch = {
@@ -367,20 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
367362
raise ValueError("No kwargs found in batch features. Offline data required.")
368363

369364
features = [item["kwargs"]["base_model_outputs"] for item in features]
370-
max_hs_length = (
371-
self.max_length
372-
if self.max_length is not None
373-
else max(item["base_model_hidden_states"].shape[0] for item in features)
374-
)
375365

376366
batch_hidden_states = torch.stack(
377367
[
378-
self.paddingtensor2d(item["base_model_hidden_states"], max_hs_length)
368+
self.paddingtensor2d(item["base_model_hidden_states"], self.max_length)
379369
for item in features
380370
]
381371
)
382372
batch_aux_hidden_states = torch.stack(
383-
[self.paddingtensor2d(item["aux_hidden_states"], max_hs_length) for item in features]
373+
[self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features]
384374
)
385375

386376
batch = {

0 commit comments

Comments
 (0)