Skip to content

Commit b9a1d0d

Browse files
authored
add option for programatically defining ray job client (#762)
Signed-off-by: Kevin <[email protected]>
1 parent dd2db49 commit b9a1d0d

File tree

2 files changed

+75
-21
lines changed

2 files changed

+75
-21
lines changed

torchx/schedulers/ray_scheduler.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from dataclasses import dataclass, field
1515
from datetime import datetime
1616
from shutil import copy2, rmtree
17-
from typing import Any, cast, Dict, Iterable, List, Optional, Tuple # noqa
17+
from typing import Any, cast, Dict, Final, Iterable, List, Optional, Tuple # noqa
18+
19+
import urllib3
1820

1921
from torchx.schedulers.api import (
2022
AppDryRunInfo,
@@ -148,9 +150,35 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
148150
149151
"""
150152

151-
def __init__(self, session_name: str) -> None:
153+
def __init__(
154+
self, session_name: str, ray_client: Optional[JobSubmissionClient] = None
155+
) -> None:
152156
super().__init__("ray", session_name)
153157

158+
# w/o Final None check in _get_ray_client does not work as it pyre assumes mutability
159+
self._ray_client: Final[Optional[JobSubmissionClient]] = ray_client
160+
161+
def _get_ray_client(
162+
self, job_submission_netloc: Optional[str] = None
163+
) -> JobSubmissionClient:
164+
if self._ray_client is not None:
165+
client_netloc = urllib3.util.parse_url(
166+
self._ray_client.get_address()
167+
).netloc
168+
if job_submission_netloc and job_submission_netloc != client_netloc:
169+
raise ValueError(
170+
f"client netloc ({client_netloc}) does not match job netloc ({job_submission_netloc})"
171+
)
172+
return self._ray_client
173+
elif os.getenv("RAY_ADDRESS"):
174+
return JobSubmissionClient(os.getenv("RAY_ADDRESS"))
175+
elif not job_submission_netloc:
176+
raise Exception(
177+
"RAY_ADDRESS env variable or a scheduler with an attached Ray JobSubmissionClient is expected."
178+
" See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info"
179+
)
180+
return JobSubmissionClient(f"http://{job_submission_netloc}")
181+
154182
# TODO: Add address as a potential CLI argument after writing ray.status() or passing in config file
155183
def _run_opts(self) -> runopts:
156184
opts = runopts()
@@ -196,9 +224,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[RayJob]) -> str:
196224
)
197225

198226
# 0. Create Job Client
199-
client: JobSubmissionClient = JobSubmissionClient(
200-
f"http://{job_submission_addr}"
201-
)
227+
client = self._get_ray_client(job_submission_netloc=job_submission_addr)
202228

203229
# 1. Copy Ray driver utilities
204230
current_directory = os.path.dirname(os.path.abspath(__file__))
@@ -341,12 +367,12 @@ def _parse_app_id(self, app_id: str) -> Tuple[str, str]:
341367

342368
def _cancel_existing(self, app_id: str) -> None: # pragma: no cover
343369
addr, app_id = self._parse_app_id(app_id)
344-
client = JobSubmissionClient(f"http://{addr}")
370+
client = self._get_ray_client(job_submission_netloc=addr)
345371
client.stop_job(app_id)
346372

347373
def _get_job_status(self, app_id: str) -> JobStatus:
348374
addr, app_id = self._parse_app_id(app_id)
349-
client = JobSubmissionClient(f"http://{addr}")
375+
client = self._get_ray_client(job_submission_netloc=addr)
350376
status = client.get_job_status(app_id)
351377
if isinstance(status, str):
352378
return cast(JobStatus, status)
@@ -393,26 +419,22 @@ def log_iter(
393419
) -> Iterable[str]:
394420
# TODO: support tailing, streams etc..
395421
addr, app_id = self._parse_app_id(app_id)
396-
client: JobSubmissionClient = JobSubmissionClient(f"http://{addr}")
422+
client: JobSubmissionClient = self._get_ray_client(
423+
job_submission_netloc=addr
424+
)
397425
logs: str = client.get_job_logs(app_id)
398426
iterator = split_lines(logs)
399427
if regex:
400428
return filter_regex(regex, iterator)
401429
return iterator
402430

403431
def list(self) -> List[ListAppResponse]:
404-
address = os.getenv("RAY_ADDRESS")
405-
if not address:
406-
raise Exception(
407-
"RAY_ADDRESS env variable is expected to be set to list jobs on ray scheduler."
408-
" See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info"
409-
)
410-
client = JobSubmissionClient(address)
432+
client = self._get_ray_client()
411433
jobs = client.list_jobs()
412-
ip = address.split("http://", 1)[-1]
434+
netloc = urllib3.util.parse_url(client.get_address()).netloc
413435
return [
414436
ListAppResponse(
415-
app_id=f"{ip}-{details.submission_id}",
437+
app_id=f"{netloc}-{details.submission_id}",
416438
state=_ray_status_to_torchx_appstate[details.status],
417439
)
418440
for details in jobs

torchx/schedulers/test/ray_scheduler_test.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from shutil import copy2
1212
from typing import Any, cast, Iterable, Iterator, List, Optional, Type
1313
from unittest import TestCase
14-
from unittest.mock import patch
14+
from unittest.mock import MagicMock, patch
1515

1616
from torchx.schedulers import get_scheduler_factories
1717
from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, ListAppResponse
@@ -22,6 +22,7 @@
2222
if has_ray():
2323
import ray
2424
from ray.cluster_utils import Cluster
25+
from ray.dashboard.modules.job.sdk import JobSubmissionClient
2526
from ray.util.placement_group import remove_placement_group
2627
from torchx.schedulers.ray import ray_driver
2728
from torchx.schedulers.ray_scheduler import (
@@ -83,6 +84,9 @@ def setUp(self) -> None:
8384
}
8485
)
8586

87+
# mock validation step so that instantiation doesn't fail due to inability to reach dashboard
88+
JobSubmissionClient._check_connection_and_version = MagicMock()
89+
8690
self._scheduler = RayScheduler("test_session")
8791

8892
self._isfile_patch = patch("torchx.schedulers.ray_scheduler.os.path.isfile")
@@ -320,11 +324,15 @@ def test_parse_app_id(self) -> None:
320324
def test_list_throws_without_address(self) -> None:
321325
if "RAY_ADDRESS" in os.environ:
322326
del os.environ["RAY_ADDRESS"]
323-
with self.assertRaisesRegex(
324-
Exception, "RAY_ADDRESS env variable is expected"
325-
):
327+
with self.assertRaisesRegex(Exception, "RAY_ADDRESS env variable"):
326328
self._scheduler.list()
327329

330+
def test_list_doesnt_throw_with_client(self) -> None:
331+
ray_client = JobSubmissionClient(address="https://test.com")
332+
ray_client.list_jobs = MagicMock(return_value=[])
333+
_scheduler_with_client = RayScheduler("client_session", ray_client)
334+
_scheduler_with_client.list() # testing for success (should not throw exception)
335+
328336
def test_min_replicas(self) -> None:
329337
app = AppDef(
330338
name="app",
@@ -358,6 +366,30 @@ def test_min_replicas(self) -> None:
358366
):
359367
self._scheduler._submit_dryrun(app, cfg={})
360368

369+
def test_nonmatching_address(self) -> None:
370+
ray_client = JobSubmissionClient(address="https://test.address.com")
371+
_scheduler_with_client = RayScheduler("client_session", ray_client)
372+
app = AppDef(
373+
name="app",
374+
roles=[
375+
Role(name="role", image="."),
376+
],
377+
)
378+
with self.assertRaisesRegex(
379+
ValueError, "client netloc .* does not match job netloc .*"
380+
):
381+
_scheduler_with_client.submit(app=app, cfg={})
382+
383+
def test_client_with_headers(self) -> None:
384+
# This tests only one option for the client. Different versions may have more options available.
385+
headers = {"Authorization": "Bearer: token"}
386+
ray_client = JobSubmissionClient(
387+
address="https://test.com", headers=headers, verify=False
388+
)
389+
_scheduler_with_client = RayScheduler("client_session", ray_client)
390+
scheduler_client = _scheduler_with_client._get_ray_client()
391+
self.assertDictContainsSubset(scheduler_client._headers, headers)
392+
361393
class RayClusterSetup:
362394
_instance = None # pyre-ignore
363395
_cluster = None # pyre-ignore

0 commit comments

Comments
 (0)