|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import pytest |
15 | 16 | import mars |
| 17 | + |
16 | 18 | from .... import tensor as mt |
17 | 19 | from .... import dataframe as md |
18 | 20 | from ....tests.core import require_ray |
19 | 21 | from ....utils import lazy_import |
20 | 22 | from ..ray import ( |
21 | 23 | new_cluster_in_ray, |
22 | 24 | new_ray_session, |
| 25 | + _load_config, |
| 26 | + new_cluster, |
23 | 27 | ) |
24 | 28 |
|
25 | 29 | ray = lazy_import("ray") |
@@ -64,3 +68,35 @@ def new_ray_session_test(): |
64 | 68 |
|
65 | 69 | gc.collect() |
66 | 70 | mars.execute(mt.random.RandomState(0).rand(100, 5).sum()) |
| 71 | + |
| 72 | + |
| 73 | +@require_ray |
| 74 | +@pytest.mark.parametrize( |
| 75 | + "test_option", |
| 76 | + [ |
| 77 | + [True, 0, ["ray://test_cluster/1/0", "ray://test_cluster/2/0"]], |
| 78 | + [False, 0, ["ray://test_cluster/0/1", "ray://test_cluster/1/0"]], |
| 79 | + [True, 2, ["ray://test_cluster/1/0", "ray://test_cluster/2/0"]], |
| 80 | + [False, 5, ["ray://test_cluster/0/6", "ray://test_cluster/1/0"]], |
| 81 | + ], |
| 82 | +) |
| 83 | +@pytest.mark.asyncio |
| 84 | +async def test_optional_supervisor_node(ray_start_regular, test_option): |
| 85 | + import logging |
| 86 | + |
| 87 | + logging.basicConfig(level=logging.INFO) |
| 88 | + supervisor_standalone, supervisor_sub_pool_num, worker_addresses = test_option |
| 89 | + config = _load_config() |
| 90 | + config["cluster"]["ray"]["supervisor"]["standalone"] = supervisor_standalone |
| 91 | + config["cluster"]["ray"]["supervisor"]["sub_pool_num"] = supervisor_sub_pool_num |
| 92 | + client = await new_cluster( |
| 93 | + "test_cluster", |
| 94 | + supervisor_mem=1 * 1024**3, |
| 95 | + worker_num=2, |
| 96 | + worker_cpu=2, |
| 97 | + worker_mem=1 * 1024**3, |
| 98 | + config=config, |
| 99 | + ) |
| 100 | + async with client: |
| 101 | + assert client.address == "ray://test_cluster/0/0" |
| 102 | + assert client._cluster._worker_addresses == worker_addresses |
0 commit comments