Skip to content

Commit 3ae8997

Browse files
committed
execute notebook as task via DefaultExecutionManager
1 parent 2f6938b commit 3ae8997

File tree

2 files changed

+60
-22
lines changed

2 files changed

+60
-22
lines changed

jupyter_scheduler/executors.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import shutil
44
import tarfile
55
import traceback
6+
import multiprocessing as mp
67
from abc import ABC, abstractmethod
78
from functools import lru_cache
89
from typing import Dict, List
@@ -15,9 +16,10 @@
1516
from prefect.futures import as_completed
1617
from prefect_dask.task_runners import DaskTaskRunner
1718

18-
from jupyter_scheduler.models import DescribeJob, JobFeature, Status
19+
from jupyter_scheduler.models import CreateJob, DescribeJob, JobFeature, Status
1920
from jupyter_scheduler.orm import Job, Workflow, create_session
2021
from jupyter_scheduler.parameterize import add_parameters
22+
from jupyter_scheduler.scheduler import Scheduler
2123
from jupyter_scheduler.utils import get_utc_timestamp
2224
from jupyter_scheduler.workflows import DescribeWorkflow
2325

@@ -186,36 +188,44 @@ def on_complete_workflow(self):
186188
class DefaultExecutionManager(ExecutionManager):
187189
"""Default execution manager that executes notebooks"""
188190

189-
@task(task_run_name="{task_id}")
190-
def execute_task(self, task_id: str):
191-
print(f"Task {task_id} executed")
192-
return task_id
191+
@task
192+
def execute_task(self, job: Job):
193+
with self.db_session() as session:
194+
staging_paths = Scheduler.get_staging_paths(DescribeJob.from_orm(job))
195+
196+
execution_manager = DefaultExecutionManager(
197+
job_id=job.job_id,
198+
staging_paths=staging_paths,
199+
root_dir=self.root_dir,
200+
db_url=self.db_url,
201+
)
202+
execution_manager.process()
203+
204+
job.pid = 1 # TODO: fix pid hardcode
205+
job_id = job.job_id
206+
207+
return job_id
193208

194209
@task
195-
def get_task_data(self, task_ids: List[str] = []):
196-
# TODO: get orm objects from Task table of the db, create DescribeTask for each
197-
tasks_data_obj = [
198-
{"id": "task0", "dependsOn": ["task3"]},
199-
{"id": "task4", "dependsOn": ["task0", "task1", "task2", "task3"]},
200-
{"id": "task1", "dependsOn": []},
201-
{"id": "task2", "dependsOn": ["task1"]},
202-
{"id": "task3", "dependsOn": ["task1", "task2"]},
203-
]
204-
205-
return tasks_data_obj
210+
def get_tasks_records(self, task_ids: List[str]) -> List[Job]:
211+
with self.db_session() as session:
212+
tasks = session.query(Job).filter(Job.job_id.in_(task_ids)).all()
213+
214+
return tasks
206215

207216
@flow
208217
def execute_workflow(self):
218+
tasks_info: List[Job] = self.get_tasks_records(self.model.tasks)
219+
tasks = {task.job_id: task for task in tasks_info}
209220

210-
tasks_info = self.get_task_data()
211-
tasks = {task["id"]: task for task in tasks_info}
212-
213-
# create Prefect tasks, use caching to ensure Prefect tasks are created before wait_for is called on them
214221
@lru_cache(maxsize=None)
215222
def make_task(task_id):
216-
deps = tasks[task_id]["dependsOn"]
223+
"""Create a delayed object for the given task recursively creating delayed objects for all tasks it depends on"""
224+
deps = tasks[task_id].depends_on or []
225+
name = tasks[task_id].name
226+
job_id = tasks[task_id].job_id
217227
return self.execute_task.submit(
218-
task_id, wait_for=[make_task(dep_id) for dep_id in deps]
228+
tasks[task_id], wait_for=[make_task(dep_id) for dep_id in deps]
219229
)
220230

221231
final_tasks = [make_task(task_id) for task_id in tasks]

jupyter_scheduler/scheduler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,11 @@ def create_job(self, model: CreateJob, run: bool = True) -> str:
521521
if not run:
522522
return job.job_id
523523

524+
job_id = self.run_job(job=job, staging_paths=staging_paths)
525+
return job_id
526+
527+
def run_job(self, job: Job, staging_paths: Dict[str, str]) -> str:
528+
with self.db_session() as session:
524529
# The MP context forces new processes to not be forked on Linux.
525530
# This is necessary because `asyncio.get_event_loop()` is bugged in
526531
# forked processes in Python versions below 3.12. This method is
@@ -556,6 +561,7 @@ def create_workflow(self, model: CreateWorkflow) -> str:
556561
def run_workflow(self, workflow_id: str) -> str:
557562
execution_manager = self.execution_manager_class(
558563
workflow_id=workflow_id,
564+
root_dir=self.root_dir,
559565
db_url=self.db_url,
560566
)
561567
execution_manager.process_workflow()
@@ -858,6 +864,28 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
858864

859865
return staging_paths
860866

867+
@staticmethod
868+
def get_staging_paths(model: Union[DescribeJob, DescribeJobDefinition]) -> Dict[str, str]:
869+
staging_paths = {}
870+
if not model:
871+
return staging_paths
872+
873+
id = model.job_id if isinstance(model, DescribeJob) else model.job_definition_id
874+
875+
for output_format in model.output_formats:
876+
filename = create_output_filename(
877+
model.input_filename, model.create_time, output_format
878+
)
879+
staging_paths[output_format] = os.path.join(
880+
os.path.join(jupyter_data_dir(), "scheduler_staging_area"), id, filename
881+
)
882+
883+
staging_paths["input"] = os.path.join(
884+
os.path.join(jupyter_data_dir(), "scheduler_staging_area"), id, model.input_filename
885+
)
886+
887+
return staging_paths
888+
861889
async def stop_extension(self):
862890
"""
863891
Cleanup code to run when the server is stopping.

0 commit comments

Comments
 (0)