Skip to content

Commit 615f3c0

Browse files
authored
Efficient Eagle3 training with eagle KV cache and flex attention (NVIDIA#350)
Signed-off-by: h-guo18 <[email protected]>
1 parent 2b13b67 commit 615f3c0

File tree

4 files changed

+109
-442
lines changed

4 files changed

+109
-442
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
236236

237237

238238
def make_eagle_supervised_data_module(
239-
tokenizer: transformers.PreTrainedTokenizer, data_args, use_offline_training: bool
239+
tokenizer: transformers.PreTrainedTokenizer,
240+
data_args,
241+
use_offline_training: bool,
242+
max_length=None,
240243
) -> dict:
241244
"""Make dataset and collator for supervised fine-tuning.
242245
@@ -295,15 +298,15 @@ def make_eagle_supervised_data_module(
295298
train_dataset = dataset_cls(valid_entries[:num_train], tokenizer=tokenizer)
296299
eval_dataset = dataset_cls(valid_entries[num_train:], tokenizer=tokenizer)
297300

298-
data_collator = DataCollatorForOffline()
301+
data_collator = DataCollatorForOffline(max_length=max_length)
299302
else:
300303
print_rank_0("Loading input conversations...")
301304
dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
302305

303306
train_dataset = dataset_cls(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
304307
eval_dataset = dataset_cls(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)
305308

306-
data_collator = DataCollatorWithPadding()
309+
data_collator = DataCollatorWithPadding(max_length=max_length)
307310

308311
return {
309312
"train_dataset": train_dataset,
@@ -313,6 +316,9 @@ def make_eagle_supervised_data_module(
313316

314317

315318
class DataCollatorWithPadding:
319+
def __init__(self, max_length):
320+
self.max_length = max_length
321+
316322
def paddingtensor2d(self, intensors, length):
317323
n, dim = intensors.shape
318324
padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
@@ -325,19 +331,18 @@ def paddingtensor(self, intensors, length):
325331
return outtensors
326332

327333
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
328-
max_length = max(item["input_ids"].shape[0] for item in features)
329334
batch_input_ids = torch.stack(
330-
[self.paddingtensor(item["input_ids"], max_length) for item in features]
335+
[self.paddingtensor(item["input_ids"], self.max_length) for item in features]
331336
)
332337
batch_attention_mask = torch.stack(
333-
[self.paddingtensor(item["attention_mask"], max_length) for item in features]
338+
[self.paddingtensor(item["attention_mask"], self.max_length) for item in features]
334339
)
335340
batch_loss_mask = torch.stack(
336-
[self.paddingtensor(item["loss_mask"], max_length) for item in features]
341+
[self.paddingtensor(item["loss_mask"], self.max_length) for item in features]
337342
)
338343

339344
batch_labels = torch.stack(
340-
[self.paddingtensor(item["labels"], max_length) for item in features]
345+
[self.paddingtensor(item["labels"], self.max_length) for item in features]
341346
)
342347

343348
batch = {
@@ -357,16 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
357362
raise ValueError("No kwargs found in batch features. Offline data required.")
358363

359364
features = [item["kwargs"]["base_model_outputs"] for item in features]
360-
max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features)
361365

362366
batch_hidden_states = torch.stack(
363367
[
364-
self.paddingtensor2d(item["base_model_hidden_states"], max_hs_length)
368+
self.paddingtensor2d(item["base_model_hidden_states"], self.max_length)
365369
for item in features
366370
]
367371
)
368372
batch_aux_hidden_states = torch.stack(
369-
[self.paddingtensor2d(item["aux_hidden_states"], max_hs_length) for item in features]
373+
[self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features]
370374
)
371375

372376
batch = {

examples/speculative_decoding/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def train():
227227
if training_args.mode == "medusa":
228228
data_module = make_medusa_supervised_data_module(tokenizer, data_args)
229229
elif training_args.mode in ["eagle1", "eagle3"]:
230-
data_module = make_eagle_supervised_data_module(tokenizer, data_args, use_offline_training)
230+
data_module = make_eagle_supervised_data_module(
231+
tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
232+
)
231233

232234
class ARValidationCallback(TrainerCallback):
233235
def __init__(self, ar_validate_steps: int = 500):

0 commit comments

Comments
 (0)