Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ trainer.save_model("<path to the output directory>")
| LLAMA 3, 3.1 | ✅ | ✅ | ✅ |
| Mistral | ✅ | ✅ | ✅ |
| Phi 3 | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ |
| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ |

## Speculation Module Checkpoints

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,16 @@ def keep_conversation(entry):
num_success = 0
pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations")

def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
async def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""
Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
if not trtllm_dumped_file.exists():
return False
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
Expand All @@ -232,35 +235,33 @@ def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)

if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
trtllm_dumped_file.unlink()
return True

async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
nonlocal num_success
await llm_spec.generate_async(input_ids, sampling_params)
# TRTLLM API name files starts from 1
# ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
_post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
num_success += 1
dump_success = await _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
num_success += int(dump_success)
pbar.update(1)

async def submit_generates():
nonlocal num_skipped_too_long
nonlocal num_invalid
tasks = []
for idx, entry in enumerate(dataset):
idx = 0
for entry in dataset:
conversation_id = entry.get("conversation_id", entry.get("uuid"))

conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
:256
]
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
num_input_tokens = (
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
)
Expand All @@ -269,6 +270,8 @@ async def submit_generates():
continue

tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
# Increment only for valid conversations to match dump file index
idx += 1
await asyncio.gather(*tasks)

asyncio.run(submit_generates())
Expand Down
33 changes: 19 additions & 14 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def preprocess(examples, tokenizer):
"loss_mask": [],
"labels": [],
}
roles = ["user", "assistant"]
for i in range(len(examples)):
messages = []
source = examples[i]["conversations"]
Expand All @@ -61,13 +60,8 @@ def get_role_content(item):
else:
raise ValueError(f"Unknown conversation format: {item}")

first_role, _ = get_role_content(source[0])
if first_role.lower() != "user":
# Skip the first one if it is not from human
source = source[1:]
for j, sentence in enumerate(source):
for sentence in source:
role, content = get_role_content(sentence)
assert role.lower() == roles[j % 2], f"{i}"
messages.append({"role": role.lower(), "content": content})
conversation = tokenizer.apply_chat_template(
messages,
Expand Down Expand Up @@ -259,11 +253,20 @@ def make_eagle_supervised_data_module(
dict: A dictionary containing train and eval datasets.
"""
# Load the conversations from the source file
with open(data_args.data_path) as f:
if data_args.data_path.endswith("jsonl"):
data_json = [json.loads(line) for line in f]
else:
data_json = json.load(f)
print_rank_0("Loading input conversations...")
data_json = []
data_path_p = Path(data_args.data_path)
if data_path_p.is_dir():
# Load all .jsonl files in the directory and combine them
for jsonl_file in sorted(data_path_p.glob("*.jsonl")):
with open(jsonl_file) as f:
data_json.extend(json.loads(line) for line in f)
else:
with open(data_args.data_path) as f:
if data_args.data_path.endswith("jsonl"):
data_json = [json.loads(line) for line in f]
else:
data_json = json.load(f)

if use_offline_training:
print_rank_0("Loading pre-processed data for offline training...")
Expand All @@ -280,12 +283,14 @@ def make_eagle_supervised_data_module(

# Filter to conversations that exist in the offline data and in the provided json
valid_entries = []
for idx, entry in enumerate(data_json):
for entry in data_json:
conv_id = entry.get("conversation_id")
if conv_id is None:
conv_id = entry.get("uuid")
if conv_id is None:
conv_id = entry.get("id")
if conv_id is None:
conv_id = "{:08d}".format(idx)
raise ValueError(f"Conversation ID required but not found for entry {entry}")
file_path = str(offline_data_path / f"{conv_id}.pt")
if file_path in all_files:
valid_entries.append((entry, file_path))
Expand Down
6 changes: 6 additions & 0 deletions examples/speculative_decoding/launch_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ while [ $# -gt 0 ]; do
if [[ "$1" != *=* ]]; then shift; fi
NUM_GPU="${1#*=}"
;;
--disable_tqdm*)
if [[ "$1" != *=* ]]; then shift; fi
DISABLE_TQDM="${1#*=}"
;;
*)
>&2 printf "Error: Invalid argument ${1#*=}\n"
exit 1
Expand Down Expand Up @@ -110,6 +114,7 @@ FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaD
NUM_GPU=${NUM_GPU:-1}
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
DISABLE_TQDM=${DISABLE_TQDM:-False}

if [[ "$MODE" == "medusa" ]]; then
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
Expand Down Expand Up @@ -165,6 +170,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
--logging_steps 100 \
--tf32 True \
--data_path $DATA \
--disable_tqdm $DISABLE_TQDM \
$OFFLINE_TRAINING_ARGS \
$SPECULATIVE_ARGS
"
Expand Down
1 change: 1 addition & 0 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class TrainingArguments(transformers.TrainingArguments):
bf16: bool = field(default=True)
mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3"
ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."})
disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."})


@dataclass
Expand Down