1414from nemo_reinforcer .distributed .virtual_cluster import (
1515 _get_node_ip_and_free_port ,
1616 PY_EXECUTABLES ,
17+ RayVirtualCluster ,
18+ ResourceInsufficientError ,
1719)
1820import ray
21+ import pytest
22+ import os
23+ from unittest .mock import patch , MagicMock
24+ import importlib
1925
2026
2127def test_get_node_ip_and_free_port_does_not_start_with_zero ():
@@ -30,3 +36,78 @@ def test_get_node_ip_and_free_port_does_not_start_with_zero():
3036 ).remote ()
3137 )
3238 assert not node_ip .startswith ("0." ), "Node IP should not start with 0.*.*.*"
39+
40+
41+ def test_env_max_retries_invalid_value ():
42+ """Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES rejects invalid values (less than or equal to zero)."""
43+
44+ # Mock environment with invalid max_retries value
45+ env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES" : "0" }
46+
47+ with patch .dict (os .environ , env_vars , clear = True ):
48+ with pytest .raises (AssertionError ):
49+ RayVirtualCluster (bundle_ct_per_node_list = [1 ])
50+
51+
52+ def test_env_max_retries_non_integer ():
53+ """Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES handles non-integer values properly."""
54+
55+ # Mock environment with non-integer max_retries value
56+ env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES" : "not_a_number" }
57+
58+ with patch .dict (os .environ , env_vars , clear = True ):
59+ with pytest .raises (ValueError ):
60+ RayVirtualCluster (bundle_ct_per_node_list = [1 ])
61+
62+
63+ def test_env_max_retries_default_value ():
64+ """Test that default value for NRL_VIRTUAL_CLUSTER_MAX_RETRIES is used when not set."""
65+
66+ # Ensure environment variable is not set
67+ with (
68+ patch .dict (os .environ , {}, clear = True ),
69+ patch (
70+ "nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups"
71+ ) as mock_init ,
72+ ):
73+ # Mock successful initialization
74+ mock_init .return_value = [MagicMock ()]
75+
76+ # Create cluster
77+ cluster = RayVirtualCluster (bundle_ct_per_node_list = [1 ])
78+
79+ # Default value should be 6 (as seen in the code)
80+ # We can't directly verify this, but we can check that initialization was attempted
81+ assert mock_init .call_count == 1
82+
83+
84+ def test_env_max_retries_exhausted ():
85+ """Test that NRL_VIRTUAL_CLUSTER_MAX_RETRIES correctly handles the case where all retries fail."""
86+
87+ # Set specific retry count to 4
88+ retry_count = 4
89+ env_vars = {"NRL_VIRTUAL_CLUSTER_MAX_RETRIES" : str (retry_count )}
90+
91+ with (
92+ patch .dict (os .environ , env_vars , clear = True ),
93+ patch (
94+ "nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster._init_placement_groups"
95+ ) as mock_init ,
96+ patch ("time.sleep" ) as mock_sleep ,
97+ ):
98+ # Make _init_placement_groups raise ResourceInsufficientError each time
99+ mock_init .side_effect = ResourceInsufficientError ("Not enough resources" )
100+
101+ # Create cluster - should retry retry_count times and then fail
102+ with pytest .raises (ResourceInsufficientError ):
103+ RayVirtualCluster (bundle_ct_per_node_list = [1 ])
104+
105+ # Verify _init_placement_groups was called exactly retry_count times
106+ assert mock_init .call_count == retry_count
107+
108+ # Verify time.sleep was called with exponentially increasing values
109+ assert mock_sleep .call_count == retry_count
110+ mock_sleep .assert_any_call (1 ) # 2^0
111+ mock_sleep .assert_any_call (2 ) # 2^1
112+ mock_sleep .assert_any_call (4 ) # 2^2
113+ mock_sleep .assert_any_call (8 ) # 2^3
0 commit comments