diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py index d25bd3dac..6b088a3a6 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py @@ -20,10 +20,10 @@ os.environ["TLLM_LOG_LEVEL"] = "error" import argparse import asyncio -import json from pathlib import Path import torch +from datasets import load_dataset from tensorrt_llm import LLM, SamplingParams from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig, SaveHiddenStatesDecodingConfig from tqdm import tqdm as tqdm @@ -59,12 +59,10 @@ def parse_args() -> argparse.Namespace: ## I/O Parameters ## parser.add_argument( - "--input-file", + "--input-data", type=Path, required=True, - help="""Path to the input `jsonl` file containing conversations. - Each entry must have a unique `conversation_id` field and a `conversations` field - containing a list of messages.""", + help="""Path to the `jsonl` file or directory containing `jsonl` files.""", ) parser.add_argument( "--output-dir", @@ -84,7 +82,13 @@ def parse_args() -> argparse.Namespace: "--dp-rank", type=int, default=0, - help="""Data parallel rank.""", + help="""Data parallel rank. TASK_ID on SLURM.""", + ) + parser.add_argument( + "--dp-world-size", + type=int, + default=1, + help="""Data parallel world size. Number of tasks on SLURM.""", ) parser.add_argument( "--use-cuda-graph", @@ -101,21 +105,21 @@ def parse_args() -> argparse.Namespace: # moe_ep * moe_tp * moe_cp should be equal to tp # REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html parser.add_argument( - "--moe_ep", + "--moe-ep", type=int, - default=1, + default=None, help="""moe_expert_parallel_size for TRTLLM.""", ) parser.add_argument( - "--moe_tp", + "--moe-tp", type=int, - default=1, + default=None, help="""moe_tensor_parallel_size for TRTLLM.""", ) parser.add_argument( - "--moe_cp", + "--moe-cp", type=int, - default=1, + default=None, help="""moe_cluster_parallel_size for TRTLLM.""", ) @@ -124,28 +128,43 @@ def parse_args() -> argparse.Namespace: def main(args: argparse.Namespace) -> None: # Load conversations - all_conversations = [] - with args.input_file.open("r", encoding="utf-8") as f: - all_conversations.extend([json.loads(line) for line in f if line.strip()]) - print("Loaded", len(all_conversations), "conversations from", args.input_file) - - # Remove conversations whose output file already exists - filtered_conversations = [] - for entry in all_conversations: - conversation_id = entry.get("conversation_id", None) - if conversation_id is None: - filtered_conversations.append(entry) - continue + if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"): + dataset = load_dataset("json", data_files=str(args.input_data), split="train") + elif args.input_data.is_dir(): + dataset = load_dataset( + "json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train" + ) + else: + raise ValueError( + f"input_data must be a .jsonl file or directory containing .jsonl files, got: {args.input_data}" + ) + print(f"Loaded {len(dataset)} conversations from {args.input_data}") + + # Shard data + if args.dp_world_size > 1: + dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank) + print( + f"Sharded dataset to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}" + ) + + # Remove already dumped conversations + def keep_conversation(entry): + conversation_id = entry.get("conversation_id", entry.get("uuid", None)) + assert conversation_id is not None, "conversation_id is required" output_file = args.output_dir / f"{conversation_id}.pt" - if output_file.exists(): - continue - filtered_conversations.append(entry) + return not output_file.exists() + + original_num = len(dataset) + dataset = dataset.filter(keep_conversation) print( "Removed", - len(all_conversations) - len(filtered_conversations), + original_num - len(dataset), "conversations due to existing output files", ) - all_conversations = filtered_conversations + + # For debugging + if args.debug_max_num_conversations is not None: + dataset = dataset.select(range(args.debug_max_num_conversations)) # Get model config and tokenizer model_config = AutoConfig.from_pretrained(args.model) @@ -187,10 +206,7 @@ def main(args: argparse.Namespace) -> None: num_skipped_too_long = 0 num_invalid = 0 num_success = 0 - num_total_conversations = min( - len(all_conversations), args.debug_max_num_conversations or len(all_conversations) - ) - pbar = tqdm(total=num_total_conversations, desc=f"DP#{args.dp_rank} Processing conversations") + pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations") def _post_process_trtllm_dumped(trtllm_dumped_file: str, conversation_id: int): """Post-process the TRTLLM dumped file to same format as HF dumped: @@ -234,15 +250,17 @@ async def submit_generates(): nonlocal num_skipped_too_long nonlocal num_invalid tasks = [] - for idx, entry in enumerate(all_conversations[: args.debug_max_num_conversations]): - conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) + for idx, entry in enumerate(dataset): + conversation_id = entry.get("conversation_id", entry.get("uuid")) conversations = entry["conversations"] if not conversations or not isinstance(conversations, list): num_invalid += 1 continue - input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False) + input_ids = tokenizer.apply_chat_template(conversations, add_generation_template=False)[ + :256 + ] num_input_tokens = ( input_ids.shape[1] if isinstance(input_ids, torch.Tensor) else len(input_ids) ) @@ -262,12 +280,10 @@ async def submit_generates(): if num_invalid > 0: print(f"Skipped {num_invalid} invalid conversations without proper fields.") - if num_success == num_total_conversations: + if num_success == len(dataset): print(f"Successfully processed all {num_success} conversations.") else: - print( - f"Successfully processed {num_success} out of {num_total_conversations} conversations." - ) + print(f"Successfully processed {num_success} out of {len(dataset)} conversations.") if __name__ == "__main__": diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh index 73289c0df..4b0fd1060 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh @@ -35,9 +35,6 @@ do export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & -# #On SLURM: -# PORT=$((10012 + i)); export TLLM_SPAWN_PROXY_PROCESS_IPC_ADDR="tcp://127.0.0.1:$PORT"; trtllm-llmapi-launch python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i - done wait diff --git a/examples/speculative_decoding/collect_hidden_states/slurm_dump.sh b/examples/speculative_decoding/collect_hidden_states/slurm_dump.sh new file mode 100644 index 000000000..b8e799d70 --- /dev/null +++ b/examples/speculative_decoding/collect_hidden_states/slurm_dump.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# CHANGE THE FOLLOWING TO YOUR ACCOUNT AND CHANGE THE JOB NAME TO COMPLY WITH THE +# USAGE. SWITCH TO `-p luna -t 04:00:00` IF YOU HAVE BEEN GRANTED CAPACITY FROM +# THE BIWEEKLY CAPACITY MEETING. IF YOU DON'T KNOW WHO IS THE PIC OF YOUR CSRG PPP +# MANAGEMET, GO WITH `-p backfill -t 00:25:00`. + +#SBATCH -A coreai_dlalgo_modelopt +#SBATCH --job-name=coreai_dlalgo_modelopt-generate_eagle_hidden_states +#SBATCH --nodes=1 --ntasks-per-node=4 --gpus-per-node=4 +#SBATCH -p batch +#SBATCH -t 04:00:00 + +echo "SLURM_ARRAY_TASK_ID: $SLURM_ARRAY_TASK_ID" +echo "SLURM_ARRAY_TASK_COUNT: $SLURM_ARRAY_TASK_COUNT" + +CONTAINER="nvcr.io#nvidia/tensorrt-llm/release:1.2.0rc0" + +INPUT_DIR="" +DUMP_DIR="