Skip to content

Commit 852b3c5

Browse files
committed
add warmup steps
1 parent 7a334c1 commit 852b3c5

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

open_diloco/train_fsdp.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class HvConfig(BaseConfig):
105105
skip_load_from_peers: bool = False
106106
world_rank: int
107107
galaxy_size: int
108+
warmup_outerstep: int = 10
108109

109110
@model_validator(mode="before")
110111
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
@@ -190,8 +191,18 @@ def get_model(config: Config) -> LlamaForCausalLM:
190191

191192

192193
def _get_cosine_schedule_with_warmup_lr_lambda(
193-
current_step: int, *, num_warmup_steps: int, num_training_steps: int, num_cycles: float, min_lr_rate: float = 0.0
194+
current_step: int,
195+
*,
196+
num_warmup_steps: int,
197+
num_training_steps: int,
198+
num_inner_steps: int,
199+
warmup_outerstep: int | None,
200+
num_cycles: float,
201+
min_lr_rate: float = 0.0,
194202
):
203+
if warmup_outerstep is not None and current_step % num_inner_steps < warmup_outerstep:
204+
return 0
205+
195206
if current_step < num_warmup_steps:
196207
return float(current_step) / float(max(1, num_warmup_steps))
197208
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
@@ -200,11 +211,13 @@ def _get_cosine_schedule_with_warmup_lr_lambda(
200211
return max(0, factor)
201212

202213

203-
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_inner_steps):
214+
def get_cosine_schedule_with_warmup(optimizer, config: Config):
204215
lambda_lr = partial(
205216
_get_cosine_schedule_with_warmup_lr_lambda,
206-
num_warmup_steps=num_warmup_steps,
207-
num_training_steps=num_training_steps,
217+
num_warmup_steps=config.warmup_steps,
218+
num_training_steps=config.total_steps,
219+
num_inner_steps=config.hv.local_steps,
220+
warmup_outerstep=config.hv.warmup_outerstep,
208221
num_cycles=0.5,
209222
)
210223
return LambdaLR(optimizer, lambda_lr, -1)
@@ -301,9 +314,7 @@ def train(config: Config):
301314
def scheduler_fn(opt):
302315
return get_cosine_schedule_with_warmup(
303316
opt,
304-
num_warmup_steps=config.warmup_steps,
305-
num_training_steps=config.total_steps,
306-
num_inner_steps=config.hv.local_steps,
317+
config=config,
307318
)
308319

309320
if config.hv is not None:

0 commit comments

Comments
 (0)