Skip to content

Commit da191b4

Browse files
authored
feat: introduce a debug API for backoff and retries for RayVirtualCluster (#234)
Signed-off-by: Terry Kong <[email protected]>
1 parent 8780093 commit da191b4

File tree

2 files changed

+107
-3
lines changed

2 files changed

+107
-3
lines changed

nemo_reinforcer/distributed/virtual_cluster.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import ray
2121
import logging
22+
import time
2223
from ray.util.placement_group import placement_group, remove_placement_group
2324
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
2425

@@ -101,6 +102,10 @@ def init_ray(log_dir: Optional[str] = None):
101102
logger.info(f"Started local cluster with: {ray.cluster_resources()}")
102103

103104

105+
class ResourceInsufficientError(Exception):
106+
"""Exception raised when the cluster does not have enough resources to satisfy the requested configuration."""
107+
108+
104109
class RayVirtualCluster:
105110
"""Creates a virtual distributed cluster using Ray placement groups.
106111
@@ -146,7 +151,25 @@ def __init__(
146151
)
147152
self.max_colocated_worker_groups = max_colocated_worker_groups
148153
self.name = name
149-
self._init_placement_groups(placement_group_strategy)
154+
max_retries = int(os.environ.get("NRL_VIRTUAL_CLUSTER_MAX_RETRIES", 6))
155+
assert max_retries > 0, (
156+
f"NRL_VIRTUAL_CLUSTER_MAX_RETRIES={max_retries} must be an integer greater than 0"
157+
)
158+
for i in range(max_retries):
159+
try:
160+
self._init_placement_groups(placement_group_strategy)
161+
# Reaching here means we were successful
162+
break
163+
except ResourceInsufficientError:
164+
print(
165+
f"Retrying placement group creation... {i + 1}/{max_retries}. Next retry in {2**i} seconds."
166+
)
167+
time.sleep(2**i)
168+
continue
169+
else:
170+
raise ResourceInsufficientError(
171+
f"Maximum number of retries reached ({max_retries}). Cluster resources may be insufficient or cluster itself is highly unstable. Please check your cluster configuration and your cluster logs."
172+
)
150173

151174
def _init_placement_groups(self, strategy: str):
152175
"""Creates placement groups for each node in the cluster. Has empty groups for nodes that don't have any bundles.
@@ -175,12 +198,12 @@ def _init_placement_groups(self, strategy: str):
175198

176199
# Validate resources
177200
if self.use_gpus and total_requested_gpus > total_available_gpus:
178-
raise ValueError(
201+
raise ResourceInsufficientError(
179202
f"Not enough GPUs available. Requested {total_requested_gpus} GPUs, but only {total_available_gpus} are available in the cluster."
180203
)
181204

182205
if total_requested_cpus > total_available_cpus:
183-
raise ValueError(
206+
raise ResourceInsufficientError(
184207
f"Not enough CPUs available. Requested {total_requested_cpus} CPUs, but only {total_available_cpus} are available in the cluster."
185208
)
186209

tests/unit/distributed/test_virtual_cluster.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,14 @@
1414
from nemo_reinforcer.distributed.virtual_cluster import (
1515
_get_node_ip_and_free_port,
1616
PY_EXECUTABLES,
17+
RayVirtualCluster,
18+
ResourceInsufficientError,
1719
)
1820
import ray
21+
import pytest
22+
import os
23+
from unittest.mock import patch, MagicMock
24+
import importlib
1925

2026

2127
def test_get_node_ip_and_free_port_does_not_start_with_zero():
@@ -30,3 +36,78 @@ def test_get_node_ip_and_free_port_does_not_start_with_zero():
3036
).remote()
3137
)
3238
assert not node_ip.startswith("0."), "Node IP should not start with 0.*.*.*"
39+
40+
41+
def test_env_max_retries_invalid_value():
42+
"""Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES rejects invalid values (less than or equal to zero)."""
43+
44+
# Mock environment with invalid max_retries value
45+
env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": "0"}
46+
47+
with patch.dict(os.environ, env_vars, clear=True):
48+
with pytest.raises(AssertionError):
49+
RayVirtualCluster(bundle_ct_per_node_list=[1])
50+
51+
52+
def test_env_max_retries_non_integer():
53+
"""Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES handles non-integer values properly."""
54+
55+
# Mock environment with non-integer max_retries value
56+
env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": "not_a_number"}
57+
58+
with patch.dict(os.environ, env_vars, clear=True):
59+
with pytest.raises(ValueError):
60+
RayVirtualCluster(bundle_ct_per_node_list=[1])
61+
62+
63+
def test_env_max_retries_default_value():
64+
"""Test that default value for NRL_VIRTUAL_CLUSTER_MAX_RETRIES is used when not set."""
65+
66+
# Ensure environment variable is not set
67+
with (
68+
patch.dict(os.environ, {}, clear=True),
69+
patch(
70+
"nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups"
71+
) as mock_init,
72+
):
73+
# Mock successful initialization
74+
mock_init.return_value = [MagicMock()]
75+
76+
# Create cluster
77+
cluster = RayVirtualCluster(bundle_ct_per_node_list=[1])
78+
79+
# Default value should be 6 (as seen in the code)
80+
# We can't directly verify this, but we can check that initialization was attempted
81+
assert mock_init.call_count == 1
82+
83+
84+
def test_env_max_retries_exhausted():
85+
"""Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES correctly handles the case where all retries fail."""
86+
87+
# Set specific retry count to 4
88+
retry_count = 4
89+
env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES": str(retry_count)}
90+
91+
with (
92+
patch.dict(os.environ, env_vars, clear=True),
93+
patch(
94+
"nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups"
95+
) as mock_init,
96+
patch("time.sleep") as mock_sleep,
97+
):
98+
# Make _init_placement_groups raise ResourceInsufficientError each time
99+
mock_init.side_effect = ResourceInsufficientError("Not enough resources")
100+
101+
# Create cluster - should retry retry_count times and then fail
102+
with pytest.raises(ResourceInsufficientError):
103+
RayVirtualCluster(bundle_ct_per_node_list=[1])
104+
105+
# Verify _init_placement_groups was called exactly retry_count times
106+
assert mock_init.call_count == retry_count
107+
108+
# Verify time.sleep was called with exponentially increasing values
109+
assert mock_sleep.call_count == retry_count
110+
mock_sleep.assert_any_call(1) # 2^0
111+
mock_sleep.assert_any_call(2) # 2^1
112+
mock_sleep.assert_any_call(4) # 2^2
113+
mock_sleep.assert_any_call(8) # 2^3

0 commit comments

Comments
 (0)