Skip to content

Commit 8c18c8f

Browse files
committed
add unittest
1 parent 5435d10 commit 8c18c8f

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

tests/common/config_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ def test_max_token_len_per_gpu_set_correctly(self):
137137
expected_max_token_len,
138138
)
139139

140+
def test_optimizer_config_propagation(self):
141+
config = get_template_config()
142+
config.algorithm.optimizer.lr = 1e-4
143+
config.algorithm.optimizer.weight_decay = 0.05
144+
config.check_and_update()
145+
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4)
146+
self.assertEqual(
147+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.weight_decay, 0.05
148+
)
149+
self.assertEqual(
150+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "constant"
151+
) # default value
152+
# critic optimizer should not be affected
153+
self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5)
154+
self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01)
155+
self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant")
156+
140157
def tearDown(self):
141158
if os.path.exists(CHECKPOINT_ROOT_DIR):
142159
shutil.rmtree(CHECKPOINT_ROOT_DIR)

0 commit comments

Comments
 (0)