Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions examples/speculative_decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,22 @@ The saved modelopt checkpoint is similar in architecture to HF models. It can be

## Training Draft Model with Offline Base Model

For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of storage depending on dataset size.
For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of disk storage depending on dataset size.

First, dump the base model's hidden states with the following command:
### Dumpping Hidden States to Disk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo in heading.

"Dumpping" should be "Dumping".

Apply this diff:

-### Dumpping Hidden States to Disk
+### Dumping Hidden States to Disk
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
### Dumpping Hidden States to Disk
### Dumping Hidden States to Disk
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 87, the heading contains
a typo "Dumpping Hidden States to Disk"; change "Dumpping" to "Dumping" so the
heading reads "Dumping Hidden States to Disk".


We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix typo.

"effciency" should be "efficiency".

Apply this diff:

-We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
+We support two backends for generating base model hidden states. For better efficiency, it is recommended to use TRT-LLM:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
We support two backends for generating base model hidden states. For better effciency, it is recommended to use TRT-LLM:
We support two backends for generating base model hidden states. For better efficiency, it is recommended to use TRT-LLM:
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 89, there is a typo:
change the word "effciency" to "efficiency" so the sentence reads "For better
efficiency, it is recommended to use TRT-LLM:"; update that single word in the
line.


```bash
python collect_hidden_states/compute_hidden_states_trtllm.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--output-dir $HIDDEN_STATES_DIR
```
Comment on lines +92 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove trailing space after line continuation.

Line 93 has a trailing space after the backslash, which breaks bash line continuation. The backslash must be immediately followed by a newline.

Apply this diff:

 python collect_hidden_states/compute_hidden_states_trtllm.py \
-            --model $BASE_MODEL \ 
+            --model $BASE_MODEL \
             --input-file Daring-Anteater/train.jsonl \
             --output-dir $HIDDEN_STATES_DIR
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
python collect_hidden_states/compute_hidden_states_trtllm.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--output-dir $HIDDEN_STATES_DIR
```
python collect_hidden_states/compute_hidden_states_trtllm.py \
--model $BASE_MODEL \
--input-file Daring-Anteater/train.jsonl \
--output-dir $HIDDEN_STATES_DIR
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around lines 92 to 96, the bash line
continuation backslash on line 93 has a trailing space which breaks the
continuation; remove the trailing space so the backslash is the last character
on the line (and ensure the line ends with a newline, not CRLF) to restore
proper shell line continuation.


**NOTE**: TRT-LLM installation needed for the above command.

Alternatively, you can generate the same hidden states with HF:

```bash
python collect_hidden_states/compute_hidden_states_hf.py \
Expand All @@ -93,9 +106,11 @@ python collect_hidden_states/compute_hidden_states_hf.py \
--output-dir $HIDDEN_STATES_DIR
```

See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation.
**NOTE**: See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) and [`run_trtllm_compute_hiddens_dp.sh`](./collect_hidden_states/run_trtllm_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation.

### Train Draft Model with Dumped Hidden States

Then, train draft model with `--offline-data` argument:
Once we finish dumping hidden states, launch offline training with an extra `--offline-data` argument:

```bash
./launch_train.sh --model $BASE_MODEL \
Expand All @@ -109,13 +124,13 @@ Then, train draft model with `--offline-data` argument:

## Model Validation

After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by:
For online training checkpoints, we can run in-framework evaluation on MT-bench:

```bash
python ar_validate.py --model_path $OUTPUT_DIR
python ar_validate.py --model_path $ONLINE_CKPT
```

Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below.
Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix grammar.

"Offline checkpoints does not support" should be "Offline checkpoints do not support" (plural subject requires plural verb).

Apply this diff:

-Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
+Offline checkpoints do not support this evaluation due to missing base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Offline checkpoints does not support this evaluation due to missing of base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
Offline checkpoints do not support this evaluation due to missing base model modules. To evaluate offline checkpoint, please export first and evaluate with serving frameworks.
🤖 Prompt for AI Agents
In examples/speculative_decoding/README.md around line 133, the sentence
"Offline checkpoints does not support this evaluation due to missing of base
model modules." has a subject-verb agreement error; change "Offline checkpoints
does not support" to "Offline checkpoints do not support" so the plural subject
matches the plural verb (optionally also remove "of" to read "due to missing
base model modules").


## Export

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Extract hidden states from an HF-compatible LLM."""

import os

os.environ["TLLM_LOG_LEVEL"] = "error"
import argparse
import asyncio
import json
from pathlib import Path

import torch
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig
from tqdm import tqdm as tqdm
from transformers import AutoConfig, AutoTokenizer

REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="""Collect hidden states from conversations
by running full conversations through a Hugging Face model."""
)

## Model & Generation Parameters ##
parser.add_argument(
"--model",
type=str,
required=True,
help="Name of the served model.",
)

## Client Parameters ##
parser.add_argument(
"--max-seq-len",
type=int,
default=3072,
help="""Maximum number of tokens in a conversation. Longer conversations will be skipped.
Defaults to 3072 tokens.""",
)

## I/O Parameters ##
parser.add_argument(
"--input-file",
type=Path,
required=True,
help="""Path to the input `jsonl` file containing conversations.
Each entry must have a unique `conversation_id` field and a `conversations` field
containing a list of messages.""",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="""Root directory in which to save the hidden states.
The data will be saved as a torch (`.pt`) dump file for each conversation.""",
)
parser.add_argument(
"--debug-max-num-conversations",
type=int,
default=None,
help="""For debugging purposes, limit the number of conversations processed.
Default is None, meaning no limit.""",
)
parser.add_argument(
"--dp-rank",
type=int,
default=0,
help="""Data parallel rank.""",
)
parser.add_argument(
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
Comment on lines +90 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix boolean argument handling.

Using type=bool for boolean flags in argparse is problematic. Any non-empty string (including "False" or "0") will be interpreted as True. Use action='store_true' or action='store_false' instead.

Apply this diff:

     parser.add_argument(
         "--use-cuda-graph",
-        type=bool,
-        default=True,
+        action='store_true',
+        default=False,
         help="""Whether to use CUDA graph.""",
     )

Or if you want it enabled by default:

     parser.add_argument(
-        "--use-cuda-graph",
-        type=bool,
-        default=True,
-        help="""Whether to use CUDA graph.""",
+        "--no-cuda-graph",
+        action='store_false',
+        dest='use_cuda_graph',
+        help="""Disable CUDA graph (enabled by default).""",
     )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--no-cuda-graph",
action='store_false',
dest='use_cuda_graph',
help="""Disable CUDA graph (enabled by default).""",
)
Suggested change
"--use-cuda-graph",
type=bool,
default=True,
help="""Whether to use CUDA graph.""",
)
parser.add_argument(
"--use-cuda-graph",
action='store_true',
default=False,
help="""Whether to use CUDA graph.""",
)
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 90 to 94, the argparse boolean flag currently uses type=bool which
treats any non-empty string as True; change the argument to use
action='store_true' (or action='store_false' if you prefer the inverse) so the
flag is parsed correctly, and set default accordingly (omit type and use
default=True/False as needed or use action='store_false' to have it enabled by
default).

parser.add_argument(
"--tp",
type=int,
default=1,
help="""tensor_parallel_size for TRTLLM.""",
)
# moe_ep * moe_tp * moe_cp should be equal to tp
# REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html
parser.add_argument(
"--moe_ep",
type=int,
default=1,
help="""moe_expert_parallel_size for TRTLLM.""",
)
parser.add_argument(
"--moe_tp",
type=int,
default=1,
help="""moe_tensor_parallel_size for TRTLLM.""",
)
parser.add_argument(
"--moe_cp",
type=int,
default=1,
help="""moe_cluster_parallel_size for TRTLLM.""",
)

return parser.parse_args()


def main(args: argparse.Namespace) -> None:
# Load conversations
all_conversations = []
with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])
print("Loaded", len(all_conversations), "conversations from", args.input_file)

# Remove conversations whose output file already exists
filtered_conversations = []
for entry in all_conversations:
conversation_id = entry.get("conversation_id", None)
if conversation_id is None:
filtered_conversations.append(entry)
continue
output_file = args.output_dir / f"{conversation_id}.pt"
if output_file.exists():
continue
filtered_conversations.append(entry)
print(
"Removed",
len(all_conversations) - len(filtered_conversations),
"conversations due to existing output files",
)
all_conversations = filtered_conversations

# Get model config and tokenizer
model_config = AutoConfig.from_pretrained(args.model)
num_hidden_layers = getattr(model_config, "num_hidden_layers", None)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")

# Set up LLM
llm_common_config = {
"model": args.model,
"attn_backend": "TRTLLM",
"disable_overlap_scheduler": False,
"cuda_graph_config": CudaGraphConfig(batch_sizes=[1, 2, 4])
if args.use_cuda_graph
else None,
"max_batch_size": 16,
"kv_cache_config": KvCacheConfig(enable_block_reuse=False, free_gpu_memory_fraction=0.5),
"enable_chunked_prefill": False,
"tensor_parallel_size": args.tp,
"moe_expert_parallel_size": args.moe_ep,
"moe_tensor_parallel_size": args.moe_tp,
"moe_cluster_parallel_size": args.moe_cp,
}
spec_config = {
"output_directory": str(args.output_dir),
"write_interval": 1,
"file_prefix": f"dp_{args.dp_rank}",
"eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
}
sampling_params = SamplingParams(max_tokens=32, temperature=0)

llm_spec = LLM(
**llm_common_config, speculative_config=SaveHiddenStatesDecodingConfig(**spec_config)
)

args.output_dir.mkdir(parents=True, exist_ok=True)
num_skipped_too_long = 0
num_invalid = 0
num_success = 0
num_total_conversations = min(
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
)
pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations")

def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)

if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
Comment on lines +195 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Correct the type hint for trtllm_dumped_file parameter.

The type hint declares trtllm_dumped_file: str, but the function calls .exists() (line 220) and .unlink() (line 221), which are Path methods. At the call site (line 229), a Path object is passed. Update the type hint to Path for consistency.

Apply this diff:

-    def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
+    def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
def _post_process_trtllm_dumped(trtllm_dumped_file: Path, conversation_id: int):
"""Post-process the TRTLLM dumped file to same format as HF dumped:
1. Remove id field, replace it with conversation_id
2. Rename hidden_state field to hidden_states
3. From list of length 1 to dict
4. Rename file to conversation_id.pt
"""
with open(trtllm_dumped_file, "rb") as f:
trtllm_dumped = torch.load(f)
assert isinstance(trtllm_dumped, list) and len(trtllm_dumped) == 1, (
"TRTLLM dumped should be a list with one element"
)
assert (
isinstance(trtllm_dumped[0], dict)
and "id" in trtllm_dumped[0]
and "hidden_state" in trtllm_dumped[0]
), "TRTLLM dumped should have an 'id' and 'hidden_states' field"
trtllm_dumped = trtllm_dumped[0]
trtllm_dumped.pop("id")
trtllm_dumped["conversation_id"] = conversation_id
trtllm_dumped["hidden_states"] = trtllm_dumped.pop("hidden_state")
output_file = args.output_dir / f"{conversation_id}.pt"
with open(output_file, "wb") as f:
torch.save(trtllm_dumped, f)
if trtllm_dumped_file.exists():
trtllm_dumped_file.unlink()
🤖 Prompt for AI Agents
In
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
around lines 195 to 221, the parameter trtllm_dumped_file is annotated as str
but the function uses Path methods (.exists(), .unlink()) and is called with a
Path at the call site; update the type hint to Path (from pathlib) for
trtllm_dumped_file to match usage and callers, and add an import for Path at the
top if not already present.


async def dump_hidden_states(idx: int, conversation_id: int, input_ids: list[int]):
nonlocal num_success
await llm_spec.generate_async(input_ids, sampling_params)
# TRTLLM API name files starts from 1
# ref:https://github.com/NVIDIA/TensorRT-LLM/pull/7012
trtllm_dumped_file = args.output_dir / f"{spec_config['file_prefix']}_{idx + 1}.pt"
_post_process_trtllm_dumped(trtllm_dumped_file, conversation_id)
num_success += 1
pbar.update(1)

async def submit_generates():
nonlocal num_skipped_too_long
nonlocal num_invalid
tasks = []
for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]):
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))

conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
num_input_tokens = (
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
)
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue

tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
await asyncio.gather(*tasks)

asyncio.run(submit_generates())
llm_spec.shutdown()
print("LLM shutdown")

if num_skipped_too_long > 0:
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
if num_invalid > 0:
print(f"Skipped {num_invalid} invalid conversations without proper fields.")

if num_success == num_total_conversations:
print(f"Successfully processed all {num_success} conversations.")
else:
print(
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
)


if __name__ == "__main__":
cli_args = parse_args()
main(cli_args)
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@

INPUT_FILE=synthetic_conversations/daring-anteater.jsonl
OUTPUT_DIR=/mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
DP_SIZE=8

split -n l/8 --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-
split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FILE /tmp/part-

for i in $(seq 0 7)
for i in $(seq 0 $((DP_SIZE-1)))
do
CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR &
done
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.

export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
Comment on lines +1 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add a shebang to define the interpreter.

Without a shebang the script runs under whatever /bin/sh points to, which can diverge from the Bash semantics used here and also trips ShellCheck (SC2148). Please make Bash explicit.

+#!/usr/bin/env bash
+
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.
export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
#!/usr/bin/env bash
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Example usage of the script to compute the hidden states for a conversation dataset
# This script computes hidden states using TensorRT-LLM and saves them to
# the specified output directory.
export TLLM_LOG_LEVEL="error";
python3 collect_hidden_states/compute_hidden_states_trtllm.py \
--model meta-llama/Llama-3.2-1B-Instruct \
--input-file synthetic_conversations/daring-anteater.jsonl \
--output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/
🧰 Tools
🪛 Shellcheck (0.11.0)

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🤖 Prompt for AI Agents
examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh
lines 1-24: the script lacks a shebang so it may run under /bin/sh instead of
bash; add a bash shebang (e.g. #!/usr/bin/env bash) as the first line and ensure
the file is executable (chmod +x) so it runs with Bash semantics and avoids
ShellCheck SC2148.


Loading