|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import os |
| 15 | +from unittest.mock import patch |
| 16 | + |
| 17 | +import pytest |
| 18 | + |
| 19 | +import nemo_rl.utils.prefetch_venvs as prefetch_venvs_module |
| 20 | + |
| 21 | +# When NRL_CONTAINER is set, create_frozen_environment_symlinks also calls |
| 22 | +# create_local_venv for each actor, effectively doubling the call count |
| 23 | +CALL_MULTIPLIER = 2 if os.environ.get("NRL_CONTAINER") else 1 |
| 24 | + |
| 25 | + |
| 26 | +@pytest.fixture |
| 27 | +def mock_registry(): |
| 28 | + """Create a mock registry with various actor types.""" |
| 29 | + return { |
| 30 | + "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker": "uv run --group vllm", |
| 31 | + "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker": "uv run --group vllm", |
| 32 | + "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker": "uv run --group mcore", |
| 33 | + "nemo_rl.environments.math_environment.MathEnvironment": "python", |
| 34 | + "nemo_rl.environments.code_environment.CodeEnvironment": "python", |
| 35 | + } |
| 36 | + |
| 37 | + |
| 38 | +@pytest.fixture |
| 39 | +def prefetch_venvs_func(mock_registry): |
| 40 | + """Patch the registry directly in the prefetch_venvs module.""" |
| 41 | + with patch.object( |
| 42 | + prefetch_venvs_module, "ACTOR_ENVIRONMENT_REGISTRY", mock_registry |
| 43 | + ): |
| 44 | + yield prefetch_venvs_module.prefetch_venvs |
| 45 | + |
| 46 | + |
| 47 | +class TestPrefetchVenvs: |
| 48 | + """Tests for the prefetch_venvs function.""" |
| 49 | + |
| 50 | + def test_prefetch_venvs_no_filters(self, prefetch_venvs_func): |
| 51 | + """Test that all uv-based venvs are prefetched when no filters are provided.""" |
| 52 | + with patch( |
| 53 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 54 | + ) as mock_create_venv: |
| 55 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 56 | + |
| 57 | + prefetch_venvs_func(filters=None) |
| 58 | + |
| 59 | + assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER |
| 60 | + |
| 61 | + # Verify the actors that were called |
| 62 | + call_args = [call[0] for call in mock_create_venv.call_args_list] |
| 63 | + actor_fqns = [args[1] for args in call_args] |
| 64 | + |
| 65 | + assert ( |
| 66 | + "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker" |
| 67 | + in actor_fqns |
| 68 | + ) |
| 69 | + assert ( |
| 70 | + "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" |
| 71 | + in actor_fqns |
| 72 | + ) |
| 73 | + assert ( |
| 74 | + "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" |
| 75 | + in actor_fqns |
| 76 | + ) |
| 77 | + |
| 78 | + def test_prefetch_venvs_single_filter(self, prefetch_venvs_func): |
| 79 | + """Test filtering with a single filter string.""" |
| 80 | + with patch( |
| 81 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 82 | + ) as mock_create_venv: |
| 83 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 84 | + |
| 85 | + prefetch_venvs_func(filters=["vllm"]) |
| 86 | + |
| 87 | + # Should only create venvs for actors containing "vllm" (1 actor) |
| 88 | + assert mock_create_venv.call_count == 1 * CALL_MULTIPLIER |
| 89 | + |
| 90 | + call_args = mock_create_venv.call_args[0] |
| 91 | + assert ( |
| 92 | + call_args[1] |
| 93 | + == "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker" |
| 94 | + ) |
| 95 | + |
| 96 | + def test_prefetch_venvs_multiple_filters(self, prefetch_venvs_func): |
| 97 | + """Test filtering with multiple filter strings (OR logic).""" |
| 98 | + with patch( |
| 99 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 100 | + ) as mock_create_venv: |
| 101 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 102 | + |
| 103 | + prefetch_venvs_func(filters=["vllm", "megatron"]) |
| 104 | + |
| 105 | + # Should create venvs for actors containing "vllm" OR "megatron" (2 actors) |
| 106 | + assert mock_create_venv.call_count == 2 * CALL_MULTIPLIER |
| 107 | + |
| 108 | + call_args = [call[0] for call in mock_create_venv.call_args_list] |
| 109 | + actor_fqns = [args[1] for args in call_args] |
| 110 | + |
| 111 | + assert ( |
| 112 | + "nemo_rl.models.generation.vllm.vllm_worker.VllmGenerationWorker" |
| 113 | + in actor_fqns |
| 114 | + ) |
| 115 | + assert ( |
| 116 | + "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" |
| 117 | + in actor_fqns |
| 118 | + ) |
| 119 | + |
| 120 | + def test_prefetch_venvs_filter_no_match(self, prefetch_venvs_func): |
| 121 | + """Test that no venvs are created when filter matches nothing.""" |
| 122 | + with patch( |
| 123 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 124 | + ) as mock_create_venv: |
| 125 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 126 | + |
| 127 | + prefetch_venvs_func(filters=["nonexistent"]) |
| 128 | + |
| 129 | + # Should not create any venvs |
| 130 | + assert mock_create_venv.call_count == 0 |
| 131 | + |
| 132 | + def test_prefetch_venvs_skips_system_python(self, prefetch_venvs_func): |
| 133 | + """Test that system python actors are skipped even if they match filters.""" |
| 134 | + with patch( |
| 135 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 136 | + ) as mock_create_venv: |
| 137 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 138 | + |
| 139 | + # Filter for "environment" which matches system python actors |
| 140 | + prefetch_venvs_func(filters=["environment"]) |
| 141 | + |
| 142 | + # Should not create any venvs since matching actors use system python |
| 143 | + assert mock_create_venv.call_count == 0 |
| 144 | + |
| 145 | + def test_prefetch_venvs_partial_match(self, prefetch_venvs_func): |
| 146 | + """Test that filter matches partial strings within FQN.""" |
| 147 | + with patch( |
| 148 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 149 | + ) as mock_create_venv: |
| 150 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 151 | + |
| 152 | + # "policy" should match both dtensor_policy_worker and megatron_policy_worker |
| 153 | + prefetch_venvs_func(filters=["policy"]) |
| 154 | + |
| 155 | + assert mock_create_venv.call_count == 2 * CALL_MULTIPLIER |
| 156 | + |
| 157 | + call_args = [call[0] for call in mock_create_venv.call_args_list] |
| 158 | + actor_fqns = [args[1] for args in call_args] |
| 159 | + |
| 160 | + assert ( |
| 161 | + "nemo_rl.models.policy.workers.dtensor_policy_worker.DTensorPolicyWorker" |
| 162 | + in actor_fqns |
| 163 | + ) |
| 164 | + assert ( |
| 165 | + "nemo_rl.models.policy.workers.megatron_policy_worker.MegatronPolicyWorker" |
| 166 | + in actor_fqns |
| 167 | + ) |
| 168 | + |
| 169 | + def test_prefetch_venvs_empty_filter_list(self, prefetch_venvs_func): |
| 170 | + """Test that empty filter list is treated as no filtering (falsy).""" |
| 171 | + with patch( |
| 172 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 173 | + ) as mock_create_venv: |
| 174 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 175 | + |
| 176 | + # Empty list should be falsy and prefetch all |
| 177 | + prefetch_venvs_func(filters=[]) |
| 178 | + |
| 179 | + # Should create venvs for all uv-based actors (3 total) |
| 180 | + assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER |
| 181 | + |
| 182 | + def test_prefetch_venvs_continues_on_error(self, prefetch_venvs_func): |
| 183 | + """Test that prefetching continues even if one venv creation fails.""" |
| 184 | + with patch( |
| 185 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 186 | + ) as mock_create_venv: |
| 187 | + # Provide enough return values for both prefetch and frozen env symlinks |
| 188 | + mock_create_venv.side_effect = [ |
| 189 | + Exception("Test error"), |
| 190 | + "/path/to/venv/bin/python", |
| 191 | + "/path/to/venv/bin/python", |
| 192 | + ] * CALL_MULTIPLIER |
| 193 | + |
| 194 | + # Should not raise, should continue with other venvs |
| 195 | + prefetch_venvs_func(filters=None) |
| 196 | + |
| 197 | + # All 3 uv-based actors should have been attempted |
| 198 | + assert mock_create_venv.call_count == 3 * CALL_MULTIPLIER |
| 199 | + |
| 200 | + def test_prefetch_venvs_case_sensitive_filter(self, prefetch_venvs_func): |
| 201 | + """Test that filters are case-sensitive.""" |
| 202 | + with patch( |
| 203 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 204 | + ) as mock_create_venv: |
| 205 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 206 | + |
| 207 | + # "VLLM" (uppercase) should not match "vllm" (lowercase) |
| 208 | + prefetch_venvs_func(filters=["VLLM"]) |
| 209 | + |
| 210 | + assert mock_create_venv.call_count == 0 |
| 211 | + |
| 212 | + def test_prefetch_venvs_summary_no_filters(self, prefetch_venvs_func, capsys): |
| 213 | + """Test that summary is printed with correct counts and names when no filters.""" |
| 214 | + with patch( |
| 215 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 216 | + ) as mock_create_venv: |
| 217 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 218 | + |
| 219 | + prefetch_venvs_func(filters=None) |
| 220 | + |
| 221 | + captured = capsys.readouterr() |
| 222 | + assert "Venv prefetching complete! Summary:" in captured.out |
| 223 | + assert "Prefetched: 3" in captured.out |
| 224 | + assert "Skipped (system Python): 2" in captured.out |
| 225 | + # Verify prefetched env names are listed |
| 226 | + assert "VllmGenerationWorker" in captured.out |
| 227 | + assert "DTensorPolicyWorker" in captured.out |
| 228 | + assert "MegatronPolicyWorker" in captured.out |
| 229 | + # Verify skipped env names are listed |
| 230 | + assert "MathEnvironment" in captured.out |
| 231 | + assert "CodeEnvironment" in captured.out |
| 232 | + # "Skipped (filtered out)" should not appear when no filters |
| 233 | + assert "Skipped (filtered out)" not in captured.out |
| 234 | + |
| 235 | + def test_prefetch_venvs_summary_with_filters(self, prefetch_venvs_func, capsys): |
| 236 | + """Test that summary includes filtered out names when filters are used.""" |
| 237 | + with patch( |
| 238 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 239 | + ) as mock_create_venv: |
| 240 | + mock_create_venv.return_value = "/path/to/venv/bin/python" |
| 241 | + |
| 242 | + prefetch_venvs_func(filters=["vllm"]) |
| 243 | + |
| 244 | + captured = capsys.readouterr() |
| 245 | + assert "Venv prefetching complete! Summary:" in captured.out |
| 246 | + assert "Prefetched: 1" in captured.out |
| 247 | + assert "Skipped (system Python): 0" in captured.out |
| 248 | + assert "Skipped (filtered out): 4" in captured.out |
| 249 | + # Verify prefetched env name is listed |
| 250 | + assert "VllmGenerationWorker" in captured.out |
| 251 | + # Verify filtered out env names are listed |
| 252 | + assert "DTensorPolicyWorker" in captured.out |
| 253 | + assert "MegatronPolicyWorker" in captured.out |
| 254 | + |
| 255 | + def test_prefetch_venvs_summary_with_failures(self, prefetch_venvs_func, capsys): |
| 256 | + """Test that summary includes failed actor names when errors occur.""" |
| 257 | + with patch( |
| 258 | + "nemo_rl.utils.prefetch_venvs.create_local_venv" |
| 259 | + ) as mock_create_venv: |
| 260 | + # Provide enough return values for both prefetch and frozen env symlinks |
| 261 | + mock_create_venv.side_effect = [ |
| 262 | + Exception("Test error"), |
| 263 | + "/path/to/venv/bin/python", |
| 264 | + "/path/to/venv/bin/python", |
| 265 | + ] * CALL_MULTIPLIER |
| 266 | + |
| 267 | + prefetch_venvs_func(filters=None) |
| 268 | + |
| 269 | + captured = capsys.readouterr() |
| 270 | + assert "Venv prefetching complete! Summary:" in captured.out |
| 271 | + assert "Prefetched: 2" in captured.out |
| 272 | + assert "Failed: 1" in captured.out |
0 commit comments