Skip to content

Commit 118a66f

Browse files
[Fix] Add L2 Regularization (#6372)
* fix no L2 regularization error * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c782976 commit 118a66f

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(
364364
self.model = AutoModelForCausalLM.from_pretrained(path, **model_config)
365365
self.model.train()
366366
self.model.gradient_checkpointing_enable()
367-
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3)
367+
self.optimizer = HybridAdam(self.model.parameters(), lr=1e-3, weight_decay=0.01)
368368
self.accum_loss = torch.zeros(1, device=self.device)
369369

370370
def setup(self):

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,11 @@ def __init__(
7272
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
7373
self.policy_model.train()
7474
self.policy_model.gradient_checkpointing_enable()
75-
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=grpo_config.get("lr", 1e-6))
75+
self.optimizer = HybridAdam(
76+
self.policy_model.parameters(),
77+
lr=grpo_config.get("lr", 1e-6),
78+
weight_decay=grpo_config.get("weight_decay", 0.01),
79+
)
7680
self.accum_loss = torch.zeros(1, device=self.device)
7781
self.accum_kl = torch.zeros(1, device=self.device)
7882
self.accum_entropy = torch.zeros(1, device=self.device)

0 commit comments

Comments
 (0)