|
3 | 3 | import shutil
|
4 | 4 | import tarfile
|
5 | 5 | import traceback
|
| 6 | +import multiprocessing as mp |
6 | 7 | from abc import ABC, abstractmethod
|
7 | 8 | from functools import lru_cache
|
8 | 9 | from typing import Dict, List
|
|
15 | 16 | from prefect.futures import as_completed
|
16 | 17 | from prefect_dask.task_runners import DaskTaskRunner
|
17 | 18 |
|
18 |
| -from jupyter_scheduler.models import DescribeJob, JobFeature, Status |
| 19 | +from jupyter_scheduler.models import CreateJob, DescribeJob, JobFeature, Status |
19 | 20 | from jupyter_scheduler.orm import Job, Workflow, create_session
|
20 | 21 | from jupyter_scheduler.parameterize import add_parameters
|
| 22 | +from jupyter_scheduler.scheduler import Scheduler |
21 | 23 | from jupyter_scheduler.utils import get_utc_timestamp
|
22 | 24 | from jupyter_scheduler.workflows import DescribeWorkflow
|
23 | 25 |
|
@@ -186,36 +188,44 @@ def on_complete_workflow(self):
|
186 | 188 | class DefaultExecutionManager(ExecutionManager):
|
187 | 189 | """Default execution manager that executes notebooks"""
|
188 | 190 |
|
189 |
| - @task(task_run_name="{task_id}") |
190 |
| - def execute_task(self, task_id: str): |
191 |
| - print(f"Task {task_id} executed") |
192 |
| - return task_id |
| 191 | + @task |
| 192 | + def execute_task(self, job: Job): |
| 193 | + with self.db_session() as session: |
| 194 | + staging_paths = Scheduler.get_staging_paths(DescribeJob.from_orm(job)) |
| 195 | + |
| 196 | + execution_manager = DefaultExecutionManager( |
| 197 | + job_id=job.job_id, |
| 198 | + staging_paths=staging_paths, |
| 199 | + root_dir=self.root_dir, |
| 200 | + db_url=self.db_url, |
| 201 | + ) |
| 202 | + execution_manager.process() |
| 203 | + |
| 204 | + job.pid = 1 # TODO: fix pid hardcode |
| 205 | + job_id = job.job_id |
| 206 | + |
| 207 | + return job_id |
193 | 208 |
|
194 | 209 | @task
|
195 |
| - def get_task_data(self, task_ids: List[str] = []): |
196 |
| - # TODO: get orm objects from Task table of the db, create DescribeTask for each |
197 |
| - tasks_data_obj = [ |
198 |
| - {"id": "task0", "dependsOn": ["task3"]}, |
199 |
| - {"id": "task4", "dependsOn": ["task0", "task1", "task2", "task3"]}, |
200 |
| - {"id": "task1", "dependsOn": []}, |
201 |
| - {"id": "task2", "dependsOn": ["task1"]}, |
202 |
| - {"id": "task3", "dependsOn": ["task1", "task2"]}, |
203 |
| - ] |
204 |
| - |
205 |
| - return tasks_data_obj |
| 210 | + def get_tasks_records(self, task_ids: List[str]) -> List[Job]: |
| 211 | + with self.db_session() as session: |
| 212 | + tasks = session.query(Job).filter(Job.job_id.in_(task_ids)).all() |
| 213 | + |
| 214 | + return tasks |
206 | 215 |
|
207 | 216 | @flow
|
208 | 217 | def execute_workflow(self):
|
| 218 | + tasks_info: List[Job] = self.get_tasks_records(self.model.tasks) |
| 219 | + tasks = {task.job_id: task for task in tasks_info} |
209 | 220 |
|
210 |
| - tasks_info = self.get_task_data() |
211 |
| - tasks = {task["id"]: task for task in tasks_info} |
212 |
| - |
213 |
| - # create Prefect tasks, use caching to ensure Prefect tasks are created before wait_for is called on them |
214 | 221 | @lru_cache(maxsize=None)
|
215 | 222 | def make_task(task_id):
|
216 |
| - deps = tasks[task_id]["dependsOn"] |
| 223 | + """Create a delayed object for the given task recursively creating delayed objects for all tasks it depends on""" |
| 224 | + deps = tasks[task_id].depends_on or [] |
| 225 | + name = tasks[task_id].name |
| 226 | + job_id = tasks[task_id].job_id |
217 | 227 | return self.execute_task.submit(
|
218 |
| - task_id, wait_for=[make_task(dep_id) for dep_id in deps] |
| 228 | + tasks[task_id], wait_for=[make_task(dep_id) for dep_id in deps] |
219 | 229 | )
|
220 | 230 |
|
221 | 231 | final_tasks = [make_task(task_id) for task_id in tasks]
|
|
0 commit comments