6
6
import ray .util .collective as cc
7
7
import torch
8
8
import tqdm
9
+ import wandb
9
10
from coati .dataset .loader import RawConversationDataset
10
11
from coati .distributed .reward .reward_fn import boxed_math_reward_fn , math_reward_fn
12
+ from ray .util .collective import allreduce
13
+ from ray .util .collective .types import Backend , ReduceOp
11
14
from torch .utils .data import DataLoader , DistributedSampler
12
15
from transformers import AutoTokenizer
13
16
14
17
from colossalai .utils import get_current_device
15
18
16
19
from .comm import ray_broadcast_tensor_dict
17
20
from .inference_backend import BACKEND_MAP
18
- from .utils import pre_send , safe_write_jsonl
21
+ from .utils import pre_send , safe_append_to_jsonl_file
19
22
20
23
try :
21
24
from vllm import SamplingParams
@@ -43,6 +46,9 @@ def __init__(
43
46
eval_interval = - 1 , # disable evaluation
44
47
evaluation_function_type = "think_answer_tags" ,
45
48
eval_save_dir : str = "./eval" ,
49
+ project_name : str = None ,
50
+ run_name : str = None ,
51
+ wandb_group_name : str = None ,
46
52
):
47
53
self .producer_idx = producer_idx
48
54
self .num_producers = num_producers
@@ -61,6 +67,14 @@ def __init__(
61
67
self .eval_interval = eval_interval
62
68
self .eval_save_dir = eval_save_dir
63
69
self .consumer_global_step = 0
70
+ if self .producer_idx == 0 :
71
+ self .wandb_run = wandb .init (
72
+ project = project_name ,
73
+ sync_tensorboard = True ,
74
+ dir = "./wandb" ,
75
+ name = run_name + "_eval" ,
76
+ group = wandb_group_name ,
77
+ )
64
78
65
79
if os .path .exists (self .eval_save_dir ):
66
80
raise ValueError (f"Eval save dir { self .eval_save_dir } already exists. Please delete it or change the name." )
@@ -132,13 +146,18 @@ def __init__(
132
146
self .consumer_pp_size = consumer_plugin_config .get ("pp_size" , 1 ) # consumer pp size
133
147
134
148
def setup (self ) -> None :
149
+ cc .init_collective_group (
150
+ world_size = self .num_producers ,
151
+ rank = self .producer_idx ,
152
+ backend = Backend .NCCL ,
153
+ group_name = "producer_group" ,
154
+ )
135
155
cc .init_collective_group (1 + self .num_consumer_procs , 0 , group_name = f"sync_data_{ self .producer_idx } " )
136
156
if self .consumer_pp_size > 1 :
137
157
for i in range (self .consumer_pp_size ):
138
158
cc .init_collective_group (self .num_producers + 1 , self .producer_idx , group_name = f"sync_model_{ i } " )
139
159
else :
140
160
cc .init_collective_group (self .num_producers + 1 , self .producer_idx , group_name = "sync_model" )
141
- cc .init_collective_group (1 + self .num_consumer_procs , 0 , group_name = f"sync_eval_statistics_{ self .producer_idx } " )
142
161
143
162
def rollout (self , input_ids : torch .Tensor , attention_mask : torch .Tensor , ** kwargs ) -> Dict [str , torch .Tensor ]:
144
163
raise NotImplementedError
@@ -160,13 +179,14 @@ def loop(self) -> None:
160
179
break
161
180
if self .eval_interval > 0 and self .eval_dataset_config is not None :
162
181
if i % self .eval_interval == 0 :
163
- eval_statistics = {}
182
+ to_log_msg = {}
164
183
for eval_task_name in self .eval_dataloaders :
165
- print (
166
- f"[P{ self .producer_idx } ] Evaluate episode { episode } step { i } on task { eval_task_name } "
167
- )
184
+ if self .producer_idx == 0 :
185
+ print (
186
+ f"[P{ self .producer_idx } ] Evaluate episode { episode } step { i } on task { eval_task_name } "
187
+ )
168
188
eval_results = []
169
- eval_statistics [ eval_task_name ] = torch .zeros (2 , device = self .device )
189
+ eval_statistics_tensor = torch .zeros (( 2 ,), dtype = torch . float32 ). to ( self .device )
170
190
for eval_batch in tqdm .tqdm (
171
191
self .eval_dataloaders [eval_task_name ], disable = self .producer_idx != 0
172
192
):
@@ -182,24 +202,27 @@ def loop(self) -> None:
182
202
for m in range (eval_outputs ["input_ids" ].size (0 ))
183
203
for n in range (eval_outputs ["input_ids" ].size (1 ))
184
204
]
185
- eval_statistics [eval_task_name ][0 ] += len (
186
- [res for res in eval_results if res ["ans_valid" ] == 1 ]
205
+ eval_statistics_tensor [0 ] += len ([res for res in eval_results if res ["ans_valid" ] == 1 ])
206
+ eval_statistics_tensor [1 ] += len (eval_results )
207
+ allreduce (eval_statistics_tensor , op = ReduceOp .SUM , group_name = "producer_group" )
208
+ to_log_msg [f"eval/{ eval_task_name } " ] = (
209
+ eval_statistics_tensor [0 ].item () / eval_statistics_tensor [1 ].item ()
187
210
)
188
- eval_statistics [eval_task_name ][1 ] += len (eval_results )
211
+ if self .producer_idx == 0 :
212
+ print (
213
+ f"[P{ self .producer_idx } ]: Accuracy on { eval_task_name } : { to_log_msg [f'eval/{ eval_task_name } ' ]} "
214
+ )
189
215
# save eval results
190
- result_file_name = os .path .join (
191
- self .eval_save_dir ,
192
- f"{ eval_task_name } _episode_{ episode } _step_{ self .consumer_global_step } .jsonl" ,
216
+ safe_append_to_jsonl_file (
217
+ os .path .join (
218
+ self .eval_save_dir ,
219
+ f"{ eval_task_name } _episode_{ episode } _step_{ self .consumer_global_step } .jsonl" ,
220
+ ),
221
+ eval_results ,
193
222
)
194
- # delete the file if it exists
195
- safe_write_jsonl (result_file_name , eval_results )
196
- print (f"[P{ self .producer_idx } ] Send eval statistics episode { episode } step { i } " )
197
- ray_broadcast_tensor_dict (
198
- eval_statistics ,
199
- src = 0 ,
200
- device = self .device ,
201
- group_name = f"sync_eval_statistics_{ self .producer_idx } " ,
202
- )
223
+
224
+ if self .producer_idx == 0 :
225
+ self .wandb_run .log (to_log_msg , step = self .consumer_global_step )
203
226
outputs = self .rollout (** batch )
204
227
205
228
print (f"[P{ self .producer_idx } ] Send data { [(k , v .shape ) for k , v in outputs .items ()]} " )
@@ -248,12 +271,11 @@ def loop(self) -> None:
248
271
# linear annealing for 1 episode, temperature from initial to 0.9
249
272
if episode <= 0 :
250
273
ratio = 1 - (len (self .train_dataloader ) - i ) / len (self .train_dataloader )
251
- if isinstance (self .model .generate_config .temperature , dict ):
252
- self .model .generate_config ["temperature" ] = (1 - ratio ) * self .generate_config [
253
- "temperature"
254
- ] + ratio * 0.9
255
- else :
256
- self .model .generate_config .temperature = (1 - ratio ) * self .generate_config [
274
+ self .model .generate_config ["temperature" ] = (1 - ratio ) * self .generate_config [
275
+ "temperature"
276
+ ] + ratio * 0.9
277
+ if isinstance (self .model , BACKEND_MAP ["vllm" ]):
278
+ self .model .sample_params .temperature = (1 - ratio ) * self .generate_config [
257
279
"temperature"
258
280
] + ratio * 0.9
259
281
@@ -280,6 +302,10 @@ def __init__(
280
302
eval_interval = - 1 , # disable evaluation
281
303
evaluation_function_type = "think_answer_tags" ,
282
304
eval_save_dir : str = "./eval" ,
305
+ eval_generation_config = {},
306
+ project_name : str = None ,
307
+ run_name : str = None ,
308
+ wandb_group_name : str = None ,
283
309
):
284
310
super ().__init__ (
285
311
producer_idx ,
@@ -299,10 +325,14 @@ def __init__(
299
325
eval_interval = eval_interval ,
300
326
evaluation_function_type = evaluation_function_type ,
301
327
eval_save_dir = eval_save_dir ,
328
+ project_name = project_name ,
329
+ run_name = run_name ,
330
+ wandb_group_name = wandb_group_name ,
302
331
)
303
332
self .model = self .backend_cls (model_config , generate_config , self .tokenizer , num_generations )
304
333
self .eval_generation_config = copy .deepcopy (self .model .generate_config )
305
334
self .eval_generation_config ["n" ] = 1 # use 1 generation for evaluation
335
+ self .eval_generation_config .update (eval_generation_config )
306
336
self .eval_sample_params = SamplingParams (** self .eval_generation_config )
307
337
308
338
@torch .no_grad ()
0 commit comments