Skip to content

Commit 0afd65a

Browse files
committed
add max_length in data collator
Signed-off-by: h-guo18 <[email protected]>
1 parent 5c1fea8 commit 0afd65a

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
236236

237237

238238
def make_eagle_supervised_data_module(
239-
tokenizer: transformers.PreTrainedTokenizer, data_args, use_offline_training: bool
239+
tokenizer: transformers.PreTrainedTokenizer,
240+
data_args,
241+
use_offline_training: bool,
242+
max_length=None,
240243
) -> dict:
241244
"""Make dataset and collator for supervised fine-tuning.
242245
@@ -295,15 +298,15 @@ def make_eagle_supervised_data_module(
295298
train_dataset = dataset_cls(valid_entries[:num_train], tokenizer=tokenizer)
296299
eval_dataset = dataset_cls(valid_entries[num_train:], tokenizer=tokenizer)
297300

298-
data_collator = DataCollatorForOffline()
301+
data_collator = DataCollatorForOffline(max_length=max_length)
299302
else:
300303
print_rank_0("Loading input conversations...")
301304
dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
302305

303306
train_dataset = dataset_cls(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
304307
eval_dataset = dataset_cls(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)
305308

306-
data_collator = DataCollatorWithPadding()
309+
data_collator = DataCollatorWithPadding(max_length=max_length)
307310

308311
return {
309312
"train_dataset": train_dataset,
@@ -313,6 +316,9 @@ def make_eagle_supervised_data_module(
313316

314317

315318
class DataCollatorWithPadding:
319+
def __init__(self, max_length=None):
320+
self.max_length = max_length
321+
316322
def paddingtensor2d(self, intensors, length):
317323
n, dim = intensors.shape
318324
padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
@@ -325,7 +331,11 @@ def paddingtensor(self, intensors, length):
325331
return outtensors
326332

327333
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
328-
max_length = max(item["input_ids"].shape[0] for item in features)
334+
max_length = (
335+
self.max_length
336+
if self.max_length is not None
337+
else max(item["input_ids"].shape[0] for item in features)
338+
)
329339
batch_input_ids = torch.stack(
330340
[self.paddingtensor(item["input_ids"], max_length) for item in features]
331341
)
@@ -351,13 +361,20 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
351361

352362

353363
class DataCollatorForOffline(DataCollatorWithPadding):
364+
def __init__(self, max_length=None):
365+
super().__init__(max_length=max_length)
366+
354367
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
355368
base_batch = super().__call__(features)
356369
if "kwargs" not in features[0]:
357370
raise ValueError("No kwargs found in batch features. Offline data required.")
358371

359372
features = [item["kwargs"]["base_model_outputs"] for item in features]
360-
max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features)
373+
max_hs_length = (
374+
self.max_length
375+
if self.max_length is not None
376+
else max(item["base_model_hidden_states"].shape[0] for item in features)
377+
)
361378

362379
batch_hidden_states = torch.stack(
363380
[

examples/speculative_decoding/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def train():
229229
if training_args.mode == "medusa":
230230
data_module = make_medusa_supervised_data_module(tokenizer, data_args)
231231
elif training_args.mode in ["eagle1", "eagle3"]:
232-
data_module = make_eagle_supervised_data_module(tokenizer, data_args, use_offline_training)
232+
data_module = make_eagle_supervised_data_module(
233+
tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
234+
)
233235

234236
class ARValidationCallback(TrainerCallback):
235237
def __init__(self, ar_validate_steps: int = 500):

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def forward(
316316
hidden_states,
317317
attention_mask=attention_mask,
318318
position_ids=position_ids,
319-
past_key_values=past_key_values,
319+
past_key_value=past_key_values,
320320
output_attentions=output_attentions,
321321
use_cache=use_cache,
322322
position_embeddings=position_embeddings,

0 commit comments

Comments
 (0)