Skip to content

Commit c7ee190

Browse files
committed
add warmup steps
1 parent b1c0a26 commit c7ee190

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

open_diloco/train_fsdp.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ class HvConfig(BaseConfig):
103103
skip_load_from_peers: bool = False
104104
world_rank: int
105105
galaxy_size: int
106+
warmup_outerstep: int = 10
106107

107108
@model_validator(mode="before")
108109
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -188,8 +189,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
188189

189190

190191
def _get_cosine_schedule_with_warmup_lr_lambda(
191-
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
192+
current_step: int,
193+
*,
194+
num_warmup_steps: int,
195+
num_training_steps: int,
196+
num_inner_steps: int,
197+
warmup_outerstep: int | None,
198+
num_cycles: float,
199+
min_lr_rate: float = 0.0,
192200
):
201+
if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep:
202+
return 0
203+
193204
if current_step < num_warmup_steps:
194205
return float(current_step) / float(max(1, num_warmup_steps))
195206
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
@@ -198,11 +209,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
198209
return max(0, factor)
199210

200211

201-
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
212+
def get_cosine_schedule_with_warmup(optimizer, config: Config):
202213
lambda_lr = partial(
203214
_get_cosine_schedule_with_warmup_lr_lambda,
204-
num_warmup_steps=num_warmup_steps,
205-
num_training_steps=num_training_steps,
215+
num_warmup_steps=config.warmup_steps,
216+
num_training_steps=config.total_steps,
217+
num_inner_steps=config.hv.local_steps,
218+
warmup_outerstep=config.hv.warmup_outerstep,
206219
num_cycles=0.5,
207220
)
208221
return LambdaLR(optimizer, lambda_lr, -1)
@@ -299,9 +312,7 @@ def train(config: Config):
299312
def scheduler_fn(opt):
300313
return get_cosine_schedule_with_warmup(
301314
opt,
302-
num_warmup_steps=config.warmup_steps,
303-
num_training_steps=config.total_steps,
304-
num_inner_steps=config.hv.local_steps,
315+
config=config,
305316
)
306317

307318
if config.hv is not None:

0 commit comments

Comments
 (0)