Skip to content

Commit aa32213

Browse files
authored
Add weight_decay in OptimizerConfig (#364)
1 parent 43531d6 commit aa32213

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

tests/common/config_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ def test_max_token_len_per_gpu_set_correctly(self):
137137
expected_max_token_len,
138138
)
139139

140+
def test_optimizer_config_propagation(self):
141+
config = get_template_config()
142+
config.algorithm.optimizer.lr = 1e-4
143+
config.algorithm.optimizer.weight_decay = 0.05
144+
config.check_and_update()
145+
self.assertEqual(config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr, 1e-4)
146+
self.assertEqual(
147+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.weight_decay, 0.05
148+
)
149+
self.assertEqual(
150+
config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr_decay_style, "constant"
151+
) # default value
152+
# critic optimizer should not be affected
153+
self.assertEqual(config.trainer.trainer_config.critic.optim.lr, 1e-5)
154+
self.assertEqual(config.trainer.trainer_config.critic.optim.weight_decay, 0.01)
155+
self.assertEqual(config.trainer.trainer_config.critic.optim.lr_decay_style, "constant")
156+
140157
def tearDown(self):
141158
if os.path.exists(CHECKPOINT_ROOT_DIR):
142159
shutil.rmtree(CHECKPOINT_ROOT_DIR)

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class OptimizerConfig:
9797
warmup_style: str = "constant"
9898
optimizer_type: str = "adam"
9999
betas: List[float] = field(default_factory=lambda: [0.9, 0.999])
100+
weight_decay: float = 0.01
100101

101102

102103
@dataclass

0 commit comments

Comments
 (0)