Skip to content
9 changes: 9 additions & 0 deletions open_diloco/ckpt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR,
outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
outer_optimizer: torch.optim.Optimizer | None = None,
scaler: torch.cuda.amp.GradScaler | None = None,
loss: float | None = None,
Expand Down Expand Up @@ -81,6 +82,8 @@ def save_checkpoint(

# 2. Save global states
global_state_dict = {"scheduler": scheduler.state_dict(), "loss": loss if loss is not None else 0}
if outer_scheduler is not None:
global_state_dict["outer_scheduler"] = outer_scheduler.state_dict()
if outer_optimizer is not None:
global_state_dict["outer_optimizer"] = outer_optimizer.state_dict()
if scaler is not None:
Expand All @@ -95,6 +98,7 @@ def load_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
outer_scheduler: torch.optim.lr_scheduler.LambdaLR | None = None,
outer_optimizer: torch.optim.Optimizer | None = None,
scaler: torch.cuda.amp.GradScaler | None = None,
data_loader: StatefulDataLoader | None = None,
Expand Down Expand Up @@ -139,8 +143,13 @@ def load_checkpoint(
if scheduler is not None:
scheduler.load_state_dict(global_state_dict["scheduler"])
optimizer.param_groups[0]["lr"] = scheduler.get_last_lr()[0]

if outer_optimizer is not None:
outer_optimizer.load_state_dict(global_state_dict["outer_optimizer"])
if outer_scheduler is not None:
outer_scheduler.load_state_dict(global_state_dict["outer_scheduler"])
outer_optimizer.param_groups[0]["lr"] = outer_scheduler.get_last_lr()[0]

if scaler is not None:
scaler.load_state_dict(global_state_dict["scaler"])
return global_state_dict["loss"]
Expand Down
8 changes: 7 additions & 1 deletion open_diloco/hivemind_diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def __init__(
inner_optimizer: OptimizerFactory,
params: Optional[Union[Parameters, ParamGroups]] = None,
scheduler: Optional[SchedulerFactory] = None,
outer_scheduler: Optional[SchedulerFactory] = None,
averager_opts: Optional[dict] = None,
grad_compression: CompressionBase = NoCompression(),
tracker_opts: Optional[dict] = None,
Expand Down Expand Up @@ -365,7 +366,7 @@ def __init__(
# since we have two optimizers, we need to persist the params to a list
self.num_inner_steps = num_inner_steps

for opt_or_scheduler in [outer_optimizer, scheduler]:
for opt_or_scheduler in [outer_optimizer, scheduler, outer_scheduler]:
if not (callable(opt_or_scheduler) or opt_or_scheduler is None):
raise TypeError("You need to pass inner and outer optimizer as well as scheduler as callable")

Expand Down Expand Up @@ -405,6 +406,8 @@ def __init__(
)
self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression)

self.outer_scheduler = outer_scheduler(self.state_averager.optimizer) if outer_scheduler else None

def _check_kwargs(self, kwargs) -> None:
"""DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs.
This function raise an error if some kwargs are not supported"""
Expand Down Expand Up @@ -555,6 +558,9 @@ def step(
if self.tracker.ready_to_update_epoch:
self._update_global_epoch()

if self.outer_scheduler is not None:
self.outer_scheduler.step()

return loss

def _compute_schema_hash(self) -> int:
Expand Down
70 changes: 67 additions & 3 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from functools import partial
import math
import os
import time
from contextlib import nullcontext
Expand All @@ -26,7 +27,6 @@
DataCollatorForLanguageModeling,
LlamaConfig,
LlamaForCausalLM,
get_cosine_schedule_with_warmup,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
Expand All @@ -46,6 +46,7 @@
)
from open_diloco.hivemind_diloco import AllReduceStrategy, DiLoCoOptimizer
from open_diloco.utils import WandbLogger, DummyLogger
from torch.optim.lr_scheduler import LambdaLR

from hivemind.dht.dht import DHT
from hivemind.utils.networking import log_visible_maddrs
Expand Down Expand Up @@ -90,6 +91,8 @@ class HvConfig(BaseConfig):
world_rank: int
galaxy_size: int
fail_rank_drop: bool = False # fail if we lose a diloco worker
warmup_outerstep: int = 10
outer_scheduler: bool = False

@model_validator(mode="before")
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -173,6 +176,61 @@ def get_model(config: Config) -> LlamaForCausalLM:
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)


def _get_cosine_schedule_with_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
min_lr_rate: float = 0.0,
):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))

progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
factor = factor * (1 - min_lr_rate) + min_lr_rate
return max(0, factor)


def get_cosine_schedule_with_warmup(optimizer, config: Config):
lambda_lr = partial(
_get_cosine_schedule_with_warmup_lr_lambda,
num_warmup_steps=config.warmup_steps,
num_training_steps=config.total_steps,
num_cycles=0.5,
)
return LambdaLR(optimizer, lambda_lr, -1)


def _get_lr_outer(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float,
min_lr_rate: float = 0.0,
):
if current_step < num_warmup_steps:
return 1

progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
factor = factor * (1 - min_lr_rate) + min_lr_rate
return max(0, factor)


def get_lr_outer(optimizer, config: Config):
lambda_lr = partial(
_get_lr_outer,
num_warmup_steps=config.warmup_steps,
# num_training_steps=config.total_steps,
num_training_steps=config.total_steps,
num_cycles=0.5,
)
return LambdaLR(optimizer, lambda_lr, -1)


def train(config: Config):
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
local_rank = int(os.environ["LOCAL_RANK"])
Expand Down Expand Up @@ -252,10 +310,12 @@ def train(config: Config):
def scheduler_fn(opt):
return get_cosine_schedule_with_warmup(
opt,
num_warmup_steps=config.warmup_steps,
num_training_steps=config.total_steps,
config=config,
)

def outer_scheduler_fn(opt):
return get_lr_outer(opt, config=config)

if config.hv is not None:
if config.ckpt.resume:
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
Expand All @@ -281,6 +341,7 @@ def scheduler_fn(opt):
outer_optimizer=outer_optimizer,
inner_optimizer=inner_optimizer,
scheduler=None,
outer_scheduler=outer_scheduler_fn if config.hv.outer_scheduler else None,
params=model.parameters(),
delay_optimizer_step=False,
delay_grad_averaging=False,
Expand Down Expand Up @@ -311,6 +372,7 @@ def scheduler_fn(opt):
model=model,
optimizer=optimizer.inner_optimizer,
scheduler=scheduler,
outer_scheduler=optimizer.outer_scheduler,
outer_optimizer=optimizer.state_averager.optimizer,
scaler=scaler,
data_loader=train_dataloader,
Expand Down Expand Up @@ -400,6 +462,7 @@ def scheduler_fn(opt):
scaler.update()

scheduler.step()

optimizer.zero_grad()

if config.hv is not None:
Expand Down Expand Up @@ -476,6 +539,7 @@ def scheduler_fn(opt):
model=model,
optimizer=optimizer.inner_optimizer,
scheduler=scheduler,
outer_scheduler=optimizer.outer_scheduler,
outer_optimizer=optimizer.state_averager.optimizer,
loss=loss_batch.item(),
scaler=scaler,
Expand Down
Loading