Skip to content

Commit f104dfa

Browse files
AniZpZjsfanfanfanjsfanfanfan
authored
[sglang, rollout] feat: support sglang as rollout engine in fully async policy (verl-project#4191)
### What does this PR do? Extend the fully async policy recipe by adding SGLang as an alternative rollout engine to vLLM when using FSDP ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: jsfanfanfan <2981866535@qq.com> Co-authored-by: jsfanfanfan <2981856535@qq.com> Co-authored-by: jsfanfanfan <71052636+jsfanfanfan@users.noreply.github.com>
1 parent 5053d42 commit f104dfa

File tree

12 files changed

+636
-129
lines changed

12 files changed

+636
-129
lines changed

verl/experimental/fully_async_policy/agent_loop/agent_loop.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
get_trajectory_info,
3232
)
3333
from verl.experimental.agent_loop.prometheus_utils import update_prometheus_config
34-
from verl.experimental.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica
3534
from verl.protocol import DataProto
3635
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup
3736
from verl.utils.rollout_trace import (
@@ -219,7 +218,21 @@ def __init__(
219218
self.reward_model_manager = None
220219
self.reward_router_address = None
221220
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
222-
self.rollout_replica_class = FullyAsyncvLLMReplica
221+
222+
# Select rollout replica class based on rollout name
223+
rollout_name = config.actor_rollout_ref.rollout.name
224+
if rollout_name == "sglang":
225+
from verl.experimental.fully_async_policy.sglang_rollout.sglang_async_server import FullyAsyncSGLangReplica
226+
227+
self.rollout_replica_class = FullyAsyncSGLangReplica
228+
print("[FullyAsyncAgentLoopManager] SGLang replica class selected")
229+
elif rollout_name == "vllm":
230+
from verl.experimental.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica
231+
232+
self.rollout_replica_class = FullyAsyncvLLMReplica
233+
print("[FullyAsyncAgentLoopManager] vLLM replica class selected")
234+
else:
235+
raise ValueError(f"Unsupported rollout name: {rollout_name}. Supported values are 'sglang' and 'vllm'.")
223236

224237
self.rm_resource_pool = rm_resource_pool
225238
self.rollout_replicas = None
@@ -331,5 +344,27 @@ async def wake_up(self):
331344
async def sleep(self):
332345
await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])
333346

347+
async def reset_prefix_cache(self):
348+
print("[FullyAsyncAgentLoopManager] Reset prefix cache ...")
349+
# await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas])
350+
# Note: debug
351+
timeout = 5.0
352+
353+
async def reset_one(idx, replica):
354+
print(f"[reset_prefix_cache] start replica={idx}")
355+
try:
356+
await asyncio.wait_for(replica.reset_prefix_cache(), timeout=timeout)
357+
except asyncio.TimeoutError:
358+
print(f"[reset_prefix_cache] TIMEOUT replica={idx} after {timeout}s")
359+
return
360+
except Exception as e:
361+
print(f"[reset_prefix_cache] ERROR replica={idx}: {e!r}")
362+
return
363+
print(f"[reset_prefix_cache] done replica={idx}")
364+
365+
tasks = [reset_one(i, replica) for i, replica in enumerate(self.rollout_replicas)]
366+
await asyncio.gather(*tasks, return_exceptions=True)
367+
print("[FullyAsyncAgentLoopManager] Reset prefix cache finished")
368+
334369
async def clear_kv_cache(self):
335370
await asyncio.gather(*[replica.clear_kv_cache() for replica in self.rollout_replicas])
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
# Copyright 2025 Meituan Ltd. and/or its affiliates
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
import logging
18+
import os
19+
import threading
20+
21+
import torch
22+
from omegaconf import DictConfig
23+
from ray.util.collective import collective
24+
25+
from verl.single_controller.base.decorator import Dispatch, register
26+
from verl.utils.device import get_torch_device
27+
28+
logger = logging.getLogger(__file__)
29+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
30+
31+
32+
class BaseDetachNcclSync:
33+
_bucket_size_mb = 1024.0
34+
_sync_history = []
35+
_max_history_size = 20
36+
_last_avg_bucket_size = 1024.0
37+
38+
def __init__(self, config: DictConfig, role: str):
39+
self._bg_loop = asyncio.new_event_loop()
40+
self._bg_thread = threading.Thread(
41+
target=self._start_background_loop, args=(self._bg_loop,), name="rollout_actor_async_worker", daemon=True
42+
)
43+
self._bg_thread.start()
44+
logger.info(f"[DetachNcclSync] Background thread for SGLang sync started. PID: {os.getpid()}")
45+
46+
@classmethod
47+
def get_bucket_size_mb(cls):
48+
return cls._bucket_size_mb
49+
50+
@classmethod
51+
def get_last_avg_bucket_size(cls):
52+
return cls._last_avg_bucket_size
53+
54+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
55+
def get_last_avg_bucket_size_remote(self):
56+
return BaseDetachNcclSync._last_avg_bucket_size
57+
58+
@classmethod
59+
def record_sync_metrics(cls, bucket_size_mb, sync_time):
60+
"""Dynamically adjust the bucket size based on past synchronization times."""
61+
bucket_size_mb_value = bucket_size_mb[0] if isinstance(bucket_size_mb, list) else bucket_size_mb
62+
print(f"[DetachNcclSync] sync_metrics: bucket_size_mb={bucket_size_mb_value:.2f}MB, sync_time={sync_time:.2f}s")
63+
cls._sync_history.append((bucket_size_mb_value, sync_time))
64+
if len(cls._sync_history) > cls._max_history_size:
65+
cls._sync_history.pop(0)
66+
67+
MIN_BUCKET_SIZE_MB = 512
68+
MAX_BUCKET_SIZE_MB = 8192 # 8GB
69+
70+
if len(cls._sync_history) < 4:
71+
cls._bucket_size_mb = min(MAX_BUCKET_SIZE_MB, cls._bucket_size_mb * 1.5)
72+
else:
73+
times = [t for _, t in cls._sync_history]
74+
buckets = [b for b, _ in cls._sync_history]
75+
recent_avg_time = sum(times[-2:]) / 2
76+
previous_avg_time = sum(times[-4:-2]) / 2
77+
recent_avg_bucket = sum(buckets[-2:]) / 2
78+
previous_avg_bucket = sum(buckets[-4:-2]) / 2
79+
80+
performance_improved = recent_avg_time < previous_avg_time
81+
bucket_increased = recent_avg_bucket > previous_avg_bucket
82+
time_change_ratio = (
83+
abs(recent_avg_time - previous_avg_time) / previous_avg_time if previous_avg_time > 0 else 0.0
84+
)
85+
86+
if time_change_ratio > 0.2:
87+
increase_step, decrease_step = 1.2, 0.8
88+
elif time_change_ratio > 0.1:
89+
increase_step, decrease_step = 1.1, 0.9
90+
elif time_change_ratio > 0.05:
91+
increase_step, decrease_step = 1.05, 0.95
92+
else:
93+
increase_step, decrease_step = 1.02, 0.98
94+
95+
should_increase = (performance_improved and bucket_increased) or (
96+
not performance_improved and not bucket_increased
97+
)
98+
step = increase_step if should_increase else decrease_step
99+
new_size = cls._bucket_size_mb * step
100+
cls._bucket_size_mb = min(MAX_BUCKET_SIZE_MB, max(MIN_BUCKET_SIZE_MB, new_size))
101+
102+
def _start_background_loop(self, loop):
103+
asyncio.set_event_loop(loop)
104+
try:
105+
loop.run_forever()
106+
except Exception as e:
107+
logger.error(f"[DetachNcclSync] Background loop crashed: {e}")
108+
109+
def _run_async_safely(self, coro):
110+
if not self._bg_thread.is_alive():
111+
raise RuntimeError("Background thread for SGLang sync is not running!")
112+
113+
future = asyncio.run_coroutine_threadsafe(coro, self._bg_loop)
114+
return future.result()
115+
116+
def __del__(self):
117+
if hasattr(self, "_bg_loop") and self._bg_loop.is_running():
118+
self._bg_loop.call_soon_threadsafe(self._bg_loop.stop)
119+
if hasattr(self, "_bg_thread") and self._bg_thread.is_alive():
120+
self._bg_thread.join(timeout=1.0)
121+
122+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
123+
def init_checkpoint_engine(self, rank_offset: int, actor_num: int, rollout_num: int):
124+
from .checkpoint_engine import CheckpointEngine
125+
126+
current_rank = torch.distributed.get_rank() + rank_offset
127+
actor_ranks = list(range(actor_num))
128+
rollout_ranks = [rank + actor_num for rank in range(rollout_num)]
129+
assert rank_offset == 0 or rank_offset == actor_num
130+
131+
self.checkpoint_engine = CheckpointEngine(
132+
current_rank, actor_ranks, rollout_ranks, self.config.checkpoint_engine.device_buffer_size_M
133+
)
134+
135+
@staticmethod
136+
def get_inference_model(rollout):
137+
"""
138+
Get models according to different types of inference_engine
139+
Args:
140+
rollout: rollout object
141+
Returns:
142+
model: model object (for vllm) or rollout object itself (for sglang)
143+
"""
144+
inference_engine = rollout.inference_engine
145+
if hasattr(inference_engine, "llm_engine"):
146+
inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
147+
elif hasattr(inference_engine, "worker"):
148+
inference_model = inference_engine.worker.model_runner.model
149+
else:
150+
raise AttributeError(
151+
f"Unsupported inference_engine type: {type(inference_engine)}. "
152+
f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)."
153+
)
154+
return inference_model
155+
156+
def _sync_sglang_weights(self, inference_model, params, sync_group_name):
157+
bucket_size_bytes = int(self.get_bucket_size_mb() * 1024 * 1024)
158+
actual_bucket_sizes = []
159+
current_batch = []
160+
current_batch_size = 0
161+
162+
def flush_batch():
163+
if current_batch:
164+
actual_bucket_sizes.append(current_batch_size / (1024 * 1024))
165+
self._run_async_safely(self.update_weights(inference_model, iter(current_batch)))
166+
get_torch_device().synchronize()
167+
current_batch.clear()
168+
169+
for key, shape, dtype in self._weights_info:
170+
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
171+
if self._is_actor:
172+
assert key in params
173+
origin_data = params[key]
174+
if hasattr(origin_data, "full_tensor"):
175+
origin_data = origin_data.full_tensor()
176+
if torch.distributed.get_rank() == 0:
177+
tensor.copy_(origin_data)
178+
collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)
179+
180+
tensor_size = tensor.numel() * tensor.element_size()
181+
current_batch.append((key, tensor))
182+
current_batch_size += tensor_size
183+
184+
if current_batch_size >= bucket_size_bytes:
185+
flush_batch()
186+
current_batch_size = 0
187+
188+
flush_batch()
189+
cls = type(self)
190+
cls._last_avg_bucket_size = (
191+
sum(actual_bucket_sizes) / len(actual_bucket_sizes) if actual_bucket_sizes else self.get_bucket_size_mb()
192+
)
193+
194+
# Resume kv_cache after weights sync to restore GPU memory released during pause
195+
if self._is_rollout and self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
196+
self._run_async_safely(inference_model.resume_memory_occupation(tags=["kv_cache"]))
197+
198+
def _sync_vllm_weights(self, inference_model, params, sync_group_name):
199+
for key, shape, dtype in self._weights_info:
200+
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
201+
if self._is_actor:
202+
assert key in params
203+
origin_data = params[key]
204+
if hasattr(origin_data, "full_tensor"):
205+
origin_data = origin_data.full_tensor()
206+
if torch.distributed.get_rank() == 0:
207+
tensor.copy_(origin_data)
208+
collective.broadcast(tensor, src_rank=0, group_name=sync_group_name)
209+
if self._is_rollout:
210+
inference_model.load_weights([(key, tensor)])
211+
212+
async def update_weights(self, inference_engine, params):
213+
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
214+
215+
await sgl_update_weights(
216+
engine=inference_engine,
217+
params_batch=params,
218+
device_mesh_key="infer_tp",
219+
device_mesh=self.rollout_device_mesh,
220+
)
221+
222+
if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
223+
await inference_engine.flush_cache()

verl/experimental/fully_async_policy/detach_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,14 @@ def prepare_single_generation_data(batch_dict, config) -> DataProto:
7171
batch_keys_to_pop = []
7272
non_tensor_batch_keys_to_pop = []
7373

74-
full_batch.pop(
75-
batch_keys=batch_keys_to_pop,
76-
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
77-
)
74+
existing_batch_keys = [k for k in batch_keys_to_pop if k in full_batch.batch.keys()]
75+
existing_non_tensor_keys = [k for k in non_tensor_batch_keys_to_pop if k in full_batch.non_tensor_batch.keys()]
76+
77+
if existing_batch_keys or existing_non_tensor_keys:
78+
full_batch.pop(
79+
batch_keys=existing_batch_keys,
80+
non_tensor_batch_keys=existing_non_tensor_keys,
81+
)
7882

7983
# Setting selected agent, that supports partial
8084
if config.actor_rollout_ref.rollout.multi_turn.enable:

0 commit comments

Comments
 (0)