Skip to content

Commit 83c6bfc

Browse files
authored
refactor: split sync/async vllm worker ([1/2] of refactor vllm worker) (#900)
Signed-off-by: Yuki Huang <[email protected]>
1 parent 9f7825e commit 83c6bfc

File tree

11 files changed

+2179
-2057
lines changed

11 files changed

+2179
-2057
lines changed

docs/design-docs/generation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ A key design principle for generation backends is that they process tokens direc
6262

6363
## VLLM Backend
6464

65-
The VLLM backend (`models/generation/vllm.py`) implements the {py:class}`GenerationInterface <nemo_rl.models.generation.interfaces.GenerationInterface>` to provide efficient text generation using the VLLM library, which is optimized for large language models.
65+
The VLLM backend (`models/generation/vllm/vllm_generation.py`) implements the {py:class}`GenerationInterface <nemo_rl.models.generation.interfaces.GenerationInterface>` to provide efficient text generation using the VLLM library, which is optimized for large language models.
6666

6767
### VllmGeneration Class
6868

docs/guides/grpo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ This Policy object holds a [RayWorkerGroup](../../nemo_rl/distributed/worker_gro
107107

108108
## Fast Generation
109109

110-
We support vLLM through the [VllmGeneration](../../nemo_rl/models/generation/vllm.py) class right now.
110+
We support vLLM through the [VllmGeneration](../../nemo_rl/models/generation/vllm/vllm_generation.py) class right now.
111111

112112
The function [grpo_train](../../nemo_rl/algorithms/grpo.py) contains the core GRPO training loop.
113113

nemo_rl/distributed/ray_actor_environment_registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES
1616

1717
ACTOR_ENVIRONMENT_REGISTRY: dict[str, str] = {
18-
"nemo_rl.models.generation.vllm.VllmGenerationWorker": PY_EXECUTABLES.VLLM,
18+
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": PY_EXECUTABLES.VLLM,
19+
"nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": PY_EXECUTABLES.VLLM,
1920
# Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM.
2021
# This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved.
2122
"nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": PY_EXECUTABLES.VLLM,

nemo_rl/models/generation/vllm.py

Lines changed: 0 additions & 2053 deletions
This file was deleted.
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from nemo_rl.models.generation.vllm.config import VllmConfig
15+
from nemo_rl.models.generation.vllm.vllm_generation import VllmGeneration
16+
from nemo_rl.models.generation.vllm.vllm_worker import VllmGenerationWorker
17+
from nemo_rl.models.generation.vllm.vllm_worker_async import VllmAsyncGenerationWorker
18+
19+
__all__ = [
20+
"VllmConfig",
21+
"VllmGeneration",
22+
"VllmGenerationWorker",
23+
"VllmAsyncGenerationWorker",
24+
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, NotRequired, TypedDict
16+
17+
from nemo_rl.models.generation.interfaces import GenerationConfig
18+
19+
20+
class VllmSpecificArgs(TypedDict):
21+
tensor_parallel_size: int
22+
pipeline_parallel_size: int
23+
gpu_memory_utilization: float
24+
max_model_len: int
25+
# Additional arguments for vLLM inserted by nemo rl based on the context of when vllm is used
26+
skip_tokenizer_init: bool
27+
async_engine: bool
28+
load_format: NotRequired[str]
29+
precision: NotRequired[str]
30+
enforce_eager: NotRequired[bool]
31+
32+
33+
class VllmConfig(GenerationConfig):
34+
vllm_cfg: VllmSpecificArgs
35+
vllm_kwargs: NotRequired[dict[str, Any]]
File renamed without changes.

0 commit comments

Comments
 (0)