diff --git a/swift/llm/argument/rlhf_args.py b/swift/llm/argument/rlhf_args.py index 6fbb7b03fb..48e6496bed 100644 --- a/swift/llm/argument/rlhf_args.py +++ b/swift/llm/argument/rlhf_args.py @@ -254,12 +254,14 @@ def _init_external_vllm(self): return from swift.trainers.rlhf_trainer.vllm_client import VLLMClient if is_master(): + logger.info('Start connecting to vLLM server') self.vllm_client = VLLMClient( base_urls=self.vllm_server_base_url, hosts=self.vllm_server_host, server_ports=self.vllm_server_port, connection_timeout=self.vllm_server_timeout) self.vllm_client.init_communicator(device=get_current_device()) + logger.info('Connected to vLLM server') def _set_default(self): if self.beta is None: diff --git a/swift/llm/infer/infer_engine/grpo_vllm_engine.py b/swift/llm/infer/infer_engine/grpo_vllm_engine.py index 2fe7a310d4..7660f10531 100644 --- a/swift/llm/infer/infer_engine/grpo_vllm_engine.py +++ b/swift/llm/infer/infer_engine/grpo_vllm_engine.py @@ -6,14 +6,14 @@ import torch from tqdm.asyncio import tqdm_asyncio +from vllm.outputs import RequestOutput from swift.llm import InferRequest, RolloutInferRequest, Template, VllmEngine from swift.plugin import Metric, multi_turns from swift.plugin.context_manager import ContextManager, context_managers from swift.plugin.env import Env, envs from swift.plugin.multi_turn import MultiTurnScheduler -from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, GymRolloutResponseChoice, - RequestConfig, RolloutResponseChoice) +from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig, RolloutOutput from .utils import AdapterRequest try: @@ -150,9 +150,8 @@ def infer( template: Optional[Template] = None, use_tqdm: Optional[bool] = None, adapter_request: Optional[AdapterRequest] = None, - ) -> List[ChatCompletionResponse]: - assert not self.use_async_engine, 'for Async Engine, use infer_async instead' - return super().infer( + ) -> List[RolloutOutput]: + res = super().infer( infer_requests, request_config, metrics, @@ -160,9 +159,18 @@ def infer( use_tqdm=use_tqdm, adapter_request=adapter_request, ) + if not isinstance(res, list): + res = [res] + for i, result in enumerate(res): + if not isinstance(result, RolloutOutput): + if not isinstance(result, ChatCompletionResponse): + raise TypeError('Result must be a ChatCompletionResponse or RolloutOutput instance.') + res[i] = RolloutOutput(response=result) + + return res async def async_infer(self, - infer_requests: List[Union[RolloutInferRequest, Dict[str, Any]]], + infer_requests: List[InferRequest], request_config: Optional[RequestConfig] = None, metrics: Optional[List[Metric]] = None, *, @@ -172,150 +180,10 @@ async def async_infer(self, request_config = RequestConfig() assert request_config.n == 1 - async def _infer_async_single(infer_request: Union[RolloutInferRequest, Dict[str, Any]], - request_config: Optional[RequestConfig] = None, - **kwargs): - if isinstance(infer_request, Dict): - infer_request = RolloutInferRequest(**infer_request) - - # Route to appropriate sampling controller - if self.use_gym_env: - return await self._gym_sampling_controller(infer_request, request_config, **kwargs) - else: - return await self._multi_turn_sampling_controller(infer_request, request_config, **kwargs) - - tasks = [_infer_async_single(infer_request, request_config, **kwargs) for infer_request in infer_requests] + tasks = [self.infer_async(infer_request, request_config, **kwargs) for infer_request in infer_requests] if use_tqdm is None: use_tqdm = len(infer_requests) > 1 - return await self._batch_infer_stream(tasks, request_config.stream, use_tqdm, metrics) - - async def _gym_sampling_controller(self, infer_request: RolloutInferRequest, request_config: RequestConfig, - **kwargs) -> ChatCompletionResponse: - """Gym environment-based sampling controller.""" - # Create environment and context manager - env_config = infer_request.data_dict.get('env_config', {}) - env = self._create_env(env_config) - ctx_config = infer_request.data_dict.get('ctx_config', {}) - context_manager = self._create_context_manager(ctx_config) - - try: - # Environment reset - observation, info, system_message = await env.reset(infer_request) - - # Initialize conversation - messages = [] - if system_message: - messages.append({'role': 'system', 'content': system_message}) - messages.append({'role': 'user', 'content': observation}) - - current_request = deepcopy(infer_request) - current_turn = 1 - done = False - total_reward = 0.0 - step_rewards = [] - trajectory_id = f'{id(infer_request)}_{hash(str(infer_request))}' - trajectory_info = [info] - - while True: - # Apply context management - messages = context_manager.manage_context(messages, trajectory_id) - current_request.messages = messages - # Remove any previous assistant response for generation - InferRequest.remove_response(current_request.messages) - - # Generate LLM response - result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs) - result_choice: RolloutResponseChoice = result.choices[0] - - completion = result_choice.message.content - messages.append({'role': 'assistant', 'content': completion}) - - # Environment step - next_observation, reward, done, step_info = await env.step(deepcopy(messages)) - - # Accumulate rewards - total_reward += reward - step_rewards.append(reward) - trajectory_info.append(step_info) - - if done or current_turn > self.max_turns: - break - - messages.append({'role': 'user', 'content': next_observation}) - current_request.messages = messages - current_turn += 1 - - # Create final result with gym-specific information - final_choice = GymRolloutResponseChoice( - index=result_choice.index, - message=result_choice.message, - finish_reason=result_choice.finish_reason, - logprobs=result_choice.logprobs, - messages=messages, - trajectory_id=trajectory_id, - total_reward=total_reward, - step_rewards=step_rewards, - trajectory_info=trajectory_info) - - return ChatCompletionResponse( - model=self.model_name, choices=[final_choice], usage=result.usage, id=f'gym_{trajectory_id}') - - finally: - await self._close_env_async(env) - - async def _multi_turn_sampling_controller(self, infer_request: RolloutInferRequest, request_config: RequestConfig, - **kwargs) -> ChatCompletionResponse: - """Multi-turn scheduler-based sampling controller.""" - current_request = infer_request - current_turn = 1 - info_dict = {} - while True: - messages = current_request.messages - if current_turn == 1 or not messages[-1]['content']: - # If it's the first turn or the last message content is empty(dummy), remove the response - InferRequest.remove_response(messages) - - result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs) - result_choice: RolloutResponseChoice = result.choices[0] - - completion = result_choice.message.content - if messages[-1]['role'] == 'assistant': - messages[-1]['content'] += completion - else: - messages.append({'role': 'assistant', 'content': completion}) - - if self.multi_turn_scheduler: - should_stop = self.multi_turn_scheduler.check_finished(current_request, result_choice, current_turn) - else: - should_stop = True - - if self.max_turns: - should_stop = should_stop or (current_turn >= self.max_turns) - - if should_stop: - result_choice.messages = messages - info_dict['num_turns'] = current_turn - for key, value in info_dict.items(): - if key in ['images', 'audios', 'videos']: - value = MultiModalRequestMixin.to_base64(value) - if hasattr(result_choice, key): - setattr(result_choice, key, value) - else: - result_choice.multi_turn_infos[key] = value - return result - - ret = self.multi_turn_scheduler.step(current_request, result_choice, current_turn) - if isinstance(ret, tuple): - current_request, info_dict = ret - else: - current_request = ret - info_dict = {} - assert isinstance(current_request, RolloutInferRequest) - if current_request.messages[-1]['role'] == 'assistant': - # Add a dummy response to allow engine to continue generating - current_request.messages.append({'role': 'assistant', 'content': None}) - - current_turn += 1 + return self._batch_infer_stream(tasks, request_config.stream, use_tqdm, metrics) async def _batch_infer_stream(self, tasks, @@ -339,17 +207,7 @@ async def _new_run(task): new_tasks = [_new_run(task) for task in tasks] return await self.batch_run(new_tasks) - async def _close_env_async(self, env: Env): - """Asynchronously close environment.""" - try: - if hasattr(env, 'close') and asyncio.iscoroutinefunction(env.close): - await env.close() - elif hasattr(env, 'close'): - env.close() - except Exception: - pass - - def _create_chat_completion_response(self, result, template: Template, request_config, + def _create_chat_completion_response(self, result: 'RequestOutput', template: Template, request_config, request_id) -> ChatCompletionResponse: assert result is not None num_generated_tokens = sum(len(output.token_ids) for output in result.outputs) @@ -360,13 +218,7 @@ def _create_chat_completion_response(self, result, template: Template, request_c response = template.decode(output.token_ids) logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs) toolcall = self._get_toolcall(response, template) - - if self.use_gym_env: - choice_cls = GymRolloutResponseChoice - elif self.use_async_engine: - choice_cls = RolloutResponseChoice - else: - choice_cls = ChatCompletionResponseChoice + choice_cls = ChatCompletionResponseChoice token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None choice = choice_cls( diff --git a/swift/llm/infer/infer_engine/vllm_engine.py b/swift/llm/infer/infer_engine/vllm_engine.py index a345953051..9e12267c94 100644 --- a/swift/llm/infer/infer_engine/vllm_engine.py +++ b/swift/llm/infer/infer_engine/vllm_engine.py @@ -512,8 +512,10 @@ async def _infer_full_async( generation_config: SamplingParams, adapter_request: Optional[AdapterRequest], request_config: RequestConfig, + request_id: Optional[str] = None, ) -> Union[ChatCompletionResponse, EmbeddingResponse]: - request_id = random_uuid() + if request_id is None: + request_id = random_uuid() result_generator = self._add_request(inputs, generation_config, request_id, adapter_request=adapter_request) result = None async for result in result_generator: @@ -641,6 +643,9 @@ async def infer_async( 'adapter_request': adapter_request, 'request_config': request_config, } + if hasattr(infer_request, 'uuid') and infer_request.uuid: + # RolloutInferRequest + kwargs.update({'request_id': infer_request.uuid}) if pre_infer_hook: kwargs = pre_infer_hook(kwargs) if request_config.stream: diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index 067e4f5e10..440c915e7d 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -10,7 +10,7 @@ import json from PIL import Image -from pydantic import BaseModel +from pydantic import BaseModel, Field, field_validator from ..template import InferRequest from ..utils import Messages, Tool @@ -299,21 +299,6 @@ class EmbeddingResponse: created: int = field(default_factory=lambda: int(time.time())) -@dataclass -class RolloutResponseChoice(ChatCompletionResponseChoice): - messages: Optional[Messages] = None - images: Optional[List[str]] = None - multi_turn_infos: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class GymRolloutResponseChoice(RolloutResponseChoice): - trajectory_id: str = None - total_reward: float = 0.0 - step_rewards: List[float] = None - trajectory_info: List[Dict[str, Any]] = None - - @dataclass class CompletionResponseChoice: index: int @@ -325,7 +310,7 @@ class CompletionResponseChoice: @dataclass class ChatCompletionResponse: model: str - choices: List[Union[ChatCompletionResponseChoice, RolloutResponseChoice, GymRolloutResponseChoice]] + choices: List[ChatCompletionResponseChoice] usage: UsageInfo id: str = field(default_factory=lambda: f'chatcmpl-{random_uuid()}') object: str = 'chat.completion' @@ -339,6 +324,61 @@ def to_cmpl_response(self) -> 'CompletionResponse': return CompletionResponse(self.model, choices, self.usage, id_, created=self.created) +class RolloutOutput(BaseModel): + """ + Output structure for rollout. + + Attributes: + response (ChatCompletionResponse): + The model's response + + messages (Optional[Messages]): + (Optional) Conversation history for the final rollout; required for multi-turn scenarios. + NOTE: + - If provided, this messages sequence will overwrite the original messages. + - If not provided, 'response' will be appended as the latest turn in the original messages. + - For multi-turn training, you need to manually return the updated messages, including the full history. + - The messages should include the latest assistant response as the final message. + + response_token_ids (Optional[List[List[int]]]): + (Optional) Token IDs generated at each rollout turn. + If provided, the training process will skip tokenizing the response. + + response_loss_mask (Optional[List[List[int]]]): + (Optional) Loss masks corresponding to each rollout turn. + If provided, the training process will skip computing loss masks for the response (as controlled by the `loss_scale` parameter). # noqa + + rollout_infos (Dict[str, Any]): + (Optional) Additional rollout information. This must be JSON-serializable. + """ + response: ChatCompletionResponse + # multi turn + messages: Optional[Messages] = None + response_token_ids: List[List[int]] = Field(default_factory=list) + response_loss_mask: List[List[int]] = Field(default_factory=list) + rollout_infos: Dict[str, Any] = Field(default_factory=dict) + + @field_validator('response_token_ids', 'response_loss_mask', mode='before') + @classmethod + def _wrap_flat_list(cls, v): + if isinstance(v, list) and v and isinstance(v[0], int): + return [v] + return v + + def model_post_init(self, __context): + # Ensure multimodal data in rollout_infos is serializable (e.g., images to base64) + super().model_post_init(__context) + self.mminfo_to_serializable() + + def mminfo_to_serializable(self): + mm_keys = ['images', 'audios', 'videos'] + + for key, value in self.rollout_infos.items(): + if key in mm_keys: + # Convert multimodal content to base64 for serialization + self.rollout_infos[key] = MultiModalRequestMixin.to_base64(value) + + @dataclass class CompletionResponse: model: str diff --git a/swift/llm/infer/rollout.py b/swift/llm/infer/rollout.py index 0a02d0f038..ad871dd055 100644 --- a/swift/llm/infer/rollout.py +++ b/swift/llm/infer/rollout.py @@ -22,6 +22,7 @@ from swift.llm import RolloutArguments, SwiftPipeline from swift.llm.template.template_inputs import RolloutInferRequest +from swift.plugin.multi_turn import RolloutScheduler, multi_turns from swift.utils import get_logger from .infer_engine import GRPOVllmEngine, InferClient from .protocol import InitCommunicatorRequest, RequestConfig, UpdateWeightsRequest @@ -42,7 +43,6 @@ --vllm_tensor_parallel_size xxx \ --vllm_data_parallel_size xxx \ --vllm_use_async_engine true/false \ - --use_gym_env true/false \ --other_vllm_arguments Note: @@ -66,6 +66,20 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int os.environ['VLLM_DP_SIZE'] = str(args.vllm_data_parallel_size) os.environ['VLLM_DP_MASTER_PORT'] = str(master_port) engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(None)) + if args.multi_turn_scheduler: + if args.multi_turn_scheduler not in multi_turns: + raise ValueError(f"Multi-turn scheduler '{args.multi_turn_scheduler}' not found in multi_turns.") + scheduler_cls = multi_turns[args.multi_turn_scheduler] + + kwargs = {} + if 'tokenizer' in list(inspect.signature(scheduler_cls.__init__).parameters): + kwargs['tokenizer'] = engine.default_template.tokenizer + + rollout_engine: RolloutScheduler = scheduler_cls(engine, args.max_turns, **kwargs) + if not rollout_engine: + raise ValueError(f"Failed to initialize multi-turn scheduler '{args.multi_turn_scheduler}'.") + else: + rollout_engine = engine # Send ready signal to parent process connection.send({'status': 'ready'}) @@ -81,7 +95,7 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int if command['type'] in ['call', 'fire_and_forget']: method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) - method = getattr(engine, method_name, None) or getattr(engine.engine, method_name, None) + method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) result = method(*args, **kwargs) if command['type'] == 'call': connection.send(result) @@ -91,7 +105,25 @@ def llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, master_port: int, connection: Connection) -> None: - engine = SwiftRolloutDeploy.get_infer_engine(args) + # Set required environment variables for DP to work with vLLM + args._import_external_plugins() + engine = SwiftRolloutDeploy.get_infer_engine(args, template=args.get_template(None)) + + if args.multi_turn_scheduler: + if args.multi_turn_scheduler not in multi_turns: + raise ValueError(f"Multi-turn scheduler '{args.multi_turn_scheduler}' not found in multi_turns.") + scheduler_cls = multi_turns[args.multi_turn_scheduler] + + kwargs = {} + if 'tokenizer' in list(inspect.signature(scheduler_cls.__init__).parameters): + kwargs['tokenizer'] = engine.default_template.tokenizer + + rollout_engine: RolloutScheduler = scheduler_cls(engine, args.max_turns, **kwargs) + if not rollout_engine: + raise ValueError(f"Failed to initialize multi-turn scheduler '{args.multi_turn_scheduler}'.") + else: + rollout_engine = engine + # Send ready signal to parent process connection.send({'status': 'ready'}) @@ -108,7 +140,7 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast import traceback method_name = command['method'] args, kwargs = command.get('args', ()), command.get('kwargs', {}) - method = getattr(engine, method_name, None) or getattr(engine.engine, method_name, None) + method = getattr(rollout_engine, method_name, None) or getattr(rollout_engine.engine, method_name, None) try: result = await method(*args, **kwargs) except Exception: @@ -122,8 +154,6 @@ async def async_llm_worker(args: RolloutArguments, data_parallel_rank: int, mast def llm_worker_entry(*args, **kwargs): - rollout_args: RolloutArguments = args[0] - rollout_args._import_external_plugins() asyncio.run(async_llm_worker(*args, **kwargs)) @@ -193,7 +223,6 @@ def get_infer_engine(args: RolloutArguments, template=None, **kwargs): 'torch_dtype': args.torch_dtype, 'template': template, 'use_async_engine': args.vllm_use_async_engine, - 'multi_turn_scheduler': args.multi_turn_scheduler, 'max_turns': args.max_turns, 'use_gym_env': args.use_gym_env, 'gym_env': args.gym_env, @@ -335,7 +364,7 @@ async def infer( if request_config.seed: request_config.seed += i * len(requests) kwargs = {'infer_requests': requests, 'request_config': request_config, 'use_tqdm': use_tqdm} - method = 'async_infer' if self.use_async_engine else 'infer' + method = 'infer' if not self.use_async_engine else 'async_infer' connection.send({'type': 'call', 'method': method, 'kwargs': kwargs}) all_outputs = [connection.recv() for connection in self.connections] diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index d495506952..90bc0df300 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1040,6 +1040,23 @@ def _get_system(self, inputs) -> Optional[str]: return system def _swift_prepare_inputs(self, inputs): + """ + Preprocesses the list of messages in the input by merging and formatting consecutive messages + according to their roles. + + Specifically, this method: + - Merges consecutive messages from the same role ('assistant' or 'user') to prevent downstream errors. + - Detects consecutive tool-related messages following an assistant message, then formats and + combines them using `agent_template._format_tool_responses` for structured output. + - Updates the messages list in-place for further processing. + + Args: + inputs: An object containing a 'messages' attribute, which is a list of dictionaries. + Each message dictionary should have at least the keys 'role' and 'content'. + + Returns: + None. The input messages list is updated in-place. + """ messages = inputs.messages if len(messages) < 2: return diff --git a/swift/llm/template/template_inputs.py b/swift/llm/template/template_inputs.py index 88de356bea..fb4e603704 100644 --- a/swift/llm/template/template_inputs.py +++ b/swift/llm/template/template_inputs.py @@ -14,21 +14,44 @@ @dataclass class InferRequest: """ - messages: Input in messages format. - Examples: [{ - "role": "user", # or assistant/system/role - "content": [ # str or List[Dict[str, Any]] - { - "type": "image", # or audio/video - "image": "", - }, - {"type": "text", "text": "Please describe the picture."}, - ], - }] - The above content is equivalent to: - [{"role": "user", "content": "Please describe the picture."}] - and additionally passing in images: [""]. - tools: Organize tools into the format of agent_template for system. for example, 'react_en'. + Data structure for inference requests. + + Attributes: + messages (Messages): + The input conversation in messages format. Each message is a dict containing at least + a "role" field (e.g., "user", "assistant", "system") and a "content" field. + Example: + [{ + "role": "user", + "content": [ + { + "type": "image", # can also be audio/video + "image": "", + }, + {"type": "text", "text": "Please describe the picture."}, + ], + }] + The above is equivalent to: + [{"role": "user", "content": "Please describe the picture."}] + with an additional argument: + images = [""] + + images (List[Union[str, Image.Image]]): + Optional, a list of images associated with the request. + Each image can be a URL, local path, base64 string, or PIL.Image object. + + audios (List[str]): + Optional, a list of audio resources associated with the request. + + videos (List[str]): + Optional, a list of video resources associated with the request. + + tools (Optional[List[Tool]]): + An optional list of tools. These should be organized in the agent_template format for + tools requested by the system, for example 'react_en'. + + objects (Dict[str, List[Any]]): + Container for additional multimodal objects, grouped by type (key). """ messages: Messages @@ -75,12 +98,35 @@ def to_printable(self): @dataclass class RolloutInferRequest(InferRequest): """ - A request class that modifies the 'images' attribute - to be a list of strings for compatibility with POST requests. - The strings can represent image URLs or Base64 encoded images. + An inference request class for rollout scenarios. + + This class extends `InferRequest` and specifically overrides the `images` attribute + to be a list of strings for compatibility with POST requests. Each string may + represent an image URL or a Base64-encoded image. + + Inherits all fields from `InferRequest`: + messages (Messages): + Input conversation messages, supporting multimodal content. + audios (List[str]): + List of audio resources associated with the request. + videos (List[str]): + List of video resources associated with the request. + tools (Optional[List[Tool]]): + List of tools, organized by the agent template (e.g. 'react_en'). + objects (Dict[str, List[Any]]): + Optional container for additional multimodal objects. + + Additional / Overridden fields: + images (List[str]): + List of image resources, each as a string (URL or base64). + data_dict (Dict): + Optional dictionary for extra request data. + uuid (Optional[str]): + Optional unique identifier for this request instance. """ images: List[str] = field(default_factory=list) data_dict: Dict = field(default_factory=dict) + uuid: Optional[str] = None @dataclass diff --git a/swift/llm/utils.py b/swift/llm/utils.py index 18f21f8b97..faab9e333f 100644 --- a/swift/llm/utils.py +++ b/swift/llm/utils.py @@ -28,7 +28,7 @@ Tool = Dict[str, Union[str, Dict]] History = List[Union[Tuple[str, str], List[str]]] -Message = Dict[str, Union[str, List[Dict[str, Any]]]] +Message = Dict[str, Union[str, List[Dict[str, Any]], List[int]]] Messages = List[Message] diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index 7503784b10..561ef3b036 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -15,8 +15,8 @@ from .orm import orms, ORM from .multi_turn import multi_turns from .rm_plugin import rm_plugins - from .env import envs - from .context_manager import context_managers + from .env import envs, Env + from .context_manager import context_managers, ContextManager else: _import_structure = { @@ -31,8 +31,8 @@ 'orm': ['orms', 'ORM'], 'multi_turn': ['multi_turns'], 'rm_plugin': ['rm_plugins'], - 'env': ['env'], - 'context_manager': ['context_managers'] + 'env': ['envs', 'Env'], + 'context_manager': ['context_managers', 'ContextManager'], } import sys diff --git a/swift/plugin/multi_turn.py b/swift/plugin/multi_turn.py index faa1009780..6faecb10f5 100644 --- a/swift/plugin/multi_turn.py +++ b/swift/plugin/multi_turn.py @@ -1,39 +1,451 @@ -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +import asyncio +from abc import ABC +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from swift.plugin import ContextManager, Env, context_managers, envs +from swift.utils import remove_response if TYPE_CHECKING: - from swift.llm.infer.protocol import RolloutResponseChoice + from swift.llm.infer.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, RequestConfig, + RolloutOutput) from swift.llm.template import RolloutInferRequest + from swift.llm.infer.infer_engine import GRPOVllmEngine + from swift.llm.utils import Messages -class MultiTurnScheduler(ABC): - - def __init__(self, max_turns: Optional[int] = None, *args, **kwargs): +class RolloutScheduler(ABC): + # Single Turn Rollout Scheduler + def __init__(self, + infer_engine: Optional['GRPOVllmEngine'] = None, + max_turns: Optional[int] = None, + *args, + **kwargs): + self.infer_engine = infer_engine self.max_turns = max_turns - @abstractmethod - def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', - current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', Dict]]: - pass + async def async_infer(self, + infer_requests: List[Union['RolloutInferRequest', Dict[str, Any]]], + request_config: 'RequestConfig', + *, + use_tqdm: Optional[bool] = None, + **kwargs) -> List['ChatCompletionResponse']: + assert request_config.n == 1 + + async def _infer_async_single(infer_request: Union['RolloutInferRequest', Dict[str, Any]], + request_config: 'RequestConfig', **kwargs): + from swift.llm.template import RolloutInferRequest + if isinstance(infer_request, Dict): + infer_request = RolloutInferRequest(**infer_request) + + return await self.run(infer_request, request_config, **kwargs) + + tasks = [_infer_async_single(infer_request, request_config, **kwargs) for infer_request in infer_requests] + if use_tqdm is None: + use_tqdm = len(infer_requests) > 1 + # Execute all tasks and flatten the results + results = await self.infer_engine._batch_infer_stream(tasks, request_config.stream, use_tqdm, None) + # Flatten the results since each task may return a list + flattened_results = [] + for result in results: + if isinstance(result, list): + flattened_results.extend(result) + else: + flattened_results.append(result) + return flattened_results + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> 'RolloutOutput': + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async(infer_request, request_config, + **kwargs) + response_token_ids = response.choices[0].token_ids + response_loss_mask = [1] * len(response_token_ids) + return RolloutOutput( + response=response, + messages=infer_request.messages, + response_token_ids=[response_token_ids], + response_loss_mask=[response_loss_mask], + rollout_infos={'num_turns': 1}) + + def __getattr__(self, key: str): + try: + return object.__getattribute__(self, key) + except AttributeError: + pass + + try: + infer_engine = object.__getattribute__(self, 'infer_engine') + if hasattr(infer_engine, key): + return getattr(infer_engine, key) + if hasattr(infer_engine.engine, key): + return getattr(infer_engine.engine, key) + + except AttributeError: + raise AttributeError(f'{type(self).__name__} object has no attribute {key}') + + @property + def engine(self): + return self.infer_engine + + +class MultiTurnScheduler(RolloutScheduler, ABC): + """ + Abstract base class for multi-turn rollout scheduling. + + Provides default implementation for multi-turn conversation management with two customization approaches: + + 1. FULL CUSTOMIZATION: + Override the `run()` method to implement completely custom multi-turn logic. + - Gives full control over the rollout process + - Must handle all turn management and termination logic + + 2. PARTIAL CUSTOMIZATION: + Implement the required `step()` method and optionally override `check_finished()` + - Uses MultiTurnScheduler's run() method infrastructure + - Only need to implement turn transition logic in step() + - Optionally customize termination conditions + + Note: You must implement at least one of these approaches in your subclass. + + Options: + - If each round's response_token_ids are included in the RolloutOutput, + the Trainer can skip encoding the completion text into token_ids when calculating loss. + This avoids potential training inconsistencies due to asymmetric encode/decode behavior. + See: https://github.com/0russwest0/Agent-R1/issues/30#issuecomment-2826155367 + + - If both response_token_ids and response_loss_mask are returned in the RolloutOutput, + you can manually control the loss mask for each token. + The Trainer will use the provided loss_mask values directly when computing the loss. + Note: Returning response_loss_mask requires that response_token_ids are also returned, + as the two must be aligned in length for correct loss computation. + + You can refer to MathTipsScheduler as an example of how to use response_token_ids and response_loss_mask. + + Loss mask configuration: + During rollout, some parts of the completion (e.g., environment observations embedded in completion) + may need to be masked out from loss computation. + There are two supported strategies: + + 1. Use the built-in `loss_scale` parameter in ms-swift and do not return response token ids. + 2. Return response_token_ids along with a corresponding response_loss_mask (of equal length) to indicate the loss mask for each token. # noqa + """ + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> Union['RolloutOutput', List['RolloutOutput']]: + """Execute multi-turn conversation rollout with built-in turn management logic. + + This implements the default multi-turn interaction flow that can be overridden + to customize conversation handling behavior. The default logic provides: + + 1. Automatic conversation turn management and stopping conditions + 2. Seamless message accumulation across multiple turns + 3. Response token tracking and loss mask management + 4. Configurable early stopping mechanisms + + Args: + infer_request: The initial inference request containing conversation messages + request_config: Configuration parameters for the inference request + **kwargs: Additional inference parameters passed to the engine + + Returns: + RolloutOutput containing the complete conversation history and metadata, + or a list of outputs for batched requests + + Customization Approaches: + - Override check_finished() to implement custom stopping criteria + - Override step() to customize turn-to-turn transition logic + - Override this entire run() method for completely custom multi-turn behavior + + Important Notes: + - Method overriding is only supported when using server mode (swift rollout) + with vllm_use_async_engine=True + - Custom implementations must maintain async/await compatibility + - Ensure proper handling of conversation state across turns + + Example: + class CustomScheduler(MultiTurnScheduler): + async def run(self, infer_request, request_config, **kwargs): + # Implement custom multi-turn conversation logic + # Must return RolloutOutput or List[RolloutOutput] + ... + """ + + current_request = infer_request + current_turn = 1 + rollout_infos = {} + total_response_ids = [] + total_response_loss_mask = [] + while True: + messages = current_request.messages + if current_turn == 1 or not messages[-1]['content']: + # If it's the first turn or the last message content is empty(dummy), remove the response + remove_response(messages) + + # Get model response + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( + current_request, request_config, **kwargs) + response_choice: 'ChatCompletionResponseChoice' = response.choices[0] + + # Update conversation history + completion = response_choice.message.content + if messages[-1]['role'] == 'assistant': + messages[-1]['content'] += completion + else: + messages.append({'role': 'assistant', 'content': completion}) - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', + # Check stopping conditions + should_stop = self.check_finished(current_request, response_choice, current_turn) + + if self.max_turns: + should_stop = should_stop or (current_turn >= self.max_turns) + + if should_stop: + return RolloutOutput( + response=response, + messages=messages, + respones_token_ids=total_response_ids, + response_loss_mask=total_response_loss_mask, + rollout_infos=rollout_infos, + ) + + # Prepare next turn + ret = self.step(current_request, response_choice, current_turn) + current_request: 'RolloutInferRequest' = ret['infer_request'] + + # Track response tokens and masks + return_token_id = False + if 'response_token_ids' in ret: + total_response_ids.append(ret['response_token_ids']) + return_token_id = True + + if 'response_loss_mask' in ret: + assert return_token_id, 'You must return response_token_ids if you want to return response_loss_mask' + assert len(ret['response_loss_mask']) == len(ret['response_token_ids']), \ + 'response_loss_mask must have the same length as response_token_ids' + total_response_loss_mask.append(ret['response_loss_mask']) + + if 'rollout_infos' in ret: + rollout_infos = {**rollout_infos, **ret['rollout_infos']} + + if current_request.messages[-1]['role'] == 'assistant': + # Add a dummy response to allow engine to continue generating + current_request.messages.append({'role': 'assistant', 'content': None}) + + current_turn += 1 + + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: + """ + Handles transition between conversation turns. + + Args: + infer_request: Current inference request + response_choice: Response from current turn + current_turn: Current turn number + + Returns: + Dict[str, Any]: A dictionary containing inference results with the following structure: + - infer_request (required): Main inference request object + - response_token_ids (Optional[List[List[int]]]): Token IDs of responses for each rollout turn + - response_loss_scale (Optional[List[List[int]]]): Loss scaling factors for responses in each rollout turn # noqa + - rollout_infos (Optional[Dict[str, Any]]): Additional metadata (must be serializable) + + """ + raise NotImplementedError( + 'Please implement the `step` method in your MultiTurnScheduler subclass, or override the `run` method.') + + def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> bool: - if result.finish_reason == 'length': + """ + Default termination logic for checking if a multi-turn rollout should end. + + This method is invoked by: + - The base class MultiTurnScheduler.run() method, OR + - Custom run() methods when explicitly called + + Note: This is the default implementation that can be overridden by subclasses for custom termination logic. + + Termination Conditions: + 1. When response hits length limit (finish_reason == 'length') + 2. When conversation reaches max_turns (if max_turns is set) + + Args: + infer_request: The inference request object + response_choice: Contains generation results including finish_reason + current_turn: Current conversation turn count + + Returns: + bool: True to terminate conversation, False to continue + """ + if response_choice.finish_reason == 'length': return True if self.max_turns and current_turn >= self.max_turns: return True return False +class ThinkingModelTipsScheduler(MultiTurnScheduler): + """ + Scheduler for multi-turn reasoning with Thinking class models. + + Key Features: + 1. Parses both "think" and "answer" content from each assistant response. + 2. For each round, only the "think" content from the last round is retained in the message history. + 3. Each round's conversation history is processed independently. + 4. Returns a list of RolloutOutput objects, one for each round. + 5. Please set `--loss_scale last_round` for training last round response. + + The scheduler will automatically inject a tip prompt if the answer is incorrect, encouraging the model to recheck its reasoning. # noqa + """ + from .orm import MathAccuracy + tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' + acc_func = MathAccuracy() + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> List['RolloutOutput']: + """ + Execute multi-turn inference for Thinking models. + + Args: + infer_request (RolloutInferRequest): The initial inference request containing the conversation history. + request_config (RequestConfig): Configuration for the inference request. + **kwargs: Additional arguments for the inference engine. + + Returns: + List[RolloutOutput]: A list of RolloutOutput objects, one for each reasoning round. + """ + from swift.llm.infer.protocol import RolloutOutput + + current_request = infer_request + current_turn = 1 + rollout_outputs = [] + + while True: + messages = current_request.messages + # Obtain model response for the current turn + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( + current_request, request_config, **kwargs) + response_choice: 'ChatCompletionResponseChoice' = response.choices[0] + completion = response_choice.message.content + + # Append the assistant's response to the message history + messages.append({'role': 'assistant', 'content': completion}) + + # Construct the message history for this round, keeping only the last "think" content + messages_with_last_think = self._build_messages(messages) + + # Create a RolloutOutput for the current round + round_output = RolloutOutput( + response=response, + messages=messages_with_last_think, + response_token_ids=response_choice.token_ids, + rollout_infos={'num_turns': current_turn}) + # Store the output for this round + rollout_outputs.append(round_output) + + # Determine whether to stop the multi-turn reasoning + should_stop = self.check_finished(current_request, response_choice, current_turn) + + if should_stop: + break + + # Prepare for the next turn by updating the inference request + ret = self.step(current_request, response_choice, current_turn) + current_request: 'RolloutInferRequest' = ret['infer_request'] + current_turn += 1 + + return rollout_outputs + + def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', + current_turn: int) -> bool: + + last_query = infer_request.messages[-2]['content'] + # tips once + if self.tips_prompt in last_query: + return True + + completion = response_choice.message.content + solution = infer_request.data_dict['solution'] + acc = self.acc_func([completion], [solution])[0] + if acc == 1: + return True + + return super().check_finished(infer_request, response_choice, current_turn) + + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: + infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) + + return {'infer_request': infer_request} + + def _is_thinking_template(self) -> bool: + if not hasattr(self.infer_engine, 'default_template'): + return False + + template = self.infer_engine.default_template + from swift.llm.template.template.utils import ThinkingTemplate + + return isinstance(template, ThinkingTemplate) + + def _build_messages(self, original_messages: 'Messages') -> 'Messages': + """ + Build history for a specific round, keeping only the think content from the last round. + + Args: + original_messages: Original conversation messages + + Returns: + Messages: History for this specific round + """ + from copy import deepcopy + + # If this is a thinking template, use the template's method to prepare messages + if self._is_thinking_template(): + # Create a mock inputs object to use the template's _swift_prepare_inputs method + class MockInputs: + + def __init__(self, messages): + self.messages = deepcopy(messages) + + mock_inputs = MockInputs(original_messages) + + # Set up the template for inference mode + template = self.infer_engine.default_template + # _swift_prepare_inputs will remove historical thinking content when in train mode, patch the mode here + original_mode = template.mode + template.mode = 'train' + # Use the template's method to prepare messages + template._swift_prepare_inputs(mock_inputs) + # Restore original mode + template.mode = original_mode + + return mock_inputs.messages + else: + # Fallback to manual processing for non-thinking templates + round_messages = [] + + # Process messages in original order + for i, msg in enumerate(original_messages): + if msg['role'] == 'assistant' and isinstance(msg['content'], str) and i != len(original_messages) - 1: + # For assistant messages + assistant_no_think = msg['content'].split('')[-1].strip() + round_messages.append(assistant_no_think) + else: + round_messages.append(deepcopy(msg)) + + return round_messages + + class MathTipsScheduler(MultiTurnScheduler): tips_prompt = 'But wait... It seems I made a mistake,' - def __init__(self, max_turns=None, *args, **kwargs): + def __init__(self, tokenizer, *args, **kwargs): from .orm import MathAccuracy - super().__init__(max_turns, *args, **kwargs) + self.tokenizer = tokenizer + super().__init__(*args, **kwargs) self.acc_func = kwargs.get('acc_function', MathAccuracy()) - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', + def check_finished(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', current_turn: int) -> bool: last_completion = infer_request.messages[-1]['content'] # we only give tips once @@ -45,11 +457,11 @@ def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutR if acc == 1: return True - return super().check_finished(infer_request, result, current_turn) + return super().check_finished(infer_request, response_choice, current_turn) - def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', - current_turn: int) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', dict]]: - completion = result.message.content + def step(self, infer_request: 'RolloutInferRequest', response_choice: 'ChatCompletionResponseChoice', + current_turn: int) -> Dict: + completion = response_choice.message.content if '' in completion: completion = completion[:completion.index('')] if '' in completion: @@ -63,41 +475,140 @@ def step(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseCho else: infer_request.messages.append({'role': 'assistant', 'content': completion}) - return infer_request + return {'infer_request': infer_request} -class MathTipsMultiTurnScheduler(MultiTurnScheduler): - from .orm import MathAccuracy - tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' - acc_func = MathAccuracy() +class GYMScheduler(RolloutScheduler): - def check_finished(self, infer_request: 'RolloutInferRequest', result: 'RolloutResponseChoice', - current_turn: int) -> bool: + def __init__(self, + infer_engine: 'GRPOVllmEngine', + gym_env: Optional[str] = None, + context_manager_name: Optional[str] = None, + max_turns: Optional[int] = None, + **kwargs): + super().__init__(infer_engine, max_turns, **kwargs) + self.gym_env_name = gym_env + self.context_manager_name = context_manager_name - last_query = infer_request.messages[-2]['content'] - # we only give tips once - if self.tips_prompt in last_query: - return True + async def _create_env(self, env_config: Dict) -> Env: + """Create environment instance from configuration.""" + env_name = env_config.get('name', self.gym_env_name) + if env_name not in envs: + raise ValueError(f"Environment '{env_name}' not found. Available: {list(envs.keys())}") + return envs[env_name](env_config) - completion = result.message.content - solution = infer_request.data_dict['solution'] - acc = self.acc_func([completion], [solution])[0] - if acc == 1: - return True + async def _create_context_manager(self, ctx_config: Dict) -> ContextManager: + """Create context manager from configuration.""" + ctx_name = ctx_config.get('name', self.context_manager_name) - return super().check_finished(infer_request, result, current_turn) + if not ctx_name: + ctx_name = 'dummyContextManager' - def step( - self, - infer_request: 'RolloutInferRequest', - result: 'RolloutResponseChoice', - current_turn: int, - ) -> Union['RolloutInferRequest', Tuple['RolloutInferRequest', dict]]: - infer_request.messages.append({'role': 'user', 'content': self.tips_prompt}) - return infer_request + return context_managers[ctx_name](ctx_config) + + async def _close_env_async(self, env: Env): + """Safely close environment with async support.""" + try: + if hasattr(env, 'close') and asyncio.iscoroutinefunction(env.close): + await env.close() + elif hasattr(env, 'close'): + env.close() + except Exception: + # Handle any exceptions during environment closure + pass + + async def run(self, infer_request: 'RolloutInferRequest', request_config: 'RequestConfig', + **kwargs) -> 'ChatCompletionResponse': + from swift.llm.infer.protocol import ChatCompletionResponse, ChatCompletionResponseChoice + """ + Execute the gym environment-based rollout: + 1. Initialize environment and context manager + 2. Run multi-turn interactions between LLM and environment + 3. Collect trajectory information and rewards + """ + # Extract configurations from request + env_config = infer_request.data_dict.get('env_config', {}) + ctx_config = infer_request.data_dict.get('ctx_config', {}) + + # Create environment and context manager + env = await self._create_env(env_config) + context_manager = await self._create_context_manager(ctx_config) + + try: + # Initialize environment + observation, info, system_message = await env.reset(infer_request) + + # Build initial messages + messages: 'Messages' = [] + if system_message: + messages.append({'role': 'system', 'content': system_message}) + messages.append({'role': 'user', 'content': observation}) + + current_request = deepcopy(infer_request) + current_turn = 1 + done = False + total_reward = 0.0 + step_rewards = [] + trajectory_id = f'{id(infer_request)}_{hash(str(infer_request))}' + trajectory_info = [info] + + while not done and current_turn <= (self.max_turns or float('inf')): + # Apply context management (e.g., history compression) + messages = context_manager.manage_context(messages, trajectory_id) + current_request.messages = messages + remove_response(current_request.messages) + + response: 'ChatCompletionResponse' = await self.infer_engine.infer_async( + current_request, request_config, **kwargs) + response_choice: 'ChatCompletionResponseChoice' = response.choices[0] + completion = response_choice.message.content + messages.append({'role': 'assistant', 'content': completion}) + + # Execute environment step + next_obs, reward, done, step_info = await env.step(deepcopy(messages)) + + # Update trajectory information + total_reward += reward + step_rewards.append(reward) + trajectory_info.append(step_info) + + # Prepare for next turn + if not done: + messages.append({'role': 'user', 'content': next_obs}) + current_request.messages = messages + current_turn += 1 + + final_choice = ChatCompletionResponseChoice( + index=response_choice.index, + message=response_choice.message, + finish_reason=response_choice.finish_reason, + logprobs=response_choice.logprobs) + + last_response = ChatCompletionResponse( + model=self.infer_engine.model_name, + choices=[final_choice], + usage=response.usage, + id=f'gym_{trajectory_id}') + + return RolloutOutput( + response=last_response, + messages=messages, + rollout_infos={ + 'num_turns': current_turn, + 'trajectory_id': trajectory_id, + 'total_reward': total_reward, + 'step_rewards': step_rewards, + 'trajectory_info': trajectory_info + }) + + finally: + # Ensure environment is properly closed + await self._close_env_async(env) multi_turns = { + 'base_scheduler': RolloutScheduler, 'math_tip_trick': MathTipsScheduler, - 'math_tip_trick_multi_turn': MathTipsMultiTurnScheduler, + 'gym_scheduler': GYMScheduler, + 'thinking_tips_scheduler': ThinkingModelTipsScheduler, } diff --git a/swift/plugin/orm.py b/swift/plugin/orm.py index 53f53745ec..d8f2b30042 100644 --- a/swift/plugin/orm.py +++ b/swift/plugin/orm.py @@ -301,14 +301,12 @@ def __call__(self, completions, **kwargs) -> List[float]: class CosineReward(ORM): # https://arxiv.org/abs/2502.03373 def __init__(self, - tokenizer=None, cosine_min_len_value_wrong: float = -0.5, cosine_max_len_value_wrong: float = 0.0, cosine_min_len_value_correct: float = 1.0, cosine_max_len_value_correct: float = 0.5, cosine_max_len: int = 1000, accuracy_orm=None): - self.tokenizer = tokenizer self.min_len_value_wrong = cosine_min_len_value_wrong self.max_len_value_wrong = cosine_max_len_value_wrong self.min_len_value_correct = cosine_min_len_value_correct @@ -323,8 +321,9 @@ def cosfn(t, T, min_value, max_value): def __call__(self, completions, solution, **kwargs) -> List[float]: acc_rewards = self.accuracy_orm(completions, solution, **kwargs) + response_token_ids = kwargs.get('response_token_ids') rewards = [] - for content, acc_reward in zip(completions, acc_rewards): + for ids, acc_reward in zip(response_token_ids, acc_rewards): is_correct = acc_reward >= 1. if is_correct: # Swap min/max for correct answers @@ -333,7 +332,7 @@ def __call__(self, completions, solution, **kwargs) -> List[float]: else: min_value = self.max_len_value_wrong max_value = self.min_len_value_wrong - gen_len = len(self.tokenizer.encode(content)) + gen_len = len(ids) reward = self.cosfn(gen_len, self.max_len, min_value, max_value) rewards.append(reward) return rewards @@ -380,16 +379,16 @@ def __call__(self, completions, **kwargs) -> List[float]: class SoftOverlong(ORM): - def __init__(self, tokenizer, soft_max_length, soft_cache_length): - self.tokenizer = tokenizer + def __init__(self, soft_max_length, soft_cache_length): assert soft_cache_length < soft_max_length self.soft_max_length = soft_max_length self.soft_cache_length = soft_cache_length def __call__(self, completions, **kwargs) -> List[float]: rewards = [] - for completion in completions: - completion_length = len(self.tokenizer.encode(completion)) + response_token_ids = kwargs.get('response_token_ids') + for ids in response_token_ids: + completion_length = len(ids) expected_len = self.soft_max_length - self.soft_cache_length exceed_len = completion_length - expected_len rewards.append(min(-exceed_len / self.soft_cache_length, 0)) diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index b330ff69a3..0f0c55f2cb 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -293,6 +293,7 @@ class GRPOArgumentsMixin(VllmArguments): multi_turn_scheduler: Optional[str] = None max_turns: Optional[int] = None completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round' + vllm_server_pass_dataset: bool = False # DAPO, https://arxiv.org/abs/2503.14476 dynamic_sample: bool = False diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 8ca8cce5a0..a04c07c74d 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1,10 +1,12 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/trl. +import base64 import concurrent.futures import inspect import os import re import time +import uuid from collections import defaultdict, deque from concurrent.futures import Future from contextlib import contextmanager, nullcontext @@ -12,13 +14,15 @@ from dataclasses import asdict, dataclass, field from math import ceil from queue import Queue +from threading import local from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch import torch.nn as nn import transformers from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from dacite import from_dict from packaging import version from torch.nn import ModuleList from torch.utils.data import DataLoader @@ -28,23 +32,24 @@ from trl.models import prepare_deepspeed from trl.trainer import grpo_trainer from trl.trainer.callbacks import SyncRefModelCallback -from trl.trainer.grpo_trainer import RepeatSampler, nanmax, nanmin, nanstd +from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd from trl.trainer.utils import selective_log_softmax from swift.llm import (InferRequest, MultiModelKeys, RequestConfig, RolloutInferRequest, RowPreprocessor, Template, to_device) -from swift.llm.infer.protocol import ChatCompletionResponse +from swift.llm.infer.protocol import ChatCompletionResponse, RolloutOutput from swift.llm.model.utils import get_llm_model from swift.llm.template.base import MaxLengthError from swift.llm.template.template_inputs import StdTemplateInputs from swift.plugin import multi_turns, orms, rm_plugins from swift.plugin.multi_turn import MultiTurnScheduler from swift.utils import (JsonlWriter, empty_cache, get_current_device, get_logger, is_swanlab_available, - is_vllm_available, is_wandb_available, seed_worker, unwrap_model_for_generation) + is_vllm_available, is_wandb_available, remove_response, seed_worker, + unwrap_model_for_generation) from ..mixin import SwiftMixin from .rlhf_mixin import RLHFTrainerMixin from .utils import (_ForwardRedirection, load_pil_img, patch_lora_merge, patch_lora_unmerge, patch_profiling_context, - patch_profiling_decorator) + patch_profiling_decorator, patch_save_last_checkpoint, replace_assistant_response_with_ids) from .vllm_client import VLLMClient try: @@ -62,20 +67,11 @@ if is_swanlab_available(): import swanlab -InputsType = List[Dict[str, Union[torch.Tensor, Any]]] -# tuple: (messages, finish_reason) -OutputsType = List[Tuple[List[Dict], str]] -if not hasattr(RepeatSampler, 'old_len_func'): - origin_len_func = RepeatSampler.__len__ +DataType = List[Dict[str, Union[torch.Tensor, Any]]] +T = TypeVar('T') - def patched_len(self) -> int: - return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count - RepeatSampler.__len__ = patched_len - RepeatSampler.old_len_func = origin_len_func - - -class GRPOCallback(TrainerCallback): +class AsyncGenerateCallback(TrainerCallback): def __init__(self, trainer): self.trainer = trainer @@ -89,8 +85,7 @@ def on_train_begin(self, args, state, control, **kwargs): @dataclass class DataCache: - inputs: List[Dict] = field(default_factory=list) - outputs: List[Dict] = field(default_factory=list) + results: DataType def identity_data_collator(features): @@ -107,6 +102,7 @@ def __init__(self, reward_funcs: Optional[List[Union[str, Callable]]] = None, *_args, **kwargs): + patch_save_last_checkpoint() from swift.trainers.rlhf_arguments import GRPOConfig args: GRPOConfig = kwargs['args'] self.args = args @@ -172,7 +168,7 @@ def __init__(self, multi_turn_scheduler = multi_turns[self.args.multi_turn_scheduler](max_turns=self.args.max_turns) self.multi_turn_scheduler: MultiTurnScheduler = multi_turn_scheduler else: - assert isinstance(multi_turn_scheduler, MultiTurnScheduler) + assert isinstance(self.args.multi_turn_scheduler, MultiTurnScheduler) self.multi_turn_scheduler: MultiTurnScheduler = self.args.multi_turn_scheduler self.num_generations = args.num_generations @@ -317,7 +313,7 @@ def __init__(self, top_k=args.top_k, repetition_penalty=args.repetition_penalty, stop=args.stop_words, - ) + return_details=True) # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set @@ -344,7 +340,7 @@ def __init__(self, self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) if self.async_generate: - self.add_callback(GRPOCallback(self)) + self.add_callback(AsyncGenerateCallback(self)) if self.args.dynamic_sample or self.template.truncation_strategy == 'raise': self.resample_dataset = deepcopy(self.train_dataset) @@ -389,6 +385,8 @@ def single_sample_context(): self.truncated_resample_iterator = cyclic_iter(self.get_train_dataloader()) # flag indicating whether the evaluation has started self.eval_flag = False + # Record the number of samples that need to be padded for even distribution across processes + self.rollout_pad_count = 0 @patch_profiling_decorator def _prepare_inputs(self, generation_batch: dict[str, Union[torch.Tensor, @@ -610,73 +608,16 @@ def _wait_queue(self): while self._queue.empty(): time.sleep(0.01) - def _infer(self, - inputs: Optional[InputsType], - request_config: RequestConfig, - is_global_inputs: bool = False) -> List[ChatCompletionResponse]: + def _rollout(self, + inputs: Optional[DataType], + request_config: RequestConfig, + is_global_inputs: bool = False) -> List[RolloutOutput]: request_config = self._get_request_config() - # keys from InferRequest - per_device_size = len(inputs) - if is_global_inputs: - per_device_size //= self.accelerator.num_processes if self.vllm_mode == 'server': - # for server mode, we gather all the inputs and send to remote vllm server in main process - if is_global_inputs: - # async generate, pre-gather to avoid potential communicate operator - all_inputs = inputs - all_input_lengths = [per_device_size] + [0] * (self.accelerator.num_processes - 1) - else: - all_inputs = gather_object(inputs) - all_input_lengths = gather_object([len(inputs)]) - - if not any(inputs for inputs in all_inputs): - return [] - - if self.accelerator.is_main_process: - results: List[ChatCompletionResponse] = self._engine_infer( - infer_requests=all_inputs, request_config=request_config) - else: - results = [None] * len(all_inputs) - # Broadcast the results from the main process to all processes, - # ensuring each process receives its corresponding slice. - if not is_global_inputs: - results = broadcast_object_list(results, from_process=0) - start_idx = sum(all_input_lengths[:self.accelerator.process_index]) - end_idx = start_idx + all_input_lengths[self.accelerator.process_index] - results = results[start_idx:end_idx] - else: - results = results if self.accelerator.is_main_process else [] + rollout_outputs = self._server_rollout(inputs, request_config, is_global_inputs) else: - # pt / vllm colocate - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - # Note: The input sizes may differ across ranks (e.g., in multi-turn scenarios, - # the amount of data each rank continues to process may vary). - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - local_input_length = len(inputs) - all_input_lengths = [None] * self.vllm_tensor_parallel_size - torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.tp_group) - start_idx = sum(all_input_lengths[:local_rank_in_group]) - end_idx = start_idx + all_input_lengths[local_rank_in_group] - - # orig_size = len(inputs)/ - gathered_inputs = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_inputs, inputs, group=self.tp_group) - inputs = [p for sublist in gathered_inputs for p in sublist] - # Set request_config.seed - # 1. Ensure that the seed for vLLM Engines within each TP (Tensor Parallelism) group is the same; - # otherwise, the program may hang. - # 2. Ensure that the seed for vLLM Engines across different TP groups is different; - # otherwise, identical completions will be generated. - results: List[ChatCompletionResponse] = self._engine_infer( - infer_requests=inputs, request_config=request_config) - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - results = results[start_idx:end_idx] - return results + rollout_outputs = self._colocate_rollout(inputs, request_config) + return rollout_outputs def _get_request_config(self) -> RequestConfig: request_config = copy(self.request_config) @@ -697,142 +638,66 @@ def _get_request_config(self) -> RequestConfig: return request_config - def _set_inputs_system(self, inputs: InputsType) -> InputsType: + def _set_inputs_system(self, inputs: DataType) -> DataType: + """ + Inserts a default system message at the beginning of each input if specified. + + If a default system message is defined in the template and the first message in + an input is not already a system message, this method inserts the default system + message at the beginning of the messages list for each input. If no default system + message is provided, no modification is made. + + Args: + inputs (DataType): A list of input data entries, each containing a 'messages' field. + + Returns: + DataType: The input list, with the default system message prepended if applicable. + """ if not self.template.template_meta.default_system: - return + return inputs if all(_input['messages'][0]['role'] == 'system' for _input in inputs): - return + return inputs for _input in inputs: messages = _input['messages'] if messages[0]['role'] != 'system': messages.insert(0, {'role': 'system', 'content': self.template.template_meta.default_system}) + return inputs def _infer_single_or_multi_turn(self, - inputs: InputsType, + inputs: DataType, request_config: RequestConfig, - is_global_inputs: bool = False) -> OutputsType: - """Perform multi-turn or single-turn inference + is_global_inputs: bool = False) -> List[DataType]: + """ + Runs inference for either single-turn or multi-turn dialogue. Args: - inputs: list of input requests - request_config: Inference configuration parameters - is_global_inputs: - A boolean indicating whether the inputs are global. When set to True, - the returned results in the main process will be a complete list of - global_outputs, while other processes will return an empty list []. + inputs: Input data for inference. + request_config: Configuration for the inference request. + is_global_inputs: Whether the inputs are from the global process. + Returns: - List of outputs where each entry contains: - - List of responses per prompt - - Each response is a tuple of (message_history, finish_reason) + List of processed outputs. """ - self._set_inputs_system(inputs) - # infer first turn - results: List[ChatCompletionResponse] = self._infer(inputs, request_config, is_global_inputs) - outputs = [] + # for external server, pass the system args which may define in trainer + + # Step 1: Prepare inputs with system prompts (if any) + inputs = self._set_inputs_system(inputs) + + # Step 2: First-turn rollout + rollout_outputs: List[RolloutOutput] = self._rollout(inputs, request_config, is_global_inputs) + + # Step 3: Handle single-turn (no scheduler, no async engine) if not self.multi_turn_scheduler and not self.vllm_use_async_engine: - # message concatenation - for i, output in enumerate(results): - _choices = [] - for choice in output.choices: - _input: Dict = deepcopy(inputs[i]) - InferRequest.remove_response(_input['messages']) - _input['messages'].append({'role': 'assistant', 'content': choice.message.content}) - _choices.append((_input['messages'], choice.finish_reason, {})) - outputs.append(_choices) - outputs = [item for sublist in outputs for item in sublist] - else: - # vLLMAsyncLLMEngine, only server mode is supported right now. - # NOTE: The message concatenation has already been done in the engine. - if self.vllm_use_async_engine: - for i, output in enumerate(results): - _choices = [] - for choice in output.choices: - # concated in Engine - if self.use_gym_env: - _choices.append( - (choice.messages, choice.finish_reason, choice.total_reward, choice.trajectory_info)) - else: - _choices.append((choice.messages, choice.finish_reason)) - outputs.append(_choices) - outputs = [item for sublist in outputs for item in sublist] - else: - # PTEngine or vLLMLLMEngine - orig_size = len(inputs) - outputs = [None] * orig_size - # we remove origin response in first turn - current_turn = 1 - while True: - has_local_data = len(inputs) > 0 - has_global_data = gather_object([has_local_data]) - if not any(has_global_data): - break - # inputs for current turn - current_inputs = [] - cnt = 0 - # combine completions from results with messages - for i, output in enumerate(results): - for choice in output.choices: - current_input = deepcopy(inputs[i]) - messages = current_input['messages'] - - if current_turn == 1 or not messages[-1]['content'] or messages[-1]['content'] == '': - # first turn or the last message content is empty(dummy), remove the response - InferRequest.remove_response(messages) - if messages[-1]['role'] == 'assistant': - # If the last message was assistant, concatenate the new content to it - messages[-1]['content'] += choice.message.content - else: - # append a new message from the assistant - messages.append({'role': 'assistant', 'content': choice.message.content}) - - if 'index' not in current_input: - current_input['index'] = cnt - current_input['finish_reason'] = choice.finish_reason - cnt += 1 - current_inputs.append(current_input) - - # Process messages in the multi-turn function - should_stops = [ - self.multi_turn_scheduler.check_finished(request, result.choices[0], current_turn) - for request, result in zip(self.inputs_to_rolloutrequest(current_inputs), results) - ] + return self._postprocess_rollout_outputs(inputs, rollout_outputs) - # Retain messages that are not yet finished for the next round of rollout - pending_inputs = [] - for stop, _input, result in zip(should_stops, current_inputs, results): - index = _input['index'] - if stop: - outputs[index] = (_input['messages'], _input['finish_reason'], - _input.get('multi_turn_infos', {'num_turns': 1})) - else: - current_request = self.inputs_to_rolloutrequest([_input])[0] - ret = self.multi_turn_scheduler.step(current_request, result.choices[0], current_turn) - if isinstance(ret, tuple): - infer_request, info_dict = ret - else: - infer_request = ret - info_dict = {} - info_dict['num_turns'] = current_turn + 1 - pending_input = asdict(infer_request) - if 'multi_turn_infos' not in pending_input: - pending_input['multi_turn_infos'] = {} - for key, value in info_dict.items(): - pending_input['multi_turn_infos'][key] = value - - pending_input['index'] = index - pending_inputs.append(pending_input) - - current_infer_inputs = pending_inputs if has_local_data else [] - results = self._infer(current_infer_inputs, request_config) - - inputs = pending_inputs - current_turn += 1 - assert not any([o is None for o in outputs]) - - # flatten 2D list to 1D list - return outputs + # Step 4: Handle async engine (multi-turn handled inside the engine) + if self.vllm_use_async_engine: + return self._postprocess_rollout_outputs(inputs, rollout_outputs) - def async_infer(self, all_inputs): + # Step 5: Handle multi-turn locally + return self._sync_multi_turn_infer(inputs, rollout_outputs, request_config) + + def async_generate_rollout(self, all_inputs): current_queue = self._queue def infer_task(): @@ -850,9 +715,9 @@ def infer_task(): def done(future): try: result = future.result() - current_queue.put(DataCache(all_inputs, result)) + current_queue.put(DataCache(result)) except Exception as e: - logger.error('Error in async_infer callback: %s', str(e)) + logger.error('Error in async_generate_rollout callback: %s', str(e)) future.add_done_callback(done) @@ -862,54 +727,60 @@ def _prefetch(self, dataloader: DataLoader): if self.state.global_step != self._last_loaded_step: self._move_model_to_vllm(skip_async_check=True) self._last_loaded_step = self.state.global_step - outputs = self._infer_single_or_multi_turn(all_inputs, self.request_config, is_global_inputs=True) - self._queue.put(DataCache(all_inputs, outputs)) + results = self._infer_single_or_multi_turn(all_inputs, self.request_config, is_global_inputs=True) + self._queue.put(DataCache(results)) - def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: - # Skip the first wake_up to avoid the warning "Executor is not sleeping" + def _fast_infer(self, inputs: DataType) -> DataType: + """ + Efficient inference logic with support for vLLM colocate mode, async generation, + and model weight offloading. + """ + # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: if self.engine.inner_model_executor.is_sleeping: - # First, load weights only, https://github.com/vllm-project/vllm/pull/15500 - if 'tags' in inspect.signature(self.engine.engine.wake_up).parameters: - self.engine.engine.wake_up(tags=['weights']) - else: - logger.info('We recommend installing vLLM >= 0.8.3, (ideally 0.8.5.post1)' - 'to help reduce memory peaks during engine wake-up.') - self.engine.engine.wake_up() + wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters + # Load weights only (faster and reduces memory peak) + kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + self.engine.engine.wake_up(**kwargs) - # First, have main process load weights if needed + # Step 2: Load model weights if global_step has changed if self.state.global_step != self._last_loaded_step: self._move_model_to_vllm() self._last_loaded_step = self.state.global_step + # Step 3: Offload model/optimizer if enabled context = self.offload_context if self.enable_offload else nullcontext with context(): - if self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping and \ - 'tags' in inspect.signature(self.engine.engine.wake_up).parameters: + # Step 4: Wake up kv_cache after offloading (vLLM colocate only) + if (self.vllm_mode == 'colocate' and self.engine.inner_model_executor.is_sleeping + and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): # Load the kv_cache only after updating and offload the weights. self.engine.engine.wake_up(tags=['kv_cache']) + # Step 5: Handle rollout for async generate or sync if self.async_generate: - # send this step data to server - # we gather inputs outside the thread for prevent potential gather deadlock + # Pre-gather inputs to avoid potential gather deadlocks all_inputs = gather_object(inputs) - self.async_infer(all_inputs) - # cached data from last step - data_cache = self._queue.get() - all_inputs = data_cache.inputs - all_outputs = gather_object(data_cache.outputs) + self.async_generate_rollout(all_inputs) + + # Retrieve cached outputs from the last step + data_cache: DataCache = self._queue.get() + all_outputs = gather_object(data_cache.results) + + # Slice inputs/outputs for the current process + per_device_datasize = len(all_outputs) // self.accelerator.num_processes process_slice = slice( - self.accelerator.process_index * len(inputs), - (self.accelerator.process_index + 1) * len(inputs), + self.accelerator.process_index * per_device_datasize, + (self.accelerator.process_index + 1) * per_device_datasize, ) - inputs = all_inputs[process_slice] outputs = all_outputs[process_slice] else: with self.multi_turn_completion_length_context(): outputs = self._infer_single_or_multi_turn(inputs, self.request_config) + # Step 6: Reset prefix cache and sleep to release memory if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: # Reset prefix cache before sleeping to prevent using stale cache upon waking up # https://github.com/modelscope/ms-swift/pull/5143 @@ -917,92 +788,109 @@ def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: self.engine.engine.sleep(level=self.args.sleep_level) empty_cache() - return inputs, outputs - - def _generate_completions(self, inputs: InputsType) -> InputsType: - """Generate completions for given inputs using either fast inference or standard PyTorch inference. + return outputs - Args: - inputs: List of input examples containing conversation messages. + def _generate_completions(self, inputs: DataType) -> DataType: + # add prompt ids and system prompts + inputs = self._preprocess_inputs(inputs) - Returns: - Modified inputs with generated completions added to the last message - and truncation flag set in 'is_truncated' field. - """ mode = 'train' if self.model.training else 'eval' if self.use_fast_infer: - inputs, outputs = self._fast_infer(inputs) + results = self._fast_infer(inputs) else: with unwrap_model_for_generation( self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation ), self.template.generate_context(), self.multi_turn_completion_length_context(): - outputs = self._infer_single_or_multi_turn(inputs, self.request_config) + results = self._infer_single_or_multi_turn(inputs, self.request_config) if mode == 'train': # In training mode, ensure the model is returned to train() mode after inference # This is necessary as pt engines set the model to eval mode during generation self.model.train() - for i, output in enumerate(outputs): - inputs[i]['messages'] = output[0] - inputs[i]['is_truncated'] = output[1] == 'length' - multi_turn_infos = output[2] if len(output) > 2 else {} - if 'images' in multi_turn_infos: - # override images - inputs[i]['images'] = multi_turn_infos['images'] - inputs[i]['multi_turn_infos'] = multi_turn_infos - if self.use_gym_env: - inputs[i]['total_reward'] = output[2] - inputs[i]['trajectory_info'] = output[3] - - return inputs + return results - def _generate_and_score_completions(self, inputs: InputsType) -> InputsType: + def _generate_and_score_completions(self, inputs: DataType) -> DataType: + # resample for overlong(> max_length) prompt data if self.template.truncation_strategy == 'raise': inputs = self.resample_truncated_inputs(inputs) inputs = self._generate_completions(inputs) - total_rewards_per_func, total_rewards, completions = self._score_completions(inputs) + total_rewards_per_func, total_rewards, completions, total_advantages, rewards_std = self._score_completions( + inputs) mode = 'train' if self.model.training else 'eval' if self.args.dynamic_sample and mode == 'train': # dynamic sampling for std=0 groups - inputs, total_rewards, total_rewards_per_func, completions = \ - self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions) + inputs, total_rewards, total_rewards_per_func, completions, total_advantages = \ + self._dynamic_sampling(inputs, total_rewards, total_rewards_per_func, completions, total_advantages, rewards_std) # noqa + + local_advantages = self.get_even_process_data(total_advantages) + assert len(local_advantages) == len(inputs) + for i, advantage in enumerate(local_advantages): + inputs[i]['advantages'] = advantage + + self._logs['advantages'].extend(total_advantages.tolist()) + if any('images' in data and data['images'] is not None for data in inputs): + self._logs['image'].extend(gather_object([inp['images'] for inp in inputs])) - # Prepare final outputs with advantages and other required fields - batch_encoded_inputs = self._prepare_batch_inputs(inputs, total_rewards) + batch_encoded_inputs = self._prepare_batch_inputs(inputs) # Log metrics messages = [inputs[i]['messages'][:-1] for i in range(len(inputs))] - trajectory_infos = None - if self.use_gym_env: - trajectory_infos = [inputs[i]['trajectory_info'] for i in range(len(inputs))] - self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func, - trajectory_infos) + + self._log_metrics(batch_encoded_inputs, messages, completions, total_rewards, total_rewards_per_func) return batch_encoded_inputs - def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Tensor, List[str]]: - """Score completions using all reward functions + def _score_completions(self, inputs: DataType) -> Tuple[torch.Tensor, torch.Tensor, List[str], torch.Tensor]: + """Score completions using all reward functions and compute advantages Args: inputs: List of input examples, each containing a 'messages' list with conversation history Returns: Tuple containing: - - rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards + - total_rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with individual rewards - total_rewards: Tensor of shape (num_examples,) with weighted sum of rewards - completions: List of generated completion strings + - advantages: Tensor of shape (num_examples,) with computed advantages """ device = self.accelerator.device completions = [example['messages'][-1]['content'] for example in inputs] + + # Extract prompt_ids for grouping + prompt_ids = [inp['prompt_id'] for inp in inputs] + # If using gym environment, extract rewards directly from inputs if self.use_gym_env: - total_rewards = torch.tensor([inp['total_reward'] for inp in inputs], dtype=torch.float32, device=device) - # For gym environment, there's only one total reward, so rewards_per_func is just total_rewards reshaped - rewards_per_func = total_rewards.unsqueeze(1) # shape: [num_examples, 1] - total_rewards_per_func = gather(rewards_per_func) - total_rewards_gathered = total_rewards_per_func.squeeze(1) # Recover from gathered data - return total_rewards_per_func, total_rewards_gathered, completions + local_rewards = torch.tensor([inp['total_reward'] for inp in inputs], dtype=torch.float32, device=device) + # For gym environment, there's only one total reward, so rewards_per_func is just local_rewards reshaped + local_rewards_per_func = local_rewards.unsqueeze(1) # shape: [num_examples, 1] + else: + # Compute rewards using reward functions + local_rewards_per_func = self._compute_rewards_per_func(inputs, completions) + local_rewards = (local_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Gather rewards and prompt_ids across processes with padding + gathered_rewards_per_func, gathered_prompt_ids = self._gather_rewards_and_prompt_ids( + local_rewards_per_func, prompt_ids) + + # Remove dummy data and compute total rewards + total_rewards_per_func, total_prompt_ids = self._remove_dummy_data(gathered_rewards_per_func, + gathered_prompt_ids) + + if self.use_gym_env: + total_rewards = total_rewards_per_func.squeeze(1) # Recover from gathered data + else: + total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Compute advantages based on prompt_id grouping + total_advantages, rewards_std = self._compute_advantages(total_rewards, total_prompt_ids) + + return total_rewards_per_func, total_rewards, completions, total_advantages, rewards_std + + def _compute_rewards_per_func(self, inputs: DataType, completions: List[str]) -> torch.Tensor: + """Compute rewards using all reward functions""" + device = self.accelerator.device rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device) for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( @@ -1028,12 +916,181 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' 'Please ensure that at least one reward function returns a valid reward.') - total_rewards_per_func = gather(rewards_per_func) - total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + return rewards_per_func + + def _gather_rewards_and_prompt_ids(self, local_rewards_per_func: torch.Tensor, + local_prompt_ids: List[str]) -> Tuple[torch.Tensor, List[str]]: + """Gather rewards and prompt_ids across processes with padding""" + device = self.accelerator.device + rewards_per_func = local_rewards_per_func + prompt_ids = local_prompt_ids + + # Prepare for gather with padding + if self.rollout_pad_count > 0: + # Pad rewards with NaN + pad_rewards_per_func = torch.full((self.rollout_pad_count, rewards_per_func.shape[1]), + torch.nan, + dtype=torch.float32, + device=device) + rewards_per_func = torch.cat([rewards_per_func, pad_rewards_per_func], dim=0) + + # Pad prompt_ids with special dummy value + dummy_prompt_ids = ['__dummy_pad__'] * self.rollout_pad_count + prompt_ids = prompt_ids + dummy_prompt_ids + + # Gather all data across processes + gathered_rewards_per_func = gather(rewards_per_func) + gathered_prompt_ids = gather_object(prompt_ids) + + return gathered_rewards_per_func, gathered_prompt_ids + + def _remove_dummy_data(self, gathered_rewards_per_func: torch.Tensor, + gathered_prompt_ids: List[str]) -> Tuple[torch.Tensor, List[str]]: + """Remove dummy data (prompt_id == '__dummy_pad__') from gathered data""" + valid_indices = [i for i, pid in enumerate(gathered_prompt_ids) if pid != '__dummy_pad__'] + valid_rewards_per_func = gathered_rewards_per_func[valid_indices] + valid_prompt_ids = [gathered_prompt_ids[i] for i in valid_indices] + + return valid_rewards_per_func, valid_prompt_ids + + def _gather_tensors(self, local_tensors: List[torch.Tensor]) -> List[torch.Tensor]: + """Gather tensors across processes with padding and remove dummy data""" + # Prepare for gather with padding + device = self.accelerator.device + tensors = local_tensors.copy() + if self.rollout_pad_count > 0 and tensors: + # Create dummy tensors with the same shape as the first tensor + dummy_tensor = torch.full_like(tensors[0], torch.nan, device=device) + for _ in range(self.rollout_pad_count): + tensors.append(dummy_tensor) + + # Gather all tensors across processes + gathered_tensors = gather(tensors) + + # Remove padded dummy tensors (NaN tensors) from gathered data + if not gathered_tensors: + return [] + + valid_tensors = [] + for tensor in gathered_tensors: + # Check if tensor is dummy (all NaN) + if not torch.isnan(tensor).all(): + # Ensure tensor is on the correct device + valid_tensors.append(tensor.to(device)) + + return valid_tensors + + def _gather_objects(self, local_objects: List[Any]) -> List[Any]: + """Gather objects across processes with padding and remove dummy data""" + # Prepare for gather with padding + objects = local_objects.copy() + if self.rollout_pad_count > 0: + # Add dummy objects + dummy_object = '__dummy_pad__' + for _ in range(self.rollout_pad_count): + objects.append(dummy_object) + + # Gather all objects across processes + gathered_objects = gather_object(objects) + + # Remove padded dummy objects from gathered data + if not gathered_objects: + return [] + + valid_objects = [obj for obj in gathered_objects if obj != '__dummy_pad__'] + return valid_objects + + def _compute_advantages(self, rewards: torch.Tensor, prompt_ids: List[str]) -> torch.Tensor: + """ + Compute normalized advantages by grouping rewards based on prompt IDs. + + This method performs group-wise advantage computation where rewards are normalized + within each prompt group by subtracting the group mean. Optionally scales advantages + by group standard deviation for variance normalization. + + The computation process: + 1. Groups rewards by unique prompt_id + 2. Computes group-wise mean and standard deviation + 3. Calculates advantages as (reward - group_mean) + 4. Optionally normalizes by group standard deviation if scale_rewards is enabled + 5. Tracks training/evaluation metrics for monitoring + + Args: + rewards (torch.Tensor): Reward values from all processes with shape (num_examples,) + prompt_ids (List[str]): Corresponding prompt identifiers with length num_examples + + Returns: + tuple: A tuple containing: + - advantages (torch.Tensor): Computed advantages with same shape as rewards + - rewards_std (torch.Tensor): Group standard deviations for each sample + """ + assert rewards.shape[0] == len(prompt_ids) + mode = 'train' if self.model.training else 'eval' + advantages = torch.zeros_like(rewards) + # calculate rewards_std for dynamic_sampling + rewards_std = torch.zeros_like(rewards) + + # Group rewards by prompt_id + unique_prompt_ids = list(set(prompt_ids)) + group_rewards_mean = [] + group_rewards_std = [] + + for prompt_id in unique_prompt_ids: + # Find all samples with this prompt_id + indices = [i for i, pid in enumerate(prompt_ids) if pid == prompt_id] + if len(indices) == 0: + continue + + group_rewards = rewards[indices] - return total_rewards_per_func, total_rewards, completions + # Compute group statistics + group_mean = group_rewards.mean() + group_rewards_mean.append(group_mean) + group_advantages = group_rewards - group_mean - def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): + group_std = group_rewards.std() + group_rewards_std.append(group_std) + rewards_std[indices] = group_std + + # Optional: scale by standard deviation + if self.args.scale_rewards: + group_advantages /= (group_std + 1e-4) + + # Assign computed advantages back to original positions + for idx, advantage in zip(indices, group_advantages): + advantages[idx] = advantage + + if group_rewards_mean: + # compute metrics + group_rewards_mean = torch.stack(group_rewards_mean) + group_rewards_std = torch.stack(group_rewards_std) + is_std_zero = torch.isclose(group_rewards_std, torch.zeros_like(group_rewards_std)) + + self._metrics[mode]['reward'].append(group_rewards_mean.mean().item()) + self._metrics[mode]['reward_std'].append(group_rewards_std.mean().item()) + self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) + + return advantages, rewards_std + + def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions, advantages, rewards_std): + """ + Perform dynamic sampling to replace samples with zero-reward-variance groups. + + This method implements DAPO (https://arxiv.org/abs/2503.14476) by replacing + samples from groups with zero reward variance (std=0) through resampling. + + Args: + inputs: local input data samples + rewards: Tensor of rewards for global data samples + rewards_per_func: Rewards per function/model for global data samples + completions: Generated completions for local inputs + advantages: Computed advantages for global data samples + rewards_std: Group standard deviations for each sample + + Returns: + tuple: (inputs, rewards, rewards_per_func, completions, advantages) + with zero-variance groups replaced by resampled data + """ # DAPO https://arxiv.org/abs/2503.14476 # Replaces samples with zero-reward-variance groups (std=0) resample_count = 0 @@ -1041,28 +1098,25 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): valid_rewards = [] valid_rewards_per_func = [] valid_completions = [] - - origin_data = (inputs, rewards, rewards_per_func, completions) + valid_advantages = [] + origin_data = (inputs, rewards, rewards_per_func, completions, advantages) while resample_count < self.args.max_resample_times: - grouped_rewards = rewards.view(-1, self.num_generations) - group_std = grouped_rewards.std(dim=1) - - valid_mask = (group_std > 0).repeat_interleave(self.num_generations) + valid_mask = (rewards_std > 0) all_inputs = gather_object(inputs) valid_samples.extend([inp for inp, mask in zip(all_inputs, valid_mask) if mask]) valid_rewards.append(rewards[valid_mask]) valid_rewards_per_func.append(rewards_per_func[valid_mask]) valid_completions.extend( [inp['messages'][-1]['content'] for inp, mask in zip(all_inputs, valid_mask) if mask]) - + valid_advantages.append(advantages[valid_mask]) if len(valid_samples) >= self.args.generation_batch_size: break inputs = next(self.dynamic_resample_iterator) inputs = Trainer._prepare_inputs(self, inputs) inputs = self._generate_completions(inputs) - rewards_per_func, rewards, completions = self._score_completions(inputs) + rewards_per_func, rewards, completions, advantages, rewards_std = self._score_completions(inputs) resample_count += 1 if len(valid_samples) >= self.args.generation_batch_size: @@ -1074,70 +1128,79 @@ def _dynamic_sampling(self, inputs, rewards, rewards_per_func, completions): rewards = torch.cat(valid_rewards)[:self.args.generation_batch_size] rewards_per_func = torch.cat(valid_rewards_per_func)[:self.args.generation_batch_size] completions = valid_completions[:self.args.generation_batch_size][process_slice] + advantages = torch.cat(valid_advantages)[:self.args.generation_batch_size] else: logger.warning(f'There are still std=0 groups present after {self.args.max_resample_times} retries.') - inputs, rewards, rewards_per_func, completions = origin_data + inputs, rewards, rewards_per_func, completions, advantages = origin_data - return inputs, rewards, rewards_per_func, completions + return inputs, rewards, rewards_per_func, completions, advantages - def split_by_mini_batches(self, inputs, advantages): - # Slice to keep only the local part of the data + def split_by_mini_batches(self, inputs: DataType) -> List[DataType]: + """ + Split inputs into mini-batches, handling variable generation counts. + + When rollout count differs from expected (bs * spg * num_generations), + we need to adjust the splitting logic to maintain proper batch sizes. + + This method divides the input data into chunks based on the steps per generation (spg). + If the total number of inputs is not evenly divisible by spg, the remainder is + distributed across the first few chunks to ensure all data is included. + + Args: + inputs (DataType): List of input data samples to be split into mini-batches. + + Returns: + List[DataType]: A list of data chunks, where each chunk represents one step + in the generation process. The number of chunks equals spg. + """ # Slice to keep only the local part of the data - process_slice = slice( - self.accelerator.process_index * len(inputs), - (self.accelerator.process_index + 1) * len(inputs), - ) - advantages = advantages[process_slice] - mode = 'train' if self.model.training else 'eval' - bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size - spg = self.args.steps_per_generation if mode == 'train' else 1 + mode: str = 'train' if self.model.training else 'eval' + spg: int = self.args.steps_per_generation if mode == 'train' else 1 + + chunk_size: int = len(inputs) // spg + remainder: int = len(inputs) % spg + spg_chunks: List[DataType] = [] - assert len(inputs) == bs * spg, f'Expected {bs * spg} inputs, got {len(inputs)}' - spg_chunks = [inputs[i * bs:(i + 1) * bs] for i in range(spg)] - # Split advantages by spg chunks - advantage_chunks = torch.chunk(advantages, spg) - return spg_chunks, advantage_chunks + start_idx: int = 0 + for i in range(spg): + current_chunk_size: int = chunk_size + (1 if i < remainder else 0) + end_idx: int = start_idx + current_chunk_size + spg_chunks.append(inputs[start_idx:end_idx]) + start_idx = end_idx - def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> List[InputsType]: + return spg_chunks + + def _prepare_batch_inputs(self, inputs: DataType) -> List[DataType]: """ - Prepare the final batch inputs with advantages, ref/old_policy logps and other fields for RL training. + Prepare the final batch inputs with ref/old_policy logps and other fields for RL training. Args: - inputs (InputsType): List of input samples. Original shape is [spg*bs] where: - - spg: steps_per_generation - - bs: per-device batch size - rewards (torch.Tensor): Tensor of global rewards corresponding to the inputs. - Shape should match the total number of samples (spg*bs*num_processes*num_generations) + inputs (DataType): List of local input samples. Returns: - List[InputsType]: A list of prepared batch inputs, organized as [spg][bs] + List[DataType]: A list of prepared batch inputs, organized as [spg][bs] """ - # Compute advantages - grouped_rewards = rewards.view(-1, self.num_generations) - mean_grouped_rewards = grouped_rewards.mean(dim=1).repeat_interleave(self.num_generations, dim=0) - std_grouped_rewards = grouped_rewards.std(dim=1).repeat_interleave(self.num_generations, dim=0) - - advantages = (rewards - mean_grouped_rewards) - if self.args.scale_rewards: - advantages /= (std_grouped_rewards + 1e-4) - self._logs['advantages'].extend(gather(advantages).tolist()) - if any('images' in data and data['images'] is not None for data in inputs): - self._logs['image'].extend(gather_object([inp['images'] for inp in inputs])) - template = self.template - gas_chunks, advantage_chunks = self.split_by_mini_batches(inputs, advantages) + gas_chunks = self.split_by_mini_batches(inputs) ga_batch_encoded_inputs = [] - for i, (batch, batch_advantages) in enumerate(zip(gas_chunks, advantage_chunks)): + for i, batch in enumerate(gas_chunks): # Encode and process each batch (size=bs) with self._template_context(template): - batch_encoded_inputs = [template.encode(infer_request) for infer_request in batch] + [ + data.update( + {'messages': replace_assistant_response_with_ids(data['messages'], data['response_token_ids'])}) + for data in batch if 'response_token_ids' in data and data['response_token_ids'] + ] + + batch_encoded_inputs = [template.encode(data) for data in batch] batch_encoded_inputs = to_device(template.data_collator(batch_encoded_inputs), self.model.device) # Process labels and masks labels = batch_encoded_inputs.pop('labels') logits_to_keep = (labels.shape[-1] - (torch.ne(labels, -100).int().argmax(-1))).max().item() + batch_encoded_inputs.update({ 'completion_mask': labels[:, -logits_to_keep:] != -100, @@ -1146,7 +1209,7 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li 'logits_to_keep': logits_to_keep, 'advantages': - batch_advantages + torch.stack([data['advantages'] for data in batch]) }) with torch.no_grad(): @@ -1168,22 +1231,28 @@ def _prepare_batch_inputs(self, inputs: InputsType, rewards: torch.Tensor) -> Li return ga_batch_encoded_inputs - def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func, trajectory_infos=None): + def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func): """Log training/evaluation metrics""" mode = 'train' if self.model.training else 'eval' device = self.accelerator.device - # Calculate completion length metrics - agg_completion_mask = gather(torch.cat([inp['completion_mask'].sum(1) for inp in inputs])) + local_completion_lengths = [inp['completion_mask'].sum(1).to(device) for inp in inputs] + + total_completion_lengths = self._gather_tensors(local_completion_lengths) + total_completion_lengths = torch.cat(total_completion_lengths) + + self._metrics[mode]['completions/mean_length'].append(total_completion_lengths.float().mean().item()) + self._metrics[mode]['completions/min_length'].append(total_completion_lengths.float().min().item()) + self._metrics[mode]['completions/max_length'].append(total_completion_lengths.float().max().item()) - self._metrics[mode]['completions/mean_length'].append(agg_completion_mask.float().mean().item()) - self._metrics[mode]['completions/min_length'].append(agg_completion_mask.float().min().item()) - self._metrics[mode]['completions/max_length'].append(agg_completion_mask.float().max().item()) # Calculate clip ratio - agg_truncated_mask = gather(torch.cat([inp['truncated_mask'] for inp in inputs]).to(device)) + local_truncated_masks = [inp['truncated_mask'].to(device) for inp in inputs] + total_truncated_masks = self._gather_tensors(local_truncated_masks) + total_truncated_masks = torch.cat(total_truncated_masks) - term_completion_mask = agg_completion_mask[agg_truncated_mask] - clipped_completions_ratio = len(term_completion_mask) / len(agg_completion_mask) + num_truncated_samples = total_truncated_masks.sum().item() + num_total_samples = total_completion_lengths.shape[0] + clipped_completions_ratio = num_truncated_samples / num_total_samples self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio) @@ -1193,26 +1262,22 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func, std_rewards = nanstd(rewards_per_func[:, i]).item() self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards) - # Log overall reward stats - grouped_rewards = rewards.view(-1, self.num_generations) - std_grouped_rewards = grouped_rewards.std(dim=1) - is_std_zero = torch.isclose(std_grouped_rewards, torch.zeros_like(std_grouped_rewards)) - - self._metrics[mode]['reward'].append(grouped_rewards.mean().item()) - self._metrics[mode]['reward_std'].append(std_grouped_rewards.mean().item()) - self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) + # Log prompt and completion texts with padding and remove dummy data + valid_messages = self._gather_objects(messages) + valid_completions = self._gather_objects(completions) - # Log prompt and completion texts - self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(gather_object(messages))) - self._logs['completion'].extend(gather_object(completions)) + self._logs['prompt'].extend(self._apply_chat_template_to_messages_list(valid_messages)) + self._logs['completion'].extend(valid_completions) if self.use_gym_env: - self._logs['trajectory_infos'].extend(gather_object(trajectory_infos)) + pass + # TODO: extra from rollout_infos + # self._logs['trajectory_infos'].extend(gather_object(trajectory_infos)) for i, name in enumerate(self.reward_func_names): self._logs['rewards'][name].extend(rewards_per_func[:, i].tolist()) - def _apply_chat_template_to_messages_list(self, messages_list: InputsType): + def _apply_chat_template_to_messages_list(self, messages_list: DataType): prompts_text = [] for messages in messages_list: InferRequest.remove_response(messages) @@ -1249,6 +1314,21 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N def _compute_loss(self, model, inputs): mode = 'train' if self.model.training else 'eval' + # Check batch size and decide processing strategy + batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else len(inputs.get('completion_mask', [])) + expected_bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + + # If batch size matches expected, use normal processing + if batch_size == expected_bs: + return self._compute_loss_single(model, inputs) + else: + assert batch_size > expected_bs + return self._compute_loss_chunked(model, inputs) + + def _compute_loss_single(self, model, inputs): + """Original loss computation logic for single batch processing.""" + mode = 'train' if self.model.training else 'eval' + completion_mask = inputs['completion_mask'] truncated_mask = inputs['truncated_mask'] @@ -1279,8 +1359,7 @@ def _compute_loss(self, model, inputs): logger.info('All completions are overlong and truncated, ' 'resulting in NaN some values for some metrics (e.g., KL)') truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device) - completion_mask = completion_mask * (~truncated_mask) - + completion_mask.mul_(~truncated_mask) # Compute the KL divergence between the model and the reference model if self.beta != 0.0: ref_per_token_logps = inputs['ref_per_token_logps'] @@ -1366,6 +1445,42 @@ def masked_batch_mean(x): return loss + def _compute_loss_chunked(self, model, inputs): + mode = 'train' if self.model.training else 'eval' + chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else len(inputs.get('completion_mask', [])) + + logger.debug(f'Computing chunked loss for batch size {batch_size} with chunk size {chunk_size}') + + all_losses = [] + + # TODO: Aggregate metrics across chunks + # aggregated_metrics = {} + + for i in range(0, batch_size, chunk_size): + end_idx = min(i + chunk_size, batch_size) + + # Create chunk inputs + chunk_inputs = {} + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + chunk_inputs[key] = value[i:end_idx] + else: + chunk_inputs[key] = value + + # Compute loss for this chunk + chunk_loss = self._compute_loss_single(model, chunk_inputs) + + all_losses.append(chunk_loss) + + # Compute average loss + final_loss_tensor = torch.stack(all_losses).mean() + + logger.debug(f'Chunked loss computation completed: {len(all_losses)} chunks -> ' + f'final loss {final_loss_tensor.item():.6f}') + + return final_loss_tensor + @contextmanager def padding_free_context(self, model: torch.nn.Module): ctx = {} @@ -1444,6 +1559,27 @@ def _get_per_token_logps_and_entropies(self, model, inputs, compute_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute per-token log probabilities and entropies with memory-efficient batching. + + When rollout count is larger than expected, we process in smaller batches + to control memory usage. + """ + # Check if we need to use memory-efficient batching + batch_size = inputs['input_ids'].shape[0] + mode = 'train' if self.model.training else 'eval' + expected_bs = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + + # If batch is larger than threshold and adaptive batching is enabled, use chunked processing + if batch_size > expected_bs: + return self._get_per_token_logps_and_entropies_chunked(model, inputs, compute_entropy) + else: + return self._get_per_token_logps_and_entropies_single(model, inputs, compute_entropy) + + def _get_per_token_logps_and_entropies_single(self, + model, + inputs, + compute_entropy=False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: logits_to_keep = inputs['logits_to_keep'] input_ids = inputs['input_ids'] unwrapped_model = self.accelerator.unwrap_model(model) @@ -1488,6 +1624,54 @@ def _get_per_token_logps_and_entropies(self, return logps, entropies + def _get_per_token_logps_and_entropies_chunked(self, + model, + inputs, + compute_entropy=False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Memory-efficient chunked processing for large batches. + + Splits the batch into smaller chunks based on per_device_batch_size + to control memory usage when rollout count is larger than expected. + """ + batch_size = inputs['input_ids'].shape[0] + mode = 'train' if self.model.training else 'eval' + chunk_size = self.args.per_device_train_batch_size if mode == 'train' else self.args.per_device_eval_batch_size + + logger.debug(f'Processing batch of size {batch_size} in chunks of {chunk_size}') + + all_logps = [] + all_entropies = [] if compute_entropy else None + + # Process in chunks + for i in range(0, batch_size, chunk_size): + end_idx = min(i + chunk_size, batch_size) + + # Create chunk inputs + chunk_inputs = {} + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + chunk_inputs[key] = value[i:end_idx] + else: + chunk_inputs[key] = value # Non-tensor values (like logits_to_keep) are scalars + + # Process this chunk + chunk_logps, chunk_entropies = self._get_per_token_logps_and_entropies_single( + model, chunk_inputs, compute_entropy) + + all_logps.append(chunk_logps) + if compute_entropy and chunk_entropies is not None: + all_entropies.append(chunk_entropies) + + # Concatenate results + final_logps = torch.cat(all_logps, dim=0) + final_entropies = torch.cat(all_entropies, dim=0) if all_entropies else None + + logger.debug(f'Chunked processing completed: {len(all_logps)} chunks -> ' f'final shape {final_logps.shape}') + + return final_logps, final_entropies + @patch_profiling_decorator def _get_last_hidden_state(self, unwrapped_model, inputs, logits_to_keep): # unwrap the model to access the model.model @@ -1556,7 +1740,7 @@ def evaluation_loop(self, dataloader, *args, **kwargs): self.eval_flag = True return output - def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch=None) -> torch.Tensor: + def training_step(self, model: nn.Module, inputs: DataType, num_items_in_batch=None) -> torch.Tensor: if self.args.async_generate: # Wait for the eval rollout to complete while not self.is_async_generate_eval_rollout_done(): @@ -1565,46 +1749,35 @@ def training_step(self, model: nn.Module, inputs: InputsType, num_items_in_batch def _engine_infer( self, - infer_requests: InputsType, + infer_requests: List[RolloutInferRequest], request_config: Optional[RequestConfig] = None, *, use_tqdm: Optional[bool] = False, - ) -> List[ChatCompletionResponse]: + ) -> List[RolloutOutput]: + """ + Perform inference using the configured engine (VLLM server or colocate engine). + + Args: + infer_requests: List of rollout inference requests to process + request_config: Optional configuration for the requests + use_tqdm: Whether to show progress bar during inference + + Returns: + List of RolloutOutput objects containing the inference results + """ with patch_profiling_context(self, 'generate'): if self.vllm_mode == 'server': - request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects'] - - infer_requests = [{ - **{k: request[k] - for k in request_keys if k in request}, - **({ - 'data_dict': {k: request[k] - for k in request if k not in request_keys} - } if ( - (self.multi_turn_scheduler and self.vllm_use_async_engine) or - (self.vllm_use_async_engine and self.use_gym_env) - ) else {}) # use gym infer - } for request in infer_requests] - - self._process_infer_requests_images(infer_requests) - return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm) + return self.vllm_client.infer([asdict(req) for req in infer_requests], + asdict(request_config), + use_tqdm=use_tqdm) else: - return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) - - def _process_infer_requests_images(self, infer_requests: InputsType): - # Process image format into a format that session.post can accept - import base64 - if not any('images' in request for request in infer_requests): - return - for request in infer_requests: - if 'images' not in request: - continue - for i, img in enumerate(request['images']): - if 'bytes' in img and img['bytes']: - request['images'][i] = base64.b64encode(img['bytes']).decode('utf-8') - elif 'path' in img and img['path']: - request['images'][i] = img['path'] - return + res = self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm) + if all(isinstance(r, RolloutOutput) for r in res): + return res + else: + # PT Eninge + assert all(isinstance(r, ChatCompletionResponse) for r in res) + return [RolloutOutput(response=r) for r in res] def old_policy(self): return self.num_iterations > 1 or self.args.gradient_accumulation_steps % self.args.steps_per_generation != 0 @@ -1695,7 +1868,7 @@ def set_default_max_tokens(_self, request_config: RequestConfig, inputs: Dict[st self.engine.max_model_len = original_max_len del self.engine.set_grpo_max_model_len - def resample_truncated_inputs(self, inputs: InputsType, n_try_fetch: int = 10) -> InputsType: + def resample_truncated_inputs(self, inputs: DataType, n_try_fetch: int = 10) -> DataType: template = self.template for i, data in enumerate(inputs): n_try = 0 @@ -1784,43 +1957,6 @@ def is_async_generate_eval_rollout_done(self): def is_async_generate_train_rollout_done(self): return not self.train_queue.empty() - def inputs_to_rolloutrequest(self, inputs: InputsType) -> List[RolloutInferRequest]: - """Convert a list of inputs to a list of RolloutInferRequest objects - - If the input contains a 'data_dict' key, it will be used as the base for the new data_dict. - For other keys, if they overlap with keys in data_dict, the values from data_dict will be used. - Non-overlapping keys will be added to data_dict. - - Args: - inputs: List of input dictionaries - - Returns: - List of RolloutInferRequest objects - """ - request_keys = ['messages', 'images', 'audios', 'videos', 'tools', 'objects'] - infer_requests = [] - - for request in inputs: - # Get the base data_dict if it exists in the input - base_data_dict = {} - if 'data_dict' in request: - if isinstance(request['data_dict'], dict): - base_data_dict = request['data_dict'] - else: - raise ValueError('data_dict exists but is not a dictionary') - - # Collect all non-request_keys items as extra fields - extra_data = {k: request[k] for k in request if k not in request_keys and k != 'data_dict'} - - # Merge the data_dict, keeping keys from base_data_dict as priority - final_data_dict = {**extra_data, **base_data_dict} - - # Create RolloutInferRequest instance - req_args = {k: request[k] for k in request_keys if k in request} - infer_requests.append(RolloutInferRequest(**req_args, data_dict=final_data_dict)) - - return infer_requests - @contextmanager def offload_context(self): if self.args.offload_model: @@ -1842,3 +1978,448 @@ def offload_context(self): if getattr(self, 'optimizer', None) and self.args.offload_optimizer: self.load_optimizer() empty_cache() + + def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType: + """ + Adds a unique `prompt_id` to each input based on their `messages` content. + + Inputs with identical `messages` (assumed to be adjacent) will share the same `prompt_id`. + Each input also gets a unique `request_id` for vLLM request tracking. + + Args: + inputs (DataType): A list of dictionaries, each containing a 'messages' key. + + Returns: + DataType: The input list with each item containing new 'prompt_id' and 'request_id' fields. + + Example: + >>> inputs = [ + ... {"messages": [{"role": "user", "content": "hello"}], "data": 1}, + ... {"messages": [{"role": "user", "content": "hello"}], "data": 2}, + ... {"messages": [{"role": "assistant", "content": "hi"}], "data": 3}, + ... ] + >>> self._add_prompt_id_to_inputs(inputs) + [ + {"messages": [...], "data": 1, "prompt_id": "a1b2c3...", "request_id": "req1"}, + {"messages": [...], "data": 2, "prompt_id": "a1b2c3...", "request_id": "req2"}, + {"messages": [...], "data": 3, "prompt_id": "d4e5f6...", "request_id": "req3"}, + ] + """ + if not inputs: + return inputs + + prev_messages = inputs[0].get('messages') + current_prompt_id = str(uuid.uuid4()) + inputs[0]['prompt_id'] = current_prompt_id + inputs[0]['request_id'] = str(uuid.uuid4()) # Each request gets a unique ID + + for i in range(1, len(inputs)): + messages = inputs[i]['messages'] + if messages == prev_messages: + inputs[i]['prompt_id'] = current_prompt_id + else: + prev_messages = messages + current_prompt_id = str(uuid.uuid4()) + inputs[i]['prompt_id'] = current_prompt_id + # Each request always gets a unique request_id, regardless of prompt_id + inputs[i]['request_id'] = str(uuid.uuid4()) + + return inputs + + def _server_rollout(self, inputs: DataType, request_config: RequestConfig, + is_global_inputs: bool) -> List[RolloutOutput]: + """ + Perform rollout inference using vLLM server mode. + + Args: + inputs: List of input data to be processed + request_config: Configuration dictionary for the inference request + is_global_inputs: Flag indicating whether inputs are shared across all processes (async-generate) + + Returns: + List of RolloutOutput objects containing inference results + For non-global inputs(async-generate), returns only the portion assigned to this process. + + Notes: + - async engine with multi-turn scenarios, the outputs count may exceed inputs count + - For distributed inputs, outputs are scattered to processes + - Main process coordinates inference and broadcasts outputs to other processes + """ + # Convert inputs to inference requests + infer_requests = self.inputs2requests(inputs) + + if is_global_inputs: + per_device_size = len(infer_requests) // self.accelerator.num_processes + # for async generate, data have been pre-gathered to avoid potential communicate operator + all_requests = infer_requests + all_requests_lengths = [per_device_size] + [0] * (self.accelerator.num_processes - 1) + else: + all_requests = gather_object(infer_requests) + all_requests_lengths = gather_object([len(infer_requests)]) + + if not any(requests for requests in all_requests): + return [] + + # TODO: Check flatten + if self.accelerator.is_main_process: + all_outputs: List[RolloutOutput] = self._engine_infer( + infer_requests=all_requests, request_config=request_config) + + # Handle async engine the outputs count may exceed inputs count + if self.vllm_use_async_engine: + outputs_count = [len(all_outputs)] if self.accelerator.is_main_process else [0] + outputs_count = gather_object(outputs_count)[0] # Broadcast count to all processes + + # Initialize empty outputs for non-main processes + if not self.accelerator.is_main_process: + all_outputs = [None] * outputs_count + + # Distribute outputs to all processes for non-global inputs + if not is_global_inputs: + all_outputs = broadcast_object_list(all_outputs, from_process=0) + + # Calculate slice for this process's outputs + if not self.vllm_use_async_engine and self.multi_turn_scheduler: + # Special handling for colocated + multi-turn inference with varying request counts + start_idx = sum(all_requests_lengths[:self.accelerator.process_index]) + end_idx = start_idx + all_requests_lengths[self.accelerator.process_index] + process_slice = slice(start_idx, end_idx) + outputs = all_outputs[process_slice] + else: + # Standard equal distribution case + outputs = self.get_even_process_data(all_outputs) + + else: + # For global inputs, only main process keeps outputs + outputs = outputs if self.accelerator.is_main_process else [] + + return outputs + + def _colocate_rollout(self, inputs: DataType, request_config: RequestConfig) -> List[RolloutOutput]: + """ + Perform co-located rollout inference with PTEngine or vLLMEngine(TP supported). + + Args: + inputs: Input data for the current process + request_config: Configuration parameters for the inference request + + Returns: + List[RolloutOutput]: Inference results for this process's portion of inputs + + Notes: + - For tensor parallel groups (vllm_tensor_parallel_size > 1): + * Gathers inputs from all ranks in the tensor parallel group + * Each rank processes the full input set but keeps only its assigned portion + * Ensures consistent seeds within TP groups for synchronization + - In single-process mode, directly processes the inputs + """ + # Handle tensor parallel group processing + if self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + # Note: The input sizes may differ across ranks (e.g., in multi-turn scenarios, + # the amount of data each rank continues to process may vary). + + # Step 1: Gather input lengths from all ranks in the TP group + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + local_input_length = len(inputs) + all_input_lengths = [None] * self.vllm_tensor_parallel_size + torch.distributed.all_gather_object(all_input_lengths, local_input_length, group=self.tp_group) + + # Calculate slice indices for this rank's outputs + start_idx = sum(all_input_lengths[:local_rank_in_group]) + end_idx = start_idx + all_input_lengths[local_rank_in_group] + + # Step 2: Gather actual inputs from all TP group ranks + gathered_inputs = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_inputs, inputs, group=self.tp_group) + + # Flatten the gathered inputs + inputs = [p for sublist in gathered_inputs for p in sublist] + + # Critical seed configuration for TP groups: + # 1. Same seed within TP group - ensures synchronization and prevents hangs + # 2. Different seeds across TP groups - avoids duplicate generations + outputs: List[RolloutOutput] = self._engine_infer(infer_requests=inputs, request_config=request_config) + + # For TP groups, each rank keeps only its assigned portion of outputs + if self.vllm_tensor_parallel_size > 1: + outputs = outputs[start_idx:end_idx] + + return outputs + + def inputs2requests(self, inputs: DataType) -> List[RolloutInferRequest]: + """ + Convert raw input data into RolloutInferRequest objects with proper data processing. + + Args: + inputs: List of raw input dictionaries containing messages and multimedia data + + Returns: + List[RolloutInferRequest]: Processed inference request objects ready for engine + + Processing includes: + - Image data conversion (bytes to base64, path handling) + - Field filtering based on request metadata requirements + - UUID assignment using unique request_id for vLLM request tracking + - Optional preservation of additional fields for multi-turn async scenarios + """ + + def _process_image_data(image_data: Union[dict, str]) -> str: + """Convert image data from various formats into standardized representation. + + Args: + image_data: Either a dict with 'bytes' or 'path', or a direct string path + + Returns: + str: Base64 encoded image data or original file path + """ + if isinstance(image_data, dict): + if image_data.get('bytes'): + return base64.b64encode(image_data['bytes']).decode('utf-8') + if image_data.get('path'): + return image_data['path'] + return image_data + + if not inputs: + return [] + + # Define core metadata fields required for all requests + REQUEST_METADATA_FIELDS = ['messages', 'images', 'audios', 'videos', 'objects', 'uuid'] + requests_dicts = [] + + for data in inputs: + # Extract required metadata fields + request_data = {key: data[key] for key in REQUEST_METADATA_FIELDS if key in data} + request_data['uuid'] = data['request_id'] # Use unique request_id for vLLM + # Preserve additional fields for multi-turn async scenarios + if self.args.vllm_server_pass_dataset: + # data_dict is already concatenated inside async engine + extra_fields = {k: v for k, v in data.items() if k not in REQUEST_METADATA_FIELDS} + if extra_fields: + request_data['data_dict'] = extra_fields + elif self.multi_turn_scheduler: + # Concatenate data_dict here + base_data_dict = {} + if 'data_dict' in data: + if isinstance(data['data_dict'], dict): + base_data_dict = data['data_dict'] + else: + raise ValueError('data_dict exists but is not a dictionary') + # Add fields that are not in metadata fields and not 'data_dict' + extra_data = {k: v for k, v in data.items() if k not in REQUEST_METADATA_FIELDS and k != 'data_dict'} + # Merge additional fields and existing data_dict + final_data_dict = {**extra_data, **base_data_dict} + request_data['data_dict'] = final_data_dict if final_data_dict else {} + + requests_dicts.append(request_data) + + # Process image data in each request + for request in requests_dicts: + if 'images' in request and request['images']: + request['images'] = ([_process_image_data(img) for img in request['images']] if isinstance( + request['images'], list) else _process_image_data(request['images'])) + + # Convert dictionaries to formal request objects + return [from_dict(RolloutInferRequest, request_data) for request_data in requests_dicts] + + def _preprocess_inputs(self, inputs: DataType) -> DataType: + """Preprocess input data before inference. + + Args: + inputs: List of input dictionaries containing conversation messages + + Returns: + Processed inputs with: + - Added prompt IDs for grouping (same messages share same prompt_id) + - Added unique request IDs for vLLM request tracking + - Removed existing assistant responses from messages + + Processing Steps: + 1. Adds prompt IDs and unique request IDs to each input + 2. Cleans each message sequence by removing existing assistant responses + """ + processed_inputs = self._add_prompt_id_to_inputs(inputs) + + for input_item in processed_inputs: + remove_response(input_item['messages']) + + return processed_inputs + + def _postprocess_rollout_outputs(self, inputs: DataType, outputs: List[RolloutOutput]) -> DataType: + """ + Postprocess rollout outputs by merging them back into the input data structures. + + Depending on the mode (async or sync), it either matches inputs by request_id + or assumes a one-to-one correspondence. + """ + + def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput): + response = output.response + choice = response.choices[0] + + # Step 1: Update or append assistant message + if output.messages: + input_data['messages'] = output.messages # Override full message history + else: + # not provided, append + messages = input_data['messages'] + remove_response(messages) + messages.append({'role': 'assistant', 'content': choice.message.content}) + + input_data['messages'].append({'role': 'assistant', 'content': choice.message.content}) + + # Step 2: Add token IDs and loss mask + if output.response_token_ids: + input_data['response_token_ids'] = output.response_token_ids + if output.response_loss_mask: + input_data['response_loss_mask'] = output.response_loss_mask + else: + if not self.multi_turn_scheduler: + # for single turn, skip tokenizer response + input_data['response_token_ids'] = output.response.choices[0].token_ids + + # Step 3: Attach rollout extra info + if output.rollout_infos: + input_data['rollout_infos'] = output.rollout_infos + + # Step 4: Store finish reason (used for truncation filters etc.) + input_data['finish_reason'] = choice.finish_reason + input_data['is_truncated'] = choice.finish_reason == 'length' + + return input_data + + # Async engine mode: match by request_id + if self.vllm_use_async_engine: + results = [] + id2inputs = {} + for input_data in inputs: + request_id = input_data['request_id'] + if request_id not in id2inputs: + id2inputs[request_id] = deepcopy(input_data) + + for output in outputs: + request_id = output.response.id + assert request_id in id2inputs, f'Request ID {request_id} not found in inputs' + input_data = deepcopy(id2inputs[request_id]) + results.append(merge_output_input_data(input_data, output)) + + return results + else: + # Sync mode: simple zip merge + assert len(inputs) == len(outputs) + return [ + merge_output_input_data(deepcopy(input_data), output) for input_data, output in zip(inputs, outputs) + ] + + def _sync_multi_turn_infer(self, inputs: DataType, first_turn_rollout_outputs: List[RolloutOutput], + request_config: RequestConfig) -> List[RolloutOutput]: + """ + Handles multi-turn inference when not using async engine. + + This method iteratively rolls out turns until all dialogues are finished + according to the multi_turn_scheduler. + """ + orig_size = len(inputs) + rollout_outputs: List[RolloutOutput] = [None] * orig_size # Preallocate to preserve order + + # Attach index to inputs for tracking + for i, input_data in enumerate(inputs): + input_data['index'] = i + + current_turn = 1 + outputs = first_turn_rollout_outputs + while True: + has_local_data = bool(len(inputs) > 0) + has_global_data = gather_object([has_local_data]) + if not any(has_global_data): + break + + for i, output in enumerate(outputs): + input_data = deepcopy(inputs[i]) + if output and output.messages: + messages = output.messages + else: + response = output.response + choice = response.choices[0] + messages = input_data['messages'] + if (current_turn == 1 or not messages[-1]['content'] or messages[-1]['content'] == ''): + remove_response(messages) + messages.append({'role': 'assistant', 'content': choice.message.content}) + + input_data['messages'] = messages + index = input_data['index'] + rollout_outputs[index] = output + rollout_outputs[index].messages = messages + + # Determine which dialogues are finished + should_stops = [ + self.multi_turn_scheduler.check_finished(req, output.response.choices[0], current_turn) + for req, output in zip(self.inputs2requests(inputs), outputs) + ] + + # Prepare pending inputs for next turn + pending_inputs = [] + for stop, _input, output in zip(should_stops, inputs, outputs): + if stop: + continue + index = _input['index'] + step_result = self.multi_turn_scheduler.step( + self.inputs2requests([_input])[0], output.response.choices[0], current_turn) + + if step_result['response_token_ids']: + rollout_outputs[index].response_token_ids.append(step_result['response_token_ids']) + if step_result['response_loss_mask']: + rollout_outputs[index].response_loss_mask.append(step_result['response_loss_mask']) + + if step_result['rollout_infos']: + rollout_outputs[index].rollout_infos.update(step_result['rollout_infos']) + + pending_input = {**asdict(step_result['infer_request']), 'index': index} + pending_inputs.append(pending_input) + + inputs = pending_inputs + current_turn += 1 + + # Rollout for the next turn + outputs = self._rollout(inputs if has_local_data else [], request_config) + + assert all(o is not None for o in rollout_outputs) + return rollout_outputs + + def get_even_process_data(self, global_data: List[T]) -> List[T]: + """ + Evenly splits `global_data` among all processes. + + Each process receives a contiguous chunk of data. If `len(global_data)` is not + perfectly divisible by the number of processes, the first `remainder` processes + will receive one additional item. + + Args: + global_data (List[T]): The full list of data to be distributed. + + Returns: + List[T]: The subset of `global_data` assigned to this process. + """ + num_procs = self.accelerator.num_processes + proc_idx = self.accelerator.process_index + total = len(global_data) + + base_size = total // num_procs + remainder = total % num_procs + + # Calculate the number of samples that need to be padded + # This ensures all processes have the same number of samples for gather operations + self.rollout_pad_count = 0 + if remainder > 0 and proc_idx >= remainder: + # Processes with extra samples need padding + self.rollout_pad_count = 1 + + if proc_idx < remainder: + start = proc_idx * (base_size + 1) + end = start + base_size + 1 + else: + start = remainder * (base_size + 1) + (proc_idx - remainder) * base_size + end = start + base_size + + return global_data[start:end] diff --git a/swift/trainers/rlhf_trainer/utils.py b/swift/trainers/rlhf_trainer/utils.py index ebd00096df..7a12391ba7 100644 --- a/swift/trainers/rlhf_trainer/utils.py +++ b/swift/trainers/rlhf_trainer/utils.py @@ -4,7 +4,7 @@ from contextlib import contextmanager from io import BytesIO from types import MethodType -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, List, Optional, Union import torch import torch.nn.functional as F @@ -20,6 +20,9 @@ if is_swanlab_available(): import swanlab +if TYPE_CHECKING: + from swift.llm.utils import Messages + def round_robin(num_reqs, num_workers): """Distribute requests evenly across workers using round-robin algorithm. @@ -253,3 +256,67 @@ def load_pil_img(img) -> Image: return Image.open(img['path']) else: raise ValueError("Image dictionary must contain either 'bytes' or 'path' key.") + + +def replace_assistant_response_with_ids(messages: 'Messages', + completion_ids: List[Union[int, List[int]]]) -> 'Messages': # noqa + """ + Replaces the content of assistant messages with the provided completion IDs. + + This function processes messages in reverse order and replaces the content of + assistant messages with the given completion IDs. If completion_ids is a flat + list of integers, it will be treated as a single completion sequence. + + Args: + messages: List of message dictionaries containing conversation history. + completion_ids: Either: + - A single list of token IDs (e.g., [1, 2, 3]) + - A list of completion sequences (e.g., [[1, 2], [3, 4]]) + + Returns: + The modified messages list with assistant responses replaced by token IDs. + + Example: + >>> messages = [{'role': 'user', 'content': 'Hello'}, + ... {'role': 'assistant', 'content': 'Hi there'}] + >>> replace_assistant_response_with_ids(messages, [1, 2, 3]) + [{'role': 'user', 'content': 'Hello'}, + {'role': 'assistant', 'content': [1, 2, 3]}] + """ + # Normalize input to always be list of lists + if isinstance(completion_ids[0], int): + completion_ids = [completion_ids] + + remaining_completions = len(completion_ids) + completion_index = 0 + + for message in reversed(messages): + if message['role'] != 'assistant': + continue + + if completion_index >= remaining_completions: + break + + # Assign completion IDs (starting from last) + message['content'] = completion_ids[-1 - completion_index] + completion_index += 1 + + return messages + + +def patch_save_last_checkpoint(): + import trl + from packaging import version + if version.parse(trl.__version__) >= version.parse('0.20'): + return + + # patch to fix save last_checkpoint https://github.com/modelscope/ms-swift/pull/4969 + from trl.trainer.grpo_trainer import RepeatSampler + if not hasattr(RepeatSampler, 'old_len_func'): + origin_len_func = RepeatSampler.__len__ + + def patched_len(self) -> int: + return (self.num_samples // self.batch_size) * self.batch_size * self.mini_repeat_count * self.repeat_count + + RepeatSampler.__len__ = patched_len + RepeatSampler.old_len_func = origin_len_func diff --git a/swift/trainers/rlhf_trainer/vllm_client.py b/swift/trainers/rlhf_trainer/vllm_client.py index 1f56921cf4..89c0dcfecd 100644 --- a/swift/trainers/rlhf_trainer/vllm_client.py +++ b/swift/trainers/rlhf_trainer/vllm_client.py @@ -9,15 +9,13 @@ import requests import torch -from dacite import from_dict from packaging import version from requests import ConnectionError from torch import nn from transformers.utils import is_torch_cuda_available from swift.llm import AdapterRequest, RolloutInferRequest, Template -from swift.llm.infer.protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, GymRolloutResponseChoice, - RequestConfig, RolloutResponseChoice) +from swift.llm.infer.protocol import RequestConfig, RolloutOutput from swift.plugin import Metric from swift.utils import is_trl_available, is_vllm_ascend_available, is_vllm_available @@ -151,7 +149,7 @@ def process_chunk(i, chunk): return resp_data = response.json() - results[i] = self.parse_resp_data(resp_data) + results[i] = [RolloutOutput.parse_obj(resp) for resp in resp_data] except Exception as e: errors[i] = e @@ -278,19 +276,3 @@ def close_communicator(self): logger.warning(f'Server {i} close failed: {response.text}') except Exception as e: logger.warning(f'Error closing server {i} communicator: {str(e)}') - - def parse_resp_data(self, resp_data): - if self.use_gym_env: - choice_cls = GymRolloutResponseChoice - elif self.use_async_engine: - choice_cls = RolloutResponseChoice - else: - choice_cls = ChatCompletionResponseChoice - result = [ - ChatCompletionResponse( - choices=[from_dict(data_class=choice_cls, data=c) for c in resp['choices']], - **{k: v - for k, v in resp.items() if k != 'choices'}) for resp in resp_data - ] - - return result diff --git a/swift/trainers/sequence_parallel/utils.py b/swift/trainers/sequence_parallel/utils.py index d6cb7e90ff..32a2fe89a3 100644 --- a/swift/trainers/sequence_parallel/utils.py +++ b/swift/trainers/sequence_parallel/utils.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: try: from ..rlhf_trainer import GRPOTrainer - from ..rlhf_trainer.grpo_trainer import InputsType + from ..rlhf_trainer.grpo_trainer import DataType except ImportError: pass # Conditional import for profiling decorator @@ -513,7 +513,7 @@ def _padding_free_output_hook(module, args, kwargs, result): def _get_per_token_logps_and_entropies_grpo( self: 'GRPOTrainer', model: torch.nn.Module, - inputs: 'InputsType', + inputs: 'DataType', sp_instance: SequenceParallel, compute_entropy: bool = False) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Get per token logps for GRPO sequence parallel training""" diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index 1ebad29dfd..df96119a93 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -16,4 +16,5 @@ show_layers, time_synchronize, unwrap_model_for_generation) from .utils import (add_version_to_work_dir, check_json_format, copy_files_by_pattern, deep_getattr, find_free_port, format_time, get_env_args, import_external_file, json_parse_to_dict, lower_bound, parse_args, - patch_getattr, read_multi_line, seed_everything, split_list, subprocess_run, test_time, upper_bound) + patch_getattr, read_multi_line, remove_response, seed_everything, split_list, subprocess_run, + test_time, upper_bound) diff --git a/swift/utils/utils.py b/swift/utils/utils.py index 9287f123e0..f4fc995a60 100644 --- a/swift/utils/utils.py +++ b/swift/utils/utils.py @@ -371,3 +371,21 @@ def json_parse_to_dict(value: Union[str, Dict, None], strict: bool = True) -> Un logger.error(f"Unable to parse string: '{value}'") raise return value + + +def remove_response(messages) -> Optional[str]: + """ + Removes and returns the content of the last message if its role is 'assistant'. + + Args: + messages (List[Dict]): + A list of message dictionaries, each typically containing a 'role' and 'content' key. + + Returns: + Optional[str]: + The content of the removed 'assistant' message if present; + otherwise, returns None. The original messages list is modified in place. + """ + last_role = messages[-1]['role'] if messages else None + if last_role == 'assistant': + return messages.pop()['content']