Skip to content

Commit 7a334c1

Browse files
committed
add custom lr shedule
1 parent 8500dba commit 7a334c1

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

open_diloco/train_fsdp.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from functools import partial
10+
import math
1011
import os
1112
import time
1213
from contextlib import nullcontext
@@ -28,7 +29,6 @@
2829
DataCollatorForLanguageModeling,
2930
LlamaConfig,
3031
LlamaForCausalLM,
31-
get_cosine_schedule_with_warmup,
3232
)
3333
from torch.distributed.fsdp import (
3434
FullyShardedDataParallel as FSDP,
@@ -39,7 +39,7 @@
3939
from torch.distributed import broadcast_object_list
4040
from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint
4141
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
42-
42+
from torch.optim.lr_scheduler import LambdaLR
4343

4444
from hivemind.dht.dht import DHT
4545
from hivemind.utils.networking import log_visible_maddrs
@@ -189,6 +189,27 @@ def get_model(config: Config) -> LlamaForCausalLM:
189189
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)
190190

191191

192+
def _get_cosine_schedule_with_warmup_lr_lambda(
193+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
194+
):
195+
if current_step < num_warmup_steps:
196+
return float(current_step) / float(max(1, num_warmup_steps))
197+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
198+
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
199+
factor = factor * (1 - min_lr_rate) + min_lr_rate
200+
return max(0, factor)
201+
202+
203+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
204+
lambda_lr = partial(
205+
_get_cosine_schedule_with_warmup_lr_lambda,
206+
num_warmup_steps=num_warmup_steps,
207+
num_training_steps=num_training_steps,
208+
num_cycles=0.5,
209+
)
210+
return LambdaLR(optimizer, lambda_lr, -1)
211+
212+
192213
def train(config: Config):
193214
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
194215
local_rank = int(os.environ["LOCAL_RANK"])
@@ -282,6 +303,7 @@ def scheduler_fn(opt):
282303
opt,
283304
num_warmup_steps=config.warmup_steps,
284305
num_training_steps=config.total_steps,
306+
num_inner_steps=config.hv.local_steps,
285307
)
286308

287309
if config.hv is not None:

0 commit comments

Comments
 (0)