Skip to content

Commit 2e822c6

Browse files
committed
add length pad for flex attn
Signed-off-by: h-guo18 <[email protected]>
1 parent f5835f9 commit 2e822c6

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

examples/speculative_decoding/eagle_utils.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
236236

237237

238238
def make_eagle_supervised_data_module(
239-
tokenizer: transformers.PreTrainedTokenizer, data_args, use_offline_training: bool
239+
tokenizer: transformers.PreTrainedTokenizer,
240+
data_args,
241+
use_offline_training: bool,
242+
pad_length=None,
240243
) -> dict:
241244
"""Make dataset and collator for supervised fine-tuning.
242245
@@ -303,7 +306,7 @@ def make_eagle_supervised_data_module(
303306
train_dataset = dataset_cls(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
304307
eval_dataset = dataset_cls(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)
305308

306-
data_collator = DataCollatorWithPadding()
309+
data_collator = DataCollatorWithPadding(pad_length=pad_length)
307310

308311
return {
309312
"train_dataset": train_dataset,
@@ -313,6 +316,9 @@ def make_eagle_supervised_data_module(
313316

314317

315318
class DataCollatorWithPadding:
319+
def __init__(self, pad_length=None):
320+
self.pad_length = pad_length
321+
316322
def paddingtensor2d(self, intensors, length):
317323
n, dim = intensors.shape
318324
padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
@@ -325,7 +331,11 @@ def paddingtensor(self, intensors, length):
325331
return outtensors
326332

327333
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
328-
max_length = max(item["input_ids"].shape[0] for item in features)
334+
max_length = (
335+
self.pad_length
336+
if self.pad_length is not None
337+
else max(item["input_ids"].shape[0] for item in features)
338+
)
329339
batch_input_ids = torch.stack(
330340
[self.paddingtensor(item["input_ids"], max_length) for item in features]
331341
)
@@ -357,7 +367,11 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
357367
raise ValueError("No kwargs found in batch features. Offline data required.")
358368

359369
features = [item["kwargs"]["base_model_outputs"] for item in features]
360-
max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features)
370+
max_hs_length = (
371+
max(item["base_model_hidden_states"].shape[0] for item in features)
372+
if self.pad_length is None
373+
else self.pad_length
374+
)
361375

362376
batch_hidden_states = torch.stack(
363377
[

examples/speculative_decoding/main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,9 @@ def train():
229229
if training_args.mode == "medusa":
230230
data_module = make_medusa_supervised_data_module(tokenizer, data_args)
231231
elif training_args.mode in ["eagle1", "eagle3"]:
232-
data_module = make_eagle_supervised_data_module(tokenizer, data_args, use_offline_training)
232+
data_module = make_eagle_supervised_data_module(
233+
tokenizer, data_args, use_offline_training, pad_length=training_args.training_seq_len
234+
)
233235

234236
class ARValidationCallback(TrainerCallback):
235237
def __init__(self, ar_validate_steps: int = 500):

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def forward(
316316
hidden_states,
317317
attention_mask=attention_mask,
318318
position_ids=position_ids,
319-
past_key_values=past_key_values,
319+
past_key_value=past_key_values,
320320
output_attentions=output_attentions,
321321
use_cache=use_cache,
322322
position_embeddings=position_embeddings,

0 commit comments

Comments
 (0)