diff --git a/.vscode/settings.json b/.vscode/settings.json index 49ae8905..64f5414c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -19,6 +19,7 @@ "python.testing.pytestArgs": [ "test" ], + "flake8.enabled": false, "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true } diff --git a/nemo_run/core/execution/dgxcloud.py b/nemo_run/core/execution/dgxcloud.py index 1973d7f8..69a7e29c 100644 --- a/nemo_run/core/execution/dgxcloud.py +++ b/nemo_run/core/execution/dgxcloud.py @@ -17,13 +17,15 @@ import json import logging import os +import queue import subprocess import tempfile +import threading import time from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Optional, Type +from typing import Any, Iterable, Optional import requests from invoke.context import Context @@ -65,6 +67,7 @@ class DGXCloudExecutor(Executor): """ base_url: str + kube_apiserver_url: str app_id: str app_secret: str project_name: str @@ -359,6 +362,92 @@ def status(self, job_id: str) -> Optional[DGXCloudState]: r_json = response.json() return DGXCloudState(r_json["phase"]) + def _stream_url_sync(self, url: str, headers: dict, q: queue.Queue): + """Stream a single URL using requests and put chunks into the queue""" + try: + with requests.get(url, stream=True, headers=headers, verify=False) as response: + for line in response.iter_lines(decode_unicode=True): + q.put((url, f"{line}\n")) + except Exception as e: + logger.error(f"Error streaming URL {url}: {e}") + + finally: + q.put((url, None)) + + def fetch_logs( + self, + job_id: str, + stream: bool, + stderr: Optional[bool] = None, + stdout: Optional[bool] = None, + ) -> Iterable[str]: + token = self.get_auth_token() + if not token: + logger.error("Failed to retrieve auth token for fetch logs request.") + yield "" + + response = requests.get( + f"{self.base_url}/workloads", headers=self._default_headers(token=token) + ) + workload_name = next( + ( + workload["name"] + for workload in response.json()["workloads"] + if workload["id"] == job_id + ), + None, + ) + if workload_name is None: + logger.error(f"No workload found with id {job_id}") + yield "" + + urls = [ + f"{self.kube_apiserver_url}/api/v1/namespaces/runai-{self.project_name}/pods/{workload_name}-worker-{i}/log?container=pytorch" + for i in range(self.nodes) + ] + + if stream: + urls = [url + "&follow=true" for url in urls] + + while self.status(job_id) != DGXCloudState.RUNNING: + logger.info("Waiting for job to start...") + time.sleep(15) + + time.sleep(10) + + q = queue.Queue() + active_urls = set(urls) + + # Start threads + threads = [ + threading.Thread( + target=self._stream_url_sync, args=(url, self._default_headers(token=token), q) + ) + for url in urls + ] + for t in threads: + t.start() + + # Yield chunks as they arrive + while active_urls: + url, item = q.get() + if item is None or self.status(job_id) in [ + DGXCloudState.DELETING, + DGXCloudState.STOPPED, + DGXCloudState.STOPPING, + DGXCloudState.DEGRADED, + DGXCloudState.FAILED, + DGXCloudState.COMPLETED, + DGXCloudState.TERMINATING, + ]: + active_urls.discard(url) + else: + yield item + + # Wait for threads + for t in threads: + t.join() + def cancel(self, job_id: str): # Retrieve the authentication token for the REST calls token = self.get_auth_token() @@ -385,12 +474,6 @@ def cancel(self, job_id: str): response.text, ) - @classmethod - def logs(cls: Type["DGXCloudExecutor"], app_id: str, fallback_path: Optional[str]): - logger.warning( - "Logs not available for DGXCloudExecutor based jobs. Please visit the cluster UI to view the logs." - ) - def cleanup(self, handle: str): ... def assign( diff --git a/nemo_run/run/logs.py b/nemo_run/run/logs.py index a2091fad..cda7c6f9 100644 --- a/nemo_run/run/logs.py +++ b/nemo_run/run/logs.py @@ -30,9 +30,7 @@ from nemo_run.core.execution.base import LogSupportedExecutor from nemo_run.core.frontend.console.api import CONSOLE from nemo_run.run.torchx_backend.runner import Runner, get_runner -from nemo_run.run.torchx_backend.schedulers.api import ( - REVERSE_EXECUTOR_MAPPING, -) +from nemo_run.run.torchx_backend.schedulers.api import REVERSE_EXECUTOR_MAPPING logger: logging.Logger = logging.getLogger(__name__) @@ -60,6 +58,8 @@ def print_log_lines( role_name, replica_id, regex, + None, + None, should_tail=should_tail, streams=streams, ): diff --git a/nemo_run/run/torchx_backend/schedulers/dgxcloud.py b/nemo_run/run/torchx_backend/schedulers/dgxcloud.py index 0865da58..4377ec71 100644 --- a/nemo_run/run/torchx_backend/schedulers/dgxcloud.py +++ b/nemo_run/run/torchx_backend/schedulers/dgxcloud.py @@ -19,8 +19,9 @@ import shutil import tempfile from dataclasses import dataclass +from datetime import datetime from pathlib import Path -from typing import Any, Optional +from typing import Any, Iterable, Optional import fiddle as fdl import fiddle._src.experimental.dataclasses as fdl_dc @@ -29,15 +30,10 @@ DescribeAppResponse, ListAppResponse, Scheduler, + Stream, + split_lines, ) -from torchx.specs import ( - AppDef, - AppState, - ReplicaStatus, - Role, - RoleStatus, - runopts, -) +from torchx.specs import AppDef, AppState, ReplicaStatus, Role, RoleStatus, runopts from nemo_run.config import get_nemorun_home from nemo_run.core.execution.base import Executor @@ -189,6 +185,36 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]: ui_url=f"{executor.base_url}/workloads/distributed/{job_id}", ) + def log_iter( + self, + app_id: str, + role_name: str, + k: int = 0, + regex: Optional[str] = None, + since: Optional[datetime] = None, + until: Optional[datetime] = None, + should_tail: bool = False, + streams: Optional[Stream] = None, + ) -> Iterable[str]: + stored_data = _get_job_dirs() + job_info = stored_data.get(app_id) + _, _, job_id = app_id.split("___") + executor: Optional[DGXCloudExecutor] = job_info.get("executor", None) # type: ignore + if not executor: + return [""] + + logs = executor.fetch_logs( + job_id=job_id, + stream=should_tail, + ) # type: ignore + if isinstance(logs, str): + if len(logs) == 0: + logs = [] + else: + logs = split_lines(logs) + + return logs + def _cancel_existing(self, app_id: str) -> None: """ Cancels the job by calling the DGXExecutor's cancel method. diff --git a/test/core/execution/test_dgxcloud.py b/test/core/execution/test_dgxcloud.py index 4d431e3c..265c3377 100644 --- a/test/core/execution/test_dgxcloud.py +++ b/test/core/execution/test_dgxcloud.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import subprocess import tempfile from unittest.mock import MagicMock, mock_open, patch import pytest +import requests from nemo_run.config import set_nemorun_home from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState @@ -29,6 +31,7 @@ class TestDGXCloudExecutor: def test_init(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -58,6 +61,7 @@ def test_get_auth_token_success(self, mock_post): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -86,6 +90,7 @@ def test_get_auth_token_failure(self, mock_post): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -97,6 +102,172 @@ def test_get_auth_token_failure(self, mock_post): assert token is None + def test_fetch_no_token(self, caplog): + with ( + patch.object(DGXCloudExecutor, "get_auth_token", return_value=None), + caplog.at_level(logging.ERROR), + ): + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + ) + + logs_iter = executor.fetch_logs("123", stream=True) + assert next(logs_iter) == "" + assert ( + caplog.records[-1].message + == "Failed to retrieve auth token for fetch logs request." + ) + assert caplog.records[-1].levelname == "ERROR" + caplog.clear() + + @patch("nemo_run.core.execution.dgxcloud.requests.get") + def test_fetch_no_workload_with_name(self, mock_requests_get, caplog): + mock_workloads_response = MagicMock(spec=requests.Response) + mock_workloads_response.json.return_value = { + "workloads": [{"name": "hello-world", "id": "123"}] + } + + mock_requests_get.side_effect = [mock_workloads_response] + + with ( + patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"), + caplog.at_level(logging.ERROR), + ): + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + ) + + logs_iter = executor.fetch_logs("this-workload-does-not-exist", stream=True) + assert next(logs_iter) == "" + assert ( + caplog.records[-1].message + == "No workload found with id this-workload-does-not-exist" + ) + assert caplog.records[-1].levelname == "ERROR" + caplog.clear() + + @patch("nemo_run.core.execution.dgxcloud.requests.get") + @patch("nemo_run.core.execution.dgxcloud.time.sleep") + @patch("nemo_run.core.execution.dgxcloud.threading.Thread") + def test_fetch_logs(self, mock_threading_Thread, mock_sleep, mock_requests_get): + # --- 1. Setup Primitives for the *live* test --- + mock_log_response = MagicMock(spec=requests.Response) + + mock_log_response.iter_lines.return_value = iter( + ["this is a static log", "this is the last static log"] + ) + mock_log_response.__enter__.return_value = mock_log_response + + # Mock for the '/workloads' call + mock_workloads_response = MagicMock(spec=requests.Response) + mock_workloads_response.json.return_value = { + "workloads": [{"name": "hello-world", "id": "123"}] + } + + mock_queue_instance = MagicMock() + mock_queue_instance.get.side_effect = [ + ( + "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true", + "this is a static log\n", + ), + ( + "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true", + None, + ), + ( + "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true", + None, + ), + ] + + mock_requests_get.side_effect = [mock_workloads_response, mock_log_response] + + # --- 4. Setup Executor (inside the patch) --- + with ( + patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"), + patch.object(DGXCloudExecutor, "status", return_value=DGXCloudState.RUNNING), + patch("nemo_run.core.execution.dgxcloud.queue.Queue", return_value=mock_queue_instance), + ): + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + nodes=2, + ) + + logs_iter = executor.fetch_logs("123", stream=True) + + assert next(logs_iter) == "this is a static log\n" + + mock_sleep.assert_called_once_with(10) + + mock_threading_Thread.assert_any_call( + target=executor._stream_url_sync, + args=( + "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true", + executor._default_headers(token="test_token"), + mock_queue_instance, + ), + ) + mock_threading_Thread.assert_any_call( + target=executor._stream_url_sync, + args=( + "https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true", + executor._default_headers(token="test_token"), + mock_queue_instance, + ), + ) + with pytest.raises(StopIteration): + next(logs_iter) + + @patch("nemo_run.core.execution.dgxcloud.requests.get") + def test__stream_url_sync(self, mock_requests_get): + # --- 1. Setup Primitives for the *live* test --- + mock_log_response = MagicMock(spec=requests.Response) + + mock_log_response.iter_lines.return_value = iter( + ["this is a static log", "this is the last static log"] + ) + mock_log_response.__enter__.return_value = mock_log_response + + mock_requests_get.side_effect = [mock_log_response] + + mock_queue_instance = MagicMock() + + with patch( + "nemo_run.core.execution.dgxcloud.queue.Queue", return_value=mock_queue_instance + ): + executor = DGXCloudExecutor( + base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", + app_id="test_app_id", + app_secret="test_app_secret", + project_name="test_project", + container_image="nvcr.io/nvidia/test:latest", + pvc_nemo_run_dir="/workspace/nemo_run", + nodes=2, + ) + + executor._stream_url_sync("123", "some-headers", mock_queue_instance) + + mock_queue_instance.put.assert_any_call(("123", "this is a static log\n")) + @patch("requests.get") def test_get_project_and_cluster_id_success(self, mock_get): mock_response = MagicMock() @@ -105,6 +276,7 @@ def test_get_project_and_cluster_id_success(self, mock_get): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -131,6 +303,7 @@ def test_get_project_and_cluster_id_not_found(self, mock_get): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -151,6 +324,7 @@ def test_copy_directory_data_command_success(self, mock_file, mock_subprocess): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -180,6 +354,7 @@ def test_copy_directory_data_command_fails(self, mock_tempdir): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -202,6 +377,7 @@ def test_create_data_mover_workload_success(self, mock_command, mock_post): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -236,6 +412,7 @@ def test_delete_workload(self, mock_delete): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -269,6 +446,7 @@ def test_move_data_success(self, mock_delete, mock_status, mock_create, mock_sle executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -295,6 +473,7 @@ def test_move_data_data_mover_fail(self, mock_create, mock_sleep): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -319,6 +498,7 @@ def test_move_data_failed(self, mock_status, mock_create, mock_sleep): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -343,6 +523,7 @@ def test_create_training_job_single_node(self, mock_post): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -400,6 +581,7 @@ def test_create_training_job_multi_node(self, mock_post): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -467,6 +649,7 @@ def test_launch_single_node( with tempfile.TemporaryDirectory() as tmp_dir: executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -513,6 +696,7 @@ def test_launch_multi_node(self, mock_create_job, mock_move_data, mock_get_ids, with tempfile.TemporaryDirectory() as tmp_dir: executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -551,6 +735,7 @@ def test_launch_no_token(self, mock_get_token): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -569,6 +754,7 @@ def test_launch_no_project_id(self, mock_get_ids, mock_get_token): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -596,6 +782,7 @@ def test_launch_job_creation_failed( with tempfile.TemporaryDirectory() as tmp_dir: executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -610,6 +797,7 @@ def test_launch_job_creation_failed( def test_nnodes(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -623,6 +811,7 @@ def test_nnodes(self): def test_nproc_per_node_with_gpus(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -636,6 +825,7 @@ def test_nproc_per_node_with_gpus(self): def test_nproc_per_node_with_nprocs(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -650,6 +840,7 @@ def test_nproc_per_node_with_nprocs(self): def test_nproc_per_node_default(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -671,6 +862,7 @@ def test_status(self, mock_get): with patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -691,6 +883,7 @@ def test_status_no_token(self, mock_get): with patch.object(DGXCloudExecutor, "get_auth_token", return_value=None): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -712,6 +905,7 @@ def test_status_error_response(self, mock_get): with patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -732,6 +926,7 @@ def test_cancel(self, mock_get): with patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -751,6 +946,7 @@ def test_cancel_no_token(self, mock_get): with patch.object(DGXCloudExecutor, "get_auth_token", return_value=None): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -762,17 +958,12 @@ def test_cancel_no_token(self, mock_get): mock_get.assert_not_called() - def test_logs(self): - with patch("logging.Logger.warning") as mock_warning: - DGXCloudExecutor.logs("app123", "/path/to/fallback") - mock_warning.assert_called_once() - assert "Logs not available" in mock_warning.call_args[0][0] - def test_assign(self): set_nemorun_home("/nemo_home") executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -802,6 +993,7 @@ def test_assign_no_pvc(self): with tempfile.TemporaryDirectory() as tmp_dir: executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -823,6 +1015,7 @@ def test_assign_no_pvc(self): def test_package_configs(self, mock_file, mock_makedirs): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -855,6 +1048,7 @@ def test_package_git_packager(self, mock_subprocess_run, mock_context_run): with tempfile.TemporaryDirectory() as tmp_dir: executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -881,6 +1075,7 @@ def test_package_git_packager(self, mock_subprocess_run, mock_context_run): def test_macro_values(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -895,6 +1090,7 @@ def test_macro_values(self): def test_default_headers_without_token(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -911,6 +1107,7 @@ def test_default_headers_without_token(self): def test_default_headers_with_token(self): executor = DGXCloudExecutor( base_url="https://dgxapi.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_app_secret", project_name="test_project", @@ -925,3 +1122,4 @@ def test_default_headers_with_token(self): assert headers["Content-Type"] == "application/json" assert "Authorization" in headers assert headers["Authorization"] == "Bearer test_token" + assert headers["Authorization"] == "Bearer test_token" diff --git a/test/run/torchx_backend/schedulers/test_dgxcloud.py b/test/run/torchx_backend/schedulers/test_dgxcloud.py index 5a3f9ae9..ca25b92b 100644 --- a/test/run/torchx_backend/schedulers/test_dgxcloud.py +++ b/test/run/torchx_backend/schedulers/test_dgxcloud.py @@ -15,16 +15,14 @@ import tempfile from unittest import mock +from unittest.mock import MagicMock import pytest from torchx.schedulers.api import AppDryRunInfo from torchx.specs import AppDef, Role from nemo_run.core.execution.dgxcloud import DGXCloudExecutor -from nemo_run.run.torchx_backend.schedulers.dgxcloud import ( - DGXCloudScheduler, - create_scheduler, -) +from nemo_run.run.torchx_backend.schedulers.dgxcloud import DGXCloudScheduler, create_scheduler @pytest.fixture @@ -38,6 +36,7 @@ def mock_app_def(): def dgx_cloud_executor(): return DGXCloudExecutor( base_url="https://dgx.example.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_app_id", app_secret="test_secret", project_name="test_project", @@ -147,6 +146,7 @@ def test_save_and_get_job_dirs(): executor = DGXCloudExecutor( base_url="https://test.com", + kube_apiserver_url="https://127.0.0.1:443", app_id="test_id", app_secret="test_secret", project_name="test_project", @@ -160,3 +160,47 @@ def test_save_and_get_job_dirs(): assert "test_app_id" in job_dirs assert isinstance(job_dirs["test_app_id"]["executor"], DGXCloudExecutor) + + +def test_log_iter(dgx_cloud_scheduler, dgx_cloud_executor): + with mock.patch( + "nemo_run.run.torchx_backend.schedulers.dgxcloud._get_job_dirs" + ) as mock_get_job_dirs: + mock_get_job_dirs.return_value = { + "test_session___test_role___test_container_id": { + "job_status": "RUNNING", + "executor": dgx_cloud_executor, + } + } + + dgx_cloud_executor.fetch_logs = MagicMock() + dgx_cloud_executor.fetch_logs.return_value = ["log2", "log3"] + + logs = list( + dgx_cloud_scheduler.log_iter( + "test_session___test_role___test_container_id", "test_role" + ) + ) + assert logs == ["log2", "log3"] + + +def test_log_iter_str(dgx_cloud_scheduler, dgx_cloud_executor): + with mock.patch( + "nemo_run.run.torchx_backend.schedulers.dgxcloud._get_job_dirs" + ) as mock_get_job_dirs: + mock_get_job_dirs.return_value = { + "test_session___test_role___test_container_id": { + "job_status": "RUNNING", + "executor": dgx_cloud_executor, + } + } + + dgx_cloud_executor.fetch_logs = MagicMock() + dgx_cloud_executor.fetch_logs.return_value = "log2\nlog3" + + logs = list( + dgx_cloud_scheduler.log_iter( + "test_session___test_role___test_container_id", "test_role" + ) + ) + assert logs == ["log2\n", "log3"]