@@ -107,7 +107,7 @@ def _init_default_config(self):
107107 "runner_num" : 32 ,
108108 "_grouped_adv_repeat_times" : 2 ,
109109 "_not_grouped_adv_repeat_times" : 1 ,
110- "n " : 1 ,
110+ "repeat_times " : 1 ,
111111 "tensor_parallel_size" : 1 ,
112112 "enable_prefix_caching" : False ,
113113 "enforce_eager" : True ,
@@ -787,6 +787,33 @@ def on_change():
787787 def _set_ppo_epochs (self ):
788788 st .number_input ("PPO Epochs" , key = "ppo_epochs" , min_value = 1 )
789789
790+ def _set_repeat_times (self ): # TODO
791+ grouped_adv_algorithms = [
792+ AlgorithmType .GRPO .value ,
793+ AlgorithmType .OPMD .value , # TODO: may add rloo
794+ ]
795+ if st .session_state ["algorithm_type" ] in grouped_adv_algorithms :
796+ min_repeat_times = 2
797+ st .session_state ["repeat_times" ] = st .session_state ["_grouped_adv_repeat_times" ]
798+ else :
799+ min_repeat_times = 1
800+ st .session_state ["repeat_times" ] = st .session_state ["_not_grouped_adv_repeat_times" ]
801+
802+ def on_change ():
803+ if st .session_state ["algorithm_type" ] in grouped_adv_algorithms :
804+ st .session_state ["_grouped_adv_repeat_times" ] = st .session_state ["repeat_times" ]
805+ else :
806+ st .session_state ["_not_grouped_adv_repeat_times" ] = st .session_state ["repeat_times" ]
807+
808+ st .number_input (
809+ "Repeat Times" ,
810+ key = "repeat_times" ,
811+ min_value = min_repeat_times ,
812+ help = "`repeat_times` is used to set how many experiences each task can generate, "
813+ "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`." ,
814+ on_change = on_change ,
815+ )
816+
790817 def _set_training_strategy (self ):
791818 st .selectbox (
792819 "Training Strategy" ,
@@ -1099,7 +1126,7 @@ def beginner_mode(self):
10991126 self ._check_engine_num_and_tp_size ()
11001127
11011128 self ._set_configs_with_st_columns (
1102- ["total_epochs" , "train_batch_size" , "ppo_epochs" , "n " ]
1129+ ["total_epochs" , "train_batch_size" , "ppo_epochs" , "repeat_times " ]
11031130 if st .session_state ["mode" ] == "both"
11041131 else ["total_epochs" , "train_batch_size" , "ppo_epochs" ]
11051132 )
@@ -1187,7 +1214,7 @@ def _expert_buffer_part(self):
11871214
11881215 def _expert_explorer_part (self ):
11891216 self ._set_configs_with_st_columns (
1190- ["engine_type" , "engine_num" , "tensor_parallel_size" , "n " ]
1217+ ["engine_type" , "engine_num" , "tensor_parallel_size" , "repeat_times " ]
11911218 )
11921219 self ._check_engine_num_and_tp_size ()
11931220
@@ -1332,7 +1359,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13321359 else :
13331360 fsdp_config = {}
13341361
1335- ppo_max_token_len_per_gpu = st .session_state ["n " ] * (
1362+ ppo_max_token_len_per_gpu = st .session_state ["repeat_times " ] * (
13361363 st .session_state ["max_prompt_tokens" ] + st .session_state ["max_response_tokens" ]
13371364 )
13381365
@@ -1349,7 +1376,8 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13491376 "prompt_key" : "placeholder" ,
13501377 "max_prompt_length" : st .session_state ["max_prompt_tokens" ],
13511378 "max_response_length" : st .session_state ["max_response_tokens" ],
1352- "train_batch_size" : st .session_state ["train_batch_size" ] * st .session_state ["n" ],
1379+ "train_batch_size" : st .session_state ["train_batch_size" ]
1380+ * st .session_state ["repeat_times" ],
13531381 "val_batch_size" : None ,
13541382 "return_raw_input_ids" : False ,
13551383 "return_raw_chat" : False ,
@@ -1437,7 +1465,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
14371465 "disable_log_stats" : True ,
14381466 "enable_chunked_prefill" : True ,
14391467 "do_sample" : True ,
1440- "n" : st .session_state ["n " ],
1468+ "n" : st .session_state ["repeat_times " ],
14411469 },
14421470 },
14431471 "critic" : {
@@ -1596,12 +1624,15 @@ def generate_config(self):
15961624 "mode" : st .session_state ["mode" ],
15971625 "project" : st .session_state ["project" ],
15981626 "name" : st .session_state ["name" ],
1599- "algorithm_type" : st .session_state ["algorithm_type" ],
1627+ "checkpoint_root_dir" : st .session_state ["checkpoint_path" ],
1628+ "algorithm" : {
1629+ "algorithm_type" : st .session_state ["algorithm_type" ],
1630+ "repeat_times" : st .session_state ["repeat_times" ],
1631+ },
16001632 "model" : {
16011633 "model_path" : st .session_state ["model_path" ],
16021634 "max_prompt_tokens" : st .session_state ["max_prompt_tokens" ],
16031635 "max_response_tokens" : st .session_state ["max_response_tokens" ],
1604- "checkpoint_path" : st .session_state ["checkpoint_path" ],
16051636 },
16061637 "cluster" : {
16071638 "node_num" : st .session_state ["node_num" ],
@@ -1624,7 +1655,7 @@ def generate_config(self):
16241655 "response_key" : st .session_state ["taskset_response_key" ],
16251656 },
16261657 "rollout_args" : {
1627- "n" : st .session_state ["n " ],
1658+ "n" : st .session_state ["repeat_times " ],
16281659 "temperature" : st .session_state ["temperature" ],
16291660 "top_p" : st .session_state ["top_p" ],
16301661 "top_k" : st .session_state ["top_k" ],
0 commit comments