Skip to content

Commit 7a8952c

Browse files
committed
Properly implement wait according to spec (i.e., wait for all states
greater than `target_states` not just for `target_states`). This goes a bit further by also implementing the yet-to-be-merged updates from ExaWorks/job-api-spec#178
1 parent 9e1a777 commit 7a8952c

File tree

5 files changed

+58
-31
lines changed

5 files changed

+58
-31
lines changed

src/psij/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Callable, TypeVar
88

99
from psij.descriptor import Descriptor
10-
from .exceptions import SubmitException, InvalidJobException, UnreachableStateException
10+
from .exceptions import SubmitException, InvalidJobException
1111
from .job import Job, JobStatusCallback
1212
from .job_attributes import JobAttributes
1313
from .job_executor import JobExecutor

src/psij/exceptions.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,3 @@ def __init__(self, message: str, exception: Optional[Exception] = None,
6666
conditions such an error would persist across subsequent re-tries until correct credentials
6767
are used.
6868
"""
69-
70-
71-
class UnreachableStateException(Exception):
72-
"""
73-
Indicates that a job state being waited for cannot be reached.
74-
75-
This exception is thrown when the :func:`~psij.Job.wait` method is called with a set of
76-
states that cannot be reached by the job when the call is made.
77-
"""
78-
79-
def __init__(self, status: 'psij.JobStatus') -> None:
80-
"""
81-
Constructs an `UnreachableStateException`.
82-
83-
:param status: The :class:`~psij.JobStatus` that the job was in when
84-
:func:`~psij.Job.wait` was called and which prevents the desired states to be
85-
reached.
86-
"""
87-
self.status = status
88-
"""
89-
Returns the job status that has caused an implementation to determine that the desired
90-
states passed to the :func:`~psij.Job.wait` method cannot be reached.
91-
"""

src/psij/job.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import threading
33
from abc import ABC, abstractmethod
44
from datetime import timedelta, datetime
5-
from typing import Optional, Sequence, Union, Callable
5+
from typing import Optional, Sequence, Union, Callable, Set
66
from uuid import uuid4
77

88
import psij
9-
from psij.exceptions import SubmitException, UnreachableStateException
9+
from psij.exceptions import SubmitException
1010
from psij.job_spec import JobSpec
1111
from psij.job_state import JobState, JobStateOrder
1212
from psij.job_status import JobStatus
@@ -161,8 +161,23 @@ def cancel(self) -> None:
161161
else:
162162
self.executor.cancel(self)
163163

164+
def _all_greater(self, states: Optional[Union[JobState, Sequence[JobState]]]) \
165+
-> Optional[Set[JobState]]:
166+
if states is None:
167+
return None
168+
if isinstance(states, JobState):
169+
states = [states]
170+
ts = set()
171+
for state1 in states:
172+
ts.add(state1)
173+
for state2 in JobState:
174+
if state2.is_greater_than(state1):
175+
ts.add(state2)
176+
return ts
177+
164178
def wait(self, timeout: Optional[timedelta] = None,
165-
target_states: Optional[Sequence[JobState]] = None) -> Optional[JobStatus]:
179+
target_states: Optional[Union[JobState, Sequence[JobState]]] = None) \
180+
-> Optional[JobStatus]:
166181
"""
167182
Waits for the job to reach certain states.
168183
@@ -186,15 +201,15 @@ def wait(self, timeout: Optional[timedelta] = None,
186201
timeout = LARGE_TIMEOUT
187202
end = start + timeout
188203

204+
ts = self._all_greater(target_states)
205+
189206
while True:
190207
with self._status_cv:
191208
status = self._status
192209
state = status.state
193-
if target_states:
194-
if state in target_states:
210+
if ts:
211+
if state.final or state in ts:
195212
return status
196-
elif state.final:
197-
raise UnreachableStateException(status)
198213
else:
199214
pass # wait
200215
else:

src/psij/job_state.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ def __bool__(self) -> bool:
108108
"""All states are consider true-ish."""
109109
return True
110110

111+
def __hash__(self) -> int:
112+
"""Returns a hash for this object."""
113+
return self._value_ # type: ignore
114+
111115

112116
class JobStateOrder:
113117
"""A class that can be used to reconstruct missing states."""

tests/test_wait.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from datetime import timedelta
2+
3+
from psij import Job, JobExecutor, JobSpec, JobState
4+
5+
6+
def _test_wait() -> None:
7+
ex = JobExecutor.get_instance('local')
8+
job = Job(JobSpec('/bin/sleep', ['4']))
9+
ex.submit(job)
10+
status = job.wait(target_states=JobState.ACTIVE)
11+
assert status is not None
12+
assert status.state == JobState.ACTIVE
13+
14+
status = job.wait(target_states=JobState.QUEUED)
15+
assert status is not None
16+
assert status.state == JobState.ACTIVE
17+
18+
status = job.wait(timedelta(milliseconds=100))
19+
assert status is None
20+
21+
status = job.wait()
22+
assert status is not None
23+
assert status.state == JobState.COMPLETED
24+
25+
status = job.wait(target_states=JobState.QUEUED)
26+
assert status is not None
27+
assert status.state == JobState.COMPLETED
28+
29+
status = job.wait(target_states=JobState.FAILED)
30+
assert status is not None
31+
assert status.state == JobState.COMPLETED

0 commit comments

Comments
 (0)