|
7 | 7 | """
|
8 | 8 |
|
9 | 9 | from functools import partial
|
| 10 | +import math |
10 | 11 | import os
|
11 | 12 | import time
|
12 | 13 | from contextlib import nullcontext
|
|
28 | 29 | DataCollatorForLanguageModeling,
|
29 | 30 | LlamaConfig,
|
30 | 31 | LlamaForCausalLM,
|
31 |
| - get_cosine_schedule_with_warmup, |
32 | 32 | )
|
33 | 33 | from torch.distributed.fsdp import (
|
34 | 34 | FullyShardedDataParallel as FSDP,
|
|
39 | 39 | from torch.distributed import broadcast_object_list
|
40 | 40 | from open_diloco.ckpt_utils import load_checkpoint, save_checkpoint
|
41 | 41 | from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
|
42 |
| - |
| 42 | +from torch.optim.lr_scheduler import LambdaLR |
43 | 43 |
|
44 | 44 | from hivemind.dht.dht import DHT
|
45 | 45 | from hivemind.utils.networking import log_visible_maddrs
|
@@ -189,6 +189,27 @@ def get_model(config: Config) -> LlamaForCausalLM:
|
189 | 189 | return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)
|
190 | 190 |
|
191 | 191 |
|
| 192 | +def _get_cosine_schedule_with_warmup_lr_lambda( |
| 193 | + current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0 |
| 194 | +): |
| 195 | + if current_step < num_warmup_steps: |
| 196 | + return float(current_step) / float(max(1, num_warmup_steps)) |
| 197 | + progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| 198 | + factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)) |
| 199 | + factor = factor * (1 - min_lr_rate) + min_lr_rate |
| 200 | + return max(0, factor) |
| 201 | + |
| 202 | + |
| 203 | +def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps): |
| 204 | + lambda_lr = partial( |
| 205 | + _get_cosine_schedule_with_warmup_lr_lambda, |
| 206 | + num_warmup_steps=num_warmup_steps, |
| 207 | + num_training_steps=num_training_steps, |
| 208 | + num_cycles=0.5, |
| 209 | + ) |
| 210 | + return LambdaLR(optimizer, lambda_lr, -1) |
| 211 | + |
| 212 | + |
192 | 213 | def train(config: Config):
|
193 | 214 | sharding_strategy = get_sharding_strategy(config.sharding_strategy)
|
194 | 215 | local_rank = int(os.environ["LOCAL_RANK"])
|
@@ -282,6 +303,7 @@ def scheduler_fn(opt):
|
282 | 303 | opt,
|
283 | 304 | num_warmup_steps=config.warmup_steps,
|
284 | 305 | num_training_steps=config.total_steps,
|
| 306 | + num_inner_steps=config.hv.local_steps, |
285 | 307 | )
|
286 | 308 |
|
287 | 309 | if config.hv is not None:
|
|
0 commit comments