Skip to content

Commit 152f3a6

Browse files
authored
[grpo] fix hang in colocate lora settings (#4451)
1 parent ec74e2b commit 152f3a6

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ def _infer(self,
523523
request_config: RequestConfig,
524524
is_global_inputs: bool = False) -> OutputsType:
525525
from swift.llm.infer.protocol import ChatCompletionResponse
526-
request_config = copy(self.request_config)
526+
request_config = self._get_request_config()
527527
# keys from InferRequest
528528
per_device_size = len(inputs)
529529
if is_global_inputs:
@@ -590,6 +590,25 @@ def _infer(self,
590590
results = results[start_idx:end_idx]
591591
return results
592592

593+
def _get_request_config(self) -> RequestConfig:
594+
request_config = copy(self.request_config)
595+
if self.args.vllm_mode == 'colocate' and self.vllm_tensor_parallel_size > 1:
596+
# Set request_config.seed
597+
# 1. Ensure that the seed for vLLM Engines within each TP (Tensor Parallelism) group is the same;
598+
# otherwise, the program may hang.
599+
# 2. Ensure that the seed for vLLM Engines across different TP groups is different;
600+
# otherwise, identical completions will be generated.
601+
mode = 'train' if self.model.training else 'eval'
602+
batch_size = (
603+
self.args.per_device_train_batch_size
604+
* self.args.gradient_accumulation_steps if mode == 'train' else self.args.per_device_eval_batch_size)
605+
batch_size *= self.vllm_tensor_parallel_size
606+
# Since the TP (Tensor Parallelism) group gathers the inputs,
607+
# multiply the batch size by the TP parallel size.
608+
request_config.seed = batch_size * (self.accelerator.process_index // self.vllm_tensor_parallel_size)
609+
610+
return request_config
611+
593612
def _set_inputs_system(self, inputs: InputsType) -> InputsType:
594613
if not self.template.template_meta.default_system:
595614
return

0 commit comments

Comments
 (0)