|
8 | 8 | import json |
9 | 9 | import logging |
10 | 10 | import os |
| 11 | +import re |
11 | 12 | import tempfile |
12 | 13 | import time |
13 | 14 | from dataclasses import dataclass, field |
14 | 15 | from datetime import datetime |
15 | 16 | from shutil import copy2, rmtree |
16 | | -from typing import Any, cast, Dict, Iterable, List, Mapping, Optional, Set, Type # noqa |
| 17 | +from typing import Any, cast, Dict, Iterable, List, Optional, Tuple # noqa |
17 | 18 |
|
18 | 19 | from torchx.schedulers.api import ( |
19 | 20 | AppDryRunInfo, |
@@ -322,13 +323,25 @@ def wait_until_finish(self, app_id: str, timeout: int = 30) -> None: |
322 | 323 | break |
323 | 324 | time.sleep(1) |
324 | 325 |
|
325 | | - def _cancel_existing(self, app_id: str) -> None: # pragma: no cover |
| 326 | + def _parse_app_id(self, app_id: str) -> Tuple[str, str]: |
| 327 | + # find index of '-' in the first :\d+- |
| 328 | + m = re.search(r":\d+-", app_id) |
| 329 | + if m: |
| 330 | + sep = m.span()[1] |
| 331 | + addr = app_id[: sep - 1] |
| 332 | + app_id = app_id[sep:] |
| 333 | + return addr, app_id |
| 334 | + |
326 | 335 | addr, _, app_id = app_id.partition("-") |
| 336 | + return addr, app_id |
| 337 | + |
| 338 | + def _cancel_existing(self, app_id: str) -> None: # pragma: no cover |
| 339 | + addr, app_id = self._parse_app_id(app_id) |
327 | 340 | client = JobSubmissionClient(f"http://{addr}") |
328 | 341 | client.stop_job(app_id) |
329 | 342 |
|
330 | 343 | def _get_job_status(self, app_id: str) -> JobStatus: |
331 | | - addr, _, app_id = app_id.partition("-") |
| 344 | + addr, app_id = self._parse_app_id(app_id) |
332 | 345 | client = JobSubmissionClient(f"http://{addr}") |
333 | 346 | status = client.get_job_status(app_id) |
334 | 347 | if isinstance(status, str): |
@@ -375,7 +388,7 @@ def log_iter( |
375 | 388 | streams: Optional[Stream] = None, |
376 | 389 | ) -> Iterable[str]: |
377 | 390 | # TODO: support tailing, streams etc.. |
378 | | - addr, _, app_id = app_id.partition("-") |
| 391 | + addr, app_id = self._parse_app_id(app_id) |
379 | 392 | client: JobSubmissionClient = JobSubmissionClient(f"http://{addr}") |
380 | 393 | logs: str = client.get_job_logs(app_id) |
381 | 394 | iterator = split_lines(logs) |
|
0 commit comments