|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Extract hidden states from an HF-compatible LLM.""" |
| 17 | + |
| 18 | +import os |
| 19 | + |
| 20 | +os.environ["TLLM_LOG_LEVEL"] = "error" |
| 21 | +import argparse |
| 22 | +import asyncio |
| 23 | +import json |
| 24 | +from pathlib import Path |
| 25 | + |
| 26 | +import torch |
| 27 | +from tensorrt_llm import LLM, SamplingParams |
| 28 | +from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig |
| 29 | +from tqdm import tqdm as tqdm |
| 30 | +from transformers import AutoConfig, AutoTokenizer |
| 31 | + |
| 32 | +REMOVE_THINK_CHAT_TEMPLATE = ( |
| 33 | + "{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}" |
| 34 | +) |
| 35 | + |
| 36 | + |
| 37 | +def parse_args() -> argparse.Namespace: |
| 38 | + parser = argparse.ArgumentParser( |
| 39 | + description="""Collect hidden states from conversations |
| 40 | + by running full conversations through a Hugging Face model.""" |
| 41 | + ) |
| 42 | + |
| 43 | + ## Model & Generation Parameters ## |
| 44 | + parser.add_argument( |
| 45 | + "--model", |
| 46 | + type=str, |
| 47 | + required=True, |
| 48 | + help="Name of the served model.", |
| 49 | + ) |
| 50 | + |
| 51 | + ## Client Parameters ## |
| 52 | + parser.add_argument( |
| 53 | + "--max-seq-len", |
| 54 | + type=int, |
| 55 | + default=3072, |
| 56 | + help="""Maximum number of tokens in a conversation. Longer conversations will be skipped. |
| 57 | + Defaults to 3072 tokens.""", |
| 58 | + ) |
| 59 | + |
| 60 | + ## I/O Parameters ## |
| 61 | + parser.add_argument( |
| 62 | + "--input-file", |
| 63 | + type=Path, |
| 64 | + required=True, |
| 65 | + help="""Path to the input `jsonl` file containing conversations. |
| 66 | + Each entry must have a unique `conversation_id` field and a `conversations` field |
| 67 | + containing a list of messages.""", |
| 68 | + ) |
| 69 | + parser.add_argument( |
| 70 | + "--output-dir", |
| 71 | + type=Path, |
| 72 | + required=True, |
| 73 | + help="""Root directory in which to save the hidden states. |
| 74 | + The data will be saved as a torch (`.pt`) dump file for each conversation.""", |
| 75 | + ) |
| 76 | + parser.add_argument( |
| 77 | + "--debug-max-num-conversations", |
| 78 | + type=int, |
| 79 | + default=None, |
| 80 | + help="""For debugging purposes, limit the number of conversations processed. |
| 81 | + Default is None, meaning no limit.""", |
| 82 | + ) |
| 83 | + parser.add_argument( |
| 84 | + "--dp-rank", |
| 85 | + type=int, |
| 86 | + default=0, |
| 87 | + help="""Data parallel rank.""", |
| 88 | + ) |
| 89 | + parser.add_argument( |
| 90 | + "--use-cuda-graph", |
| 91 | + type=bool, |
| 92 | + default=True, |
| 93 | + help="""Whether to use CUDA graph.""", |
| 94 | + ) |
| 95 | + parser.add_argument( |
| 96 | + "--tp", |
| 97 | + type=int, |
| 98 | + default=1, |
| 99 | + help="""tensor_parallel_size for TRTLLM.""", |
| 100 | + ) |
| 101 | + # moe_ep * moe_tp * moe_cp should be equal to tp |
| 102 | + # REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html |
| 103 | + parser.add_argument( |
| 104 | + "--moe_ep", |
| 105 | + type=int, |
| 106 | + default=1, |
| 107 | + help="""moe_expert_parallel_size for TRTLLM.""", |
| 108 | + ) |
| 109 | + parser.add_argument( |
| 110 | + "--moe_tp", |
| 111 | + type=int, |
| 112 | + default=1, |
| 113 | + help="""moe_tensor_parallel_size for TRTLLM.""", |
| 114 | + ) |
| 115 | + parser.add_argument( |
| 116 | + "--moe_cp", |
| 117 | + type=int, |
| 118 | + default=1, |
| 119 | + help="""moe_cluster_parallel_size for TRTLLM.""", |
| 120 | + ) |
| 121 | + |
| 122 | + return parser.parse_args() |
| 123 | + |
| 124 | + |
| 125 | +def main(args: argparse.Namespace) -> None: |
| 126 | + # Load conversations |
| 127 | + all_conversations = [] |
| 128 | + with args.input_file.open("r", encoding="utf-8") as f: |
| 129 | + all_conversations.extend([json.loads(line) for line in f if line.strip()]) |
| 130 | + print("Loaded", len(all_conversations), "conversations from", args.input_file) |
| 131 | + |
| 132 | + # Remove conversations whose output file already exists |
| 133 | + filtered_conversations = [] |
| 134 | + for entry in all_conversations: |
| 135 | + conversation_id = entry.get("conversation_id", None) |
| 136 | + if conversation_id is None: |
| 137 | + filtered_conversations.append(entry) |
| 138 | + continue |
| 139 | + output_file = args.output_dir / f"{conversation_id}.pt" |
| 140 | + if output_file.exists(): |
| 141 | + continue |
| 142 | + filtered_conversations.append(entry) |
| 143 | + print( |
| 144 | + "Removed", |
| 145 | + len(all_conversations) - len(filtered_conversations), |
| 146 | + "conversations due to existing output files", |
| 147 | + ) |
| 148 | + all_conversations = filtered_conversations |
| 149 | + |
| 150 | + # Get model config and tokenizer |
| 151 | + model_config = AutoConfig.from_pretrained(args.model) |
| 152 | + num_hidden_layers = getattr(model_config, "num_hidden_layers", None) |
| 153 | + tokenizer = AutoTokenizer.from_pretrained(args.model) |
| 154 | + if tokenizer.pad_token is None: |
| 155 | + tokenizer.pad_token = tokenizer.eos_token |
| 156 | + tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") |
| 157 | + |
| 158 | + # Set up LLM |
| 159 | + llm_common_config = { |
| 160 | + "model": args.model, |
| 161 | + "attn_backend": "TRTLLM", |
| 162 | + "disable_overlap_scheduler": False, |
| 163 | + "cuda_graph_config": CudaGraphConfig(batch_sizes=[1, 2, 4]) |
| 164 | + if args.use_cuda_graph |
| 165 | + else None, |
| 166 | + "max_batch_size": 16, |
| 167 | + "kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5), |
| 168 | + "enable_chunked_prefill": False, |
| 169 | + "tensor_parallel_size": args.tp, |
| 170 | + "moe_expert_parallel_size": args.moe_ep, |
| 171 | + "moe_tensor_parallel_size": args.moe_tp, |
| 172 | + "moe_cluster_parallel_size": args.moe_cp, |
| 173 | + } |
| 174 | + spec_config = { |
| 175 | + "output_directory": str(args.output_dir), |
| 176 | + "write_interval": 1, |
| 177 | + "file_prefix": f"dp_{args.dp_rank}", |
| 178 | + "eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}, |
| 179 | + } |
| 180 | + sampling_params = SamplingParams(max_tokens=32, temperature=0) |
| 181 | + |
| 182 | + llm_spec = LLM( |
| 183 | + **llm_common_config, speculative_config=SaveHiddenStatesDecodingConfig(**spec_config) |
| 184 | + ) |
| 185 | + |
| 186 | + args.output_dir.mkdir(parents=True, exist_ok=True) |
| 187 | + num_skipped_too_long = 0 |
| 188 | + num_invalid = 0 |
| 189 | + num_success = 0 |
| 190 | + num_total_conversations = min( |
| 191 | + len(all_conversations), args.debug_max_num_conversations or len(all_conversations) |
| 192 | + ) |
| 193 | + pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations") |
| 194 | + |
| 195 | + def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): |
| 196 | + """Post-process the TRTLLM dumped file to same format as HF dumped: |
| 197 | + 1. Remove id field, replace it with conversation_id |
| 198 | + 2. Rename hidden_state field to hidden_states |
| 199 | + 3. From list of length 1 to dict |
| 200 | + 4. Rename file to conversation_id.pt |
| 201 | + """ |
| 202 | + with open(trtllm_dumped_file, "rb") as f: |
| 203 | + trtllm_dumped = torch.load(f) |
| 204 | + assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, ( |
| 205 | + "TRTLLM dumped should be a list with one element" |
| 206 | + ) |
| 207 | + assert ( |
| 208 | + isinstance(trtllm_dumped[0], dict) |
| 209 | + and "id" in trtllm_dumped[0] |
| 210 | + and "hidden_state" in trtllm_dumped[0] |
| 211 | + ), "TRTLLM dumped should have an 'id' and 'hidden_states' field" |
| 212 | + trtllm_dumped = trtllm_dumped[0] |
| 213 | + trtllm_dumped.pop("id") |
| 214 | + trtllm_dumped["conversation_id"] = conversation_id |
| 215 | + trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state") |
| 216 | + output_file = args.output_dir / f"{conversation_id}.pt" |
| 217 | + with open(output_file, "wb") as f: |
| 218 | + torch.save(trtllm_dumped, f) |
| 219 | + |
| 220 | + if trtllm_dumped_file.exists(): |
| 221 | + trtllm_dumped_file.unlink() |
| 222 | + |
| 223 | + async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]): |
| 224 | + nonlocal num_success |
| 225 | + await llm_spec.generate_async(input_ids, sampling_params) |
| 226 | + # TRTLLM API name files starts from 1 |
| 227 | + # ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012 |
| 228 | + trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt" |
| 229 | + _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id) |
| 230 | + num_success += 1 |
| 231 | + pbar.update(1) |
| 232 | + |
| 233 | + async def submit_generates(): |
| 234 | + nonlocal num_skipped_too_long |
| 235 | + nonlocal num_invalid |
| 236 | + tasks = [] |
| 237 | + for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]): |
| 238 | + conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) |
| 239 | + |
| 240 | + conversations = entry["conversations"] |
| 241 | + if not conversations or not isinstance(conversations, list): |
| 242 | + num_invalid += 1 |
| 243 | + continue |
| 244 | + |
| 245 | + input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False) |
| 246 | + num_input_tokens = ( |
| 247 | + input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids) |
| 248 | + ) |
| 249 | + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: |
| 250 | + num_skipped_too_long += 1 |
| 251 | + continue |
| 252 | + |
| 253 | + tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) |
| 254 | + await asyncio.gather(*tasks) |
| 255 | + |
| 256 | + asyncio.run(submit_generates()) |
| 257 | + llm_spec.shutdown() |
| 258 | + print("LLM shutdown") |
| 259 | + |
| 260 | + if num_skipped_too_long > 0: |
| 261 | + print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") |
| 262 | + if num_invalid > 0: |
| 263 | + print(f"Skipped {num_invalid} invalid conversations without proper fields.") |
| 264 | + |
| 265 | + if num_success == num_total_conversations: |
| 266 | + print(f"Successfully processed all {num_success} conversations.") |
| 267 | + else: |
| 268 | + print( |
| 269 | + f"Successfully processed {num_success} out of {num_total_conversations} conversations." |
| 270 | + ) |
| 271 | + |
| 272 | + |
| 273 | +if __name__ == "__main__": |
| 274 | + cli_args = parse_args() |
| 275 | + main(cli_args) |
0 commit comments