@@ -72,7 +72,7 @@ def _init_default_config(self):
7272 "_not_dpo_storage_type" : StorageType .QUEUE .value ,
7373 "storage_type" : StorageType .QUEUE .value ,
7474 "train_dataset_path" : "" ,
75- "max_retry_times " : 3 ,
75+ "buffer_max_retry_times " : 3 ,
7676 "max_retry_interval" : 1 ,
7777 "dpo_dataset_train_split" : "train" ,
7878 "dpo_dataset_prompt_type" : PromptType .MESSAGES .value ,
@@ -88,31 +88,37 @@ def _init_default_config(self):
8888 # Explorer and Sync Configs
8989 "engine_type" : "vllm_async" ,
9090 "engine_num" : 2 ,
91- "tensor_parallel_size " : 1 ,
91+ "runner_num " : 32 ,
9292 "_grouped_adv_repeat_times" : 2 ,
9393 "_not_grouped_adv_repeat_times" : 1 ,
9494 "repeat_times" : 1 ,
95- "_not_dpo_sync_method" : SyncMethod .NCCL .value ,
96- "sync_method" : SyncMethod .NCCL .value ,
97- "sync_interval" : 10 ,
98- "sync_timeout" : 1200 ,
99- "runner_num" : 32 ,
100- "max_pending_requests" : 32 ,
101- "max_waiting_steps" : 4 ,
95+ "eval_interval" : 1000 ,
96+ "tensor_parallel_size" : 1 ,
97+ "enable_prefix_caching" : False ,
98+ "enforce_eager" : True ,
10299 "dtype" : "bfloat16" ,
103- "backend" : "nccl" ,
104100 "temperature" : 1.0 ,
105101 "top_p" : 1.0 ,
106102 "top_k" : - 1 ,
107103 "seed" : 42 ,
108104 "logprobs" : 0 ,
109- "enable_prefix_caching" : False ,
110- "enforce_eager" : True ,
105+ "backend" : "nccl" ,
106+ "use_ray" : False ,
107+ "gpu_memory_utilization" : 0.9 ,
108+ "enable_chunked_prefill" : False ,
109+ "max_pending_requests" : 32 ,
110+ "max_waiting_steps" : 4 ,
111+ "max_timeout" : 900 ,
112+ "explorer_max_retry_times" : 2 ,
113+ # Synchronizer Configs
114+ "_not_dpo_sync_method" : SyncMethod .NCCL .value ,
115+ "sync_method" : SyncMethod .NCCL .value ,
116+ "sync_interval" : 10 ,
117+ "sync_timeout" : 1200 ,
111118 # Trainer Configs
112119 "trainer_type" : "verl" ,
113120 "algorithm_type" : AlgorithmType .PPO .value ,
114121 "sft_warmup_steps" : 0 ,
115- "eval_interval" : 1000 ,
116122 "_nccl_save_interval" : 100 ,
117123 "save_interval" : 100 ,
118124 # veRL Trainer Configs
@@ -370,8 +376,8 @@ def _set_train_dataset_path(self): # TODO
370376 self .unfinished_fields .add ("train_dataset_path" )
371377 st .warning ("Please input train dataset path." )
372378
373- def _set_max_retry_times (self ):
374- st .number_input ("Max Retry Times" , key = "max_retry_times " , min_value = 1 )
379+ def _set_buffer_max_retry_times (self ):
380+ st .number_input ("Max Retry Times" , key = "buffer_max_retry_times " , min_value = 1 )
375381
376382 def _set_max_retry_interval (self ):
377383 st .number_input ("Max Retry Interval" , key = "max_retry_interval" , min_value = 1 )
@@ -613,11 +619,28 @@ def _set_logprobs(self):
613619 st .number_input ("Logprobs" , key = "logprobs" , min_value = 0 , max_value = 20 )
614620
615621 def _set_enable_prefix_caching (self ):
616- st .checkbox ("Enable Prefix Caching" , key = "enable_prefix_caching" )
622+ st .checkbox ("Prefix Caching" , key = "enable_prefix_caching" )
617623
618624 def _set_enforce_eager (self ):
619625 st .checkbox ("Enforce Eager" , key = "enforce_eager" )
620626
627+ def _set_use_ray (self ):
628+ st .checkbox ("Use Ray" , key = "use_ray" )
629+
630+ def _set_gpu_memory_utilization (self ):
631+ st .number_input (
632+ "GPU Memory Utilization" , key = "gpu_memory_utilization" , min_value = 0.0 , max_value = 1.0
633+ )
634+
635+ def _set_enable_chunked_prefill (self ):
636+ st .checkbox ("Chunked Prefill" , key = "enable_chunked_prefill" )
637+
638+ def _set_max_timeout (self ):
639+ st .number_input ("Max Timeout" , key = "max_timeout" , min_value = 0 )
640+
641+ def _set_explorer_max_retry_times (self ):
642+ st .number_input ("Explorer Max Retry Times" , key = "explorer_max_retry_times" , min_value = 0 )
643+
621644 def _set_trainer_type (self ):
622645 st .selectbox ("Trainer Type" , ["verl" ], key = "trainer_type" )
623646
@@ -1079,7 +1102,7 @@ def _expert_buffer_part(self):
10791102
10801103 self .buffer_advanced_tab = st .expander ("Advanced Config" )
10811104 with self .buffer_advanced_tab :
1082- self ._set_configs_with_st_columns (["max_retry_times " , "max_retry_interval" ])
1105+ self ._set_configs_with_st_columns (["buffer_max_retry_times " , "max_retry_interval" ])
10831106
10841107 self ._set_sft_warmup_dataset_path ()
10851108 self ._set_sft_warmup_dataset_args ()
@@ -1094,12 +1117,22 @@ def _expert_explorer_part(self):
10941117
10951118 with st .expander ("Advanced Config" ):
10961119 self ._set_configs_with_st_columns (
1097- ["runner_num" , "max_pending_requests " , "max_waiting_steps " , "dtype " ]
1120+ ["runner_num" , "temperature " , "top_p " , "top_k" , "seed" , "logprobs " ]
10981121 )
10991122
1100- self ._set_configs_with_st_columns (["backend" , "temperature" , "seed" , "logprobs" ])
1123+ self ._set_configs_with_st_columns (["dtype" , "backend" , "gpu_memory_utilization" ])
1124+ self ._set_configs_with_st_columns (
1125+ [
1126+ "max_pending_requests" ,
1127+ "max_waiting_steps" ,
1128+ "max_timeout" ,
1129+ "explorer_max_retry_times" ,
1130+ ]
1131+ )
11011132
1102- self ._set_configs_with_st_columns (["enable_prefix_caching" , "enforce_eager" ])
1133+ self ._set_configs_with_st_columns (
1134+ ["enable_prefix_caching" , "enforce_eager" , "use_ray" , "enable_chunked_prefill" ]
1135+ )
11031136
11041137 def _expert_trainer_part (self ):
11051138 self ._set_configs_with_st_columns ( # TODO: may add `trainer_type`
@@ -1442,6 +1475,12 @@ def generate_config(self):
14421475 else :
14431476 trainer_n_gpus_per_node = st .session_state ["gpu_per_node" ]
14441477
1478+ critic_model_path = (
1479+ st .session_state ["critic_model_path" ].strip ()
1480+ if st .session_state ["critic_model_path" ].strip ()
1481+ else st .session_state ["model_path" ]
1482+ )
1483+
14451484 if st .session_state ["algorithm_type" ] == AlgorithmType .DPO .value :
14461485 train_dataset_path = (
14471486 st .session_state ["train_dataset_path" ].strip ()
@@ -1495,6 +1534,7 @@ def generate_config(self):
14951534 },
14961535 "model" : {
14971536 "model_path" : st .session_state ["model_path" ],
1537+ "critic_model_path" : critic_model_path ,
14981538 "max_prompt_tokens" : st .session_state ["max_prompt_tokens" ],
14991539 "max_response_tokens" : st .session_state ["max_response_tokens" ],
15001540 "checkpoint_path" : st .session_state ["checkpoint_path" ],
@@ -1504,18 +1544,16 @@ def generate_config(self):
15041544 "gpu_per_node" : st .session_state ["gpu_per_node" ],
15051545 },
15061546 "buffer" : {
1507- "max_retry_times" : st .session_state ["max_retry_times " ],
1547+ "max_retry_times" : st .session_state ["buffer_max_retry_times " ],
15081548 "max_retry_interval" : st .session_state ["max_retry_interval" ],
15091549 "train_dataset" : {
15101550 "name" : "experience_buffer" , # TODO
15111551 "storage_type" : st .session_state ["storage_type" ],
1512- "algorithm_type" : st .session_state ["algorithm_type" ],
15131552 "path" : train_dataset_path ,
15141553 },
15151554 "sft_warmup_dataset" : {
15161555 "name" : "sft_warmup_dataset" ,
15171556 "storage_type" : sft_storage_type ,
1518- "algorithm_type" : AlgorithmType .SFT .value ,
15191557 "path" : st .session_state ["sft_warmup_dataset_path" ],
15201558 "kwargs" : {
15211559 "train_split" : st .session_state ["sft_warmup_train_split" ],
@@ -1530,18 +1568,27 @@ def generate_config(self):
15301568 "engine_type" : st .session_state ["engine_type" ],
15311569 "engine_num" : st .session_state ["engine_num" ],
15321570 "runner_num" : st .session_state ["runner_num" ],
1571+ "repeat_times" : st .session_state ["repeat_times" ],
1572+ # "chat_template": None, # TODO: add chat template
15331573 "eval_interval" : st .session_state ["eval_interval" ],
15341574 "tensor_parallel_size" : st .session_state ["tensor_parallel_size" ],
15351575 "enable_prefix_caching" : st .session_state ["enable_prefix_caching" ],
15361576 "enforce_eager" : st .session_state ["enforce_eager" ],
15371577 "dtype" : st .session_state ["dtype" ],
15381578 "temperature" : st .session_state ["temperature" ],
1579+ "top_p" : st .session_state ["top_p" ], # TODO
1580+ "top_k" : st .session_state ["top_k" ], # TODO
15391581 "seed" : st .session_state ["seed" ],
15401582 "logprobs" : st .session_state ["logprobs" ],
1541- "repeat_times" : st .session_state ["repeat_times" ],
15421583 "backend" : st .session_state ["backend" ],
1584+ "use_ray" : st .session_state ["use_ray" ], # TODO
1585+ "gpu_memory_utilization" : st .session_state ["gpu_memory_utilization" ], # TODO
1586+ "enable_chunked_prefill" : st .session_state ["enable_chunked_prefill" ], # TODO
1587+ "use_v1" : True ,
15431588 "max_pending_requests" : st .session_state ["max_pending_requests" ],
15441589 "max_waiting_steps" : st .session_state ["max_waiting_steps" ],
1590+ "max_timeout" : st .session_state ["max_timeout" ], # TODO
1591+ "max_retry_times" : st .session_state ["explorer_max_retry_times" ], # TODO
15451592 },
15461593 "synchronizer" : {
15471594 "sync_method" : st .session_state ["sync_method" ],
0 commit comments