Skip to content

Commit 9d03959

Browse files
committed
add warmup steps
1 parent f6748dd commit 9d03959

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

open_diloco/train_fsdp.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class HvConfig(BaseConfig):
9191
world_rank: int
9292
galaxy_size: int
9393
fail_rank_drop: bool = False # fail if we lose a diloco worker
94+
warmup_outerstep: int = 10
9495

9596
@model_validator(mode="before")
9697
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -175,8 +176,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
175176

176177

177178
def _get_cosine_schedule_with_warmup_lr_lambda(
178-
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
179+
current_step: int,
180+
*,
181+
num_warmup_steps: int,
182+
num_training_steps: int,
183+
num_inner_steps: int,
184+
warmup_outerstep: int | None,
185+
num_cycles: float,
186+
min_lr_rate: float = 0.0,
179187
):
188+
if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep:
189+
return 0
190+
180191
if current_step < num_warmup_steps:
181192
return float(current_step) / float(max(1, num_warmup_steps))
182193
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
@@ -185,11 +196,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
185196
return max(0, factor)
186197

187198

188-
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
199+
def get_cosine_schedule_with_warmup(optimizer, config: Config):
189200
lambda_lr = partial(
190201
_get_cosine_schedule_with_warmup_lr_lambda,
191-
num_warmup_steps=num_warmup_steps,
192-
num_training_steps=num_training_steps,
202+
num_warmup_steps=config.warmup_steps,
203+
num_training_steps=config.total_steps,
204+
num_inner_steps=config.hv.local_steps,
205+
warmup_outerstep=config.hv.warmup_outerstep,
193206
num_cycles=0.5,
194207
)
195208
return LambdaLR(optimizer, lambda_lr, -1)
@@ -274,9 +287,7 @@ def train(config: Config):
274287
def scheduler_fn(opt):
275288
return get_cosine_schedule_with_warmup(
276289
opt,
277-
num_warmup_steps=config.warmup_steps,
278-
num_training_steps=config.total_steps,
279-
num_inner_steps=config.hv.local_steps,
290+
config=config,
280291
)
281292

282293
if config.hv is not None:

0 commit comments

Comments
 (0)