Skip to content

Commit e91cbda

Browse files
committed
add outer lr scheduler
1 parent ce6f82b commit e91cbda

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

open_diloco/hivemind_diloco.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def __init__(
334334
inner_optimizer: OptimizerFactory,
335335
params: Optional[Union[Parameters, ParamGroups]] = None,
336336
scheduler: Optional[SchedulerFactory] = None,
337+
outer_scheduler: Optional[SchedulerFactory] = None,
337338
averager_opts: Optional[dict] = None,
338339
grad_compression: CompressionBase = NoCompression(),
339340
tracker_opts: Optional[dict] = None,
@@ -365,7 +366,7 @@ def __init__(
365366
# since we have two optimizers, we need to persist the params to a list
366367
self.num_inner_steps = num_inner_steps
367368

368-
for opt_or_scheduler in [outer_optimizer, scheduler]:
369+
for opt_or_scheduler in [outer_optimizer, scheduler, outer_scheduler]:
369370
if not (callable(opt_or_scheduler) or opt_or_scheduler is None):
370371
raise TypeError("You need to pass inner and outer optimizer as well as scheduler as callable")
371372

@@ -405,6 +406,8 @@ def __init__(
405406
)
406407
self.diloco_grad_averager = self._make_gradient_averager(compression=grad_compression)
407408

409+
self.outer_scheduler = outer_scheduler(self.state_averager.optimizer)
410+
408411
def _check_kwargs(self, kwargs) -> None:
409412
"""DiLoCo Optimizer only support a subset of Hivemind Optimizer kwargs.
410413
This function raise an error if some kwargs are not supported"""
@@ -555,6 +558,9 @@ def step(
555558
if self.tracker.ready_to_update_epoch:
556559
self._update_global_epoch()
557560

561+
if self.outer_scheduler is not None:
562+
self.outer_scheduler.step()
563+
558564
return loss
559565

560566
def _compute_schema_hash(self) -> int:

open_diloco/train_fsdp.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ class HvConfig(BaseConfig):
9292
galaxy_size: int
9393
fail_rank_drop: bool = False # fail if we lose a diloco worker
9494
warmup_outerstep: int = 10
95+
outer_lr_min: float = 0.3
96+
outer_scheduler: bool = False
9597

9698
@model_validator(mode="before")
9799
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -180,17 +182,12 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
180182
*,
181183
num_warmup_steps: int,
182184
num_training_steps: int,
183-
num_inner_steps: int,
184-
warmup_outerstep: int | None,
185185
num_cycles: float,
186186
min_lr_rate: float = 0.0,
187187
):
188188
if current_step < num_warmup_steps:
189189
return float(current_step) / float(max(1, num_warmup_steps))
190190

191-
if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep:
192-
return 0
193-
194191
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
195192
factor = 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
196193
factor = factor * (1 - min_lr_rate) + min_lr_rate
@@ -202,13 +199,36 @@ def get_cosine_schedule_with_warmup(optimizer, config: Config):
202199
_get_cosine_schedule_with_warmup_lr_lambda,
203200
num_warmup_steps=config.warmup_steps,
204201
num_training_steps=config.total_steps,
205-
num_inner_steps=config.hv.local_steps,
206-
warmup_outerstep=config.hv.warmup_outerstep,
207202
num_cycles=0.5,
208203
)
209204
return LambdaLR(optimizer, lambda_lr, -1)
210205

211206

207+
def _get_lr_outer(
208+
current_step: int,
209+
*,
210+
num_warmup_steps: int,
211+
num_training_steps: int,
212+
min_lr_rate: float = 0.0,
213+
):
214+
if current_step < num_warmup_steps:
215+
return 1
216+
217+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
218+
factor = 0.5 * (1.0 + math.cos(math.pi * 2.0 * progress))
219+
factor = factor * (1 - min_lr_rate) + min_lr_rate
220+
return max(0, factor)
221+
222+
223+
def get_lr_outer(optimizer, config: Config):
224+
lambda_lr = partial(
225+
_get_lr_outer,
226+
num_warmup_steps=config.warmup_steps,
227+
num_training_steps=config.total_steps,
228+
)
229+
return LambdaLR(optimizer, lambda_lr, -1)
230+
231+
212232
def train(config: Config):
213233
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
214234
local_rank = int(os.environ["LOCAL_RANK"])
@@ -291,6 +311,9 @@ def scheduler_fn(opt):
291311
config=config,
292312
)
293313

314+
def outer_scheduler_fn(opt):
315+
return get_lr_outer(opt, config=config)
316+
294317
if config.hv is not None:
295318
if config.ckpt.resume:
296319
# We need to load with a fake optimizer to set the model parameters correctly before initializing the DiLoCoOptimizer
@@ -316,6 +339,7 @@ def scheduler_fn(opt):
316339
outer_optimizer=outer_optimizer,
317340
inner_optimizer=inner_optimizer,
318341
scheduler=None,
342+
outer_scheduler=outer_scheduler_fn if config.hv.outer_scheduler else None,
319343
params=model.parameters(),
320344
delay_optimizer_step=False,
321345
delay_grad_averaging=False,
@@ -435,6 +459,7 @@ def scheduler_fn(opt):
435459
scaler.update()
436460

437461
scheduler.step()
462+
438463
optimizer.zero_grad()
439464

440465
if config.hv is not None:

0 commit comments

Comments
 (0)