|
1 | 1 | #!/usr/bin/env python3 |
2 | 2 | # -*- coding: utf-8 -*- |
| 3 | + |
3 | 4 | from celery.exceptions import NotRegistered |
4 | 5 | from celery.result import AsyncResult |
| 6 | +from starlette.concurrency import run_in_threadpool |
5 | 7 |
|
6 | 8 | from backend.app.task.celery import celery_app |
| 9 | +from backend.app.task.schema.task import RunParam |
| 10 | +from backend.common.dataclasses import TaskResult |
7 | 11 | from backend.common.exception.errors import NotFoundError |
8 | 12 |
|
9 | 13 |
|
10 | 14 | class TaskService: |
11 | 15 | @staticmethod |
12 | | - def get_list(): |
13 | | - filtered_tasks = [] |
14 | | - tasks = celery_app.tasks |
15 | | - for key, value in tasks.items(): |
16 | | - if not key.startswith('celery.'): |
17 | | - filtered_tasks.append({key, value}) |
18 | | - return filtered_tasks |
19 | | - |
20 | | - @staticmethod |
21 | | - def get(): |
22 | | - return celery_app.current_worker_task |
| 16 | + async def get_list(): |
| 17 | + registered_tasks = await run_in_threadpool(celery_app.control.inspect().registered) |
| 18 | + tasks = list(registered_tasks.values())[0] |
| 19 | + return tasks |
23 | 20 |
|
24 | 21 | @staticmethod |
25 | | - def get_status(uid: str): |
| 22 | + def get_detail(*, tid: str): |
26 | 23 | try: |
27 | | - task_result = AsyncResult(id=uid, app=celery_app) |
| 24 | + result = AsyncResult(id=tid, app=celery_app) |
28 | 25 | except NotRegistered: |
29 | 26 | raise NotFoundError(msg='任务不存在') |
30 | | - return task_result.status |
| 27 | + return TaskResult( |
| 28 | + result=result.result, |
| 29 | + traceback=result.traceback, |
| 30 | + status=result.state, |
| 31 | + name=result.name, |
| 32 | + args=result.args, |
| 33 | + kwargs=result.kwargs, |
| 34 | + worker=result.worker, |
| 35 | + retries=result.retries, |
| 36 | + queue=result.queue, |
| 37 | + ) |
31 | 38 |
|
32 | 39 | @staticmethod |
33 | | - def get_result(uid: str): |
| 40 | + def revoke(*, tid: str): |
34 | 41 | try: |
35 | | - task_result = AsyncResult(id=uid, app=celery_app) |
| 42 | + result = AsyncResult(id=tid, app=celery_app) |
36 | 43 | except NotRegistered: |
37 | 44 | raise NotFoundError(msg='任务不存在') |
38 | | - return task_result.result |
| 45 | + result.revoke(terminate=True) |
39 | 46 |
|
40 | 47 | @staticmethod |
41 | | - def run(*, name: str, args: list | None = None, kwargs: dict | None = None): |
42 | | - task = celery_app.send_task(name=name, args=args, kwargs=kwargs) |
43 | | - return task |
| 48 | + def run(*, obj: RunParam): |
| 49 | + task: AsyncResult = celery_app.send_task(name=obj.name, args=obj.args, kwargs=obj.kwargs) |
| 50 | + return task.task_id |
44 | 51 |
|
45 | 52 |
|
46 | 53 | task_service: TaskService = TaskService() |
0 commit comments