Skip to content

Commit f6748dd

Browse files
committed
fix torch compile log act (#23)
* fix renaming logic for key * fix stuff * fix exploding norm * remove print
1 parent f7ef006 commit f6748dd

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

open_diloco/train_fsdp.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
from functools import partial
10+
import math
1011
import os
1112
import time
1213
from contextlib import nullcontext
@@ -26,7 +27,6 @@
2627
DataCollatorForLanguageModeling,
2728
LlamaConfig,
2829
LlamaForCausalLM,
29-
get_cosine_schedule_with_warmup,
3030
)
3131
from torch.distributed.fsdp import (
3232
FullyShardedDataParallel as FSDP,
@@ -46,6 +46,7 @@
4646
)
4747
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
4848
from open_diloco.utils import WandbLogger, DummyLogger
49+
from torch.optim.lr_scheduler import LambdaLR
4950

5051
from hivemind.dht.dht import DHT
5152
from hivemind.utils.networking import log_visible_maddrs
@@ -173,6 +174,27 @@ def get_model(config: Config) -> LlamaForCausalLM:
173174
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)
174175

175176

177+
def _get_cosine_schedule_with_warmup_lr_lambda(
178+
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
179+
):
180+
if current_step < num_warmup_steps:
181+
return float(current_step) / float(max(1, num_warmup_steps))
182+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
183+
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
184+
factor = factor * (1 - min_lr_rate) + min_lr_rate
185+
return max(0, factor)
186+
187+
188+
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
189+
lambda_lr = partial(
190+
_get_cosine_schedule_with_warmup_lr_lambda,
191+
num_warmup_steps=num_warmup_steps,
192+
num_training_steps=num_training_steps,
193+
num_cycles=0.5,
194+
)
195+
return LambdaLR(optimizer, lambda_lr, -1)
196+
197+
176198
def train(config: Config):
177199
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
178200
local_rank = int(os.environ["LOCAL_RANK"])
@@ -254,6 +276,7 @@ def scheduler_fn(opt):
254276
opt,
255277
num_warmup_steps=config.warmup_steps,
256278
num_training_steps=config.total_steps,
279+
num_inner_steps=config.hv.local_steps,
257280
)
258281

259282
if config.hv is not None:

0 commit comments

Comments
 (0)