Skip to content

Commit ffdf4ff

Browse files
authored
Add more useful options in OptimizerConfig (#371)
Co-authored-by: 问昊 <[email protected]>
1 parent ba33438 commit ffdf4ff

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

tests/common/config_test.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,18 +141,32 @@ def test_optimizer_config_propagation(self):
141141
config = get_template_config()
142142
config.algorithm.optimizer.lr = 1e-4
143143
config.algorithm.optimizer.weight_decay = 0.05
144+
config.algorithm.optimizer.clip_grad = 2.0
145+
config.algorithm.optimizer.lr_decay_steps = 1000
146+
config.algorithm.optimizer.lr_decay_style = "cosine"
147+
config.algorithm.optimizer.lr_warmup_init = 1e-7
148+
config.algorithm.optimizer.min_lr = 1e-6
144149
config.check_and_update()
145150
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4)
146151
self.assertEqual(
147152
config.trainer.trainer_config.actor_rollout_ref.actor.optim.weight_decay, 0.05
148153
)
154+
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.clip_grad, 2.0)
149155
self.assertEqual(
150-
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "constant"
151-
) # default value
156+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_steps, 1000
157+
)
158+
self.assertEqual(
159+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "cosine"
160+
)
161+
self.assertEqual(
162+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_warmup_init, 1e-7
163+
)
164+
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.min_lr, 1e-6)
152165
# critic optimizer should not be affected
153166
self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5)
154167
self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01)
155168
self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant")
169+
self.assertEqual(config.trainer.trainer_config.critic.optim.clip_grad, 1.0)
156170

157171
def tearDown(self):
158172
if os.path.exists(CHECKPOINT_ROOT_DIR):

trinity/common/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,11 @@ class OptimizerConfig:
9898
optimizer_type: str = "adam"
9999
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
100100
weight_decay: float = 0.01
101+
clip_grad: float = 1.0
102+
lr_warmup_init: float = 0.0
103+
lr_decay_steps: Optional[int] = None
104+
lr_decay_style: str = "constant"
105+
min_lr: float = 0.0
101106

102107

103108
@dataclass

0 commit comments

Comments
 (0)