66 OPMDAdvantageFn ,
77 PPOAdvantageFn ,
88)
9- from trinity .algorithm .algorithm import ALGORITHM_TYPE , PPOAlgorithm
9+ from trinity .algorithm .algorithm import ALGORITHM_TYPE , GRPOAlgorithm
1010from trinity .algorithm .entropy_loss_fn .entropy_loss_fn import (
1111 ENTROPY_LOSS_FN ,
1212 EntropyLossFn ,
2323from trinity .algorithm .sample_strategy import SAMPLE_STRATEGY , MixSampleStrategy
2424from trinity .manager .config_registry .config_registry import CONFIG_GENERATORS
2525from trinity .manager .config_registry .model_config_manager import set_trainer_gpu_num
26+ from trinity .utils .registry import Registry
2627
2728
2829@CONFIG_GENERATORS .register_config (
29- default_value = "ppo " ,
30- other_configs = {"mode" : "both" , "_current_default_config" : PPOAlgorithm .default_config ()},
30+ default_value = "grpo " ,
31+ other_configs = {"mode" : "both" , "_current_default_config" : GRPOAlgorithm .default_config ()},
3132)
3233def set_algorithm_type (** kwargs ):
3334 def on_change ():
34- if st .session_state ["algorithm_type" ] == "dpo" :
35+ if st .session_state ["algorithm_type" ] in ( "dpo" , "sft" ) :
3536 st .session_state ["mode" ] = "train"
3637 else :
3738 st .session_state ["mode" ] = "both"
@@ -45,25 +46,27 @@ def on_change():
4546 candidates = list (ALGORITHM_TYPE .modules .keys ())
4647 st .selectbox (
4748 "Algorithm Type" ,
48- candidates ,
49+ options = candidates ,
50+ format_func = lambda x : x .upper (),
4951 on_change = on_change ,
5052 ** kwargs ,
5153 )
5254
5355
5456@CONFIG_GENERATORS .register_config (
55- default_value = PPOAlgorithm .default_config ()["repeat_times" ],
57+ default_value = GRPOAlgorithm .default_config ()["repeat_times" ],
5658 visible = lambda : "repeat_times" in st .session_state ["_current_default_config" ],
5759 other_configs = {
5860 "_grouped_adv_repeat_times" : 2 ,
5961 "_not_grouped_adv_repeat_times" : 1 ,
6062 },
6163)
62- def set_repeat_times (** kwargs ): # TODO
64+ def set_repeat_times (** kwargs ):
6365 key = kwargs .get ("key" )
6466 grouped_adv_algorithms = [
6567 "grpo" ,
66- "opmd" , # TODO: may add rloo
68+ "opmd" ,
69+ "rloo" ,
6770 ]
6871 if st .session_state ["algorithm_type" ] in grouped_adv_algorithms :
6972 min_repeat_times = 2
@@ -82,7 +85,7 @@ def on_change():
8285 "Repeat Times" ,
8386 min_value = min_repeat_times ,
8487 help = "`repeat_times` is used to set how many experiences each task can generate, "
85- "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo `." ,
88+ "and it must be greater than `1` when `algorithm_type` is `grpo`, ` opmd` or 'rloo `." ,
8689 on_change = on_change ,
8790 ** kwargs ,
8891 )
@@ -92,15 +95,17 @@ def on_change():
9295
9396
9497@CONFIG_GENERATORS .register_config (
95- default_value = PPOAlgorithm .default_config ()["sample_strategy" ],
98+ default_value = GRPOAlgorithm .default_config ()["sample_strategy" ],
9699 visible = lambda : "sample_strategy" in st .session_state ["_current_default_config" ],
97100)
98101def set_sample_strategy (** kwargs ):
102+ on_change = _create_on_change_callback ("sample_strategy" , SAMPLE_STRATEGY , ** kwargs )
99103 candidates = list (SAMPLE_STRATEGY .modules .keys ())
100104 st .selectbox (
101105 "Sample Strategy" ,
102106 candidates ,
103107 help = "The sample strategy used to obtain experiences." ,
108+ on_change = on_change ,
104109 ** kwargs ,
105110 )
106111
@@ -124,15 +129,18 @@ def set_expert_data_ratio_in_sample_strategy(**kwargs):
124129
125130
126131@CONFIG_GENERATORS .register_config (
127- default_value = PPOAlgorithm .default_config ()["advantage_fn" ],
132+ default_value = GRPOAlgorithm .default_config ()["advantage_fn" ],
128133 visible = lambda : "advantage_fn" in st .session_state ["_current_default_config" ],
129134)
130135def set_advantage_fn (** kwargs ):
136+ on_change = _create_on_change_callback ("advantage_fn" , ADVANTAGE_FN , ** kwargs )
131137 candidates = list (ADVANTAGE_FN .modules .keys ())
132138 st .selectbox (
133139 "Advantage Function" ,
134- candidates ,
140+ options = candidates ,
141+ format_func = lambda x : x .upper (),
135142 help = "The advantage function used to compute advantages." ,
143+ on_change = on_change ,
136144 ** kwargs ,
137145 )
138146
@@ -142,22 +150,26 @@ def set_advantage_fn(**kwargs):
142150 visible = lambda : st .session_state ["advantage_fn" ] in {"ppo" , "reinforceplusplus" },
143151)
144152def set_gamma_in_advantage_fn (** kwargs ):
145- st .number_input (r"Gamma :blue-badge[$\gamma$]" , ** kwargs )
153+ st .number_input (r"Gamma :blue-badge[$\gamma$]" , help = "Discounted factor used in RL" , ** kwargs )
146154
147155
148156@CONFIG_GENERATORS .register_config (
149157 default_value = PPOAdvantageFn .default_args ()["lam" ],
150158 visible = lambda : st .session_state ["advantage_fn" ] == "ppo" ,
151159)
152160def set_lam_in_advantage_fn (** kwargs ):
153- st .number_input (r"Lambda :blue-badge[$\lambda$]" , ** kwargs )
161+ st .number_input (
162+ r"Lambda :blue-badge[$\lambda$]" ,
163+ help = "Lambda value when computing Generalized Advantage Estimation" ,
164+ ** kwargs ,
165+ )
154166
155167
156168@CONFIG_GENERATORS .register_config (
157169 default_value = GRPOAdvantageFn .default_args ()["epsilon" ],
158170 visible = lambda : st .session_state ["advantage_fn" ] == "grpo" ,
159171)
160- def set_epsilon_in_advantage_fn (** kwargs ): # TODO: update help message
172+ def set_epsilon_in_advantage_fn (** kwargs ):
161173 st .number_input (
162174 r"GRPO Epsilon" ,
163175 help = r"""
@@ -194,14 +206,17 @@ def set_tau_in_advantage_fn(**kwargs):
194206
195207
196208@CONFIG_GENERATORS .register_config (
197- default_value = PPOAlgorithm .default_config ()["kl_loss_fn" ],
209+ default_value = GRPOAlgorithm .default_config ()["kl_loss_fn" ],
198210 visible = lambda : "kl_loss_fn" in st .session_state ["_current_default_config" ],
199211)
200212def set_kl_loss_fn (** kwargs ):
213+ on_change = _create_on_change_callback ("kl_loss_fn" , KL_FN , ** kwargs )
201214 candidates = list (KL_FN .modules .keys ())
202215 st .selectbox (
203216 "KL Loss Type" ,
204- candidates ,
217+ options = candidates ,
218+ format_func = lambda x : x .upper (),
219+ on_change = on_change ,
205220 ** kwargs ,
206221 )
207222
@@ -224,14 +239,17 @@ def set_kl_coef_in_kl_loss_fn(**kwargs):
224239
225240
226241@CONFIG_GENERATORS .register_config (
227- default_value = PPOAlgorithm .default_config ()["kl_penalty_fn" ],
242+ default_value = GRPOAlgorithm .default_config ()["kl_penalty_fn" ],
228243 visible = lambda : "kl_penalty_fn" in st .session_state ["_current_default_config" ],
229244)
230245def set_kl_penalty_fn (** kwargs ):
246+ on_change = _create_on_change_callback ("kl_penalty_fn" , KL_FN , ** kwargs )
231247 candidates = list (KL_FN .modules .keys ())
232248 st .selectbox (
233249 "KL Penalty Type" ,
234- candidates ,
250+ options = candidates ,
251+ format_func = lambda x : x .upper (),
252+ on_change = on_change ,
235253 ** kwargs ,
236254 )
237255
@@ -267,14 +285,17 @@ def set_kl_coef_in_kl_penalty_fn(**kwargs):
267285
268286
269287@CONFIG_GENERATORS .register_config (
270- default_value = PPOAlgorithm .default_config ()["policy_loss_fn" ],
288+ default_value = GRPOAlgorithm .default_config ()["policy_loss_fn" ],
271289 visible = lambda : "policy_loss_fn" in st .session_state ["_current_default_config" ],
272290)
273291def set_policy_loss_fn (** kwargs ):
292+ on_change = _create_on_change_callback ("policy_loss_fn" , POLICY_LOSS_FN , ** kwargs )
274293 candidates = list (POLICY_LOSS_FN .modules .keys ())
275294 st .selectbox (
276295 "Policy Loss Fn" ,
277- candidates ,
296+ options = candidates ,
297+ format_func = lambda x : x .upper (),
298+ on_change = on_change ,
278299 ** kwargs ,
279300 )
280301
@@ -356,12 +377,18 @@ def set_mu_in_policy_loss_fn(**kwargs):
356377
357378
358379@CONFIG_GENERATORS .register_config (
359- default_value = PPOAlgorithm .default_config ()["entropy_loss_fn" ],
380+ default_value = GRPOAlgorithm .default_config ()["entropy_loss_fn" ],
360381 visible = lambda : "entropy_loss_fn" in st .session_state ["_current_default_config" ],
361382)
362383def set_entropy_loss_fn (** kwargs ):
384+ on_change = _create_on_change_callback ("entropy_loss_fn" , ENTROPY_LOSS_FN , ** kwargs )
363385 candidates = list (ENTROPY_LOSS_FN .modules .keys ())
364- st .selectbox ("Entropy Loss Function" , candidates , ** kwargs )
386+ st .selectbox (
387+ "Entropy Loss Function" ,
388+ options = candidates ,
389+ on_change = on_change ,
390+ ** kwargs ,
391+ )
365392
366393
367394@CONFIG_GENERATORS .register_config (
@@ -376,3 +403,19 @@ def set_entropy_coef_in_entropy_loss_fn(**kwargs):
376403 format = "%.1e" ,
377404 ** kwargs ,
378405 )
406+
407+
408+ # define on_change
409+ def _create_on_change_callback (key_name : str , registry : Registry , ** kwargs ):
410+ """Creates an on_change callback to update dependent configs."""
411+
412+ def on_change ():
413+ value = st .session_state [kwargs .get ("key" , key_name )]
414+ value_class = registry .get (value )
415+ if value_class :
416+ default_args = value_class .default_args ()
417+ for arg_key , arg_value in default_args .items ():
418+ full_key = f"{ arg_key } _in_{ key_name } "
419+ st .session_state [full_key ] = arg_value
420+
421+ return on_change
0 commit comments