diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 37b9805ba3..4b29e09f7e 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -141,18 +141,32 @@ def test_optimizer_config_propagation(self): config = get_template_config() config.algorithm.optimizer.lr = 1e-4 config.algorithm.optimizer.weight_decay = 0.05 + config.algorithm.optimizer.clip_grad = 2.0 + config.algorithm.optimizer.lr_decay_steps = 1000 + config.algorithm.optimizer.lr_decay_style = "cosine" + config.algorithm.optimizer.lr_warmup_init = 1e-7 + config.algorithm.optimizer.min_lr = 1e-6 config.check_and_update() self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4) self.assertEqual( config.trainer.trainer_config.actor_rollout_ref.actor.optim.weight_decay, 0.05 ) + self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.clip_grad, 2.0) self.assertEqual( - config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "constant" - ) # default value + config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_steps, 1000 + ) + self.assertEqual( + config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "cosine" + ) + self.assertEqual( + config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init, 1e-7 + ) + self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr, 1e-6) # critic optimizer should not be affected self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5) self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01) self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant") + self.assertEqual(config.trainer.trainer_config.critic.optim.clip_grad, 1.0) def tearDown(self): if os.path.exists(CHECKPOINT_ROOT_DIR): diff --git a/trinity/common/config.py b/trinity/common/config.py index c722959b96..3ce55ed1b2 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -98,6 +98,11 @@ class OptimizerConfig: optimizer_type: str = "adam" betas: List[float] = field(default_factory=lambda: [0.9, 0.999]) weight_decay: float = 0.01 + clip_grad: float = 1.0 + lr_warmup_init: float = 0.0 + lr_decay_steps: Optional[int] = None + lr_decay_style: str = "constant" + min_lr: float = 0.0 @dataclass