From 450c7739376369de7e929309a5dd2c7fde37c8c6 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Tue, 7 Oct 2025 02:32:40 +0000 Subject: [PATCH 1/5] add scripts for trtllm eagle dumper Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../compute_hidden_states_trtllm.py | 234 ++++++++++++++++++ .../run_hf_compute_hiddens_dp.sh | 5 +- .../run_trtllm_compute_hiddens.sh | 23 ++ .../run_trtllm_compute_hiddens_dp.sh | 37 +++ 4 files changed, 297 insertions(+), 2 deletions(-) create mode 100644 examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py create mode 100644 examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh create mode 100644 examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py new file mode 100644 index 000000000..3fd1e2eeb --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extract hidden states from an HF-compatible LLM.""" + +import os + +os.environ["TLLM_LOG_LEVEL"] = "error" +import argparse +import asyncio +import json +from pathlib import Path + +import torch +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig +from tqdm import tqdm as tqdm +from transformers import AutoConfig, AutoTokenizer + +REMOVE_THINK_CHAT_TEMPLATE = ( + "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="""Collect hidden states from conversations + by running full conversations through a Hugging Face model.""" + ) + + ## Model & Generation Parameters ## + parser.add_argument( + "--model", + type=str, + required=True, + help="Name of the served model.", + ) + + ## Client Parameters ## + parser.add_argument( + "--max-seq-len", + type=int, + default=3072, + help="""Maximum number of tokens in a conversation. Longer conversations will be skipped. + Defaults to 3072 tokens.""", + ) + + ## I/O Parameters ## + parser.add_argument( + "--input-file", + type=Path, + required=True, + help="""Path to the input `jsonl` file containing conversations. + Each entry must have a unique `conversation_id` field and a `conversations` field + containing a list of messages.""", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="""Root directory in which to save the hidden states. + The data will be saved as a torch (`.pt`) dump file for each conversation.""", + ) + parser.add_argument( + "--debug-max-num-conversations", + type=int, + default=None, + help="""For debugging purposes, limit the number of conversations processed. + Default is None, meaning no limit.""", + ) + parser.add_argument( + "--dp-rank", + type=int, + default=0, + help="""Data parallel rank.""", + ) + parser.add_argument( + "--use-cuda-graph", + type=bool, + default=True, + help="""Whether to use CUDA graph.""", + ) + parser.add_argument( + "--tp-size-per-dp", + type=int, + default=2, + help="""Tensor parallel size per data parallel.""", + ) + + return parser.parse_args() + + +def main(args: argparse.Namespace) -> None: + # Load conversations + all_conversations = [] + with args.input_file.open("r", encoding="utf-8") as f: + all_conversations.extend([json.loads(line) for line in f if line.strip()]) + print("Loaded", len(all_conversations), "conversations from", args.input_file) + + # Get model config and tokenizer + model_config = AutoConfig.from_pretrained(args.model) + num_hidden_layers = getattr(model_config, "num_hidden_layers", None) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") + + # Set up LLM + llm_common_config = { + "model": args.model, + "attn_backend": "TRTLLM", + "disable_overlap_scheduler": False, + "cuda_graph_config": CudaGraphConfig(batch_sizes=[1, 2, 4]) + if args.use_cuda_graph + else None, + "max_batch_size": 16, + "kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5), + "enable_chunked_prefill": False, + "tensor_parallel_size": args.tp_size_per_dp, + } + spec_config = { + "output_directory": str(args.output_dir), + "write_interval": 1, + "file_prefix": f"dp_{args.dp_rank}", + "eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}, + } + sampling_params = SamplingParams(max_tokens=32, temperature=0) + + llm_spec = LLM( + **llm_common_config, speculative_config=SaveHiddenStatesDecodingConfig(**spec_config) + ) + + args.output_dir.mkdir(parents=True, exist_ok=True) + num_skipped_too_long = 0 + num_invalid = 0 + num_success = 0 + num_total_conversations = min( + len(all_conversations), args.debug_max_num_conversations or len(all_conversations) + ) + pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations") + + def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): + """Post-process the TRTLLM dumped file to same format as HF dumped: + 1. Remove id field, replace it with conversation_id + 2. Rename hidden_state field to hidden_states + 3. From list of length 1 to dict + 4. Rename file to conversation_id.pt + """ + with open(trtllm_dumped_file, "rb") as f: + trtllm_dumped = torch.load(f) + assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, ( + "TRTLLM dumped should be a list with one element" + ) + assert ( + isinstance(trtllm_dumped[0], dict) + and "id" in trtllm_dumped[0] + and "hidden_state" in trtllm_dumped[0] + ), "TRTLLM dumped should have an 'id' and 'hidden_states' field" + trtllm_dumped = trtllm_dumped[0] + trtllm_dumped.pop("id") + trtllm_dumped["conversation_id"] = conversation_id + trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state") + output_file = args.output_dir / f"{conversation_id}.pt" + with open(output_file, "wb") as f: + torch.save(trtllm_dumped, f) + + if trtllm_dumped_file.exists(): + trtllm_dumped_file.unlink() + + async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]): + nonlocal num_success + await llm_spec.generate_async(input_ids, sampling_params) + # TRTLLM API name files starts from 1 + # ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012 + trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt" + _post_process_trtllm_dumped(trtllm_dumped_file, conversation_id) + num_success += 1 + pbar.update(1) + + async def submit_generates(): + nonlocal num_skipped_too_long + nonlocal num_invalid + tasks = [] + for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]): + conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) + + conversations = entry["conversations"] + if not conversations or not isinstance(conversations, list): + num_invalid += 1 + continue + + input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False) + num_input_tokens = ( + input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids) + ) + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_skipped_too_long += 1 + continue + + tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) + await asyncio.gather(*tasks) + + asyncio.run(submit_generates()) + llm_spec.shutdown() + print("LLM shutdown") + + if num_skipped_too_long > 0: + print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") + if num_invalid > 0: + print(f"Skipped {num_invalid} invalid conversations without proper fields.") + + if num_success == num_total_conversations: + print(f"Successfully processed all {num_success} conversations.") + else: + print( + f"Successfully processed {num_success} out of {num_total_conversations} conversations." + ) + + +if __name__ == "__main__": + cli_args = parse_args() + main(cli_args) diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh index e9c3a2cdb..31e2294d9 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh @@ -24,10 +24,11 @@ INPUT_FILE=synthetic_conversations/daring-anteater.jsonl OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +DP_SIZE=8 -split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- +split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- -for i in $(seq 0 7) +for i in $(seq 0 $((DP_SIZE-1))) do CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & done diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh new file mode 100644 index 000000000..eb285087b --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example usage of the script to compute the hidden states for a conversation dataset +# This script computes hidden states using TensorRT-LLM and saves them to +# the specified output directory. + +python3 collect_hidden_states/compute_hidden_states_trtllm.py \ + --model meta-llama/Llama-3.2-1B-Instruct \ + --input-file synthetic_conversations/daring-anteater.jsonl \ + --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh new file mode 100644 index 000000000..23bef155c --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example usage of the script to compute the hidden states for a conversation dataset +# This script computes hidden states using a Hugging Face model and saves them to +# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting +# the input file into 8 parts and running 8 processes in parallel, one on each GPU. + +# Note: depending on the write-throughput of the destination disk, this is not guaranteed +# to yield a speed improvement compared to running the model-parallel version. Consider +# benchmarking on a smaller dataset before launching a large run. + +INPUT_FILE=synthetic_conversations/daring-anteater.jsonl +OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ +DP_SIZE=8 + +split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- + +for i in $(seq 0 $((DP_SIZE-1))) +do +CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & +done +wait + +rm /tmp/part-*.jsonl \ No newline at end of file From 067bc1adfe142d3b0d31bd200c43399648e38ef0 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 8 Oct 2025 00:27:25 +0000 Subject: [PATCH 2/5] add slurm support in scripts Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../run_trtllm_compute_hiddens.sh | 4 +++- .../run_trtllm_compute_hiddens_dp.sh | 11 +++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh index eb285087b..487d0d69d 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh @@ -17,7 +17,9 @@ # This script computes hidden states using TensorRT-LLM and saves them to # the specified output directory. +export TLLM_LOG_LEVEL="error"; python3 collect_hidden_states/compute_hidden_states_trtllm.py \ --model meta-llama/Llama-3.2-1B-Instruct \ --input-file synthetic_conversations/daring-anteater.jsonl \ - --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file + --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ + \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh index 23bef155c..9096f476b 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh @@ -25,13 +25,20 @@ INPUT_FILE=synthetic_conversations/daring-anteater.jsonl OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ DP_SIZE=8 +MODEL=meta-llama/Llama-3.2-1B-Instruct +export TLLM_LOG_LEVEL="error"; split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part- for i in $(seq 0 $((DP_SIZE-1))) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & + +export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & + +# #On SLURM: +# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i + done wait -rm /tmp/part-*.jsonl \ No newline at end of file +rm /tmp/part-*.jsonl From dfd0c626fd184ce5d692aef8f587c7fe3783699d Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 8 Oct 2025 00:29:21 +0000 Subject: [PATCH 3/5] minor Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../collect_hidden_states/run_trtllm_compute_hiddens_dp.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh index 9096f476b..73289c0df 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh @@ -14,7 +14,7 @@ # limitations under the License. # Example usage of the script to compute the hidden states for a conversation dataset -# This script computes hidden states using a Hugging Face model and saves them to +# This script computes hidden states using TensorRT-LLM and saves them to # the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting # the input file into 8 parts and running 8 processes in parallel, one on each GPU. From a4342e2908972873b1ada6533d63866b72df6b96 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 8 Oct 2025 00:52:03 +0000 Subject: [PATCH 4/5] update instructions for trtllm dumper in readme Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/README.md | 29 +++++++++++++++++++------ 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index aeea25adb..31133e5df 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -82,9 +82,22 @@ The saved modelopt checkpoint is similar in architecture to HF models. It can be ## Training Draft Model with Offline Base Model -For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of storage depending on dataset size. +For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of disk storage depending on dataset size. -First, dump the base model's hidden states with the following command: +### Dumpping Hidden States to Disk + +We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM: + +```bash +python collect_hidden_states/compute_hidden_states_trtllm.py \ + --model $BASE_MODEL \ + --input-file Daring-Anteater/train.jsonl \ + --output-dir $HIDDEN_STATES_DIR +``` + +**NOTE**: TRT-LLM installation needed for the above command. + +Alternatively, you can generate the same hidden states with HF: ```bash python collect_hidden_states/compute_hidden_states_hf.py \ @@ -93,9 +106,11 @@ python collect_hidden_states/compute_hidden_states_hf.py \ --output-dir $HIDDEN_STATES_DIR ``` -See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation. +**NOTE**: See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) and [`run_trtllm_compute_hiddens_dp.sh`](./collect_hidden_states/run_trtllm_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation. + +### Train Draft Model with Dumped Hidden States -Then, train draft model with `--offline-data` argument: +Once we finish dumping hidden states, launch offline training with an extra `--offline-data` argument: ```bash ./launch_train.sh --model $BASE_MODEL \ @@ -109,13 +124,13 @@ Then, train draft model with `--offline-data` argument: ## Model Validation -After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by: +For online training checkpoints, we can run in-framework evaluation on MT-bench: ```bash -python ar_validate.py --model_path $OUTPUT_DIR +python ar_validate.py --model_path $ONLINE_CKPT ``` -Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. +Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks. ## Export From c784ad084ef49fe9587bf13777694fc7c0ea4151 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Wed, 8 Oct 2025 01:58:35 +0000 Subject: [PATCH 5/5] add ep; rename tp args; add stop and resume; Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../compute_hidden_states_trtllm.py | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py index 3fd1e2eeb..d25bd3dac 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py @@ -93,10 +93,30 @@ def parse_args() -> argparse.Namespace: help="""Whether to use CUDA graph.""", ) parser.add_argument( - "--tp-size-per-dp", + "--tp", type=int, - default=2, - help="""Tensor parallel size per data parallel.""", + default=1, + help="""tensor_parallel_size for TRTLLM.""", + ) + # moe_ep * moe_tp * moe_cp should be equal to tp + # REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html + parser.add_argument( + "--moe_ep", + type=int, + default=1, + help="""moe_expert_parallel_size for TRTLLM.""", + ) + parser.add_argument( + "--moe_tp", + type=int, + default=1, + help="""moe_tensor_parallel_size for TRTLLM.""", + ) + parser.add_argument( + "--moe_cp", + type=int, + default=1, + help="""moe_cluster_parallel_size for TRTLLM.""", ) return parser.parse_args() @@ -109,6 +129,24 @@ def main(args: argparse.Namespace) -> None: all_conversations.extend([json.loads(line) for line in f if line.strip()]) print("Loaded", len(all_conversations), "conversations from", args.input_file) + # Remove conversations whose output file already exists + filtered_conversations = [] + for entry in all_conversations: + conversation_id = entry.get("conversation_id", None) + if conversation_id is None: + filtered_conversations.append(entry) + continue + output_file = args.output_dir / f"{conversation_id}.pt" + if output_file.exists(): + continue + filtered_conversations.append(entry) + print( + "Removed", + len(all_conversations) - len(filtered_conversations), + "conversations due to existing output files", + ) + all_conversations = filtered_conversations + # Get model config and tokenizer model_config = AutoConfig.from_pretrained(args.model) num_hidden_layers = getattr(model_config, "num_hidden_layers", None) @@ -128,7 +166,10 @@ def main(args: argparse.Namespace) -> None: "max_batch_size": 16, "kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5), "enable_chunked_prefill": False, - "tensor_parallel_size": args.tp_size_per_dp, + "tensor_parallel_size": args.tp, + "moe_expert_parallel_size": args.moe_ep, + "moe_tensor_parallel_size": args.moe_tp, + "moe_cluster_parallel_size": args.moe_cp, } spec_config = { "output_directory": str(args.output_dir),