Skip to content

Commit 40a7d24

Browse files
authored
[Eagle Offline] multinode support for hidden states dumper (#422)
Signed-off-by: h-guo18 <[email protected]>
1 parent 08fb23f commit 40a7d24

File tree

3 files changed

+103
-43
lines changed

3 files changed

+103
-43
lines changed

examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py

Lines changed: 56 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@
2020
os.environ["TLLM_LOG_LEVEL"] = "error"
2121
import argparse
2222
import asyncio
23-
import json
2423
from pathlib import Path
2524

2625
import torch
26+
from datasets import load_dataset
2727
from tensorrt_llm import LLM, SamplingParams
2828
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig
2929
from tqdm import tqdm as tqdm
@@ -59,12 +59,10 @@ def parse_args() -> argparse.Namespace:
5959

6060
## I/O Parameters ##
6161
parser.add_argument(
62-
"--input-file",
62+
"--input-data",
6363
type=Path,
6464
required=True,
65-
help="""Path to the input `jsonl` file containing conversations.
66-
Each entry must have a unique `conversation_id` field and a `conversations` field
67-
containing a list of messages.""",
65+
help="""Path to the `jsonl` file or directory containing `jsonl` files.""",
6866
)
6967
parser.add_argument(
7068
"--output-dir",
@@ -84,7 +82,13 @@ def parse_args() -> argparse.Namespace:
8482
"--dp-rank",
8583
type=int,
8684
default=0,
87-
help="""Data parallel rank.""",
85+
help="""Data parallel rank. TASK_ID on SLURM.""",
86+
)
87+
parser.add_argument(
88+
"--dp-world-size",
89+
type=int,
90+
default=1,
91+
help="""Data parallel world size. Number of tasks on SLURM.""",
8892
)
8993
parser.add_argument(
9094
"--use-cuda-graph",
@@ -101,21 +105,21 @@ def parse_args() -> argparse.Namespace:
101105
# moe_ep * moe_tp * moe_cp should be equal to tp
102106
# REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html
103107
parser.add_argument(
104-
"--moe_ep",
108+
"--moe-ep",
105109
type=int,
106-
default=1,
110+
default=None,
107111
help="""moe_expert_parallel_size for TRTLLM.""",
108112
)
109113
parser.add_argument(
110-
"--moe_tp",
114+
"--moe-tp",
111115
type=int,
112-
default=1,
116+
default=None,
113117
help="""moe_tensor_parallel_size for TRTLLM.""",
114118
)
115119
parser.add_argument(
116-
"--moe_cp",
120+
"--moe-cp",
117121
type=int,
118-
default=1,
122+
default=None,
119123
help="""moe_cluster_parallel_size for TRTLLM.""",
120124
)
121125

@@ -124,28 +128,43 @@ def parse_args() -> argparse.Namespace:
124128

125129
def main(args: argparse.Namespace) -> None:
126130
# Load conversations
127-
all_conversations = []
128-
with args.input_file.open("r", encoding="utf-8") as f:
129-
all_conversations.extend([json.loads(line) for line in f if line.strip()])
130-
print("Loaded", len(all_conversations), "conversations from", args.input_file)
131-
132-
# Remove conversations whose output file already exists
133-
filtered_conversations = []
134-
for entry in all_conversations:
135-
conversation_id = entry.get("conversation_id", None)
136-
if conversation_id is None:
137-
filtered_conversations.append(entry)
138-
continue
131+
if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"):
132+
dataset = load_dataset("json", data_files=str(args.input_data), split="train")
133+
elif args.input_data.is_dir():
134+
dataset = load_dataset(
135+
"json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train"
136+
)
137+
else:
138+
raise ValueError(
139+
f"input_data must be a .jsonl file or directory containing .jsonl files, got: {args.input_data}"
140+
)
141+
print(f"Loaded {len(dataset)} conversations from {args.input_data}")
142+
143+
# Shard data
144+
if args.dp_world_size > 1:
145+
dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank)
146+
print(
147+
f"Sharded dataset to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}"
148+
)
149+
150+
# Remove already dumped conversations
151+
def keep_conversation(entry):
152+
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
153+
assert conversation_id is not None, "conversation_id is required"
139154
output_file = args.output_dir / f"{conversation_id}.pt"
140-
if output_file.exists():
141-
continue
142-
filtered_conversations.append(entry)
155+
return not output_file.exists()
156+
157+
original_num = len(dataset)
158+
dataset = dataset.filter(keep_conversation)
143159
print(
144160
"Removed",
145-
len(all_conversations) - len(filtered_conversations),
161+
original_num - len(dataset),
146162
"conversations due to existing output files",
147163
)
148-
all_conversations = filtered_conversations
164+
165+
# For debugging
166+
if args.debug_max_num_conversations is not None:
167+
dataset = dataset.select(range(args.debug_max_num_conversations))
149168

150169
# Get model config and tokenizer
151170
model_config = AutoConfig.from_pretrained(args.model)
@@ -187,10 +206,7 @@ def main(args: argparse.Namespace) -> None:
187206
num_skipped_too_long = 0
188207
num_invalid = 0
189208
num_success = 0
190-
num_total_conversations = min(
191-
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
192-
)
193-
pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations")
209+
pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations")
194210

195211
def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int):
196212
"""Post-process the TRTLLM dumped file to same format as HF dumped:
@@ -234,15 +250,17 @@ async def submit_generates():
234250
nonlocal num_skipped_too_long
235251
nonlocal num_invalid
236252
tasks = []
237-
for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]):
238-
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
253+
for idx, entry in enumerate(dataset):
254+
conversation_id = entry.get("conversation_id", entry.get("uuid"))
239255

240256
conversations = entry["conversations"]
241257
if not conversations or not isinstance(conversations, list):
242258
num_invalid += 1
243259
continue
244260

245-
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)
261+
input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[
262+
:256
263+
]
246264
num_input_tokens = (
247265
input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids)
248266
)
@@ -262,12 +280,10 @@ async def submit_generates():
262280
if num_invalid > 0:
263281
print(f"Skipped {num_invalid} invalid conversations without proper fields.")
264282

265-
if num_success == num_total_conversations:
283+
if num_success == len(dataset):
266284
print(f"Successfully processed all {num_success} conversations.")
267285
else:
268-
print(
269-
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
270-
)
286+
print(f"Successfully processed {num_success} out of {len(dataset)} conversations.")
271287

272288

273289
if __name__ == "__main__":

examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ do
3535

3636
export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i &
3737

38-
# #On SLURM:
39-
# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i
40-
4138
done
4239
wait
4340

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/bin/bash
2+
3+
# CHANGE THE FOLLOWING TO YOUR ACCOUNT AND CHANGE THE JOB NAME TO COMPLY WITH THE
4+
# USAGE. SWITCH TO `-p luna -t 04:00:00` IF YOU HAVE BEEN GRANTED CAPACITY FROM
5+
# THE BIWEEKLY CAPACITY MEETING. IF YOU DON'T KNOW WHO IS THE PIC OF YOUR CSRG PPP
6+
# MANAGEMET, GO WITH `-p backfill -t 00:25:00`.
7+
8+
#SBATCH -A coreai_dlalgo_modelopt
9+
#SBATCH --job-name=coreai_dlalgo_modelopt-generate_eagle_hidden_states
10+
#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4
11+
#SBATCH -p batch
12+
#SBATCH -t 04:00:00
13+
14+
echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID"
15+
echo "SLURM_ARRAY_TASK_COUNT: $SLURM_ARRAY_TASK_COUNT"
16+
17+
CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0"
18+
19+
INPUT_DIR="<Can be directory containing the .jsonl files, or path to single .jsonl file>"
20+
DUMP_DIR="<Directory for output hidden states>"
21+
MODELOPT_DIR="<Path to Modelopt repo>"
22+
TEACHER_MODEL="<Path to teacher model>"
23+
24+
if [ ! -d "$DUMP_DIR" ]; then
25+
mkdir -p "$DUMP_DIR"
26+
fi
27+
28+
MOUNTS=$INPUT_DIR:/input,$DUMP_DIR:/output,$MODELOPT_DIR:/modelopt,$TEACHER_MODEL:/model
29+
30+
#By default: TP inside node, and DP across slurm array
31+
#EP optionally available by setting --moe-ep-size and --moe-tp-size. See compute_hidden_states_trtllm.py.
32+
PARALLEL_ARGS="--tp 4 --dp-rank $SLURM_ARRAY_TASK_ID --dp-world-size $SLURM_ARRAY_TASK_COUNT"
33+
34+
RUN_DUMPER="export TLLM_LOG_LEVEL="error";
35+
trtllm-llmapi-launch python3 /modelopt/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py \
36+
--model /model \
37+
--input-data /input/ \
38+
--output-dir /output \
39+
$PARALLEL_ARGS \
40+
"
41+
42+
timeout 235m srun -l \
43+
--mpi=pmix --overlap \
44+
--output=%x_%j_$DATETIME.log \
45+
--container-image ${CONTAINER} \
46+
--container-mounts ${MOUNTS} \
47+
bash -c "$RUN_DUMPER"

0 commit comments

Comments
 (0)