Skip to content

Commit a99bc26

Browse files
authored
refactor: Introduce BasePolicyWorker (#1585)
Signed-off-by: ashors1 <ashors@nvidia.com>
1 parent 5e73bfd commit a99bc26

File tree

11 files changed

+193
-345
lines changed

11 files changed

+193
-345
lines changed

.github/workflows/_automodel_integration_check.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ jobs:
134134
echo "Checking if dtensor policy worker files are synchronized..."
135135
136136
# Define the dtensor policy worker file paths
137-
DTENSOR_POLICY_WORKER_FILE="nemo_rl/models/policy/dtensor_policy_worker.py"
138-
DTENSOR_POLICY_WORKER_V2_FILE="nemo_rl/models/policy/dtensor_policy_worker_v2.py"
137+
DTENSOR_POLICY_WORKER_FILE="nemo_rl/models/policy/workers/dtensor_policy_worker.py"
138+
DTENSOR_POLICY_WORKER_V2_FILE="nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py"
139139
140140
# Check if dtensor_policy_worker.py was modified in this PR
141141
if git diff --name-only origin/${{ inputs.base_ref }}..HEAD | grep -q "^${DTENSOR_POLICY_WORKER_FILE}$"; then

docs/fp8.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ FP8 generations are recommended to be configured with the following settings:
5353
use_activation_pow2_scale: False
5454
```
5555

56-
"To train with FP8, you need to set the Megatron path and configure it using the following settings:
56+
To train with FP8, you need to set the Megatron path and configure it using the following settings:
5757

5858
```
5959
policy:
@@ -68,12 +68,12 @@ FP8 generations are recommended to be configured with the following settings:
6868

6969
The TransformerEngine implementation for this recipe requires **cuda version ≥ 12.9**. The latest nemo-rl depends on torch 2.8.0 + cuda 12.9 (since this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd)). Users should check-out code to latest and build container from `docker/Dockerfile` ([instructions](docker.md)).
7070

71-
If you are using nemo-rl before this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd), you will see the following error when trying to use fp8 training
71+
If you are using nemo-rl before this [commit](https://github.com/NVIDIA-NeMo/RL/commit/3f36d14b53e906b27c01c06e36dbbd2b8eb300cd), you will see the following error when trying to use fp8 training:
7272

7373
```
74-
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 646, in fp8_autocast
74+
File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 646, in fp8_autocast
7575
FP8GlobalStateManager.fp8_autocast_enter(
76-
File "/opt/ray_venvs/nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 465, in fp8_autocast_enter
76+
File "/opt/ray_venvs/nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker/lib/python3.12/site-packages/transformer_engine/pytorch/fp8.py", line 465, in fp8_autocast_enter
7777
assert fp8_block_available, reason_for_no_fp8_block
7878
^^^^^^^^^^^^^^^^^^^
7979
AssertionError: FP8 block scaled GEMM requires Hopper and CUDA >= 12.9.
@@ -88,5 +88,5 @@ The above results are from Llama-3.1-8B-Instruct GRPO experiments. You can run t
8888
* For BF16: `examples/configs/grpo_math_8B_megatron.yaml`
8989
* For FP8: `examples/configs/grpo_math_8B_megatron_fp8.yaml`
9090

91-
In the experiment in this figure, enabling FP8 rollout and training gives 15%-25% decrease in step time, and the validation accuracy curves match up to 1000 step.
91+
In the experiment in this figure, enabling FP8 rollout and training gives 15%-25% decrease in step time, and the validation accuracy curves match up to 1000 steps.
9292
Efforts are ongoing to performs longer runs and further optimize performance.

nemo_rl/distributed/ray_actor_environment_registry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@
2929
"nemo_rl.models.generation.vllm.vllm_worker_async.VllmAsyncGenerationWorker": VLLM_EXECUTABLE,
3030
# Temporary workaround for the coupled implementation of DTensorPolicyWorker and vLLM.
3131
# This will be reverted to PY_EXECUTABLES.BASE once https://github.com/NVIDIA-NeMo/RL/issues/501 is resolved.
32-
"nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE,
33-
"nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
34-
"nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
32+
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": VLLM_EXECUTABLE,
33+
"nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2": PY_EXECUTABLES.AUTOMODEL,
34+
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": MCORE_EXECUTABLE,
3535
"nemo_rl.environments.math_environment.MathEnvironment": PY_EXECUTABLES.SYSTEM,
3636
"nemo_rl.environments.vlm_environment.VLMEnvironment": PY_EXECUTABLES.SYSTEM,
3737
"nemo_rl.environments.code_environment.CodeEnvironment": PY_EXECUTABLES.SYSTEM,

nemo_rl/models/policy/interfaces.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def get_logprobs(
6767

6868
@abstractmethod
6969
def get_reference_policy_logprobs(
70-
self, data: BatchedDataDict[GenerationDatumSpec]
70+
self,
71+
data: BatchedDataDict[GenerationDatumSpec],
72+
micro_batch_size: Optional[int] = None,
7173
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
7274
"""Get logprobs of actions from observations.
7375
@@ -100,6 +102,7 @@ def train(
100102
data: BatchedDataDict,
101103
loss_fn: LossFunction,
102104
eval_mode: bool = False,
105+
*,
103106
gbs: Optional[int] = None,
104107
mbs: Optional[int] = None,
105108
) -> dict[str, Any]:
@@ -114,13 +117,6 @@ def train(
114117
"""
115118
pass
116119

117-
@abstractmethod
118-
def score(
119-
self, data: BatchedDataDict[GenerationDatumSpec]
120-
) -> BatchedDataDict[ScoreOutputSpec]:
121-
"""Score a batch of data using the policy."""
122-
pass
123-
124120
@abstractmethod
125121
def calibrate_qkv_fp8_scales(
126122
self,
@@ -191,3 +187,7 @@ def broadcast_weights_for_collective(
191187
self, kv_scales: Optional[dict[str, float]] = None
192188
) -> list[ray.ObjectRef]:
193189
pass
190+
191+
@abstractmethod
192+
def prepare_for_lp_inference(self) -> None:
193+
pass

nemo_rl/models/policy/lm_policy.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def __init__(
8787
"DTensor (policy.dtensor_cfg.enabled=true), not both."
8888
)
8989
if megatron_enable:
90-
worker_builder_cls = (
91-
"nemo_rl.models.policy.megatron_policy_worker.MegatronPolicyWorker"
92-
)
90+
worker_builder_cls = "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
9391
tp_size = config["megatron_cfg"]["tensor_model_parallel_size"]
9492
pp_size = config["megatron_cfg"]["pipeline_model_parallel_size"]
9593
cp_size = config["megatron_cfg"]["context_parallel_size"]
@@ -112,11 +110,9 @@ def __init__(
112110
# Check if _v2 is enabled in dtensor_cfg (defaults to False for backward compatibility)
113111
use_v2 = config.get("dtensor_cfg", {}).get("_v2", False)
114112
if use_v2:
115-
worker_builder_cls = "nemo_rl.models.policy.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
113+
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker_v2.DTensorPolicyWorkerV2"
116114
else:
117-
worker_builder_cls = (
118-
"nemo_rl.models.policy.dtensor_policy_worker.DTensorPolicyWorker"
119-
)
115+
worker_builder_cls = "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"
120116

121117
tp_size = config["dtensor_cfg"]["tensor_parallel_size"]
122118
cp_size = config["dtensor_cfg"]["context_parallel_size"]
@@ -666,10 +662,6 @@ def invalidate_kv_cache(self, *args: Any, **kwargs: Any) -> bool:
666662
# We don't need to do anything here
667663
return True
668664

669-
def finish_training(self, *args: Any, **kwargs: Any) -> None:
670-
# Placeholder implementation
671-
pass
672-
673665
def prepare_refit_info(self) -> Optional[dict[str, Any]]:
674666
"""Prepare the info for refit.
675667
@@ -681,6 +673,10 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]:
681673
# Only get the first worker's info since all workers will have the same result
682674
return results[0]
683675

676+
def finish_training(self, *args: Any, **kwargs: Any) -> None:
677+
# Placeholder implementation
678+
pass
679+
684680
def calibrate_qkv_fp8_scales(
685681
self,
686682
data: BatchedDataDict[GenerationDatumSpec],
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Optional
15+
16+
import ray
17+
import torch
18+
import zmq
19+
20+
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
21+
from nemo_rl.models.policy.interfaces import ReferenceLogprobOutputSpec
22+
from nemo_rl.utils.nsys import wrap_with_nvtx_name
23+
24+
25+
class AbstractPolicyWorker:
26+
"""Base class for policy workers with shared functionality."""
27+
28+
def init_collective(
29+
self, ip: str, port: int, world_size: int, *, train_world_size: int
30+
) -> None:
31+
"""Initialize the collective communication.
32+
33+
Args:
34+
ip: IP address for the process group
35+
port: Port for the process group
36+
world_size: Total world size (train_world_size + inference_world_size)
37+
train_world_size: Number of training workers (used in inference cluster)
38+
"""
39+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
40+
from vllm.distributed.utils import StatelessProcessGroup
41+
42+
pg = StatelessProcessGroup.create(
43+
host=ip, port=port, rank=self.rank, world_size=world_size
44+
)
45+
device = torch.cuda.current_device()
46+
self.model_update_group = PyNcclCommunicator(pg, device=device)
47+
48+
def is_alive(self) -> bool:
49+
"""Check if the worker is alive."""
50+
return True
51+
52+
def reset_peak_memory_stats(self) -> None:
53+
"""Reset peak memory statistics."""
54+
torch.cuda.reset_peak_memory_stats()
55+
56+
def get_gpu_info(self) -> dict[str, Any]:
57+
"""Return information about the GPU being used by this worker."""
58+
from nemo_rl.models.policy.utils import get_gpu_info
59+
60+
return get_gpu_info(self.model)
61+
62+
def report_device_id(self) -> str:
63+
"""Report the UUID of the current CUDA device using NVML.
64+
65+
Returns:
66+
str: UUID of the device in the format "GPU-xxxxx"
67+
"""
68+
from nemo_rl.utils.nvml import get_device_uuid
69+
70+
# Get current device index from torch
71+
device_idx = torch.cuda.current_device()
72+
# Get device UUID using NVML
73+
return get_device_uuid(device_idx)
74+
75+
def get_zmq_address(self) -> str:
76+
"""Get the ZMQ address for the current device."""
77+
return f"ipc:///tmp/{self.report_device_id()}.sock"
78+
79+
def maybe_init_zmq(self) -> None:
80+
"""Initialize the ZMQ socket if it doesn't exist."""
81+
if not hasattr(self, "zmq_socket"):
82+
self.zmq_context = zmq.Context()
83+
self.zmq_socket = self.zmq_context.socket(zmq.REQ)
84+
self.zmq_socket.setsockopt(
85+
zmq.SNDTIMEO, 120000
86+
) # set timeout to 120 seconds
87+
self.zmq_socket.setsockopt(
88+
zmq.RCVTIMEO, 120000
89+
) # set timeout to 120 seconds
90+
self.zmq_socket.setsockopt(zmq.LINGER, 0)
91+
self.zmq_socket.bind(self.get_zmq_address())
92+
93+
def get_free_memory_bytes(self) -> int:
94+
"""Get the available free memory."""
95+
from nemo_rl.utils.nvml import get_free_memory_bytes
96+
97+
device_idx = torch.cuda.current_device()
98+
return get_free_memory_bytes(device_idx)
99+
100+
def shutdown(self) -> bool:
101+
"""Shutdown the policy."""
102+
try:
103+
# Clean up extension resources like ZMQ sockets
104+
if hasattr(self, "zmq_socket"):
105+
self.zmq_socket.close()
106+
self.zmq_context.term()
107+
return True
108+
except Exception:
109+
return False
110+
111+
def start_gpu_profiling(self) -> None:
112+
"""Start GPU profiling."""
113+
torch.cuda.profiler.start()
114+
115+
def stop_gpu_profiling(self) -> None:
116+
"""Stop GPU profiling."""
117+
torch.cuda.profiler.stop()
118+
119+
def report_node_ip_and_gpu_id(self) -> tuple[str, int]:
120+
"""Report the node IP and GPU ID of the current worker."""
121+
ip = ray._private.services.get_node_ip_address()
122+
gpu_id = ray.get_gpu_ids()[0]
123+
return (ip, gpu_id)
124+
125+
# Temporary fix, 'data' is a kwarg due to some sort of ray bug
126+
@wrap_with_nvtx_name("policy_worker/get_reference_policy_logprobs")
127+
def get_reference_policy_logprobs(
128+
self,
129+
*,
130+
data: BatchedDataDict[Any],
131+
micro_batch_size: Optional[int] = None,
132+
) -> BatchedDataDict[ReferenceLogprobOutputSpec]:
133+
"""Get the logprobs from the reference policy for a batch of data.
134+
135+
If micro_batch_size is provided, it will be used instead of the configured
136+
logprob_batch_size.
137+
138+
Returns:
139+
a BatchedDataDict with key "reference_logprobs" and shape [batch_size, sequence_length].
140+
We use the convention that the logprob of the first token is 0 so that the sequence length is maintained.
141+
The logprob of input token i is specified at position i in the output logprobs tensor.
142+
"""
143+
with self.use_reference_model():
144+
reference_logprobs = self.get_logprobs(
145+
data=data, micro_batch_size=micro_batch_size
146+
)
147+
148+
return_data = BatchedDataDict[ReferenceLogprobOutputSpec]()
149+
return_data["reference_logprobs"] = reference_logprobs["logprobs"].cpu()
150+
return return_data
151+
152+
def finish_training(self, *args: Any, **kwargs: Any) -> None:
153+
# Placeholder implementation
154+
pass

0 commit comments

Comments
 (0)