2020os .environ ["TLLM_LOG_LEVEL" ] = "error"
2121import argparse
2222import asyncio
23- import json
2423from pathlib import Path
2524
2625import torch
26+ from datasets import load_dataset
2727from tensorrt_llm import LLM , SamplingParams
2828from tensorrt_llm .llmapi import CudaGraphConfig , KvCacheConfig , SaveHiddenStatesDecodingConfig
2929from tqdm import tqdm as tqdm
@@ -59,12 +59,10 @@ def parse_args() -> argparse.Namespace:
5959
6060 ## I/O Parameters ##
6161 parser .add_argument (
62- "--input-file " ,
62+ "--input-data " ,
6363 type = Path ,
6464 required = True ,
65- help = """Path to the input `jsonl` file containing conversations.
66- Each entry must have a unique `conversation_id` field and a `conversations` field
67- containing a list of messages.""" ,
65+ help = """Path to the `jsonl` file or directory containing `jsonl` files.""" ,
6866 )
6967 parser .add_argument (
7068 "--output-dir" ,
@@ -84,7 +82,13 @@ def parse_args() -> argparse.Namespace:
8482 "--dp-rank" ,
8583 type = int ,
8684 default = 0 ,
87- help = """Data parallel rank.""" ,
85+ help = """Data parallel rank. TASK_ID on SLURM.""" ,
86+ )
87+ parser .add_argument (
88+ "--dp-world-size" ,
89+ type = int ,
90+ default = 1 ,
91+ help = """Data parallel world size. Number of tasks on SLURM.""" ,
8892 )
8993 parser .add_argument (
9094 "--use-cuda-graph" ,
@@ -101,21 +105,21 @@ def parse_args() -> argparse.Namespace:
101105 # moe_ep * moe_tp * moe_cp should be equal to tp
102106 # REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html
103107 parser .add_argument (
104- "--moe_ep " ,
108+ "--moe-ep " ,
105109 type = int ,
106- default = 1 ,
110+ default = None ,
107111 help = """moe_expert_parallel_size for TRTLLM.""" ,
108112 )
109113 parser .add_argument (
110- "--moe_tp " ,
114+ "--moe-tp " ,
111115 type = int ,
112- default = 1 ,
116+ default = None ,
113117 help = """moe_tensor_parallel_size for TRTLLM.""" ,
114118 )
115119 parser .add_argument (
116- "--moe_cp " ,
120+ "--moe-cp " ,
117121 type = int ,
118- default = 1 ,
122+ default = None ,
119123 help = """moe_cluster_parallel_size for TRTLLM.""" ,
120124 )
121125
@@ -124,28 +128,43 @@ def parse_args() -> argparse.Namespace:
124128
125129def main (args : argparse .Namespace ) -> None :
126130 # Load conversations
127- all_conversations = []
128- with args .input_file .open ("r" , encoding = "utf-8" ) as f :
129- all_conversations .extend ([json .loads (line ) for line in f if line .strip ()])
130- print ("Loaded" , len (all_conversations ), "conversations from" , args .input_file )
131-
132- # Remove conversations whose output file already exists
133- filtered_conversations = []
134- for entry in all_conversations :
135- conversation_id = entry .get ("conversation_id" , None )
136- if conversation_id is None :
137- filtered_conversations .append (entry )
138- continue
131+ if args .input_data .is_file () and str (args .input_data ).endswith (".jsonl" ):
132+ dataset = load_dataset ("json" , data_files = str (args .input_data ), split = "train" )
133+ elif args .input_data .is_dir ():
134+ dataset = load_dataset (
135+ "json" , data_files = {"train" : f"{ args .input_data } /*.jsonl" }, split = "train"
136+ )
137+ else :
138+ raise ValueError (
139+ f"input_data must be a .jsonl file or directory containing .jsonl files, got: { args .input_data } "
140+ )
141+ print (f"Loaded { len (dataset )} conversations from { args .input_data } " )
142+
143+ # Shard data
144+ if args .dp_world_size > 1 :
145+ dataset = dataset .shard (num_shards = args .dp_world_size , index = args .dp_rank )
146+ print (
147+ f"Sharded dataset to { len (dataset )} conversations for DP#{ args .dp_rank } /{ args .dp_world_size } "
148+ )
149+
150+ # Remove already dumped conversations
151+ def keep_conversation (entry ):
152+ conversation_id = entry .get ("conversation_id" , entry .get ("uuid" , None ))
153+ assert conversation_id is not None , "conversation_id is required"
139154 output_file = args .output_dir / f"{ conversation_id } .pt"
140- if output_file .exists ():
141- continue
142- filtered_conversations .append (entry )
155+ return not output_file .exists ()
156+
157+ original_num = len (dataset )
158+ dataset = dataset .filter (keep_conversation )
143159 print (
144160 "Removed" ,
145- len ( all_conversations ) - len (filtered_conversations ),
161+ original_num - len (dataset ),
146162 "conversations due to existing output files" ,
147163 )
148- all_conversations = filtered_conversations
164+
165+ # For debugging
166+ if args .debug_max_num_conversations is not None :
167+ dataset = dataset .select (range (args .debug_max_num_conversations ))
149168
150169 # Get model config and tokenizer
151170 model_config = AutoConfig .from_pretrained (args .model )
@@ -187,10 +206,7 @@ def main(args: argparse.Namespace) -> None:
187206 num_skipped_too_long = 0
188207 num_invalid = 0
189208 num_success = 0
190- num_total_conversations = min (
191- len (all_conversations ), args .debug_max_num_conversations or len (all_conversations )
192- )
193- pbar = tqdm (total = num_total_conversations , desc = f"DP#{ args .dp_rank } Processing conversations" )
209+ pbar = tqdm (total = len (dataset ), desc = f"DP#{ args .dp_rank } Processing conversations" )
194210
195211 def _post_process_trtllm_dumped (trtllm_dumped_file : str , conversation_id : int ):
196212 """Post-process the TRTLLM dumped file to same format as HF dumped:
@@ -234,15 +250,17 @@ async def submit_generates():
234250 nonlocal num_skipped_too_long
235251 nonlocal num_invalid
236252 tasks = []
237- for idx , entry in enumerate (all_conversations [: args . debug_max_num_conversations ] ):
238- conversation_id = entry .get ("conversation_id" , "{:08d}" . format ( idx ))
253+ for idx , entry in enumerate (dataset ):
254+ conversation_id = entry .get ("conversation_id" , entry . get ( "uuid" ))
239255
240256 conversations = entry ["conversations" ]
241257 if not conversations or not isinstance (conversations , list ):
242258 num_invalid += 1
243259 continue
244260
245- input_ids = tokenizer .apply_chat_template (conversations , add_generation_template = False )
261+ input_ids = tokenizer .apply_chat_template (conversations , add_generation_template = False )[
262+ :256
263+ ]
246264 num_input_tokens = (
247265 input_ids .shape [1 ] if isinstance (input_ids , torch .Tensor ) else len (input_ids )
248266 )
@@ -262,12 +280,10 @@ async def submit_generates():
262280 if num_invalid > 0 :
263281 print (f"Skipped { num_invalid } invalid conversations without proper fields." )
264282
265- if num_success == num_total_conversations :
283+ if num_success == len ( dataset ) :
266284 print (f"Successfully processed all { num_success } conversations." )
267285 else :
268- print (
269- f"Successfully processed { num_success } out of { num_total_conversations } conversations."
270- )
286+ print (f"Successfully processed { num_success } out of { len (dataset )} conversations." )
271287
272288
273289if __name__ == "__main__" :
0 commit comments