|
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | 16 | import base64 |
| 17 | +import glob |
17 | 18 | import json |
18 | 19 | import logging |
19 | 20 | import os |
20 | | -import queue |
21 | 21 | import subprocess |
22 | 22 | import tempfile |
23 | | -import threading |
24 | 23 | import time |
25 | 24 | from dataclasses import dataclass, field |
26 | 25 | from enum import Enum |
@@ -323,7 +322,8 @@ def launch(self, name: str, cmd: list[str]) -> tuple[str, str]: |
323 | 322 | launch_script = f""" |
324 | 323 | ln -s {self.pvc_job_dir}/ /nemo_run |
325 | 324 | cd /nemo_run/code |
326 | | -{" ".join(cmd)} |
| 325 | +mkdir -p {self.pvc_job_dir}/logs |
| 326 | +{" ".join(cmd)} 2>&1 | tee -a {self.pvc_job_dir}/logs/output-$HOSTNAME.log |
327 | 327 | """ |
328 | 328 | with open(os.path.join(self.job_dir, "launch_script.sh"), "w+") as f: |
329 | 329 | f.write(launch_script) |
@@ -371,91 +371,66 @@ def status(self, job_id: str) -> Optional[DGXCloudState]: |
371 | 371 | r_json = response.json() |
372 | 372 | return DGXCloudState(r_json["phase"]) |
373 | 373 |
|
374 | | - def _stream_url_sync(self, url: str, headers: dict, q: queue.Queue): |
375 | | - """Stream a single URL using requests and put chunks into the queue""" |
376 | | - try: |
377 | | - with requests.get(url, stream=True, headers=headers, verify=False) as response: |
378 | | - for line in response.iter_lines(decode_unicode=True): |
379 | | - q.put((url, f"{line}\n")) |
380 | | - except Exception as e: |
381 | | - logger.error(f"Error streaming URL {url}: {e}") |
382 | | - |
383 | | - finally: |
384 | | - q.put((url, None)) |
385 | | - |
386 | 374 | def fetch_logs( |
387 | 375 | self, |
388 | 376 | job_id: str, |
389 | 377 | stream: bool, |
390 | 378 | stderr: Optional[bool] = None, |
391 | 379 | stdout: Optional[bool] = None, |
392 | 380 | ) -> Iterable[str]: |
393 | | - token = self.get_auth_token() |
394 | | - if not token: |
395 | | - logger.error("Failed to retrieve auth token for fetch logs request.") |
396 | | - yield "" |
397 | | - |
398 | | - response = requests.get( |
399 | | - f"{self.base_url}/workloads", headers=self._default_headers(token=token) |
400 | | - ) |
401 | | - workload_name = next( |
402 | | - ( |
403 | | - workload["name"] |
404 | | - for workload in response.json()["workloads"] |
405 | | - if workload["id"] == job_id |
406 | | - ), |
407 | | - None, |
408 | | - ) |
409 | | - if workload_name is None: |
410 | | - logger.error(f"No workload found with id {job_id}") |
411 | | - yield "" |
| 381 | + while self.status(job_id) != DGXCloudState.RUNNING: |
| 382 | + logger.info("Waiting for job to start...") |
| 383 | + time.sleep(15) |
412 | 384 |
|
413 | | - urls = [ |
414 | | - f"{self.kube_apiserver_url}/api/v1/namespaces/runai-{self.project_name}/pods/{workload_name}-worker-{i}/log?container=pytorch" |
415 | | - for i in range(self.nodes) |
416 | | - ] |
| 385 | + cmd = ["tail"] |
417 | 386 |
|
418 | 387 | if stream: |
419 | | - urls = [url + "&follow=true" for url in urls] |
| 388 | + cmd.append("-f") |
420 | 389 |
|
421 | | - while self.status(job_id) != DGXCloudState.RUNNING: |
422 | | - logger.info("Waiting for job to start...") |
423 | | - time.sleep(15) |
| 390 | + # setting linked PVC job directory |
| 391 | + nemo_run_home = get_nemorun_home() |
| 392 | + job_subdir = self.job_dir[len(nemo_run_home) + 1 :] # +1 to remove the initial backslash |
| 393 | + self.pvc_job_dir = os.path.join(self.pvc_nemo_run_dir, job_subdir) |
424 | 394 |
|
425 | | - time.sleep(10) |
| 395 | + files = [] |
| 396 | + while len(files) < self.nodes: |
| 397 | + files = list(glob.glob(f"{self.pvc_job_dir}/logs/output-*.log")) |
| 398 | + logger.info(f"Waiting for {self.nodes - len(files)} log files to be created...") |
| 399 | + time.sleep(3) |
426 | 400 |
|
427 | | - q = queue.Queue() |
428 | | - active_urls = set(urls) |
| 401 | + cmd.extend(files) |
429 | 402 |
|
430 | | - # Start threads |
431 | | - threads = [ |
432 | | - threading.Thread( |
433 | | - target=self._stream_url_sync, args=(url, self._default_headers(token=token), q) |
434 | | - ) |
435 | | - for url in urls |
436 | | - ] |
437 | | - for t in threads: |
438 | | - t.start() |
439 | | - |
440 | | - # Yield chunks as they arrive |
441 | | - while active_urls: |
442 | | - url, item = q.get() |
443 | | - if item is None or self.status(job_id) in [ |
444 | | - DGXCloudState.DELETING, |
445 | | - DGXCloudState.STOPPED, |
446 | | - DGXCloudState.STOPPING, |
447 | | - DGXCloudState.DEGRADED, |
448 | | - DGXCloudState.FAILED, |
449 | | - DGXCloudState.COMPLETED, |
450 | | - DGXCloudState.TERMINATING, |
451 | | - ]: |
452 | | - active_urls.discard(url) |
453 | | - else: |
454 | | - yield item |
455 | | - |
456 | | - # Wait for threads |
457 | | - for t in threads: |
458 | | - t.join() |
| 403 | + logger.info(f"Attempting to stream logs with command: {cmd}") |
| 404 | + |
| 405 | + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1) |
| 406 | + |
| 407 | + if stream: |
| 408 | + while True: |
| 409 | + try: |
| 410 | + for line in iter(proc.stdout.readline, ""): |
| 411 | + if ( |
| 412 | + line |
| 413 | + and not line.rstrip("\n").endswith(".log <==") |
| 414 | + and line.rstrip("\n") != "" |
| 415 | + ): |
| 416 | + yield f"{line}" |
| 417 | + if proc.poll() is not None: |
| 418 | + break |
| 419 | + except Exception as e: |
| 420 | + logger.error(f"Error streaming logs: {e}") |
| 421 | + time.sleep(3) |
| 422 | + continue |
| 423 | + |
| 424 | + else: |
| 425 | + try: |
| 426 | + for line in iter(proc.stdout.readline, ""): |
| 427 | + if line: |
| 428 | + yield line.rstrip("\n") |
| 429 | + if proc.poll() is not None: |
| 430 | + break |
| 431 | + finally: |
| 432 | + proc.terminate() |
| 433 | + proc.wait(timeout=2) |
459 | 434 |
|
460 | 435 | def cancel(self, job_id: str): |
461 | 436 | # Retrieve the authentication token for the REST calls |
|
0 commit comments