2020os .environ ["TLLM_LOG_LEVEL" ] =  "error" 
2121import  argparse 
2222import  asyncio 
23- import  json 
2423from  pathlib  import  Path 
2524
2625import  torch 
26+ from  datasets  import  load_dataset 
2727from  tensorrt_llm  import  LLM , SamplingParams 
2828from  tensorrt_llm .llmapi  import  CudaGraphConfig , KvCacheConfig , SaveHiddenStatesDecodingConfig 
2929from  tqdm  import  tqdm  as  tqdm 
@@ -59,12 +59,10 @@ def parse_args() -> argparse.Namespace:
5959
6060    ## I/O Parameters ## 
6161    parser .add_argument (
62-         "--input-file " ,
62+         "--input-data " ,
6363        type = Path ,
6464        required = True ,
65-         help = """Path to the input `jsonl` file containing conversations. 
66-         Each entry must have a unique `conversation_id` field and a `conversations` field 
67-         containing a list of messages.""" ,
65+         help = """Path to the `jsonl` file or directory containing `jsonl` files.""" ,
6866    )
6967    parser .add_argument (
7068        "--output-dir" ,
@@ -84,7 +82,13 @@ def parse_args() -> argparse.Namespace:
8482        "--dp-rank" ,
8583        type = int ,
8684        default = 0 ,
87-         help = """Data parallel rank.""" ,
85+         help = """Data parallel rank. TASK_ID on SLURM.""" ,
86+     )
87+     parser .add_argument (
88+         "--dp-world-size" ,
89+         type = int ,
90+         default = 1 ,
91+         help = """Data parallel world size. Number of tasks on SLURM.""" ,
8892    )
8993    parser .add_argument (
9094        "--use-cuda-graph" ,
@@ -101,21 +105,21 @@ def parse_args() -> argparse.Namespace:
101105    # moe_ep * moe_tp * moe_cp should be equal to tp 
102106    # REF: https://nvidia.github.io/TensorRT-LLM/advanced/expert-parallelism.html 
103107    parser .add_argument (
104-         "--moe_ep " ,
108+         "--moe-ep " ,
105109        type = int ,
106-         default = 1 ,
110+         default = None ,
107111        help = """moe_expert_parallel_size for TRTLLM.""" ,
108112    )
109113    parser .add_argument (
110-         "--moe_tp " ,
114+         "--moe-tp " ,
111115        type = int ,
112-         default = 1 ,
116+         default = None ,
113117        help = """moe_tensor_parallel_size for TRTLLM.""" ,
114118    )
115119    parser .add_argument (
116-         "--moe_cp " ,
120+         "--moe-cp " ,
117121        type = int ,
118-         default = 1 ,
122+         default = None ,
119123        help = """moe_cluster_parallel_size for TRTLLM.""" ,
120124    )
121125
@@ -124,28 +128,43 @@ def parse_args() -> argparse.Namespace:
124128
125129def  main (args : argparse .Namespace ) ->  None :
126130    # Load conversations 
127-     all_conversations  =  []
128-     with  args .input_file .open ("r" , encoding = "utf-8" ) as  f :
129-         all_conversations .extend ([json .loads (line ) for  line  in  f  if  line .strip ()])
130-     print ("Loaded" , len (all_conversations ), "conversations from" , args .input_file )
131- 
132-     # Remove conversations whose output file already exists 
133-     filtered_conversations  =  []
134-     for  entry  in  all_conversations :
135-         conversation_id  =  entry .get ("conversation_id" , None )
136-         if  conversation_id  is  None :
137-             filtered_conversations .append (entry )
138-             continue 
131+     if  args .input_data .is_file () and  str (args .input_data ).endswith (".jsonl" ):
132+         dataset  =  load_dataset ("json" , data_files = str (args .input_data ), split = "train" )
133+     elif  args .input_data .is_dir ():
134+         dataset  =  load_dataset (
135+             "json" , data_files = {"train" : f"{ args .input_data }  /*.jsonl" }, split = "train" 
136+         )
137+     else :
138+         raise  ValueError (
139+             f"input_data must be a .jsonl file or directory containing .jsonl files, got: { args .input_data }  " 
140+         )
141+     print (f"Loaded { len (dataset )}   conversations from { args .input_data }  " )
142+ 
143+     # Shard data 
144+     if  args .dp_world_size  >  1 :
145+         dataset  =  dataset .shard (num_shards = args .dp_world_size , index = args .dp_rank )
146+     print (
147+         f"Sharded dataset to { len (dataset )}   conversations for DP#{ args .dp_rank }  /{ args .dp_world_size }  " 
148+     )
149+ 
150+     # Remove already dumped conversations 
151+     def  keep_conversation (entry ):
152+         conversation_id  =  entry .get ("conversation_id" , entry .get ("uuid" , None ))
153+         assert  conversation_id  is  not   None , "conversation_id is required" 
139154        output_file  =  args .output_dir  /  f"{ conversation_id }  .pt" 
140-         if  output_file .exists ():
141-             continue 
142-         filtered_conversations .append (entry )
155+         return  not  output_file .exists ()
156+ 
157+     original_num  =  len (dataset )
158+     dataset  =  dataset .filter (keep_conversation )
143159    print (
144160        "Removed" ,
145-         len ( all_conversations )  -  len (filtered_conversations ),
161+         original_num  -  len (dataset ),
146162        "conversations due to existing output files" ,
147163    )
148-     all_conversations  =  filtered_conversations 
164+ 
165+     # For debugging 
166+     if  args .debug_max_num_conversations  is  not   None :
167+         dataset  =  dataset .select (range (args .debug_max_num_conversations ))
149168
150169    # Get model config and tokenizer 
151170    model_config  =  AutoConfig .from_pretrained (args .model )
@@ -187,10 +206,7 @@ def main(args: argparse.Namespace) -> None:
187206    num_skipped_too_long  =  0 
188207    num_invalid  =  0 
189208    num_success  =  0 
190-     num_total_conversations  =  min (
191-         len (all_conversations ), args .debug_max_num_conversations  or  len (all_conversations )
192-     )
193-     pbar  =  tqdm (total = num_total_conversations , desc = f"DP#{ args .dp_rank }   Processing conversations" )
209+     pbar  =  tqdm (total = len (dataset ), desc = f"DP#{ args .dp_rank }   Processing conversations" )
194210
195211    def  _post_process_trtllm_dumped (trtllm_dumped_file : str , conversation_id : int ):
196212        """Post-process the TRTLLM dumped file to same format as HF dumped: 
@@ -234,15 +250,17 @@ async def submit_generates():
234250        nonlocal  num_skipped_too_long 
235251        nonlocal  num_invalid 
236252        tasks  =  []
237-         for  idx , entry  in  enumerate (all_conversations [:  args . debug_max_num_conversations ] ):
238-             conversation_id  =  entry .get ("conversation_id" , "{:08d}" . format ( idx ))
253+         for  idx , entry  in  enumerate (dataset ):
254+             conversation_id  =  entry .get ("conversation_id" , entry . get ( "uuid" ))
239255
240256            conversations  =  entry ["conversations" ]
241257            if  not  conversations  or  not  isinstance (conversations , list ):
242258                num_invalid  +=  1 
243259                continue 
244260
245-             input_ids  =  tokenizer .apply_chat_template (conversations , add_generation_template = False )
261+             input_ids  =  tokenizer .apply_chat_template (conversations , add_generation_template = False )[
262+                 :256 
263+             ]
246264            num_input_tokens  =  (
247265                input_ids .shape [1 ] if  isinstance (input_ids , torch .Tensor ) else  len (input_ids )
248266            )
@@ -262,12 +280,10 @@ async def submit_generates():
262280    if  num_invalid  >  0 :
263281        print (f"Skipped { num_invalid }   invalid conversations without proper fields." )
264282
265-     if  num_success  ==  num_total_conversations :
283+     if  num_success  ==  len ( dataset ) :
266284        print (f"Successfully processed all { num_success }   conversations." )
267285    else :
268-         print (
269-             f"Successfully processed { num_success }   out of { num_total_conversations }   conversations." 
270-         )
286+         print (f"Successfully processed { num_success }   out of { len (dataset )}   conversations." )
271287
272288
273289if  __name__  ==  "__main__" :
0 commit comments