diff --git a/examples/speculative_decoding/.gitignore b/examples/speculative_decoding/.gitignore index 7095ed0a..e2226783 100644 --- a/examples/speculative_decoding/.gitignore +++ b/examples/speculative_decoding/.gitignore @@ -1 +1,4 @@ Daring-Anteater +input_conversations +synthetic_conversations +ckpts \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/__init__.py b/examples/speculative_decoding/collect_hidden_states/__init__.py new file mode 100644 index 00000000..78352afe --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Collect hidden states from a dataset of conversations.""" diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py new file mode 100644 index 00000000..82b40e9a --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extract hidden states from an HF-compatible LLM.""" + +import argparse +import asyncio +import json +from pathlib import Path + +import torch +from tqdm import tqdm as tqdm +from transformers import AutoModel, AutoTokenizer + +REMOVE_THINK_CHAT_TEMPLATE = ( + "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="""Collect hidden states from conversations + by running full conversations through a Hugging Face model.""" + ) + + ## Model & Generation Parameters ## + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the served model.", + ) + + ## Client Parameters ## + parser.add_argument( + "--max-seq-len", + type=int, + default=3072, + help="""Maximum number of tokens in a conversation. Longer conversations will be skipped. + Defaults to 3072 tokens.""", + ) + + ## I/O Parameters ## + parser.add_argument( + "--input-file", + type=Path, + required=True, + help="""Path to the input `jsonl` file containing conversations. + Each entry must have a unique `conversation_id` field and a `conversations` field + containing a list of messages.""", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="""Root directory in which to save the hidden states. + The data will be saved as a torch (`.pt`) dump file for each conversation.""", + ) + parser.add_argument( + "--debug-max-num-conversations", + type=int, + default=None, + help="""For debugging purposes, limit the number of conversations processed. + Default is None, meaning no limit.""", + ) + + return parser.parse_args() + + +async def main(args: argparse.Namespace) -> None: + all_conversations = [] + with args.input_file.open("r", encoding="utf-8") as f: + all_conversations.extend([json.loads(line) for line in f if line.strip()]) + + if any(not entry.get("conversation_id") for entry in all_conversations): + msg = "All conversations must have a 'conversation_id' field." + raise ValueError(msg) + + print("Loaded", len(all_conversations), "conversations from", args.input_file) + + model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + num_hidden_layers = getattr(model.config, "num_hidden_layers", None) + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + + output_dir = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + num_skipped_too_long = 0 + num_invalid = 0 + num_success = 0 + num_total_conversations = min( + len(all_conversations), args.debug_max_num_conversations or len(all_conversations) + ) + for idx, entry in enumerate( + tqdm( + all_conversations[: args.debug_max_num_conversations], + desc="Processing conversations", + total=num_total_conversations, + ) + ): + conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) + conversations = entry["conversations"] + if not conversations or not isinstance(conversations, list): + num_invalid += 1 + continue + + # Tokenize and check length + input_ids = tokenizer.apply_chat_template( + conversations, return_tensors="pt", add_generation_template=False + ) + num_input_tokens = input_ids.shape[1] + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_skipped_too_long += 1 + continue + + # Get hidden states + with torch.inference_mode(): + outputs = model(input_ids=input_ids.to(model.device), output_hidden_states=True) + if num_hidden_layers is None: + num_hidden_layers = len(outputs.hidden_states) - 1 + else: + assert num_hidden_layers + 1 == len(outputs.hidden_states), ( + f"Expected {num_hidden_layers}+1 layers of hidden states, but got {len(outputs.hidden_states)}." + ) + # Extract hidden states from layers with index (2, N/2, N-3), and the output hidden states + hidden_states = outputs.hidden_states + selected_layer_indices = [ + 2, + max(0, num_hidden_layers // 2), + max(1, num_hidden_layers - 3), + ] + selected_layer_indices = sorted(set(selected_layer_indices)) + aux_hidden_states = torch.cat( + [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1 + ) + output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu() + output_file = output_dir / f"{conversation_id}.pt" + num_success += 1 + with open(output_file, "wb") as f: + torch.save( + { + "input_ids": input_ids.squeeze(0).cpu(), + "hidden_states": output_hidden_states, + "aux_hidden_states": aux_hidden_states, + "conversation_id": conversation_id, + }, + f, + ) + + if num_skipped_too_long > 0: + print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") + if num_invalid > 0: + print(f"Skipped {num_invalid} invalid conversations without proper fields.") + + if num_success == num_total_conversations: + print(f"Successfully processed all {num_success} conversations.") + else: + print( + f"Successfully processed {num_success} out of {num_total_conversations} conversations." + ) + + +if __name__ == "__main__": + cli_args = parse_args() + asyncio.run(main(cli_args)) diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh new file mode 100644 index 00000000..48d12aeb --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example usage of the script to compute the hidden states for a conversation dataset +# This script computes hidden states using a Hugging Face model and saves them to +# the specified output directory. + +python3 collect_hidden_states/compute_hidden_states_hf.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-file synthetic_conversations/daring-anteater.jsonl \ + --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh new file mode 100644 index 00000000..e9c3a2cd --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example usage of the script to compute the hidden states for a conversation dataset +# This script computes hidden states using a Hugging Face model and saves them to +# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting +# the input file into 8 parts and running 8 processes in parallel, one on each GPU. + +# Note: depending on the write-throughput of the destination disk, this is not guaranteed +# to yield a speed improvement compared to running the model-parallel version. Consider +# benchmarking on a smaller dataset before launching a large run. + +INPUT_FILE=synthetic_conversations/daring-anteater.jsonl +OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ + +split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- + +for i in $(seq 0 7) +do +CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & +done +wait + +rm /tmp/part-*.jsonl \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh b/examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh new file mode 100644 index 00000000..9da3907a --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/run_send_conversations.sh @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example usage of the script to send conversations for hidden state collection +# This script sends conversations to a (local) OpenAI-compatible server for processing and collects hidden states. + +python3 collect_hidden_states/send_conversations_for_hiddens.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-file synthetic_conversations/mtbench.jsonl \ + --output-dir /mnt/md0/eagle-hidden-states/llama1b/mtbench/ +# --debug-max-num-conversations-per-split 1000 diff --git a/examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py b/examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py new file mode 100644 index 00000000..75a88969 --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/sample_hidden_states.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility script to print a sample of hidden states extracted from a dataset.""" + +import argparse +import random +from pathlib import Path + +import torch + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Print a sample of hidden states from a dataset." + "This script will crawl the provided directory for hidden state files," + " and print a small number of samples." + ) + + parser.add_argument( + "--input-path", + type=Path, + required=True, + help="Path to the base input directory containing hidden states." + "Alternatively, this can be a path to a specific `.pt` file.", + ) + parser.add_argument( + "--num-samples", + type=int, + default=1, + help="Number of samples to print per split. If input_path is a file, this is ignored. " + "Defaults to 1.", + ) + return parser.parse_args() + + +def main(args: argparse.Namespace) -> None: + # Iterate through the input directory and find all hidden state files + if args.input_path.is_file(): + all_files = [args.input_path] + else: + all_files = list(args.input_path.glob("*.pt")) + + sampled_files = ( + random.sample(all_files, args.num_samples) + if len(all_files) > args.num_samples + else all_files + ) + + for i, file in enumerate(sampled_files): + data = torch.load(file) + expected_keys = [ + "input_ids", + "hidden_states", + "aux_hidden_states", + "conversation_id", + ] + if set(expected_keys) != set(data.keys()): + print(f"File {file} does not contain all expected keys: {expected_keys}") + print(f" Found keys: {list(data.keys())}") + continue + print(f"Sample {i + 1}: {file.name}") + for key in ["input_ids", "hidden_states", "aux_hidden_states"]: + print(f"{key}: {data[key].shape} {data[key].dtype} {data[key].device}") + print(f"conversation_id: {data['conversation_id']}") + input_ids_list = data["input_ids"].tolist() + hidden_states = data["hidden_states"] + print(f"Sample of input_ids (first 10 tokens): \n{input_ids_list[:10]}") + print(f"Sample of input_ids (last 10 tokens): \n{input_ids_list[-10:]}") + print(f"Sample of hidden_states (first 10 positions): \n{hidden_states[:10]}") + + print(f"\n\nDone. Found: {len(all_files)} files in total.") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py b/examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py new file mode 100644 index 00000000..c3c12a6c --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Send conversations from a dataset to an OpenAI-compatible endpoint.""" + +import argparse +import asyncio +import json +from pathlib import Path + +import httpx +import openai +from openai import AsyncOpenAI +from tqdm import tqdm +from transformers import AutoTokenizer + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="""Collect hidden states from conversations + by sending full conversations as prompts to an OpenAI-compatible endpoint.""" + ) + + ## Model & Generation Parameters ## + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the served model.", + ) + + ## Client Parameters ## + parser.add_argument( + "--base-url", + type=str, + default="http://localhost:8000/v1", + help="""HTTP URL for the OpenAI-compatible endpoint. + Defaults to `http://localhost:8000/v1`.""", + ) + parser.add_argument( + "--openai-api-key", + default="EMPTY", + help="""Access key required by the OpenAI Python client + (not required for local serving engines like vLLM).""", + ) + + ## I/O Parameters ## + parser.add_argument( + "--max-seq-len", + type=int, + default=3072, + help="""Maximum number of tokens in a conversation. Longer conversations will be skipped. + Defaults to 3072 tokens.""", + ) + parser.add_argument( + "--input-file", + type=Path, + required=True, + help="""Path to the input `jsonl` file containing conversations. + Each entry must have a unique `conversation_id` field and a `conversations` field + containing a list of messages.""", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="""Root directory in which to save the hidden states. + The data will be saved as a torch (`.pt`) dump file for each conversation.""", + ) + parser.add_argument( + "--debug-max-num-conversations", + type=int, + default=None, + help="""For debugging purposes, limit the number of conversations processed. + Default is None, meaning no limit.""", + ) + + return parser.parse_args() + + +async def main(args: argparse.Namespace) -> None: + all_conversations = [] + with args.input_file.open("r", encoding="utf-8") as f: + all_conversations.extend([json.loads(line) for line in f if line.strip()]) + + if any(not entry.get("conversation_id") for entry in all_conversations): + msg = "All conversations must have a 'conversation_id' field." + raise ValueError(msg) + + print("Loaded", len(all_conversations), "conversations from", args.input_file) + + client: AsyncOpenAI = AsyncOpenAI( + api_key=args.openai_api_key, + base_url=args.base_url, + ) + + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + bos_token_id = tokenizer.bos_token_id + if bos_token_id is None: + raise ValueError("The tokenizer does not have a BOS token.") + + temp_meta_file = Path("/tmp/meta.json") + if temp_meta_file.exists(): + print(f"Temporary meta file {temp_meta_file} found, removing it.") + temp_meta_file.unlink() + + output_dir = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + num_invalid = 0 + num_error = 0 + num_too_long = 0 + num_success = 0 + num_total_conversations = min( + len(all_conversations), args.debug_max_num_conversations or len(all_conversations) + ) + for entry in tqdm( + all_conversations[: args.debug_max_num_conversations], + desc="Processing conversations", + total=num_total_conversations, + ): + conversation_id = entry["conversation_id"] + conversations = entry["conversations"] + if not conversations or not isinstance(conversations, list): + num_invalid += 1 + continue + + hidden_states_file = output_dir / f"{conversation_id}.pt" + + # Use /tmp/meta.json to communicate with the local serving engine. + # See usage guide for more details + with temp_meta_file.open("w") as f: + json.dump( + { + "conversation_id": conversation_id, + "output_file": str(hidden_states_file), + }, + f, + ) + + input_ids = tokenizer.apply_chat_template( + conversations, return_tensors=None, add_generation_template=False, tokenize=True + ) + num_input_tokens = len(input_ids) + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_too_long += 1 + continue + if input_ids[0] == bos_token_id: + # Remove the leading BOS token. vLLM's completion generation + # endpoint will prepend the BOS token automatically. + input_ids = input_ids[1:] + input_string = tokenizer.decode(input_ids, skip_special_tokens=False) + + try: + # Send the message to the OpenAI-compatible endpoint + await client.completions.create( + model=args.model, + prompt=input_string, + temperature=0.0, + max_tokens=1, + ) + except httpx.HTTPStatusError as e: + print(f"HTTP error for conversation {conversation_id}: {e}") + num_error += 1 + continue + except openai.BadRequestError: + # Most likely the conversation is too long, ignore + num_too_long += 1 + continue + except Exception as e: + num_error += 1 + print(f"Error sending conversation {conversation_id}: {e}") + continue + finally: + # Ensure the meta file is cleaned up after each request + if temp_meta_file.exists(): + temp_meta_file.unlink() + num_success += 1 + continue + + if num_invalid > 0: + print(f"Skipped {num_invalid} invalid conversations without proper fields.") + if num_too_long > 0: + print(f"Skipped {num_too_long} conversations likely due to length constraints.") + if num_error > 0: + print(f"Encountered errors for {num_error} conversations.") + + if num_success == num_total_conversations: + print(f"Successfully processed all {num_success} conversations.") + else: + print( + f"Successfully processed {num_success} out of {num_total_conversations} conversations." + ) + + +if __name__ == "__main__": + cli_args = parse_args() + asyncio.run(main(cli_args)) diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 8ac292a3..7f50178b 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -14,6 +14,7 @@ # limitations under the License. import json +from pathlib import Path from typing import Any import torch @@ -172,8 +173,70 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]: return ret +class OfflineSupervisedDataset(Dataset): + """Lazy offline dataset for supervised fine-tuning. + + This dataset loads data on-the-fly from pre-processed .pt data files as well as + input conversations in JSON format. + + Args: + data_entries (list): A list of tuples (raw_data_example, file_path). + tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. + """ + + def __init__(self, data_entries, tokenizer: transformers.PreTrainedTokenizer): + super().__init__() + print_rank_0("Formatting inputs...Skip in offline mode") + self.tokenizer = tokenizer + self.data_entries = data_entries + + # Does not cache the hidden states, as those have an extremely large memory footprint. + self.cached_data_dict = {} + + def __len__(self): + return len(self.data_entries) + + def __getitem__(self, i) -> dict[str, torch.Tensor]: + # Load the conversational data, using the cache + raw_data, offline_file_path = self.data_entries[i] + if i in self.cached_data_dict: + preprocessed_base = self.cached_data_dict[i] + else: + ret = preprocess([raw_data], self.tokenizer) + preprocessed_base = { + "input_ids": ret["input_ids"][0], + "labels": ret["labels"][0], + "attention_mask": ret["attention_mask"][0], + "loss_mask": ret["loss_mask"][0], + } + self.cached_data_dict[i] = preprocessed_base + + # Extend the data sample with the hidden states from the .pt file + max_length = self.tokenizer.model_max_length + offline_data = torch.load(offline_file_path) + offline_data["input_ids"] = offline_data["input_ids"][:max_length] + offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] + offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] + + # Make sure the input_ids have the same shape + if preprocessed_base["input_ids"].shape != offline_data["input_ids"].shape: + msg = f"""Input IDs from offline data do not match the preprocessed input IDs + for offline data sample at {offline_file_path}.""" + raise ValueError(msg) + + ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache + ret["input_ids"] = offline_data["input_ids"] + ret["kwargs"] = { + "base_model_outputs": { + "base_model_hidden_states": offline_data["hidden_states"], + "aux_hidden_states": offline_data["aux_hidden_states"], + } + } + return ret + + def make_eagle_supervised_data_module( - tokenizer: transformers.PreTrainedTokenizer, data_args + tokenizer: transformers.PreTrainedTokenizer, data_args, use_offline_training: bool ) -> dict: """Make dataset and collator for supervised fine-tuning. @@ -184,18 +247,63 @@ def make_eagle_supervised_data_module( Returns: dict: A dictionary containing train and eval datasets. """ - dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset - print_rank_0("Loading data...") - - if data_args.data_path.endswith("jsonl"): - with open(data_args.data_path) as f: + # Load the conversations from the source file + with open(data_args.data_path) as f: + if data_args.data_path.endswith("jsonl"): data_json = [json.loads(line) for line in f] + else: + data_json = json.load(f) + + if use_offline_training: + print_rank_0("Loading pre-processed data for offline training...") + dataset_cls = OfflineSupervisedDataset + + # Glob for all .pt files in the data_path directory + assert data_args.offline_data_path is not None, ( + "offline_data_path must be provided for offline training." + ) + offline_data_path = Path(data_args.offline_data_path) + all_files = {str(p) for p in offline_data_path.glob("*.pt")} + if not all_files: + raise ValueError(f"No .pt files found in {data_args.offline_data_path}") + + # Filter to conversations that exist in the offline data and in the provided json + valid_entries = [] + for entry in data_json: + conv_id = entry.get("conversation_id") or entry.get("id") + if not conv_id: + raise ValueError( + "Each entry in the data json must have a 'conversation_id' or 'id' field." + ) + file_path = str(offline_data_path / f"{conv_id}.pt") + if file_path in all_files: + valid_entries.append((entry, file_path)) + + if len(valid_entries) == 0: + msg = """No valid files found in the offline data path that match the conversation IDs + in the provided data json. Please ensure that the offline data path is correct and + contains .pt files named after the conversation IDs, and that the input conversations + json has the correct format (with 'conversation_id' or 'id' fields).""" + raise ValueError(msg) + elif len(valid_entries) < len(data_json): + print_rank_0( + f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations" + " have corresponding .pt files in the offline data path. Continuing..." + ) + + num_train = int(len(valid_entries) * 0.95) + train_dataset = dataset_cls(valid_entries[:num_train], tokenizer=tokenizer) + eval_dataset = dataset_cls(valid_entries[num_train:], tokenizer=tokenizer) + + data_collator = DataCollatorForOffline() else: - data_json = json.load(open(data_args.data_path)) - train_dataset = dataset_cls(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer) - eval_dataset = dataset_cls(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer) + print_rank_0("Loading input conversations...") + dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset + + train_dataset = dataset_cls(data_json[: int(len(data_json) * 0.95)], tokenizer=tokenizer) + eval_dataset = dataset_cls(data_json[int(len(data_json) * 0.95) :], tokenizer=tokenizer) - data_collator = DataCollatorWithPadding() + data_collator = DataCollatorWithPadding() return { "train_dataset": train_dataset, @@ -240,3 +348,33 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: } return batch + + +class DataCollatorForOffline(DataCollatorWithPadding): + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + base_batch = super().__call__(features) + if "kwargs" not in features[0]: + raise ValueError("No kwargs found in batch features. Offline data required.") + + features = [item["kwargs"]["base_model_outputs"] for item in features] + max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features) + + batch_hidden_states = torch.stack( + [ + self.paddingtensor2d(item["base_model_hidden_states"], max_hs_length) + for item in features + ] + ) + batch_aux_hidden_states = torch.stack( + [self.paddingtensor2d(item["aux_hidden_states"], max_hs_length) for item in features] + ) + + batch = { + **base_batch, + "base_model_outputs": { + "base_model_hidden_states": batch_hidden_states, + "aux_hidden_states": batch_aux_hidden_states, + }, + } + + return batch diff --git a/examples/speculative_decoding/launch.sh b/examples/speculative_decoding/launch.sh deleted file mode 100755 index 67d228bf..00000000 --- a/examples/speculative_decoding/launch.sh +++ /dev/null @@ -1,159 +0,0 @@ -#!/bin/bash -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -set -eo pipefail - -while [ $# -gt 0 ]; do - case "$1" in - --training_seq_len*) - if [[ "$1" != *=* ]]; then shift; fi - TRAINING_SEQ_LEN="${1#*=}" - ;; - --model*) - if [[ "$1" != *=* ]]; then shift; fi - MODEL="${1#*=}" - ;; - --data*) - if [[ "$1" != *=* ]]; then shift; fi - DATA="${1#*=}" - ;; - --mode*) - if [[ "$1" != *=* ]]; then shift; fi - MODE="${1#*=}" - ;; - --output_dir*) - if [[ "$1" != *=* ]]; then shift; fi - OUTPUT_DIR="${1#*=}" - ;; - --num_epochs*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_EPOCHS="${1#*=}" - ;; - --save_steps*) - if [[ "$1" != *=* ]]; then shift; fi - SAVE_STEPS="${1#*=}" - ;; - --lr*) - if [[ "$1" != *=* ]]; then shift; fi - LR="${1#*=}" - ;; - --train_bs*) - if [[ "$1" != *=* ]]; then shift; fi - TRAIN_BS="${1#*=}" - ;; - --medusa_num_heads*) - if [[ "$1" != *=* ]]; then shift; fi - MEDUSA_NUM_HEADS="${1#*=}" - ;; - --medusa_num_layers*) - if [[ "$1" != *=* ]]; then shift; fi - MEDUSA_NUM_LAYERS="${1#*=}" - ;; - --eagle_config*) - if [[ "$1" != *=* ]]; then shift; fi - EAGLE_CONFIG="${1#*=}" - ;; - --fsdp_transformer_layer_cls_to_wrap*) - if [[ "$1" != *=* ]]; then shift; fi - FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}" - ;; - --num_gpu*) - if [[ "$1" != *=* ]]; then shift; fi - NUM_GPU="${1#*=}" - ;; - --do_eval*) - if [[ "$1" != *=* ]]; then shift; fi - DO_EVAL="${1#*=}" - ;; - *) - >&2 printf "Error: Invalid argument ${1#*=}\n" - exit 1 - ;; - esac - shift -done - -set -x - -# Get the default value for save_steps based on the available number of GPUs -GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT)) - -MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"} -MODE=${MODE:-"medusa"} -OUTPUT_DIR=${OUTPUT_DIR:-"tinyllama-medusa"} -NUM_EPOCHS=${NUM_EPOCHS:-1} -SAVE_STEPS=${SAVE_STEPS:-$DEFAULT_SAVE_STEPS} -LR=${LR:-"1e-4"} -TRAIN_BS=${TRAIN_BS:-4} -MEDUSA_NUM_HEADS=${MEDUSA_NUM_HEADS:-1} -MEDUSA_NUM_LAYERS=${MEDUSA_NUM_LAYERS:-1} -REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1} -REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1} -FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} -NUM_GPU=${NUM_GPU:-1} -DO_EVAL=${DO_EVAL:-"True"} -TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} - -if [[ "$MODE" == "medusa" ]]; then - SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" -elif [[ "$MODE" == "eagle1" || "$MODE" == "eagle3" ]]; then - if [[ -n "$EAGLE_CONFIG" ]]; then - SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" - else - SPECULATIVE_ARGS="" - fi -else - echo "Only medusa, eagle1, eagle3 supported for now!" - exit 1 -fi - -if [[ "$NUM_GPU" == 1 ]]; then - MULTI_GPU="" -else - MULTI_GPU="--multi_gpu" -fi - -CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ - --mode $MODE \ - --model_name_or_path $MODEL \ - --training_seq_len $TRAINING_SEQ_LEN \ - --dataloader_drop_last True \ - --bf16 True \ - --output_dir $OUTPUT_DIR \ - --num_train_epochs $NUM_EPOCHS \ - --per_device_train_batch_size $TRAIN_BS \ - --per_device_eval_batch_size $TRAIN_BS \ - --gradient_accumulation_steps 1 \ - --do_eval $DO_EVAL \ - --eval_accumulation_steps 1 \ - --save_strategy steps \ - --save_steps $SAVE_STEPS \ - --learning_rate $LR \ - --weight_decay 0.0 \ - --warmup_steps 100 \ - --lr_scheduler_type linear \ - --logging_steps 100 \ - --tf32 True \ - --data_path $DATA \ - --report_to tensorboard \ - $SPECULATIVE_ARGS -" - -start_time=$(date +%s) -sh -c "$CMD" -echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index cf54f944..3ecd4238 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -30,6 +30,10 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DATA="${1#*=}" ;; + --offline-data*) + if [[ "$1" != *=* ]]; then shift; fi + OFFLINE_DATA_PATH="${1#*=}" + ;; --mode*) if [[ "$1" != *=* ]]; then shift; fi MODE="${1#*=}" @@ -104,7 +108,8 @@ REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1} REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1} FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"} NUM_GPU=${NUM_GPU:-1} -TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512} +TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048} +OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -119,6 +124,17 @@ else exit 1 fi +if [[ "$OFFLINE_DATA_PATH" != "" ]]; then + if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then + echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." + exit 1 + else + OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH" + fi +else + OFFLINE_TRAINING_ARGS="" +fi + if [[ "$NUM_GPU" == 1 ]]; then MULTI_GPU="" else @@ -149,6 +165,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \ --logging_steps 100 \ --tf32 True \ --data_path $DATA \ + $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS " diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index aae4177c..ea711765 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -69,6 +69,16 @@ class DataArguments: metadata={"help": "Path to the training data."}, ) eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."}) + offline_data_path: str = field( + default=None, + metadata={ + "help": """Path to the offline training data. Providing this flag sets + `eagle_offline` in the EagleConfig and enables offline training. + The directory should contain many `.pt` files, each containing a pre-processed + data sample. `data_path` should still point to the original conversations file. + """ + }, + ) lazy_preprocess: bool = True draft_vocab_cache_dir: str = field( default="draft_vocab_cache", @@ -131,13 +141,21 @@ def train(): elif last_checkpoint is not None: checkpoint = last_checkpoint + use_offline_training = data_args.offline_data_path is not None + if checkpoint: model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto") tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint) else: + model_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} model = transformers.AutoModelForCausalLM.from_pretrained( - model_args.model_name_or_path, torch_dtype="auto" + model_args.model_name_or_path, torch_dtype="auto", **model_kwargs ) + if use_offline_training: + # When doing offline training, we need to set num_hidden_layers + # since we override it when loading the model for space savings + model_config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path) + model.config.num_orig_hidden_layers = model_config.num_hidden_layers tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, @@ -167,6 +185,9 @@ def train(): }[training_args.mode]["config"] # overwrite config with custom config + if use_offline_training: + config["eagle_offline"] = True + if eagle_args.eagle_config: with open(eagle_args.eagle_config) as f: custom_config = json.load(f) @@ -206,13 +227,15 @@ def train(): if training_args.mode == "medusa": data_module = make_medusa_supervised_data_module(tokenizer, data_args) elif training_args.mode in ["eagle1", "eagle3"]: - data_module = make_eagle_supervised_data_module(tokenizer, data_args) + data_module = make_eagle_supervised_data_module(tokenizer, data_args, use_offline_training) class ARValidationCallback(TrainerCallback): def __init__(self, ar_validate_steps: int = 500): self.ar_validate_steps = ar_validate_steps def on_step_end(self, args, state, control, **kwargs): + if self.ar_validate_steps <= 0: + return control if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0: print_rank_0("Running AR validation...") ars = validate_ar( diff --git a/examples/speculative_decoding/prepare_input_conversations/__init__.py b/examples/speculative_decoding/prepare_input_conversations/__init__.py new file mode 100644 index 00000000..926974f5 --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Scripts to add various datasets to a prompt dataset file.""" diff --git a/examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py b/examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py new file mode 100644 index 00000000..c78739ed --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Add Daring-Anteater conversations to a conversation dataset.""" + +import argparse +from pathlib import Path + +from datasets import load_dataset +from tqdm import tqdm +from utils import ( + dataset_splits_explanation, + id_for_conversation, + update_dataset_file_with_conversations, +) + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Load Daring-Anteater conversations.") + + parser.add_argument( + "--output-split-name", + type=str, + default="daring-anteater", + help=dataset_splits_explanation("daring-anteater"), + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("input_conversations/"), + help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", + ) + + return parser.parse_args() + + +async def main(args: argparse.Namespace) -> None: + ds = load_dataset("nvidia/Daring-Anteater", split="train", streaming=False) + input_conversations = [] + for i in tqdm( + range(len(ds)), + desc="Loading Daring-Anteater dataset", + total=len(ds), + ): + conversations = ds[i]["conversations"] + if conversations and isinstance(conversations, list): + prompt_id = f"daring-anteater-{i:05}_" + id_for_conversation(conversations) + processed_conversations = [] + for msg in conversations: + if "from" in msg: + role = msg["from"].lower() + elif "role" in msg: + role = msg["role"].lower() + else: + continue + if role == "human": + role = "user" + elif role == "gpt": + role = "assistant" + + if "value" in msg: + content = msg["value"] + elif "text" in msg: + content = msg["text"] + elif "content" in msg: + content = msg["content"] + else: + continue + content = content.strip() + if content: + processed_conversations.append({"role": role, "content": content}) + + input_conversations.append( + {"conversation_id": prompt_id, "conversations": processed_conversations} + ) + + print(f"Loaded {len(input_conversations)} prompts from Daring-Anteater.") + + update_dataset_file_with_conversations( + input_conversations, args.output_dir, args.output_split_name + ) + + +if __name__ == "__main__": + import asyncio + + args = parse_args() + asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/add_mtbench.py b/examples/speculative_decoding/prepare_input_conversations/add_mtbench.py new file mode 100644 index 00000000..76f090cd --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/add_mtbench.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Add MTBench conversations to a conversation dataset.""" + +import argparse +import json +from pathlib import Path + +from tqdm import tqdm +from utils import ( + dataset_splits_explanation, + download_file, + id_for_conversation, + update_dataset_file_with_conversations, +) + +MTBENCH_QUESTIONS_URL = "https://raw.githubusercontent.com/lm-sys/FastChat/main/fastchat/llm_judge/data/mt_bench/question.jsonl" + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Load MTBench conversations.") + + parser.add_argument( + "--mtbench-questions-file", + type=Path, + required=False, + help="""Path to the MTBench questions.jsonl file. + If not provided, it will be downloaded and saved to ~/.cache/""", + ) + + parser.add_argument( + "--output-split-name", + type=str, + default="mtbench", + help=dataset_splits_explanation("mtbench"), + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("input_conversations/"), + help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", + ) + + return parser.parse_args() + + +async def main(args: argparse.Namespace) -> None: + # Download the MTBench questions file if not provided + if not args.mtbench_questions_file: + args.mtbench_questions_file = ( + Path("~/.cache/mtbench_questions.jsonl").expanduser().resolve() + ) + if not args.mtbench_questions_file.exists(): + print("Downloading MTBench questions dataset...") + await download_file(MTBENCH_QUESTIONS_URL, args.mtbench_questions_file) + else: + print(f"Using existing MTBench questions file {args.mtbench_questions_file}") + + # Error if we failed to download the file or if it was provided but does not exist + if not args.mtbench_questions_file.exists(): + err_msg = f"MTBench questions file {args.mtbench_questions_file} does not exist." + raise FileNotFoundError(err_msg) + + with args.mtbench_questions_file.open("r", encoding="utf-8") as f: + mtbench_raw = [json.loads(line) for line in f] + + input_conversations: list[dict] = [] + for entry in tqdm(mtbench_raw, desc="Loading MTBench", total=len(mtbench_raw)): + if not entry: + continue + prompt = entry.get("turns", [""])[0] + if not prompt: + continue + prompt_id = f"mtbench-{entry['question_id']:03}_" + id_for_conversation(prompt) + input_conversations.append( + {"conversation_id": prompt_id, "conversations": [{"role": "user", "content": prompt}]} + ) + + print(f"Loaded {len(input_conversations)} filtered conversations from MTBench.") + + update_dataset_file_with_conversations( + input_conversations, args.output_dir, args.output_split_name + ) + + +if __name__ == "__main__": + import asyncio + + args = parse_args() + asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py b/examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py new file mode 100644 index 00000000..5ea90cfe --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Add ShareGPT conversations to a conversation dataset.""" + +import argparse +import json +from pathlib import Path + +from tqdm import tqdm +from utils import ( + dataset_splits_explanation, + download_file, + id_for_conversation, + update_dataset_file_with_conversations, +) + +SHAREGPT_DATASET_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Load ShareGPT conversations.") + + parser.add_argument( + "--sharegpt-file", + type=Path, + required=False, + help="""Path to the ShareGPT JSON file containing conversations. + If not provided, it will be downloaded and saved to ~/.cache/""", + ) + + parser.add_argument( + "--output-split-name", + type=str, + default="sharegpt", + help=dataset_splits_explanation("sharegpt"), + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("input_conversations/"), + help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", + ) + + return parser.parse_args() + + +def parse_sharegpt_conversation(sharegpt_conv: dict) -> list[dict] | None: + """Parse a ShareGPT conversation into a list of messages.""" + msgs = [] + for turn in sharegpt_conv.get("conversations", []): + if turn.get("from") in ["human", "user"]: + role = "user" + elif turn.get("from") in ["gpt", "chatgpt", "bard"]: + role = "assistant" + elif turn.get("from") == "system": + # ShareGPT system messages are metadata, skip them + continue + elif turn.get("from") == "bing": + # Bing conversations are skipped for training, omit it + return None + else: + err_msg = f"Unknown role in conversation: {turn.get('from')}" + raise ValueError(err_msg) + + value = turn.get("value", "").strip() + if value: + msgs.append({"role": role, "content": value}) + + return msgs + + +async def main(args: argparse.Namespace) -> None: + # Download the ShareGPT dataset if not provided + if not args.sharegpt_file: + args.sharegpt_file = Path("~/.cache/sharegpt.json").expanduser().resolve() + if not args.sharegpt_file.exists(): + print("Downloading ShareGPT dataset...") + await download_file(SHAREGPT_DATASET_URL, args.sharegpt_file) + else: + print(f"Using existing ShareGPT file at {args.sharegpt_file}") + + # Error if we failed to download the file or if it was provided but does not exist + if not args.sharegpt_file.exists(): + err_msg = f"ShareGPT file {args.sharegpt_file} does not exist." + raise FileNotFoundError(err_msg) + + with args.sharegpt_file.open("r", encoding="utf-8") as f: + sharegpt_raw = json.load(f) + + input_conversations: list[dict] = [] + for source_conv in tqdm(sharegpt_raw, desc="Loading ShareGPT", total=len(sharegpt_raw)): + msgs = parse_sharegpt_conversation(source_conv) + if not msgs: + continue + cid = source_conv.get("id") + conv_id = id_for_conversation(msgs) + if cid: + cid = f"{cid}_{conv_id}" + else: + cid = conv_id + cid = f"sharegpt-{cid}" + + input_conversations.append({"conversation_id": cid, "conversations": msgs}) + + print(f"Loaded {len(input_conversations)} filtered conversations from ShareGPT.") + + update_dataset_file_with_conversations( + input_conversations, args.output_dir, args.output_split_name + ) + + +if __name__ == "__main__": + import asyncio + + args = parse_args() + asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py b/examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py new file mode 100644 index 00000000..2c5f5c74 --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Add UltraChat conversations to a conversation dataset.""" + +import argparse +from pathlib import Path + +from datasets import load_dataset +from tqdm import tqdm +from utils import ( + dataset_splits_explanation, + id_for_conversation, + update_dataset_file_with_conversations, +) + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser(description="Load UltraChat conversations.") + + parser.add_argument( + "--ultrachat-split", + type=str, + default="train_sft", + help="Split of the HuggingFace UltraChat dataset to load. Default is 'train_sft'.", + ) + + parser.add_argument( + "--output-split-name", + type=str, + default="ultrachat", + help=dataset_splits_explanation("ultrachat"), + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("input_conversations/"), + help="Path to save the conversations file(s) into. Default is 'input_conversations/'.", + ) + + return parser.parse_args() + + +async def main(args: argparse.Namespace) -> None: + ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=args.ultrachat_split, streaming=False) + input_conversations = [] + for i in tqdm( + range(len(ds)), + desc=f"Loading UltraChat split {args.ultrachat_split}", + total=len(ds), + ): + prompt = ds[i]["prompt"].strip() + prompt_id = ds[i]["prompt_id"].strip() + if prompt and prompt_id: + msgs = [{"role": "user", "content": prompt}] + prompt_id = ( + f"ultrachat-{args.ultrachat_split}_{i:06}-{prompt_id}_" + id_for_conversation(msgs) + ) + input_conversations.append({"conversation_id": prompt_id, "conversations": msgs}) + + print(f"Loaded {len(input_conversations)} prompts from UltraChat.") + + update_dataset_file_with_conversations( + input_conversations, args.output_dir, args.output_split_name + ) + + +if __name__ == "__main__": + import asyncio + + args = parse_args() + asyncio.run(main(args)) diff --git a/examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh b/examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh new file mode 100644 index 00000000..fa1319b8 --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example script to prepare a dataset of prompts for generation +# Lines in this script can be uncommented to include specific datasets/splits in the prompt dataset. + +python3 prepare_input_conversations/add_daring_anteater.py --output-split-name train +# python3 prepare_input_conversations/add_sharegpt.py --output-split-name train +# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_sft --output-split-name train +# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split train_gen --output-split-name train +# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_sft --output-split-name mix_test +# python3 prepare_input_conversations/add_ultrachat.py --ultrachat-split test_gen --output-split-name mix_test +python3 prepare_input_conversations/add_mtbench.py --output-split-name mix_test diff --git a/examples/speculative_decoding/prepare_input_conversations/utils.py b/examples/speculative_decoding/prepare_input_conversations/utils.py new file mode 100644 index 00000000..6a3698f3 --- /dev/null +++ b/examples/speculative_decoding/prepare_input_conversations/utils.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions.""" + +import hashlib +import json +import random +from pathlib import Path + +import aiohttp + + +async def download_file(url: str, destination: Path) -> None: + """Download a file from a URL to a specified destination.""" + destination.parent.mkdir(parents=True, exist_ok=True) + async with aiohttp.ClientSession() as session, session.get(url) as response: + if response.status != 200: + msg = f"Failed to download {url}: {response.status}" + raise RuntimeError(msg) + content = await response.read() + destination.write_bytes(content) + print(f"Downloaded {url} to {destination}") + + +def id_for_conversation(conversation: list) -> str: + """Generate a unique ID for a conversation based on its content.""" + json_str = json.dumps(conversation, sort_keys=True, separators=(",", ":"), ensure_ascii=False) + json_bytes = json_str.encode("utf-8") + return hashlib.sha256(json_bytes).hexdigest() + + +def add_conversations_to_split(conversations: list, dataset_dir: Path, split: str) -> None: + """Add conversations to a specific split in the dataset.""" + if len(conversations) == 0: + return + + # Open the dataset file for the specified split, or create it if it doesn't exist + dataset_file = dataset_dir / f"{split}.jsonl" + all_conversations = [] + if dataset_file.exists(): + # load the existing conversations + with dataset_file.open("r", encoding="utf-8") as f: + all_conversations.extend([json.loads(line) for line in f if line.strip()]) + + if any(not entry.get("conversation_id") for entry in all_conversations): + msg = "All existing conversations must have a 'conversation_id' field." + raise ValueError(msg) + + existing_ids = {entry["conversation_id"] for entry in all_conversations} + num_new_entries = 0 + num_duplicates = 0 + for entry in conversations: + if entry.get("conversation_id") is None: + raise ValueError("Each conversation must have a 'conversation_id' field.") + if entry["conversation_id"] not in existing_ids: + all_conversations.append( + { + "conversation_id": entry["conversation_id"], + "conversations": entry["conversations"], + } + ) + num_new_entries += 1 + else: + num_duplicates += 1 + + if num_duplicates > 0: + print( + f"Added {num_new_entries} new conversations to {dataset_file}, " + f"skipped {num_duplicates} existing entries." + ) + else: + print(f"Added {num_new_entries} new conversations to {dataset_file}.") + + dataset_dir.mkdir(parents=True, exist_ok=True) + with dataset_file.open("w", encoding="utf-8") as f: + for entry in all_conversations: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + +def mix_conversations_and_add_to_splits( + conversations: list, + dataset_dir: Path, + train_ratio: float, + val_ratio: float, + test_ratio: float, + *, + shuffle: bool = True, + seed: int = 42, +) -> None: + """Mix the conversations and add to the dataset's train, val, and test splits.""" + if train_ratio + val_ratio + test_ratio != 1.0: + msg = "Ratios must sum to 1.0" + raise ValueError(msg) + if any(ratio < 0 for ratio in [train_ratio, val_ratio, test_ratio]): + msg = "Ratios must be non-negative" + raise ValueError(msg) + + total_conversations = len(conversations) + train_count = int(total_conversations * train_ratio) + val_count = int(total_conversations * val_ratio) + + if shuffle: + random.seed(seed) + random.shuffle(conversations) + + train_conversations = conversations[:train_count] + val_conversations = conversations[train_count : train_count + val_count] + test_conversations = conversations[train_count + val_count :] + add_conversations_to_split(train_conversations, dataset_dir, "train") + add_conversations_to_split(val_conversations, dataset_dir, "val") + add_conversations_to_split(test_conversations, dataset_dir, "test") + + +def update_dataset_file_with_conversations( + conversations: list, dataset_dir: Path, dataset_split: str +) -> None: + """ + Update a dataset file with new conversations. The conversations are added to the specified + split in the dataset. If the split is 'mix' or 'mix_test', the conversations are mixed and + distributed into train, val, and test splits according to predefined ratios. + """ + if dataset_split == "mix": + print("Mixing conversations and adding to train, val, and test splits.") + mix_conversations_and_add_to_splits( + conversations, + dataset_dir, + train_ratio=0.8, + val_ratio=0.1, + test_ratio=0.1, + ) + elif dataset_split == "mix_test": + print("Mixing conversations and adding to val and test splits.") + mix_conversations_and_add_to_splits( + conversations, + dataset_dir, + train_ratio=0.0, + val_ratio=0.5, + test_ratio=0.5, + ) + else: + add_conversations_to_split(conversations, dataset_dir, dataset_split) + + +def dataset_splits_explanation(default_split: str) -> str: + """Return an explanation string for the dataset split argument.""" + return f"""Split to assign the processed conversations to. + Can be any name, or one of ['mix', 'mix_test']. + Default is '{default_split}'. + + If the provided split name matches an existing file in the dataset directory, + the new conversations will be added to that file, + avoiding duplicates based on conversation IDs. + + Special split names: + - 'mix': Conversations will be randomly mixed and distributed into + 'train' (80%%), 'val' (10%%), and 'test' (10%%) splits. + - 'mix_test': Conversations will be randomly mixed and distributed into + 'val' (50%%) and 'test' (50%%) splits. + """ diff --git a/examples/speculative_decoding/train_eagle3_and_export.sh b/examples/speculative_decoding/train_eagle3_and_export.sh old mode 100644 new mode 100755 index 042a2765..6af635be --- a/examples/speculative_decoding/train_eagle3_and_export.sh +++ b/examples/speculative_decoding/train_eagle3_and_export.sh @@ -38,6 +38,10 @@ while [[ $# -gt 0 ]]; do DATA="$2" shift; shift ;; + --offline_data) + OFFLINE_DATA_PATH="$2" + shift; shift + ;; *) echo "Unknown argument: $1" exit 1 @@ -54,6 +58,12 @@ else export CUDA_VISIBLE_DEVICES="$devs" fi +if [[ "$OFFLINE_DATA_PATH" != "" ]]; then + OFFLINE_DATA_ARGS="--offline-data $OFFLINE_DATA_PATH" +else + OFFLINE_DATA_ARGS="" +fi + MODEL_BASENAME=$(basename "$BASE_MODEL") echo "==== [1/3] Training draft model ====" @@ -61,6 +71,7 @@ OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) mkdir -p "$(dirname "$OUTPUT_DIR")" ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ + $OFFLINE_DATA_ARGS \ --data $DATA \ --num_gpu $NUM_GPU \ --num_epochs 2 \ diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index e1da326d..85e5faae 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -311,7 +311,7 @@ def forward( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -330,11 +330,16 @@ def forward( @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +@OfflineEagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFEagleModel(EagleModel): """Eagle Model Class for huggingface models.""" def _set_default_aux_hidden_state_layers(self): + # Read a custom config attribute since we override num_hidden_layers for offline training num_layers = self.config.num_hidden_layers + if self.eagle_offline and (num_layers is None or num_layers <= 0): + num_layers = getattr(self.config, "num_orig_hidden_layers", 0) + self.eagle_config.eagle_aux_hidden_state_layer_ids = [ 1, max(0, num_layers // 2 - 1), @@ -417,7 +422,11 @@ def modify( ) self.eagle_rotary_emb = LlamaRotaryEmbedding(config=self.eagle_config) - if hasattr(self.model.layers[-1].self_attn, "o_proj"): + if eagle_offline: + # For offline training, the base model has no layers. + # Read the device from the lm_head instead. + device = self.lm_head.weight.device + elif hasattr(self.model.layers[-1].self_attn, "o_proj"): device = self.model.layers[-1].self_attn.o_proj.weight.device elif hasattr(self.model.layers[-1].self_attn, "q_proj"): device = self.model.layers[-1].self_attn.q_proj.weight.device @@ -799,7 +808,7 @@ def forward( if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: base_model_logits = self._map_logits_to_draft_vocab(base_model_logits) base_model_loss = None - past_key_values = None + past_key_values = DynamicCache() # Dummy cache else: base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = ( @@ -1141,13 +1150,6 @@ def pseudo_speculative_generate( return base_token, draft_tokens -@OfflineEagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) -class DetachedHFEagleModel(HFEagleModel): - """A wrapper for detached Eagle module.""" - - # TODO: Implement DetachedHFEagleModel class for offline eagle. - - class HFARValidation(AcceptanceRateValidation): """This is the subclass for HF model AR validation."""