Skip to content

Commit 47483e2

Browse files
committed
add custom lr shedule
1 parent 8b7b6a8 commit 47483e2

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

open_diloco/train_fsdp.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from functools import partial
10+
import math
1011
import os
1112
import time
1213
from contextlib import nullcontext
@@ -27,7 +28,6 @@
2728
DataCollatorForLanguageModeling,
2829
LlamaConfig,
2930
LlamaForCausalLM,
30-
get_cosine_schedule_with_warmup,
3131
)
3232
from torch.distributed.fsdp import (
3333
FullyShardedDataParallel as FSDP,
@@ -46,7 +46,7 @@
4646
save_checkpoint,
4747
)
4848
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
49-
49+
from torch.optim.lr_scheduler import LambdaLR
5050

5151
from hivemind.dht.dht import DHT
5252
from hivemind.utils.networking import log_visible_maddrs
@@ -173,6 +173,27 @@ def get_model(config: Config) -> LlamaForCausalLM:
173173
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)
174174

175175

176+
def _get_cosine_schedule_with_warmup_lr_lambda(
177+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
178+
):
179+
if current_step < num_warmup_steps:
180+
return float(current_step) / float(max(1, num_warmup_steps))
181+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
182+
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
183+
factor = factor * (1 - min_lr_rate) + min_lr_rate
184+
return max(0, factor)
185+
186+
187+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
188+
lambda_lr = partial(
189+
_get_cosine_schedule_with_warmup_lr_lambda,
190+
num_warmup_steps=num_warmup_steps,
191+
num_training_steps=num_training_steps,
192+
num_cycles=0.5,
193+
)
194+
return LambdaLR(optimizer, lambda_lr, -1)
195+
196+
176197
def train(config: Config):
177198
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
178199
local_rank = int(os.environ["LOCAL_RANK"])
@@ -262,6 +283,7 @@ def scheduler_fn(opt):
262283
opt,
263284
num_warmup_steps=config.warmup_steps,
264285
num_training_steps=config.total_steps,
286+
num_inner_steps=config.hv.local_steps,
265287
)
266288

267289
if config.hv is not None:

0 commit comments

Comments
 (0)