Skip to content

Commit 3f7f638

Browse files
committed
have metrics pushing also work for --no-wait jobs
1 parent 72c5541 commit 3f7f638

File tree

2 files changed

+52
-40
lines changed

2 files changed

+52
-40
lines changed

batchtools/br.py

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# pyright: reportUninitializedInstanceVariable=false
21
from typing import cast
32
from typing_extensions import override, Optional
43

@@ -185,10 +184,9 @@ def run(args: argparse.Namespace):
185184
queue_wait = None
186185
total_wall = None
187186

188-
if args.wait:
189-
result_phase, run_elapsed, queue_wait, total_wall = log_job_output(
190-
job_name=job_name, wait=True, timeout=args.timeout
191-
)
187+
result_phase, run_elapsed, queue_wait, total_wall = log_job_output(
188+
job_name=job_name, wait=True, timeout=args.timeout
189+
)
192190

193191
if (
194192
run_elapsed is not None
@@ -270,48 +268,48 @@ def log_job_output(
270268
queue_wait = None
271269
total_wall = None
272270

273-
if wait:
274-
start_poll = time.monotonic()
275-
while True:
276-
phase = get_pod_status(pod_name)
277-
if phase == "Running" and run_start is None:
278-
# time waiting in queue is time from entering the queue to the time it takes to start running
279-
run_start = time.monotonic()
280-
queue_wait = run_start - start_poll # submit -> running
281-
282-
if phase in ("Succeeded", "Failed"):
283-
result_phase = phase.lower()
284-
total_wall = time.monotonic() - start_poll # submit -> terminal
285-
print(f"Pod {pod_name} finished with phase={phase}")
286-
break
287-
288-
if timeout and (time.monotonic() - start_poll) > timeout:
271+
start_poll = time.monotonic()
272+
while True:
273+
phase = get_pod_status(pod_name)
274+
if phase == "Running" and run_start is None:
275+
# time waiting in queue is time from entering the queue to the time it takes to start running
276+
run_start = time.monotonic()
277+
queue_wait = run_start - start_poll # submit -> running
278+
279+
if phase in ("Succeeded", "Failed"):
280+
result_phase = phase.lower()
281+
total_wall = time.monotonic() - start_poll # submit -> terminal
282+
print(f"Pod {pod_name} finished with phase={phase}")
283+
break
284+
285+
if timeout and (time.monotonic() - start_poll) > timeout:
286+
if wait:
289287
print(f"Timeout waiting for pod {pod_name} to complete")
290288
print(f"Deleting job {job_name}")
291289
oc_delete("job", job_name)
292-
total_wall = time.monotonic() - start_poll
293-
# timeout: no run duration (didn't finish), queue_wait may or may not be set
294-
print_timing(queue_wait, None, total_wall)
295-
return ("timeout", None, queue_wait, total_wall)
290+
total_wall = time.monotonic() - start_poll
291+
# timeout: no run duration (didn't finish), queue_wait may or may not be set
292+
print_timing(queue_wait, None, total_wall)
293+
return ("timeout", None, queue_wait, total_wall)
296294

297-
time.sleep(2)
295+
time.sleep(2)
298296

299297
print(pretty_print(pod))
300298

301299
# compute the runtime using the total time (total_wall)- time waiting in queue (queue_wait)
302-
if wait:
303-
if run_start is not None:
304-
# running status was reached, run_elapsed = terminal - run_start
305-
if total_wall is None:
306-
total_wall = time.monotonic() - start_poll # fallback
307-
run_elapsed = total_wall - (queue_wait or 0.0)
308-
else:
309-
# Never reached Running; keep convention for failures:
310-
run_elapsed = 0.0 if result_phase == "failed" else None
311-
if total_wall is None:
312-
total_wall = 0.0
313-
if total_wall is not None and queue_wait is None:
314-
queue_wait = total_wall
300+
301+
if run_start is not None:
302+
# running status was reached, run_elapsed = terminal - run_start
303+
if total_wall is None:
304+
total_wall = time.monotonic() - start_poll # fallback
305+
run_elapsed = total_wall - (queue_wait or 0.0)
306+
else:
307+
# Never reached Running; keep convention for failures:
308+
run_elapsed = 0.0 if result_phase == "failed" else None
309+
if total_wall is None:
310+
total_wall = 0.0
311+
if total_wall is not None and queue_wait is None:
312+
queue_wait = total_wall
315313

316314
print_timing(queue_wait, run_elapsed, total_wall)
317315
return (result_phase, run_elapsed, queue_wait, total_wall)

tests/test_br.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_invalid_gpu(args: argparse.Namespace):
4444
),
4545
],
4646
)
47+
@mock.patch("batchtools.br.log_job_output", return_value=("succeeded", 1.0, 0.5, 1.5))
4748
@mock.patch("openshift_client.create", name="create")
4849
@mock.patch("openshift_client.selector", name="selector")
4950
@mock.patch("socket.gethostname", name="gethostname")
@@ -53,13 +54,19 @@ def test_create_job_nowait(
5354
mock_gethostname,
5455
mock_selector,
5556
mock_create,
57+
mock_log_job_output,
5658
gpu,
5759
resources,
5860
args: argparse.Namespace,
5961
tempdir,
6062
parser,
6163
subparsers,
6264
):
65+
"""
66+
Even if args.wait is False, CreateJobCommand.run should still build the
67+
correct Job object and call oc.create() with it. We stub out
68+
log_job_output so this test does not depend on pod selectors or timing.
69+
"""
6370
CreateJobCommand.build_parser(subparsers)
6471
args = parser.parse_args(["br"])
6572
args.wait = False
@@ -82,7 +89,7 @@ def test_create_job_nowait(
8289
}
8390
)
8491

85-
mock_result = mock.Mock(spec=["object"])
92+
mock_result = mock.Mock()
8693
mock_result.object.return_value = pod
8794
mock_selector.return_value = mock_result
8895

@@ -121,10 +128,17 @@ def test_create_job_nowait(
121128
},
122129
}
123130

131+
# Make the rsync_script simple and deterministic for the test
124132
batchtools.build_yaml.rsync_script = "testcommand {cmdline}"
133+
125134
CreateJobCommand.run(args)
126135

136+
# Verify we created the expected Job spec
127137
assert mock_create.call_args.args[0] == expected
138+
# And that we did call log_job_output once (even with wait=False)
139+
mock_log_job_output.assert_called_once()
140+
called_job_name = mock_log_job_output.call_args.kwargs["job_name"]
141+
assert called_job_name == f"job-{gpu}-test"
128142

129143

130144
@mock.patch("openshift_client.create", name="create")

0 commit comments

Comments
 (0)