From 313d3b1a6652764cddf6714d909e4182c6932c57 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 21 Oct 2025 20:30:26 +0000 Subject: [PATCH] fix:eagle3 offline Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 2 +- .../compute_hidden_states_trtllm.py | 25 +++++++------- examples/speculative_decoding/eagle_utils.py | 33 +++++++++++-------- examples/speculative_decoding/launch_train.sh | 6 ++++ examples/speculative_decoding/main.py | 1 + 5 files changed, 41 insertions(+), 26 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 2936b1d5a..effdebe02 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -312,7 +312,7 @@ trainer.save_model("") | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | -| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | +| QWen 1.5,2,2.5,3 | ✅ | ✅ | ✅ | ## Speculation Module Checkpoints diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py index 6b088a3a6..0bf68e430 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py @@ -208,13 +208,16 @@ def keep_conversation(entry): num_success = 0 pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations") - def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): - """Post-process the TRTLLM dumped file to same format as HF dumped: + async def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): + """ + Post-process the TRTLLM dumped file to same format as HF dumped: 1. Remove id field, replace it with conversation_id 2. Rename hidden_state field to hidden_states 3. From list of length 1 to dict 4. Rename file to conversation_id.pt """ + if not trtllm_dumped_file.exists(): + return False with open(trtllm_dumped_file, "rb") as f: trtllm_dumped = torch.load(f) assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, ( @@ -232,9 +235,8 @@ def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): output_file = args.output_dir / f"{conversation_id}.pt" with open(output_file, "wb") as f: torch.save(trtllm_dumped, f) - - if trtllm_dumped_file.exists(): - trtllm_dumped_file.unlink() + trtllm_dumped_file.unlink() + return True async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]): nonlocal num_success @@ -242,15 +244,16 @@ async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int # TRTLLM API name files starts from 1 # ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012 trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt" - _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id) - num_success += 1 + dump_success = await _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id) + num_success += int(dump_success) pbar.update(1) async def submit_generates(): nonlocal num_skipped_too_long nonlocal num_invalid tasks = [] - for idx, entry in enumerate(dataset): + idx = 0 + for entry in dataset: conversation_id = entry.get("conversation_id", entry.get("uuid")) conversations = entry["conversations"] @@ -258,9 +261,7 @@ async def submit_generates(): num_invalid += 1 continue - input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ - :256 - ] + input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False) num_input_tokens = ( input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids) ) @@ -269,6 +270,8 @@ async def submit_generates(): continue tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) + # Increment only for valid conversations to match dump file index + idx += 1 await asyncio.gather(*tasks) asyncio.run(submit_generates()) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 576179dd1..7d1820e28 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -47,7 +47,6 @@ def preprocess(examples, tokenizer): "loss_mask": [], "labels": [], } - roles = ["user", "assistant"] for i in range(len(examples)): messages = [] source = examples[i]["conversations"] @@ -61,13 +60,8 @@ def get_role_content(item): else: raise ValueError(f"Unknown conversation format: {item}") - first_role, _ = get_role_content(source[0]) - if first_role.lower() != "user": - # Skip the first one if it is not from human - source = source[1:] - for j, sentence in enumerate(source): + for sentence in source: role, content = get_role_content(sentence) - assert role.lower() == roles[j % 2], f"{i}" messages.append({"role": role.lower(), "content": content}) conversation = tokenizer.apply_chat_template( messages, @@ -259,11 +253,20 @@ def make_eagle_supervised_data_module( dict: A dictionary containing train and eval datasets. """ # Load the conversations from the source file - with open(data_args.data_path) as f: - if data_args.data_path.endswith("jsonl"): - data_json = [json.loads(line) for line in f] - else: - data_json = json.load(f) + print_rank_0("Loading input conversations...") + data_json = [] + data_path_p = Path(data_args.data_path) + if data_path_p.is_dir(): + # Load all .jsonl files in the directory and combine them + for jsonl_file in sorted(data_path_p.glob("*.jsonl")): + with open(jsonl_file) as f: + data_json.extend(json.loads(line) for line in f) + else: + with open(data_args.data_path) as f: + if data_args.data_path.endswith("jsonl"): + data_json = [json.loads(line) for line in f] + else: + data_json = json.load(f) if use_offline_training: print_rank_0("Loading pre-processed data for offline training...") @@ -280,12 +283,14 @@ def make_eagle_supervised_data_module( # Filter to conversations that exist in the offline data and in the provided json valid_entries = [] - for idx, entry in enumerate(data_json): + for entry in data_json: conv_id = entry.get("conversation_id") + if conv_id is None: + conv_id = entry.get("uuid") if conv_id is None: conv_id = entry.get("id") if conv_id is None: - conv_id = "{:08d}".format(idx) + raise ValueError(f"Conversation ID required but not found for entry {entry}") file_path = str(offline_data_path / f"{conv_id}.pt") if file_path in all_files: valid_entries.append((entry, file_path)) diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 2d0a4abe7..99fa89b3a 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -78,6 +78,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi NUM_GPU="${1#*=}" ;; + --disable_tqdm*) + if [[ "$1" != *=* ]]; then shift; fi + DISABLE_TQDM="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -110,6 +114,7 @@ FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaD NUM_GPU=${NUM_GPU:-1} TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} +DISABLE_TQDM=${DISABLE_TQDM:-False} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -165,6 +170,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ --logging_steps 100 \ --tf32 True \ --data_path $DATA \ + --disable_tqdm $DISABLE_TQDM \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS " diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 20242c795..c1f59ef88 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -92,6 +92,7 @@ class TrainingArguments(transformers.TrainingArguments): bf16: bool = field(default=True) mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3" ar_validate_steps: int = field(default=1000, metadata={"help": "Steps between AR validation."}) + disable_tqdm: bool = field(default=False, metadata={"help": "Disable tqdm progress bar."}) @dataclass