Skip to content

Commit 17ac713

Browse files
committed
add warmup steps
1 parent 47483e2 commit 17ac713

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

open_diloco/train_fsdp.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class HvConfig(BaseConfig):
9191
world_rank: int
9292
galaxy_size: int
9393
fail_rank_drop: bool = False # fail if we lose a diloco worker
94+
warmup_outerstep: int = 10
9495

9596
@model_validator(mode="before")
9697
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -174,8 +175,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
174175

175176

176177
def _get_cosine_schedule_with_warmup_lr_lambda(
177-
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
178+
current_step: int,
179+
*,
180+
num_warmup_steps: int,
181+
num_training_steps: int,
182+
num_inner_steps: int,
183+
warmup_outerstep: int | None,
184+
num_cycles: float,
185+
min_lr_rate: float = 0.0,
178186
):
187+
if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep:
188+
return 0
189+
179190
if current_step < num_warmup_steps:
180191
return float(current_step) / float(max(1, num_warmup_steps))
181192
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
@@ -184,11 +195,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
184195
return max(0, factor)
185196

186197

187-
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
198+
def get_cosine_schedule_with_warmup(optimizer, config: Config):
188199
lambda_lr = partial(
189200
_get_cosine_schedule_with_warmup_lr_lambda,
190-
num_warmup_steps=num_warmup_steps,
191-
num_training_steps=num_training_steps,
201+
num_warmup_steps=config.warmup_steps,
202+
num_training_steps=config.total_steps,
203+
num_inner_steps=config.hv.local_steps,
204+
warmup_outerstep=config.hv.warmup_outerstep,
192205
num_cycles=0.5,
193206
)
194207
return LambdaLR(optimizer, lambda_lr, -1)
@@ -281,9 +294,7 @@ def train(config: Config):
281294
def scheduler_fn(opt):
282295
return get_cosine_schedule_with_warmup(
283296
opt,
284-
num_warmup_steps=config.warmup_steps,
285-
num_training_steps=config.total_steps,
286-
num_inner_steps=config.hv.local_steps,
297+
config=config,
287298
)
288299

289300
if config.hv is not None:

0 commit comments

Comments
 (0)