Skip to content

Commit 831a503

Browse files
committed
fix error in config manager
1 parent 684bcba commit 831a503

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

trinity/manager/config_manager.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _init_default_config(self):
107107
"runner_num": 32,
108108
"_grouped_adv_repeat_times": 2,
109109
"_not_grouped_adv_repeat_times": 1,
110-
"n": 1,
110+
"repeat_times": 1,
111111
"tensor_parallel_size": 1,
112112
"enable_prefix_caching": False,
113113
"enforce_eager": True,
@@ -787,6 +787,33 @@ def on_change():
787787
def _set_ppo_epochs(self):
788788
st.number_input("PPO Epochs", key="ppo_epochs", min_value=1)
789789

790+
def _set_repeat_times(self): # TODO
791+
grouped_adv_algorithms = [
792+
AlgorithmType.GRPO.value,
793+
AlgorithmType.OPMD.value, # TODO: may add rloo
794+
]
795+
if st.session_state["algorithm_type"] in grouped_adv_algorithms:
796+
min_repeat_times = 2
797+
st.session_state["repeat_times"] = st.session_state["_grouped_adv_repeat_times"]
798+
else:
799+
min_repeat_times = 1
800+
st.session_state["repeat_times"] = st.session_state["_not_grouped_adv_repeat_times"]
801+
802+
def on_change():
803+
if st.session_state["algorithm_type"] in grouped_adv_algorithms:
804+
st.session_state["_grouped_adv_repeat_times"] = st.session_state["repeat_times"]
805+
else:
806+
st.session_state["_not_grouped_adv_repeat_times"] = st.session_state["repeat_times"]
807+
808+
st.number_input(
809+
"Repeat Times",
810+
key="repeat_times",
811+
min_value=min_repeat_times,
812+
help="`repeat_times` is used to set how many experiences each task can generate, "
813+
"and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
814+
on_change=on_change,
815+
)
816+
790817
def _set_training_strategy(self):
791818
st.selectbox(
792819
"Training Strategy",
@@ -1099,7 +1126,7 @@ def beginner_mode(self):
10991126
self._check_engine_num_and_tp_size()
11001127

11011128
self._set_configs_with_st_columns(
1102-
["total_epochs", "train_batch_size", "ppo_epochs", "n"]
1129+
["total_epochs", "train_batch_size", "ppo_epochs", "repeat_times"]
11031130
if st.session_state["mode"] == "both"
11041131
else ["total_epochs", "train_batch_size", "ppo_epochs"]
11051132
)
@@ -1187,7 +1214,7 @@ def _expert_buffer_part(self):
11871214

11881215
def _expert_explorer_part(self):
11891216
self._set_configs_with_st_columns(
1190-
["engine_type", "engine_num", "tensor_parallel_size", "n"]
1217+
["engine_type", "engine_num", "tensor_parallel_size", "repeat_times"]
11911218
)
11921219
self._check_engine_num_and_tp_size()
11931220

@@ -1332,7 +1359,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13321359
else:
13331360
fsdp_config = {}
13341361

1335-
ppo_max_token_len_per_gpu = st.session_state["n"] * (
1362+
ppo_max_token_len_per_gpu = st.session_state["repeat_times"] * (
13361363
st.session_state["max_prompt_tokens"] + st.session_state["max_response_tokens"]
13371364
)
13381365

@@ -1349,7 +1376,8 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
13491376
"prompt_key": "placeholder",
13501377
"max_prompt_length": st.session_state["max_prompt_tokens"],
13511378
"max_response_length": st.session_state["max_response_tokens"],
1352-
"train_batch_size": st.session_state["train_batch_size"] * st.session_state["n"],
1379+
"train_batch_size": st.session_state["train_batch_size"]
1380+
* st.session_state["repeat_times"],
13531381
"val_batch_size": None,
13541382
"return_raw_input_ids": False,
13551383
"return_raw_chat": False,
@@ -1437,7 +1465,7 @@ def _generate_verl_config(self, trainer_nnodes: int = 1, trainer_n_gpus_per_node
14371465
"disable_log_stats": True,
14381466
"enable_chunked_prefill": True,
14391467
"do_sample": True,
1440-
"n": st.session_state["n"],
1468+
"n": st.session_state["repeat_times"],
14411469
},
14421470
},
14431471
"critic": {
@@ -1596,12 +1624,15 @@ def generate_config(self):
15961624
"mode": st.session_state["mode"],
15971625
"project": st.session_state["project"],
15981626
"name": st.session_state["name"],
1599-
"algorithm_type": st.session_state["algorithm_type"],
1627+
"checkpoint_root_dir": st.session_state["checkpoint_path"],
1628+
"algorithm": {
1629+
"algorithm_type": st.session_state["algorithm_type"],
1630+
"repeat_times": st.session_state["repeat_times"],
1631+
},
16001632
"model": {
16011633
"model_path": st.session_state["model_path"],
16021634
"max_prompt_tokens": st.session_state["max_prompt_tokens"],
16031635
"max_response_tokens": st.session_state["max_response_tokens"],
1604-
"checkpoint_path": st.session_state["checkpoint_path"],
16051636
},
16061637
"cluster": {
16071638
"node_num": st.session_state["node_num"],
@@ -1624,7 +1655,7 @@ def generate_config(self):
16241655
"response_key": st.session_state["taskset_response_key"],
16251656
},
16261657
"rollout_args": {
1627-
"n": st.session_state["n"],
1658+
"n": st.session_state["repeat_times"],
16281659
"temperature": st.session_state["temperature"],
16291660
"top_p": st.session_state["top_p"],
16301661
"top_k": st.session_state["top_k"],

trinity/trainer/verl_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def train_sft_step(self, experiences: Experiences) -> Tuple[bool, int]:
304304
if self.sft_warmup_step_num == self.config.trainer.sft_warmup_steps:
305305
self.logger.log(
306306
data={"sft_warmup_steps": self.sft_warmup_step_num},
307-
step=self.global_steps,
307+
step=self.global_steps - 1,
308308
)
309309
with _timer("save_checkpoint", timing_raw):
310310
self._save_checkpoint()

0 commit comments

Comments
 (0)