|
7 | 7 | """
|
8 | 8 |
|
9 | 9 | from functools import partial
|
| 10 | +import math |
10 | 11 | import os
|
11 | 12 | import time
|
12 | 13 | from contextlib import nullcontext
|
|
26 | 27 | DataCollatorForLanguageModeling,
|
27 | 28 | LlamaConfig,
|
28 | 29 | LlamaForCausalLM,
|
29 |
| - get_cosine_schedule_with_warmup, |
30 | 30 | )
|
31 | 31 | from torch.distributed.fsdp import (
|
32 | 32 | FullyShardedDataParallel as FSDP,
|
|
46 | 46 | )
|
47 | 47 | from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
|
48 | 48 | from open_diloco.utils import WandbLogger, DummyLogger
|
| 49 | +from torch.optim.lr_scheduler import LambdaLR |
49 | 50 |
|
50 | 51 | from hivemind.dht.dht import DHT
|
51 | 52 | from hivemind.utils.networking import log_visible_maddrs
|
@@ -173,6 +174,27 @@ def get_model(config: Config) -> LlamaForCausalLM:
|
173 | 174 | return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)
|
174 | 175 |
|
175 | 176 |
|
| 177 | +def _get_cosine_schedule_with_warmup_lr_lambda( |
| 178 | + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 |
| 179 | +): |
| 180 | + if current_step < num_warmup_steps: |
| 181 | + return float(current_step) / float(max(1, num_warmup_steps)) |
| 182 | + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| 183 | + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
| 184 | + factor = factor * (1 - min_lr_rate) + min_lr_rate |
| 185 | + return max(0, factor) |
| 186 | + |
| 187 | + |
| 188 | +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps): |
| 189 | + lambda_lr = partial( |
| 190 | + _get_cosine_schedule_with_warmup_lr_lambda, |
| 191 | + num_warmup_steps=num_warmup_steps, |
| 192 | + num_training_steps=num_training_steps, |
| 193 | + num_cycles=0.5, |
| 194 | + ) |
| 195 | + return LambdaLR(optimizer, lambda_lr, -1) |
| 196 | + |
| 197 | + |
176 | 198 | def train(config: Config):
|
177 | 199 | sharding_strategy = get_sharding_strategy(config.sharding_strategy)
|
178 | 200 | local_rank = int(os.environ["LOCAL_RANK"])
|
@@ -254,6 +276,7 @@ def scheduler_fn(opt):
|
254 | 276 | opt,
|
255 | 277 | num_warmup_steps=config.warmup_steps,
|
256 | 278 | num_training_steps=config.total_steps,
|
| 279 | + num_inner_steps=config.hv.local_steps, |
257 | 280 | )
|
258 | 281 |
|
259 | 282 | if config.hv is not None:
|
|
0 commit comments