1+ import math
12import os
3+ import sys
24from typing import Dict , List
35
46import ray
@@ -36,7 +38,7 @@ def __init__(self, config: Config):
3638
3739 def _init_algorithm (self ):
3840 self .algorithm = ALGORITHM_TYPE .get (self .config .algorithm .algorithm_type )
39- algorithm_config = self .config .algorithm
41+ self . algorithm_config = algorithm_config = self .config .algorithm
4042 if self .algorithm .compute_advantage_in_trainer :
4143 self .advantage_fn = ADVANTAGE_FN .get (algorithm_config .advantage_fn )(
4244 ** algorithm_config .advantage_fn_args
@@ -63,12 +65,60 @@ def _init_algorithm(self):
6365 and (self .loss_agg_mode == "token-mean" )
6466 )
6567
66- self .adam_params = types .AdamParams (
67- learning_rate = algorithm_config .optimizer .lr ,
68- beta1 = algorithm_config .optimizer .betas [0 ],
69- beta2 = algorithm_config .optimizer .betas [1 ],
68+ self .lr_scheduler_type = algorithm_config .optimizer .lr_scheduler_type
69+ self .total_steps = self .config .trainer .total_steps or sys .maxsize
70+ self .num_warmup_steps = algorithm_config .optimizer .lr_warmup_steps
71+ if self .num_warmup_steps < 0 :
72+ self .num_warmup_steps = int (
73+ algorithm_config .optimizer .lr_warmup_steps_ratio * self .total_steps
74+ )
75+ self .min_lr_ratio = algorithm_config .optimizer .min_lr_ratio
76+ assert 0.0 <= self .min_lr_ratio <= 1.0
77+ self .logger .info (
78+ f"Total steps: { self .total_steps } , num_warmup_steps: { self .num_warmup_steps } "
79+ )
80+
81+ if self .lr_scheduler_type not in {"constant" , "cosine" }:
82+ raise NotImplementedError (
83+ f"LR scheduler type { self .lr_scheduler_type } is not supported"
84+ )
85+
86+ @property
87+ def _current_lr_factor (self ):
88+ train_step_num = self ._train_step_num
89+ # warmup
90+ if train_step_num < self .num_warmup_steps :
91+ factor = float (train_step_num ) / float (max (1.0 , self .num_warmup_steps ))
92+ factor = self .min_lr_ratio + (1.0 - self .min_lr_ratio ) * factor
93+ return factor
94+
95+ # decay
96+ if train_step_num >= self .total_steps :
97+ progress = 1.0
98+ else :
99+ progress = float (train_step_num - self .num_warmup_steps ) / float (
100+ max (1.0 , self .total_steps - self .num_warmup_steps )
101+ )
102+ if self .lr_scheduler_type == "constant" :
103+ factor = 1.0
104+ elif self .lr_scheduler_type == "cosine" :
105+ num_cycles = 0.5 # TODO: may add to config
106+ factor = 0.5 * (1.0 + math .cos (math .pi * float (num_cycles ) * 2.0 * progress ))
107+ factor = self .min_lr_ratio + (1.0 - self .min_lr_ratio ) * factor
108+ return max (self .min_lr_ratio , factor )
109+
110+ @property
111+ def current_learning_rate (self ):
112+ return self ._current_lr_factor * self .algorithm_config .optimizer .lr
113+
114+ @property
115+ def adam_params (self ):
116+ return types .AdamParams (
117+ learning_rate = self .current_learning_rate ,
118+ beta1 = self .algorithm_config .optimizer .betas [0 ],
119+ beta2 = self .algorithm_config .optimizer .betas [1 ],
70120 # eps is currently not in config
71- weight_decay = algorithm_config .optimizer .weight_decay ,
121+ weight_decay = self . algorithm_config .optimizer .weight_decay ,
72122 grad_clip_norm = self .config .trainer .grad_clip ,
73123 )
74124
0 commit comments