Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Extract hidden states from an HF-compatible LLM."""

import os

os.environ["TLLM_LOG_LEVEL"] = "error"
import argparse
import asyncio
import json
from pathlib import Path

import torch
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig
from tqdm import tqdm as tqdm
from transformers import AutoConfig, AutoTokenizer

REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations
by running full conversations through a Hugging Face model."""
)

## Model & Generation Parameters ##
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the served model.",
)

## Client Parameters ##
parser.add_argument(
"--max-seq-len",
type=int,
default=3072,
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
Defaults to 3072 tokens.""",
)

## I/O Parameters ##
parser.add_argument(
"--input-file",
type=Path,
required=True,
help="""Path to the input `jsonl` file containing conversations.
Each entry must have a unique `conversation_id` field and a `conversations` field
containing a list of messages.""",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="""Root directory in which to save the hidden states.
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
)
parser.add_argument(
"--debug-max-num-conversations",
type=int,
default=None,
help="""For debugging purposes, limit the number of conversations processed.
Default is None, meaning no limit.""",
)
parser.add_argument(
"--dp-rank",
type=int,
default=0,
help="""Data parallel rank.""",
)
parser.add_argument(
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
Comment on lines +90 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix boolean argument handling.

Using type=bool for boolean flags in argparse is problematic. Any non-empty string (including "False" or "0") will be interpreted as True. Use action='store_true' or action='store_false' instead.

Apply this diff:

     parser.add_argument(
         "--use-cuda-graph",
-        type=bool,
-        default=True,
+        action='store_true',
+        default=False,
         help="""Whether to use CUDA graph.""",
     )

Or if you want it enabled by default:

     parser.add_argument(
-        "--use-cuda-graph",
-        type=bool,
-        default=True,
-        help="""Whether to use CUDA graph.""",
+        "--no-cuda-graph",
+        action='store_false',
+        dest='use_cuda_graph',
+        help="""Disable CUDA graph (enabled by default).""",
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--no-cuda-graph",
action='store_false',
dest='use_cuda_graph',
help="""Disable CUDA graph (enabled by default).""",
)
Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--use-cuda-graph",
action='store_true',
default=False,
help="""Whether to use CUDA graph.""",
)
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 90 to 94, the argparse boolean flag currently uses type=bool which
treats any non-empty string as True; change the argument to use
action='store_true' (or action='store_false' if you prefer the inverse) so the
flag is parsed correctly, and set default accordingly (omit type and use
default=True/False as needed or use action='store_false' to have it enabled by
default).

parser.add_argument(
"--tp-size-per-dp",
type=int,
default=2,
help="""Tensor parallel size per data parallel.""",
)

return parser.parse_args()


def main(args: argparse.Namespace) -> None:
# Load conversations
all_conversations = []
with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])
print("Loaded", len(all_conversations), "conversations from", args.input_file)

# Get model config and tokenizer
model_config = AutoConfig.from_pretrained(args.model)
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")

# Set up LLM
llm_common_config = {
"model": args.model,
"attn_backend": "TRTLLM",
"disable_overlap_scheduler": False,
"cuda_graph_config": CudaGraphConfig(batch_sizes=[1, 2, 4])
if args.use_cuda_graph
else None,
"max_batch_size": 16,
"kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5),
"enable_chunked_prefill": False,
"tensor_parallel_size": args.tp_size_per_dp,
}
spec_config = {
"output_directory": str(args.output_dir),
"write_interval": 1,
"file_prefix": f"dp_{args.dp_rank}",
"eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
}
sampling_params = SamplingParams(max_tokens=32, temperature=0)

llm_spec = LLM(
**llm_common_config, speculative_config=SaveHiddenStatesDecodingConfig(**spec_config)
)

args.output_dir.mkdir(parents=True, exist_ok=True)
num_skipped_too_long = 0
num_invalid = 0
num_success = 0
num_total_conversations = min(
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
)
pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations")

def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)

if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
Comment on lines +195 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Correct the type hint for trtllm_dumped_file parameter.

The type hint declares trtllm_dumped_file: str, but the function calls .exists() (line 220) and .unlink() (line 221), which are Path methods. At the call site (line 229), a Path object is passed. Update the type hint to Path for consistency.

Apply this diff:

-    def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
+    def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 195 to 221, the parameter trtllm_dumped_file is annotated as str
but the function uses Path methods (.exists(), .unlink()) and is called with a
Path at the call site; update the type hint to Path (from pathlib) for
trtllm_dumped_file to match usage and callers, and add an import for Path at the
top if not already present.


async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
nonlocal num_success
await llm_spec.generate_async(input_ids, sampling_params)
# TRTLLM API name files starts from 1
# ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
_post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
num_success += 1
pbar.update(1)

async def submit_generates():
nonlocal num_skipped_too_long
nonlocal num_invalid
tasks = []
for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]):
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))

conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
num_input_tokens = (
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
)
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue

tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
await asyncio.gather(*tasks)

asyncio.run(submit_generates())
llm_spec.shutdown()
print("LLM shutdown")

if num_skipped_too_long > 0:
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
if num_invalid > 0:
print(f"Skipped {num_invalid} invalid conversations without proper fields.")

if num_success == num_total_conversations:
print(f"Successfully processed all {num_success} conversations.")
else:
print(
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
)


if __name__ == "__main__":
cli_args = parse_args()
main(cli_args)
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
DP_SIZE=8

split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 7)
for i in $(seq 0 $((DP_SIZE-1)))
do
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
done
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.

export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Comment on lines +1 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add a shebang to define the interpreter.

Without a shebang the script runs under whatever /bin/sh points to, which can diverge from the Bash semantics used here and also trips ShellCheck (SC2148). Please make Bash explicit.

+#!/usr/bin/env bash
+
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.
export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
#!/usr/bin/env bash
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.
export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh
lines 1-24: the script lacks a shebang so it may run under /bin/sh instead of
bash; add a bash shebang (e.g. #!/usr/bin/env bash) as the first line and ensure
the file is executable (chmod +x) so it runs with Bash semantics and avoids
ShellCheck SC2148.


Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory. It does so in a data-parallel manner across 8 GPUs, by splitting
# the input file into 8 parts and running 8 processes in parallel, one on each GPU.

# Note: depending on the write-throughput of the destination disk, this is not guaranteed
# to yield a speed improvement compared to running the model-parallel version. Consider
# benchmarking on a smaller dataset before launching a large run.

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
DP_SIZE=8
MODEL=meta-llama/Llama-3.2-1B-Instruct
export TLLM_LOG_LEVEL="error";

Comment on lines +1 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add a shebang to define the interpreter.

Same as the single-run script: make Bash explicit so the runtime shell is unambiguous and SC2148 is silenced.

+#!/usr/bin/env bash
+
🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
around lines 1-30, the script lacks a shebang; add a Bash shebang as the very
first line (e.g., use env to locate bash) so the interpreter is explicit and
shellcheck SC2148 is silenced, then save and ensure the script remains
executable (chmod +x) if needed.

split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 $((DP_SIZE-1)))
do

export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &

# #On SLURM:
# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i

done
Comment on lines +36 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix suffix formatting for DP ranks ≥10.

Line 36 hardcodes /tmp/part-0${i}.jsonl, which works only for single-digit ranks; once DP_SIZE reaches 10 the actual file is /tmp/part-10.jsonl, so the launch fails. Format the suffix with printf to stay aligned with split output.

-for i in $(seq 0 $((DP_SIZE-1)))
-do
-
-export CUDA_VISIBLE_DEVICES=$i;  python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
+for i in $(seq 0 $((DP_SIZE-1)))
+do
+  suffix=$(printf "%02d" "$i")
+  CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py \
+    --model "$MODEL" \
+    --input-file "/tmp/part-${suffix}.jsonl" \
+    --output-dir "$OUTPUT_DIR" \
+    --dp-rank "$i" &
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
# #On SLURM:
# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i
done
for i in $(seq 0 $((DP_SIZE-1)))
do
suffix=$(printf "%02d" "$i")
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model "$MODEL" \
--input-file "/tmp/part-${suffix}.jsonl" \
--output-dir "$OUTPUT_DIR" \
--dp-rank "$i" &
# #On SLURM:
# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i
done
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh
around lines 36-41, the input file path is hardcoded as /tmp/part-0${i}.jsonl
which breaks for DP ranks >= 10; replace that construction with a printf-based
filename (e.g. use $(printf "/tmp/part-%s.jsonl" "$i") or similar) so the
generated suffix matches split's output for multi-digit ranks, and update the
python invocation to use that formatted filename.

wait

rm /tmp/part-*.jsonl
Loading