Skip to content

Commit 4aa3046

Browse files
committed
add create_workflow_task to handler and scheduler
1 parent 5c351ac commit 4aa3046

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

jupyter_scheduler/scheduler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
create_output_directory,
4747
create_output_filename,
4848
)
49-
from jupyter_scheduler.workflows import CreateWorkflow, DescribeWorkflow
49+
from jupyter_scheduler.workflows import CreateWorkflow, DescribeWorkflow, UpdateWorkflow
5050

5151

5252
class BaseScheduler(LoggingConfigurable):
@@ -124,6 +124,10 @@ def get_workflow(self, workflow_id: str) -> DescribeWorkflow:
124124
"""Returns workflow record for a single workflow."""
125125
raise NotImplementedError("must be implemented by subclass")
126126

127+
def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str:
128+
"""Adds a task to a workflow."""
129+
raise NotImplementedError("must be implemented by subclass")
130+
127131
def update_job(self, job_id: str, model: UpdateJob):
128132
"""Updates job metadata in the persistence store,
129133
for example name, status etc. In case of status
@@ -565,6 +569,14 @@ def get_workflow(self, workflow_id: str) -> DescribeWorkflow:
565569
model = DescribeWorkflow.from_orm(workflow_record)
566570
return model
567571

572+
def create_workflow_task(self, workflow_id: str, model: CreateJob) -> str:
573+
job_id = self.scheduler.create_job(model, run=False)
574+
workflow: DescribeWorkflow = self.scheduler.get_workflow(workflow_id)
575+
updated_tasks = (workflow.tasks or [])[:]
576+
updated_tasks.append(job_id)
577+
self.scheduler.update_workflow(workflow_id, UpdateWorkflow(depends_on=updated_tasks))
578+
return job_id
579+
568580
def update_job(self, job_id: str, model: UpdateJob):
569581
with self.db_session() as session:
570582
session.query(Job).filter(Job.job_id == job_id).update(model.dict(exclude_none=True))

jupyter_scheduler/workflows.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import List
2+
from typing import List, Optional
33

44
from jupyter_server.utils import ensure_async
55
from tornado.web import HTTPError, authenticated
@@ -71,7 +71,11 @@ async def post(self, workflow_id: str):
7171
"Error during workflow job creation. workflow_id in the URL and payload don't match.",
7272
)
7373
try:
74-
job_id = await ensure_async(self.scheduler.create_job(CreateJob(**payload), run=False))
74+
job_id = await ensure_async(
75+
self.scheduler.create_workflow_task(
76+
workflow_id=workflow_id, model=CreateJob(**payload)
77+
)
78+
)
7579
except ValidationError as e:
7680
self.log.exception(e)
7781
raise HTTPError(500, str(e)) from e
@@ -175,3 +179,10 @@ class UpdateWorkflow(BaseModel):
175179

176180
class Config:
177181
orm_mode = True
182+
183+
184+
class UpdateWorkflow(BaseModel):
185+
status: Optional[Status] = None
186+
name: Optional[str] = None
187+
compute_type: Optional[str] = None
188+
depends_on: Optional[str] = None

0 commit comments

Comments
 (0)