15
15
from coati .trainer .utils import all_reduce_mean
16
16
from transformers import AutoModelForCausalLM , AutoTokenizer
17
17
18
+ from colossalai .nn .lr_scheduler import CosineAnnealingWarmupLR
18
19
from colossalai .nn .optimizer import HybridAdam
19
20
20
21
@@ -34,10 +35,10 @@ def __init__(
34
35
model_config ,
35
36
plugin_config ,
36
37
microbatch_size = 1 ,
37
- num_generations = 4 ,
38
+ num_generations = 8 ,
38
39
use_wandb = True ,
39
- generator_config = None ,
40
- filter_range = None ,
40
+ generate_config = None ,
41
+ training_config = {} ,
41
42
):
42
43
super ().__init__ (
43
44
num_producers ,
@@ -57,7 +58,7 @@ def __init__(
57
58
self .policy_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
58
59
self .policy_model .train ()
59
60
self .policy_model .gradient_checkpointing_enable ()
60
- self .optimizer = HybridAdam (self .policy_model .parameters (), lr = 1e-6 )
61
+ self .optimizer = HybridAdam (self .policy_model .parameters (), lr = training_config . get ( "lr" , 1e-6 ) )
61
62
self .accum_loss = torch .zeros (1 , device = self .device )
62
63
self .accum_reward = torch .zeros (1 , device = self .device )
63
64
self .accum_kl = torch .zeros (1 , device = self .device )
@@ -66,6 +67,7 @@ def __init__(
66
67
self .accum_advantages = torch .zeros (1 , device = self .device )
67
68
self .accum_response_length = torch .zeros (1 , device = self .device )
68
69
self .accum_count = 0
70
+ self .generate_config = generate_config
69
71
70
72
# Reference model is initialized from policy model.
71
73
self .reference_model = AutoModelForCausalLM .from_pretrained (path , ** model_config )
@@ -74,7 +76,7 @@ def __init__(
74
76
self .tokenizer = AutoTokenizer .from_pretrained (path )
75
77
self .pad_token_id = self .tokenizer .pad_token_id
76
78
self .num_generations = num_generations
77
- self .filter_range = filter_range
79
+ self .filter_range = training_config . get ( " filter_range" , None )
78
80
if self .filter_range is not None :
79
81
assert len (self .filter_range ) == 2 , "Filter range should have 2 values."
80
82
@@ -92,15 +94,21 @@ def __init__(
92
94
self .policy_loss_fn = PolicyLoss ()
93
95
self .global_step = 0
94
96
if use_wandb and self .rank == 0 :
95
- if "repetition_penalty" in generator_config :
96
- name = f"{ generator_config ['backend' ]} _bs_{ self .batch_size * self .world_size } _temp_{ generator_config ['temperature' ]:.01f} _rep_penalty_{ generator_config ['repetition_penalty' ]:.01f} "
97
- else :
98
- name = f"{ generator_config ['backend' ]} _bs_{ self .batch_size * self .world_size } _temp_{ generator_config ['temperature' ]:.01f} "
97
+ name = f"{ generate_config ['backend' ]} _bs_{ self .batch_size * self .world_size } _temp_{ generate_config ['temperature' ]:.01f} _top_p_{ generate_config ['top_p' ]:.02f} "
99
98
self .wandb_run = wandb .init (project = "GRPO-V1" , sync_tensorboard = True , dir = "./wandb" , name = name )
100
99
100
+ self .lr_scheduler = CosineAnnealingWarmupLR (
101
+ optimizer = self .optimizer ,
102
+ total_steps = min (self .num_episodes , 4 ) * self .num_update_per_episode ,
103
+ warmup_steps = 0 ,
104
+ eta_min = 0.1 * training_config .get ("lr" , 1e-6 ),
105
+ )
106
+
101
107
def setup (self ):
102
108
super ().setup ()
103
- self .policy_model , self .optimizer , * _ = self .booster .boost (self .policy_model , self .optimizer )
109
+ self .policy_model , self .optimizer , _ , _ , self .lr_scheduler = self .booster .boost (
110
+ self .policy_model , self .optimizer , lr_scheduler = self .lr_scheduler
111
+ )
104
112
self .reference_model , * _ = self .booster .boost (self .reference_model )
105
113
106
114
def step (self , step_idx : int , ** kwargs ) -> Optional [float ]:
@@ -133,7 +141,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
133
141
attention_mask = data ["attention_mask" ],
134
142
)["logits" ]
135
143
action_log_probs = calc_action_log_probs (
136
- policy_model_logits / generator_config ["temperature" ], data ["input_ids" ], num_action , self .plugin .shard_config
144
+ policy_model_logits / self .generate_config ["temperature" ],
145
+ data ["input_ids" ],
146
+ num_action ,
147
+ self .plugin .shard_config ,
137
148
)
138
149
139
150
with torch .no_grad ():
@@ -142,7 +153,10 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
142
153
attention_mask = data ["attention_mask" ],
143
154
)["logits" ]
144
155
reference_action_log_probs = calc_action_log_probs (
145
- reference_model_logits / generator_config ["temperature" ], data ["input_ids" ], num_action , self .plugin .shard_config
156
+ reference_model_logits / self .generate_config ["temperature" ],
157
+ data ["input_ids" ],
158
+ num_action ,
159
+ self .plugin .shard_config ,
146
160
)
147
161
148
162
per_token_kl = (
@@ -161,22 +175,24 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
161
175
acc_reward = torch .tensor ([value [2 ] for value in reward_group ]).to (data ["input_ids" ].device )
162
176
163
177
# [batch_size, num_generations]
178
+
179
+ group_reward = reward .view (- 1 , self .num_generations )
180
+ reward_mean = group_reward .mean (dim = 1 )
164
181
# filter out the reward that is too high (all sample gets full score) or too low (all sample gets 0 score),
165
182
loss_mask = (
166
183
None
167
184
if self .filter_range is None
168
- else torch .logical_and (reward > self .filter_range [0 ], reward < self .filter_range [1 ])
185
+ else torch .logical_and (
186
+ reward_mean > self .filter_range [0 ], reward_mean < self .filter_range [1 ]
187
+ ).repeat_interleave (self .num_generations , dim = 0 )
169
188
)
170
- group_reward = reward .view (- 1 , self .num_generations )
171
- reward_mean = group_reward .mean (dim = 1 )
172
189
173
190
# [batch_size x num_generations]
174
- reward_mean = group_reward . mean ( dim = 1 ) .repeat_interleave (self .num_generations , dim = 0 )
191
+ reward_mean = reward_mean .repeat_interleave (self .num_generations , dim = 0 )
175
192
reward_std = group_reward .std (dim = 1 ).repeat_interleave (self .num_generations , dim = 0 )
176
193
# [batch_size x num_generations]
177
194
advantages = (reward - reward_mean ) / (reward_std + 1e-4 )
178
195
179
- # Calculate Loss
180
196
loss , skip_update , _ = self .policy_loss_fn (
181
197
action_log_probs ,
182
198
old_action_log_probs ,
0 commit comments