Skip to content

[BREAKING] Refactor Scheduler and GRPOTrainer for Flexible Multi-Turn Training #5307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 32 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b27fb9b
wip
hjh0119 Jul 29, 2025
3a9f23c
wip
hjh0119 Jul 29, 2025
061a2d5
revert prompt_ids
hjh0119 Jul 30, 2025
b1ac613
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Jul 31, 2025
8d2b170
remove tokenizer in reward
hjh0119 Aug 1, 2025
3da665b
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 1, 2025
0dea74d
encode ids
hjh0119 Aug 1, 2025
1114cea
wip replace ids
hjh0119 Aug 3, 2025
ebf1b35
fix adv
hjh0119 Aug 4, 2025
67d124a
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 4, 2025
ffcd9b4
wip
hjh0119 Aug 4, 2025
05ae143
wip
hjh0119 Aug 4, 2025
fdb1ae8
wip
hjh0119 Aug 4, 2025
725e1f6
wip
hjh0119 Aug 5, 2025
dcc9052
wip
hjh0119 Aug 5, 2025
f72712e
wip
hjh0119 Aug 6, 2025
a8aeb73
refactor v1
hjh0119 Aug 8, 2025
8555f31
rename completion id
hjh0119 Aug 8, 2025
4b30574
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 8, 2025
4c0ada9
fix typo & bugs
hjh0119 Aug 8, 2025
b9e4c04
compute loss for dynamic batch size
hjh0119 Aug 10, 2025
6ad2700
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 10, 2025
eea4485
fix tiny bugs
hjh0119 Aug 11, 2025
5b927de
dynamic rollout advantages
hjh0119 Aug 11, 2025
b0c52b7
fix score_completions
hjh0119 Aug 11, 2025
e25c2e4
fix split mini batch
hjh0119 Aug 11, 2025
bf035b3
docstring for split mini batches
hjh0119 Aug 11, 2025
60fb903
fix gather device
hjh0119 Aug 11, 2025
5108fce
fix rollout async infer
hjh0119 Aug 12, 2025
7deeec3
Merge remote-tracking branch 'origin' into grpo-use-ids
hjh0119 Aug 13, 2025
69f6e43
Merge branch 'grpo-use-ids' of github.com:hjh0119/swift into grpo-use…
hjh0119 Aug 13, 2025
0812129
thinking tips scheduler
hjh0119 Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 18 additions & 166 deletions swift/llm/infer/infer_engine/grpo_vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

import torch
from tqdm.asyncio import tqdm_asyncio
from vllm.outputs import RequestOutput

from swift.llm import InferRequest, RolloutInferRequest, Template, VllmEngine
from swift.plugin import Metric, multi_turns
from swift.plugin.context_manager import ContextManager, context_managers
from swift.plugin.env import Env, envs
from swift.plugin.multi_turn import MultiTurnScheduler
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, GymRolloutResponseChoice,
RequestConfig, RolloutResponseChoice)
from ..protocol import ChatCompletionResponse, ChatCompletionResponseChoice, ChatMessage, RequestConfig, RolloutOutput
from .utils import AdapterRequest

try:
Expand Down Expand Up @@ -150,19 +150,27 @@ def infer(
template: Optional[Template] = None,
use_tqdm: Optional[bool] = None,
adapter_request: Optional[AdapterRequest] = None,
) -> List[ChatCompletionResponse]:
assert not self.use_async_engine, 'for Async Engine, use infer_async instead'
return super().infer(
) -> List[RolloutOutput]:
res = super().infer(
infer_requests,
request_config,
metrics,
template=template,
use_tqdm=use_tqdm,
adapter_request=adapter_request,
)
if not isinstance(res, list):
res = [res]
for i, result in enumerate(res):
if not isinstance(result, RolloutOutput):
if not isinstance(result, ChatCompletionResponse):
raise TypeError('Result must be a ChatCompletionResponse or RolloutOutput instance.')
res[i] = RolloutOutput(response=result)

return res

async def async_infer(self,
infer_requests: List[Union[RolloutInferRequest, Dict[str, Any]]],
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
Expand All @@ -172,150 +180,10 @@ async def async_infer(self,
request_config = RequestConfig()
assert request_config.n == 1

async def _infer_async_single(infer_request: Union[RolloutInferRequest, Dict[str, Any]],
request_config: Optional[RequestConfig] = None,
**kwargs):
if isinstance(infer_request, Dict):
infer_request = RolloutInferRequest(**infer_request)

# Route to appropriate sampling controller
if self.use_gym_env:
return await self._gym_sampling_controller(infer_request, request_config, **kwargs)
else:
return await self._multi_turn_sampling_controller(infer_request, request_config, **kwargs)

tasks = [_infer_async_single(infer_request, request_config, **kwargs) for infer_request in infer_requests]
tasks = [self.infer_async(infer_request, request_config, **kwargs) for infer_request in infer_requests]
if use_tqdm is None:
use_tqdm = len(infer_requests) > 1
return await self._batch_infer_stream(tasks, request_config.stream, use_tqdm, metrics)

async def _gym_sampling_controller(self, infer_request: RolloutInferRequest, request_config: RequestConfig,
**kwargs) -> ChatCompletionResponse:
"""Gym environment-based sampling controller."""
# Create environment and context manager
env_config = infer_request.data_dict.get('env_config', {})
env = self._create_env(env_config)
ctx_config = infer_request.data_dict.get('ctx_config', {})
context_manager = self._create_context_manager(ctx_config)

try:
# Environment reset
observation, info, system_message = await env.reset(infer_request)

# Initialize conversation
messages = []
if system_message:
messages.append({'role': 'system', 'content': system_message})
messages.append({'role': 'user', 'content': observation})

current_request = deepcopy(infer_request)
current_turn = 1
done = False
total_reward = 0.0
step_rewards = []
trajectory_id = f'{id(infer_request)}_{hash(str(infer_request))}'
trajectory_info = [info]

while True:
# Apply context management
messages = context_manager.manage_context(messages, trajectory_id)
current_request.messages = messages
# Remove any previous assistant response for generation
InferRequest.remove_response(current_request.messages)

# Generate LLM response
result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs)
result_choice: RolloutResponseChoice = result.choices[0]

completion = result_choice.message.content
messages.append({'role': 'assistant', 'content': completion})

# Environment step
next_observation, reward, done, step_info = await env.step(deepcopy(messages))

# Accumulate rewards
total_reward += reward
step_rewards.append(reward)
trajectory_info.append(step_info)

if done or current_turn > self.max_turns:
break

messages.append({'role': 'user', 'content': next_observation})
current_request.messages = messages
current_turn += 1

# Create final result with gym-specific information
final_choice = GymRolloutResponseChoice(
index=result_choice.index,
message=result_choice.message,
finish_reason=result_choice.finish_reason,
logprobs=result_choice.logprobs,
messages=messages,
trajectory_id=trajectory_id,
total_reward=total_reward,
step_rewards=step_rewards,
trajectory_info=trajectory_info)

return ChatCompletionResponse(
model=self.model_name, choices=[final_choice], usage=result.usage, id=f'gym_{trajectory_id}')

finally:
await self._close_env_async(env)

async def _multi_turn_sampling_controller(self, infer_request: RolloutInferRequest, request_config: RequestConfig,
**kwargs) -> ChatCompletionResponse:
"""Multi-turn scheduler-based sampling controller."""
current_request = infer_request
current_turn = 1
info_dict = {}
while True:
messages = current_request.messages
if current_turn == 1 or not messages[-1]['content']:
# If it's the first turn or the last message content is empty(dummy), remove the response
InferRequest.remove_response(messages)

result: ChatCompletionResponse = await self.infer_async(current_request, request_config, **kwargs)
result_choice: RolloutResponseChoice = result.choices[0]

completion = result_choice.message.content
if messages[-1]['role'] == 'assistant':
messages[-1]['content'] += completion
else:
messages.append({'role': 'assistant', 'content': completion})

if self.multi_turn_scheduler:
should_stop = self.multi_turn_scheduler.check_finished(current_request, result_choice, current_turn)
else:
should_stop = True

if self.max_turns:
should_stop = should_stop or (current_turn >= self.max_turns)

if should_stop:
result_choice.messages = messages
info_dict['num_turns'] = current_turn
for key, value in info_dict.items():
if key in ['images', 'audios', 'videos']:
value = MultiModalRequestMixin.to_base64(value)
if hasattr(result_choice, key):
setattr(result_choice, key, value)
else:
result_choice.multi_turn_infos[key] = value
return result

ret = self.multi_turn_scheduler.step(current_request, result_choice, current_turn)
if isinstance(ret, tuple):
current_request, info_dict = ret
else:
current_request = ret
info_dict = {}
assert isinstance(current_request, RolloutInferRequest)
if current_request.messages[-1]['role'] == 'assistant':
# Add a dummy response to allow engine to continue generating
current_request.messages.append({'role': 'assistant', 'content': None})

current_turn += 1
return self._batch_infer_stream(tasks, request_config.stream, use_tqdm, metrics)

async def _batch_infer_stream(self,
tasks,
Expand All @@ -339,17 +207,7 @@ async def _new_run(task):
new_tasks = [_new_run(task) for task in tasks]
return await self.batch_run(new_tasks)

async def _close_env_async(self, env: Env):
"""Asynchronously close environment."""
try:
if hasattr(env, 'close') and asyncio.iscoroutinefunction(env.close):
await env.close()
elif hasattr(env, 'close'):
env.close()
except Exception:
pass

def _create_chat_completion_response(self, result, template: Template, request_config,
def _create_chat_completion_response(self, result: 'RequestOutput', template: Template, request_config,
request_id) -> ChatCompletionResponse:
assert result is not None
num_generated_tokens = sum(len(output.token_ids) for output in result.outputs)
Expand All @@ -360,13 +218,7 @@ def _create_chat_completion_response(self, result, template: Template, request_c
response = template.decode(output.token_ids)
logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs)
toolcall = self._get_toolcall(response, template)

if self.use_gym_env:
choice_cls = GymRolloutResponseChoice
elif self.use_async_engine:
choice_cls = RolloutResponseChoice
else:
choice_cls = ChatCompletionResponseChoice
choice_cls = ChatCompletionResponseChoice

token_ids = template.skip_stop_tokens(output.token_ids) if request_config.return_details else None
choice = choice_cls(
Expand Down
7 changes: 6 additions & 1 deletion swift/llm/infer/infer_engine/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,10 @@ async def _infer_full_async(
generation_config: SamplingParams,
adapter_request: Optional[AdapterRequest],
request_config: RequestConfig,
request_id: Optional[str] = None,
) -> Union[ChatCompletionResponse, EmbeddingResponse]:
request_id = random_uuid()
if request_id is None:
request_id = random_uuid()
result_generator = self._add_request(inputs, generation_config, request_id, adapter_request=adapter_request)
result = None
async for result in result_generator:
Expand Down Expand Up @@ -641,6 +643,9 @@ async def infer_async(
'adapter_request': adapter_request,
'request_config': request_config,
}
if hasattr(infer_request, 'uuid') and infer_request.uuid:
# RolloutInferRequest
kwargs.update({'request_id': infer_request.uuid})
if pre_infer_hook:
kwargs = pre_infer_hook(kwargs)
if request_config.stream:
Expand Down
67 changes: 50 additions & 17 deletions swift/llm/infer/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import json
from PIL import Image
from pydantic import BaseModel
from pydantic import BaseModel, Field

from ..template import InferRequest
from ..utils import Messages, Tool
Expand Down Expand Up @@ -299,21 +299,6 @@ class EmbeddingResponse:
created: int = field(default_factory=lambda: int(time.time()))


@dataclass
class RolloutResponseChoice(ChatCompletionResponseChoice):
messages: Optional[Messages] = None
images: Optional[List[str]] = None
multi_turn_infos: Dict[str, Any] = field(default_factory=dict)


@dataclass
class GymRolloutResponseChoice(RolloutResponseChoice):
trajectory_id: str = None
total_reward: float = 0.0
step_rewards: List[float] = None
trajectory_info: List[Dict[str, Any]] = None


@dataclass
class CompletionResponseChoice:
index: int
Expand All @@ -325,7 +310,7 @@ class CompletionResponseChoice:
@dataclass
class ChatCompletionResponse:
model: str
choices: List[Union[ChatCompletionResponseChoice, RolloutResponseChoice, GymRolloutResponseChoice]]
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
id: str = field(default_factory=lambda: f'chatcmpl-{random_uuid()}')
object: str = 'chat.completion'
Expand All @@ -339,6 +324,54 @@ def to_cmpl_response(self) -> 'CompletionResponse':
return CompletionResponse(self.model, choices, self.usage, id_, created=self.created)


class RolloutOutput(BaseModel):
"""
Output structure for rollout.

Attributes:
response (ChatCompletionResponse):
The model's response

messages (Optional[Messages]):
(Optional) Conversation history for the final rollout; required for multi-turn scenarios.
NOTE:
- If provided, this messages sequence will overwrite the original messages.
- If not provided, 'response' will be appended as the latest turn in the original messages.
- For multi-turn training, you need to manually return the updated messages, including the full history.
- The messages should include the latest assistant response as the final message.

response_token_ids (Optional[List[List[int]]]):
(Optional) Token IDs generated at each rollout turn.
If provided, the training process will skip tokenizing the response.

response_loss_mask (Optional[List[List[int]]]):
(Optional) Loss masks corresponding to each rollout turn.
If provided, the training process will skip computing loss masks for the response (as controlled by the `loss_scale` parameter). # noqa

rollout_infos (Dict[str, Any]):
(Optional) Additional rollout information. This must be JSON-serializable.
"""
response: ChatCompletionResponse
# multi turn
messages: Optional[Messages] = None
response_token_ids: List[List[int]] = Field(default_factory=list)
response_loss_mask: List[List[int]] = Field(default_factory=list)
rollout_infos: Dict[str, Any] = Field(default_factory=dict)

def model_post_init(self, __context):
# Ensure multimodal data in rollout_infos is serializable (e.g., images to base64)
super().model_post_init(__context)
self.mminfo_to_serializable()

def mminfo_to_serializable(self):
mm_keys = ['images', 'audios', 'videos']

for key, value in self.rollout_infos.items():
if key in mm_keys:
# Convert multimodal content to base64 for serialization
self.rollout_infos[key] = MultiModalRequestMixin.to_base64(value)


@dataclass
class CompletionResponse:
model: str
Expand Down
Loading
Loading