Skip to content

Commit b01d118

Browse files
committed
cleanup and testing
Signed-off-by: oliver könig <[email protected]>
1 parent 69f2b14 commit b01d118

File tree

2 files changed

+172
-164
lines changed

2 files changed

+172
-164
lines changed

nemo_run/core/execution/dgxcloud.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import json
1919
import logging
2020
import os
21-
import queue
2221
import subprocess
2322
import tempfile
2423
import time
@@ -372,18 +371,6 @@ def status(self, job_id: str) -> Optional[DGXCloudState]:
372371
r_json = response.json()
373372
return DGXCloudState(r_json["phase"])
374373

375-
def _stream_url_sync(self, url: str, headers: dict, q: queue.Queue):
376-
"""Stream a single URL using requests and put chunks into the queue"""
377-
try:
378-
with requests.get(url, stream=True, headers=headers, verify=False) as response:
379-
for line in response.iter_lines(decode_unicode=True):
380-
q.put((url, f"{line}\n"))
381-
except Exception as e:
382-
logger.error(f"Error streaming URL {url}: {e}")
383-
384-
finally:
385-
q.put((url, None))
386-
387374
def fetch_logs(
388375
self,
389376
job_id: str,

test/core/execution/test_dgxcloud.py

Lines changed: 172 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
1716
import os
1817
import subprocess
1918
import tempfile
2019
from unittest.mock import MagicMock, mock_open, patch
2120

2221
import pytest
23-
import requests
2422

2523
from nemo_run.config import set_nemorun_home
2624
from nemo_run.core.execution.dgxcloud import DGXCloudExecutor, DGXCloudState
@@ -83,11 +81,12 @@ def test_get_auth_token_success(self, mock_post):
8381
)
8482

8583
@patch("requests.post")
86-
def test_get_auth_token_failure(self, mock_post):
84+
@patch("time.sleep")
85+
def test_get_auth_token_failure(self, mock_sleep, mock_post):
8786
mock_response = MagicMock()
8887
mock_response.text = '{"error": "Invalid credentials"}'
8988
mock_post.return_value = mock_response
90-
89+
mock_sleep.return_value = None
9190
executor = DGXCloudExecutor(
9291
base_url="https://dgxapi.example.com",
9392
kube_apiserver_url="https://127.0.0.1:443",
@@ -102,171 +101,193 @@ def test_get_auth_token_failure(self, mock_post):
102101

103102
assert token is None
104103

105-
def test_fetch_no_token(self, caplog):
106-
with (
107-
patch.object(DGXCloudExecutor, "get_auth_token", return_value=None),
108-
caplog.at_level(logging.ERROR),
109-
):
110-
executor = DGXCloudExecutor(
111-
base_url="https://dgxapi.example.com",
112-
kube_apiserver_url="https://127.0.0.1:443",
113-
app_id="test_app_id",
114-
app_secret="test_app_secret",
115-
project_name="test_project",
116-
container_image="nvcr.io/nvidia/test:latest",
117-
pvc_nemo_run_dir="/workspace/nemo_run",
118-
)
104+
@patch("glob.glob")
105+
@patch("subprocess.Popen")
106+
@patch("time.sleep")
107+
def test_fetch_logs_streaming(self, mock_sleep, mock_popen, mock_glob):
108+
"""Test fetch_logs in streaming mode."""
109+
set_nemorun_home("/nemo_home")
119110

120-
logs_iter = executor.fetch_logs("123", stream=True)
121-
assert next(logs_iter) == ""
122-
assert (
123-
caplog.records[-1].message
124-
== "Failed to retrieve auth token for fetch logs request."
125-
)
126-
assert caplog.records[-1].levelname == "ERROR"
127-
caplog.clear()
128-
129-
@patch("nemo_run.core.execution.dgxcloud.requests.get")
130-
def test_fetch_no_workload_with_name(self, mock_requests_get, caplog):
131-
mock_workloads_response = MagicMock(spec=requests.Response)
132-
mock_workloads_response.json.return_value = {
133-
"workloads": [{"name": "hello-world", "id": "123"}]
134-
}
135-
136-
mock_requests_get.side_effect = [mock_workloads_response]
137-
138-
with (
139-
patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"),
140-
caplog.at_level(logging.ERROR),
141-
):
142-
executor = DGXCloudExecutor(
143-
base_url="https://dgxapi.example.com",
144-
kube_apiserver_url="https://127.0.0.1:443",
145-
app_id="test_app_id",
146-
app_secret="test_app_secret",
147-
project_name="test_project",
148-
container_image="nvcr.io/nvidia/test:latest",
149-
pvc_nemo_run_dir="/workspace/nemo_run",
150-
)
111+
# Mock log files
112+
mock_glob.return_value = [
113+
"/workspace/nemo_run/experiments/exp1/task1/logs/output-worker-0.log",
114+
"/workspace/nemo_run/experiments/exp1/task1/logs/output-worker-1.log",
115+
]
151116

152-
logs_iter = executor.fetch_logs("this-workload-does-not-exist", stream=True)
153-
assert next(logs_iter) == ""
154-
assert (
155-
caplog.records[-1].message
156-
== "No workload found with id this-workload-does-not-exist"
157-
)
158-
assert caplog.records[-1].levelname == "ERROR"
159-
caplog.clear()
160-
161-
@patch("nemo_run.core.execution.dgxcloud.requests.get")
162-
@patch("nemo_run.core.execution.dgxcloud.time.sleep")
163-
@patch("nemo_run.core.execution.dgxcloud.threading.Thread")
164-
def test_fetch_logs(self, mock_threading_Thread, mock_sleep, mock_requests_get):
165-
# --- 1. Setup Primitives for the *live* test ---
166-
mock_log_response = MagicMock(spec=requests.Response)
167-
168-
mock_log_response.iter_lines.return_value = iter(
169-
["this is a static log", "this is the last static log"]
117+
# Mock process that yields log lines
118+
mock_process = MagicMock()
119+
mock_process.stdout.readline.side_effect = [
120+
"Log line 1\n",
121+
"Log line 2\n",
122+
"", # End of stream
123+
]
124+
mock_process.poll.return_value = None
125+
mock_popen.return_value = mock_process
126+
mock_sleep.return_value = None
127+
128+
executor = DGXCloudExecutor(
129+
base_url="https://dgxapi.example.com",
130+
kube_apiserver_url="https://127.0.0.1:443",
131+
app_id="test_app_id",
132+
app_secret="test_app_secret",
133+
project_name="test_project",
134+
container_image="nvcr.io/nvidia/test:latest",
135+
pvc_nemo_run_dir="/workspace/nemo_run",
136+
nodes=2,
170137
)
171-
mock_log_response.__enter__.return_value = mock_log_response
172-
173-
# Mock for the '/workloads' call
174-
mock_workloads_response = MagicMock(spec=requests.Response)
175-
mock_workloads_response.json.return_value = {
176-
"workloads": [{"name": "hello-world", "id": "123"}]
177-
}
178-
179-
mock_queue_instance = MagicMock()
180-
mock_queue_instance.get.side_effect = [
181-
(
182-
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true",
183-
"this is a static log\n",
184-
),
185-
(
186-
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true",
187-
None,
188-
),
189-
(
190-
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true",
191-
None,
192-
),
138+
executor.job_dir = "/nemo_home/experiments/exp1/task1"
139+
140+
with patch.object(executor, "status", return_value=DGXCloudState.RUNNING):
141+
logs_iter = executor.fetch_logs("job123", stream=True)
142+
143+
# Consume first two log lines
144+
log1 = next(logs_iter)
145+
log2 = next(logs_iter)
146+
147+
assert "Log line 1" in log1
148+
assert "Log line 2" in log2
149+
150+
# Verify subprocess was called with tail -f
151+
mock_popen.assert_called_once()
152+
call_args = mock_popen.call_args[0][0]
153+
assert "tail" in call_args
154+
assert "-f" in call_args
155+
156+
@patch("glob.glob")
157+
@patch("subprocess.Popen")
158+
@patch("time.sleep")
159+
def test_fetch_logs_non_streaming(self, mock_sleep, mock_popen, mock_glob):
160+
"""Test fetch_logs in non-streaming mode."""
161+
set_nemorun_home("/nemo_home")
162+
163+
# Mock log files
164+
mock_glob.return_value = [
165+
"/workspace/nemo_run/experiments/exp1/task1/logs/output-worker-0.log",
193166
]
194167

195-
mock_requests_get.side_effect = [mock_workloads_response, mock_log_response]
168+
# Mock process that yields log lines
169+
mock_process = MagicMock()
170+
mock_process.stdout.readline.side_effect = [
171+
"Log line 1\n",
172+
"Log line 2\n",
173+
"", # End of stream
174+
]
175+
mock_process.poll.return_value = None
176+
mock_popen.return_value = mock_process
177+
mock_sleep.return_value = None
196178

197-
# --- 4. Setup Executor (inside the patch) ---
198-
with (
199-
patch.object(DGXCloudExecutor, "get_auth_token", return_value="test_token"),
200-
patch.object(DGXCloudExecutor, "status", return_value=DGXCloudState.RUNNING),
201-
patch("nemo_run.core.execution.dgxcloud.queue.Queue", return_value=mock_queue_instance),
202-
):
203-
executor = DGXCloudExecutor(
204-
base_url="https://dgxapi.example.com",
205-
kube_apiserver_url="https://127.0.0.1:443",
206-
app_id="test_app_id",
207-
app_secret="test_app_secret",
208-
project_name="test_project",
209-
container_image="nvcr.io/nvidia/test:latest",
210-
pvc_nemo_run_dir="/workspace/nemo_run",
211-
nodes=2,
212-
)
179+
executor = DGXCloudExecutor(
180+
base_url="https://dgxapi.example.com",
181+
kube_apiserver_url="https://127.0.0.1:443",
182+
app_id="test_app_id",
183+
app_secret="test_app_secret",
184+
project_name="test_project",
185+
container_image="nvcr.io/nvidia/test:latest",
186+
pvc_nemo_run_dir="/workspace/nemo_run",
187+
nodes=1,
188+
)
189+
executor.job_dir = "/nemo_home/experiments/exp1/task1"
213190

214-
logs_iter = executor.fetch_logs("123", stream=True)
191+
with patch.object(executor, "status", return_value=DGXCloudState.RUNNING):
192+
logs_iter = executor.fetch_logs("job123", stream=False)
215193

216-
assert next(logs_iter) == "this is a static log\n"
194+
# Consume log lines
195+
logs = list(logs_iter)
217196

218-
mock_sleep.assert_called_once_with(10)
197+
assert len(logs) == 2
198+
assert logs[0] == "Log line 1"
199+
assert logs[1] == "Log line 2"
219200

220-
mock_threading_Thread.assert_any_call(
221-
target=executor._stream_url_sync,
222-
args=(
223-
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-0/log?container=pytorch&follow=true",
224-
executor._default_headers(token="test_token"),
225-
mock_queue_instance,
226-
),
227-
)
228-
mock_threading_Thread.assert_any_call(
229-
target=executor._stream_url_sync,
230-
args=(
231-
"https://127.0.0.1:443/api/v1/namespaces/runai-test_project/pods/hello-world-worker-1/log?container=pytorch&follow=true",
232-
executor._default_headers(token="test_token"),
233-
mock_queue_instance,
234-
),
235-
)
236-
with pytest.raises(StopIteration):
237-
next(logs_iter)
201+
# Verify subprocess was called with tail (no -f)
202+
mock_popen.assert_called_once()
203+
call_args = mock_popen.call_args[0][0]
204+
assert "tail" in call_args
205+
assert "-f" not in call_args
238206

239-
@patch("nemo_run.core.execution.dgxcloud.requests.get")
240-
def test__stream_url_sync(self, mock_requests_get):
241-
# --- 1. Setup Primitives for the *live* test ---
242-
mock_log_response = MagicMock(spec=requests.Response)
207+
# Verify process was terminated
208+
mock_process.terminate.assert_called_once()
209+
mock_process.wait.assert_called_once()
243210

244-
mock_log_response.iter_lines.return_value = iter(
245-
["this is a static log", "this is the last static log"]
211+
@patch("time.sleep")
212+
@patch("glob.glob")
213+
def test_fetch_logs_waits_for_running_status(self, mock_glob, mock_sleep):
214+
"""Test that fetch_logs waits for job to be RUNNING."""
215+
set_nemorun_home("/nemo_home")
216+
217+
executor = DGXCloudExecutor(
218+
base_url="https://dgxapi.example.com",
219+
kube_apiserver_url="https://127.0.0.1:443",
220+
app_id="test_app_id",
221+
app_secret="test_app_secret",
222+
project_name="test_project",
223+
container_image="nvcr.io/nvidia/test:latest",
224+
pvc_nemo_run_dir="/workspace/nemo_run",
225+
nodes=1,
246226
)
247-
mock_log_response.__enter__.return_value = mock_log_response
227+
executor.job_dir = "/nemo_home/experiments/exp1/task1"
248228

249-
mock_requests_get.side_effect = [mock_log_response]
229+
# Mock status to return PENDING then RUNNING
230+
status_values = [DGXCloudState.PENDING, DGXCloudState.PENDING, DGXCloudState.RUNNING]
231+
mock_sleep.return_value = None
250232

251-
mock_queue_instance = MagicMock()
233+
with patch.object(executor, "status", side_effect=status_values):
234+
# Mock glob to prevent it from blocking
235+
mock_glob.return_value = ["/workspace/nemo_run/logs/output.log"]
252236

253-
with patch(
254-
"nemo_run.core.execution.dgxcloud.queue.Queue", return_value=mock_queue_instance
255-
):
256-
executor = DGXCloudExecutor(
257-
base_url="https://dgxapi.example.com",
258-
kube_apiserver_url="https://127.0.0.1:443",
259-
app_id="test_app_id",
260-
app_secret="test_app_secret",
261-
project_name="test_project",
262-
container_image="nvcr.io/nvidia/test:latest",
263-
pvc_nemo_run_dir="/workspace/nemo_run",
264-
nodes=2,
265-
)
237+
with patch("subprocess.Popen") as mock_popen:
238+
mock_process = MagicMock()
239+
mock_process.stdout.readline.return_value = ""
240+
mock_process.poll.return_value = 0
241+
mock_popen.return_value = mock_process
242+
243+
logs_iter = executor.fetch_logs("job123", stream=False)
244+
# Consume the iterator to trigger the logic
245+
list(logs_iter)
246+
247+
# Should have slept while waiting for RUNNING status
248+
assert mock_sleep.call_count >= 2
249+
250+
@patch("time.sleep")
251+
@patch("glob.glob")
252+
@patch("subprocess.Popen")
253+
def test_fetch_logs_waits_for_log_files(self, mock_popen, mock_glob, mock_sleep):
254+
"""Test that fetch_logs waits for all log files to be created."""
255+
set_nemorun_home("/nemo_home")
256+
257+
# Mock glob to return incomplete files first, then all files
258+
mock_glob.side_effect = [
259+
[], # No files yet
260+
["/workspace/nemo_run/experiments/exp1/task1/logs/output-worker-0.log"], # 1 of 2
261+
[ # All 2 files
262+
"/workspace/nemo_run/experiments/exp1/task1/logs/output-worker-0.log",
263+
"/workspace/nemo_run/experiments/exp1/task1/logs/output-worker-1.log",
264+
],
265+
]
266+
267+
mock_process = MagicMock()
268+
mock_process.stdout.readline.return_value = ""
269+
mock_process.poll.return_value = 0
270+
mock_popen.return_value = mock_process
271+
mock_sleep.return_value = None
272+
273+
executor = DGXCloudExecutor(
274+
base_url="https://dgxapi.example.com",
275+
kube_apiserver_url="https://127.0.0.1:443",
276+
app_id="test_app_id",
277+
app_secret="test_app_secret",
278+
project_name="test_project",
279+
container_image="nvcr.io/nvidia/test:latest",
280+
pvc_nemo_run_dir="/workspace/nemo_run",
281+
nodes=2, # Expecting 2 log files
282+
)
283+
executor.job_dir = "/nemo_home/experiments/exp1/task1"
266284

267-
executor._stream_url_sync("123", "some-headers", mock_queue_instance)
285+
with patch.object(executor, "status", return_value=DGXCloudState.RUNNING):
286+
logs_iter = executor.fetch_logs("job123", stream=False)
287+
list(logs_iter) # Consume the iterator
268288

269-
mock_queue_instance.put.assert_any_call(("123", "this is a static log\n"))
289+
# Should have called glob multiple times waiting for files
290+
assert mock_glob.call_count == 3
270291

271292
@patch("requests.get")
272293
def test_get_project_and_cluster_id_success(self, mock_get):

0 commit comments

Comments
 (0)