Skip to content

Commit 362dcfd

Browse files
authored
feat(rollout): add auto scaling scheduler for vLLM, aligning with SGLang (RLinf#253)
Signed-off-by: Bo Dai <daibo@infini-ai.com>
1 parent e963611 commit 362dcfd

File tree

10 files changed

+693
-214
lines changed

10 files changed

+693
-214
lines changed

.github/workflows/scheduler-tests.yml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,12 @@ jobs:
4747
export REPO_PATH=$(pwd)
4848
source switch_env reason
4949
bash tests/e2e_tests/dynamic_scheduler/run.sh qwen2.5-1.5b-grpo-dynamic-mg-sgl
50+
51+
- name: Megatron vLLM
52+
timeout-minutes: 20
53+
run: |
54+
export PYTHONPATH=$(pwd)/Megatron-LM-011:$(pwd)/params_resharding_release
55+
cd rlinf
56+
export REPO_PATH=$(pwd)
57+
source switch_env reason
58+
bash tests/e2e_tests/dynamic_scheduler/run.sh qwen2.5-1.5b-grpo-dynamic-mg-vllm

docs/source-en/rst_source/tutorials/scheduler/dynamic-scheduling.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Dynamic Scheduling
44

55
Dynamic scheduling adjusts and migrates resources among components (actor / rollout / inference)
66
in real time during training to improve overall throughput and resource utilization.
7-
It relies on Megatron-LM's online scaling (second-level elasticity) and SGLang's migrate capability
7+
It relies on Megatron-LM's online scaling (second-level elasticity) and SGLang/vLLM's migrate capability
88
to reallocate GPU resources without stopping training.
99

1010
What is Dynamic Scheduling?

docs/source-zh/rst_source/tutorials/scheduler/dynamic-scheduling.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
动态调度(Dynamic Scheduling)
66
是在训练运行期根据系统各组件(actor / rollout / inference)的实时状态,
77
对资源进行秒级动态调整与迁移,以提升整体吞吐与资源利用率的机制。
8-
它依托于 Megatron-LM 的在线扩缩容能力(秒级扩缩)与 SGLang 的请求迁移功能,
8+
它依托于 Megatron-LM 的在线扩缩容能力(秒级扩缩)与 SGLang/vLLM 的请求迁移功能,
99
在不终止训练的前提下,对集群中的 GPU 资源进行弹性重分配。
1010

1111
什么是动态调度?

rlinf/data/io_struct.py

Lines changed: 59 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def get_seq_length(
4949
class RolloutRequest:
5050
"""
5151
Attr
52-
input_ids: List of input token IDs for rollout
52+
input_ids: list of input token IDs for rollout
5353
n: Number of completions to generate for each input
5454
image_data: list of image data (bytes or URLs) for multimodal inputs
55-
answers: List of answers for the requests, where each answer can be either a list of strings (for typical tasks) or a dict (for VQA tasks), if available.
55+
answers: list of answers for the requests, where each answer can be either a list of strings (for typical tasks) or a dict (for VQA tasks), if available.
5656
multi_modal_inputs: list of multi-modal inputs for the requests
5757
"""
5858

@@ -66,7 +66,7 @@ def to_seq_group_infos(self) -> list["SeqGroupInfo"]:
6666
"""Convert the RolloutRequest into a list of SeqGroupInfo objects.
6767
6868
Returns:
69-
List[SeqGroupInfo]: A list of SeqGroupInfo objects.
69+
list[SeqGroupInfo]: A list of SeqGroupInfo objects.
7070
"""
7171
return [
7272
SeqGroupInfo(
@@ -102,12 +102,12 @@ class SeqGroupInfo:
102102
103103
Attributes:
104104
id (int): Unique identifier for the sequence group.
105-
input_ids (List[int]): List of input IDs of the original sequence.
106-
answer (Union[List[str], Dict]): List of answers of the original sequence.(One sequence can have multiple equivalent answers), or a dict in case of vqa task.
105+
input_ids (list[int]): list of input IDs of the original sequence.
106+
answer (Union[list[str], dict]): list of answers of the original sequence.(One sequence can have multiple equivalent answers), or a dict in case of vqa task.
107107
group_size (int): Number of sequences in the group.
108108
idx_completed (set[int]): Set of indices for sequences that have completed rollout and are ready for evaluation.
109109
idx_aborted (set[int]): Set of indices for sequences that have been aborted. These sequences need to be re-rolled out before they can be evaluated.
110-
results (List[Optional[Dict]]): List storing result dictionaries for each sequence, or None if not yet available.
110+
results (list[Optional[dict]]): list storing result for each sequence, or None if not yet available.
111111
"""
112112

113113
id: int
@@ -116,7 +116,9 @@ class SeqGroupInfo:
116116
group_size: int
117117
idx_completed: set[int] = field(init=False, compare=False)
118118
idx_aborted: set[int] = field(init=False, compare=False)
119-
results: list[Optional[dict]] = field(init=False, compare=False)
119+
results: list[Optional[Union[dict, "VllmRequestOutput"]]] = field(
120+
init=False, compare=False
121+
)
120122
image_data: Optional[list] = None
121123
multi_modal_inputs: Optional[dict] = None
122124

@@ -126,6 +128,18 @@ def __post_init__(self):
126128
self.idx_aborted = set()
127129
self.results = [None for _ in range(self.group_size)]
128130

131+
def record_vllm_result(self, idx: int, result: "VllmRequestOutput", logger=None):
132+
finish_reason = result.outputs[0].finish_reason
133+
if finish_reason is None or finish_reason == "abort":
134+
self.idx_aborted.add(idx)
135+
else:
136+
self.idx_completed.add(idx)
137+
138+
if self.results[idx] is None:
139+
self.results[idx] = result
140+
else:
141+
self.results[idx].add(next_output=result, aggregate=True)
142+
129143
def record_sglang_result(self, idx: int, result: dict, logger=None):
130144
"""Record a single sglang execution result and update internal tracking.
131145
@@ -139,7 +153,7 @@ def record_sglang_result(self, idx: int, result: dict, logger=None):
139153
Args:
140154
idx: int
141155
The index of the sequence within the group (0 <= idx < group_size).
142-
result: Dict
156+
result: dict
143157
Result of SGLang. Expected to contain at least:
144158
- "meta_info": {"finish_reason": {"type": FinishReasonEnum}}
145159
- "output_ids": a list (or list-like) of output identifier elements
@@ -300,13 +314,14 @@ def from_vllm_results(
300314
return_logprobs: bool = False,
301315
) -> "RolloutResult":
302316
"""
303-
Create a RolloutResult from the given vLLM results.
317+
Create a RolloutResult from the given vLLM results. every result is generated with n=1,
318+
so its outputs len is 1
304319
305320
Args:
306321
group_size (int): The group size used during rollout.
307322
results (list[VllmRequestOutput]): The rollout results from vLLM.
308323
answers (Optional[Union[list[str], dict]]): The answers corresponding to the inputs, notably, if task type is vqa, answers is a dict.
309-
multi_modal_inputs (Optional[list[Dict]]): The multi-modal inputs corresponding to the inputs.
324+
multi_modal_inputs (Optional[list[dict]]): The multi-modal inputs corresponding to the inputs.
310325
return_logprobs (bool): Whether to return log probabilities.
311326
312327
Returns:
@@ -325,62 +340,36 @@ def get_logprobs(
325340
logprobs.append(logprob[response_ids[i]].logprob)
326341
return logprobs
327342

328-
num_sequences = len(results) * group_size
329-
330-
if multi_modal_inputs:
331-
mm_inputs = []
332-
for mm_input in multi_modal_inputs:
333-
mm_inputs.extend([mm_input] * group_size)
334-
else:
335-
mm_inputs = None
336-
337343
# for VQA task, answers is a dict
338344
if isinstance(answers, dict):
339345
answers = [answers]
340346

341-
prompt_lengths = []
342-
prompt_ids = []
343-
response_lengths = []
344-
response_ids = []
345-
logprobs = []
346-
is_end = []
347-
response_texts = []
348-
rollout_answers = (
349-
[answer for answer in answers for _ in range(group_size)]
350-
if answers
351-
else None
352-
)
353-
for vllm_result in results:
354-
if vllm_result.prompt_token_ids is not None:
355-
prompt_ids.extend([vllm_result.prompt_token_ids] * group_size)
356-
prompt_lengths.extend([len(vllm_result.prompt_token_ids)] * group_size)
357-
else:
358-
raise NotImplementedError("vllm should return tokenized prompt.")
359-
response_ids.extend(
360-
[list(output.token_ids) for output in vllm_result.outputs]
361-
)
362-
response_texts.extend([output.text for output in vllm_result.outputs])
363-
response_lengths.extend(
364-
[len(output.token_ids) for output in vllm_result.outputs]
365-
)
366-
is_end.extend([vllm_result.finished] * group_size)
367-
if return_logprobs:
368-
logprobs.extend(
369-
[
370-
get_logprobs(list(output.token_ids), output)
371-
for output in vllm_result.outputs
372-
]
347+
# here vllm must return prompt ids because we pass input_ids as input
348+
prompt_ids = [vllm_result.prompt_token_ids for vllm_result in results]
349+
prompt_lengths = [len(vllm_result.prompt_token_ids) for vllm_result in results]
350+
response_ids = [vllm_result.outputs[0].token_ids for vllm_result in results]
351+
response_texts = [vllm_result.outputs[0].text for vllm_result in results]
352+
response_lengths = [
353+
len(vllm_result.outputs[0].token_ids) for vllm_result in results
354+
]
355+
is_end = [vllm_result.finished for vllm_result in results]
356+
if return_logprobs:
357+
logprobs = [
358+
get_logprobs(
359+
list(vllm_result.outputs[0].token_ids), vllm_result.outputs[0]
373360
)
361+
for vllm_result in results
362+
]
374363
result: RolloutResult = RolloutResult(
375364
group_size=group_size,
376-
num_sequence=num_sequences,
377-
answers=rollout_answers,
365+
num_sequence=len(results),
366+
answers=answers,
378367
prompt_ids=prompt_ids,
379368
prompt_lengths=prompt_lengths,
380369
response_ids=response_ids,
381370
response_lengths=response_lengths,
382371
response_texts=response_texts,
383-
multi_modal_inputs=mm_inputs,
372+
multi_modal_inputs=multi_modal_inputs,
384373
is_end=is_end,
385374
)
386375
if return_logprobs:
@@ -400,8 +389,8 @@ def from_sglang_results(
400389
"""Create a MathRolloutResult from the given results and input IDs.
401390
402391
Args:
403-
results (List[Dict]): The rollout results from the model.
404-
input_ids (List[List[int]]): The input IDs for the prompts.
392+
results (list[dict]): The rollout results from the model.
393+
input_ids (list[list[int]]): The input IDs for the prompts.
405394
return_logprobs (bool): Whether to return log probabilities.
406395
"""
407396
assert len(results) == len(input_ids), (
@@ -447,6 +436,16 @@ def from_sglang_seq_group(cls, seq_group: SeqGroupInfo, return_logprobs: bool):
447436
return_logprobs=return_logprobs,
448437
)
449438

439+
@classmethod
440+
def from_vllm_seq_group(cls, seq_group: SeqGroupInfo, return_logprobs: bool):
441+
return cls.from_vllm_results(
442+
seq_group.group_size,
443+
seq_group.results,
444+
answers=[seq_group.answer] * seq_group.group_size,
445+
multi_modal_inputs=[seq_group.multi_modal_inputs] * seq_group.group_size,
446+
return_logprobs=return_logprobs,
447+
)
448+
450449
@staticmethod
451450
def merge_result_list(
452451
rollout_results: list["RolloutResult"],
@@ -550,10 +549,10 @@ def split_result_list_by_group(
550549
If input has multiple RolloutResult objects, split each one and merge the results.
551550
552551
Args:
553-
rollout_results: List of input RolloutResult objects
552+
rollout_results: list of input RolloutResult objects
554553
555554
Returns:
556-
List of RolloutResult objects grouped by group_size
555+
list of RolloutResult objects grouped by group_size
557556
"""
558557
assert len(rollout_results) > 0, "No rollout results to split."
559558

@@ -576,7 +575,7 @@ def _split_single_result_by_group(
576575
rollout_result: The RolloutResult to be split
577576
578577
Returns:
579-
List of split RolloutResult objects
578+
list of split RolloutResult objects
580579
"""
581580
group_size = rollout_result.group_size
582581
num_sequence = rollout_result.num_sequence
@@ -710,7 +709,7 @@ def to_actor_batch(
710709
pad_token (int): Token used for padding, e.g., `tokenizer.pad_token_id`.
711710
712711
Returns:
713-
Dict[str, torch.Tensor]: A dictionary with keys:
712+
dict[str, torch.Tensor]: A dictionary with keys:
714713
715714
input_ids (torch.Tensor):
716715
Concatenated prompt and response token IDs,

rlinf/hybrid_engines/vllm/vllm_0_8_5/worker.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
self.actor_weight_rank = rank_map[
6363
self._rlinf_worker.get_parent_rank(), self.rank
6464
]
65+
self.is_weight_offloaded = False
6566

6667
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
6768
"""Allocate GPU KV cache with the specified kv_cache_config."""
@@ -76,6 +77,7 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
7677

7778
def offload_model_weights(self) -> None:
7879
super().sleep(level=2)
80+
self.is_weight_offloaded = True
7981

8082
def sync_hf_weight(self) -> None:
8183
use_cudagraph = not self.rlinf_config.rollout.enforce_eager
@@ -85,20 +87,20 @@ def sync_hf_weight(self) -> None:
8587
state_dict = self._rlinf_worker.recv(
8688
src_group_name=self._actor_group_name, src_rank=self.actor_weight_rank
8789
)
88-
if self.placement_mode == PlacementMode.COLLOCATED:
89-
# in disaggregated mode, rollout backend will never offload weights
90-
# so we don't need to wake up when placement is disaggregated
90+
if self.is_weight_offloaded:
9191
super().wake_up()
92+
self.is_weight_offloaded = False
9293

9394
model = self.model_runner.model
9495
if colocate:
96+
batch_weights = []
9597
for name, handle in state_dict.items():
9698
func, args = handle
9799
list_args = list(args)
98100
list_args[6] = torch.cuda.current_device()
99101
new_weight: torch.Tensor = func(*list_args)
100-
model.load_weights([(name, new_weight)])
101-
del new_weight
102+
batch_weights.append((name, new_weight))
103+
model.load_weights(batch_weights)
102104
else:
103105
model.load_weights(state_dict.items())
104106
super().compile_or_warm_up_model()

rlinf/scheduler/dynamic_scheduler/scheduler_worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def __init__(
4343
self.components = self.component_placement._components
4444
self.workflow = workflow
4545

46-
assert self.cfg.rollout.rollout_backend == "sglang", (
47-
"only sglang is supported for dynamic scheduler"
46+
assert self.cfg.rollout.rollout_backend in ["sglang", "vllm"], (
47+
"only sglang and vllm are supported for dynamic scheduler"
4848
)
4949
assert self.cfg.actor.training_backend == "megatron", (
5050
"only megatron is supported for dynamic scheduler"

0 commit comments

Comments
 (0)