File tree Expand file tree Collapse file tree 2 files changed +18
-0
lines changed
Expand file tree Collapse file tree 2 files changed +18
-0
lines changed Original file line number Diff line number Diff line change @@ -137,6 +137,23 @@ def test_max_token_len_per_gpu_set_correctly(self):
137137 expected_max_token_len ,
138138 )
139139
140+ def test_optimizer_config_propagation (self ):
141+ config = get_template_config ()
142+ config .algorithm .optimizer .lr = 1e-4
143+ config .algorithm .optimizer .weight_decay = 0.05
144+ config .check_and_update ()
145+ self .assertEqual (config .trainer .trainer_config .actor_rollout_ref .actor .optim .lr , 1e-4 )
146+ self .assertEqual (
147+ config .trainer .trainer_config .actor_rollout_ref .actor .optim .weight_decay , 0.05
148+ )
149+ self .assertEqual (
150+ config .trainer .trainer_config .actor_rollout_ref .actor .optim .lr_decay_style , "constant"
151+ ) # default value
152+ # critic optimizer should not be affected
153+ self .assertEqual (config .trainer .trainer_config .critic .optim .lr , 1e-5 )
154+ self .assertEqual (config .trainer .trainer_config .critic .optim .weight_decay , 0.01 )
155+ self .assertEqual (config .trainer .trainer_config .critic .optim .lr_decay_style , "constant" )
156+
140157 def tearDown (self ):
141158 if os .path .exists (CHECKPOINT_ROOT_DIR ):
142159 shutil .rmtree (CHECKPOINT_ROOT_DIR )
Original file line number Diff line number Diff line change @@ -97,6 +97,7 @@ class OptimizerConfig:
9797 warmup_style : str = "constant"
9898 optimizer_type : str = "adam"
9999 betas : List [float ] = field (default_factory = lambda : [0.9 , 0.999 ])
100+ weight_decay : float = 0.01
100101
101102
102103@dataclass
You can’t perform that action at this time.
0 commit comments