@@ -141,18 +141,32 @@ def test_optimizer_config_propagation(self):
141141 config = get_template_config ()
142142 config .algorithm .optimizer .lr = 1e-4
143143 config .algorithm .optimizer .weight_decay = 0.05
144+ config .algorithm .optimizer .clip_grad = 2.0
145+ config .algorithm .optimizer .lr_decay_steps = 1000
146+ config .algorithm .optimizer .lr_decay_style = "cosine"
147+ config .algorithm .optimizer .lr_warmup_init = 1e-7
148+ config .algorithm .optimizer .min_lr = 1e-6
144149 config .check_and_update ()
145150 self .assertEqual (config .trainer .trainer_config .actor_rollout_ref .actor .optim .lr , 1e-4 )
146151 self .assertEqual (
147152 config .trainer .trainer_config .actor_rollout_ref .actor .optim .weight_decay , 0.05
148153 )
154+ self .assertEqual (config .trainer .trainer_config .actor_rollout_ref .actor .optim .clip_grad , 2.0 )
149155 self .assertEqual (
150- config .trainer .trainer_config .actor_rollout_ref .actor .optim .lr_decay_style , "constant"
151- ) # default value
156+ config .trainer .trainer_config .actor_rollout_ref .actor .optim .lr_decay_steps , 1000
157+ )
158+ self .assertEqual (
159+ config .trainer .trainer_config .actor_rollout_ref .actor .optim .lr_decay_style , "cosine"
160+ )
161+ self .assertEqual (
162+ config .trainer .trainer_config .actor_rollout_ref .actor .optim .lr_warmup_init , 1e-7
163+ )
164+ self .assertEqual (config .trainer .trainer_config .actor_rollout_ref .actor .optim .min_lr , 1e-6 )
152165 # critic optimizer should not be affected
153166 self .assertEqual (config .trainer .trainer_config .critic .optim .lr , 1e-5 )
154167 self .assertEqual (config .trainer .trainer_config .critic .optim .weight_decay , 0.01 )
155168 self .assertEqual (config .trainer .trainer_config .critic .optim .lr_decay_style , "constant" )
169+ self .assertEqual (config .trainer .trainer_config .critic .optim .clip_grad , 1.0 )
156170
157171 def tearDown (self ):
158172 if os .path .exists (CHECKPOINT_ROOT_DIR ):
0 commit comments