Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,10 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:


def make_eagle_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer, data_args, use_offline_training: bool
tokenizer: transformers.PreTrainedTokenizer,
data_args,
use_offline_training: bool,
max_length=None,
) -> dict:
"""Make dataset and collator for supervised fine-tuning.

Expand Down Expand Up @@ -295,15 +298,15 @@ def make_eagle_supervised_data_module(
train_dataset = dataset_cls(valid_entries[:num_train], tokenizer=tokenizer)
eval_dataset = dataset_cls(valid_entries[num_train:], tokenizer=tokenizer)

data_collator = DataCollatorForOffline()
data_collator = DataCollatorForOffline(max_length=max_length)
else:
print_rank_0("Loading input conversations...")
dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset

train_dataset = dataset_cls(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer)
eval_dataset = dataset_cls(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer)

data_collator = DataCollatorWithPadding()
data_collator = DataCollatorWithPadding(max_length=max_length)

return {
"train_dataset": train_dataset,
Expand All @@ -313,6 +316,9 @@ def make_eagle_supervised_data_module(


class DataCollatorWithPadding:
def __init__(self, max_length):
self.max_length = max_length

def paddingtensor2d(self, intensors, length):
n, dim = intensors.shape
padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype)
Expand All @@ -325,19 +331,18 @@ def paddingtensor(self, intensors, length):
return outtensors

def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
max_length = max(item["input_ids"].shape[0] for item in features)
batch_input_ids = torch.stack(
[self.paddingtensor(item["input_ids"], max_length) for item in features]
[self.paddingtensor(item["input_ids"], self.max_length) for item in features]
)
batch_attention_mask = torch.stack(
[self.paddingtensor(item["attention_mask"], max_length) for item in features]
[self.paddingtensor(item["attention_mask"], self.max_length) for item in features]
)
batch_loss_mask = torch.stack(
[self.paddingtensor(item["loss_mask"], max_length) for item in features]
[self.paddingtensor(item["loss_mask"], self.max_length) for item in features]
)

batch_labels = torch.stack(
[self.paddingtensor(item["labels"], max_length) for item in features]
[self.paddingtensor(item["labels"], self.max_length) for item in features]
)

batch = {
Expand All @@ -357,16 +362,15 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
raise ValueError("No kwargs found in batch features. Offline data required.")

features = [item["kwargs"]["base_model_outputs"] for item in features]
max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features)

batch_hidden_states = torch.stack(
[
self.paddingtensor2d(item["base_model_hidden_states"], max_hs_length)
self.paddingtensor2d(item["base_model_hidden_states"], self.max_length)
for item in features
]
)
batch_aux_hidden_states = torch.stack(
[self.paddingtensor2d(item["aux_hidden_states"], max_hs_length) for item in features]
[self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features]
)

batch = {
Expand Down
4 changes: 3 additions & 1 deletion examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ def train():
if training_args.mode == "medusa":
data_module = make_medusa_supervised_data_module(tokenizer, data_args)
elif training_args.mode in ["eagle1", "eagle3"]:
data_module = make_eagle_supervised_data_module(tokenizer, data_args, use_offline_training)
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, use_offline_training, max_length=training_args.training_seq_len
)

class ARValidationCallback(TrainerCallback):
def __init__(self, ar_validate_steps: int = 500):
Expand Down
Loading
Loading