Skip to content

Commit 8049367

Browse files
committed
add outer lr scheduler
1 parent 720d929 commit 8049367

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

open_diloco/hivemind_diloco.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(
334334
inner_optimizer: OptimizerFactory,
335335
params: Optional[Union[Parameters, ParamGroups]] = None,
336336
scheduler: Optional[SchedulerFactory] = None,
337+
outer_scheduler: Optional[SchedulerFactory] = None,
337338
averager_opts: Optional[dict] = None,
338339
grad_compression: CompressionBase = NoCompression(),
339340
tracker_opts: Optional[dict] = None,
@@ -365,7 +366,7 @@ def __init__(
365366
# since we have two optimizers, we need to persist the params to a list
366367
self.num_inner_steps = num_inner_steps
367368

368-
for opt_or_scheduler in [outer_optimizer, scheduler]:
369+
for opt_or_scheduler in [outer_optimizer, scheduler, outer_scheduler]:
369370
if not (callable(opt_or_scheduler) or opt_or_scheduler is None):
370371
raise TypeError("You need to pass inner and outer optimizer as well as scheduler as callable")
371372

@@ -405,6 +406,8 @@ def __init__(
405406
)
406407
self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression)
407408

409+
self.outer_scheduler = outer_scheduler(self.state_averager.optimizer)
410+
408411
def _check_kwargs(self, kwargs) -> None:
409412
"""DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs.
410413
This function raise an error if some kwargs are not supported"""
@@ -555,6 +558,9 @@ def step(
555558
if self.tracker.ready_to_update_epoch:
556559
self._update_global_epoch()
557560

561+
if self.outer_scheduler is not None:
562+
self.outer_scheduler.step()
563+
558564
return loss
559565

560566
def _compute_schema_hash(self) -> int:

open_diloco/train_fsdp.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class HvConfig(BaseConfig):
9292
galaxy_size: int
9393
fail_rank_drop: bool = False # fail if we lose a diloco worker
9494
warmup_outerstep: int = 10
95+
outer_lr_min: float = 0.3
96+
outer_scheduler: bool = False
9597

9698
@model_validator(mode="before")
9799
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -179,17 +181,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
179181
*,
180182
num_warmup_steps: int,
181183
num_training_steps: int,
182-
num_inner_steps: int,
183-
warmup_outerstep: int | None,
184184
num_cycles: float,
185185
min_lr_rate: float = 0.0,
186186
):
187187
if current_step < num_warmup_steps:
188188
return float(current_step) / float(max(1, num_warmup_steps))
189189

190-
if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep:
191-
return 0
192-
193190
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
194191
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
195192
factor = factor * (1 - min_lr_rate) + min_lr_rate
@@ -201,13 +198,36 @@ def get_cosine_schedule_with_warmup(optimizer, config: Config):
201198
_get_cosine_schedule_with_warmup_lr_lambda,
202199
num_warmup_steps=config.warmup_steps,
203200
num_training_steps=config.total_steps,
204-
num_inner_steps=config.hv.local_steps,
205-
warmup_outerstep=config.hv.warmup_outerstep,
206201
num_cycles=0.5,
207202
)
208203
return LambdaLR(optimizer, lambda_lr, -1)
209204

210205

206+
def _get_lr_outer(
207+
current_step: int,
208+
*,
209+
num_warmup_steps: int,
210+
num_training_steps: int,
211+
min_lr_rate: float = 0.0,
212+
):
213+
if current_step < num_warmup_steps:
214+
return 1
215+
216+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
217+
factor = 0.5 * (1.0 + math.cos(math.pi * 2.0 * progress))
218+
factor = factor * (1 - min_lr_rate) + min_lr_rate
219+
return max(0, factor)
220+
221+
222+
def get_lr_outer(optimizer, config: Config):
223+
lambda_lr = partial(
224+
_get_lr_outer,
225+
num_warmup_steps=config.warmup_steps,
226+
num_training_steps=config.total_steps,
227+
)
228+
return LambdaLR(optimizer, lambda_lr, -1)
229+
230+
211231
def train(config: Config):
212232
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
213233
local_rank = int(os.environ["LOCAL_RANK"])
@@ -298,6 +318,9 @@ def scheduler_fn(opt):
298318
config=config,
299319
)
300320

321+
def outer_scheduler_fn(opt):
322+
return get_lr_outer(opt, config=config)
323+
301324
if config.hv is not None:
302325
if config.ckpt.resume:
303326
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
@@ -323,6 +346,7 @@ def scheduler_fn(opt):
323346
outer_optimizer=outer_optimizer,
324347
inner_optimizer=inner_optimizer,
325348
scheduler=None,
349+
outer_scheduler=outer_scheduler_fn if config.hv.outer_scheduler else None,
326350
params=model.parameters(),
327351
delay_optimizer_step=False,
328352
delay_grad_averaging=False,
@@ -438,6 +462,7 @@ def scheduler_fn(opt):
438462
scaler.update()
439463

440464
scheduler.step()
465+
441466
optimizer.zero_grad()
442467

443468
if logging_activations_steps:

0 commit comments

Comments
 (0)