Skip to content

Commit 59a3863

Browse files
authored
feat: grpo async generate thread-safe queue production (#3821)
* lock * remove lock * fix * fix * move comment * fix --------- Co-authored-by: hjh <[email protected]>
1 parent 3b056b0 commit 59a3863

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ def cyclic_iter(iterable):
321321
yield x
322322

323323
self.resample_iterator = cyclic_iter(self.get_resample_dataloader())
324+
# flag indicating whether the evaluation has started
325+
self.eval_flag = False
324326

325327
def split_batches(self):
326328
"""Sync weights in batches
@@ -1089,6 +1091,10 @@ def _get_per_token_logps(self, model, inputs):
10891091
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
10901092

10911093
def evaluation_loop(self, dataloader, *args, **kwargs):
1094+
# Wait for the training rollout to complete
1095+
if self.args.async_generate:
1096+
while not self.is_async_generate_eval_rollout_done():
1097+
time.sleep(0.1)
10921098
# set mini_batch_size None in evaluation
10931099
mini_batch_size = self.args.mini_batch_size
10941100
self.args.mini_batch_size = None
@@ -1099,13 +1105,17 @@ def evaluation_loop(self, dataloader, *args, **kwargs):
10991105
metrics = {f'{metric_key_prefix}_{key}': sum(val) / len(val) for key, val in self._metrics['eval'].items()}
11001106
output.metrics.update(metrics)
11011107
self.args.mini_batch_size = mini_batch_size
1108+
self.eval_flag = True
11021109
return output
11031110

11041111
def training_step(self,
11051112
model: nn.Module,
11061113
inputs: Dict[str, Union[torch.Tensor, Any]],
11071114
num_items_in_batch=None) -> torch.Tensor:
1108-
1115+
if self.args.async_generate:
1116+
# Wait for the eval rollout to complete
1117+
while not self.is_async_generate_eval_rollout_done():
1118+
time.sleep(0.1)
11091119
if self.args.mini_batch_size is None:
11101120
return super().training_step(model, inputs, num_items_in_batch)
11111121
model.train()
@@ -1326,3 +1336,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non
13261336
if self.args.wandb_log_unique_prompts:
13271337
df = df.drop_duplicates(subset=['prompt'])
13281338
wandb.log({'completions': wandb.Table(dataframe=df)})
1339+
1340+
def is_async_generate_eval_rollout_done(self):
1341+
return not self.eval_flag or not self.eval_queue.empty()
1342+
1343+
def is_async_generate_train_rollout_done(self):
1344+
return not self.train_queue.empty()

0 commit comments

Comments
 (0)