Skip to content

Commit 9f1719e

Browse files
authored
Sync the config manager with the latest codebase (agentscope-ai#332)
1 parent cec5a68 commit 9f1719e

File tree

8 files changed

+384
-243
lines changed

8 files changed

+384
-243
lines changed

docs/sphinx_doc/source_zh/tutorial/example_async_mode.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ synchronizer:
9898
trainer:
9999
grad_clip: 1.0
100100
use_dynamic_bsz: true
101-
ppo_max_token_len_per_gpu: 16384
101+
max_token_len_per_gpu: 16384
102102
ulysses_sequence_parallel_size: 1
103103
```
104104

trinity/common/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ class TrainerConfig:
476476
# trainer configs
477477
grad_clip: float = 1.0
478478
use_dynamic_bsz: bool = True
479-
# if None, automatically set to 2 * model.max_model_len / ulysses_sequence_parallel_size
479+
# if None, automatically set to ceil(2 * model.max_model_len / ulysses_sequence_parallel_size)
480480
max_token_len_per_gpu: Optional[int] = None
481481
ulysses_sequence_parallel_size: int = 1 # sp size
482482
# TODO: extract more train-related params from underlying trainer engine

trinity/manager/config_manager.py

Lines changed: 114 additions & 99 deletions
Large diffs are not rendered by default.

trinity/manager/config_registry/algorithm_config_manager.py

Lines changed: 66 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
OPMDAdvantageFn,
77
PPOAdvantageFn,
88
)
9-
from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm
9+
from trinity.algorithm.algorithm import ALGORITHM_TYPE, GRPOAlgorithm
1010
from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import (
1111
ENTROPY_LOSS_FN,
1212
EntropyLossFn,
@@ -23,15 +23,16 @@
2323
from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, MixSampleStrategy
2424
from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS
2525
from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num
26+
from trinity.utils.registry import Registry
2627

2728

2829
@CONFIG_GENERATORS.register_config(
29-
default_value="ppo",
30-
other_configs={"mode": "both", "_current_default_config": PPOAlgorithm.default_config()},
30+
default_value="grpo",
31+
other_configs={"mode": "both", "_current_default_config": GRPOAlgorithm.default_config()},
3132
)
3233
def set_algorithm_type(**kwargs):
3334
def on_change():
34-
if st.session_state["algorithm_type"] == "dpo":
35+
if st.session_state["algorithm_type"] in ("dpo", "sft"):
3536
st.session_state["mode"] = "train"
3637
else:
3738
st.session_state["mode"] = "both"
@@ -45,25 +46,27 @@ def on_change():
4546
candidates = list(ALGORITHM_TYPE.modules.keys())
4647
st.selectbox(
4748
"Algorithm Type",
48-
candidates,
49+
options=candidates,
50+
format_func=lambda x: x.upper(),
4951
on_change=on_change,
5052
**kwargs,
5153
)
5254

5355

5456
@CONFIG_GENERATORS.register_config(
55-
default_value=PPOAlgorithm.default_config()["repeat_times"],
57+
default_value=GRPOAlgorithm.default_config()["repeat_times"],
5658
visible=lambda: "repeat_times" in st.session_state["_current_default_config"],
5759
other_configs={
5860
"_grouped_adv_repeat_times": 2,
5961
"_not_grouped_adv_repeat_times": 1,
6062
},
6163
)
62-
def set_repeat_times(**kwargs): # TODO
64+
def set_repeat_times(**kwargs):
6365
key = kwargs.get("key")
6466
grouped_adv_algorithms = [
6567
"grpo",
66-
"opmd", # TODO: may add rloo
68+
"opmd",
69+
"rloo",
6770
]
6871
if st.session_state["algorithm_type"] in grouped_adv_algorithms:
6972
min_repeat_times = 2
@@ -82,7 +85,7 @@ def on_change():
8285
"Repeat Times",
8386
min_value=min_repeat_times,
8487
help="`repeat_times` is used to set how many experiences each task can generate, "
85-
"and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.",
88+
"and it must be greater than `1` when `algorithm_type` is `grpo`, `opmd` or 'rloo`.",
8689
on_change=on_change,
8790
**kwargs,
8891
)
@@ -92,15 +95,17 @@ def on_change():
9295

9396

9497
@CONFIG_GENERATORS.register_config(
95-
default_value=PPOAlgorithm.default_config()["sample_strategy"],
98+
default_value=GRPOAlgorithm.default_config()["sample_strategy"],
9699
visible=lambda: "sample_strategy" in st.session_state["_current_default_config"],
97100
)
98101
def set_sample_strategy(**kwargs):
102+
on_change = _create_on_change_callback("sample_strategy", SAMPLE_STRATEGY, **kwargs)
99103
candidates = list(SAMPLE_STRATEGY.modules.keys())
100104
st.selectbox(
101105
"Sample Strategy",
102106
candidates,
103107
help="The sample strategy used to obtain experiences.",
108+
on_change=on_change,
104109
**kwargs,
105110
)
106111

@@ -124,15 +129,18 @@ def set_expert_data_ratio_in_sample_strategy(**kwargs):
124129

125130

126131
@CONFIG_GENERATORS.register_config(
127-
default_value=PPOAlgorithm.default_config()["advantage_fn"],
132+
default_value=GRPOAlgorithm.default_config()["advantage_fn"],
128133
visible=lambda: "advantage_fn" in st.session_state["_current_default_config"],
129134
)
130135
def set_advantage_fn(**kwargs):
136+
on_change = _create_on_change_callback("advantage_fn", ADVANTAGE_FN, **kwargs)
131137
candidates = list(ADVANTAGE_FN.modules.keys())
132138
st.selectbox(
133139
"Advantage Function",
134-
candidates,
140+
options=candidates,
141+
format_func=lambda x: x.upper(),
135142
help="The advantage function used to compute advantages.",
143+
on_change=on_change,
136144
**kwargs,
137145
)
138146

@@ -142,22 +150,26 @@ def set_advantage_fn(**kwargs):
142150
visible=lambda: st.session_state["advantage_fn"] in {"ppo", "reinforceplusplus"},
143151
)
144152
def set_gamma_in_advantage_fn(**kwargs):
145-
st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs)
153+
st.number_input(r"Gamma :blue-badge[$\gamma$]", help="Discounted factor used in RL", **kwargs)
146154

147155

148156
@CONFIG_GENERATORS.register_config(
149157
default_value=PPOAdvantageFn.default_args()["lam"],
150158
visible=lambda: st.session_state["advantage_fn"] == "ppo",
151159
)
152160
def set_lam_in_advantage_fn(**kwargs):
153-
st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs)
161+
st.number_input(
162+
r"Lambda :blue-badge[$\lambda$]",
163+
help="Lambda value when computing Generalized Advantage Estimation",
164+
**kwargs,
165+
)
154166

155167

156168
@CONFIG_GENERATORS.register_config(
157169
default_value=GRPOAdvantageFn.default_args()["epsilon"],
158170
visible=lambda: st.session_state["advantage_fn"] == "grpo",
159171
)
160-
def set_epsilon_in_advantage_fn(**kwargs): # TODO: update help message
172+
def set_epsilon_in_advantage_fn(**kwargs):
161173
st.number_input(
162174
r"GRPO Epsilon",
163175
help=r"""
@@ -194,14 +206,17 @@ def set_tau_in_advantage_fn(**kwargs):
194206

195207

196208
@CONFIG_GENERATORS.register_config(
197-
default_value=PPOAlgorithm.default_config()["kl_loss_fn"],
209+
default_value=GRPOAlgorithm.default_config()["kl_loss_fn"],
198210
visible=lambda: "kl_loss_fn" in st.session_state["_current_default_config"],
199211
)
200212
def set_kl_loss_fn(**kwargs):
213+
on_change = _create_on_change_callback("kl_loss_fn", KL_FN, **kwargs)
201214
candidates = list(KL_FN.modules.keys())
202215
st.selectbox(
203216
"KL Loss Type",
204-
candidates,
217+
options=candidates,
218+
format_func=lambda x: x.upper(),
219+
on_change=on_change,
205220
**kwargs,
206221
)
207222

@@ -224,14 +239,17 @@ def set_kl_coef_in_kl_loss_fn(**kwargs):
224239

225240

226241
@CONFIG_GENERATORS.register_config(
227-
default_value=PPOAlgorithm.default_config()["kl_penalty_fn"],
242+
default_value=GRPOAlgorithm.default_config()["kl_penalty_fn"],
228243
visible=lambda: "kl_penalty_fn" in st.session_state["_current_default_config"],
229244
)
230245
def set_kl_penalty_fn(**kwargs):
246+
on_change = _create_on_change_callback("kl_penalty_fn", KL_FN, **kwargs)
231247
candidates = list(KL_FN.modules.keys())
232248
st.selectbox(
233249
"KL Penalty Type",
234-
candidates,
250+
options=candidates,
251+
format_func=lambda x: x.upper(),
252+
on_change=on_change,
235253
**kwargs,
236254
)
237255

@@ -267,14 +285,17 @@ def set_kl_coef_in_kl_penalty_fn(**kwargs):
267285

268286

269287
@CONFIG_GENERATORS.register_config(
270-
default_value=PPOAlgorithm.default_config()["policy_loss_fn"],
288+
default_value=GRPOAlgorithm.default_config()["policy_loss_fn"],
271289
visible=lambda: "policy_loss_fn" in st.session_state["_current_default_config"],
272290
)
273291
def set_policy_loss_fn(**kwargs):
292+
on_change = _create_on_change_callback("policy_loss_fn", POLICY_LOSS_FN, **kwargs)
274293
candidates = list(POLICY_LOSS_FN.modules.keys())
275294
st.selectbox(
276295
"Policy Loss Fn",
277-
candidates,
296+
options=candidates,
297+
format_func=lambda x: x.upper(),
298+
on_change=on_change,
278299
**kwargs,
279300
)
280301

@@ -356,12 +377,18 @@ def set_mu_in_policy_loss_fn(**kwargs):
356377

357378

358379
@CONFIG_GENERATORS.register_config(
359-
default_value=PPOAlgorithm.default_config()["entropy_loss_fn"],
380+
default_value=GRPOAlgorithm.default_config()["entropy_loss_fn"],
360381
visible=lambda: "entropy_loss_fn" in st.session_state["_current_default_config"],
361382
)
362383
def set_entropy_loss_fn(**kwargs):
384+
on_change = _create_on_change_callback("entropy_loss_fn", ENTROPY_LOSS_FN, **kwargs)
363385
candidates = list(ENTROPY_LOSS_FN.modules.keys())
364-
st.selectbox("Entropy Loss Function", candidates, **kwargs)
386+
st.selectbox(
387+
"Entropy Loss Function",
388+
options=candidates,
389+
on_change=on_change,
390+
**kwargs,
391+
)
365392

366393

367394
@CONFIG_GENERATORS.register_config(
@@ -376,3 +403,19 @@ def set_entropy_coef_in_entropy_loss_fn(**kwargs):
376403
format="%.1e",
377404
**kwargs,
378405
)
406+
407+
408+
# define on_change
409+
def _create_on_change_callback(key_name: str, registry: Registry, **kwargs):
410+
"""Creates an on_change callback to update dependent configs."""
411+
412+
def on_change():
413+
value = st.session_state[kwargs.get("key", key_name)]
414+
value_class = registry.get(value)
415+
if value_class:
416+
default_args = value_class.default_args()
417+
for arg_key, arg_value in default_args.items():
418+
full_key = f"{arg_key}_in_{key_name}"
419+
st.session_state[full_key] = arg_value
420+
421+
return on_change

0 commit comments

Comments
 (0)