From b27fb9be5961cbc51b17a13a1226835bbb1d4c6c Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 29 Jul 2025 17:19:03 +0800 Subject: [PATCH 01/26] wip --- swift/trainers/rlhf_trainer/grpo_trainer.py | 51 ++++++++++++++++----- 1 file changed, 39 insertions(+), 12 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 74d2f18060..674e3d23c9 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -62,7 +62,7 @@ InputsType = List[Dict[str, Union[torch.Tensor, Any]]] # tuple: (messages, finish_reason) -OutputsType = List[Tuple[List[Dict], str]] +OutputsType = List[Dict[List[Dict], str]] # TODO: Check if not hasattr(RepeatSampler, 'old_len_func'): origin_len_func = RepeatSampler.__len__ @@ -72,7 +72,16 @@ def patched_len(self) -> int: RepeatSampler.__len__ = patched_len RepeatSampler.old_len_func = origin_len_func +""" +Refactor: + Rollout 返回结果 message.content -> completion_ids + 设置 request_config.return_details + 获取 completion_ids + 修改template.encode 逻辑:只encode prompt 部分 + 多轮和 loss_scale 可能会比较麻烦 + +""" class GRPOCallback(TrainerCallback): def __init__(self, trainer): @@ -309,6 +318,7 @@ def __init__(self, top_k=args.top_k, repetition_penalty=args.repetition_penalty, stop=args.stop_words, + return_details=True ) # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the @@ -719,6 +729,7 @@ def _infer_single_or_multi_turn(self, - List of responses per prompt - Each response is a tuple of (message_history, finish_reason) """ + # for external server, pass the system args which may define in trainer self._set_inputs_system(inputs) # infer first turn results: List[ChatCompletionResponse] = self._infer(inputs, request_config, is_global_inputs) @@ -729,9 +740,14 @@ def _infer_single_or_multi_turn(self, _choices = [] for choice in output.choices: _input: Dict = deepcopy(inputs[i]) + # origin messages may contain response, we should remove it InferRequest.remove_response(_input['messages']) _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) - _choices.append((_input['messages'], choice.finish_reason)) + output_dict = { + 'messages': _input['messages'], + 'finish_reason': choice.finish_reason, + 'completion_ids': choice.token_ids } + _choices.append(output_dict) outputs.append(_choices) outputs = [item for sublist in outputs for item in sublist] else: @@ -742,11 +758,17 @@ def _infer_single_or_multi_turn(self, _choices = [] for choice in output.choices: # concated in Engine + _choice = { + 'messages': choice.messages, + 'finish_reason': choice.finish_reason, + 'completion_ids': choice.token_ids + } if self.use_gym_env: - _choices.append( - (choice.messages, choice.finish_reason, choice.total_reward, choice.trajectory_info)) - else: - _choices.append((choice.messages, choice.finish_reason)) + _choice.update( + { + 'total_reward': choice.total_reward, + 'trajectory_info': choice.trajectory_info + }) outputs.append(_choices) outputs = [item for sublist in outputs for item in sublist] else: @@ -796,7 +818,11 @@ def _infer_single_or_multi_turn(self, for stop, _input, result in zip(should_stops, current_inputs, results): index = _input['index'] if stop: - outputs[index] = (_input['messages'], _input['finish_reason']) + outputs[index] = { + 'messages': _input['messages'], + 'finish_reason': result.choices[0].finish_reason, + 'completion_ids': result.choices[0].token_ids + } else: current_request = self.inputs_to_rolloutrequest([_input])[0] infer_request = self.multi_turn_scheduler.step(current_request, result.choices[0], @@ -926,15 +952,17 @@ def _generate_completions(self, inputs: InputsType) -> InputsType: self.model.train() for i, output in enumerate(outputs): - inputs[i]['messages'] = output[0] - inputs[i]['is_truncated'] = output[1] == 'length' + inputs[i]['messages'] = output['messages'] + inputs[i]['is_truncated'] = output['finish_reason'] == 'length' + inputs[i]['completion_ids'] = output['completion_ids'] if self.use_gym_env: - inputs[i]['total_reward'] = output[2] - inputs[i]['trajectory_info'] = output[3] + inputs[i]['total_reward'] = output['total_reward'] + inputs[i]['trajectory_info'] = output['trajectory_info'] if 'trajectory_info' in output else None return inputs def _generate_and_score_completions(self, inputs: InputsType) -> InputsType: + # resample for overlong(> max_length) prompt data if self.template.truncation_strategy == 'raise': inputs = self.resample_truncated_inputs(inputs) @@ -1059,7 +1087,6 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): return inputs, rewards, rewards_per_func, completions def split_by_mini_batches(self, inputs, advantages): - # Slice to keep only the local part of the data # Slice to keep only the local part of the data process_slice = slice( self.accelerator.process_index * len(inputs), From 3a9f23cef45cf0ff6192843f6e1aa7c10b7ba7d3 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 29 Jul 2025 21:01:48 +0800 Subject: [PATCH 02/26] wip --- swift/trainers/rlhf_trainer/grpo_trainer.py | 30 ++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 674e3d23c9..702e77d498 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -746,7 +746,8 @@ def _infer_single_or_multi_turn(self, output_dict = { 'messages': _input['messages'], 'finish_reason': choice.finish_reason, - 'completion_ids': choice.token_ids } + 'completion_ids': choice.token_ids, + 'prompt_ids': choice.prompt_ids,} _choices.append(output_dict) outputs.append(_choices) outputs = [item for sublist in outputs for item in sublist] @@ -761,7 +762,8 @@ def _infer_single_or_multi_turn(self, _choice = { 'messages': choice.messages, 'finish_reason': choice.finish_reason, - 'completion_ids': choice.token_ids + 'completion_ids': choice.token_ids, + 'prompt_ids': choice.prompt_ids, } if self.use_gym_env: _choice.update( @@ -821,7 +823,8 @@ def _infer_single_or_multi_turn(self, outputs[index] = { 'messages': _input['messages'], 'finish_reason': result.choices[0].finish_reason, - 'completion_ids': result.choices[0].token_ids + 'completion_ids': result.choices[0].token_ids, + 'prompt_ids': result.choices[0].prompt_ids, } else: current_request = self.inputs_to_rolloutrequest([_input])[0] @@ -950,11 +953,32 @@ def _generate_completions(self, inputs: InputsType) -> InputsType: # In training mode, ensure the model is returned to train() mode after inference # This is necessary as pt engines set the model to eval mode during generation self.model.train() + device = self.accelerator.device + + prompt_ids_list = [input['prompt_ids'] for input in inputs] + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + completion_ids_list = [output['completion_ids'] for output in outputs] + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + prompt_completion_ids = pad(prompt_completion_ids, padding_value=self.pad_token_id, padding_side="right") + for i, _input in enumerate(inputs): + _input['prompt_ids'] = prompt_ids[i] + _input['completion_ids'] = completion_ids[i] + _input['prompt_completion_ids'] = prompt_completion_ids[i] + + if completion_mask not in inputs[0]: + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() for i, output in enumerate(outputs): inputs[i]['messages'] = output['messages'] inputs[i]['is_truncated'] = output['finish_reason'] == 'length' inputs[i]['completion_ids'] = output['completion_ids'] + inputs[i]['prompt_ids'] = output['prompt_ids'] + inputs[i]['prompt_completion_ids'] = output['prompt_completion_ids'] if self.use_gym_env: inputs[i]['total_reward'] = output['total_reward'] inputs[i]['trajectory_info'] = output['trajectory_info'] if 'trajectory_info' in output else None From 061a2d5a360a872ce5ff010515a0e9b2e73b0543 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 30 Jul 2025 18:06:40 +0800 Subject: [PATCH 03/26] revert prompt_ids --- swift/trainers/rlhf_trainer/grpo_trainer.py | 26 +-------------------- 1 file changed, 1 insertion(+), 25 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 702e77d498..896d8e1a7c 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -746,8 +746,7 @@ def _infer_single_or_multi_turn(self, output_dict = { 'messages': _input['messages'], 'finish_reason': choice.finish_reason, - 'completion_ids': choice.token_ids, - 'prompt_ids': choice.prompt_ids,} + 'completion_ids': choice.token_ids} _choices.append(output_dict) outputs.append(_choices) outputs = [item for sublist in outputs for item in sublist] @@ -763,7 +762,6 @@ def _infer_single_or_multi_turn(self, 'messages': choice.messages, 'finish_reason': choice.finish_reason, 'completion_ids': choice.token_ids, - 'prompt_ids': choice.prompt_ids, } if self.use_gym_env: _choice.update( @@ -824,7 +822,6 @@ def _infer_single_or_multi_turn(self, 'messages': _input['messages'], 'finish_reason': result.choices[0].finish_reason, 'completion_ids': result.choices[0].token_ids, - 'prompt_ids': result.choices[0].prompt_ids, } else: current_request = self.inputs_to_rolloutrequest([_input])[0] @@ -953,32 +950,11 @@ def _generate_completions(self, inputs: InputsType) -> InputsType: # In training mode, ensure the model is returned to train() mode after inference # This is necessary as pt engines set the model to eval mode during generation self.model.train() - device = self.accelerator.device - - prompt_ids_list = [input['prompt_ids'] for input in inputs] - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] - completion_ids_list = [output['completion_ids'] for output in outputs] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - prompt_completion_ids = pad(prompt_completion_ids, padding_value=self.pad_token_id, padding_side="right") - for i, _input in enumerate(inputs): - _input['prompt_ids'] = prompt_ids[i] - _input['completion_ids'] = completion_ids[i] - _input['prompt_completion_ids'] = prompt_completion_ids[i] - - if completion_mask not in inputs[0]: - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() for i, output in enumerate(outputs): inputs[i]['messages'] = output['messages'] inputs[i]['is_truncated'] = output['finish_reason'] == 'length' inputs[i]['completion_ids'] = output['completion_ids'] - inputs[i]['prompt_ids'] = output['prompt_ids'] - inputs[i]['prompt_completion_ids'] = output['prompt_completion_ids'] if self.use_gym_env: inputs[i]['total_reward'] = output['total_reward'] inputs[i]['trajectory_info'] = output['trajectory_info'] if 'trajectory_info' in output else None From 8d2b170a40a4bfc581f78722d05d483336b58e83 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 1 Aug 2025 12:10:18 +0800 Subject: [PATCH 04/26] remove tokenizer in reward --- swift/plugin/orm.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index 53f53745ec..9b5e4947b7 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -301,14 +301,12 @@ def __call__(self, completions, **kwargs) -> List[float]: class CosineReward(ORM): # https://arxiv.org/abs/2502.03373 def __init__(self, - tokenizer=None, cosine_min_len_value_wrong: float = -0.5, cosine_max_len_value_wrong: float = 0.0, cosine_min_len_value_correct: float = 1.0, cosine_max_len_value_correct: float = 0.5, cosine_max_len: int = 1000, accuracy_orm=None): - self.tokenizer = tokenizer self.min_len_value_wrong = cosine_min_len_value_wrong self.max_len_value_wrong = cosine_max_len_value_wrong self.min_len_value_correct = cosine_min_len_value_correct @@ -323,8 +321,9 @@ def cosfn(t, T, min_value, max_value): def __call__(self, completions, solution, **kwargs) -> List[float]: acc_rewards = self.accuracy_orm(completions, solution, **kwargs) + completion_ids = kwargs.get('completion_ids') rewards = [] - for content, acc_reward in zip(completions, acc_rewards): + for ids, acc_reward in zip(completion_ids, acc_rewards): is_correct = acc_reward >= 1. if is_correct: # Swap min/max for correct answers @@ -333,7 +332,7 @@ def __call__(self, completions, solution, **kwargs) -> List[float]: else: min_value = self.max_len_value_wrong max_value = self.min_len_value_wrong - gen_len = len(self.tokenizer.encode(content)) + gen_len = len(ids) reward = self.cosfn(gen_len, self.max_len, min_value, max_value) rewards.append(reward) return rewards @@ -380,16 +379,16 @@ def __call__(self, completions, **kwargs) -> List[float]: class SoftOverlong(ORM): - def __init__(self, tokenizer, soft_max_length, soft_cache_length): - self.tokenizer = tokenizer + def __init__(self, soft_max_length, soft_cache_length): assert soft_cache_length < soft_max_length self.soft_max_length = soft_max_length self.soft_cache_length = soft_cache_length def __call__(self, completions, **kwargs) -> List[float]: rewards = [] - for completion in completions: - completion_length = len(self.tokenizer.encode(completion)) + completion_ids = kwargs.get('completion_ids') + for ids in completion_ids: + completion_length = len(ids) expected_len = self.soft_max_length - self.soft_cache_length exceed_len = completion_length - expected_len rewards.append(min(-exceed_len / self.soft_cache_length, 0)) From 0dea74da892f6dbf0008b2f0e60e02b92a97ed04 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 1 Aug 2025 15:17:04 +0800 Subject: [PATCH 05/26] encode ids --- swift/trainers/rlhf_trainer/grpo_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 1d616ec687..55e2fa0e30 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1153,7 +1153,12 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)): # Encode and process each batch (size=bs) with self._template_context(template): - batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch] + processed_assistant_batch = [] + for data in batch: + InferRequest.remove_response(data['messages']) + data['messages'].append({'role': 'assistant', 'content': data['completion_ids']}) + processed_assistant_batch.append(data) + batch_encoded_inputs = [template.encode(infer_request) for infer_request in processed_assistant_batch] batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) # Process labels and masks From 1114cea83e071c1686ce2c44c6c710ad39d269cd Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Sun, 3 Aug 2025 12:45:49 +0800 Subject: [PATCH 06/26] wip replace ids --- swift/llm/template/base.py | 4 +- swift/llm/utils.py | 2 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 20 ++++---- swift/trainers/rlhf_trainer/utils.py | 51 ++++++++++++++++++++- 4 files changed, 66 insertions(+), 11 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index f6cd5c2073..9b072b9940 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -928,7 +928,9 @@ def _pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float @staticmethod def _add_default_tags(inputs: StdTemplateInputs): - total_content = '\n'.join([message['content'] or '' for message in inputs.messages]) + total_content = '\n'.join( + (message['content'] if isinstance(message['content'], str) else str(message['content']) or '') + for message in inputs.messages) if inputs.rejected_response: if isinstance(inputs.rejected_response, str): total_content += inputs.rejected_response diff --git a/swift/llm/utils.py b/swift/llm/utils.py index bfdebc90ef..fbfcf90a8e 100644 --- a/swift/llm/utils.py +++ b/swift/llm/utils.py @@ -28,7 +28,7 @@ Tool = Dict[str, Union[str, Dict]] History = List[Union[Tuple[str, str], List[str]]] -Message = Dict[str, Union[str, List[Dict[str, Any]]]] +Message = Dict[str, Union[str, List[Dict[str, Any]], List[int]]] Messages = List[Message] diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 55e2fa0e30..4a516f79ef 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -43,7 +43,7 @@ from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin from .utils import (_ForwardRedirection, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, - patch_profiling_decorator) + patch_profiling_decorator, replace_assistant_response_with_ids) from .vllm_client import VLLMClient try: @@ -748,6 +748,8 @@ def _infer_single_or_multi_turn(self, output_dict = { 'messages': _input['messages'], 'finish_reason': choice.finish_reason, + # NOTE: for training, we use rollout token_ids to calculate loss + # because the tokenizer encode/decode may change the token ids 'completion_ids': choice.token_ids } _choices.append(output_dict) @@ -774,7 +776,7 @@ def _infer_single_or_multi_turn(self, outputs.append(_choices) outputs = [item for sublist in outputs for item in sublist] else: - # PTEngine or vLLMLLMEngine + # multi turn for PTEngine or vLLMLLMEngine orig_size = len(inputs) outputs = [None] * orig_size # we remove origin response in first turn @@ -806,6 +808,9 @@ def _infer_single_or_multi_turn(self, if 'index' not in current_input: current_input['index'] = cnt current_input['finish_reason'] = choice.finish_reason + if 'completion_ids' not in current_input: + current_input['completion_ids'] = [] + current_input['completion_ids'].append(choice.token_ids) cnt += 1 current_inputs.append(current_input) @@ -1153,12 +1158,11 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)): # Encode and process each batch (size=bs) with self._template_context(template): - processed_assistant_batch = [] - for data in batch: - InferRequest.remove_response(data['messages']) - data['messages'].append({'role': 'assistant', 'content': data['completion_ids']}) - processed_assistant_batch.append(data) - batch_encoded_inputs = [template.encode(infer_request) for infer_request in processed_assistant_batch] + processed_batch = [ + replace_assistant_response_with_ids(data['messages'], data['completion_ids']) + if 'completion_ids' in data else data for data in batch + ] + batch_encoded_inputs = [template.encode(data) for data in processed_batch] batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) # Process labels and masks diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 15a4b7301e..79989fde68 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -3,7 +3,7 @@ import time from contextlib import contextmanager from types import MethodType -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union import torch import torch.nn.functional as F @@ -18,6 +18,9 @@ if is_swanlab_available(): import swanlab +if TYPE_CHECKING: + from swift.llm.utils import Messages + def round_robin(num_reqs, num_workers): """Distribute requests evenly across workers using round-robin algorithm. @@ -228,3 +231,49 @@ def entropy_from_logits(logits, chunk_size: int = 1) -> torch.Tensor: chunk_entropy = -(torch.exp(logps) * logps).sum(-1) per_token_entropies.append(chunk_entropy) return torch.cat(per_token_entropies, dim=0) + + +def replace_assistant_response_with_ids(messages: 'Messages', completion_ids: List[Union[int, + List[int]]]) -> 'Messages': + """ + Replaces the content of assistant messages with the provided completion IDs. + + This function processes messages in reverse order and replaces the content of + assistant messages with the given completion IDs. If completion_ids is a flat + list of integers, it will be treated as a single completion sequence. + + Args: + messages: List of message dictionaries containing conversation history. + completion_ids: Either: + - A single list of token IDs (e.g., [1, 2, 3]) + - A list of completion sequences (e.g., [[1, 2], [3, 4]]) + + Returns: + The modified messages list with assistant responses replaced by token IDs. + + Example: + >>> messages = [{'role': 'user', 'content': 'Hello'}, + ... {'role': 'assistant', 'content': 'Hi there'}] + >>> replace_assistant_response_with_ids(messages, [1, 2, 3]) + [{'role': 'user', 'content': 'Hello'}, + {'role': 'assistant', 'content': [1, 2, 3]}] + """ + # Normalize input to always be list of lists + if isinstance(completion_ids[0], int): + completion_ids = [completion_ids] + + remaining_completions = len(completion_ids) + completion_index = 0 + + for message in reversed(messages): + if message['role'] != 'assistant': + continue + + if completion_index >= remaining_completions: + break + + # Assign completion IDs (starting from last) + message['content'] = completion_ids[-1 - completion_index] + completion_index += 1 + + return messages From ebf1b35f05994f1d4fd856754cd60374f474c977 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 4 Aug 2025 11:52:32 +0800 Subject: [PATCH 07/26] fix adv --- swift/trainers/rlhf_trainer/grpo_trainer.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 4a516f79ef..9b14dd100c 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -71,16 +71,6 @@ def patched_len(self) -> int: RepeatSampler.__len__ = patched_len RepeatSampler.old_len_func = origin_len_func -""" -Refactor: - Rollout 返回结果 message.content -> completion_ids - 设置 request_config.return_details - 获取 completion_ids - 修改template.encode 逻辑:只encode prompt 部分 - 多轮和 loss_scale 可能会比较麻烦 - - -""" class GRPOCallback(TrainerCallback): @@ -1770,6 +1760,7 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non 'completion': list(self._textual_logs['completion'])[:seen_nums], **{k: list(v)[:seen_nums] for k, v in self._textual_logs['rewards'].items()}, + 'advantage': list(self._logs['advantages'])[:seen_nums], } if self.use_gym_env: table['trajactory_info'] = self._textual_logs['trajactory_info'] From ffcd9b4d84566edfa0a8bfbb1e3f9394efc886bb Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 4 Aug 2025 19:30:07 +0800 Subject: [PATCH 08/26] wip --- swift/llm/infer/protocol.py | 5 + swift/plugin/__init__.py | 8 +- swift/plugin/multi_turn.py | 318 +++++++++++++++++++- swift/trainers/rlhf_trainer/grpo_trainer.py | 1 + 4 files changed, 321 insertions(+), 11 deletions(-) diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 154f767eed..148d197b20 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -321,6 +321,11 @@ class CompletionResponseChoice: logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None +class RolloutOutput: + results: List[ChatCompletionResponse] + extra_info: Dict[str, Any] + + @dataclass class ChatCompletionResponse: model: str diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 4dc40bcf58..d396a5690d 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -15,8 +15,8 @@ from .orm import orms, ORM from .multi_turn import multi_turns from .rm_plugin import rm_plugins - from .env import envs - from .context_manager import context_managers + from .env import envs, Env + from .context_manager import context_managers, ContextManager else: _import_structure = { @@ -31,8 +31,8 @@ 'orm': ['orms', 'ORM'], 'multi_turn': ['multi_turns'], 'rm_plugin': ['rm_plugins'], - 'env': ['env'], - 'context_manager': ['context_managers'] + 'env': ['envs', 'Env'], + 'context_manager': ['context_managers', 'ContextManager'], } import sys diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index faa1009780..fcb3aa4431 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -1,23 +1,204 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from abc import ABC +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from swift.plugin import ContextManager, Env, context_managers, envs +import asyncio +from copy import deepcopy if TYPE_CHECKING: - from swift.llm.infer.protocol import RolloutResponseChoice + from swift.llm.infer.protocol import ChatCompletionResponse, RolloutResponseChoice, RequestConfig, GymRolloutResponseChoice from swift.llm.template import RolloutInferRequest + from swift.llm.infer.infer_engine import GRPOVllmEngine + from swift.llm.utils import Messages + + +def remove_response(messages: 'Messages') -> Optional[str]: + last_role = messages[-1]['role'] if messages else None + if last_role == 'assistant': + return messages.pop()['content'] + +class RolloutOutput: + results: 'ChatCompletionResponse' + # multi turn rollout + messages: Optional['Messages'] # history of messages for the rollout + response_token_ids: Optional[List[List[int]]] # response token ids for each rollout turn + response_loss_scale: Optional[List[List[int]]] # response loss scale for each rollout turn + extra_info: Optional[Dict[str, Any]] # make sure serializable -class MultiTurnScheduler(ABC): - def __init__(self, max_turns: Optional[int] = None, *args, **kwargs): +class RolloutScheduler(ABC): + # Single Turn Rollout Scheduler + def __init__(self, infer_engine: 'GRPOVllmEngine', max_turns: Optional[int] = None, *args, **kwargs): + self.infer_engine = infer_engine self.max_turns = max_turns - @abstractmethod + async def async_infer(self, + infer_requests: List[Union['RolloutInferRequest', Dict[str, Any]]], + request_config: 'RequestConfig', + *, + use_tqdm: Optional[bool] = None, + **kwargs) -> List['ChatCompletionResponse']: + assert request_config.n == 1 + + async def _infer_async_single(infer_request: Union['RolloutInferRequest', Dict[str, Any]], + request_config: 'RequestConfig', **kwargs): + from swift.llm.infer.protocol import RolloutInferRequest + if isinstance(infer_request, Dict): + infer_request = RolloutInferRequest(**infer_request) + + return await self.run(infer_request, request_config, **kwargs) + + tasks = [_infer_async_single(infer_request, request_config, **kwargs) for infer_request in infer_requests] + if use_tqdm is None: + use_tqdm = len(infer_requests) > 1 + return await self.infer_engine._batch_infer_stream(tasks, request_config.stream, use_tqdm) + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> 'RolloutOutput': + result: 'ChatCompletionResponse' = await self.infer_engine.infer_async(infer_request, request_config, **kwargs) + response_token_ids = result.choices[0].token_ids + response_loss_scale = [1] * len(response_token_ids) + return RolloutOutput( + results=result, + messages=infer_request.messages, + response_token_ids=[response_token_ids], + response_loss_scale=[response_loss_scale], + extra_info={'num_turns': 1}) + + +class MultiTurnScheduler(RolloutScheduler, ABC): + """ + Abstract base class for multi-turn rollout scheduling. + + Provides default implementation for multi-turn conversation management with two customization approaches: + + 1. FULL CUSTOMIZATION: + Override the `run()` method to implement completely custom multi-turn logic. + - Gives full control over the rollout process + - Must handle all turn management and termination logic + + 2. PARTIAL CUSTOMIZATION: + Implement the required `step()` method and optionally override `check_finished()` + - Uses MultiTurnScheduler's run() method infrastructure + - Only need to implement turn transition logic in step() + - Optionally customize termination conditions + + Note: You must implement at least one of these approaches in your subclass. + + Options: + - If each round's response token ids are included in the RolloutOutput, + the Trainer can skip encoding the completion text into token_ids when calculating loss. + This avoids potential training inconsistencies due to asymmetric encode/decode behavior. + See: https://github.com/0russwest0/Agent-R1/issues/30#issuecomment-2826155367 + + - If both response token ids and response loss_scale are returned in the RolloutOutput, + you can manually control the loss mask by specifying a loss_scale value for each token. + The Trainer will use the provided loss_scale values directly when computing the loss. + Note: Returning loss_scale requires that response token ids are also returned, + as the two must be aligned in length for correct loss computation. + + Loss mask configuration: + During rollout, some parts of the completion (e.g., environment observations embedded in completion) + may need to be masked out from loss computation. + There are two supported strategies: + + 1. Use the built-in `loss_scale` parameter in ms-swift and do not return response token ids. + 2. Return response token ids along with a corresponding response loss scale tensor (of equal length) to indicate the loss mask for each token. + """ + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> Union['RolloutOutput', List['RolloutOutput']]: + current_request = infer_request + current_turn = 1 + info_dict = {} + while True: + messages = current_request.messages + if current_turn == 1 or not messages[-1]['content']: + # If it's the first turn or the last message content is empty(dummy), remove the response + remove_response(messages) + + result: 'ChatCompletionResponse' = await self.infer_engine.infer_async(current_request, request_config, + **kwargs) + result_choice: 'RolloutResponseChoice' = result.choices[0] + + completion = result_choice.message.content + if messages[-1]['role'] == 'assistant': + messages[-1]['content'] += completion + else: + messages.append({'role': 'assistant', 'content': completion}) + + should_stop = self.check_finished(current_request, result_choice, current_turn) + + if self.max_turns: + should_stop = should_stop or (current_turn >= self.max_turns) + + if should_stop: + result_choice.messages = messages + info_dict['num_turns'] = current_turn + for key, value in info_dict.items(): + if hasattr(result_choice, key): + setattr(result_choice, key, value) + else: + result_choice.multi_turn_infos[key] = value + result_choice.process_images() + return result + + ret = self.step(current_request, result_choice, current_turn) + if isinstance(ret, tuple): + current_request, info_dict = ret + else: + current_request = ret + info_dict = {} + assert isinstance(current_request, RolloutInferRequest) + if current_request.messages[-1]['role'] == 'assistant': + # Add a dummy response to allow engine to continue generating + current_request.messages.append({'role': 'assistant', 'content': None}) + + current_turn += 1 + def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', Dict]]: - pass + """ + Handles transition between conversation turns. + + Args: + infer_request: Current inference request + result: Response from current turn + current_turn: Current turn number + + Returns: + Either: + - The next inference request, OR + - A tuple of (next_request, info_dict) where info_dict contains + additional metadata to be stored in the final result + """ + raise NotImplementedError( + 'Please implement the `step` method in your MultiTurnScheduler subclass, or override the `run` method.') def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', current_turn: int) -> bool: + """ + Default termination logic for checking if a multi-turn rollout should end. + + This method is invoked by: + - The base class MultiTurnScheduler.run() method, OR + - Custom run() methods when explicitly called + + Note: This is the default implementation that can be overridden by subclasses for custom termination logic. + + Termination Conditions: + 1. When response hits length limit (finish_reason == 'length') + 2. When conversation reaches max_turns (if max_turns is set) + + Args: + infer_request: The inference request object + result: Contains generation results including finish_reason + current_turn: Current conversation turn count + + Returns: + bool: True to terminate conversation, False to continue + """ if result.finish_reason == 'length': return True if self.max_turns and current_turn >= self.max_turns: @@ -25,6 +206,128 @@ def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutR return False +class GYMScheduler(RolloutScheduler): + + def __init__(self, + infer_engine: 'GRPOVllmEngine', + gym_env: Optional[str] = None, + context_manager_name: Optional[str] = None, + max_turns: Optional[int] = None, + **kwargs): + from swift.llm.infer.protocol import ChatCompletionResponse, RolloutResponseChoice, GymRolloutResponseChoice + super().__init__(infer_engine, max_turns, **kwargs) + self.gym_env_name = gym_env + self.context_manager_name = context_manager_name + + async def _create_env(self, env_config: Dict) -> Env: + """Create environment instance from configuration.""" + env_name = env_config.get('name', self.gym_env_name) + if env_name not in envs: + raise ValueError(f"Environment '{env_name}' not found. Available: {list(envs.keys())}") + return envs[env_name](env_config) + + async def _create_context_manager(self, ctx_config: Dict) -> ContextManager: + """Create context manager from configuration.""" + ctx_name = ctx_config.get('name', self.context_manager_name) + + if not ctx_name: + ctx_name = 'dummyContextManager' + + return context_managers[ctx_name](ctx_config) + + async def _close_env_async(self, env: Env): + """Safely close environment with async support.""" + try: + if hasattr(env, 'close') and asyncio.iscoroutinefunction(env.close): + await env.close() + elif hasattr(env, 'close'): + env.close() + except Exception: + # Handle any exceptions during environment closure + pass + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> 'ChatCompletionResponse': + """ + Execute the gym environment-based rollout: + 1. Initialize environment and context manager + 2. Run multi-turn interactions between LLM and environment + 3. Collect trajectory information and rewards + """ + # Extract configurations from request + env_config = infer_request.data_dict.get('env_config', {}) + ctx_config = infer_request.data_dict.get('ctx_config', {}) + + # Create environment and context manager + env = await self._create_env(env_config) + context_manager = await self._create_context_manager(ctx_config) + + try: + # Initialize environment + observation, info, system_message = await env.reset(infer_request) + + # Build initial messages + messages: 'Messages' = [] + if system_message: + messages.append({'role': 'system', 'content': system_message}) + messages.append({'role': 'user', 'content': observation}) + + current_request = deepcopy(infer_request) + current_turn = 1 + done = False + total_reward = 0.0 + step_rewards = [] + trajectory_id = f'{id(infer_request)}_{hash(str(infer_request))}' + trajectory_info = [info] + + while not done and current_turn <= (self.max_turns or float('inf')): + # Apply context management (e.g., history compression) + messages = context_manager.manage_context(messages, trajectory_id) + current_request.messages = messages + remove_response(current_request.messages) + + result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs) + result_choice: RolloutResponseChoice = result.choices[0] + completion = result_choice.message.content + messages.append({'role': 'assistant', 'content': completion}) + + # Execute environment step + next_obs, reward, done, step_info = await env.step(deepcopy(messages)) + + # Update trajectory information + total_reward += reward + step_rewards.append(reward) + trajectory_info.append(step_info) + + # Prepare for next turn + if not done: + messages.append({'role': 'user', 'content': next_obs}) + current_request.messages = messages + current_turn += 1 + + # Build final response with gym-specific information + final_choice = GymRolloutResponseChoice( + index=result_choice.index, + message=result_choice.message, + finish_reason=result_choice.finish_reason, + logprobs=result_choice.logprobs, + messages=messages, + trajectory_id=trajectory_id, + total_reward=total_reward, + step_rewards=step_rewards, + trajectory_info=trajectory_info) + + return ChatCompletionResponse( + model=self.infer_engine.model_name, + choices=[final_choice], + usage=result.usage, + id=f'gym_{trajectory_id}') + + finally: + # Ensure environment is properly closed + await self._close_env_async(env) + + class MathTipsScheduler(MultiTurnScheduler): tips_prompt = 'But wait... It seems I made a mistake,' @@ -98,6 +401,7 @@ def step( multi_turns = { + 'base_scheduler': RolloutScheduler, 'math_tip_trick': MathTipsScheduler, 'math_tip_trick_multi_turn': MathTipsMultiTurnScheduler, } diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 9b14dd100c..74666725ce 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -801,6 +801,7 @@ def _infer_single_or_multi_turn(self, if 'completion_ids' not in current_input: current_input['completion_ids'] = [] current_input['completion_ids'].append(choice.token_ids) + cnt += 1 current_inputs.append(current_input) From 05ae14337205cc40c598d64e32243b3b015da1fa Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 4 Aug 2025 19:31:51 +0800 Subject: [PATCH 09/26] wip --- swift/plugin/multi_turn.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index fcb3aa4431..3664ad4450 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -109,6 +109,7 @@ class MultiTurnScheduler(RolloutScheduler, ABC): async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> Union['RolloutOutput', List['RolloutOutput']]: + from swift.llm.template import RolloutInferRequest current_request = infer_request current_turn = 1 info_dict = {} @@ -214,7 +215,6 @@ def __init__(self, context_manager_name: Optional[str] = None, max_turns: Optional[int] = None, **kwargs): - from swift.llm.infer.protocol import ChatCompletionResponse, RolloutResponseChoice, GymRolloutResponseChoice super().__init__(infer_engine, max_turns, **kwargs) self.gym_env_name = gym_env self.context_manager_name = context_manager_name @@ -248,6 +248,7 @@ async def _close_env_async(self, env: Env): async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> 'ChatCompletionResponse': + from swift.llm.infer.protocol import ChatCompletionResponse, GymRolloutResponseChoice """ Execute the gym environment-based rollout: 1. Initialize environment and context manager @@ -286,8 +287,8 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque current_request.messages = messages remove_response(current_request.messages) - result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs) - result_choice: RolloutResponseChoice = result.choices[0] + result: 'ChatCompletionResponse' = await self.infer_async(current_request, request_config, **kwargs) + result_choice: 'RolloutResponseChoice' = result.choices[0] completion = result_choice.message.content messages.append({'role': 'assistant', 'content': completion}) From fdb1ae8568977b41ead2d3dc8bbceef78c775065 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 4 Aug 2025 20:15:40 +0800 Subject: [PATCH 10/26] wip --- swift/llm/infer/rollout.py | 13 +++++++++++-- swift/plugin/multi_turn.py | 1 - 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 0a02d0f038..323dbda461 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -22,6 +22,8 @@ from swift.llm import RolloutArguments, SwiftPipeline from swift.llm.template.template_inputs import RolloutInferRequest +from swift.plugin.multi_turn import multi_turns, RolloutScheduler + from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest @@ -66,6 +68,14 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int os.environ['VLLM_DP_SIZE'] = str(args.vllm_data_parallel_size) os.environ['VLLM_DP_MASTER_PORT'] = str(master_port) engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(None)) + if args.multi_turn_scheduler: + if args.multi_turn_scheduler not in multi_turns: + raise ValueError(f"Multi-turn scheduler '{args.multi_turn_scheduler}' not found in multi_turns.") + rollout_engine: RolloutScheduler = multi_turns[args.multi_turn_scheduler](engine, args.max_turns) + if not rollout_engine: + raise ValueError(f"Failed to initialize multi-turn scheduler '{args.multi_turn_scheduler}'.") + else: + rollout_engine = engine # Send ready signal to parent process connection.send({'status': 'ready'}) @@ -81,7 +91,7 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int if command['type'] in ['call', 'fire_and_forget']: method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) - method = getattr(engine, method_name, None) or getattr(engine.engine, method_name, None) + method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) or result = method(*args, **kwargs) if command['type'] == 'call': connection.send(result) @@ -193,7 +203,6 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, - 'multi_turn_scheduler': args.multi_turn_scheduler, 'max_turns': args.max_turns, 'use_gym_env': args.use_gym_env, 'gym_env': args.gym_env, diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 3664ad4450..001ca57725 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -328,7 +328,6 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque # Ensure environment is properly closed await self._close_env_async(env) - class MathTipsScheduler(MultiTurnScheduler): tips_prompt = 'But wait... It seems I made a mistake,' From 725e1f6d16ea7fec522ec87d0e82f220fb40ec18 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 5 Aug 2025 17:21:01 +0800 Subject: [PATCH 11/26] wip --- .../infer/infer_engine/grpo_vllm_engine.py | 178 +---------- swift/llm/infer/protocol.py | 31 +- swift/llm/infer/rollout.py | 15 +- swift/plugin/multi_turn.py | 292 +++++++++++------- swift/trainers/rlhf_trainer/vllm_client.py | 17 +- 5 files changed, 207 insertions(+), 326 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 9831b3301b..100a4126fd 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -12,8 +12,7 @@ from swift.plugin.context_manager import ContextManager, context_managers from swift.plugin.env import Env, envs from swift.plugin.multi_turn import MultiTurnScheduler -from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, GymRolloutResponseChoice, - RequestConfig, RolloutResponseChoice) +from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig from .utils import AdapterRequest try: @@ -161,161 +160,6 @@ def infer( adapter_request=adapter_request, ) - async def async_infer(self, - infer_requests: List[Union[RolloutInferRequest, Dict[str, Any]]], - request_config: Optional[RequestConfig] = None, - metrics: Optional[List[Metric]] = None, - *, - use_tqdm: Optional[bool] = None, - **kwargs) -> List[ChatCompletionResponse]: - if request_config is None: - request_config = RequestConfig() - assert request_config.n == 1 - - async def _infer_async_single(infer_request: Union[RolloutInferRequest, Dict[str, Any]], - request_config: Optional[RequestConfig] = None, - **kwargs): - if isinstance(infer_request, Dict): - infer_request = RolloutInferRequest(**infer_request) - - # Route to appropriate sampling controller - if self.use_gym_env: - return await self._gym_sampling_controller(infer_request, request_config, **kwargs) - else: - return await self._multi_turn_sampling_controller(infer_request, request_config, **kwargs) - - tasks = [_infer_async_single(infer_request, request_config, **kwargs) for infer_request in infer_requests] - if use_tqdm is None: - use_tqdm = len(infer_requests) > 1 - return await self._batch_infer_stream(tasks, request_config.stream, use_tqdm, metrics) - - async def _gym_sampling_controller(self, infer_request: RolloutInferRequest, request_config: RequestConfig, - **kwargs) -> ChatCompletionResponse: - """Gym environment-based sampling controller.""" - # Create environment and context manager - env_config = infer_request.data_dict.get('env_config', {}) - env = self._create_env(env_config) - ctx_config = infer_request.data_dict.get('ctx_config', {}) - context_manager = self._create_context_manager(ctx_config) - - try: - # Environment reset - observation, info, system_message = await env.reset(infer_request) - - # Initialize conversation - messages = [] - if system_message: - messages.append({'role': 'system', 'content': system_message}) - messages.append({'role': 'user', 'content': observation}) - - current_request = deepcopy(infer_request) - current_turn = 1 - done = False - total_reward = 0.0 - step_rewards = [] - trajectory_id = f'{id(infer_request)}_{hash(str(infer_request))}' - trajectory_info = [info] - - while True: - # Apply context management - messages = context_manager.manage_context(messages, trajectory_id) - current_request.messages = messages - # Remove any previous assistant response for generation - InferRequest.remove_response(current_request.messages) - - # Generate LLM response - result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs) - result_choice: RolloutResponseChoice = result.choices[0] - - completion = result_choice.message.content - messages.append({'role': 'assistant', 'content': completion}) - - # Environment step - next_observation, reward, done, step_info = await env.step(deepcopy(messages)) - - # Accumulate rewards - total_reward += reward - step_rewards.append(reward) - trajectory_info.append(step_info) - - if done or current_turn > self.max_turns: - break - - messages.append({'role': 'user', 'content': next_observation}) - current_request.messages = messages - current_turn += 1 - - # Create final result with gym-specific information - final_choice = GymRolloutResponseChoice( - index=result_choice.index, - message=result_choice.message, - finish_reason=result_choice.finish_reason, - logprobs=result_choice.logprobs, - messages=messages, - trajectory_id=trajectory_id, - total_reward=total_reward, - step_rewards=step_rewards, - trajectory_info=trajectory_info) - - return ChatCompletionResponse( - model=self.model_name, choices=[final_choice], usage=result.usage, id=f'gym_{trajectory_id}') - - finally: - await self._close_env_async(env) - - async def _multi_turn_sampling_controller(self, infer_request: RolloutInferRequest, request_config: RequestConfig, - **kwargs) -> ChatCompletionResponse: - """Multi-turn scheduler-based sampling controller.""" - current_request = infer_request - current_turn = 1 - info_dict = {} - while True: - messages = current_request.messages - if current_turn == 1 or not messages[-1]['content']: - # If it's the first turn or the last message content is empty(dummy), remove the response - InferRequest.remove_response(messages) - - result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs) - result_choice: RolloutResponseChoice = result.choices[0] - - completion = result_choice.message.content - if messages[-1]['role'] == 'assistant': - messages[-1]['content'] += completion - else: - messages.append({'role': 'assistant', 'content': completion}) - - if self.multi_turn_scheduler: - should_stop = self.multi_turn_scheduler.check_finished(current_request, result_choice, current_turn) - else: - should_stop = True - - if self.max_turns: - should_stop = should_stop or (current_turn >= self.max_turns) - - if should_stop: - result_choice.messages = messages - info_dict['num_turns'] = current_turn - for key, value in info_dict.items(): - if hasattr(result_choice, key): - setattr(result_choice, key, value) - else: - result_choice.multi_turn_infos[key] = value - result_choice.process_images() - return result - - ret = self.multi_turn_scheduler.step(current_request, result_choice, current_turn) - if isinstance(ret, tuple): - current_request, info_dict = ret - else: - current_request = ret - info_dict = {} - assert isinstance(current_request, RolloutInferRequest) - if current_request.messages[-1]['role'] == 'assistant': - # Add a dummy response to allow engine to continue generating - current_request.messages.append({'role': 'assistant', 'content': None}) - - current_turn += 1 - async def _batch_infer_stream(self, tasks, stream: bool = True, @@ -338,17 +182,7 @@ async def _new_run(task): new_tasks = [_new_run(task) for task in tasks] return await self.batch_run(new_tasks) - async def _close_env_async(self, env: Env): - """Asynchronously close environment.""" - try: - if hasattr(env, 'close') and asyncio.iscoroutinefunction(env.close): - await env.close() - elif hasattr(env, 'close'): - env.close() - except Exception: - pass - - def _create_chat_completion_response(self, result, template: Template, request_config, + def _create_chat_completion_response(self, result: 'ChatCompletionResponse', template: Template, request_config, request_id) -> ChatCompletionResponse: assert result is not None num_generated_tokens = sum(len(output.token_ids) for output in result.outputs) @@ -359,13 +193,7 @@ def _create_chat_completion_response(self, result, template: Template, request_c response = template.decode(output.token_ids) logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(response, template) - - if self.use_gym_env: - choice_cls = GymRolloutResponseChoice - elif self.use_async_engine: - choice_cls = RolloutResponseChoice - else: - choice_cls = ChatCompletionResponseChoice + choice_cls = ChatCompletionResponseChoice token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None choice = choice_cls( diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 148d197b20..5d4ed9c827 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -298,19 +298,18 @@ class EmbeddingResponse: created: int = field(default_factory=lambda: int(time.time())) -@dataclass -class RolloutResponseChoice(ChatCompletionResponseChoice): - messages: Optional[Messages] = None - images: Optional[List[str]] = None - multi_turn_infos: Dict[str, Any] = field(default_factory=dict) - +# @dataclass +# class RolloutResponseChoice(ChatCompletionResponseChoice): +# messages: Optional[Messages] = None +# images: Optional[List[str]] = None +# multi_turn_infos: Dict[str, Any] = field(default_factory=dict) -@dataclass -class GymRolloutResponseChoice(RolloutResponseChoice): - trajectory_id: str = None - total_reward: float = 0.0 - step_rewards: List[float] = None - trajectory_info: List[Dict[str, Any]] = None +# @dataclass +# class GymRolloutResponseChoice(RolloutResponseChoice): +# trajectory_id: str = None +# total_reward: float = 0.0 +# step_rewards: List[float] = None +# trajectory_info: List[Dict[str, Any]] = None @dataclass @@ -321,15 +320,15 @@ class CompletionResponseChoice: logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None -class RolloutOutput: - results: List[ChatCompletionResponse] - extra_info: Dict[str, Any] +# class RolloutOutput: +# results: List[ChatCompletionResponse] +# extra_info: Dict[str, Any] @dataclass class ChatCompletionResponse: model: str - choices: List[Union[ChatCompletionResponseChoice, RolloutResponseChoice, GymRolloutResponseChoice]] + choices: List[ChatCompletionResponseChoice] usage: UsageInfo id: str = field(default_factory=lambda: f'chatcmpl-{random_uuid()}') object: str = 'chat.completion' diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 323dbda461..bfbb6a6964 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -22,8 +22,7 @@ from swift.llm import RolloutArguments, SwiftPipeline from swift.llm.template.template_inputs import RolloutInferRequest -from swift.plugin.multi_turn import multi_turns, RolloutScheduler - +from swift.plugin.multi_turn import RolloutScheduler, multi_turns from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest @@ -44,7 +43,6 @@ --vllm_tensor_parallel_size xxx \ --vllm_data_parallel_size xxx \ --vllm_use_async_engine true/false \ - --use_gym_env true/false \ --other_vllm_arguments Note: @@ -71,7 +69,13 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int if args.multi_turn_scheduler: if args.multi_turn_scheduler not in multi_turns: raise ValueError(f"Multi-turn scheduler '{args.multi_turn_scheduler}' not found in multi_turns.") - rollout_engine: RolloutScheduler = multi_turns[args.multi_turn_scheduler](engine, args.max_turns) + scheduler_cls = multi_turns[args.multi_turn_scheduler] + + kwargs = {} + if 'tokenizer' in list(inspect.signature(scheduler_cls.__init__).parameters): + kwargs['tokenizer'] = engine.default_template.tokenizer + + rollout_engine: RolloutScheduler = scheduler_cls(engine, args.max_turns, **kwargs) if not rollout_engine: raise ValueError(f"Failed to initialize multi-turn scheduler '{args.multi_turn_scheduler}'.") else: @@ -91,7 +95,7 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int if command['type'] in ['call', 'fire_and_forget']: method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) - method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) or + method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) result = method(*args, **kwargs) if command['type'] == 'call': connection.send(result) @@ -119,6 +123,7 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) method = getattr(engine, method_name, None) or getattr(engine.engine, method_name, None) + try: result = await method(*args, **kwargs) except Exception: diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 001ca57725..fad32df66b 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -1,15 +1,24 @@ +import asyncio from abc import ABC +from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from pydantic import BaseModel, Field + from swift.plugin import ContextManager, Env, context_managers, envs -import asyncio -from copy import deepcopy if TYPE_CHECKING: - from swift.llm.infer.protocol import ChatCompletionResponse, RolloutResponseChoice, RequestConfig, GymRolloutResponseChoice + from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice, RequestConfig from swift.llm.template import RolloutInferRequest from swift.llm.infer.infer_engine import GRPOVllmEngine from swift.llm.utils import Messages +""" + + 1. 修改 step 方法的返回值:统一为 Dict,run 方法需要相应做修改 + 2. response_id 和 response_loss_scale 在 step方法的 dict 返回,在run中维护list + 3. + +""" def remove_response(messages: 'Messages') -> Optional[str]: @@ -18,13 +27,28 @@ def remove_response(messages: 'Messages') -> Optional[str]: return messages.pop()['content'] -class RolloutOutput: +class RolloutOutput(BaseModel): results: 'ChatCompletionResponse' # multi turn rollout - messages: Optional['Messages'] # history of messages for the rollout - response_token_ids: Optional[List[List[int]]] # response token ids for each rollout turn - response_loss_scale: Optional[List[List[int]]] # response loss scale for each rollout turn - extra_info: Optional[Dict[str, Any]] # make sure serializable + messages: Optional['Messages'] = None # Conversation history for the final rollout (required for multi-turn) + response_token_ids: Optional[List[List[int]]] = None # (optional) Token IDs generated at each rollout turn + response_loss_mask: Optional[List[List[int]]] = None # (optional) Loss mask for each rollout turn + extra_info: Dict[str, + Any] = Field(default_factory=dict) # Additional rollout information; must be JSON-serializable + + def model_post_init(self, __context): + # Ensure multimodal data in extra_info is serializable (e.g., images to base64) + super().model_post_init(__context) + self.mminfo_to_serializable() + + def mminfo_to_serializable(self): + mm_keys = ['images', 'audios', 'videos'] + + for key, value in self.extra_info.items(): + if key in mm_keys: + from swift.llm.infer.protocol import MultiModalRequestMixin + # Convert multimodal content to base64 for serialization + self.extra_info[key] = MultiModalRequestMixin.to_base64(value) class RolloutScheduler(ABC): @@ -43,7 +67,7 @@ async def async_infer(self, async def _infer_async_single(infer_request: Union['RolloutInferRequest', Dict[str, Any]], request_config: 'RequestConfig', **kwargs): - from swift.llm.infer.protocol import RolloutInferRequest + from swift.llm.template import RolloutInferRequest if isinstance(infer_request, Dict): infer_request = RolloutInferRequest(**infer_request) @@ -58,14 +82,32 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque **kwargs) -> 'RolloutOutput': result: 'ChatCompletionResponse' = await self.infer_engine.infer_async(infer_request, request_config, **kwargs) response_token_ids = result.choices[0].token_ids - response_loss_scale = [1] * len(response_token_ids) + response_loss_mask = [1] * len(response_token_ids) return RolloutOutput( results=result, messages=infer_request.messages, response_token_ids=[response_token_ids], - response_loss_scale=[response_loss_scale], + response_loss_mask=[response_loss_mask], extra_info={'num_turns': 1}) + def __getattr__(self, key: str): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + try: + infer_engine = object.__getattribute__(self, 'infer_engine') + if hasattr(infer_engine, key): + return getattr(infer_engine, key) + + except AttributeError: + raise AttributeError(f'{type(self).__name__} object has no attribute {key}') + + @property + def engine(self): + return self.infer_engine + class MultiTurnScheduler(RolloutScheduler, ABC): """ @@ -87,32 +129,35 @@ class MultiTurnScheduler(RolloutScheduler, ABC): Note: You must implement at least one of these approaches in your subclass. Options: - - If each round's response token ids are included in the RolloutOutput, + - If each round's response_token_ids are included in the RolloutOutput, the Trainer can skip encoding the completion text into token_ids when calculating loss. This avoids potential training inconsistencies due to asymmetric encode/decode behavior. See: https://github.com/0russwest0/Agent-R1/issues/30#issuecomment-2826155367 - - If both response token ids and response loss_scale are returned in the RolloutOutput, - you can manually control the loss mask by specifying a loss_scale value for each token. - The Trainer will use the provided loss_scale values directly when computing the loss. - Note: Returning loss_scale requires that response token ids are also returned, + - If both response_token_ids and response_loss_mask are returned in the RolloutOutput, + you can manually control the loss mask for each token. + The Trainer will use the provided loss_mask values directly when computing the loss. + Note: Returning response_loss_mask requires that response_token_ids are also returned, as the two must be aligned in length for correct loss computation. + You can refer to MathTipsScheduler as an example of how to use response_token_ids and response_loss_mask. + Loss mask configuration: During rollout, some parts of the completion (e.g., environment observations embedded in completion) may need to be masked out from loss computation. There are two supported strategies: 1. Use the built-in `loss_scale` parameter in ms-swift and do not return response token ids. - 2. Return response token ids along with a corresponding response loss scale tensor (of equal length) to indicate the loss mask for each token. + 2. Return response_token_ids along with a corresponding response_loss_mask (of equal length) to indicate the loss mask for each token. # noqa """ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> Union['RolloutOutput', List['RolloutOutput']]: - from swift.llm.template import RolloutInferRequest current_request = infer_request current_turn = 1 info_dict = {} + total_response_ids = [] + total_response_loss_mask = [] while True: messages = current_request.messages if current_turn == 1 or not messages[-1]['content']: @@ -121,7 +166,7 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque result: 'ChatCompletionResponse' = await self.infer_engine.infer_async(current_request, request_config, **kwargs) - result_choice: 'RolloutResponseChoice' = result.choices[0] + result_choice: 'ChatCompletionResponseChoice' = result.choices[0] completion = result_choice.message.content if messages[-1]['role'] == 'assistant': @@ -135,31 +180,39 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque should_stop = should_stop or (current_turn >= self.max_turns) if should_stop: - result_choice.messages = messages info_dict['num_turns'] = current_turn for key, value in info_dict.items(): if hasattr(result_choice, key): setattr(result_choice, key, value) else: result_choice.multi_turn_infos[key] = value - result_choice.process_images() - return result + return RolloutOutput( + results=result, + messages=messages, + response_id=total_response_ids, + response_loss_mask=total_response_loss_mask, + extra_info=info_dict, + ) ret = self.step(current_request, result_choice, current_turn) - if isinstance(ret, tuple): - current_request, info_dict = ret - else: - current_request = ret - info_dict = {} - assert isinstance(current_request, RolloutInferRequest) + current_request: 'RolloutInferRequest' = ret['infer_request'] + return_token_id = False + if 'response_token_ids' in ret: + total_response_ids.append(ret['response_token_ids']) + return_token_id = True + if 'response_loss_mask' in ret: + assert return_token_id, 'You must return response_token_ids if you want to return response_loss_mask' + assert len(ret['response_loss_mask']) == len(ret['response_token_ids']), \ + 'response_loss_mask must have the same length as response_token_ids' + total_response_loss_mask.append(ret['response_loss_mask']) if current_request.messages[-1]['role'] == 'assistant': # Add a dummy response to allow engine to continue generating current_request.messages.append({'role': 'assistant', 'content': None}) current_turn += 1 - def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', - current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', Dict]]: + def step(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: """ Handles transition between conversation turns. @@ -169,15 +222,17 @@ def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseCho current_turn: Current turn number Returns: - Either: - - The next inference request, OR - - A tuple of (next_request, info_dict) where info_dict contains - additional metadata to be stored in the final result + Dict[str, Any]: A dictionary containing inference results with the following structure: + - infer_request (required): Main inference request object + - response_token_ids (Optional[List[List[int]]]): Token IDs of responses for each rollout turn + - response_loss_scale (Optional[List[List[int]]]): Loss scaling factors for responses in each rollout turn # noqa + - extra_info (Optional[Dict[str, Any]]): Additional metadata (must be serializable) + """ raise NotImplementedError( 'Please implement the `step` method in your MultiTurnScheduler subclass, or override the `run` method.') - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', + def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', current_turn: int) -> bool: """ Default termination logic for checking if a multi-turn rollout should end. @@ -207,6 +262,75 @@ def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutR return False +class MathTipsScheduler(MultiTurnScheduler): + tips_prompt = 'But wait... It seems I made a mistake,' + + def __init__(self, tokenizer, *args, **kwargs): + from .orm import MathAccuracy + self.tokenizer = tokenizer + super().__init__(*args, **kwargs) + self.acc_func = kwargs.get('acc_function', MathAccuracy()) + + def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + current_turn: int) -> bool: + last_completion = infer_request.messages[-1]['content'] + # we only give tips once + if self.tips_prompt in last_completion: + return True + solution = infer_request.data_dict['solution'] + + acc = self.acc_func([last_completion], [solution])[0] + if acc == 1: + return True + + return super().check_finished(infer_request, result, current_turn) + + def step(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: + completion = result.message.content + if '' in completion: + completion = completion[:completion.index('')] + if '' in completion: + completion = completion[:completion.index('')] + completion += self.tips_prompt + if infer_request.messages[-1]['role'] == 'assistant': + if not infer_request.messages[-1]['content']: + # Multi-turn continuation: pop the dummy input we add in last turn + infer_request.messages.pop(-1) + infer_request.messages[-1]['content'] = completion + else: + infer_request.messages.append({'role': 'assistant', 'content': completion}) + + return {'infer_request': infer_request} + + +class MathTipsMultiTurnScheduler(MultiTurnScheduler): + from .orm import MathAccuracy + tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' + acc_func = MathAccuracy() + + def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + current_turn: int) -> bool: + + last_query = infer_request.messages[-2]['content'] + # we only give tips once + if self.tips_prompt in last_query: + return True + + completion = result.message.content + solution = infer_request.data_dict['solution'] + acc = self.acc_func([completion], [solution])[0] + if acc == 1: + return True + + return super().check_finished(infer_request, result, current_turn) + + def step(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: + infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) + return infer_request + + class GYMScheduler(RolloutScheduler): def __init__(self, @@ -248,7 +372,7 @@ async def _close_env_async(self, env: Env): async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> 'ChatCompletionResponse': - from swift.llm.infer.protocol import ChatCompletionResponse, GymRolloutResponseChoice + from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice """ Execute the gym environment-based rollout: 1. Initialize environment and context manager @@ -288,7 +412,7 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque remove_response(current_request.messages) result: 'ChatCompletionResponse' = await self.infer_async(current_request, request_config, **kwargs) - result_choice: 'RolloutResponseChoice' = result.choices[0] + result_choice: 'ChatCompletionResponseChoice' = result.choices[0] completion = result_choice.message.content messages.append({'role': 'assistant', 'content': completion}) @@ -307,101 +431,37 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque current_turn += 1 # Build final response with gym-specific information - final_choice = GymRolloutResponseChoice( + final_choice = ChatCompletionResponseChoice( index=result_choice.index, message=result_choice.message, finish_reason=result_choice.finish_reason, - logprobs=result_choice.logprobs, - messages=messages, - trajectory_id=trajectory_id, - total_reward=total_reward, - step_rewards=step_rewards, - trajectory_info=trajectory_info) + logprobs=result_choice.logprobs) - return ChatCompletionResponse( + result = ChatCompletionResponse( model=self.infer_engine.model_name, choices=[final_choice], usage=result.usage, id=f'gym_{trajectory_id}') + return RolloutOutput( + results=result, + messages=messages, + extra_info={ + 'num_turns': current_turn, + 'trajectory_id': trajectory_id, + 'total_reward': total_reward, + 'step_rewards': step_rewards, + 'trajectory_info': trajectory_info + }) + finally: # Ensure environment is properly closed await self._close_env_async(env) -class MathTipsScheduler(MultiTurnScheduler): - tips_prompt = 'But wait... It seems I made a mistake,' - - def __init__(self, max_turns=None, *args, **kwargs): - from .orm import MathAccuracy - super().__init__(max_turns, *args, **kwargs) - self.acc_func = kwargs.get('acc_function', MathAccuracy()) - - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', - current_turn: int) -> bool: - last_completion = infer_request.messages[-1]['content'] - # we only give tips once - if self.tips_prompt in last_completion: - return True - solution = infer_request.data_dict['solution'] - - acc = self.acc_func([last_completion], [solution])[0] - if acc == 1: - return True - - return super().check_finished(infer_request, result, current_turn) - - def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', - current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', dict]]: - completion = result.message.content - if '' in completion: - completion = completion[:completion.index('')] - if '' in completion: - completion = completion[:completion.index('')] - completion += self.tips_prompt - if infer_request.messages[-1]['role'] == 'assistant': - if not infer_request.messages[-1]['content']: - # Multi-turn continuation: pop the dummy input we add in last turn - infer_request.messages.pop(-1) - infer_request.messages[-1]['content'] = completion - else: - infer_request.messages.append({'role': 'assistant', 'content': completion}) - - return infer_request - - -class MathTipsMultiTurnScheduler(MultiTurnScheduler): - from .orm import MathAccuracy - tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' - acc_func = MathAccuracy() - - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', - current_turn: int) -> bool: - - last_query = infer_request.messages[-2]['content'] - # we only give tips once - if self.tips_prompt in last_query: - return True - - completion = result.message.content - solution = infer_request.data_dict['solution'] - acc = self.acc_func([completion], [solution])[0] - if acc == 1: - return True - - return super().check_finished(infer_request, result, current_turn) - - def step( - self, - infer_request: 'RolloutInferRequest', - result: 'RolloutResponseChoice', - current_turn: int, - ) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', dict]]: - infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) - return infer_request - multi_turns = { 'base_scheduler': RolloutScheduler, 'math_tip_trick': MathTipsScheduler, 'math_tip_trick_multi_turn': MathTipsMultiTurnScheduler, + 'gym_scheduler': GYMScheduler, } diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 1f56921cf4..ae542d6446 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -16,8 +16,7 @@ from transformers.utils import is_torch_cuda_available from swift.llm import AdapterRequest, RolloutInferRequest, Template -from swift.llm.infer.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, GymRolloutResponseChoice, - RequestConfig, RolloutResponseChoice) +from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice, RequestConfig from swift.plugin import Metric from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available @@ -280,17 +279,7 @@ def close_communicator(self): logger.warning(f'Error closing server {i} communicator: {str(e)}') def parse_resp_data(self, resp_data): - if self.use_gym_env: - choice_cls = GymRolloutResponseChoice - elif self.use_async_engine: - choice_cls = RolloutResponseChoice - else: - choice_cls = ChatCompletionResponseChoice - result = [ - ChatCompletionResponse( - choices=[from_dict(data_class=choice_cls, data=c) for c in resp['choices']], - **{k: v - for k, v in resp.items() if k != 'choices'}) for resp in resp_data - ] + choice_cls = ChatCompletionResponseChoice + result = [ChatCompletionResponse(choices=[from_dict(data_class=choice_cls, data=c) for c in resp['choices']])] return result From dcc9052fd976def936190e4a4f66af2a9fd75964 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 5 Aug 2025 18:39:31 +0800 Subject: [PATCH 12/26] wip --- swift/llm/infer/protocol.py | 40 ++++++++++++++++++++----------------- swift/plugin/multi_turn.py | 29 +-------------------------- 2 files changed, 23 insertions(+), 46 deletions(-) diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 5d4ed9c827..4899649458 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -10,7 +10,7 @@ import json from PIL import Image -from pydantic import BaseModel +from pydantic import BaseModel, Field from ..template import InferRequest from ..utils import Messages, Tool @@ -298,20 +298,6 @@ class EmbeddingResponse: created: int = field(default_factory=lambda: int(time.time())) -# @dataclass -# class RolloutResponseChoice(ChatCompletionResponseChoice): -# messages: Optional[Messages] = None -# images: Optional[List[str]] = None -# multi_turn_infos: Dict[str, Any] = field(default_factory=dict) - -# @dataclass -# class GymRolloutResponseChoice(RolloutResponseChoice): -# trajectory_id: str = None -# total_reward: float = 0.0 -# step_rewards: List[float] = None -# trajectory_info: List[Dict[str, Any]] = None - - @dataclass class CompletionResponseChoice: index: int @@ -320,9 +306,7 @@ class CompletionResponseChoice: logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None -# class RolloutOutput: -# results: List[ChatCompletionResponse] -# extra_info: Dict[str, Any] + @dataclass @@ -341,6 +325,26 @@ def to_cmpl_response(self) -> 'CompletionResponse': id_ = f'cmpl{self.id[len("chatcmpl"):]}' return CompletionResponse(self.model, choices, self.usage, id_, created=self.created) +class RolloutOutput(BaseModel): + results: ChatCompletionResponse + # multi turn rollout + messages: Optional[Messages] = None # Conversation history for the final rollout (required for multi-turn) + response_token_ids: Optional[List[List[int]]] = None # (optional) Token IDs generated at each rollout turn + response_loss_mask: Optional[List[List[int]]] = None # (optional) Loss mask for each rollout turn + extra_info: Dict[str, Any] = Field(default_factory=dict) # Additional rollout infos; must be JSON-serializable + + def model_post_init(self, __context): + # Ensure multimodal data in extra_info is serializable (e.g., images to base64) + super().model_post_init(__context) + self.mminfo_to_serializable() + + def mminfo_to_serializable(self): + mm_keys = ['images', 'audios', 'videos'] + + for key, value in self.extra_info.items(): + if key in mm_keys: + # Convert multimodal content to base64 for serialization + self.extra_info[key] = MultiModalRequestMixin.to_base64(value) @dataclass class CompletionResponse: diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index fad32df66b..266c60df24 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -1,9 +1,7 @@ import asyncio from abc import ABC from copy import deepcopy -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union - -from pydantic import BaseModel, Field +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from swift.plugin import ContextManager, Env, context_managers, envs @@ -26,31 +24,6 @@ def remove_response(messages: 'Messages') -> Optional[str]: if last_role == 'assistant': return messages.pop()['content'] - -class RolloutOutput(BaseModel): - results: 'ChatCompletionResponse' - # multi turn rollout - messages: Optional['Messages'] = None # Conversation history for the final rollout (required for multi-turn) - response_token_ids: Optional[List[List[int]]] = None # (optional) Token IDs generated at each rollout turn - response_loss_mask: Optional[List[List[int]]] = None # (optional) Loss mask for each rollout turn - extra_info: Dict[str, - Any] = Field(default_factory=dict) # Additional rollout information; must be JSON-serializable - - def model_post_init(self, __context): - # Ensure multimodal data in extra_info is serializable (e.g., images to base64) - super().model_post_init(__context) - self.mminfo_to_serializable() - - def mminfo_to_serializable(self): - mm_keys = ['images', 'audios', 'videos'] - - for key, value in self.extra_info.items(): - if key in mm_keys: - from swift.llm.infer.protocol import MultiModalRequestMixin - # Convert multimodal content to base64 for serialization - self.extra_info[key] = MultiModalRequestMixin.to_base64(value) - - class RolloutScheduler(ABC): # Single Turn Rollout Scheduler def __init__(self, infer_engine: 'GRPOVllmEngine', max_turns: Optional[int] = None, *args, **kwargs): From f72712efffce7472b9856a0fc374169e46a7c051 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 6 Aug 2025 12:06:24 +0800 Subject: [PATCH 13/26] wip --- .../infer/infer_engine/grpo_vllm_engine.py | 15 ++++++-- swift/trainers/rlhf_trainer/grpo_trainer.py | 38 ++++++++++++++----- 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 100a4126fd..03cc656012 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -12,7 +12,7 @@ from swift.plugin.context_manager import ContextManager, context_managers from swift.plugin.env import Env, envs from swift.plugin.multi_turn import MultiTurnScheduler -from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig +from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig, RolloutOutput from .utils import AdapterRequest try: @@ -149,9 +149,9 @@ def infer( template: Optional[Template] = None, use_tqdm: Optional[bool] = None, adapter_request: Optional[AdapterRequest] = None, - ) -> List[ChatCompletionResponse]: + ) -> List[RolloutOutput]: assert not self.use_async_engine, 'for Async Engine, use infer_async instead' - return super().infer( + res = super().infer( infer_requests, request_config, metrics, @@ -159,6 +159,15 @@ def infer( use_tqdm=use_tqdm, adapter_request=adapter_request, ) + if not isinstance(res, list): + res = [res] + for i, result in enumerate(res): + if not isinstance(result, RolloutOutput): + if not isinstance(result, ChatCompletionResponse): + raise TypeError("Result must be a ChatCompletionResponse or RolloutOutput instance.") + res[i] = RolloutOutput(results=result) + + return res async def _batch_infer_stream(self, tasks, diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 74666725ce..e8cd2a23c8 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -32,7 +32,7 @@ from swift.llm import (InferRequest, MultiModelKeys, RequestConfig, RolloutInferRequest, RowPreprocessor, Template, get_model_arch, to_device) -from swift.llm.infer.protocol import ChatCompletionResponse +from swift.llm.infer.protocol import ChatCompletionResponse, RolloutOutput from swift.llm.model.utils import get_llm_model from swift.llm.template.base import MaxLengthError from swift.llm.template.template_inputs import StdTemplateInputs @@ -72,7 +72,16 @@ def patched_len(self) -> int: RepeatSampler.__len__ = patched_len RepeatSampler.old_len_func = origin_len_func - +""" +TODO: + 1. RolloutOutput 统一输入输出解析 + a. docstring 修改 + 2. 增加 prompt id 和 device id,修改获取本地数据的逻辑 + a. 待确认:是均分比较好 还是 获取本地prompt的rollout,前者的话不需要device id + 3. 动态SPG逻辑,修改 _prepare_inputs 跳过 rollout 的逻辑 + 4. 不足一个batch的样本,随机抽取样本进行填充,并且不计算loss + 5. 优化 server 分发数据的通信逻辑(利好大集群训练) +""" class GRPOCallback(TrainerCallback): def __init__(self, trainer): @@ -609,7 +618,7 @@ def _wait_queue(self): def _infer(self, inputs: Optional[InputsType], request_config: RequestConfig, - is_global_inputs: bool = False) -> List[ChatCompletionResponse]: + is_global_inputs: bool = False) -> List[RolloutOutput]: request_config = self._get_request_config() # keys from InferRequest per_device_size = len(inputs) @@ -629,7 +638,7 @@ def _infer(self, return [] if self.accelerator.is_main_process: - results: List[ChatCompletionResponse] = self._engine_infer( + results: List[RolloutOutput] = self._engine_infer( infer_requests=all_inputs, request_config=request_config) else: results = [None] * len(all_inputs) @@ -665,7 +674,7 @@ def _infer(self, # otherwise, the program may hang. # 2. Ensure that the seed for vLLM Engines across different TP groups is different; # otherwise, identical completions will be generated. - results: List[ChatCompletionResponse] = self._engine_infer( + resltus: List[RolloutOutput] = self._engine_infer( infer_requests=inputs, request_config=request_config) if self.vllm_tensor_parallel_size > 1: @@ -724,11 +733,21 @@ def _infer_single_or_multi_turn(self, # for external server, pass the system args which may define in trainer self._set_inputs_system(inputs) # infer first turn - results: List[ChatCompletionResponse] = self._infer(inputs, request_config, is_global_inputs) + results: List[RolloutOutput] = self._infer(inputs, request_config, is_global_inputs) outputs = [] if not self.multi_turn_scheduler and not self.vllm_use_async_engine: # message concatenation - for i, output in enumerate(results): + for i, result in enumerate(results): + _input: Dict = deepcopy(inputs[i]) + choice = result.results.choices[0] + if result.messages: + messages = result.messages + else: + messages = _inputs[i]['messages'] + messages.append({'role': 'assistant', 'content': choice.message.content}) + _input['messages'] = messages + # TODO: input 和 results 数量不定 + _choices = [] for choice in output.choices: _input: Dict = deepcopy(inputs[i]) @@ -1588,7 +1607,7 @@ def _engine_infer( request_config: Optional[RequestConfig] = None, *, use_tqdm: Optional[bool] = False, - ) -> List[ChatCompletionResponse]: + ) -> List[RolloutOutput]: with patch_profiling_context(self, 'generate'): if self.vllm_mode == 'server': request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects'] @@ -1608,7 +1627,8 @@ def _engine_infer( self._process_infer_requests_images(infer_requests) return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm) else: - return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) + res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) + return [RolloutOutput(results=r) for r in res] def _process_infer_requests_images(self, infer_requests: InputsType): # Process image format into a format that session.post can accept From a8aeb739766ee4b8cf9503c760eff52de2cefbdc Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 8 Aug 2025 16:05:00 +0800 Subject: [PATCH 14/26] refactor v1 --- .../infer/infer_engine/grpo_vllm_engine.py | 7 +- swift/llm/infer/infer_engine/vllm_engine.py | 7 +- swift/llm/infer/protocol.py | 49 +- swift/llm/template/template_inputs.py | 88 +- swift/plugin/multi_turn.py | 149 +-- swift/trainers/rlhf_trainer/grpo_trainer.py | 862 +++++++++++------- swift/trainers/rlhf_trainer/vllm_client.py | 4 +- swift/trainers/sequence_parallel/utils.py | 4 +- swift/utils/__init__.py | 3 +- swift/utils/utils.py | 18 + 10 files changed, 765 insertions(+), 426 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 03cc656012..a69748a395 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -6,6 +6,7 @@ import torch from tqdm.asyncio import tqdm_asyncio +from vllm.outputs import RequestOutput from swift.llm import InferRequest, RolloutInferRequest, Template, VllmEngine from swift.plugin import Metric, multi_turns @@ -164,8 +165,8 @@ def infer( for i, result in enumerate(res): if not isinstance(result, RolloutOutput): if not isinstance(result, ChatCompletionResponse): - raise TypeError("Result must be a ChatCompletionResponse or RolloutOutput instance.") - res[i] = RolloutOutput(results=result) + raise TypeError('Result must be a ChatCompletionResponse or RolloutOutput instance.') + res[i] = RolloutOutput(response=result) return res @@ -191,7 +192,7 @@ async def _new_run(task): new_tasks = [_new_run(task) for task in tasks] return await self.batch_run(new_tasks) - def _create_chat_completion_response(self, result: 'ChatCompletionResponse', template: Template, request_config, + def _create_chat_completion_response(self, result: 'RequestOutput', template: Template, request_config, request_id) -> ChatCompletionResponse: assert result is not None num_generated_tokens = sum(len(output.token_ids) for output in result.outputs) diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index 79dc4fd03e..9715ee992f 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -432,8 +432,10 @@ async def _infer_full_async( generation_config: SamplingParams, adapter_request: Optional[AdapterRequest], request_config: RequestConfig, + request_id: Optional[str] = None, ) -> Union[ChatCompletionResponse, EmbeddingResponse]: - request_id = random_uuid() + if request_id is None: + request_id = random_uuid() result_generator = self._add_request(inputs, generation_config, request_id, adapter_request=adapter_request) result = None async for result in result_generator: @@ -561,6 +563,9 @@ async def infer_async( 'adapter_request': adapter_request, 'request_config': request_config, } + if hasattr(infer_request, 'uuid') and infer_request.uuid: + # RolloutInferRequest + kwargs.update({'request_id': infer_request.uuid}) if pre_infer_hook: kwargs = pre_infer_hook(kwargs) if request_config.stream: diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 4899649458..5047db19e7 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -306,9 +306,6 @@ class CompletionResponseChoice: logprobs: Optional[Dict[str, List[Dict[str, Any]]]] = None - - - @dataclass class ChatCompletionResponse: model: str @@ -325,26 +322,54 @@ def to_cmpl_response(self) -> 'CompletionResponse': id_ = f'cmpl{self.id[len("chatcmpl"):]}' return CompletionResponse(self.model, choices, self.usage, id_, created=self.created) + class RolloutOutput(BaseModel): - results: ChatCompletionResponse - # multi turn rollout - messages: Optional[Messages] = None # Conversation history for the final rollout (required for multi-turn) - response_token_ids: Optional[List[List[int]]] = None # (optional) Token IDs generated at each rollout turn - response_loss_mask: Optional[List[List[int]]] = None # (optional) Loss mask for each rollout turn - extra_info: Dict[str, Any] = Field(default_factory=dict) # Additional rollout infos; must be JSON-serializable + """ + Output structure for rollout. + + Attributes: + response (ChatCompletionResponse): + The model's response + + messages (Optional[Messages]): + (Optional) Conversation history for the final rollout; required for multi-turn scenarios. + NOTE: + - If provided, this messages sequence will overwrite the original messages. + - If not provided, 'response' will be appended as the latest turn in the original messages. + - For multi-turn training, you need to manually return the updated messages, including the full history. + - The messages should include the latest assistant response as the final message. + + response_token_ids (Optional[List[List[int]]]): + (Optional) Token IDs generated at each rollout turn. + If provided, the training process will skip tokenizing the response. + + response_loss_mask (Optional[List[List[int]]]): + (Optional) Loss masks corresponding to each rollout turn. + If provided, the training process will skip computing loss masks for the response (as controlled by the `loss_scale` parameter). # noqa + + rollout_infos (Dict[str, Any]): + (Optional) Additional rollout information. This must be JSON-serializable. + """ + response: ChatCompletionResponse + # multi turn + messages: Optional[Messages] = None + response_token_ids: List[List[int]] = Field(default_factory=list) + response_loss_mask: List[List[int]] = Field(default_factory=list) + rollout_infos: Dict[str, Any] = Field(default_factory=dict) def model_post_init(self, __context): - # Ensure multimodal data in extra_info is serializable (e.g., images to base64) + # Ensure multimodal data in rollout_infos is serializable (e.g., images to base64) super().model_post_init(__context) self.mminfo_to_serializable() def mminfo_to_serializable(self): mm_keys = ['images', 'audios', 'videos'] - for key, value in self.extra_info.items(): + for key, value in self.rollout_infos.items(): if key in mm_keys: # Convert multimodal content to base64 for serialization - self.extra_info[key] = MultiModalRequestMixin.to_base64(value) + self.rollout_infos[key] = MultiModalRequestMixin.to_base64(value) + @dataclass class CompletionResponse: diff --git a/swift/llm/template/template_inputs.py b/swift/llm/template/template_inputs.py index be81223711..27922f5acb 100644 --- a/swift/llm/template/template_inputs.py +++ b/swift/llm/template/template_inputs.py @@ -14,21 +14,44 @@ @dataclass class InferRequest: """ - messages: Input in messages format. - Examples: [{ - "role": "user", # or assistant/system/role - "content": [ # str or List[Dict[str, Any]] - { - "type": "image", # or audio/video - "image": "", - }, - {"type": "text", "text": "Please describe the picture."}, - ], - }] - The above content is equivalent to: - [{"role": "user", "content": "Please describe the picture."}] - and additionally passing in images: [""]. - tools: Organize tools into the format of agent_template for system. for example, 'react_en'. + Data structure for inference requests. + + Attributes: + messages (Messages): + The input conversation in messages format. Each message is a dict containing at least + a "role" field (e.g., "user", "assistant", "system") and a "content" field. + Example: + [{ + "role": "user", + "content": [ + { + "type": "image", # can also be audio/video + "image": "", + }, + {"type": "text", "text": "Please describe the picture."}, + ], + }] + The above is equivalent to: + [{"role": "user", "content": "Please describe the picture."}] + with an additional argument: + images = [""] + + images (List[Union[str, Image.Image]]): + Optional, a list of images associated with the request. + Each image can be a URL, local path, base64 string, or PIL.Image object. + + audios (List[str]): + Optional, a list of audio resources associated with the request. + + videos (List[str]): + Optional, a list of video resources associated with the request. + + tools (Optional[List[Tool]]): + An optional list of tools. These should be organized in the agent_template format for + tools requested by the system, for example 'react_en'. + + objects (Dict[str, List[Any]]): + Container for additional multimodal objects, grouped by type (key). """ messages: Messages @@ -75,18 +98,35 @@ def to_printable(self): @dataclass class RolloutInferRequest(InferRequest): """ - A request class that modifies the 'images' attribute - to be a list of strings for compatibility with POST requests. - The strings can represent image URLs or Base64 encoded images. + An inference request class for rollout scenarios. + + This class extends `InferRequest` and specifically overrides the `images` attribute + to be a list of strings for compatibility with POST requests. Each string may + represent an image URL or a Base64-encoded image. + + Inherits all fields from `InferRequest`: + messages (Messages): + Input conversation messages, supporting multimodal content. + audios (List[str]): + List of audio resources associated with the request. + videos (List[str]): + List of video resources associated with the request. + tools (Optional[List[Tool]]): + List of tools, organized by the agent template (e.g. 'react_en'). + objects (Dict[str, List[Any]]): + Optional container for additional multimodal objects. + + Additional / Overridden fields: + images (List[str]): + List of image resources, each as a string (URL or base64). + data_dict (Dict): + Optional dictionary for extra request data. + uuid (Optional[str]): + Optional unique identifier for this request instance. """ images: List[str] = field(default_factory=list) data_dict: Dict = field(default_factory=dict) - - def process_images(self): - """Convert PIL images to base64 strings.""" - self.images = [ - image.convert('RGB').tobytes() if isinstance(image, Image.Image) else image for image in self.images - ] + uuid: Optional[str] = None @dataclass diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 266c60df24..5aa2c4317f 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -4,25 +4,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from swift.plugin import ContextManager, Env, context_managers, envs +from swift.utils import remove_response if TYPE_CHECKING: - from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice, RequestConfig + from swift.llm.infer.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, RequestConfig, + RolloutOutput) from swift.llm.template import RolloutInferRequest from swift.llm.infer.infer_engine import GRPOVllmEngine from swift.llm.utils import Messages -""" - 1. 修改 step 方法的返回值:统一为 Dict,run 方法需要相应做修改 - 2. response_id 和 response_loss_scale 在 step方法的 dict 返回,在run中维护list - 3. - -""" - - -def remove_response(messages: 'Messages') -> Optional[str]: - last_role = messages[-1]['role'] if messages else None - if last_role == 'assistant': - return messages.pop()['content'] class RolloutScheduler(ABC): # Single Turn Rollout Scheduler @@ -53,15 +43,16 @@ async def _infer_async_single(infer_request: Union['RolloutInferRequest', Dict[s async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> 'RolloutOutput': - result: 'ChatCompletionResponse' = await self.infer_engine.infer_async(infer_request, request_config, **kwargs) - response_token_ids = result.choices[0].token_ids + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async(infer_request, request_config, + **kwargs) + response_token_ids = response.choices[0].token_ids response_loss_mask = [1] * len(response_token_ids) return RolloutOutput( - results=result, + response=response, messages=infer_request.messages, response_token_ids=[response_token_ids], response_loss_mask=[response_loss_mask], - extra_info={'num_turns': 1}) + rollout_infos={'num_turns': 1}) def __getattr__(self, key: str): try: @@ -126,9 +117,40 @@ class MultiTurnScheduler(RolloutScheduler, ABC): async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> Union['RolloutOutput', List['RolloutOutput']]: + """Execute multi-turn conversation rollout with built-in turn management logic. + + This implements the default multi-turn interaction flow, which you can override + to customize the conversation handling behavior. The default logic: + + 1. Manages conversation turns and stopping conditions + 2. Handles message accumulation across turns + 3. Tracks response tokens and loss masks + 4. Supports early stopping conditions + + Args: + infer_request: The initial inference request containing messages + request_config: Configuration for the inference request + **kwargs: Additional inference parameters + + Returns: + RolloutOutput containing the complete conversation history and metadata, + or a list of outputs for batched requests + + Customization Points: + - Override check_finished() to change stopping conditions + - Override step() to customize turn-to-turn transitions + - Subclass to completely change multi-turn behavior + + Example: + class CustomScheduler(MultiTurnScheduler): + async def run(self, *args, **kwargs): + # Custom multi-turn logic here + ... + """ + current_request = infer_request current_turn = 1 - info_dict = {} + rollout_infos = {} total_response_ids = [] total_response_loss_mask = [] while True: @@ -137,61 +159,66 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque # If it's the first turn or the last message content is empty(dummy), remove the response remove_response(messages) - result: 'ChatCompletionResponse' = await self.infer_engine.infer_async(current_request, request_config, - **kwargs) - result_choice: 'ChatCompletionResponseChoice' = result.choices[0] + # Get model response + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( + current_request, request_config, **kwargs) + response_choice: 'ChatCompletionResponseChoice' = response.choices[0] - completion = result_choice.message.content + # Update conversation history + completion = response_choice.message.content if messages[-1]['role'] == 'assistant': messages[-1]['content'] += completion else: messages.append({'role': 'assistant', 'content': completion}) - should_stop = self.check_finished(current_request, result_choice, current_turn) + # Check stopping conditions + should_stop = self.check_finished(current_request, response_choice, current_turn) if self.max_turns: should_stop = should_stop or (current_turn >= self.max_turns) if should_stop: - info_dict['num_turns'] = current_turn - for key, value in info_dict.items(): - if hasattr(result_choice, key): - setattr(result_choice, key, value) - else: - result_choice.multi_turn_infos[key] = value return RolloutOutput( - results=result, + response=response, messages=messages, response_id=total_response_ids, response_loss_mask=total_response_loss_mask, - extra_info=info_dict, + rollout_infos=rollout_infos, ) - ret = self.step(current_request, result_choice, current_turn) + # Prepare next turn + ret = self.step(current_request, response_choice, current_turn) current_request: 'RolloutInferRequest' = ret['infer_request'] + + # Track response tokens and masks return_token_id = False if 'response_token_ids' in ret: total_response_ids.append(ret['response_token_ids']) return_token_id = True + if 'response_loss_mask' in ret: assert return_token_id, 'You must return response_token_ids if you want to return response_loss_mask' assert len(ret['response_loss_mask']) == len(ret['response_token_ids']), \ 'response_loss_mask must have the same length as response_token_ids' total_response_loss_mask.append(ret['response_loss_mask']) + + if 'rollout_infos' in ret: + rollout_infos = {**rollout_infos, **ret['rollout_infos']} + if current_request.messages[-1]['role'] == 'assistant': # Add a dummy response to allow engine to continue generating current_request.messages.append({'role': 'assistant', 'content': None}) current_turn += 1 - def step(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> Dict: """ Handles transition between conversation turns. Args: infer_request: Current inference request - result: Response from current turn + response_choice: Response from current turn current_turn: Current turn number Returns: @@ -199,13 +226,13 @@ def step(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResp - infer_request (required): Main inference request object - response_token_ids (Optional[List[List[int]]]): Token IDs of responses for each rollout turn - response_loss_scale (Optional[List[List[int]]]): Loss scaling factors for responses in each rollout turn # noqa - - extra_info (Optional[Dict[str, Any]]): Additional metadata (must be serializable) + - rollout_infos (Optional[Dict[str, Any]]): Additional metadata (must be serializable) """ raise NotImplementedError( 'Please implement the `step` method in your MultiTurnScheduler subclass, or override the `run` method.') - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> bool: """ Default termination logic for checking if a multi-turn rollout should end. @@ -222,19 +249,25 @@ def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatComp Args: infer_request: The inference request object - result: Contains generation results including finish_reason + response_choice: Contains generation results including finish_reason current_turn: Current conversation turn count Returns: bool: True to terminate conversation, False to continue """ - if result.finish_reason == 'length': + if response_choice.finish_reason == 'length': return True if self.max_turns and current_turn >= self.max_turns: return True return False +class ThinkingModelScheduler(MultiTurnScheduler): + # TODO: example for thinking model + # replace history thinking block + pass + + class MathTipsScheduler(MultiTurnScheduler): tips_prompt = 'But wait... It seems I made a mistake,' @@ -244,7 +277,7 @@ def __init__(self, tokenizer, *args, **kwargs): super().__init__(*args, **kwargs) self.acc_func = kwargs.get('acc_function', MathAccuracy()) - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> bool: last_completion = infer_request.messages[-1]['content'] # we only give tips once @@ -256,11 +289,11 @@ def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatComp if acc == 1: return True - return super().check_finished(infer_request, result, current_turn) + return super().check_finished(infer_request, response_choice, current_turn) - def step(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> Dict: - completion = result.message.content + completion = response_choice.message.content if '' in completion: completion = completion[:completion.index('')] if '' in completion: @@ -282,7 +315,7 @@ class MathTipsMultiTurnScheduler(MultiTurnScheduler): tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' acc_func = MathAccuracy() - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> bool: last_query = infer_request.messages[-2]['content'] @@ -290,15 +323,15 @@ def check_finished(self, infer_request: 'RolloutInferRequest', result: 'ChatComp if self.tips_prompt in last_query: return True - completion = result.message.content + completion = response_choice.message.content solution = infer_request.data_dict['solution'] acc = self.acc_func([completion], [solution])[0] if acc == 1: return True - return super().check_finished(infer_request, result, current_turn) + return super().check_finished(infer_request, response_choice, current_turn) - def step(self, infer_request: 'RolloutInferRequest', result: 'ChatCompletionResponseChoice', + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> Dict: infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) return infer_request @@ -384,9 +417,10 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque current_request.messages = messages remove_response(current_request.messages) - result: 'ChatCompletionResponse' = await self.infer_async(current_request, request_config, **kwargs) - result_choice: 'ChatCompletionResponseChoice' = result.choices[0] - completion = result_choice.message.content + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( + current_request, request_config, **kwargs) + response_choice: 'ChatCompletionResponseChoice' = response.choices[0] + completion = response_choice.message.content messages.append({'role': 'assistant', 'content': completion}) # Execute environment step @@ -403,23 +437,22 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque current_request.messages = messages current_turn += 1 - # Build final response with gym-specific information final_choice = ChatCompletionResponseChoice( - index=result_choice.index, - message=result_choice.message, - finish_reason=result_choice.finish_reason, - logprobs=result_choice.logprobs) + index=response_choice.index, + message=response_choice.message, + finish_reason=response_choice.finish_reason, + logprobs=response_choice.logprobs) - result = ChatCompletionResponse( + last_response = ChatCompletionResponse( model=self.infer_engine.model_name, choices=[final_choice], - usage=result.usage, + usage=response.usage, id=f'gym_{trajectory_id}') return RolloutOutput( - results=result, + response=last_response, messages=messages, - extra_info={ + rollout_infos={ 'num_turns': current_turn, 'trajectory_id': trajectory_id, 'total_reward': total_reward, diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index e8cd2a23c8..0186d815cc 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1,10 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/trl. +import base64 import concurrent.futures import inspect import os import re import time +import uuid from collections import defaultdict, deque from concurrent.futures import Future from contextlib import contextmanager, nullcontext @@ -19,6 +21,7 @@ import torch.nn as nn import transformers from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from dacite import from_dict from packaging import version from torch.nn import ModuleList from torch.utils.data import DataLoader @@ -39,7 +42,8 @@ from swift.plugin import multi_turns, orms, rm_plugins from swift.plugin.multi_turn import MultiTurnScheduler from swift.utils import (JsonlWriter, empty_cache, get_current_device, get_logger, is_swanlab_available, - is_vllm_available, is_wandb_available, seed_worker, unwrap_model_for_generation) + is_vllm_available, is_wandb_available, remove_response, seed_worker, + unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin from .utils import (_ForwardRedirection, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, @@ -60,9 +64,9 @@ if is_swanlab_available(): import swanlab -InputsType = List[Dict[str, Union[torch.Tensor, Any]]] -# tuple: (messages, finish_reason) -OutputsType = List[Dict[List[Dict], str]] # TODO: Check +DataType = List[Dict[str, Union[torch.Tensor, Any]]] + +# patch to fix save last_checkpoint https://github.com/modelscope/ms-swift/pull/4969 if not hasattr(RepeatSampler, 'old_len_func'): origin_len_func = RepeatSampler.__len__ @@ -72,16 +76,7 @@ def patched_len(self) -> int: RepeatSampler.__len__ = patched_len RepeatSampler.old_len_func = origin_len_func -""" -TODO: - 1. RolloutOutput 统一输入输出解析 - a. docstring 修改 - 2. 增加 prompt id 和 device id,修改获取本地数据的逻辑 - a. 待确认:是均分比较好 还是 获取本地prompt的rollout,前者的话不需要device id - 3. 动态SPG逻辑,修改 _prepare_inputs 跳过 rollout 的逻辑 - 4. 不足一个batch的样本,随机抽取样本进行填充,并且不计算loss - 5. 优化 server 分发数据的通信逻辑(利好大集群训练) -""" + class GRPOCallback(TrainerCallback): def __init__(self, trainer): @@ -96,8 +91,7 @@ def on_train_begin(self, args, state, control, **kwargs): @dataclass class DataCache: - inputs: List[Dict] = field(default_factory=list) - outputs: List[Dict] = field(default_factory=list) + results: DataType def identity_data_collator(features): @@ -615,73 +609,16 @@ def _wait_queue(self): while self._queue.empty(): time.sleep(0.01) - def _infer(self, - inputs: Optional[InputsType], - request_config: RequestConfig, - is_global_inputs: bool = False) -> List[RolloutOutput]: + def _rollout(self, + inputs: Optional[DataType], + request_config: RequestConfig, + is_global_inputs: bool = False) -> List[RolloutOutput]: request_config = self._get_request_config() - # keys from InferRequest - per_device_size = len(inputs) - if is_global_inputs: - per_device_size //= self.accelerator.num_processes if self.vllm_mode == 'server': - # for server mode, we gather all the inputs and send to remote vllm server in main process - if is_global_inputs: - # async generate, pre-gather to avoid potential communicate operator - all_inputs = inputs - all_input_lengths = [per_device_size] + [0] * (self.accelerator.num_processes - 1) - else: - all_inputs = gather_object(inputs) - all_input_lengths = gather_object([len(inputs)]) - - if not any(inputs for inputs in all_inputs): - return [] - - if self.accelerator.is_main_process: - results: List[RolloutOutput] = self._engine_infer( - infer_requests=all_inputs, request_config=request_config) - else: - results = [None] * len(all_inputs) - # Broadcast the results from the main process to all processes, - # ensuring each process receives its corresponding slice. - if not is_global_inputs: - results = broadcast_object_list(results, from_process=0) - start_idx = sum(all_input_lengths[:self.accelerator.process_index]) - end_idx = start_idx + all_input_lengths[self.accelerator.process_index] - results = results[start_idx:end_idx] - else: - results = results if self.accelerator.is_main_process else [] + rollout_outputs = self._server_rollout(inputs, request_config, is_global_inputs) else: - # pt / vllm colocate - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - # Note: The input sizes may differ across ranks (e.g., in multi-turn scenarios, - # the amount of data each rank continues to process may vary). - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - local_input_length = len(inputs) - all_input_lengths = [None] * self.vllm_tensor_parallel_size - torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.tp_group) - start_idx = sum(all_input_lengths[:local_rank_in_group]) - end_idx = start_idx + all_input_lengths[local_rank_in_group] - - # orig_size = len(inputs)/ - gathered_inputs = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_inputs, inputs, group=self.tp_group) - inputs = [p for sublist in gathered_inputs for p in sublist] - # Set request_config.seed - # 1. Ensure that the seed for vLLM Engines within each TP (Tensor Parallelism) group is the same; - # otherwise, the program may hang. - # 2. Ensure that the seed for vLLM Engines across different TP groups is different; - # otherwise, identical completions will be generated. - resltus: List[RolloutOutput] = self._engine_infer( - infer_requests=inputs, request_config=request_config) - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - results = results[start_idx:end_idx] - return results + rollout_outputs = self._colocate_rollout(inputs, request_config) + return rollout_outputs def _get_request_config(self) -> RequestConfig: request_config = copy(self.request_config) @@ -702,7 +639,21 @@ def _get_request_config(self) -> RequestConfig: return request_config - def _set_inputs_system(self, inputs: InputsType) -> InputsType: + def _set_inputs_system(self, inputs: DataType) -> DataType: + """ + Inserts a default system message at the beginning of each input if specified. + + If a default system message is defined in the template and the first message in + an input is not already a system message, this method inserts the default system + message at the beginning of the messages list for each input. If no default system + message is provided, no modification is made. + + Args: + inputs (DataType): A list of input data entries, each containing a 'messages' field. + + Returns: + DataType: The input list, with the default system message prepended if applicable. + """ if not self.template.template_meta.default_system: return if all(_input['messages'][0]['role'] == 'system' for _input in inputs): @@ -713,163 +664,40 @@ def _set_inputs_system(self, inputs: InputsType) -> InputsType: messages.insert(0, {'role': 'system', 'content': self.template.template_meta.default_system}) def _infer_single_or_multi_turn(self, - inputs: InputsType, + inputs: DataType, request_config: RequestConfig, - is_global_inputs: bool = False) -> OutputsType: - """Perform multi-turn or single-turn inference + is_global_inputs: bool = False) -> List[DataType]: + """ + Runs inference for either single-turn or multi-turn dialogue. Args: - inputs: list of input requests - request_config: Inference configuration parameters - is_global_inputs: - A boolean indicating whether the inputs are global. When set to True, - the returned results in the main process will be a complete list of - global_outputs, while other processes will return an empty list []. + inputs: Input data for inference. + request_config: Configuration for the inference request. + is_global_inputs: Whether the inputs are from the global process. + Returns: - List of outputs where each entry contains: - - List of responses per prompt - - Each response is a tuple of (message_history, finish_reason) + List of processed outputs. """ # for external server, pass the system args which may define in trainer + + # Step 1: Prepare inputs with system prompts (if any) self._set_inputs_system(inputs) - # infer first turn - results: List[RolloutOutput] = self._infer(inputs, request_config, is_global_inputs) - outputs = [] + + # Step 2: First-turn rollout + rollout_outputs: List[RolloutOutput] = self._rollout(inputs, request_config, is_global_inputs) + + # Step 3: Handle single-turn (no scheduler, no async engine) if not self.multi_turn_scheduler and not self.vllm_use_async_engine: - # message concatenation - for i, result in enumerate(results): - _input: Dict = deepcopy(inputs[i]) - choice = result.results.choices[0] - if result.messages: - messages = result.messages - else: - messages = _inputs[i]['messages'] - messages.append({'role': 'assistant', 'content': choice.message.content}) - _input['messages'] = messages - # TODO: input 和 results 数量不定 - - _choices = [] - for choice in output.choices: - _input: Dict = deepcopy(inputs[i]) - # origin messages may contain response, we should remove it - InferRequest.remove_response(_input['messages']) - _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) - output_dict = { - 'messages': _input['messages'], - 'finish_reason': choice.finish_reason, - # NOTE: for training, we use rollout token_ids to calculate loss - # because the tokenizer encode/decode may change the token ids - 'completion_ids': choice.token_ids - } - _choices.append(output_dict) - outputs.append(_choices) - outputs = [item for sublist in outputs for item in sublist] - else: - # vLLMAsyncLLMEngine, only server mode is supported right now. - # NOTE: The message concatenation has already been done in the engine. - if self.vllm_use_async_engine: - for i, output in enumerate(results): - _choices = [] - for choice in output.choices: - # concated in Engine - _choice = { - 'messages': choice.messages, - 'finish_reason': choice.finish_reason, - 'completion_ids': choice.token_ids, - } - if self.use_gym_env: - _choice.update({ - 'total_reward': choice.total_reward, - 'trajectory_info': choice.trajectory_info - }) - outputs.append(_choices) - outputs = [item for sublist in outputs for item in sublist] - else: - # multi turn for PTEngine or vLLMLLMEngine - orig_size = len(inputs) - outputs = [None] * orig_size - # we remove origin response in first turn - current_turn = 1 - while True: - has_local_data = len(inputs) > 0 - has_global_data = gather_object([has_local_data]) - if not any(has_global_data): - break - # inputs for current turn - current_inputs = [] - cnt = 0 - # combine completions from results with messages - for i, output in enumerate(results): - for choice in output.choices: - current_input = deepcopy(inputs[i]) - messages = current_input['messages'] - - if current_turn == 1 or not messages[-1]['content'] or messages[-1]['content'] == '': - # first turn or the last message content is empty(dummy), remove the response - InferRequest.remove_response(messages) - if messages[-1]['role'] == 'assistant': - # If the last message was assistant, concatenate the new content to it - messages[-1]['content'] += choice.message.content - else: - # append a new message from the assistant - messages.append({'role': 'assistant', 'content': choice.message.content}) - - if 'index' not in current_input: - current_input['index'] = cnt - current_input['finish_reason'] = choice.finish_reason - if 'completion_ids' not in current_input: - current_input['completion_ids'] = [] - current_input['completion_ids'].append(choice.token_ids) - - cnt += 1 - current_inputs.append(current_input) - - # Process messages in the multi-turn function - should_stops = [ - self.multi_turn_scheduler.check_finished(request, result.choices[0], current_turn) - for request, result in zip(self.inputs_to_rolloutrequest(current_inputs), results) - ] - - # Retain messages that are not yet finished for the next round of rollout - pending_inputs = [] - for stop, _input, result in zip(should_stops, current_inputs, results): - index = _input['index'] - if stop: - outputs[index] = { - 'messages': _input['messages'], - 'finish_reason': result.choices[0].finish_reason, - 'completion_ids': result.choices[0].token_ids, - 'multi_turn_infos': _input.get('multi_turn_infos', {'num_turns': 1}) - } - else: - current_request = self.inputs_to_rolloutrequest([_input])[0] - ret = self.multi_turn_scheduler.step(current_request, result.choices[0], current_turn) - if isinstance(ret, tuple): - infer_request, info_dict = ret - else: - infer_request = ret - info_dict = {} - info_dict['num_turns'] = current_turn + 1 - pending_input = asdict(infer_request) - if 'multi_turn_infos' not in pending_input: - pending_input['multi_turn_infos'] = {} - for key, value in info_dict.items(): - pending_input['multi_turn_infos'][key] = value - - pending_input['index'] = index - pending_inputs.append(pending_input) - - current_infer_inputs = pending_inputs if has_local_data else [] - results = self._infer(current_infer_inputs, request_config) - - inputs = pending_inputs - current_turn += 1 - assert not any([o is None for o in outputs]) - - # flatten 2D list to 1D list - return outputs + return self._postprocess_rollout_outputs(inputs, rollout_outputs) + + # Step 4: Handle async engine (multi-turn handled inside the engine) + if self.vllm_use_async_engine: + return self._postprocess_rollout_outputs(inputs, rollout_outputs) - def async_infer(self, all_inputs): + # Step 5: Handle multi-turn locally + return self._sync_multi_turn_infer(inputs, rollout_outputs, request_config) + + def async_generate_rollout(self, all_inputs): current_queue = self._queue def infer_task(): @@ -887,9 +715,9 @@ def infer_task(): def done(future): try: result = future.result() - current_queue.put(DataCache(all_inputs, result)) + current_queue.put(DataCache(result)) except Exception as e: - logger.error('Error in async_infer callback: %s', str(e)) + logger.error('Error in async_generate_rollout callback: %s', str(e)) future.add_done_callback(done) @@ -899,54 +727,60 @@ def _prefetch(self, dataloader: DataLoader): if self.state.global_step != self._last_loaded_step: self._move_model_to_vllm(skip_async_check=True) self._last_loaded_step = self.state.global_step - outputs = self._infer_single_or_multi_turn(all_inputs, self.request_config, is_global_inputs=True) - self._queue.put(DataCache(all_inputs, outputs)) + results = self._infer_single_or_multi_turn(all_inputs, self.request_config, is_global_inputs=True) + self._queue.put(DataCache(results)) - def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: - # Skip the first wake_up to avoid the warning "Executor is not sleeping" + def _fast_infer(self, inputs: DataType) -> DataType: + """ + Efficient inference logic with support for vLLM colocate mode, async generation, + and model weight offloading. + """ + # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: if self.engine.inner_model_executor.is_sleeping: - # First, load weights only, https://github.com/vllm-project/vllm/pull/15500 - if 'tags' in inspect.signature(self.engine.engine.wake_up).parameters: - self.engine.engine.wake_up(tags=['weights']) - else: - logger.info('We recommend installing vLLM >= 0.8.3, (ideally 0.8.5.post1)' - 'to help reduce memory peaks during engine wake-up.') - self.engine.engine.wake_up() + wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters + # Load weights only (faster and reduces memory peak) + kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + self.engine.engine.wake_up(**kwargs) - # First, have main process load weights if needed + # Step 2: Load model weights if global_step has changed if self.state.global_step != self._last_loaded_step: self._move_model_to_vllm() self._last_loaded_step = self.state.global_step + # Step 3: Offload model/optimizer if enabled context = self.offload_context if self.enable_offload else nullcontext with context(): - if self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping and \ - 'tags' in inspect.signature(self.engine.engine.wake_up).parameters: + # Step 4: Wake up kv_cache after offloading (vLLM colocate only) + if (self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping + and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): # Load the kv_cache only after updating and offload the weights. self.engine.engine.wake_up(tags=['kv_cache']) + # Step 5: Handle rollout for async generate or sync if self.async_generate: - # send this step data to server - # we gather inputs outside the thread for prevent potential gather deadlock + # Pre-gather inputs to avoid potential gather deadlocks all_inputs = gather_object(inputs) - self.async_infer(all_inputs) - # cached data from last step - data_cache = self._queue.get() - all_inputs = data_cache.inputs - all_outputs = gather_object(data_cache.outputs) + self.async_generate_rollout(all_inputs) + + # Retrieve cached outputs from the last step + data_cache: DataCache = self._queue.get() + all_outputs = gather_object(data_cache.results) + + # Slice inputs/outputs for the current process + per_device_datasize = len(all_outputs) // self.accelerator.num_processes process_slice = slice( - self.accelerator.process_index * len(inputs), - (self.accelerator.process_index + 1) * len(inputs), + self.accelerator.process_index * per_device_datasize, + (self.accelerator.process_index + 1) * per_device_datasize, ) - inputs = all_inputs[process_slice] outputs = all_outputs[process_slice] else: with self.multi_turn_completion_length_context(): outputs = self._infer_single_or_multi_turn(inputs, self.request_config) + # Step 6: Reset prefix cache and sleep to release memory if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: # Reset prefix cache before sleeping to prevent using stale cache upon waking up # https://github.com/modelscope/ms-swift/pull/5143 @@ -954,49 +788,27 @@ def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: self.engine.engine.sleep(level=self.args.sleep_level) empty_cache() - return inputs, outputs - - def _generate_completions(self, inputs: InputsType) -> InputsType: - """Generate completions for given inputs using either fast inference or standard PyTorch inference. + return outputs - Args: - inputs: List of input examples containing conversation messages. + def _generate_completions(self, inputs: DataType) -> DataType: + inputs = self._preprocess_inputs(inputs) - Returns: - Modified inputs with generated completions added to the last message - and truncation flag set in 'is_truncated' field. - """ mode = 'train' if self.model.training else 'eval' if self.use_fast_infer: - inputs, outputs = self._fast_infer(inputs) + results = self._fast_infer(inputs) else: with unwrap_model_for_generation( self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation ), self.template.generate_context(), self.multi_turn_completion_length_context(): - outputs = self._infer_single_or_multi_turn(inputs, self.request_config) + results = self._infer_single_or_multi_turn(inputs, self.request_config) if mode == 'train': # In training mode, ensure the model is returned to train() mode after inference # This is necessary as pt engines set the model to eval mode during generation self.model.train() - for i, output in enumerate(outputs): - inputs[i]['messages'] = output['messages'] - inputs[i]['is_truncated'] = output['finish_reason'] == 'length' - inputs[i]['completion_ids'] = output['completion_ids'] - if 'multi_turn_infos' in output: - multi_turn_infos = output['multi_turn_infos'] - - if 'images' in output['multi_turn_infos']: - # override images for 'think with images' scenario - inputs[i]['images'] = multi_turn_infos['images'] - inputs[i]['multi_turn_infos'] = multi_turn_infos - if self.use_gym_env: - inputs[i]['total_reward'] = output['total_reward'] - inputs[i]['trajectory_info'] = output['trajectory_info'] if 'trajectory_info' in output else None - - return inputs + return results - def _generate_and_score_completions(self, inputs: InputsType) -> InputsType: + def _generate_and_score_completions(self, inputs: DataType) -> DataType: # resample for overlong(> max_length) prompt data if self.template.truncation_strategy == 'raise': inputs = self.resample_truncated_inputs(inputs) @@ -1022,7 +834,7 @@ def _generate_and_score_completions(self, inputs: InputsType) -> InputsType: return batch_encoded_inputs - def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: + def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: """Score completions using all reward functions Args: @@ -1139,19 +951,19 @@ def split_by_mini_batches(self, inputs, advantages): advantage_chunks = torch.chunk(advantages, spg) return spg_chunks, advantage_chunks - def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> List[InputsType]: + def _prepare_batch_inputs(self, inputs: DataType, rewards: torch.Tensor) -> List[DataType]: """ Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. Args: - inputs (InputsType): List of input samples. Original shape is [spg*bs] where: + inputs (DataType): List of input samples. Original shape is [spg*bs] where: - spg: steps_per_generation - bs: per-device batch size rewards (torch.Tensor): Tensor of global rewards corresponding to the inputs. Shape should match the total number of samples (spg*bs*num_processes*num_generations) Returns: - List[InputsType]: A list of prepared batch inputs, organized as [spg][bs] + List[DataType]: A list of prepared batch inputs, organized as [spg][bs] """ # Compute advantages grouped_rewards = rewards.view(-1, self.num_generations) @@ -1252,7 +1064,7 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func, for i, name in enumerate(self.reward_func_names): self._textual_logs['rewards'][name].extend(rewards_per_func[:, i].tolist()) - def _apply_chat_template_to_messages_list(self, messages_list: InputsType): + def _apply_chat_template_to_messages_list(self, messages_list: DataType): prompts_text = [] for messages in messages_list: InferRequest.remove_response(messages) @@ -1594,7 +1406,7 @@ def evaluation_loop(self, dataloader, *args, **kwargs): self.eval_flag = True return output - def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch=None) -> torch.Tensor: + def training_step(self, model: nn.Module, inputs: DataType, num_items_in_batch=None) -> torch.Tensor: if self.args.async_generate: # Wait for the eval rollout to complete while not self.is_async_generate_eval_rollout_done(): @@ -1603,47 +1415,17 @@ def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch def _engine_infer( self, - infer_requests: InputsType, + infer_requests: List[RolloutInferRequest], request_config: Optional[RequestConfig] = None, *, use_tqdm: Optional[bool] = False, ) -> List[RolloutOutput]: with patch_profiling_context(self, 'generate'): if self.vllm_mode == 'server': - request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects'] - - infer_requests = [{ - **{k: request[k] - for k in request_keys if k in request}, - **({ - 'data_dict': {k: request[k] - for k in request if k not in request_keys} - } if ( - (self.multi_turn_scheduler and self.vllm_use_async_engine) or - (self.vllm_use_async_engine and self.use_gym_env) - ) else {}) # use gym infer - } for request in infer_requests] - - self._process_infer_requests_images(infer_requests) return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm) else: res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) - return [RolloutOutput(results=r) for r in res] - - def _process_infer_requests_images(self, infer_requests: InputsType): - # Process image format into a format that session.post can accept - import base64 - if not any('images' in request for request in infer_requests): - return - for request in infer_requests: - if 'images' not in request: - continue - for i, img in enumerate(request['images']): - if 'bytes' in img and img['bytes']: - request['images'][i] = base64.b64encode(img['bytes']).decode('utf-8') - elif 'path' in img and img['path']: - request['images'][i] = img['path'] - return + return [RolloutOutput(response=r) for r in res] def old_policy(self): return self.num_iterations > 1 or self.args.gradient_accumulation_steps % self.args.steps_per_generation != 0 @@ -1734,7 +1516,7 @@ def set_default_max_tokens(_self, request_config: RequestConfig, inputs: Dict[st self.engine.max_model_len = original_max_len del self.engine.set_grpo_max_model_len - def resample_truncated_inputs(self, inputs: InputsType, n_try_fetch: int = 10) -> InputsType: + def resample_truncated_inputs(self, inputs: DataType, n_try_fetch: int = 10) -> DataType: template = self.template for i, data in enumerate(inputs): n_try = 0 @@ -1811,7 +1593,7 @@ def is_async_generate_eval_rollout_done(self): def is_async_generate_train_rollout_done(self): return not self.train_queue.empty() - def inputs_to_rolloutrequest(self, inputs: InputsType) -> List[RolloutInferRequest]: + def inputs_to_rolloutrequest(self, inputs: DataType) -> List[RolloutInferRequest]: """Convert a list of inputs to a list of RolloutInferRequest objects If the input contains a 'data_dict' key, it will be used as the base for the new data_dict. @@ -1869,3 +1651,435 @@ def offload_context(self): if getattr(self, 'optimizer', None) and self.args.offload_optimizer: self.load_optimizer() empty_cache() + + def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: + """ + Adds a unique `prompt_id` to each input based on their `messages` content. + + Inputs with identical `messages` (assumed to be adjacent) will share the same `prompt_id`. + + Args: + inputs (DataType): A list of dictionaries, each containing a 'messages' key. + + + Returns: + DataType: The input list with each item containing a new 'prompt_id' field. + + Example: + >>> inputs = [ + ... {"messages": [{"role": "user", "content": "hello"}], "data": 1}, + ... {"messages": [{"role": "user", "content": "hello"}], "data": 2}, + ... {"messages": [{"role": "assistant", "content": "hi"}], "data": 3}, + ... ] + >>> self._add_prompt_id_to_inputs(inputs) + [ + {"messages": [...], "data": 1, "prompt_id": "a1b2c3..."}, + {"messages": [...], "data": 2, "prompt_id": "a1b2c3..."}, + {"messages": [...], "data": 3, "prompt_id": "d4e5f6..."}, + ] + """ + if not inputs: + return inputs + + prev_messages = inputs[0].get('messages') + current_uuid = str(uuid.uuid4()) + inputs[0]['prompt_id'] = current_uuid + + for i in range(1, len(inputs)): + messages = inputs[i]['messages'] + if messages == prev_messages: + inputs[i]['prompt_id'] = current_id + else: + prev_messages = messages + current_id = str(uuid.uuid4()) + inputs[i]['prompt_id'] = current_id + + return inputs + + def _server_rollout(self, inputs: DataType, request_config: RequestConfig, + is_global_inputs: bool) -> List[RolloutOutput]: + """ + Perform rollout inference using vLLM server mode. + + Args: + inputs: List of input data to be processed + request_config: Configuration dictionary for the inference request + is_global_inputs: Flag indicating whether inputs are shared across all processes (async-generate) + + Returns: + List of RolloutOutput objects containing inference results + For non-global inputs(async-generate), returns only the portion assigned to this process. + + Notes: + - async engine with multi-turn scenarios, the outputs count may exceed inputs count + - For distributed inputs, outputs are scattered to processes + - Main process coordinates inference and broadcasts outputs to other processes + """ + # Convert inputs to inference requests + infer_requests = self.inputs2requests(inputs) + + if is_global_inputs: + per_device_size = len(infer_requests) // self.accelerator.num_processes + # for async generate, data have been pre-gathered to avoid potential communicate operator + all_requests = infer_requests + all_requests_lengths = [per_device_size] + [0] * (self.accelerator.num_processes - 1) + else: + all_requests = gather_object(infer_requests) + all_requests_lengths = gather_object([len(infer_requests)]) + + if not any(requests for requests in all_requests): + return [] + + # TODO: Check flatten + if self.accelerator.is_main_process: + all_outputs: List[RolloutOutput] = self._engine_infer( + infer_requests=all_requests, request_config=request_config) + + # Handle async engine the outputs count may exceed inputs count + if self.vllm_use_async_engine: + outputs_count = [len(all_outputs)] if self.accelerator.is_main_process else [0] + outputs_count = gather_object(outputs_count)[0] # Broadcast count to all processes + + # Initialize empty outputs for non-main processes + if not self.accelerator.is_main_process: + all_outputs = [None] * outputs_count + + # Distribute outputs to all processes for non-global inputs + if not is_global_inputs: + all_outputs = broadcast_object_list(all_outputs, from_process=0) + + # Calculate slice for this process's outputs + if not self.vllm_use_async_engine and self.multi_turn_scheduler: + # Special handling for colocated + multi-turn inference with varying request counts + start_idx = sum(all_requests_lengths[:self.accelerator.process_index]) + end_idx = start_idx + all_requests_lengths[self.accelerator.process_index] + process_slice = slice(start_idx, end_idx) + outputs = outputs[process_slice] + else: + # Standard equal distribution case + outputs = self.get_even_process_rollout_outputs(all_outputs) + + else: + # For global inputs, only main process keeps outputs + outputs = outputs if self.accelerator.is_main_process else [] + + return outputs + + def _colocate_rollout(self, inputs: DataType, request_config: RequestConfig) -> List[RolloutOutput]: + """ + Perform co-located rollout inference with PTEngine or vLLMEngine(TP supported). + + Args: + inputs: Input data for the current process + request_config: Configuration parameters for the inference request + + Returns: + List[RolloutOutput]: Inference results for this process's portion of inputs + + Notes: + - For tensor parallel groups (vllm_tensor_parallel_size > 1): + * Gathers inputs from all ranks in the tensor parallel group + * Each rank processes the full input set but keeps only its assigned portion + * Ensures consistent seeds within TP groups for synchronization + - In single-process mode, directly processes the inputs + """ + # Handle tensor parallel group processing + if self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + # Note: The input sizes may differ across ranks (e.g., in multi-turn scenarios, + # the amount of data each rank continues to process may vary). + + # Step 1: Gather input lengths from all ranks in the TP group + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + local_input_length = len(inputs) + all_input_lengths = [None] * self.vllm_tensor_parallel_size + torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.tp_group) + + # Calculate slice indices for this rank's outputs + start_idx = sum(all_input_lengths[:local_rank_in_group]) + end_idx = start_idx + all_input_lengths[local_rank_in_group] + + # Step 2: Gather actual inputs from all TP group ranks + gathered_inputs = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_inputs, inputs, group=self.tp_group) + + # Flatten the gathered inputs + inputs = [p for sublist in gathered_inputs for p in sublist] + + # Critical seed configuration for TP groups: + # 1. Same seed within TP group - ensures synchronization and prevents hangs + # 2. Different seeds across TP groups - avoids duplicate generations + outputs: List[RolloutOutput] = self._engine_infer(infer_requests=inputs, request_config=request_config) + + # For TP groups, each rank keeps only its assigned portion of outputs + if self.vllm_tensor_parallel_size > 1: + outputs = outputs[start_idx:end_idx] + + return outputs + + def inputs2requests(self, inputs: DataType) -> List[RolloutInferRequest]: + """ + Convert raw input data into RolloutInferRequest objects with proper data processing. + + Args: + inputs: List of raw input dictionaries containing messages and multimedia data + + Returns: + List[RolloutInferRequest]: Processed inference request objects ready for engine + + Processing includes: + - Image data conversion (bytes to base64, path handling) + - Field filtering based on request metadata requirements + - Optional preservation of additional fields for multi-turn async scenarios + """ + + def _process_image_data(image_data: Union[dict, str]) -> str: + """Convert image data from various formats into standardized representation. + + Args: + image_data: Either a dict with 'bytes' or 'path', or a direct string path + + Returns: + str: Base64 encoded image data or original file path + """ + if isinstance(image_data, dict): + if image_data.get('bytes'): + return base64.b64encode(image_data['bytes']).decode('utf-8') + if image_data.get('path'): + return image_data['path'] + return image_data + + if not inputs: + return [] + + # Define core metadata fields required for all requests + REQUEST_METADATA_FIELDS = ['messages', 'images', 'audios', 'videos', 'objects', 'uuid'] + requests_dicts = [] + + for data in inputs: + # Extract required metadata fields + request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data} + + # Preserve additional fields for multi-turn async scenarios + if self.multi_turn_scheduler and self.vllm_use_async_engine: + # data_dict is already concatenated inside async engine + extra_fields = {k: v for k, v in data.items() if k not in REQUEST_METADATA_FIELDS} + if extra_fields: + request_data['data_dict'] = extra_fields + elif self.multi_turn_scheduler: + # Concatenate data_dict here + base_data_dict = {} + if 'data_dict' in data: + if isinstance(data['data_dict'], dict): + base_data_dict = data['data_dict'] + else: + raise ValueError('data_dict exists but is not a dictionary') + # Add fields that are not in metadata fields and not 'data_dict' + extra_data = {k: v for k, v in data.items() if k not in REQUEST_METADATA_FIELDS and k != 'data_dict'} + # Merge additional fields and existing data_dict + final_data_dict = {**extra_data, **base_data_dict} + request_data['data_dict'] = final_data_dict if final_data_dict else {} + + requests_dicts.append(request_data) + + # Process image data in each request + for request in requests_dicts: + if 'images' in request and request['images']: + request['images'] = ([_process_image_data(img) for img in request['images']] if isinstance( + request['images'], list) else _process_image_data(request['images'])) + + # Convert dictionaries to formal request objects + return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts] + + def _preprocess_inputs(self, inputs: DataType) -> DataType: + """Preprocess input data before inference. + + Args: + inputs: List of input dictionaries containing conversation messages + + Returns: + Processed inputs with: + - Added prompt IDs for tracking + - Removed existing assistant responses from messages + + Processing Steps: + 1. Adds unique prompt IDs to each input for request tracking + 2. Cleans each message sequence by removing existing assistant responses + """ + processed_inputs = self._add_prompt_id_to_inputs(inputs) + + for input_item in processed_inputs: + remove_response(input_item['messages']) + + return processed_inputs + + def _postprocess_rollout_outputs(self, inputs: DataType, outputs: List[RolloutOutput]) -> DataType: + """ + Postprocess rollout outputs by merging them back into the input data structures. + + Depending on the mode (async or sync), it either matches inputs by UUID + or assumes a one-to-one correspondence. + """ + + def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput): + response = output.response + choice = response.choices[0] + + # Step 1: Update or append assistant message + if output.messages: + input_data['messages'] = output.messages # Override full message history + else: + # not provided, append + messages = input_data['messages'] + remove_response(messages) + messages.append({'role': 'assistant', 'content': choice.message.content}) + + input_data['messages'].append({'role': 'assistant', 'content': choice.message.content}) + + # Step 2: Add token IDs and loss mask + if output.response_token_ids: + input_data['response_token_ids'] = output.response_token_ids + if output.response_loss_mask: + input_data['response_loss_mask'] = output.response_loss_mask + else: + if not self.multi_turn_scheduler: + # for single turn, skip tokenizer response + input_data['response_token_ids'] = output.response.choices[0].token_ids + + # Step 3: Attach rollout extra info + if output.rollout_infos: + input_data['rollout_infos'] = output.rollout_infos + + # Step 4: Store finish reason (used for truncation filters etc.) + input_data['finish_reason'] = choice.finish_reason, + + return input_data + + # Async engine mode: match by UUID + if self.vllm_use_async_engine: + results = [] + id2inputs = {} + for input_data in inputs: + uuid = input_data['uuid'] + if uuid not in id2inputs: + id2inputs[uuid] = deepcopy(input_data) + + for output in outputs: + uuid = output.response.id + assert uuid not in id2inputs + input_data = deepcopy(id2inputs[uuid]) + results.append(merge_output_input_data(input_data, output)) + + return results + else: + # Sync mode: simple zip merge + assert len(inputs) == len(outputs) + return [ + merge_output_input_data(deepcopy(input_data), output) for input_data, output in zip(inputs, outputs) + ] + + def _sync_multi_turn_infer(self, inputs: DataType, first_turn_rollout_outputs: List[RolloutOutput], + request_config: RequestConfig) -> List[RolloutOutput]: + """ + Handles multi-turn inference when not using async engine. + + This method iteratively rolls out turns until all dialogues are finished + according to the multi_turn_scheduler. + """ + orig_size = len(inputs) + rollout_outputs = [None] * orig_size # Preallocate to preserve order + + # Attach index to inputs for tracking + for i, input_data in enumerate(inputs): + input_data['index'] = i + + current_turn = 1 + outputs = first_turn_rollout_outputs + while True: + has_local_data = bool(len(inputs) > 0) + has_global_data = gather_object([has_local_data]) + if not any(has_global_data): + break + + for i, output in enumerate(outputs): + input_data = deepcopy(inputs[i]) + if output and output.messages: + messages = output.messages + else: + response = output.response + choice = response.choices[0] + messages = input_data['messages'] + if (current_turn == 1 or not messages[-1]['content'] or messages[-1]['content'] == ''): + remove_response(messages) + messages.append({'role': 'assistant', 'content': choice.message.content}) + + input_data['messages'] = messages + index = input_data['index'] + rollout_outputs[index] = output + rollout_outputs[index].messages = messages + + # Determine which dialogues are finished + should_stops = [ + self.multi_turn_scheduler.check_finished(req, output.response.choices[0], current_turn) + for req, output in zip(self.inputs_to_rolloutrequest(inputs), outputs) + ] + + # Prepare pending inputs for next turn + pending_inputs = [] + for stop, _input, output in zip(should_stops, inputs, outputs): + if stop: + continue + index = _input['index'] + step_result = self.multi_turn_scheduler.step( + self.inputs2requests([_input])[0], output.choices[0], current_turn) + + if step_result['response_token_ids']: + rollout_outputs[index].response_token_ids.append(step_result['response_token_ids']) + if step_result['response_loss_mask']: + rollout_outputs[index].response_loss_mask.append(step_result['response_loss_mask']) + + if step_result['rollout_infos']: + rollout_outputs[index].rollout_infos.update(step_result['rollout_infos']) + + pending_input = {**asdict(step_result['infer_request']), 'index': index} + pending_inputs.append(pending_input) + + inputs = pending_inputs + current_turn += 1 + + # Rollout for the next turn + outputs = self._rollout(inputs if has_local_data else [], request_config) + + assert all(o is not None for o in rollout_outputs) + return rollout_outputs + + def get_even_process_rollout_outputs(self, all_outputs: List[RolloutOutput]) -> List[RolloutOutput]: + """ + Evenly splits `all_outputs` among all processes. + + Each process receives a contiguous chunk of data. If `len(all_outputs)` is not + perfectly divisible by the number of processes, the first `remainder` processes + will receive one additional item. + + Args: + all_outputs (List[RolloutOutput]): The full list of outputs to be distributed. + + Returns: + List[RolloutOutput]: The subset of `all_outputs` assigned to this process. + """ + num_procs = self.accelerator.num_processes + proc_idx = self.accelerator.process_index + total = len(all_outputs) + + base_size = total // num_procs + remainder = total % num_procs + + if proc_idx < remainder: + start = proc_idx * (base_size + 1) + end = start + base_size + 1 + else: + start = remainder * (base_size + 1) + (proc_idx - remainder) * base_size + end = start + base_size + + return all_outputs[start:end] diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index ae542d6446..e152abb13b 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -280,6 +280,8 @@ def close_communicator(self): def parse_resp_data(self, resp_data): choice_cls = ChatCompletionResponseChoice - result = [ChatCompletionResponse(choices=[from_dict(data_class=choice_cls, data=c) for c in resp['choices']])] + result = [ + ChatCompletionResponse(choices=[from_dict(data_class=choice_cls, data=c) for c in resp_data['choices']]) + ] return result diff --git a/swift/trainers/sequence_parallel/utils.py b/swift/trainers/sequence_parallel/utils.py index 96b20f1bfd..c07fa7e537 100644 --- a/swift/trainers/sequence_parallel/utils.py +++ b/swift/trainers/sequence_parallel/utils.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: try: from ..rlhf_trainer import GRPOTrainer - from ..rlhf_trainer.grpo_trainer import InputsType + from ..rlhf_trainer.grpo_trainer import DataType except ImportError: pass # Conditional import for profiling decorator @@ -497,7 +497,7 @@ def _padding_free_output_hook(module, args, kwargs, result): def _get_per_token_logps_and_entropies_grpo( self: 'GRPOTrainer', model: torch.nn.Module, - inputs: 'InputsType', + inputs: 'DataType', sp_instance: SequenceParallel, compute_entropy: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Get per token logps for GRPO sequence parallel training""" diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index 9cc0aebd0b..0453913434 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -16,4 +16,5 @@ show_layers, time_synchronize, unwrap_model_for_generation) from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port, format_time, get_env_args, import_external_file, json_parse_to_dict, lower_bound, parse_args, - patch_getattr, read_multi_line, seed_everything, split_list, subprocess_run, test_time, upper_bound) + patch_getattr, read_multi_line, remove_response, seed_everything, split_list, subprocess_run, + test_time, upper_bound) diff --git a/swift/utils/utils.py b/swift/utils/utils.py index 458921d536..64b433842c 100644 --- a/swift/utils/utils.py +++ b/swift/utils/utils.py @@ -366,3 +366,21 @@ def json_parse_to_dict(value: Union[str, Dict, None], strict: bool = True) -> Un logger.error(f"Unable to parse string: '{value}'") raise return value + + +def remove_response(messages) -> Optional[str]: + """ + Removes and returns the content of the last message if its role is 'assistant'. + + Args: + messages (List[Dict]): + A list of message dictionaries, each typically containing a 'role' and 'content' key. + + Returns: + Optional[str]: + The content of the removed 'assistant' message if present; + otherwise, returns None. The original messages list is modified in place. + """ + last_role = messages[-1]['role'] if messages else None + if last_role == 'assistant': + return messages.pop()['content'] From 8555f31f3f2705e982c946e03bf200e79e42aa22 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 8 Aug 2025 16:10:06 +0800 Subject: [PATCH 15/26] rename completion id --- swift/trainers/rlhf_trainer/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 0186d815cc..e17b2f68b9 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -981,8 +981,8 @@ def _prepare_batch_inputs(self, inputs: DataType, rewards: torch.Tensor) -> List # Encode and process each batch (size=bs) with self._template_context(template): processed_batch = [ - replace_assistant_response_with_ids(data['messages'], data['completion_ids']) - if 'completion_ids' in data else data for data in batch + replace_assistant_response_with_ids(data['messages'], data['response_token_ids']) + if 'response_token_ids' in data and data['response_token_ids'] else data for data in batch ] batch_encoded_inputs = [template.encode(data) for data in processed_batch] batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) From 4c0ada946e5937ac1aaf58e5e657037d63999326 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Fri, 8 Aug 2025 17:12:03 +0800 Subject: [PATCH 16/26] fix typo & bugs --- swift/plugin/multi_turn.py | 5 +++-- swift/plugin/orm.py | 8 ++++---- swift/trainers/rlhf_trainer/grpo_trainer.py | 6 +++--- swift/trainers/rlhf_trainer/vllm_client.py | 12 ++---------- 4 files changed, 12 insertions(+), 19 deletions(-) diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 5aa2c4317f..281a3de11c 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -181,7 +181,7 @@ async def run(self, *args, **kwargs): return RolloutOutput( response=response, messages=messages, - response_id=total_response_ids, + respones_token_ids=total_response_ids, response_loss_mask=total_response_loss_mask, rollout_infos=rollout_infos, ) @@ -334,7 +334,8 @@ def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> Dict: infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) - return infer_request + + return {'infer_request': infer_request} class GYMScheduler(RolloutScheduler): diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index 9b5e4947b7..d8f2b30042 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -321,9 +321,9 @@ def cosfn(t, T, min_value, max_value): def __call__(self, completions, solution, **kwargs) -> List[float]: acc_rewards = self.accuracy_orm(completions, solution, **kwargs) - completion_ids = kwargs.get('completion_ids') + response_token_ids = kwargs.get('response_token_ids') rewards = [] - for ids, acc_reward in zip(completion_ids, acc_rewards): + for ids, acc_reward in zip(response_token_ids, acc_rewards): is_correct = acc_reward >= 1. if is_correct: # Swap min/max for correct answers @@ -386,8 +386,8 @@ def __init__(self, soft_max_length, soft_cache_length): def __call__(self, completions, **kwargs) -> List[float]: rewards = [] - completion_ids = kwargs.get('completion_ids') - for ids in completion_ids: + response_token_ids = kwargs.get('response_token_ids') + for ids in response_token_ids: completion_length = len(ids) expected_len = self.soft_max_length - self.soft_cache_length exceed_len = completion_length - expected_len diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 0cea189337..7932f1970e 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1700,8 +1700,8 @@ def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: return inputs prev_messages = inputs[0].get('messages') - current_uuid = str(uuid.uuid4()) - inputs[0]['prompt_id'] = current_uuid + current_id = str(uuid.uuid4()) + inputs[0]['prompt_id'] = current_id for i in range(1, len(inputs)): messages = inputs[i]['messages'] @@ -1970,7 +1970,7 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out input_data['rollout_infos'] = output.rollout_infos # Step 4: Store finish reason (used for truncation filters etc.) - input_data['finish_reason'] = choice.finish_reason, + input_data['finish_reason'] = choice.finish_reason return input_data diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index e152abb13b..f88bb074ce 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -16,7 +16,7 @@ from transformers.utils import is_torch_cuda_available from swift.llm import AdapterRequest, RolloutInferRequest, Template -from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice, RequestConfig +from swift.llm.infer.protocol import RequestConfig, RolloutOutput from swift.plugin import Metric from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available @@ -150,7 +150,7 @@ def process_chunk(i, chunk): return resp_data = response.json() - results[i] = self.parse_resp_data(resp_data) + results[i] = [from_dict(data_class=RolloutOutput, data=resp) for resp in resp_data] except Exception as e: errors[i] = e @@ -277,11 +277,3 @@ def close_communicator(self): logger.warning(f'Server {i} close failed: {response.text}') except Exception as e: logger.warning(f'Error closing server {i} communicator: {str(e)}') - - def parse_resp_data(self, resp_data): - choice_cls = ChatCompletionResponseChoice - result = [ - ChatCompletionResponse(choices=[from_dict(data_class=choice_cls, data=c) for c in resp_data['choices']]) - ] - - return result From b9e4c042beaf678c1d1d1ccc4871864d3ac7a216 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 Aug 2025 00:37:57 +0800 Subject: [PATCH 17/26] compute loss for dynamic batch size --- swift/plugin/multi_turn.py | 6 +- swift/trainers/rlhf_trainer/grpo_trainer.py | 516 +++++++++++++++++--- 2 files changed, 465 insertions(+), 57 deletions(-) diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 281a3de11c..5120fe0c9b 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -16,7 +16,11 @@ class RolloutScheduler(ABC): # Single Turn Rollout Scheduler - def __init__(self, infer_engine: 'GRPOVllmEngine', max_turns: Optional[int] = None, *args, **kwargs): + def __init__(self, + infer_engine: Optional['GRPOVllmEngine'] = None, + max_turns: Optional[int] = None, + *args, + **kwargs): self.infer_engine = infer_engine self.max_turns = max_turns diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 7932f1970e..cef25e09cd 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -15,7 +15,7 @@ from math import ceil from queue import Queue from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch import torch.nn as nn @@ -65,6 +65,7 @@ import swanlab DataType = List[Dict[str, Union[torch.Tensor, Any]]] +T = TypeVar('T') # patch to fix save last_checkpoint https://github.com/modelscope/ms-swift/pull/4969 if not hasattr(RepeatSampler, 'old_len_func'): @@ -173,7 +174,7 @@ def __init__(self, multi_turn_scheduler = multi_turns[self.args.multi_turn_scheduler](max_turns=self.args.max_turns) self.multi_turn_scheduler: MultiTurnScheduler = multi_turn_scheduler else: - assert isinstance(multi_turn_scheduler, MultiTurnScheduler) + assert isinstance(self.args.multi_turn_scheduler, MultiTurnScheduler) self.multi_turn_scheduler: MultiTurnScheduler = self.args.multi_turn_scheduler self.num_generations = args.num_generations @@ -657,13 +658,14 @@ def _set_inputs_system(self, inputs: DataType) -> DataType: DataType: The input list, with the default system message prepended if applicable. """ if not self.template.template_meta.default_system: - return + return inputs if all(_input['messages'][0]['role'] == 'system' for _input in inputs): - return + return inputs for _input in inputs: messages = _input['messages'] if messages[0]['role'] != 'system': messages.insert(0, {'role': 'system', 'content': self.template.template_meta.default_system}) + return inputs def _infer_single_or_multi_turn(self, inputs: DataType, @@ -683,7 +685,7 @@ def _infer_single_or_multi_turn(self, # for external server, pass the system args which may define in trainer # Step 1: Prepare inputs with system prompts (if any) - self._set_inputs_system(inputs) + inputs = self._set_inputs_system(inputs) # Step 2: First-turn rollout rollout_outputs: List[RolloutOutput] = self._rollout(inputs, request_config, is_global_inputs) @@ -793,6 +795,7 @@ def _fast_infer(self, inputs: DataType) -> DataType: return outputs def _generate_completions(self, inputs: DataType) -> DataType: + # add prompt ids and system prompts inputs = self._preprocess_inputs(inputs) mode = 'train' if self.model.training else 'eval' @@ -825,17 +828,43 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions) # Prepare final outputs with advantages and other required fields - batch_encoded_inputs = self._prepare_batch_inputs(inputs, total_rewards) + inputs = self._calculate_advantages(inputs, total_rewards) + + batch_encoded_inputs = self._prepare_batch_inputs(inputs) # Log metrics messages = [inputs[i]['messages'][:-1] for i in range(len(inputs))] - trajectory_infos = None - if self.use_gym_env: - trajectory_infos = [inputs[i]['trajectory_info'] for i in range(len(inputs))] - self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func, - trajectory_infos) + + self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func) return batch_encoded_inputs + def _calculate_advantages(self, inputs: DataType, total_rewards: torch.Tensor) -> DataType: + advantages = self._compute_advantages(inputs, total_rewards) + # Legacy method (kept for backward compatibility) + legacy_advantages = self._compute_advantages_legacy(inputs, total_rewards) + + # DEBUG + assert torch.allclose(advantages, legacy_advantages) + try: + self._validate_advantage_calculation(inputs, total_rewards, advantages) + except Exception as e: + logger.warning(f'Advantage validation failed: {e}') + + # log advantages and image(for VL models) + self._logs['advantages'].extend(advantages.tolist()) + if any('images' in data and data['images'] is not None for data in inputs): + self._logs['image'].extend(gather_object([inp['images'] for inp in inputs])) + + # get local advantages + local_advantages = self.get_even_process_data(advantages) + assert len(local_advantages) == len(inputs) + + # merge advantages to inputs + for i, advantage in enumerate(local_advantages): + inputs[i]['advantage'] = advantage + + return inputs + def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: """Score completions using all reward functions @@ -935,55 +964,312 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): return inputs, rewards, rewards_per_func, completions - def split_by_mini_batches(self, inputs, advantages): + def split_by_mini_batches(self, inputs): + """ + Split inputs into mini-batches, handling variable generation counts. + + When rollout count differs from expected (bs * spg * num_generations), + we need to adjust the splitting logic to maintain proper batch sizes. + """ # Slice to keep only the local part of the data - process_slice = slice( - self.accelerator.process_index * len(inputs), - (self.accelerator.process_index + 1) * len(inputs), - ) - advantages = advantages[process_slice] mode = 'train' if self.model.training else 'eval' bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size spg = self.args.steps_per_generation if mode == 'train' else 1 + # TODO: Check + expected_normal_size = bs * spg + + # Check if we have the expected number of inputs (normal case) + if len(inputs) == expected_normal_size: + # Normal case: rollout returned expected count + # Group by (bs * num_generations) to maintain proper prompt grouping + group_size = bs * self.num_generations + spg_chunks = [inputs[i * group_size:(i + 1) * group_size] for i in range(spg)] + else: + # Variable generation case: split by actual per_device_batch_size to control memory + # Split into chunks of size bs to maintain memory efficiency + num_chunks = (len(inputs) + bs - 1) // bs # Ceiling division + spg_chunks = [] - assert len(inputs) == bs * spg, f'Expected {bs * spg} inputs, got {len(inputs)}' - spg_chunks = [inputs[i * bs:(i + 1) * bs] for i in range(spg)] - # Split advantages by spg chunks - advantage_chunks = torch.chunk(advantages, spg) - return spg_chunks, advantage_chunks + for i in range(num_chunks): + start_idx = i * bs + end_idx = min((i + 1) * bs, len(inputs)) + spg_chunks.append(inputs[start_idx:end_idx]) - def _prepare_batch_inputs(self, inputs: DataType, rewards: torch.Tensor) -> List[DataType]: + return spg_chunks + + def _compute_advantages(self, inputs: DataType, rewards: torch.Tensor) -> torch.Tensor: """ - Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. + Compute advantages based on prompt_id grouping, handling variable generation counts per prompt. Args: - inputs (DataType): List of input samples. Original shape is [spg*bs] where: - - spg: steps_per_generation - - bs: per-device batch size - rewards (torch.Tensor): Tensor of global rewards corresponding to the inputs. - Shape should match the total number of samples (spg*bs*num_processes*num_generations) + inputs (DataType): List of input samples with prompt_id fields + rewards (torch.Tensor): Tensor of rewards corresponding to the inputs Returns: - List[DataType]: A list of prepared batch inputs, organized as [spg][bs] + torch.Tensor: Computed advantages with same shape as rewards """ - # Compute advantages - grouped_rewards = rewards.view(-1, self.num_generations) - mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0) - std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0) + if len(inputs) != len(rewards): + raise ValueError(f'Inputs length ({len(inputs)}) != rewards length ({len(rewards)})') + + # Ensure all inputs have prompt_id + if not all('prompt_id' in inp for inp in inputs): + logger.warning('Some inputs missing prompt_id, adding them...') + inputs = self._add_prompt_id_to_inputs(inputs) + + # Group rewards by prompt_id + prompt_groups = {} + for i, inp in enumerate(inputs): + prompt_id = inp['prompt_id'] + if prompt_id not in prompt_groups: + prompt_groups[prompt_id] = [] + prompt_groups[prompt_id].append((i, rewards[i].item())) + + # Compute advantages for each group + advantages = torch.zeros_like(rewards) + + for prompt_id, group_data in prompt_groups.items(): + indices, group_rewards = zip(*group_data) + group_rewards = torch.tensor(group_rewards, device=rewards.device, dtype=rewards.dtype) + + group_mean = group_rewards.mean() + group_advantages = group_rewards - group_mean + + # Optional: scale by standard deviation + if self.args.scale_rewards: + group_std = group_rewards.std() + group_advantages /= (group_std + 1e-4) + + # Assign computed advantages back to original positions + for idx, advantage in zip(indices, group_advantages): + advantages[idx] = advantage + + # Check for groups with unexpected generation counts + generation_counts = [len(group_data) for group_data in prompt_groups.values()] + if generation_counts and (min(generation_counts) != max(generation_counts)): + logger.warning(f'Variable generation counts detected: min={min(generation_counts)}, ' + f'max={max(generation_counts)}, expected={self.num_generations}') + + return advantages + + def _compute_advantages_legacy(self, inputs: DataType, rewards: torch.Tensor) -> torch.Tensor: + """ + Legacy advantage computation method using reshape. - advantages = (rewards - mean_grouped_rewards) - if self.args.scale_rewards: - advantages /= (std_grouped_rewards + 1e-4) - self._logs['advantages'].extend(gather(advantages).tolist()) - if any('images' in data and data['images'] is not None for data in inputs): - self._logs['image'].extend(gather_object([inp['images'] for inp in inputs])) + This method assumes a fixed number of generations per prompt (self.num_generations). + Kept for backward compatibility, but may fail if rollout count differs from num_generations. + + Args: + inputs: List of input samples (not used in legacy method) + rewards: Tensor of rewards + + Returns: + torch.Tensor: Computed advantages + + Raises: + RuntimeError: If rewards cannot be reshaped to (-1, num_generations) + """ + try: + # Original logic - assumes fixed num_generations per prompt + grouped_rewards = rewards.view(-1, self.num_generations) + mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0) + + advantages = (rewards - mean_grouped_rewards) + if self.args.scale_rewards: + advantages /= (std_grouped_rewards + 1e-4) + + logger.debug(f'Legacy advantage computation: {len(rewards)} rewards, ' + f'{self.num_generations} generations per prompt') + + return advantages + + except RuntimeError as e: + raise + + def _validate_advantage_calculation(self, inputs: DataType, rewards: torch.Tensor, + advantages: torch.Tensor) -> None: + """ + Validate the computed advantages for correctness and consistency. + + This method performs several checks: + 1. Ensures advantages sum to ~0 within each prompt group + 2. Verifies advantage shapes match input/reward shapes + 3. Checks for NaN or infinite values + + Args: + inputs: Original input data with prompt_id fields + rewards: Original reward tensor + advantages: Computed advantage tensor + + Raises: + ValueError: If validation fails + """ + if len(advantages) != len(rewards) or len(advantages) != len(inputs): + raise ValueError(f'Shape mismatch: advantages={len(advantages)}, ' + f'rewards={len(rewards)}, inputs={len(inputs)}') + + # Check for NaN or infinite values + if torch.isnan(advantages).any(): + nan_count = torch.isnan(advantages).sum().item() + logger.warning(f'Found {nan_count} NaN values in advantages') + + if torch.isinf(advantages).any(): + inf_count = torch.isinf(advantages).sum().item() + logger.warning(f'Found {inf_count} infinite values in advantages') + + # Verify advantages sum to ~0 within each prompt group + prompt_groups = {} + for i, inp in enumerate(inputs): + prompt_id = inp['prompt_id'] + if prompt_id not in prompt_groups: + prompt_groups[prompt_id] = [] + prompt_groups[prompt_id].append(advantages[i].item()) + + tolerance = 1e-6 + problematic_groups = [] + for prompt_id, group_advantages in prompt_groups.items(): + if len(group_advantages) > 1: # Only check groups with multiple generations + group_sum = sum(group_advantages) + if abs(group_sum) > tolerance: + problematic_groups.append((prompt_id, group_sum, len(group_advantages))) + + if problematic_groups: + logger.warning(f'Found {len(problematic_groups)} prompt groups where advantages ' + f"don't sum to ~0 (tolerance={tolerance})") + for prompt_id, group_sum, count in problematic_groups[:5]: # Log first 5 + logger.warning(f' Group {prompt_id}: sum={group_sum:.8f}, count={count}') + + def _test_advantage_calculation_edge_cases(self) -> None: + """ + Test the advantage calculation with various edge cases. + This method is useful for debugging and validation. + """ + logger.info('Testing advantage calculation edge cases...') + + # Test case 1: Different generation counts per prompt + test_inputs_1 = [ + { + 'prompt_id': 'prompt_1', + 'messages': [{ + 'role': 'user', + 'content': 'hello' + }] + }, + { + 'prompt_id': 'prompt_1', + 'messages': [{ + 'role': 'user', + 'content': 'hello' + }] + }, + { + 'prompt_id': 'prompt_1', + 'messages': [{ + 'role': 'user', + 'content': 'hello' + }] + }, # 3 generations + { + 'prompt_id': 'prompt_2', + 'messages': [{ + 'role': 'user', + 'content': 'hi' + }] + }, + { + 'prompt_id': 'prompt_2', + 'messages': [{ + 'role': 'user', + 'content': 'hi' + }] + }, # 2 generations + { + 'prompt_id': 'prompt_3', + 'messages': [{ + 'role': 'user', + 'content': 'hey' + }] + }, # 1 generation + ] + test_rewards_1 = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], device=self.accelerator.device) + + try: + advantages_1 = self._compute_advantages(test_inputs_1, test_rewards_1) + logger.info(f'Test 1 passed: advantages shape {advantages_1.shape}') + + # Verify that advantages sum to ~0 for each group + assert abs(advantages_1[0] + advantages_1[1] + advantages_1[2]) < 1e-6, 'Group 1 advantages should sum to 0' + assert abs(advantages_1[3] + advantages_1[4]) < 1e-6, 'Group 2 advantages should sum to 0' + assert advantages_1[5] == 0.0, 'Single generation group should have 0 advantage' + + except Exception as e: + logger.error(f'Test 1 failed: {e}') + + # Test case 2: All same generation counts (should match original behavior) + test_inputs_2 = [ + { + 'prompt_id': 'prompt_a', + 'messages': [{ + 'role': 'user', + 'content': 'test1' + }] + }, + { + 'prompt_id': 'prompt_a', + 'messages': [{ + 'role': 'user', + 'content': 'test1' + }] + }, + { + 'prompt_id': 'prompt_b', + 'messages': [{ + 'role': 'user', + 'content': 'test2' + }] + }, + { + 'prompt_id': 'prompt_b', + 'messages': [{ + 'role': 'user', + 'content': 'test2' + }] + }, + ] + test_rewards_2 = torch.tensor([1.0, 3.0, 2.0, 4.0], device=self.accelerator.device) + + try: + advantages_2 = self._compute_advantages(test_inputs_2, test_rewards_2) + logger.info(f'Test 2 passed: advantages shape {advantages_2.shape}') + # Check that advantages are correctly computed + expected_adv_a = torch.tensor([-1.0, 1.0]) # (1-2, 3-2) + expected_adv_b = torch.tensor([-1.0, 1.0]) # (2-3, 4-3) + + assert torch.allclose(advantages_2[0:2], expected_adv_a, atol=1e-6), 'Group A advantages incorrect' + assert torch.allclose(advantages_2[2:4], expected_adv_b, atol=1e-6), 'Group B advantages incorrect' + + except Exception as e: + logger.error(f'Test 2 failed: {e}') + + logger.info('Advantage calculation edge case testing completed') + + def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: + """ + Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. + + Args: + inputs (DataType): List of local input samples. + + Returns: + List[DataType]: A list of prepared batch inputs, organized as [spg][bs] + """ template = self.template - gas_chunks, advantage_chunks = self.split_by_mini_batches(inputs, advantages) + gas_chunks, _ = self.split_by_mini_batches(inputs) ga_batch_encoded_inputs = [] - for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)): + for i, batch in enumerate(gas_chunks): # Encode and process each batch (size=bs) with self._template_context(template): processed_batch = [ @@ -996,6 +1282,7 @@ def _prepare_batch_inputs(self, inputs: DataType, rewards: torch.Tensor) -> List # Process labels and masks labels = batch_encoded_inputs.pop('labels') logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + batch_encoded_inputs.update({ 'completion_mask': labels[:, -logits_to_keep:] != -100, @@ -1003,8 +1290,6 @@ def _prepare_batch_inputs(self, inputs: DataType, rewards: torch.Tensor) -> List torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool), 'logits_to_keep': logits_to_keep, - 'advantages': - batch_advantages }) with torch.no_grad(): @@ -1107,6 +1392,21 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N def _compute_loss(self, model, inputs): mode = 'train' if self.model.training else 'eval' + # Check batch size and decide processing strategy + batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else len(inputs.get('completion_mask', [])) + expected_bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + + # If batch size matches expected, use normal processing + if batch_size == expected_bs: + return self._compute_loss_single(model, inputs) + else: + assert batch_size > expected_bs + return self._compute_loss_chunked(model, inputs) + + def _compute_loss_single(self, model, inputs): + """Original loss computation logic for single batch processing.""" + mode = 'train' if self.model.training else 'eval' + completion_mask = inputs['completion_mask'] truncated_mask = inputs['truncated_mask'] @@ -1137,8 +1437,7 @@ def _compute_loss(self, model, inputs): logger.info('All completions are overlong and truncated, ' 'resulting in NaN some values for some metrics (e.g., KL)') truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device) - completion_mask = completion_mask * (~truncated_mask) - + completion_mask.mul_(~truncated_mask) # Compute the KL divergence between the model and the reference model if self.beta != 0.0: ref_per_token_logps = inputs['ref_per_token_logps'] @@ -1219,6 +1518,42 @@ def masked_batch_mean(x): return loss + def _compute_loss_chunked(self, model, inputs): + mode = 'train' if self.model.training else 'eval' + chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else len(inputs.get('completion_mask', [])) + + logger.debug(f'Computing chunked loss for batch size {batch_size} with chunk size {chunk_size}') + + all_losses = [] + + # TODO: Aggregate metrics across chunks + # aggregated_metrics = {} + + for i in range(0, batch_size, chunk_size): + end_idx = min(i + chunk_size, batch_size) + + # Create chunk inputs + chunk_inputs = {} + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + chunk_inputs[key] = value[i:end_idx] + else: + chunk_inputs[key] = value + + # Compute loss for this chunk + chunk_loss = self._compute_loss_single(model, chunk_inputs) + + all_losses.append(chunk_loss) + + # Compute average loss + final_loss_tensor = torch.stack(all_losses).mean() + + logger.debug(f'Chunked loss computation completed: {len(all_losses)} chunks -> ' + f'final loss {final_loss_tensor.item():.6f}') + + return final_loss_tensor + @contextmanager def padding_free_context(self, model: torch.nn.Module): ctx = {} @@ -1297,6 +1632,27 @@ def _get_per_token_logps_and_entropies(self, model, inputs, compute_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute per-token log probabilities and entropies with memory-efficient batching. + + When rollout count is larger than expected, we process in smaller batches + to control memory usage. + """ + # Check if we need to use memory-efficient batching + batch_size = inputs['input_ids'].shape[0] + mode = 'train' if self.model.training else 'eval' + expected_bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + + # If batch is larger than threshold and adaptive batching is enabled, use chunked processing + if batch_size > expected_bs: + return self._get_per_token_logps_and_entropies_chunked(model, inputs, compute_entropy) + else: + return self._get_per_token_logps_and_entropies_single(model, inputs, compute_entropy) + + def _get_per_token_logps_and_entropies_single(self, + model, + inputs, + compute_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: logits_to_keep = inputs['logits_to_keep'] input_ids = inputs['input_ids'] unwrapped_model = self.accelerator.unwrap_model(model) @@ -1341,6 +1697,54 @@ def _get_per_token_logps_and_entropies(self, return logps, entropies + def _get_per_token_logps_and_entropies_chunked(self, + model, + inputs, + compute_entropy=False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Memory-efficient chunked processing for large batches. + + Splits the batch into smaller chunks based on per_device_batch_size + to control memory usage when rollout count is larger than expected. + """ + batch_size = inputs['input_ids'].shape[0] + mode = 'train' if self.model.training else 'eval' + chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + + logger.debug(f'Processing batch of size {batch_size} in chunks of {chunk_size}') + + all_logps = [] + all_entropies = [] if compute_entropy else None + + # Process in chunks + for i in range(0, batch_size, chunk_size): + end_idx = min(i + chunk_size, batch_size) + + # Create chunk inputs + chunk_inputs = {} + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + chunk_inputs[key] = value[i:end_idx] + else: + chunk_inputs[key] = value # Non-tensor values (like logits_to_keep) are scalars + + # Process this chunk + chunk_logps, chunk_entropies = self._get_per_token_logps_and_entropies_single( + model, chunk_inputs, compute_entropy) + + all_logps.append(chunk_logps) + if compute_entropy and chunk_entropies is not None: + all_entropies.append(chunk_entropies) + + # Concatenate results + final_logps = torch.cat(all_logps, dim=0) + final_entropies = torch.cat(all_entropies, dim=0) if all_entropies else None + + logger.debug(f'Chunked processing completed: {len(all_logps)} chunks -> ' f'final shape {final_logps.shape}') + + return final_logps, final_entropies + @patch_profiling_decorator def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep): # unwrap the model to access the model.model @@ -1775,7 +2179,7 @@ def _server_rollout(self, inputs: DataType, request_config: RequestConfig, outputs = outputs[process_slice] else: # Standard equal distribution case - outputs = self.get_even_process_rollout_outputs(all_outputs) + outputs = self.get_even_process_data(all_outputs) else: # For global inputs, only main process keeps outputs @@ -2006,7 +2410,7 @@ def _sync_multi_turn_infer(self, inputs: DataType, first_turn_rollout_outputs: L according to the multi_turn_scheduler. """ orig_size = len(inputs) - rollout_outputs = [None] * orig_size # Preallocate to preserve order + rollout_outputs: List[RolloutOutput] = [None] * orig_size # Preallocate to preserve order # Attach index to inputs for tracking for i, input_data in enumerate(inputs): @@ -2050,7 +2454,7 @@ def _sync_multi_turn_infer(self, inputs: DataType, first_turn_rollout_outputs: L continue index = _input['index'] step_result = self.multi_turn_scheduler.step( - self.inputs2requests([_input])[0], output.choices[0], current_turn) + self.inputs2requests([_input])[0], output.response.choices[0], current_turn) if step_result['response_token_ids']: rollout_outputs[index].response_token_ids.append(step_result['response_token_ids']) @@ -2072,23 +2476,23 @@ def _sync_multi_turn_infer(self, inputs: DataType, first_turn_rollout_outputs: L assert all(o is not None for o in rollout_outputs) return rollout_outputs - def get_even_process_rollout_outputs(self, all_outputs: List[RolloutOutput]) -> List[RolloutOutput]: + def get_even_process_data(self, global_data: List[T]) -> List[T]: """ - Evenly splits `all_outputs` among all processes. + Evenly splits `global_data` among all processes. - Each process receives a contiguous chunk of data. If `len(all_outputs)` is not + Each process receives a contiguous chunk of data. If `len(global_data)` is not perfectly divisible by the number of processes, the first `remainder` processes will receive one additional item. Args: - all_outputs (List[RolloutOutput]): The full list of outputs to be distributed. + global_data (List[T]): The full list of data to be distributed. Returns: - List[RolloutOutput]: The subset of `all_outputs` assigned to this process. + List[T]: The subset of `global_data` assigned to this process. """ num_procs = self.accelerator.num_processes proc_idx = self.accelerator.process_index - total = len(all_outputs) + total = len(global_data) base_size = total // num_procs remainder = total % num_procs @@ -2100,4 +2504,4 @@ def get_even_process_rollout_outputs(self, all_outputs: List[RolloutOutput]) -> start = remainder * (base_size + 1) + (proc_idx - remainder) * base_size end = start + base_size - return all_outputs[start:end] + return global_data[start:end] From eea448535426268ef8e417074666f4f54adeeaec Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 Aug 2025 10:27:58 +0800 Subject: [PATCH 18/26] fix tiny bugs --- swift/trainers/rlhf_trainer/grpo_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c24b98b954..1ae61c397c 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -35,7 +35,7 @@ from swift.llm import (InferRequest, MultiModelKeys, RequestConfig, RolloutInferRequest, RowPreprocessor, Template, to_device) -from swift.llm.infer.protocol import ChatCompletionResponse +from swift.llm.infer.protocol import ChatCompletionResponse, RolloutOutput from swift.llm.model.utils import get_llm_model from swift.llm.template.base import MaxLengthError from swift.llm.template.template_inputs import StdTemplateInputs @@ -1084,7 +1084,7 @@ def _compute_advantages_legacy(self, inputs: DataType, rewards: torch.Tensor) -> return advantages - except RuntimeError as e: + except RuntimeError: raise def _validate_advantage_calculation(self, inputs: DataType, rewards: torch.Tensor, @@ -2176,7 +2176,7 @@ def _server_rollout(self, inputs: DataType, request_config: RequestConfig, start_idx = sum(all_requests_lengths[:self.accelerator.process_index]) end_idx = start_idx + all_requests_lengths[self.accelerator.process_index] process_slice = slice(start_idx, end_idx) - outputs = outputs[process_slice] + outputs = all_outputs[process_slice] else: # Standard equal distribution case outputs = self.get_even_process_data(all_outputs) From 5b927de664ffd80d091020659d7f346bcc8f6602 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 Aug 2025 15:52:16 +0800 Subject: [PATCH 19/26] dynamic rollout advantages --- swift/llm/template/base.py | 17 + swift/plugin/multi_turn.py | 193 ++++++++- swift/trainers/rlhf_trainer/grpo_trainer.py | 449 +++++++------------- 3 files changed, 351 insertions(+), 308 deletions(-) diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index b8d90be095..5354d2ca64 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1044,6 +1044,23 @@ def _get_system(self, inputs) -> Optional[str]: return system def _swift_prepare_inputs(self, inputs): + """ + Preprocesses the list of messages in the input by merging and formatting consecutive messages + according to their roles. + + Specifically, this method: + - Merges consecutive messages from the same role ('assistant' or 'user') to prevent downstream errors. + - Detects consecutive tool-related messages following an assistant message, then formats and + combines them using `agent_template._format_tool_responses` for structured output. + - Updates the messages list in-place for further processing. + + Args: + inputs: An object containing a 'messages' attribute, which is a list of dictionaries. + Each message dictionary should have at least the keys 'role' and 'content'. + + Returns: + None. The input messages list is updated in-place. + """ messages = inputs.messages if len(messages) < 2: return diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 5120fe0c9b..09c5f7925d 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -267,9 +267,195 @@ def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: class ThinkingModelScheduler(MultiTurnScheduler): - # TODO: example for thinking model - # replace history thinking block - pass + """ + Scheduler for Thinking class models that handle multi-turn reasoning. + + For Thinking models, the assistant response format is: + " think content answer content " + + This scheduler: + 1. Parses think and answer content from assistant responses + 2. Only keeps the think content from the last round + 3. Processes each round's history separately + 4. Returns List[RolloutOutput] with one output per round + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _parse_think_answer(self, content: str) -> tuple[str, str]: + """ + Parse think and answer content from assistant response. + + Args: + content: Assistant response content + + Returns: + tuple: (think_content, answer_content) + """ + think_content = '' + answer_content = '' + + # Parse think content + think_start = content.find('') + think_end = content.find('') + if think_start != -1 and think_end != -1: + think_content = content[think_start + 7:think_end].strip() + + # Parse answer content + answer_start = content.find('') + answer_end = content.find('') + if answer_start != -1 and answer_end != -1: + answer_content = content[answer_start + 8:answer_end].strip() + + return think_content, answer_content + + def _is_thinking_template(self) -> bool: + """ + Check if the model's template is a ThinkingTemplate. + + Returns: + bool: True if the template is a ThinkingTemplate or its subclass + """ + if not hasattr(self.infer_engine, 'default_template'): + return False + + template = self.infer_engine.default_template + from swift.llm.template.template.utils import ThinkingTemplate + + return isinstance(template, ThinkingTemplate) + + def _build_round_history(self, original_messages: 'Messages', round_num: int, think_content: str) -> 'Messages': + """ + Build history for a specific round, keeping only the think content from the last round. + + Args: + original_messages: Original conversation messages + round_num: Current round number + think_content: Think content to include + + Returns: + Messages: History for this specific round + """ + from copy import deepcopy + + # If this is a thinking template, use the template's method to prepare messages + if self._is_thinking_template(): + # Create a mock inputs object to use the template's _swift_prepare_inputs method + class MockInputs: + + def __init__(self, messages): + self.messages = deepcopy(messages) + + mock_inputs = MockInputs(original_messages) + + # Set up the template for inference mode + template = self.infer_engine.default_template + original_is_training = getattr(template, 'is_training', False) + template.is_training = False + + # Use the template's method to prepare messages + template._swift_prepare_inputs(mock_inputs) + + # Restore original training state + template.is_training = original_is_training + + return mock_inputs.messages + else: + # Fallback to manual processing for non-thinking templates + round_messages = [] + + # Process messages in original order + for i, msg in enumerate(original_messages): + if msg['role'] == 'assistant' and isinstance(msg['content'], str) and i != len(original_messages) - 1: + # For assistant messages + assistant_no_think = msg['content'].split('')[-1].strip() + round_messages.append(assistant_no_think) + else: + round_messages.append(deepcopy(msg)) + + return round_messages + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> List['RolloutOutput']: + """ + Execute multi-turn conversation for Thinking models. + + Args: + infer_request: The initial inference request containing messages + request_config: Configuration for the inference request + **kwargs: Additional inference parameters + + Returns: + List[RolloutOutput]: List of outputs, one for each round + """ + from swift.llm.infer.protocol import RolloutOutput + + current_request = infer_request + current_turn = 1 + rollout_outputs = [] + last_think_content = '' + + while True: + messages = current_request.messages + if current_turn == 1 or not messages[-1]['content']: + # If it's the first turn or the last message content is empty(dummy), remove the response + remove_response(messages) + + # Get model response + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( + current_request, request_config, **kwargs) + response_choice: 'ChatCompletionResponseChoice' = response.choices[0] + + # Parse think and answer content + completion = response_choice.message.content + think_content, answer_content = self._parse_think_answer(completion) + + # Update last think content + if think_content: + last_think_content = think_content + + # Update conversation history + if messages[-1]['role'] == 'assistant': + messages[-1]['content'] += completion + else: + messages.append({'role': 'assistant', 'content': completion}) + + # Build history for this round + round_history = self._build_round_history(messages, current_turn, last_think_content) + + # Create RolloutOutput for this round + round_output = RolloutOutput( + response=response, + messages=round_history, + rollout_infos={ + 'num_turns': current_turn, + 'think_content': think_content, + 'answer_content': answer_content, + 'round_number': current_turn + }) + rollout_outputs.append(round_output) + + # Check stopping conditions + should_stop = self.check_finished(current_request, response_choice, current_turn) + + if self.max_turns: + should_stop = should_stop or (current_turn >= self.max_turns) + + if should_stop: + break + + # Prepare next turn + ret = self.step(current_request, response_choice, current_turn) + current_request: 'RolloutInferRequest' = ret['infer_request'] + + if current_request.messages[-1]['role'] == 'assistant': + # Add a dummy response to allow engine to continue generating + current_request.messages.append({'role': 'assistant', 'content': None}) + + current_turn += 1 + + return rollout_outputs class MathTipsScheduler(MultiTurnScheduler): @@ -475,4 +661,5 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque 'math_tip_trick': MathTipsScheduler, 'math_tip_trick_multi_turn': MathTipsMultiTurnScheduler, 'gym_scheduler': GYMScheduler, + 'thinking_scheduler': ThinkingModelScheduler, } diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 1ae61c397c..c7e3e6b50c 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -391,6 +391,8 @@ def single_sample_context(): self.truncated_resample_iterator = cyclic_iter(self.get_train_dataloader()) # flag indicating whether the evaluation has started self.eval_flag = False + # Record the number of samples that need to be padded for even distribution across processes + self.rollout_pad_count = 0 @patch_profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, Union[torch.Tensor, @@ -819,16 +821,22 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: inputs = self.resample_truncated_inputs(inputs) inputs = self._generate_completions(inputs) - total_rewards_per_func, total_rewards, completions = self._score_completions(inputs) + total_rewards_per_func, total_rewards, completions, total_advantages = self._score_completions(inputs) mode = 'train' if self.model.training else 'eval' if self.args.dynamic_sample and mode == 'train': # dynamic sampling for std=0 groups - inputs, total_rewards, total_rewards_per_func, completions = \ - self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions) + inputs, total_rewards, total_rewards_per_func, completions, total_advantages = \ + self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions, total_advantages) - # Prepare final outputs with advantages and other required fields - inputs = self._calculate_advantages(inputs, total_rewards) + local_advantages = self.get_even_process_data(total_advantages) + assert len(local_advantages) == len(inputs) + for i, advantage in enumerate(local_advantages): + inputs[i]['advantage'] = advantage + + self._logs['advantages'].extend(total_advantages.tolist()) + if any('images' in data and data['images'] is not None for data in inputs): + self._logs['image'].extend(gather_object([inp['images'] for inp in inputs])) batch_encoded_inputs = self._prepare_batch_inputs(inputs) # Log metrics @@ -838,35 +846,8 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: return batch_encoded_inputs - def _calculate_advantages(self, inputs: DataType, total_rewards: torch.Tensor) -> DataType: - advantages = self._compute_advantages(inputs, total_rewards) - # Legacy method (kept for backward compatibility) - legacy_advantages = self._compute_advantages_legacy(inputs, total_rewards) - - # DEBUG - assert torch.allclose(advantages, legacy_advantages) - try: - self._validate_advantage_calculation(inputs, total_rewards, advantages) - except Exception as e: - logger.warning(f'Advantage validation failed: {e}') - - # log advantages and image(for VL models) - self._logs['advantages'].extend(advantages.tolist()) - if any('images' in data and data['images'] is not None for data in inputs): - self._logs['image'].extend(gather_object([inp['images'] for inp in inputs])) - - # get local advantages - local_advantages = self.get_even_process_data(advantages) - assert len(local_advantages) == len(inputs) - - # merge advantages to inputs - for i, advantage in enumerate(local_advantages): - inputs[i]['advantage'] = advantage - - return inputs - - def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: - """Score completions using all reward functions + def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tensor, List[str], torch.Tensor]: + """Score completions using all reward functions and compute advantages Args: inputs: List of input examples, each containing a 'messages' list with conversation history @@ -876,9 +857,11 @@ def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tens - rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards - total_rewards: Tensor of shape (num_examples,) with weighted sum of rewards - completions: List of generated completion strings + - advantages: Tensor of shape (num_examples,) with computed advantages """ device = self.accelerator.device completions = [example['messages'][-1]['content'] for example in inputs] + # If using gym environment, extract rewards directly from inputs if self.use_gym_env: total_rewards = torch.tensor([inp['total_reward'] for inp in inputs], dtype=torch.float32, device=device) @@ -886,7 +869,11 @@ def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tens rewards_per_func = total_rewards.unsqueeze(1) # shape: [num_examples, 1] total_rewards_per_func = gather(rewards_per_func) total_rewards_gathered = total_rewards_per_func.squeeze(1) # Recover from gathered data - return total_rewards_per_func, total_rewards_gathered, completions + total_prompt_ids = gather_object([inp['prompt_id'] for inp in inputs]) + # flatten + total_prompt_ids = [pid for sublist in total_prompt_ids for pid in sublist] + total_advantages = self._compute_advantages(total_rewards_gathered, total_prompt_ids) + return total_rewards_per_func, total_rewards_gathered, completions, total_advantages rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( @@ -912,12 +899,96 @@ def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tens logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' 'Please ensure that at least one reward function returns a valid reward.') - total_rewards_per_func = gather(rewards_per_func) - total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + # Calculate total rewards + total_rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Extract prompt_ids for grouping + prompt_ids = [inp['prompt_id'] for inp in inputs] + + # Prepare for gather with padding + if self.rollout_pad_count > 0: + # Pad total rewards with NaN + pad_total_rewards = torch.full((self.rollout_pad_count, ), torch.nan, dtype=torch.float32, device=device) + total_rewards = torch.cat([total_rewards, pad_total_rewards], dim=0) + + # Pad prompt_ids with special dummy value + dummy_prompt_ids = ['__dummy_pad__'] * self.rollout_pad_count + prompt_ids = prompt_ids + dummy_prompt_ids + + # Gather all data across processes + gathered_total_rewards = gather(total_rewards) + gathered_prompt_ids = gather_object(prompt_ids) + + # Remove dummy data (prompt_id == "__dummy_pad__") + valid_indices = [i for i, id in enumerate(gathered_prompt_ids) if id != '__dummy_pad__'] + valid_total_rewards = gathered_total_rewards[valid_indices] + valid_prompt_ids = [gathered_prompt_ids[i] for i in valid_indices] + + # Compute advantages based on prompt_id grouping + total_advantages = self._compute_advantages(valid_total_rewards, valid_prompt_ids) + + # Create local advantages by filtering to original local data + assert rewards_per_func.shape[0] == total_rewards.shape[0] == total_advantages.shape[0] + return rewards_per_func, total_rewards, completions, total_advantages + + def _compute_advantages(self, rewards: torch.Tensor, prompt_ids: List[str]) -> torch.Tensor: + """ + Compute advantages based on prompt_id grouping from gathered data. + + Args: + rewards: Tensor of rewards from all processes, shape: (num_examples,) + prompt_ids: List of prompt_ids from all processes, len: num_examples + + Returns: + torch.Tensor: Computed advantages with same shape as rewards + """ + assert rewards.shape[0] == len(prompt_ids) + advantages = torch.zeros_like(rewards) + + # Group rewards by prompt_id + unique_prompt_ids = list(set(prompt_ids)) - return total_rewards_per_func, total_rewards, completions + for prompt_id in unique_prompt_ids: + # Find all samples with this prompt_id + indices = [i for i, pid in enumerate(prompt_ids) if pid == prompt_id] + if len(indices) == 0: + continue + + group_rewards = rewards[indices] - def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): + # Compute group statistics + group_mean = group_rewards.mean() + group_advantages = group_rewards - group_mean + + # Optional: scale by standard deviation + if self.args.scale_rewards: + group_std = group_rewards.std() + group_advantages /= (group_std + 1e-4) + + # Assign computed advantages back to original positions + for idx, advantage in zip(indices, group_advantages): + advantages[idx] = advantage + + return advantages + + def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, advantages): + """ + Perform dynamic sampling to replace samples with zero-reward-variance groups. + + This method implements DAPO (https://arxiv.org/abs/2503.14476) by replacing + samples from groups with zero reward variance (std=0) through resampling. + + Args: + inputs: local input data samples + rewards: Tensor of rewards for global data samples + rewards_per_func: Rewards per function/model for global data samples + completions: Generated completions for local inputs + advantages: Computed advantages for global data samples + + Returns: + tuple: (inputs, rewards, rewards_per_func, completions, advantages) + with zero-variance groups replaced by resampled data + """ # DAPO https://arxiv.org/abs/2503.14476 # Replaces samples with zero-reward-variance groups (std=0) resample_count = 0 @@ -925,8 +996,8 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): valid_rewards = [] valid_rewards_per_func = [] valid_completions = [] - - origin_data = (inputs, rewards, rewards_per_func, completions) + valid_advantages = [] + origin_data = (inputs, rewards, rewards_per_func, completions, advantages) while resample_count < self.args.max_resample_times: grouped_rewards = rewards.view(-1, self.num_generations) @@ -939,14 +1010,14 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): valid_rewards_per_func.append(rewards_per_func[valid_mask]) valid_completions.extend( [inp['messages'][-1]['content'] for inp, mask in zip(all_inputs, valid_mask) if mask]) - + valid_advantages.append(advantages[valid_mask]) if len(valid_samples) >= self.args.generation_batch_size: break inputs = next(self.dynamic_resample_iterator) inputs = Trainer._prepare_inputs(self, inputs) inputs = self._generate_completions(inputs) - rewards_per_func, rewards, completions = self._score_completions(inputs) + rewards_per_func, rewards, completions, advantages = self._score_completions(inputs) resample_count += 1 if len(valid_samples) >= self.args.generation_batch_size: @@ -958,11 +1029,12 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): rewards = torch.cat(valid_rewards)[:self.args.generation_batch_size] rewards_per_func = torch.cat(valid_rewards_per_func)[:self.args.generation_batch_size] completions = valid_completions[:self.args.generation_batch_size][process_slice] + advantages = torch.cat(valid_advantages)[:self.args.generation_batch_size] else: logger.warning(f'There are still std=0 groups present after {self.args.max_resample_times} retries.') - inputs, rewards, rewards_per_func, completions = origin_data + inputs, rewards, rewards_per_func, completions, advantages = origin_data - return inputs, rewards, rewards_per_func, completions + return inputs, rewards, rewards_per_func, completions, advantages def split_by_mini_batches(self, inputs): """ @@ -998,263 +1070,6 @@ def split_by_mini_batches(self, inputs): return spg_chunks - def _compute_advantages(self, inputs: DataType, rewards: torch.Tensor) -> torch.Tensor: - """ - Compute advantages based on prompt_id grouping, handling variable generation counts per prompt. - - Args: - inputs (DataType): List of input samples with prompt_id fields - rewards (torch.Tensor): Tensor of rewards corresponding to the inputs - - Returns: - torch.Tensor: Computed advantages with same shape as rewards - """ - if len(inputs) != len(rewards): - raise ValueError(f'Inputs length ({len(inputs)}) != rewards length ({len(rewards)})') - - # Ensure all inputs have prompt_id - if not all('prompt_id' in inp for inp in inputs): - logger.warning('Some inputs missing prompt_id, adding them...') - inputs = self._add_prompt_id_to_inputs(inputs) - - # Group rewards by prompt_id - prompt_groups = {} - for i, inp in enumerate(inputs): - prompt_id = inp['prompt_id'] - if prompt_id not in prompt_groups: - prompt_groups[prompt_id] = [] - prompt_groups[prompt_id].append((i, rewards[i].item())) - - # Compute advantages for each group - advantages = torch.zeros_like(rewards) - - for prompt_id, group_data in prompt_groups.items(): - indices, group_rewards = zip(*group_data) - group_rewards = torch.tensor(group_rewards, device=rewards.device, dtype=rewards.dtype) - - group_mean = group_rewards.mean() - group_advantages = group_rewards - group_mean - - # Optional: scale by standard deviation - if self.args.scale_rewards: - group_std = group_rewards.std() - group_advantages /= (group_std + 1e-4) - - # Assign computed advantages back to original positions - for idx, advantage in zip(indices, group_advantages): - advantages[idx] = advantage - - # Check for groups with unexpected generation counts - generation_counts = [len(group_data) for group_data in prompt_groups.values()] - if generation_counts and (min(generation_counts) != max(generation_counts)): - logger.warning(f'Variable generation counts detected: min={min(generation_counts)}, ' - f'max={max(generation_counts)}, expected={self.num_generations}') - - return advantages - - def _compute_advantages_legacy(self, inputs: DataType, rewards: torch.Tensor) -> torch.Tensor: - """ - Legacy advantage computation method using reshape. - - This method assumes a fixed number of generations per prompt (self.num_generations). - Kept for backward compatibility, but may fail if rollout count differs from num_generations. - - Args: - inputs: List of input samples (not used in legacy method) - rewards: Tensor of rewards - - Returns: - torch.Tensor: Computed advantages - - Raises: - RuntimeError: If rewards cannot be reshaped to (-1, num_generations) - """ - try: - # Original logic - assumes fixed num_generations per prompt - grouped_rewards = rewards.view(-1, self.num_generations) - mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0) - std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0) - - advantages = (rewards - mean_grouped_rewards) - if self.args.scale_rewards: - advantages /= (std_grouped_rewards + 1e-4) - - logger.debug(f'Legacy advantage computation: {len(rewards)} rewards, ' - f'{self.num_generations} generations per prompt') - - return advantages - - except RuntimeError: - raise - - def _validate_advantage_calculation(self, inputs: DataType, rewards: torch.Tensor, - advantages: torch.Tensor) -> None: - """ - Validate the computed advantages for correctness and consistency. - - This method performs several checks: - 1. Ensures advantages sum to ~0 within each prompt group - 2. Verifies advantage shapes match input/reward shapes - 3. Checks for NaN or infinite values - - Args: - inputs: Original input data with prompt_id fields - rewards: Original reward tensor - advantages: Computed advantage tensor - - Raises: - ValueError: If validation fails - """ - if len(advantages) != len(rewards) or len(advantages) != len(inputs): - raise ValueError(f'Shape mismatch: advantages={len(advantages)}, ' - f'rewards={len(rewards)}, inputs={len(inputs)}') - - # Check for NaN or infinite values - if torch.isnan(advantages).any(): - nan_count = torch.isnan(advantages).sum().item() - logger.warning(f'Found {nan_count} NaN values in advantages') - - if torch.isinf(advantages).any(): - inf_count = torch.isinf(advantages).sum().item() - logger.warning(f'Found {inf_count} infinite values in advantages') - - # Verify advantages sum to ~0 within each prompt group - prompt_groups = {} - for i, inp in enumerate(inputs): - prompt_id = inp['prompt_id'] - if prompt_id not in prompt_groups: - prompt_groups[prompt_id] = [] - prompt_groups[prompt_id].append(advantages[i].item()) - - tolerance = 1e-6 - problematic_groups = [] - for prompt_id, group_advantages in prompt_groups.items(): - if len(group_advantages) > 1: # Only check groups with multiple generations - group_sum = sum(group_advantages) - if abs(group_sum) > tolerance: - problematic_groups.append((prompt_id, group_sum, len(group_advantages))) - - if problematic_groups: - logger.warning(f'Found {len(problematic_groups)} prompt groups where advantages ' - f"don't sum to ~0 (tolerance={tolerance})") - for prompt_id, group_sum, count in problematic_groups[:5]: # Log first 5 - logger.warning(f' Group {prompt_id}: sum={group_sum:.8f}, count={count}') - - def _test_advantage_calculation_edge_cases(self) -> None: - """ - Test the advantage calculation with various edge cases. - This method is useful for debugging and validation. - """ - logger.info('Testing advantage calculation edge cases...') - - # Test case 1: Different generation counts per prompt - test_inputs_1 = [ - { - 'prompt_id': 'prompt_1', - 'messages': [{ - 'role': 'user', - 'content': 'hello' - }] - }, - { - 'prompt_id': 'prompt_1', - 'messages': [{ - 'role': 'user', - 'content': 'hello' - }] - }, - { - 'prompt_id': 'prompt_1', - 'messages': [{ - 'role': 'user', - 'content': 'hello' - }] - }, # 3 generations - { - 'prompt_id': 'prompt_2', - 'messages': [{ - 'role': 'user', - 'content': 'hi' - }] - }, - { - 'prompt_id': 'prompt_2', - 'messages': [{ - 'role': 'user', - 'content': 'hi' - }] - }, # 2 generations - { - 'prompt_id': 'prompt_3', - 'messages': [{ - 'role': 'user', - 'content': 'hey' - }] - }, # 1 generation - ] - test_rewards_1 = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], device=self.accelerator.device) - - try: - advantages_1 = self._compute_advantages(test_inputs_1, test_rewards_1) - logger.info(f'Test 1 passed: advantages shape {advantages_1.shape}') - - # Verify that advantages sum to ~0 for each group - assert abs(advantages_1[0] + advantages_1[1] + advantages_1[2]) < 1e-6, 'Group 1 advantages should sum to 0' - assert abs(advantages_1[3] + advantages_1[4]) < 1e-6, 'Group 2 advantages should sum to 0' - assert advantages_1[5] == 0.0, 'Single generation group should have 0 advantage' - - except Exception as e: - logger.error(f'Test 1 failed: {e}') - - # Test case 2: All same generation counts (should match original behavior) - test_inputs_2 = [ - { - 'prompt_id': 'prompt_a', - 'messages': [{ - 'role': 'user', - 'content': 'test1' - }] - }, - { - 'prompt_id': 'prompt_a', - 'messages': [{ - 'role': 'user', - 'content': 'test1' - }] - }, - { - 'prompt_id': 'prompt_b', - 'messages': [{ - 'role': 'user', - 'content': 'test2' - }] - }, - { - 'prompt_id': 'prompt_b', - 'messages': [{ - 'role': 'user', - 'content': 'test2' - }] - }, - ] - test_rewards_2 = torch.tensor([1.0, 3.0, 2.0, 4.0], device=self.accelerator.device) - - try: - advantages_2 = self._compute_advantages(test_inputs_2, test_rewards_2) - logger.info(f'Test 2 passed: advantages shape {advantages_2.shape}') - - # Check that advantages are correctly computed - expected_adv_a = torch.tensor([-1.0, 1.0]) # (1-2, 3-2) - expected_adv_b = torch.tensor([-1.0, 1.0]) # (2-3, 4-3) - - assert torch.allclose(advantages_2[0:2], expected_adv_a, atol=1e-6), 'Group A advantages incorrect' - assert torch.allclose(advantages_2[2:4], expected_adv_b, atol=1e-6), 'Group B advantages incorrect' - - except Exception as e: - logger.error(f'Test 2 failed: {e}') - - logger.info('Advantage calculation edge case testing completed') - def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: """ Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. @@ -1311,7 +1126,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: return ga_batch_encoded_inputs - def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func, trajectory_infos=None): + def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func): """Log training/evaluation metrics""" mode = 'train' if self.model.training else 'eval' device = self.accelerator.device @@ -1350,7 +1165,9 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func, self._logs['completion'].extend(gather_object(completions)) if self.use_gym_env: - self._logs['trajectory_infos'].extend(gather_object(trajectory_infos)) + pass + # TODO: extra from rollout_infos + # self._logs['trajectory_infos'].extend(gather_object(trajectory_infos)) for i, name in enumerate(self.reward_func_names): self._logs['rewards'][name].extend(rewards_per_func[:, i].tolist()) @@ -1830,12 +1647,28 @@ def _engine_infer( *, use_tqdm: Optional[bool] = False, ) -> List[RolloutOutput]: + """ + Perform inference using the configured engine (VLLM server or colocate engine). + + Args: + infer_requests: List of rollout inference requests to process + request_config: Optional configuration for the requests + use_tqdm: Whether to show progress bar during inference + + Returns: + List of RolloutOutput objects containing the inference results + """ with patch_profiling_context(self, 'generate'): if self.vllm_mode == 'server': return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm) else: res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) - return [RolloutOutput(response=r) for r in res] + if all(isinstance(r, RolloutOutput) for r in res): + return res + else: + # PT Eninge + assert all(isinstance(r, ChatCompletionResponse) for r in res) + return [RolloutOutput(response=r) for r in res] def old_policy(self): return self.num_iterations > 1 or self.args.gradient_accumulation_steps % self.args.steps_per_generation != 0 @@ -2497,6 +2330,12 @@ def get_even_process_data(self, global_data: List[T]) -> List[T]: base_size = total // num_procs remainder = total % num_procs + # Calculate the number of samples that need to be padded + # This ensures all processes have the same number of samples for gather operations + if remainder > 0 and proc_idx >= remainder: + # Processes with extra samples need padding + self.rollout_pad_count = 1 + if proc_idx < remainder: start = proc_idx * (base_size + 1) end = start + base_size + 1 From b0c52b7ef6995dbb8a14a5b3b521b8d92269abeb Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 Aug 2025 16:37:43 +0800 Subject: [PATCH 20/26] fix score_completions --- swift/plugin/multi_turn.py | 199 ++++++++++---------- swift/trainers/rlhf_trainer/grpo_trainer.py | 85 ++++++--- 2 files changed, 160 insertions(+), 124 deletions(-) diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 09c5f7925d..4e22b7b26c 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -123,32 +123,39 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque **kwargs) -> Union['RolloutOutput', List['RolloutOutput']]: """Execute multi-turn conversation rollout with built-in turn management logic. - This implements the default multi-turn interaction flow, which you can override - to customize the conversation handling behavior. The default logic: + This implements the default multi-turn interaction flow that can be overridden + to customize conversation handling behavior. The default logic provides: - 1. Manages conversation turns and stopping conditions - 2. Handles message accumulation across turns - 3. Tracks response tokens and loss masks - 4. Supports early stopping conditions + 1. Automatic conversation turn management and stopping conditions + 2. Seamless message accumulation across multiple turns + 3. Response token tracking and loss mask management + 4. Configurable early stopping mechanisms Args: - infer_request: The initial inference request containing messages - request_config: Configuration for the inference request - **kwargs: Additional inference parameters + infer_request: The initial inference request containing conversation messages + request_config: Configuration parameters for the inference request + **kwargs: Additional inference parameters passed to the engine Returns: RolloutOutput containing the complete conversation history and metadata, or a list of outputs for batched requests - Customization Points: - - Override check_finished() to change stopping conditions - - Override step() to customize turn-to-turn transitions - - Subclass to completely change multi-turn behavior + Customization Approaches: + - Override check_finished() to implement custom stopping criteria + - Override step() to customize turn-to-turn transition logic + - Override this entire run() method for completely custom multi-turn behavior + + Important Notes: + - Method overriding is only supported when using server mode (swift rollout) + with vllm_use_async_engine=True + - Custom implementations must maintain async/await compatibility + - Ensure proper handling of conversation state across turns Example: class CustomScheduler(MultiTurnScheduler): - async def run(self, *args, **kwargs): - # Custom multi-turn logic here + async def run(self, infer_request, request_config, **kwargs): + # Implement custom multi-turn conversation logic + # Must return RolloutOutput or List[RolloutOutput] ... """ @@ -283,6 +290,87 @@ class ThinkingModelScheduler(MultiTurnScheduler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> List['RolloutOutput']: + """ + Execute multi-turn conversation for Thinking models. + + Args: + infer_request: The initial inference request containing messages + request_config: Configuration for the inference request + **kwargs: Additional inference parameters + + Returns: + List[RolloutOutput]: List of outputs, one for each round + """ + from swift.llm.infer.protocol import RolloutOutput + + current_request = infer_request + current_turn = 1 + rollout_outputs = [] + last_think_content = '' + + while True: + messages = current_request.messages + if current_turn == 1 or not messages[-1]['content']: + # If it's the first turn or the last message content is empty(dummy), remove the response + remove_response(messages) + + # Get model response + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( + current_request, request_config, **kwargs) + response_choice: 'ChatCompletionResponseChoice' = response.choices[0] + + # Parse think and answer content + completion = response_choice.message.content + think_content, answer_content = self._parse_think_answer(completion) + + # Update last think content + if think_content: + last_think_content = think_content + + # Update conversation history + if messages[-1]['role'] == 'assistant': + messages[-1]['content'] += completion + else: + messages.append({'role': 'assistant', 'content': completion}) + + # Build history for this round + round_history = self._build_round_history(messages, current_turn, last_think_content) + + # Create RolloutOutput for this round + round_output = RolloutOutput( + response=response, + messages=round_history, + rollout_infos={ + 'num_turns': current_turn, + 'think_content': think_content, + 'answer_content': answer_content, + 'round_number': current_turn + }) + rollout_outputs.append(round_output) + + # Check stopping conditions + should_stop = self.check_finished(current_request, response_choice, current_turn) + + if self.max_turns: + should_stop = should_stop or (current_turn >= self.max_turns) + + if should_stop: + break + + # Prepare next turn + ret = self.step(current_request, response_choice, current_turn) + current_request: 'RolloutInferRequest' = ret['infer_request'] + + if current_request.messages[-1]['role'] == 'assistant': + # Add a dummy response to allow engine to continue generating + current_request.messages.append({'role': 'assistant', 'content': None}) + + current_turn += 1 + + return rollout_outputs + def _parse_think_answer(self, content: str) -> tuple[str, str]: """ Parse think and answer content from assistant response. @@ -376,87 +464,6 @@ def __init__(self, messages): return round_messages - async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', - **kwargs) -> List['RolloutOutput']: - """ - Execute multi-turn conversation for Thinking models. - - Args: - infer_request: The initial inference request containing messages - request_config: Configuration for the inference request - **kwargs: Additional inference parameters - - Returns: - List[RolloutOutput]: List of outputs, one for each round - """ - from swift.llm.infer.protocol import RolloutOutput - - current_request = infer_request - current_turn = 1 - rollout_outputs = [] - last_think_content = '' - - while True: - messages = current_request.messages - if current_turn == 1 or not messages[-1]['content']: - # If it's the first turn or the last message content is empty(dummy), remove the response - remove_response(messages) - - # Get model response - response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( - current_request, request_config, **kwargs) - response_choice: 'ChatCompletionResponseChoice' = response.choices[0] - - # Parse think and answer content - completion = response_choice.message.content - think_content, answer_content = self._parse_think_answer(completion) - - # Update last think content - if think_content: - last_think_content = think_content - - # Update conversation history - if messages[-1]['role'] == 'assistant': - messages[-1]['content'] += completion - else: - messages.append({'role': 'assistant', 'content': completion}) - - # Build history for this round - round_history = self._build_round_history(messages, current_turn, last_think_content) - - # Create RolloutOutput for this round - round_output = RolloutOutput( - response=response, - messages=round_history, - rollout_infos={ - 'num_turns': current_turn, - 'think_content': think_content, - 'answer_content': answer_content, - 'round_number': current_turn - }) - rollout_outputs.append(round_output) - - # Check stopping conditions - should_stop = self.check_finished(current_request, response_choice, current_turn) - - if self.max_turns: - should_stop = should_stop or (current_turn >= self.max_turns) - - if should_stop: - break - - # Prepare next turn - ret = self.step(current_request, response_choice, current_turn) - current_request: 'RolloutInferRequest' = ret['infer_request'] - - if current_request.messages[-1]['role'] == 'assistant': - # Add a dummy response to allow engine to continue generating - current_request.messages.append({'role': 'assistant', 'content': None}) - - current_turn += 1 - - return rollout_outputs - class MathTipsScheduler(MultiTurnScheduler): tips_prompt = 'But wait... It seems I made a mistake,' diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index c7e3e6b50c..4efad7d54b 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -14,6 +14,7 @@ from dataclasses import asdict, dataclass, field from math import ceil from queue import Queue +from threading import local from types import MethodType from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union @@ -854,7 +855,7 @@ def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tens Returns: Tuple containing: - - rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards + - total_rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards - total_rewards: Tensor of shape (num_examples,) with weighted sum of rewards - completions: List of generated completion strings - advantages: Tensor of shape (num_examples,) with computed advantages @@ -862,18 +863,40 @@ def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tens device = self.accelerator.device completions = [example['messages'][-1]['content'] for example in inputs] + # Extract prompt_ids for grouping + prompt_ids = [inp['prompt_id'] for inp in inputs] + # If using gym environment, extract rewards directly from inputs if self.use_gym_env: - total_rewards = torch.tensor([inp['total_reward'] for inp in inputs], dtype=torch.float32, device=device) - # For gym environment, there's only one total reward, so rewards_per_func is just total_rewards reshaped - rewards_per_func = total_rewards.unsqueeze(1) # shape: [num_examples, 1] - total_rewards_per_func = gather(rewards_per_func) - total_rewards_gathered = total_rewards_per_func.squeeze(1) # Recover from gathered data - total_prompt_ids = gather_object([inp['prompt_id'] for inp in inputs]) - # flatten - total_prompt_ids = [pid for sublist in total_prompt_ids for pid in sublist] - total_advantages = self._compute_advantages(total_rewards_gathered, total_prompt_ids) - return total_rewards_per_func, total_rewards_gathered, completions, total_advantages + local_rewards = torch.tensor([inp['total_reward'] for inp in inputs], dtype=torch.float32, device=device) + # For gym environment, there's only one total reward, so rewards_per_func is just local_rewards reshaped + local_rewards_per_func = local_rewards.unsqueeze(1) # shape: [num_examples, 1] + else: + # Compute rewards using reward functions + local_rewards_per_func = self._compute_rewards_per_func(inputs, completions) + local_rewards = (local_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Gather rewards and prompt_ids across processes with padding + gathered_rewards_per_func, gathered_prompt_ids = self._gather_rewards_and_prompt_ids( + local_rewards_per_func, prompt_ids) + + # Remove dummy data and compute total rewards + total_rewards_per_func, total_prompt_ids = self._remove_dummy_data(gathered_rewards_per_func, + gathered_prompt_ids) + + if self.use_gym_env: + total_rewards = total_rewards_per_func.squeeze(1) # Recover from gathered data + else: + total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Compute advantages based on prompt_id grouping + total_advantages = self._compute_advantages(total_rewards, total_prompt_ids) + + return total_rewards_per_func, total_rewards, completions, total_advantages + + def _compute_rewards_per_func(self, inputs: DataType, completions: List[str]) -> torch.Tensor: + """Compute rewards using all reward functions""" + device = self.accelerator.device rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( @@ -899,37 +922,42 @@ def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tens logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' 'Please ensure that at least one reward function returns a valid reward.') - # Calculate total rewards - total_rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + return rewards_per_func - # Extract prompt_ids for grouping - prompt_ids = [inp['prompt_id'] for inp in inputs] + def _gather_rewards_and_prompt_ids(self, local_rewards_per_func: torch.Tensor, + local_prompt_ids: List[str]) -> Tuple[torch.Tensor, List[str]]: + """Gather rewards and prompt_ids across processes with padding""" + device = self.accelerator.device + rewards_per_func = local_rewards_per_func + prompt_ids = local_prompt_ids # Prepare for gather with padding if self.rollout_pad_count > 0: - # Pad total rewards with NaN - pad_total_rewards = torch.full((self.rollout_pad_count, ), torch.nan, dtype=torch.float32, device=device) - total_rewards = torch.cat([total_rewards, pad_total_rewards], dim=0) + # Pad rewards with NaN + pad_rewards_per_func = torch.full((self.rollout_pad_count, rewards_per_func.shape[1]), + torch.nan, + dtype=torch.float32, + device=device) + rewards_per_func = torch.cat([rewards_per_func, pad_rewards_per_func], dim=0) # Pad prompt_ids with special dummy value dummy_prompt_ids = ['__dummy_pad__'] * self.rollout_pad_count prompt_ids = prompt_ids + dummy_prompt_ids # Gather all data across processes - gathered_total_rewards = gather(total_rewards) + gathered_rewards_per_func = gather(rewards_per_func) gathered_prompt_ids = gather_object(prompt_ids) - # Remove dummy data (prompt_id == "__dummy_pad__") - valid_indices = [i for i, id in enumerate(gathered_prompt_ids) if id != '__dummy_pad__'] - valid_total_rewards = gathered_total_rewards[valid_indices] - valid_prompt_ids = [gathered_prompt_ids[i] for i in valid_indices] + return gathered_rewards_per_func, gathered_prompt_ids - # Compute advantages based on prompt_id grouping - total_advantages = self._compute_advantages(valid_total_rewards, valid_prompt_ids) + def _remove_dummy_data(self, gathered_rewards_per_func: torch.Tensor, + gathered_prompt_ids: List[str]) -> Tuple[torch.Tensor, List[str]]: + """Remove dummy data (prompt_id == '__dummy_pad__') from gathered data""" + valid_indices = [i for i, pid in enumerate(gathered_prompt_ids) if pid != '__dummy_pad__'] + valid_rewards_per_func = gathered_rewards_per_func[valid_indices] + valid_prompt_ids = [gathered_prompt_ids[i] for i in valid_indices] - # Create local advantages by filtering to original local data - assert rewards_per_func.shape[0] == total_rewards.shape[0] == total_advantages.shape[0] - return rewards_per_func, total_rewards, completions, total_advantages + return valid_rewards_per_func, valid_prompt_ids def _compute_advantages(self, rewards: torch.Tensor, prompt_ids: List[str]) -> torch.Tensor: """ @@ -2332,6 +2360,7 @@ def get_even_process_data(self, global_data: List[T]) -> List[T]: # Calculate the number of samples that need to be padded # This ensures all processes have the same number of samples for gather operations + self.rollout_pad_count = 0 if remainder > 0 and proc_idx >= remainder: # Processes with extra samples need padding self.rollout_pad_count = 1 From e25c2e419ff1481e7d83867410c6409e80a62c65 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 Aug 2025 17:09:15 +0800 Subject: [PATCH 21/26] fix split mini batch --- swift/plugin/multi_turn.py | 6 +++ swift/trainers/rlhf_trainer/grpo_trainer.py | 57 ++++++++------------- swift/trainers/rlhf_trainer/utils.py | 18 +++++++ 3 files changed, 44 insertions(+), 37 deletions(-) diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 4e22b7b26c..4327fc7e81 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -361,6 +361,7 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque # Prepare next turn ret = self.step(current_request, response_choice, current_turn) + current_request.messages = messages current_request: 'RolloutInferRequest' = ret['infer_request'] if current_request.messages[-1]['role'] == 'assistant': @@ -371,6 +372,11 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque return rollout_outputs + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: + # TODO: tool calling example + pass + def _parse_think_answer(self, content: str) -> tuple[str, str]: """ Parse think and answer content from assistant response. diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 4efad7d54b..4b82609df3 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -31,7 +31,7 @@ from trl import GRPOTrainer as HFGRPOTrainer from trl.models import prepare_deepspeed from trl.trainer.callbacks import SyncRefModelCallback -from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin, nanstd +from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd from trl.trainer.utils import selective_log_softmax from swift.llm import (InferRequest, MultiModelKeys, RequestConfig, RolloutInferRequest, RowPreprocessor, Template, @@ -48,7 +48,7 @@ from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin from .utils import (_ForwardRedirection, load_pil_img, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, - patch_profiling_decorator, replace_assistant_response_with_ids) + patch_profiling_decorator, patch_save_last_checkpoint, replace_assistant_response_with_ids) from .vllm_client import VLLMClient try: @@ -68,18 +68,8 @@ DataType = List[Dict[str, Union[torch.Tensor, Any]]] T = TypeVar('T') -# patch to fix save last_checkpoint https://github.com/modelscope/ms-swift/pull/4969 -if not hasattr(RepeatSampler, 'old_len_func'): - origin_len_func = RepeatSampler.__len__ - def patched_len(self) -> int: - return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count - - RepeatSampler.__len__ = patched_len - RepeatSampler.old_len_func = origin_len_func - - -class GRPOCallback(TrainerCallback): +class AsyncGenerateCallback(TrainerCallback): def __init__(self, trainer): self.trainer = trainer @@ -110,6 +100,7 @@ def __init__(self, reward_funcs: Optional[List[Union[str, Callable]]] = None, *_args, **kwargs): + patch_save_last_checkpoint() from swift.trainers.rlhf_arguments import GRPOConfig args: GRPOConfig = kwargs['args'] self.args = args @@ -347,7 +338,7 @@ def __init__(self, self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) if self.async_generate: - self.add_callback(GRPOCallback(self)) + self.add_callback(AsyncGenerateCallback(self)) if self.args.dynamic_sample or self.template.truncation_strategy == 'raise': self.resample_dataset = deepcopy(self.train_dataset) @@ -1076,25 +1067,17 @@ def split_by_mini_batches(self, inputs): mode = 'train' if self.model.training else 'eval' bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size spg = self.args.steps_per_generation if mode == 'train' else 1 - # TODO: Check - expected_normal_size = bs * spg - - # Check if we have the expected number of inputs (normal case) - if len(inputs) == expected_normal_size: - # Normal case: rollout returned expected count - # Group by (bs * num_generations) to maintain proper prompt grouping - group_size = bs * self.num_generations - spg_chunks = [inputs[i * group_size:(i + 1) * group_size] for i in range(spg)] - else: - # Variable generation case: split by actual per_device_batch_size to control memory - # Split into chunks of size bs to maintain memory efficiency - num_chunks = (len(inputs) + bs - 1) // bs # Ceiling division - spg_chunks = [] - for i in range(num_chunks): - start_idx = i * bs - end_idx = min((i + 1) * bs, len(inputs)) - spg_chunks.append(inputs[start_idx:end_idx]) + chunk_size = len(inputs) // spg + remainder = len(inputs) % spg + spg_chunks = [] + + start_idx = 0 + for i in range(spg): + current_chunk_size = chunk_size + (1 if i < remainder else 0) + end_idx = start_idx + current_chunk_size + spg_chunks.append(inputs[start_idx:end_idx]) + start_idx = end_idx return spg_chunks @@ -1115,11 +1098,11 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: for i, batch in enumerate(gas_chunks): # Encode and process each batch (size=bs) with self._template_context(template): - processed_batch = [ - replace_assistant_response_with_ids(data['messages'], data['response_token_ids']) - if 'response_token_ids' in data and data['response_token_ids'] else data for data in batch - ] - batch_encoded_inputs = [template.encode(data) for data in processed_batch] + if 'response_token_ids' in batch and batch['response_token_ids']: + batch['messages'] = replace_assistant_response_with_ids(batch['messages'], + batch['response_token_ids']) + + batch_encoded_inputs = [template.encode(data) for data in batch] batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) # Process labels and masks diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index aad195419e..76f2ee4587 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -302,3 +302,21 @@ def replace_assistant_response_with_ids(messages: 'Messages', completion_ids: Li completion_index += 1 return messages + + +def patch_save_last_checkpoint(): + import trl + from packaging import version + if version.parse(trl.__version__) >= version.parse('0.20'): + return + + # patch to fix save last_checkpoint https://github.com/modelscope/ms-swift/pull/4969 + from trl.trainer.grpo_trainer import RepeatSampler + if not hasattr(RepeatSampler, 'old_len_func'): + origin_len_func = RepeatSampler.__len__ + + def patched_len(self) -> int: + return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count + + RepeatSampler.__len__ = patched_len + RepeatSampler.old_len_func = origin_len_func From bf035b326034c164d583bf6f39f8f2c37eed44ad Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 Aug 2025 17:25:47 +0800 Subject: [PATCH 22/26] docstring for split mini batches --- swift/trainers/rlhf_trainer/grpo_trainer.py | 109 ++++++++++++++++---- 1 file changed, 91 insertions(+), 18 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 4b82609df3..373ab0eee1 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -950,6 +950,56 @@ def _remove_dummy_data(self, gathered_rewards_per_func: torch.Tensor, return valid_rewards_per_func, valid_prompt_ids + def _gather_tensors_with_padding(self, local_tensors: List[torch.Tensor], + device: torch.device) -> List[torch.Tensor]: + """Gather tensors across processes with padding to ensure consistent sizes""" + # Prepare for gather with padding + tensors = local_tensors.copy() + if self.rollout_pad_count > 0 and tensors: + # Create dummy tensors with the same shape as the first tensor + dummy_tensor = torch.full_like(tensors[0], torch.nan, device=device) + for _ in range(self.rollout_pad_count): + tensors.append(dummy_tensor) + + # Gather all tensors across processes + gathered_tensors = gather_object(tensors) + return gathered_tensors + + def _remove_padded_tensors(self, gathered_tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """Remove padded dummy tensors (NaN tensors) from gathered data""" + if not gathered_tensors: + return [] + + valid_tensors = [] + for tensor in gathered_tensors: + # Check if tensor is dummy (all NaN) + if not torch.isnan(tensor).all(): + valid_tensors.append(tensor) + + return valid_tensors + + def _gather_objects_with_padding(self, local_objects: List[Any]) -> List[Any]: + """Gather objects across processes with padding to ensure consistent sizes""" + # Prepare for gather with padding + objects = local_objects.copy() + if self.rollout_pad_count > 0: + # Add dummy objects + dummy_object = '__dummy_pad__' + for _ in range(self.rollout_pad_count): + objects.append(dummy_object) + + # Gather all objects across processes + gathered_objects = gather_object(objects) + return gathered_objects + + def _remove_padded_objects(self, gathered_objects: List[Any]) -> List[Any]: + """Remove padded dummy objects from gathered data""" + if not gathered_objects: + return [] + + valid_objects = [obj for obj in gathered_objects if obj != '__dummy_pad__'] + return valid_objects + def _compute_advantages(self, rewards: torch.Tensor, prompt_ids: List[str]) -> torch.Tensor: """ Compute advantages based on prompt_id grouping from gathered data. @@ -1055,27 +1105,37 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, adva return inputs, rewards, rewards_per_func, completions, advantages - def split_by_mini_batches(self, inputs): + def split_by_mini_batches(self, inputs: DataType) -> List[DataType]: """ Split inputs into mini-batches, handling variable generation counts. When rollout count differs from expected (bs * spg * num_generations), we need to adjust the splitting logic to maintain proper batch sizes. + + This method divides the input data into chunks based on the steps per generation (spg). + If the total number of inputs is not evenly divisible by spg, the remainder is + distributed across the first few chunks to ensure all data is included. + + Args: + inputs (DataType): List of input data samples to be split into mini-batches. + + Returns: + List[DataType]: A list of data chunks, where each chunk represents one step + in the generation process. The number of chunks equals spg. """ # Slice to keep only the local part of the data - mode = 'train' if self.model.training else 'eval' - bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size - spg = self.args.steps_per_generation if mode == 'train' else 1 + mode: str = 'train' if self.model.training else 'eval' + spg: int = self.args.steps_per_generation if mode == 'train' else 1 - chunk_size = len(inputs) // spg - remainder = len(inputs) % spg - spg_chunks = [] + chunk_size: int = len(inputs) // spg + remainder: int = len(inputs) % spg + spg_chunks: List[DataType] = [] - start_idx = 0 + start_idx: int = 0 for i in range(spg): - current_chunk_size = chunk_size + (1 if i < remainder else 0) - end_idx = start_idx + current_chunk_size + current_chunk_size: int = chunk_size + (1 if i < remainder else 0) + end_idx: int = start_idx + current_chunk_size spg_chunks.append(inputs[start_idx:end_idx]) start_idx = end_idx @@ -1083,7 +1143,7 @@ def split_by_mini_batches(self, inputs): def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: """ - Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. + Prepare the final batch inputs with ref/old_policy logps and other fields for RL training. Args: inputs (DataType): List of local input samples. @@ -1093,7 +1153,7 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: """ template = self.template - gas_chunks, _ = self.split_by_mini_batches(inputs) + gas_chunks = self.split_by_mini_batches(inputs) ga_batch_encoded_inputs = [] for i, batch in enumerate(gas_chunks): # Encode and process each batch (size=bs) @@ -1142,14 +1202,21 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) mode = 'train' if self.model.training else 'eval' device = self.accelerator.device - # Calculate completion length metrics - agg_completion_mask = gather(torch.cat([inp['completion_mask'].sum(1) for inp in inputs])) + # Gather completion masks with padding + local_completion_masks = [inp['completion_mask'].sum(1) for inp in inputs] + gathered_completion_masks = self._gather_tensors_with_padding(local_completion_masks, device) + valid_completion_masks = self._remove_padded_tensors(gathered_completion_masks) + agg_completion_mask = torch.cat(valid_completion_masks) self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item()) self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item()) self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item()) + # Calculate clip ratio - agg_truncated_mask = gather(torch.cat([inp['truncated_mask'] for inp in inputs]).to(device)) + local_truncated_masks = [inp['truncated_mask'] for inp in inputs] + gathered_truncated_masks = self._gather_tensors_with_padding(local_truncated_masks, device) + valid_truncated_masks = self._remove_padded_tensors(gathered_truncated_masks) + agg_truncated_mask = torch.cat(valid_truncated_masks).to(device) term_completion_mask = agg_completion_mask[agg_truncated_mask] clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask) @@ -1171,9 +1238,15 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) self._metrics[mode]['reward_std'].append(std_grouped_rewards.mean().item()) self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) - # Log prompt and completion texts - self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(gather_object(messages))) - self._logs['completion'].extend(gather_object(completions)) + # Log prompt and completion texts with padding + gathered_messages = self._gather_objects_with_padding(messages) + gathered_completions = self._gather_objects_with_padding(completions) + + valid_messages = self._remove_padded_objects(gathered_messages) + valid_completions = self._remove_padded_objects(gathered_completions) + + self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(valid_messages)) + self._logs['completion'].extend(valid_completions) if self.use_gym_env: pass From 60fb903bee20c78b741180bacda5443979768b3f Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Mon, 11 Aug 2025 18:01:17 +0800 Subject: [PATCH 23/26] fix gather device --- swift/trainers/rlhf_trainer/grpo_trainer.py | 42 +++++++++------------ 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 373ab0eee1..23f2407674 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -950,10 +950,10 @@ def _remove_dummy_data(self, gathered_rewards_per_func: torch.Tensor, return valid_rewards_per_func, valid_prompt_ids - def _gather_tensors_with_padding(self, local_tensors: List[torch.Tensor], - device: torch.device) -> List[torch.Tensor]: - """Gather tensors across processes with padding to ensure consistent sizes""" + def _gather_tensors(self, local_tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """Gather tensors across processes with padding and remove dummy data""" # Prepare for gather with padding + device = self.accelerator.device tensors = local_tensors.copy() if self.rollout_pad_count > 0 and tensors: # Create dummy tensors with the same shape as the first tensor @@ -962,11 +962,9 @@ def _gather_tensors_with_padding(self, local_tensors: List[torch.Tensor], tensors.append(dummy_tensor) # Gather all tensors across processes - gathered_tensors = gather_object(tensors) - return gathered_tensors + gathered_tensors = gather(tensors) - def _remove_padded_tensors(self, gathered_tensors: List[torch.Tensor]) -> List[torch.Tensor]: - """Remove padded dummy tensors (NaN tensors) from gathered data""" + # Remove padded dummy tensors (NaN tensors) from gathered data if not gathered_tensors: return [] @@ -974,12 +972,13 @@ def _remove_padded_tensors(self, gathered_tensors: List[torch.Tensor]) -> List[t for tensor in gathered_tensors: # Check if tensor is dummy (all NaN) if not torch.isnan(tensor).all(): - valid_tensors.append(tensor) + # Ensure tensor is on the correct device + valid_tensors.append(tensor.to(device)) return valid_tensors - def _gather_objects_with_padding(self, local_objects: List[Any]) -> List[Any]: - """Gather objects across processes with padding to ensure consistent sizes""" + def _gather_objects(self, local_objects: List[Any]) -> List[Any]: + """Gather objects across processes with padding and remove dummy data""" # Prepare for gather with padding objects = local_objects.copy() if self.rollout_pad_count > 0: @@ -990,10 +989,8 @@ def _gather_objects_with_padding(self, local_objects: List[Any]) -> List[Any]: # Gather all objects across processes gathered_objects = gather_object(objects) - return gathered_objects - def _remove_padded_objects(self, gathered_objects: List[Any]) -> List[Any]: - """Remove padded dummy objects from gathered data""" + # Remove padded dummy objects from gathered data if not gathered_objects: return [] @@ -1202,10 +1199,8 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) mode = 'train' if self.model.training else 'eval' device = self.accelerator.device - # Gather completion masks with padding local_completion_masks = [inp['completion_mask'].sum(1) for inp in inputs] - gathered_completion_masks = self._gather_tensors_with_padding(local_completion_masks, device) - valid_completion_masks = self._remove_padded_tensors(gathered_completion_masks) + valid_completion_masks = self._gather_tensors(local_completion_masks) agg_completion_mask = torch.cat(valid_completion_masks) self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item()) @@ -1214,9 +1209,8 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) # Calculate clip ratio local_truncated_masks = [inp['truncated_mask'] for inp in inputs] - gathered_truncated_masks = self._gather_tensors_with_padding(local_truncated_masks, device) - valid_truncated_masks = self._remove_padded_tensors(gathered_truncated_masks) - agg_truncated_mask = torch.cat(valid_truncated_masks).to(device) + valid_truncated_masks = self._gather_tensors(local_truncated_masks) + agg_truncated_mask = torch.cat(valid_truncated_masks) term_completion_mask = agg_completion_mask[agg_truncated_mask] clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask) @@ -1238,12 +1232,9 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) self._metrics[mode]['reward_std'].append(std_grouped_rewards.mean().item()) self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) - # Log prompt and completion texts with padding - gathered_messages = self._gather_objects_with_padding(messages) - gathered_completions = self._gather_objects_with_padding(completions) - - valid_messages = self._remove_padded_objects(gathered_messages) - valid_completions = self._remove_padded_objects(gathered_completions) + # Log prompt and completion texts with padding and remove dummy data + valid_messages = self._gather_objects(messages) + valid_completions = self._gather_objects(completions) self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(valid_messages)) self._logs['completion'].extend(valid_completions) @@ -2292,6 +2283,7 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out # Step 4: Store finish reason (used for truncation filters etc.) input_data['finish_reason'] = choice.finish_reason + input_data['is_truncated'] = choice.finish_reason == 'length' return input_data From 5108fcefebb11bcd1fdc37863ffb85c79e412f19 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Tue, 12 Aug 2025 11:01:02 +0800 Subject: [PATCH 24/26] fix rollout async infer --- .../infer/infer_engine/grpo_vllm_engine.py | 17 ++++- swift/llm/infer/rollout.py | 2 +- swift/plugin/multi_turn.py | 71 ++----------------- swift/trainers/rlhf_trainer/grpo_trainer.py | 34 +++++---- swift/trainers/rlhf_trainer/vllm_client.py | 3 +- 5 files changed, 42 insertions(+), 85 deletions(-) diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index a69748a395..7660f10531 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -151,7 +151,6 @@ def infer( use_tqdm: Optional[bool] = None, adapter_request: Optional[AdapterRequest] = None, ) -> List[RolloutOutput]: - assert not self.use_async_engine, 'for Async Engine, use infer_async instead' res = super().infer( infer_requests, request_config, @@ -170,6 +169,22 @@ def infer( return res + async def async_infer(self, + infer_requests: List[InferRequest], + request_config: Optional[RequestConfig] = None, + metrics: Optional[List[Metric]] = None, + *, + use_tqdm: Optional[bool] = None, + **kwargs) -> List[ChatCompletionResponse]: + if request_config is None: + request_config = RequestConfig() + assert request_config.n == 1 + + tasks = [self.infer_async(infer_request, request_config, **kwargs) for infer_request in infer_requests] + if use_tqdm is None: + use_tqdm = len(infer_requests) > 1 + return self._batch_infer_stream(tasks, request_config.stream, use_tqdm, metrics) + async def _batch_infer_stream(self, tasks, stream: bool = True, diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index bfbb6a6964..db7930b720 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -349,7 +349,7 @@ async def infer( if request_config.seed: request_config.seed += i * len(requests) kwargs = {'infer_requests': requests, 'request_config': request_config, 'use_tqdm': use_tqdm} - method = 'async_infer' if self.use_async_engine else 'infer' + method = 'infer' if not self.use_async_engine else 'async_infer' connection.send({'type': 'call', 'method': method, 'kwargs': kwargs}) all_outputs = [connection.recv() for connection in self.connections] diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 4327fc7e81..60278b9d87 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -287,35 +287,16 @@ class ThinkingModelScheduler(MultiTurnScheduler): 4. Returns List[RolloutOutput] with one output per round """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> List['RolloutOutput']: - """ - Execute multi-turn conversation for Thinking models. - - Args: - infer_request: The initial inference request containing messages - request_config: Configuration for the inference request - **kwargs: Additional inference parameters - - Returns: - List[RolloutOutput]: List of outputs, one for each round - """ from swift.llm.infer.protocol import RolloutOutput current_request = infer_request current_turn = 1 rollout_outputs = [] - last_think_content = '' while True: messages = current_request.messages - if current_turn == 1 or not messages[-1]['content']: - # If it's the first turn or the last message content is empty(dummy), remove the response - remove_response(messages) - # Get model response response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( current_request, request_config, **kwargs) @@ -323,11 +304,6 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque # Parse think and answer content completion = response_choice.message.content - think_content, answer_content = self._parse_think_answer(completion) - - # Update last think content - if think_content: - last_think_content = think_content # Update conversation history if messages[-1]['role'] == 'assistant': @@ -336,18 +312,13 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque messages.append({'role': 'assistant', 'content': completion}) # Build history for this round - round_history = self._build_round_history(messages, current_turn, last_think_content) + messages_with_last_think = self._build_messages(messages) # Create RolloutOutput for this round round_output = RolloutOutput( response=response, - messages=round_history, - rollout_infos={ - 'num_turns': current_turn, - 'think_content': think_content, - 'answer_content': answer_content, - 'round_number': current_turn - }) + messages=messages_with_last_think, + ) rollout_outputs.append(round_output) # Check stopping conditions @@ -372,38 +343,6 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque return rollout_outputs - def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', - current_turn: int) -> Dict: - # TODO: tool calling example - pass - - def _parse_think_answer(self, content: str) -> tuple[str, str]: - """ - Parse think and answer content from assistant response. - - Args: - content: Assistant response content - - Returns: - tuple: (think_content, answer_content) - """ - think_content = '' - answer_content = '' - - # Parse think content - think_start = content.find('') - think_end = content.find('') - if think_start != -1 and think_end != -1: - think_content = content[think_start + 7:think_end].strip() - - # Parse answer content - answer_start = content.find('') - answer_end = content.find('') - if answer_start != -1 and answer_end != -1: - answer_content = content[answer_start + 8:answer_end].strip() - - return think_content, answer_content - def _is_thinking_template(self) -> bool: """ Check if the model's template is a ThinkingTemplate. @@ -419,14 +358,12 @@ def _is_thinking_template(self) -> bool: return isinstance(template, ThinkingTemplate) - def _build_round_history(self, original_messages: 'Messages', round_num: int, think_content: str) -> 'Messages': + def _build_messages(self, original_messages: 'Messages') -> 'Messages': """ Build history for a specific round, keeping only the think content from the last round. Args: original_messages: Original conversation messages - round_num: Current round number - think_content: Think content to include Returns: Messages: History for this specific round diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 23f2407674..ab28292634 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -824,7 +824,7 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: local_advantages = self.get_even_process_data(total_advantages) assert len(local_advantages) == len(inputs) for i, advantage in enumerate(local_advantages): - inputs[i]['advantage'] = advantage + inputs[i]['advantages'] = advantage self._logs['advantages'].extend(total_advantages.tolist()) if any('images' in data and data['images'] is not None for data in inputs): @@ -1173,6 +1173,8 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: torch.tensor([b['is_truncated'] for b in batch], dtype=torch.bool), 'logits_to_keep': logits_to_keep, + 'advantages': + torch.stack([data['advantages'] for data in batch]) }) with torch.no_grad(): @@ -1199,21 +1201,23 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) mode = 'train' if self.model.training else 'eval' device = self.accelerator.device - local_completion_masks = [inp['completion_mask'].sum(1) for inp in inputs] - valid_completion_masks = self._gather_tensors(local_completion_masks) - agg_completion_mask = torch.cat(valid_completion_masks) + local_completion_lengths = [inp['completion_mask'].sum(1).to(device) for inp in inputs] - self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item()) - self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item()) - self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item()) + total_completion_lengths = self._gather_tensors(local_completion_lengths) + total_completion_lengths = torch.cat(total_completion_lengths) + + self._metrics[mode]['completions/mean_length'].append(total_completion_lengths.float().mean().item()) + self._metrics[mode]['completions/min_length'].append(total_completion_lengths.float().min().item()) + self._metrics[mode]['completions/max_length'].append(total_completion_lengths.float().max().item()) # Calculate clip ratio - local_truncated_masks = [inp['truncated_mask'] for inp in inputs] - valid_truncated_masks = self._gather_tensors(local_truncated_masks) - agg_truncated_mask = torch.cat(valid_truncated_masks) + local_truncated_masks = [inp['truncated_mask'].to(device) for inp in inputs] + total_truncated_masks = self._gather_tensors(local_truncated_masks) + total_truncated_masks = torch.cat(total_truncated_masks) - term_completion_mask = agg_completion_mask[agg_truncated_mask] - clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask) + num_truncated_samples = total_truncated_masks.sum().item() + num_total_samples = total_completion_lengths.shape[0] + clipped_completions_ratio = num_truncated_samples / num_total_samples self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio) @@ -1735,7 +1739,9 @@ def _engine_infer( """ with patch_profiling_context(self, 'generate'): if self.vllm_mode == 'server': - return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm) + return self.vllm_client.infer([asdict(req) for req in infer_requests], + asdict(request_config), + use_tqdm=use_tqdm) else: res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) if all(isinstance(r, RolloutOutput) for r in res): @@ -2190,7 +2196,7 @@ def _process_image_data(image_data: Union[dict, str]) -> str: for data in inputs: # Extract required metadata fields request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data} - + request_data['uuid'] = data['prompt_id'] # Preserve additional fields for multi-turn async scenarios if self.multi_turn_scheduler and self.vllm_use_async_engine: # data_dict is already concatenated inside async engine diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index f88bb074ce..89c0dcfecd 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -9,7 +9,6 @@ import requests import torch -from dacite import from_dict from packaging import version from requests import ConnectionError from torch import nn @@ -150,7 +149,7 @@ def process_chunk(i, chunk): return resp_data = response.json() - results[i] = [from_dict(data_class=RolloutOutput, data=resp) for resp in resp_data] + results[i] = [RolloutOutput.parse_obj(resp) for resp in resp_data] except Exception as e: errors[i] = e From 08121296deb99695f57e49ca7e9c7cb35f8d72c6 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Wed, 13 Aug 2025 16:05:41 +0800 Subject: [PATCH 25/26] thinking tips scheduler --- swift/plugin/multi_turn.py | 123 +++++++++++++++++-------------------- 1 file changed, 56 insertions(+), 67 deletions(-) diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 60278b9d87..4ca20298f3 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -273,22 +273,36 @@ def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: return False -class ThinkingModelScheduler(MultiTurnScheduler): +class ThinkingModelTipsScheduler(MultiTurnScheduler): """ - Scheduler for Thinking class models that handle multi-turn reasoning. + Scheduler for multi-turn reasoning with Thinking class models. - For Thinking models, the assistant response format is: - " think content answer content " + Key Features: + 1. Parses both "think" and "answer" content from each assistant response. + 2. For each round, only the "think" content from the last round is retained in the message history. + 3. Each round's conversation history is processed independently. + 4. Returns a list of RolloutOutput objects, one for each round. + 5. Please set `--loss_scale last_round` for training last round response. - This scheduler: - 1. Parses think and answer content from assistant responses - 2. Only keeps the think content from the last round - 3. Processes each round's history separately - 4. Returns List[RolloutOutput] with one output per round + The scheduler will automatically inject a tip prompt if the answer is incorrect, encouraging the model to recheck its reasoning. # noqa """ + from .orm import MathAccuracy + tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' + acc_func = MathAccuracy() async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> List['RolloutOutput']: + """ + Execute multi-turn inference for Thinking models. + + Args: + infer_request (RolloutInferRequest): The initial inference request containing the conversation history. + request_config (RequestConfig): Configuration for the inference request. + **kwargs: Additional arguments for the inference engine. + + Returns: + List[RolloutOutput]: A list of RolloutOutput objects, one for each reasoning round. + """ from swift.llm.infer.protocol import RolloutOutput current_request = infer_request @@ -297,59 +311,63 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque while True: messages = current_request.messages - # Get model response + # Obtain model response for the current turn response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( current_request, request_config, **kwargs) response_choice: 'ChatCompletionResponseChoice' = response.choices[0] - - # Parse think and answer content completion = response_choice.message.content - # Update conversation history - if messages[-1]['role'] == 'assistant': - messages[-1]['content'] += completion - else: - messages.append({'role': 'assistant', 'content': completion}) + # Append the assistant's response to the message history + messages.append({'role': 'assistant', 'content': completion}) - # Build history for this round + # Construct the message history for this round, keeping only the last "think" content messages_with_last_think = self._build_messages(messages) - # Create RolloutOutput for this round + # Create a RolloutOutput for the current round round_output = RolloutOutput( response=response, messages=messages_with_last_think, - ) + response_token_ids=response_choice.token_ids, + rollout_infos={'num_turns': current_turn}) + # Store the output for this round rollout_outputs.append(round_output) - # Check stopping conditions + # Determine whether to stop the multi-turn reasoning should_stop = self.check_finished(current_request, response_choice, current_turn) - if self.max_turns: - should_stop = should_stop or (current_turn >= self.max_turns) - if should_stop: break - # Prepare next turn + # Prepare for the next turn by updating the inference request ret = self.step(current_request, response_choice, current_turn) - current_request.messages = messages current_request: 'RolloutInferRequest' = ret['infer_request'] - - if current_request.messages[-1]['role'] == 'assistant': - # Add a dummy response to allow engine to continue generating - current_request.messages.append({'role': 'assistant', 'content': None}) - current_turn += 1 return rollout_outputs - def _is_thinking_template(self) -> bool: - """ - Check if the model's template is a ThinkingTemplate. + def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', + current_turn: int) -> bool: - Returns: - bool: True if the template is a ThinkingTemplate or its subclass - """ + last_query = infer_request.messages[-2]['content'] + # tips once + if self.tips_prompt in last_query: + return True + + completion = response_choice.message.content + solution = infer_request.data_dict['solution'] + acc = self.acc_func([completion], [solution])[0] + if acc == 1: + return True + + return super().check_finished(infer_request, response_choice, current_turn) + + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: + infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) + + return {'infer_request': infer_request} + + def _is_thinking_template(self) -> bool: if not hasattr(self.infer_engine, 'default_template'): return False @@ -450,34 +468,6 @@ def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompl return {'infer_request': infer_request} -class MathTipsMultiTurnScheduler(MultiTurnScheduler): - from .orm import MathAccuracy - tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' - acc_func = MathAccuracy() - - def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', - current_turn: int) -> bool: - - last_query = infer_request.messages[-2]['content'] - # we only give tips once - if self.tips_prompt in last_query: - return True - - completion = response_choice.message.content - solution = infer_request.data_dict['solution'] - acc = self.acc_func([completion], [solution])[0] - if acc == 1: - return True - - return super().check_finished(infer_request, response_choice, current_turn) - - def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', - current_turn: int) -> Dict: - infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) - - return {'infer_request': infer_request} - - class GYMScheduler(RolloutScheduler): def __init__(self, @@ -609,7 +599,6 @@ async def run(self, infer_request: 'RolloutInferRequest', request_config: 'Reque multi_turns = { 'base_scheduler': RolloutScheduler, 'math_tip_trick': MathTipsScheduler, - 'math_tip_trick_multi_turn': MathTipsMultiTurnScheduler, 'gym_scheduler': GYMScheduler, - 'thinking_scheduler': ThinkingModelScheduler, + 'thinking_tips_scheduler': ThinkingModelTipsScheduler, } From fc9d8edaf30413e77e176eaf9f0264f7a5013c77 Mon Sep 17 00:00:00 2001 From: hjh0119 Date: Thu, 14 Aug 2025 11:07:39 +0800 Subject: [PATCH 26/26] resolve dynamic sampling" --- swift/llm/argument/rlhf_args.py | 2 + swift/llm/infer/protocol.py | 9 +- swift/llm/infer/rollout.py | 25 ++- swift/plugin/multi_turn.py | 24 ++- swift/trainers/arguments.py | 1 + swift/trainers/rlhf_trainer/grpo_trainer.py | 168 ++++++++++---------- swift/trainers/rlhf_trainer/utils.py | 4 +- 7 files changed, 130 insertions(+), 103 deletions(-) diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 6fbb7b03fb..48e6496bed 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -254,12 +254,14 @@ def _init_external_vllm(self): return from swift.trainers.rlhf_trainer.vllm_client import VLLMClient if is_master(): + logger.info('Start connecting to vLLM server') self.vllm_client = VLLMClient( base_urls=self.vllm_server_base_url, hosts=self.vllm_server_host, server_ports=self.vllm_server_port, connection_timeout=self.vllm_server_timeout) self.vllm_client.init_communicator(device=get_current_device()) + logger.info('Connected to vLLM server') def _set_default(self): if self.beta is None: diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 87d4fe808c..440c915e7d 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -10,7 +10,7 @@ import json from PIL import Image -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from ..template import InferRequest from ..utils import Messages, Tool @@ -358,6 +358,13 @@ class RolloutOutput(BaseModel): response_loss_mask: List[List[int]] = Field(default_factory=list) rollout_infos: Dict[str, Any] = Field(default_factory=dict) + @field_validator('response_token_ids', 'response_loss_mask', mode='before') + @classmethod + def _wrap_flat_list(cls, v): + if isinstance(v, list) and v and isinstance(v[0], int): + return [v] + return v + def model_post_init(self, __context): # Ensure multimodal data in rollout_infos is serializable (e.g., images to base64) super().model_post_init(__context) diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index db7930b720..ad871dd055 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -105,7 +105,25 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int, connection: Connection) -> None: - engine = SwiftRolloutDeploy.get_infer_engine(args) + # Set required environment variables for DP to work with vLLM + args._import_external_plugins() + engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(None)) + + if args.multi_turn_scheduler: + if args.multi_turn_scheduler not in multi_turns: + raise ValueError(f"Multi-turn scheduler '{args.multi_turn_scheduler}' not found in multi_turns.") + scheduler_cls = multi_turns[args.multi_turn_scheduler] + + kwargs = {} + if 'tokenizer' in list(inspect.signature(scheduler_cls.__init__).parameters): + kwargs['tokenizer'] = engine.default_template.tokenizer + + rollout_engine: RolloutScheduler = scheduler_cls(engine, args.max_turns, **kwargs) + if not rollout_engine: + raise ValueError(f"Failed to initialize multi-turn scheduler '{args.multi_turn_scheduler}'.") + else: + rollout_engine = engine + # Send ready signal to parent process connection.send({'status': 'ready'}) @@ -122,8 +140,7 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast import traceback method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) - method = getattr(engine, method_name, None) or getattr(engine.engine, method_name, None) - + method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) try: result = await method(*args, **kwargs) except Exception: @@ -137,8 +154,6 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast def llm_worker_entry(*args, **kwargs): - rollout_args: RolloutArguments = args[0] - rollout_args._import_external_plugins() asyncio.run(async_llm_worker(*args, **kwargs)) diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index 4ca20298f3..6faecb10f5 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -43,7 +43,16 @@ async def _infer_async_single(infer_request: Union['RolloutInferRequest', Dict[s tasks = [_infer_async_single(infer_request, request_config, **kwargs) for infer_request in infer_requests] if use_tqdm is None: use_tqdm = len(infer_requests) > 1 - return await self.infer_engine._batch_infer_stream(tasks, request_config.stream, use_tqdm) + # Execute all tasks and flatten the results + results = await self.infer_engine._batch_infer_stream(tasks, request_config.stream, use_tqdm, None) + # Flatten the results since each task may return a list + flattened_results = [] + for result in results: + if isinstance(result, list): + flattened_results.extend(result) + else: + flattened_results.append(result) + return flattened_results async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', **kwargs) -> 'RolloutOutput': @@ -68,6 +77,8 @@ def __getattr__(self, key: str): infer_engine = object.__getattribute__(self, 'infer_engine') if hasattr(infer_engine, key): return getattr(infer_engine, key) + if hasattr(infer_engine.engine, key): + return getattr(infer_engine.engine, key) except AttributeError: raise AttributeError(f'{type(self).__name__} object has no attribute {key}') @@ -400,14 +411,13 @@ def __init__(self, messages): # Set up the template for inference mode template = self.infer_engine.default_template - original_is_training = getattr(template, 'is_training', False) - template.is_training = False - + # _swift_prepare_inputs will remove historical thinking content when in train mode, patch the mode here + original_mode = template.mode + template.mode = 'train' # Use the template's method to prepare messages template._swift_prepare_inputs(mock_inputs) - - # Restore original training state - template.is_training = original_is_training + # Restore original mode + template.mode = original_mode return mock_inputs.messages else: diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index b330ff69a3..0f0c55f2cb 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -293,6 +293,7 @@ class GRPOArgumentsMixin(VllmArguments): multi_turn_scheduler: Optional[str] = None max_turns: Optional[int] = None completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round' + vllm_server_pass_dataset: bool = False # DAPO, https://arxiv.org/abs/2503.14476 dynamic_sample: bool = False diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 60185bf246..a04c07c74d 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -815,13 +815,14 @@ def _generate_and_score_completions(self, inputs: DataType) -> DataType: inputs = self.resample_truncated_inputs(inputs) inputs = self._generate_completions(inputs) - total_rewards_per_func, total_rewards, completions, total_advantages = self._score_completions(inputs) + total_rewards_per_func, total_rewards, completions, total_advantages, rewards_std = self._score_completions( + inputs) mode = 'train' if self.model.training else 'eval' if self.args.dynamic_sample and mode == 'train': # dynamic sampling for std=0 groups inputs, total_rewards, total_rewards_per_func, completions, total_advantages = \ - self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions, total_advantages) + self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions, total_advantages, rewards_std) # noqa local_advantages = self.get_even_process_data(total_advantages) assert len(local_advantages) == len(inputs) @@ -883,9 +884,9 @@ def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tens total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) # Compute advantages based on prompt_id grouping - total_advantages = self._compute_advantages(total_rewards, total_prompt_ids) + total_advantages, rewards_std = self._compute_advantages(total_rewards, total_prompt_ids) - return total_rewards_per_func, total_rewards, completions, total_advantages + return total_rewards_per_func, total_rewards, completions, total_advantages, rewards_std def _compute_rewards_per_func(self, inputs: DataType, completions: List[str]) -> torch.Tensor: """Compute rewards using all reward functions""" @@ -1001,20 +1002,38 @@ def _gather_objects(self, local_objects: List[Any]) -> List[Any]: def _compute_advantages(self, rewards: torch.Tensor, prompt_ids: List[str]) -> torch.Tensor: """ - Compute advantages based on prompt_id grouping from gathered data. + Compute normalized advantages by grouping rewards based on prompt IDs. + + This method performs group-wise advantage computation where rewards are normalized + within each prompt group by subtracting the group mean. Optionally scales advantages + by group standard deviation for variance normalization. + + The computation process: + 1. Groups rewards by unique prompt_id + 2. Computes group-wise mean and standard deviation + 3. Calculates advantages as (reward - group_mean) + 4. Optionally normalizes by group standard deviation if scale_rewards is enabled + 5. Tracks training/evaluation metrics for monitoring Args: - rewards: Tensor of rewards from all processes, shape: (num_examples,) - prompt_ids: List of prompt_ids from all processes, len: num_examples + rewards (torch.Tensor): Reward values from all processes with shape (num_examples,) + prompt_ids (List[str]): Corresponding prompt identifiers with length num_examples Returns: - torch.Tensor: Computed advantages with same shape as rewards + tuple: A tuple containing: + - advantages (torch.Tensor): Computed advantages with same shape as rewards + - rewards_std (torch.Tensor): Group standard deviations for each sample """ assert rewards.shape[0] == len(prompt_ids) + mode = 'train' if self.model.training else 'eval' advantages = torch.zeros_like(rewards) + # calculate rewards_std for dynamic_sampling + rewards_std = torch.zeros_like(rewards) # Group rewards by prompt_id unique_prompt_ids = list(set(prompt_ids)) + group_rewards_mean = [] + group_rewards_std = [] for prompt_id in unique_prompt_ids: # Find all samples with this prompt_id @@ -1026,20 +1045,34 @@ def _compute_advantages(self, rewards: torch.Tensor, prompt_ids: List[str]) -> t # Compute group statistics group_mean = group_rewards.mean() + group_rewards_mean.append(group_mean) group_advantages = group_rewards - group_mean + group_std = group_rewards.std() + group_rewards_std.append(group_std) + rewards_std[indices] = group_std + # Optional: scale by standard deviation if self.args.scale_rewards: - group_std = group_rewards.std() group_advantages /= (group_std + 1e-4) # Assign computed advantages back to original positions for idx, advantage in zip(indices, group_advantages): advantages[idx] = advantage - return advantages + if group_rewards_mean: + # compute metrics + group_rewards_mean = torch.stack(group_rewards_mean) + group_rewards_std = torch.stack(group_rewards_std) + is_std_zero = torch.isclose(group_rewards_std, torch.zeros_like(group_rewards_std)) + + self._metrics[mode]['reward'].append(group_rewards_mean.mean().item()) + self._metrics[mode]['reward_std'].append(group_rewards_std.mean().item()) + self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) + + return advantages, rewards_std - def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, advantages): + def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, advantages, rewards_std): """ Perform dynamic sampling to replace samples with zero-reward-variance groups. @@ -1052,6 +1085,7 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, adva rewards_per_func: Rewards per function/model for global data samples completions: Generated completions for local inputs advantages: Computed advantages for global data samples + rewards_std: Group standard deviations for each sample Returns: tuple: (inputs, rewards, rewards_per_func, completions, advantages) @@ -1068,10 +1102,7 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, adva origin_data = (inputs, rewards, rewards_per_func, completions, advantages) while resample_count < self.args.max_resample_times: - grouped_rewards = rewards.view(-1, self.num_generations) - group_std = grouped_rewards.std(dim=1) - - valid_mask = (group_std > 0).repeat_interleave(self.num_generations) + valid_mask = (rewards_std > 0) all_inputs = gather_object(inputs) valid_samples.extend([inp for inp, mask in zip(all_inputs, valid_mask) if mask]) valid_rewards.append(rewards[valid_mask]) @@ -1085,7 +1116,7 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, adva inputs = next(self.dynamic_resample_iterator) inputs = Trainer._prepare_inputs(self, inputs) inputs = self._generate_completions(inputs) - rewards_per_func, rewards, completions, advantages = self._score_completions(inputs) + rewards_per_func, rewards, completions, advantages, rewards_std = self._score_completions(inputs) resample_count += 1 if len(valid_samples) >= self.args.generation_batch_size: @@ -1157,9 +1188,11 @@ def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: for i, batch in enumerate(gas_chunks): # Encode and process each batch (size=bs) with self._template_context(template): - if 'response_token_ids' in batch and batch['response_token_ids']: - batch['messages'] = replace_assistant_response_with_ids(batch['messages'], - batch['response_token_ids']) + [ + data.update( + {'messages': replace_assistant_response_with_ids(data['messages'], data['response_token_ids'])}) + for data in batch if 'response_token_ids' in data and data['response_token_ids'] + ] batch_encoded_inputs = [template.encode(data) for data in batch] batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) @@ -1229,15 +1262,6 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func) std_rewards = nanstd(rewards_per_func[:, i]).item() self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards) - # Log overall reward stats - grouped_rewards = rewards.view(-1, self.num_generations) - std_grouped_rewards = grouped_rewards.std(dim=1) - is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards)) - - self._metrics[mode]['reward'].append(grouped_rewards.mean().item()) - self._metrics[mode]['reward_std'].append(std_grouped_rewards.mean().item()) - self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) - # Log prompt and completion texts with padding and remove dummy data valid_messages = self._gather_objects(messages) valid_completions = self._gather_objects(completions) @@ -1933,43 +1957,6 @@ def is_async_generate_eval_rollout_done(self): def is_async_generate_train_rollout_done(self): return not self.train_queue.empty() - def inputs_to_rolloutrequest(self, inputs: DataType) -> List[RolloutInferRequest]: - """Convert a list of inputs to a list of RolloutInferRequest objects - - If the input contains a 'data_dict' key, it will be used as the base for the new data_dict. - For other keys, if they overlap with keys in data_dict, the values from data_dict will be used. - Non-overlapping keys will be added to data_dict. - - Args: - inputs: List of input dictionaries - - Returns: - List of RolloutInferRequest objects - """ - request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects'] - infer_requests = [] - - for request in inputs: - # Get the base data_dict if it exists in the input - base_data_dict = {} - if 'data_dict' in request: - if isinstance(request['data_dict'], dict): - base_data_dict = request['data_dict'] - else: - raise ValueError('data_dict exists but is not a dictionary') - - # Collect all non-request_keys items as extra fields - extra_data = {k: request[k] for k in request if k not in request_keys and k != 'data_dict'} - - # Merge the data_dict, keeping keys from base_data_dict as priority - final_data_dict = {**extra_data, **base_data_dict} - - # Create RolloutInferRequest instance - req_args = {k: request[k] for k in request_keys if k in request} - infer_requests.append(RolloutInferRequest(**req_args, data_dict=final_data_dict)) - - return infer_requests - @contextmanager def offload_context(self): if self.args.offload_model: @@ -1997,13 +1984,13 @@ def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: Adds a unique `prompt_id` to each input based on their `messages` content. Inputs with identical `messages` (assumed to be adjacent) will share the same `prompt_id`. + Each input also gets a unique `request_id` for vLLM request tracking. Args: inputs (DataType): A list of dictionaries, each containing a 'messages' key. - Returns: - DataType: The input list with each item containing a new 'prompt_id' field. + DataType: The input list with each item containing new 'prompt_id' and 'request_id' fields. Example: >>> inputs = [ @@ -2013,26 +2000,29 @@ def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: ... ] >>> self._add_prompt_id_to_inputs(inputs) [ - {"messages": [...], "data": 1, "prompt_id": "a1b2c3..."}, - {"messages": [...], "data": 2, "prompt_id": "a1b2c3..."}, - {"messages": [...], "data": 3, "prompt_id": "d4e5f6..."}, + {"messages": [...], "data": 1, "prompt_id": "a1b2c3...", "request_id": "req1"}, + {"messages": [...], "data": 2, "prompt_id": "a1b2c3...", "request_id": "req2"}, + {"messages": [...], "data": 3, "prompt_id": "d4e5f6...", "request_id": "req3"}, ] """ if not inputs: return inputs prev_messages = inputs[0].get('messages') - current_id = str(uuid.uuid4()) - inputs[0]['prompt_id'] = current_id + current_prompt_id = str(uuid.uuid4()) + inputs[0]['prompt_id'] = current_prompt_id + inputs[0]['request_id'] = str(uuid.uuid4()) # Each request gets a unique ID for i in range(1, len(inputs)): messages = inputs[i]['messages'] if messages == prev_messages: - inputs[i]['prompt_id'] = current_id + inputs[i]['prompt_id'] = current_prompt_id else: prev_messages = messages - current_id = str(uuid.uuid4()) - inputs[i]['prompt_id'] = current_id + current_prompt_id = str(uuid.uuid4()) + inputs[i]['prompt_id'] = current_prompt_id + # Each request always gets a unique request_id, regardless of prompt_id + inputs[i]['request_id'] = str(uuid.uuid4()) return inputs @@ -2171,6 +2161,7 @@ def inputs2requests(self, inputs: DataType) -> List[RolloutInferRequest]: Processing includes: - Image data conversion (bytes to base64, path handling) - Field filtering based on request metadata requirements + - UUID assignment using unique request_id for vLLM request tracking - Optional preservation of additional fields for multi-turn async scenarios """ @@ -2200,9 +2191,9 @@ def _process_image_data(image_data: Union[dict, str]) -> str: for data in inputs: # Extract required metadata fields request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data} - request_data['uuid'] = data['prompt_id'] + request_data['uuid'] = data['request_id'] # Use unique request_id for vLLM # Preserve additional fields for multi-turn async scenarios - if self.multi_turn_scheduler and self.vllm_use_async_engine: + if self.args.vllm_server_pass_dataset: # data_dict is already concatenated inside async engine extra_fields = {k: v for k, v in data.items() if k not in REQUEST_METADATA_FIELDS} if extra_fields: @@ -2240,11 +2231,12 @@ def _preprocess_inputs(self, inputs: DataType) -> DataType: Returns: Processed inputs with: - - Added prompt IDs for tracking + - Added prompt IDs for grouping (same messages share same prompt_id) + - Added unique request IDs for vLLM request tracking - Removed existing assistant responses from messages Processing Steps: - 1. Adds unique prompt IDs to each input for request tracking + 1. Adds prompt IDs and unique request IDs to each input 2. Cleans each message sequence by removing existing assistant responses """ processed_inputs = self._add_prompt_id_to_inputs(inputs) @@ -2258,7 +2250,7 @@ def _postprocess_rollout_outputs(self, inputs: DataType, outputs: List[RolloutOu """ Postprocess rollout outputs by merging them back into the input data structures. - Depending on the mode (async or sync), it either matches inputs by UUID + Depending on the mode (async or sync), it either matches inputs by request_id or assumes a one-to-one correspondence. """ @@ -2297,19 +2289,19 @@ def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], out return input_data - # Async engine mode: match by UUID + # Async engine mode: match by request_id if self.vllm_use_async_engine: results = [] id2inputs = {} for input_data in inputs: - uuid = input_data['uuid'] - if uuid not in id2inputs: - id2inputs[uuid] = deepcopy(input_data) + request_id = input_data['request_id'] + if request_id not in id2inputs: + id2inputs[request_id] = deepcopy(input_data) for output in outputs: - uuid = output.response.id - assert uuid not in id2inputs - input_data = deepcopy(id2inputs[uuid]) + request_id = output.response.id + assert request_id in id2inputs, f'Request ID {request_id} not found in inputs' + input_data = deepcopy(id2inputs[request_id]) results.append(merge_output_input_data(input_data, output)) return results @@ -2363,7 +2355,7 @@ def _sync_multi_turn_infer(self, inputs: DataType, first_turn_rollout_outputs: L # Determine which dialogues are finished should_stops = [ self.multi_turn_scheduler.check_finished(req, output.response.choices[0], current_turn) - for req, output in zip(self.inputs_to_rolloutrequest(inputs), outputs) + for req, output in zip(self.inputs2requests(inputs), outputs) ] # Prepare pending inputs for next turn diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index 76f2ee4587..7a12391ba7 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -258,8 +258,8 @@ def load_pil_img(img) -> Image: raise ValueError("Image dictionary must contain either 'bytes' or 'path' key.") -def replace_assistant_response_with_ids(messages: 'Messages', completion_ids: List[Union[int, - List[int]]]) -> 'Messages': +def replace_assistant_response_with_ids(messages: 'Messages', + completion_ids: List[Union[int, List[int]]]) -> 'Messages': # noqa """ Replaces the content of assistant messages with the provided completion IDs.