Skip to content

Commit e989c2f

Browse files
jfmlimamxab
authored andcommitted
Wait for a running task instead of a running job
1 parent 40d1952 commit e989c2f

File tree

3 files changed

+70
-10
lines changed

3 files changed

+70
-10
lines changed

jupyterhub_nomad_spawner/nomad/nomad_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,6 +1571,20 @@ class AllocatedResources(BaseModel):
15711571
Shared: Optional[AllocatedSharedResources] = None
15721572
Tasks: Optional[Dict[str, AllocatedTaskResources]] = None
15731573

1574+
class TaskEvent(BaseModel):
1575+
Type: str
1576+
Time: int
1577+
DisplayMessage: str
1578+
Details: Dict[str, Any]
1579+
FailsTask: bool
1580+
DriverMessage: str
1581+
1582+
class TaskState(BaseModel):
1583+
State: str
1584+
Failed: bool
1585+
StartedAt: Optional[str]
1586+
FinishedAt: Optional[str]
1587+
Events: List[TaskEvent]
15741588

15751589
class AllocationListStub(BaseModel):
15761590
AllocatedResources: Optional[AllocatedResources] = None

jupyterhub_nomad_spawner/nomad/nomad_service.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
CSIVolumeCapability,
1212
CSIVolumeCreateRequest,
1313
JobsParseRequest,
14+
TaskState,
1415
)
1516

1617

@@ -125,6 +126,43 @@ async def job_status(self, job_id) -> str:
125126

126127
job_detail = response.json()
127128
return job_detail.get("Status", "")
129+
130+
async def task_status(self, job_name: str) -> str:
131+
"""Get detailed task status from most recent allocation"""
132+
allocs = await self.client.get(f"/v1/job/{job_name}/allocations")
133+
if not allocs:
134+
return "pending"
135+
136+
allocs = allocs.json()
137+
latest_alloc = max(allocs, key=lambda x: x["CreateTime"])
138+
if not latest_alloc:
139+
return "pending"
140+
141+
task_states = latest_alloc.get("TaskStates", {}) or {}
142+
task_states = {name: TaskState(**state) for name, state in task_states.items()}
143+
144+
if not task_states:
145+
return "pending"
146+
147+
for task in task_states.values():
148+
if task.State == "dead" and task.Failed:
149+
return "dead"
150+
if task.State != "running":
151+
return self._get_task_state_from_event(task)
152+
153+
return "running"
154+
155+
def _get_task_state_from_event(self, task: TaskState) -> str:
156+
"""Determine task state from latest event"""
157+
events = task.Events
158+
if not events:
159+
return "pending"
160+
161+
latest_event = events[-1]
162+
if latest_event.Type in ["Driver", "Task Setup"]:
163+
return "starting"
164+
return "pending"
165+
128166

129167
async def job_allocations(self, job_id) -> list[dict[str, Any]]:
130168
response = await self.client.get(f"/v1/job/{job_id}/allocations")

jupyterhub_nomad_spawner/spawner.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -544,16 +544,24 @@ def _get_csi_extra_parameters(self) -> Optional[Dict]:
544544
async def _ensure_running(self, nomad_service: NomadService):
545545
while True:
546546
try:
547-
status = await nomad_service.job_status(self.job_name)
548-
except Exception:
549-
self.log.exception("Failed to get job status")
550-
if status == "running":
551-
break
552-
elif status == "dead":
553-
raise Exception(f"Job (name={self.job_name}) is dead already")
554-
else:
555-
self.log.info("Waiting for %s...", self.job_name)
556-
await asyncio.sleep(5)
547+
job_status = await nomad_service.job_status(self.job_name)
548+
if job_status == "dead":
549+
raise Exception(f"Job (name={self.job_name}) is dead")
550+
elif job_status == "running":
551+
task_status = await nomad_service.task_status(self.job_name)
552+
if task_status == "running":
553+
break
554+
elif task_status == "dead":
555+
raise Exception(f"Task for job (name={self.job_name}) is dead")
556+
else:
557+
self.log.info("Task for %s is %s, waiting...", self.job_name, task_status)
558+
else:
559+
self.log.info("Job %s is %s, waiting...", self.job_name, job_status)
560+
except Exception as e:
561+
self.log.exception("Failed to get job/task status")
562+
raise e
563+
564+
await asyncio.sleep(5)
557565

558566
@retry(wait=wait_fixed(3), stop=stop_after_attempt(5))
559567
async def address_and_port_from_consul(

0 commit comments

Comments
 (0)