|
14 | 14 | from dataclasses import dataclass, field |
15 | 15 | from datetime import datetime |
16 | 16 | from shutil import copy2, rmtree |
17 | | -from typing import Any, cast, Dict, Iterable, List, Optional, Tuple # noqa |
| 17 | +from typing import Any, cast, Dict, Final, Iterable, List, Optional, Tuple # noqa |
| 18 | + |
| 19 | +import urllib3 |
18 | 20 |
|
19 | 21 | from torchx.schedulers.api import ( |
20 | 22 | AppDryRunInfo, |
@@ -148,9 +150,35 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]): |
148 | 150 |
|
149 | 151 | """ |
150 | 152 |
|
151 | | - def __init__(self, session_name: str) -> None: |
| 153 | + def __init__( |
| 154 | + self, session_name: str, ray_client: Optional[JobSubmissionClient] = None |
| 155 | + ) -> None: |
152 | 156 | super().__init__("ray", session_name) |
153 | 157 |
|
| 158 | + # w/o Final None check in _get_ray_client does not work as it pyre assumes mutability |
| 159 | + self._ray_client: Final[Optional[JobSubmissionClient]] = ray_client |
| 160 | + |
| 161 | + def _get_ray_client( |
| 162 | + self, job_submission_netloc: Optional[str] = None |
| 163 | + ) -> JobSubmissionClient: |
| 164 | + if self._ray_client is not None: |
| 165 | + client_netloc = urllib3.util.parse_url( |
| 166 | + self._ray_client.get_address() |
| 167 | + ).netloc |
| 168 | + if job_submission_netloc and job_submission_netloc != client_netloc: |
| 169 | + raise ValueError( |
| 170 | + f"client netloc ({client_netloc}) does not match job netloc ({job_submission_netloc})" |
| 171 | + ) |
| 172 | + return self._ray_client |
| 173 | + elif os.getenv("RAY_ADDRESS"): |
| 174 | + return JobSubmissionClient(os.getenv("RAY_ADDRESS")) |
| 175 | + elif not job_submission_netloc: |
| 176 | + raise Exception( |
| 177 | + "RAY_ADDRESS env variable or a scheduler with an attached Ray JobSubmissionClient is expected." |
| 178 | + " See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info" |
| 179 | + ) |
| 180 | + return JobSubmissionClient(f"http://{job_submission_netloc}") |
| 181 | + |
154 | 182 | # TODO: Add address as a potential CLI argument after writing ray.status() or passing in config file |
155 | 183 | def _run_opts(self) -> runopts: |
156 | 184 | opts = runopts() |
@@ -196,9 +224,7 @@ def schedule(self, dryrun_info: AppDryRunInfo[RayJob]) -> str: |
196 | 224 | ) |
197 | 225 |
|
198 | 226 | # 0. Create Job Client |
199 | | - client: JobSubmissionClient = JobSubmissionClient( |
200 | | - f"http://{job_submission_addr}" |
201 | | - ) |
| 227 | + client = self._get_ray_client(job_submission_netloc=job_submission_addr) |
202 | 228 |
|
203 | 229 | # 1. Copy Ray driver utilities |
204 | 230 | current_directory = os.path.dirname(os.path.abspath(__file__)) |
@@ -341,12 +367,12 @@ def _parse_app_id(self, app_id: str) -> Tuple[str, str]: |
341 | 367 |
|
342 | 368 | def _cancel_existing(self, app_id: str) -> None: # pragma: no cover |
343 | 369 | addr, app_id = self._parse_app_id(app_id) |
344 | | - client = JobSubmissionClient(f"http://{addr}") |
| 370 | + client = self._get_ray_client(job_submission_netloc=addr) |
345 | 371 | client.stop_job(app_id) |
346 | 372 |
|
347 | 373 | def _get_job_status(self, app_id: str) -> JobStatus: |
348 | 374 | addr, app_id = self._parse_app_id(app_id) |
349 | | - client = JobSubmissionClient(f"http://{addr}") |
| 375 | + client = self._get_ray_client(job_submission_netloc=addr) |
350 | 376 | status = client.get_job_status(app_id) |
351 | 377 | if isinstance(status, str): |
352 | 378 | return cast(JobStatus, status) |
@@ -393,26 +419,22 @@ def log_iter( |
393 | 419 | ) -> Iterable[str]: |
394 | 420 | # TODO: support tailing, streams etc.. |
395 | 421 | addr, app_id = self._parse_app_id(app_id) |
396 | | - client: JobSubmissionClient = JobSubmissionClient(f"http://{addr}") |
| 422 | + client: JobSubmissionClient = self._get_ray_client( |
| 423 | + job_submission_netloc=addr |
| 424 | + ) |
397 | 425 | logs: str = client.get_job_logs(app_id) |
398 | 426 | iterator = split_lines(logs) |
399 | 427 | if regex: |
400 | 428 | return filter_regex(regex, iterator) |
401 | 429 | return iterator |
402 | 430 |
|
403 | 431 | def list(self) -> List[ListAppResponse]: |
404 | | - address = os.getenv("RAY_ADDRESS") |
405 | | - if not address: |
406 | | - raise Exception( |
407 | | - "RAY_ADDRESS env variable is expected to be set to list jobs on ray scheduler." |
408 | | - " See https://docs.ray.io/en/latest/cluster/jobs-package-ref.html#job-submission-sdk for more info" |
409 | | - ) |
410 | | - client = JobSubmissionClient(address) |
| 432 | + client = self._get_ray_client() |
411 | 433 | jobs = client.list_jobs() |
412 | | - ip = address.split("http://", 1)[-1] |
| 434 | + netloc = urllib3.util.parse_url(client.get_address()).netloc |
413 | 435 | return [ |
414 | 436 | ListAppResponse( |
415 | | - app_id=f"{ip}-{details.submission_id}", |
| 437 | + app_id=f"{netloc}-{details.submission_id}", |
416 | 438 | state=_ray_status_to_torchx_appstate[details.status], |
417 | 439 | ) |
418 | 440 | for details in jobs |
|
0 commit comments