Skip to content

Commit f017fd8

Browse files
authored
feat: Support prefetching of specific envs (#1692)
Signed-off-by: Hemil Desai <[email protected]>
1 parent 433eaa1 commit f017fd8

File tree

2 files changed

+339
-4
lines changed

2 files changed

+339
-4
lines changed

nemo_rl/utils/prefetch_venvs.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import argparse
1415
import os
1516
import sys
1617
from pathlib import Path
@@ -21,16 +22,35 @@
2122
from nemo_rl.utils.venvs import create_local_venv
2223

2324

24-
def prefetch_venvs():
25-
"""Prefetch all virtual environments that will be used by workers."""
25+
def prefetch_venvs(filters=None):
26+
"""Prefetch all virtual environments that will be used by workers.
27+
28+
Args:
29+
filters: List of strings to match against actor FQNs. If provided, only
30+
actors whose FQN contains at least one of the filter strings will
31+
be prefetched. If None, all venvs are prefetched.
32+
"""
2633
print("Prefetching virtual environments...")
34+
if filters:
35+
print(f"Filtering for: {filters}")
36+
37+
# Track statistics for summary
38+
skipped_by_filter = []
39+
skipped_system_python = []
40+
prefetched = []
41+
failed = []
2742

2843
# Group venvs by py_executable to avoid duplicating work
2944
venv_configs = {}
3045
for actor_fqn, py_executable in ACTOR_ENVIRONMENT_REGISTRY.items():
46+
# Apply filters if provided
47+
if filters and not any(f in actor_fqn for f in filters):
48+
skipped_by_filter.append(actor_fqn)
49+
continue
3150
# Skip system python as it doesn't need a venv
3251
if py_executable == "python" or py_executable == sys.executable:
3352
print(f"Skipping {actor_fqn} (uses system Python)")
53+
skipped_system_python.append(actor_fqn)
3454
continue
3555

3656
# Only create venvs for uv-based executables
@@ -47,12 +67,31 @@ def prefetch_venvs():
4767
try:
4868
python_path = create_local_venv(py_executable, actor_fqn)
4969
print(f" Success: {python_path}")
70+
prefetched.append(actor_fqn)
5071
except Exception as e:
5172
print(f" Error: {e}")
73+
failed.append(actor_fqn)
5274
# Continue with other venvs even if one fails
5375
continue
5476

55-
print("\nVenv prefetching complete!")
77+
# Print summary
78+
print("\n" + "=" * 50)
79+
print("Venv prefetching complete! Summary:")
80+
print("=" * 50)
81+
print(f" Prefetched: {len(prefetched)}")
82+
for actor_fqn in prefetched:
83+
print(f" - {actor_fqn}")
84+
print(f" Skipped (system Python): {len(skipped_system_python)}")
85+
for actor_fqn in skipped_system_python:
86+
print(f" - {actor_fqn}")
87+
if filters:
88+
print(f" Skipped (filtered out): {len(skipped_by_filter)}")
89+
for actor_fqn in skipped_by_filter:
90+
print(f" - {actor_fqn}")
91+
if failed:
92+
print(f" Failed: {len(failed)}")
93+
for actor_fqn in failed:
94+
print(f" - {actor_fqn}")
5695

5796
# Create convenience python wrapper scripts for frozen environment support (container-only)
5897
create_frozen_environment_symlinks(venv_configs)
@@ -150,4 +189,28 @@ def create_frozen_environment_symlinks(venv_configs):
150189

151190

152191
if __name__ == "__main__":
153-
prefetch_venvs()
192+
parser = argparse.ArgumentParser(
193+
description="Prefetch virtual environments for Ray actors.",
194+
formatter_class=argparse.RawDescriptionHelpFormatter,
195+
epilog="""
196+
Examples:
197+
# Prefetch all venvs
198+
python -m nemo_rl.utils.prefetch_venvs
199+
200+
# Prefetch only vLLM-related venvs
201+
python -m nemo_rl.utils.prefetch_venvs vllm
202+
203+
# Prefetch multiple specific venvs
204+
python -m nemo_rl.utils.prefetch_venvs vllm policy environment
205+
""",
206+
)
207+
parser.add_argument(
208+
"filters",
209+
nargs="*",
210+
help="Filter strings to match against actor FQNs. Only actors whose FQN "
211+
"contains at least one of these strings will be prefetched. "
212+
"If not provided, all venvs are prefetched.",
213+
)
214+
args = parser.parse_args()
215+
216+
prefetch_venvs(filters=args.filters if args.filters else None)
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
from unittest.mock import patch
16+
17+
import pytest
18+
19+
import nemo_rl.utils.prefetch_venvs as prefetch_venvs_module
20+
21+
# When NRL_CONTAINER is set, create_frozen_environment_symlinks also calls
22+
# create_local_venv for each actor, effectively doubling the call count
23+
CALL_MULTIPLIER = 2 if os.environ.get("NRL_CONTAINER") else 1
24+
25+
26+
@pytest.fixture
27+
def mock_registry():
28+
"""Create a mock registry with various actor types."""
29+
return {
30+
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": "uv run --group vllm",
31+
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": "uv run --group vllm",
32+
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": "uv run --group mcore",
33+
"nemo_rl.environments.math_environment.MathEnvironment": "python",
34+
"nemo_rl.environments.code_environment.CodeEnvironment": "python",
35+
}
36+
37+
38+
@pytest.fixture
39+
def prefetch_venvs_func(mock_registry):
40+
"""Patch the registry directly in the prefetch_venvs module."""
41+
with patch.object(
42+
prefetch_venvs_module, "ACTOR_ENVIRONMENT_REGISTRY", mock_registry
43+
):
44+
yield prefetch_venvs_module.prefetch_venvs
45+
46+
47+
class TestPrefetchVenvs:
48+
"""Tests for the prefetch_venvs function."""
49+
50+
def test_prefetch_venvs_no_filters(self, prefetch_venvs_func):
51+
"""Test that all uv-based venvs are prefetched when no filters are provided."""
52+
with patch(
53+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
54+
) as mock_create_venv:
55+
mock_create_venv.return_value = "/path/to/venv/bin/python"
56+
57+
prefetch_venvs_func(filters=None)
58+
59+
assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER
60+
61+
# Verify the actors that were called
62+
call_args = [call[0] for call in mock_create_venv.call_args_list]
63+
actor_fqns = [args[1] for args in call_args]
64+
65+
assert (
66+
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker"
67+
in actor_fqns
68+
)
69+
assert (
70+
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"
71+
in actor_fqns
72+
)
73+
assert (
74+
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
75+
in actor_fqns
76+
)
77+
78+
def test_prefetch_venvs_single_filter(self, prefetch_venvs_func):
79+
"""Test filtering with a single filter string."""
80+
with patch(
81+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
82+
) as mock_create_venv:
83+
mock_create_venv.return_value = "/path/to/venv/bin/python"
84+
85+
prefetch_venvs_func(filters=["vllm"])
86+
87+
# Should only create venvs for actors containing "vllm" (1 actor)
88+
assert mock_create_venv.call_count == 1 * CALL_MULTIPLIER
89+
90+
call_args = mock_create_venv.call_args[0]
91+
assert (
92+
call_args[1]
93+
== "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker"
94+
)
95+
96+
def test_prefetch_venvs_multiple_filters(self, prefetch_venvs_func):
97+
"""Test filtering with multiple filter strings (OR logic)."""
98+
with patch(
99+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
100+
) as mock_create_venv:
101+
mock_create_venv.return_value = "/path/to/venv/bin/python"
102+
103+
prefetch_venvs_func(filters=["vllm", "megatron"])
104+
105+
# Should create venvs for actors containing "vllm" OR "megatron" (2 actors)
106+
assert mock_create_venv.call_count == 2 * CALL_MULTIPLIER
107+
108+
call_args = [call[0] for call in mock_create_venv.call_args_list]
109+
actor_fqns = [args[1] for args in call_args]
110+
111+
assert (
112+
"nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker"
113+
in actor_fqns
114+
)
115+
assert (
116+
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
117+
in actor_fqns
118+
)
119+
120+
def test_prefetch_venvs_filter_no_match(self, prefetch_venvs_func):
121+
"""Test that no venvs are created when filter matches nothing."""
122+
with patch(
123+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
124+
) as mock_create_venv:
125+
mock_create_venv.return_value = "/path/to/venv/bin/python"
126+
127+
prefetch_venvs_func(filters=["nonexistent"])
128+
129+
# Should not create any venvs
130+
assert mock_create_venv.call_count == 0
131+
132+
def test_prefetch_venvs_skips_system_python(self, prefetch_venvs_func):
133+
"""Test that system python actors are skipped even if they match filters."""
134+
with patch(
135+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
136+
) as mock_create_venv:
137+
mock_create_venv.return_value = "/path/to/venv/bin/python"
138+
139+
# Filter for "environment" which matches system python actors
140+
prefetch_venvs_func(filters=["environment"])
141+
142+
# Should not create any venvs since matching actors use system python
143+
assert mock_create_venv.call_count == 0
144+
145+
def test_prefetch_venvs_partial_match(self, prefetch_venvs_func):
146+
"""Test that filter matches partial strings within FQN."""
147+
with patch(
148+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
149+
) as mock_create_venv:
150+
mock_create_venv.return_value = "/path/to/venv/bin/python"
151+
152+
# "policy" should match both dtensor_policy_worker and megatron_policy_worker
153+
prefetch_venvs_func(filters=["policy"])
154+
155+
assert mock_create_venv.call_count == 2 * CALL_MULTIPLIER
156+
157+
call_args = [call[0] for call in mock_create_venv.call_args_list]
158+
actor_fqns = [args[1] for args in call_args]
159+
160+
assert (
161+
"nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker"
162+
in actor_fqns
163+
)
164+
assert (
165+
"nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker"
166+
in actor_fqns
167+
)
168+
169+
def test_prefetch_venvs_empty_filter_list(self, prefetch_venvs_func):
170+
"""Test that empty filter list is treated as no filtering (falsy)."""
171+
with patch(
172+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
173+
) as mock_create_venv:
174+
mock_create_venv.return_value = "/path/to/venv/bin/python"
175+
176+
# Empty list should be falsy and prefetch all
177+
prefetch_venvs_func(filters=[])
178+
179+
# Should create venvs for all uv-based actors (3 total)
180+
assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER
181+
182+
def test_prefetch_venvs_continues_on_error(self, prefetch_venvs_func):
183+
"""Test that prefetching continues even if one venv creation fails."""
184+
with patch(
185+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
186+
) as mock_create_venv:
187+
# Provide enough return values for both prefetch and frozen env symlinks
188+
mock_create_venv.side_effect = [
189+
Exception("Test error"),
190+
"/path/to/venv/bin/python",
191+
"/path/to/venv/bin/python",
192+
] * CALL_MULTIPLIER
193+
194+
# Should not raise, should continue with other venvs
195+
prefetch_venvs_func(filters=None)
196+
197+
# All 3 uv-based actors should have been attempted
198+
assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER
199+
200+
def test_prefetch_venvs_case_sensitive_filter(self, prefetch_venvs_func):
201+
"""Test that filters are case-sensitive."""
202+
with patch(
203+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
204+
) as mock_create_venv:
205+
mock_create_venv.return_value = "/path/to/venv/bin/python"
206+
207+
# "VLLM" (uppercase) should not match "vllm" (lowercase)
208+
prefetch_venvs_func(filters=["VLLM"])
209+
210+
assert mock_create_venv.call_count == 0
211+
212+
def test_prefetch_venvs_summary_no_filters(self, prefetch_venvs_func, capsys):
213+
"""Test that summary is printed with correct counts and names when no filters."""
214+
with patch(
215+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
216+
) as mock_create_venv:
217+
mock_create_venv.return_value = "/path/to/venv/bin/python"
218+
219+
prefetch_venvs_func(filters=None)
220+
221+
captured = capsys.readouterr()
222+
assert "Venv prefetching complete! Summary:" in captured.out
223+
assert "Prefetched: 3" in captured.out
224+
assert "Skipped (system Python): 2" in captured.out
225+
# Verify prefetched env names are listed
226+
assert "VllmGenerationWorker" in captured.out
227+
assert "DTensorPolicyWorker" in captured.out
228+
assert "MegatronPolicyWorker" in captured.out
229+
# Verify skipped env names are listed
230+
assert "MathEnvironment" in captured.out
231+
assert "CodeEnvironment" in captured.out
232+
# "Skipped (filtered out)" should not appear when no filters
233+
assert "Skipped (filtered out)" not in captured.out
234+
235+
def test_prefetch_venvs_summary_with_filters(self, prefetch_venvs_func, capsys):
236+
"""Test that summary includes filtered out names when filters are used."""
237+
with patch(
238+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
239+
) as mock_create_venv:
240+
mock_create_venv.return_value = "/path/to/venv/bin/python"
241+
242+
prefetch_venvs_func(filters=["vllm"])
243+
244+
captured = capsys.readouterr()
245+
assert "Venv prefetching complete! Summary:" in captured.out
246+
assert "Prefetched: 1" in captured.out
247+
assert "Skipped (system Python): 0" in captured.out
248+
assert "Skipped (filtered out): 4" in captured.out
249+
# Verify prefetched env name is listed
250+
assert "VllmGenerationWorker" in captured.out
251+
# Verify filtered out env names are listed
252+
assert "DTensorPolicyWorker" in captured.out
253+
assert "MegatronPolicyWorker" in captured.out
254+
255+
def test_prefetch_venvs_summary_with_failures(self, prefetch_venvs_func, capsys):
256+
"""Test that summary includes failed actor names when errors occur."""
257+
with patch(
258+
"nemo_rl.utils.prefetch_venvs.create_local_venv"
259+
) as mock_create_venv:
260+
# Provide enough return values for both prefetch and frozen env symlinks
261+
mock_create_venv.side_effect = [
262+
Exception("Test error"),
263+
"/path/to/venv/bin/python",
264+
"/path/to/venv/bin/python",
265+
] * CALL_MULTIPLIER
266+
267+
prefetch_venvs_func(filters=None)
268+
269+
captured = capsys.readouterr()
270+
assert "Venv prefetching complete! Summary:" in captured.out
271+
assert "Prefetched: 2" in captured.out
272+
assert "Failed: 1" in captured.out

0 commit comments

Comments
 (0)