Skip to content

Commit b1c0a26

Browse files
committed
add custom lr shedule
1 parent e4e4f86 commit b1c0a26

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

open_diloco/train_fsdp.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from functools import partial
10+
import math
1011
import os
1112
import time
1213
from contextlib import nullcontext
@@ -28,7 +29,6 @@
2829
DataCollatorForLanguageModeling,
2930
LlamaConfig,
3031
LlamaForCausalLM,
31-
get_cosine_schedule_with_warmup,
3232
)
3333
from torch.distributed.fsdp import (
3434
FullyShardedDataParallel as FSDP,
@@ -39,7 +39,7 @@
3939
from torch.distributed import broadcast_object_list
4040
from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint
4141
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
42-
42+
from torch.optim.lr_scheduler import LambdaLR
4343

4444
from hivemind.dht.dht import DHT
4545
from hivemind.utils.networking import log_visible_maddrs
@@ -187,6 +187,27 @@ def get_model(config: Config) -> LlamaForCausalLM:
187187
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)
188188

189189

190+
def _get_cosine_schedule_with_warmup_lr_lambda(
191+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
192+
):
193+
if current_step < num_warmup_steps:
194+
return float(current_step) / float(max(1, num_warmup_steps))
195+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
196+
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
197+
factor = factor * (1 - min_lr_rate) + min_lr_rate
198+
return max(0, factor)
199+
200+
201+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
202+
lambda_lr = partial(
203+
_get_cosine_schedule_with_warmup_lr_lambda,
204+
num_warmup_steps=num_warmup_steps,
205+
num_training_steps=num_training_steps,
206+
num_cycles=0.5,
207+
)
208+
return LambdaLR(optimizer, lambda_lr, -1)
209+
210+
190211
def train(config: Config):
191212
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
192213
local_rank = int(os.environ["LOCAL_RANK"])
@@ -280,6 +301,7 @@ def scheduler_fn(opt):
280301
opt,
281302
num_warmup_steps=config.warmup_steps,
282303
num_training_steps=config.total_steps,
304+
num_inner_steps=config.hv.local_steps,
283305
)
284306

285307
if config.hv is not None:

0 commit comments

Comments
 (0)