Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions nemo_rl/distributed/worker_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,18 @@ def __init__(self, ray_actor_class_fqn: str, *args, **kwargs):
self.args = args
self.kwargs = kwargs

# Use the ray_actor_class_fqn_for_env if provided, otherwise use the ray_actor_class_fqn.
# This is useful when using worker extension classes.
self.ray_actor_class_fqn_for_env = ray_actor_class_fqn
if (
"ray_actor_class_fqn_for_env" in kwargs
and kwargs["ray_actor_class_fqn_for_env"] is not None
):
self.ray_actor_class_fqn_for_env = kwargs["ray_actor_class_fqn_for_env"]

if "ray_actor_class_fqn_for_env" in kwargs:
del self.kwargs["ray_actor_class_fqn_for_env"]

def create_worker_async(
self,
placement_group: PlacementGroup,
Expand Down Expand Up @@ -436,7 +448,7 @@ def _create_workers_from_bundle_indices(

# Get the python environment for the actor
actor_python_env = get_actor_python_env(
remote_worker_builder.ray_actor_class_fqn
remote_worker_builder.ray_actor_class_fqn_for_env
)
if actor_python_env.startswith("uv"):
# If the py_executable begins with uv it signals that we need to create a
Expand All @@ -445,7 +457,7 @@ def _create_workers_from_bundle_indices(
# NEMO_RL_VENV_DIR and defaults to $GIT_ROOT/venvs/.
py_executable = create_local_venv_on_each_node(
py_executable=actor_python_env,
venv_name=remote_worker_builder.ray_actor_class_fqn,
venv_name=remote_worker_builder.ray_actor_class_fqn_for_env,
)
else:
py_executable = actor_python_env
Expand Down
49 changes: 49 additions & 0 deletions nemo_rl/models/policy/lm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def __init__(
optimizer_path: Optional[PathLike] = None,
init_reference_model: bool = True,
processor: Optional[AutoProcessor] = None,
worker_extension_cls: Optional[str] = None,
):
if weights_path:
weights_path = os.path.abspath(weights_path)
if optimizer_path:
optimizer_path = os.path.abspath(optimizer_path)

worker_builder_cls: str
worker_builder_cls_for_env: Optional[str] = None
tp_size = 1
pp_size = 1
cp_size = 1
Expand Down Expand Up @@ -131,6 +133,14 @@ def __init__(

env_vars = config["dtensor_cfg"].get("env_vars", {})

# If a worker extension class is provided, use it instead of the default worker builder class
if worker_extension_cls is not None:
print(
f"Using worker extension class: {worker_extension_cls}, please make sure it is a subclass of {worker_builder_cls}."
)
worker_builder_cls_for_env = worker_builder_cls
worker_builder_cls = worker_extension_cls

# Validate world_size compatibility with parallelism configuration
model_parallel_size = pp_size * cp_size * tp_size
actual_world_size = cluster.world_size()
Expand Down Expand Up @@ -198,6 +208,7 @@ def __init__(
init_reference_model=init_reference_model,
worker_sharding_annotations=self.sharding_annotations,
pre_init_communication_queue=pre_init_queue,
ray_actor_class_fqn_for_env=worker_builder_cls_for_env,
)

if cluster._sorted_bundle_indices is not None:
Expand Down Expand Up @@ -275,6 +286,44 @@ def __init__(

self.cfg = config

def run_all_workers_single_data(self, method_name: str, *args, **kwargs) -> Any:
"""Run a method on all workers in parallel with the same data.

Mainly used for worker extension classes.

Args:
method_name: The name of the method to run.
*args: The positional arguments to pass to the method.
**kwargs: The keyword arguments to pass to the method.

Returns:
The results of the method run on all workers.
"""
futures = self.worker_group.run_all_workers_single_data(
method_name, *args, **kwargs
)
results = ray.get(futures)
return results

def run_all_workers_multiple_data(self, method_name: str, *args, **kwargs) -> Any:
"""Run a method on all workers in parallel with different data.

Mainly used for worker extension classes.

Args:
method_name: The name of the method to run.
*args: The positional arguments to pass to the method.
**kwargs: The keyword arguments to pass to the method.

Returns:
The results of the method run on all workers.
"""
futures = self.worker_group.run_all_workers_multiple_data(
method_name, *args, **kwargs
)
results = ray.get(futures)
return results

def init_collective(
self, ip: str, port: int, world_size: int, *, train_world_size: int
) -> list[ray.ObjectRef]:
Expand Down
14 changes: 10 additions & 4 deletions nemo_rl/models/policy/workers/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,9 @@ def get_cpu_state_dict(
return new_state_dict


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker")
) # pragma: no cover
class DTensorPolicyWorker(AbstractPolicyWorker, ColocatablePolicyInterface):
# Classes with @ray.remote can't be inherited from, so we split the implementation out.
# This is useful when using worker extension classes.
class DTensorPolicyWorkerImpl(AbstractPolicyWorker, ColocatablePolicyInterface):
def __repr__(self) -> str:
"""Customizes the actor's prefix in the Ray logs.

Expand Down Expand Up @@ -1896,3 +1895,10 @@ def load_checkpoint(
scheduler=self.scheduler if optimizer_path else None,
optimizer_path=optimizer_path,
)


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker")
) # pragma: no cover
class DTensorPolicyWorker(DTensorPolicyWorkerImpl):
pass
18 changes: 11 additions & 7 deletions nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@
LogprobOutputSpec,
ScoreOutputSpec,
)
from nemo_rl.models.policy.utils import (
get_runtime_env_for_policy_worker,
)
from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker
from nemo_rl.models.policy.workers.base_policy_worker import AbstractPolicyWorker
from nemo_rl.models.policy.workers.patches import (
apply_torch_aten_alias_tensor_patch,
Expand Down Expand Up @@ -194,10 +192,9 @@ def get_train_context(
yield


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2")
) # pragma: no cover
class DTensorPolicyWorkerV2(AbstractPolicyWorker, ColocatablePolicyInterface):
# Classes with @ray.remote can't be inherited from, so we split the implementation out.
# This is useful when using worker extension classes.
class DTensorPolicyWorkerV2Impl(AbstractPolicyWorker, ColocatablePolicyInterface):
def __repr__(self) -> str:
"""Customizes the actor's prefix in the Ray logs.

Expand Down Expand Up @@ -1127,3 +1124,10 @@ def _init_checkpoint_manager(
config_updates=config_updates,
checkpoint_root=checkpoint_root,
)


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2")
) # pragma: no cover
class DTensorPolicyWorkerV2(DTensorPolicyWorkerV2Impl):
pass
14 changes: 10 additions & 4 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,9 @@
TokenizerType = TypeVar("TokenizerType", bound=PreTrainedTokenizerBase)


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("megatron_policy_worker")
) # pragma: no cover
class MegatronPolicyWorker(AbstractPolicyWorker, ColocatablePolicyInterface):
# Classes with @ray.remote can't be inherited from, so we split the implementation out.
# This is useful when using worker extension classes.
class MegatronPolicyWorkerImpl(AbstractPolicyWorker, ColocatablePolicyInterface):
def __repr__(self):
"""Customizes the actor's prefix in the Ray logs.

Expand Down Expand Up @@ -1518,3 +1517,10 @@ def _percentile(values: list[float], p: float) -> float:
final_result = obj_list[0] # type: ignore

return final_result


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("megatron_policy_worker")
) # pragma: no cover
class MegatronPolicyWorker(MegatronPolicyWorkerImpl):
pass
29 changes: 25 additions & 4 deletions research/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,31 @@ cp -r research/template_project research/my_new_project
```

The template includes:
- A minimal train-and-generate loop example
- Complete test suite structure (unit, functional, and test suites)
- Configuration examples
- Documentation template
- A minimal train-and-generate loop example:
- Main loop: [single_update.py](research/template_project/single_update.py)
- Utilities used by the main loop: [template_project/](research/template_project/template_project)
- Configuration examples: [configs/](research/template_project/configs)
- The subdirectory [configs/recipes/](research/template_project/configs/recipes) is used only for test suites
- Documentation template: [README.md](research/template_project/README.md)
- Complete test suite structure (unit, functional, and test suites): [tests/](research/template_project/tests)
- Dependency specification: [.python-version](research/template_project/.python-version) and [pyproject.toml](research/template_project/pyproject.toml)

## What Needs To Be Provided

A new research project needs to include at least:
- Driver script and main loop
- You can refer to [run_grpo.py](examples/run_grpo.py) and [grpo.py](nemo_rl/algorithms/grpo.py) in the core repository, and [single_update.py](research/template_project/single_update.py) in the research template for examples.
- Configuration
- A runnable `config.yaml` that defines the experiment.
- You can refer to [examples/configs/](examples/configs) in the core repository and [configs/](research/template_project/configs) in the research template for examples.
- Documentation
- A `README.md` that describes the project, how to run it, and how to reproduce results.
- Functional test
- An end-to-end test with minimal configuration to ensure that changes elsewhere do not break the research project.

The following are optional:
- Unit tests and test suites (adding these is encouraged).
- Dependency specifications (required if the project’s dependencies differ from the core `nemo_rl` package).

## Expectations for Research Project Authors

Expand Down
12 changes: 7 additions & 5 deletions research/template_project/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ This is a template project for research experiments with NeMo RL.

The `single_update.py` script demonstrates a minimal train-and-generate loop:
1. Sets up a Ray compute cluster
2. Initializes vLLM generation and an LM policy
3. Trains the policy on a small batch using NLL loss
4. Refits the generation engine with the updated policy weights
5. Generates outputs with the new policy
6. Repeats the loop (10 iterations by default)
2. Initializes the vLLM generation
3. Initializes the LM policy with an extension worker class that supports custom functions
4. Executes custom functions provided by the extension worker class
5. Repeats the loop (10 iterations by default)
1. Trains the policy on a small batch using NLL loss
2. Refits the generation engine with the updated policy weights
3. Generates outputs with the new policy

This shows the basic cycle of training a language model and using it for generation.

Expand Down
24 changes: 21 additions & 3 deletions research/template_project/single_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
1) Sets up a RayVirtualCluster
2) Initializes VllmGeneration
3) Initializes LM Policy
4) Trains on a tiny synthetic batch (global batch size = 2) with NLLLossFn
4) Executes custom methods on all workers
5) Refits the generation engine with the latest policy weights
6) Optionally repeats the train→refit cycle in a short loop
6) Trains on a tiny synthetic batch (global batch size = 2) with NLLLossFn
7) Optionally repeats the train→refit cycle in a short loop

Notes:
- The configuration is defined entirely in this file, inspired by examples/configs/grpo_math_1B.yaml
Expand Down Expand Up @@ -88,14 +89,28 @@ def main(config: MasterConfig) -> None:
config=policy_config,
tokenizer=tokenizer,
init_reference_model=False,
worker_extension_cls="research.template_project.template_project.worker_extension.DTensorPolicyWorkerV2Extension",
)
print(" ✓ Policy created")

# 4) Executes custom methods on all workers
# 4.1) Run a method on all workers in parallel with the same data
print("\n▶ Running a method on all workers in parallel with the same data...")
results = policy.run_all_workers_single_data("get_worker_rank")
print(f" ✓ Results for get_worker_rank: {results}")

# 4.2) Run a method on all workers in parallel with different data
print("\n▶ Running a method on all workers in parallel with different data...")
worker_nums = config["cluster"]["gpus_per_node"] * config["cluster"]["num_nodes"]
input_list = [i for i in range(worker_nums)]
results = policy.run_all_workers_multiple_data("return_input", input=input_list)
print(f" ✓ Results for return_input: {results}")

# Prepare refit info once before first refit
state_dict_info = policy.prepare_refit_info()
policy_generation.prepare_refit_info(state_dict_info or {})

# 4) Create tiny numeric batch and train with NLLLossFn
# Create tiny numeric batch and train with NLLLossFn
print("\n▶ Creating tiny numeric batch and training with NLLLossFn...")
train_sentences = ["a b c d e hello", "a d f world"] * config["policy"][
"train_global_batch_size"
Expand Down Expand Up @@ -132,6 +147,7 @@ def main(config: MasterConfig) -> None:
}
)

# 5) Refits the generation engine with the latest policy weights
print(" • Refit generation with latest policy weights...")
refit_policy_generation(
policy=policy,
Expand All @@ -153,6 +169,8 @@ def main(config: MasterConfig) -> None:
)
for i, out_text in enumerate(decoded):
print(f" - prompt: '{generation_prompts[i]}' -> '{out_text}'")

# 6) Trains on a tiny synthetic batch (global batch size = 2) with NLLLossFn
policy.prepare_for_training()
results = policy.train(data, loss_fn)
loss_tensor = results["loss"]
Expand Down
39 changes: 39 additions & 0 deletions research/template_project/template_project/worker_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

import ray
import torch

from nemo_rl.models.policy.utils import get_runtime_env_for_policy_worker
from nemo_rl.models.policy.workers.dtensor_policy_worker_v2 import (
DTensorPolicyWorkerV2Impl,
)


@ray.remote(
runtime_env=get_runtime_env_for_policy_worker("dtensor_policy_worker_v2")
) # pragma: no cover
class DTensorPolicyWorkerV2Extension(DTensorPolicyWorkerV2Impl):
"""Example worker extension that adds custom methods."""

def get_worker_rank(self) -> dict[str, Any]:
"""Return per-worker rank. Used to demonstrate run_all_workers_single_data."""
rank = torch.distributed.get_rank()
return rank

def return_input(self, input: Any) -> Any:
"""Return the input. Used to demonstrate run_all_workers_multiple_data."""
return input
4 changes: 2 additions & 2 deletions research/template_project/tests/functional/single_update.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ cd $PROJECT_ROOT
SINGLE_UPDATE_ITERS=1 uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
single_update.py \
--config $PROJECT_ROOT/configs/grpo_math_1B.yaml \
cluster.gpus_per_node=1 \
cluster.gpus_per_node=2 \
cluster.num_nodes=1 \
policy.train_global_batch_size=1 \
policy.train_global_batch_size=2 \
policy.train_micro_batch_size=1 \
$@ \
2>&1 | tee $RUN_LOG
Expand Down