Skip to content

Commit a928398

Browse files
committed
Merge branch 'master' into pr/giancarloromeo/8141
2 parents a74ec78 + 87820ae commit a928398

File tree

76 files changed

+1228
-589
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1228
-589
lines changed

packages/aws-library/src/aws_library/s3/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
)
2323

2424
__all__: tuple[str, ...] = (
25-
"CopiedBytesTransferredCallback",
26-
"MultiPartUploadLinks",
2725
"PRESIGNED_LINK_MAX_SIZE",
2826
"S3_MAX_FILE_SIZE",
27+
"CopiedBytesTransferredCallback",
28+
"MultiPartUploadLinks",
2929
"S3AccessError",
3030
"S3BucketInvalidError",
3131
"S3DestinationNotEmptyError",
@@ -37,8 +37,8 @@
3737
"S3RuntimeError",
3838
"S3UploadNotFoundError",
3939
"SimcoreS3API",
40-
"UploadedBytesTransferredCallback",
4140
"UploadID",
41+
"UploadedBytesTransferredCallback",
4242
)
4343

4444
# nopycln: file

packages/celery-library/src/celery_library/backends/_redis.py renamed to packages/celery-library/src/celery_library/backends/redis.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
11
import contextlib
22
import logging
33
from datetime import timedelta
4-
from typing import Final
4+
from typing import TYPE_CHECKING, Final
55

66
from models_library.progress_bar import ProgressReport
77
from pydantic import ValidationError
88
from servicelib.celery.models import (
99
Task,
1010
TaskFilter,
1111
TaskID,
12+
TaskInfoStore,
1213
TaskMetadata,
13-
TaskUUID,
14+
Wildcard,
1415
)
15-
from servicelib.redis import RedisClientSDK
16-
17-
from ..utils import build_task_id_prefix
16+
from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types
1817

1918
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
2019
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
21-
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
22-
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 10000
20+
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000
2321
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
2422
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"
2523

@@ -41,23 +39,24 @@ async def create_task(
4139
expiry: timedelta,
4240
) -> None:
4341
task_key = _build_key(task_id)
44-
await self._redis_client_sdk.redis.hset(
45-
name=task_key,
46-
key=_CELERY_TASK_METADATA_KEY,
47-
value=task_metadata.model_dump_json(),
48-
) # type: ignore
42+
await handle_redis_returns_union_types(
43+
self._redis_client_sdk.redis.hset(
44+
name=task_key,
45+
key=_CELERY_TASK_METADATA_KEY,
46+
value=task_metadata.model_dump_json(),
47+
)
48+
)
4949
await self._redis_client_sdk.redis.expire(
5050
task_key,
5151
expiry,
5252
)
5353

54-
async def exists_task(self, task_id: TaskID) -> bool:
55-
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
56-
assert isinstance(n, int) # nosec
57-
return n > 0
58-
5954
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
60-
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_METADATA_KEY) # type: ignore
55+
raw_result = await handle_redis_returns_union_types(
56+
self._redis_client_sdk.redis.hget(
57+
_build_key(task_id), _CELERY_TASK_METADATA_KEY
58+
)
59+
)
6160
if not raw_result:
6261
return None
6362

@@ -70,7 +69,11 @@ async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
7069
return None
7170

7271
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
73-
raw_result = await self._redis_client_sdk.redis.hget(_build_key(task_id), _CELERY_TASK_PROGRESS_KEY) # type: ignore
72+
raw_result = await handle_redis_returns_union_types(
73+
self._redis_client_sdk.redis.hget(
74+
_build_key(task_id), _CELERY_TASK_PROGRESS_KEY
75+
)
76+
)
7477
if not raw_result:
7578
return None
7679

@@ -83,17 +86,14 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8386
return None
8487

8588
async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
86-
search_key = (
87-
_CELERY_TASK_INFO_PREFIX
88-
+ build_task_id_prefix(task_filter)
89-
+ _CELERY_TASK_ID_KEY_SEPARATOR
89+
search_key = _CELERY_TASK_INFO_PREFIX + task_filter.create_task_id(
90+
task_uuid=Wildcard()
9091
)
91-
search_key_len = len(search_key)
9292

9393
keys: list[str] = []
9494
pipeline = self._redis_client_sdk.redis.pipeline()
9595
async for key in self._redis_client_sdk.redis.scan_iter(
96-
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
96+
match=search_key, count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
9797
):
9898
# fake redis (tests) returns bytes, real redis returns str
9999
_key = (
@@ -115,7 +115,7 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
115115
task_metadata = TaskMetadata.model_validate_json(raw_metadata)
116116
tasks.append(
117117
Task(
118-
uuid=TaskUUID(key[search_key_len:]),
118+
uuid=TaskFilter.get_task_uuid(key),
119119
metadata=task_metadata,
120120
)
121121
)
@@ -126,8 +126,19 @@ async def remove_task(self, task_id: TaskID) -> None:
126126
await self._redis_client_sdk.redis.delete(_build_key(task_id))
127127

128128
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
129-
await self._redis_client_sdk.redis.hset(
130-
name=_build_key(task_id),
131-
key=_CELERY_TASK_PROGRESS_KEY,
132-
value=report.model_dump_json(),
133-
) # type: ignore
129+
await handle_redis_returns_union_types(
130+
self._redis_client_sdk.redis.hset(
131+
name=_build_key(task_id),
132+
key=_CELERY_TASK_PROGRESS_KEY,
133+
value=report.model_dump_json(),
134+
)
135+
)
136+
137+
async def task_exists(self, task_id: TaskID) -> bool:
138+
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
139+
assert isinstance(n, int) # nosec
140+
return n > 0
141+
142+
143+
if TYPE_CHECKING:
144+
_: type[TaskInfoStore] = RedisTaskInfoStore

packages/celery-library/src/celery_library/common.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@
22
from typing import Any
33

44
from celery import Celery # type: ignore[import-untyped]
5-
from servicelib.redis import RedisClientSDK
65
from settings_library.celery import CelerySettings
76
from settings_library.redis import RedisDatabase
87

9-
from .backends._redis import RedisTaskInfoStore
10-
from .task_manager import CeleryTaskManager
11-
128

139
def _celery_configure(celery_settings: CelerySettings) -> dict[str, Any]:
1410
base_config = {
@@ -36,22 +32,3 @@ def create_app(settings: CelerySettings) -> Celery:
3632
),
3733
**_celery_configure(settings),
3834
)
39-
40-
41-
async def create_task_manager(
42-
app: Celery, settings: CelerySettings
43-
) -> CeleryTaskManager:
44-
redis_client_sdk = RedisClientSDK(
45-
settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(
46-
RedisDatabase.CELERY_TASKS
47-
),
48-
client_name="celery_tasks",
49-
)
50-
await redis_client_sdk.setup()
51-
# GCR please address https://github.com/ITISFoundation/osparc-simcore/issues/8159
52-
53-
return CeleryTaskManager(
54-
app,
55-
settings,
56-
RedisTaskInfoStore(redis_client_sdk),
57-
)

packages/celery-library/src/celery_library/errors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import base64
22
import pickle
33

4+
from common_library.errors_classes import OsparcErrorMixin
5+
46

57
class TransferrableCeleryError(Exception):
68
def __repr__(self) -> str:
@@ -22,3 +24,7 @@ def decode_celery_transferrable_error(error: TransferrableCeleryError) -> Except
2224
assert isinstance(error, TransferrableCeleryError) # nosec
2325
result: Exception = pickle.loads(base64.b64decode(error.args[0])) # noqa: S301
2426
return result
27+
28+
29+
class TaskNotFoundError(OsparcErrorMixin, Exception):
30+
msg_template = "Task with id '{task_id}' was not found"

packages/celery-library/src/celery_library/rpc/_async_jobs.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from models_library.api_schemas_rpc_async_jobs.exceptions import (
1414
JobAbortedError,
1515
JobError,
16+
JobMissingError,
1617
JobNotDoneError,
1718
JobSchedulerError,
1819
)
@@ -22,6 +23,7 @@
2223
from servicelib.rabbitmq import RPCRouter
2324

2425
from ..errors import (
26+
TaskNotFoundError,
2527
TransferrableCeleryError,
2628
decode_celery_transferrable_error,
2729
)
@@ -30,7 +32,7 @@
3032
router = RPCRouter()
3133

3234

33-
@router.expose(reraise_if_error_type=(JobSchedulerError,))
35+
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
3436
async def cancel(
3537
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
3638
):
@@ -42,11 +44,13 @@ async def cancel(
4244
task_filter=task_filter,
4345
task_uuid=job_id,
4446
)
47+
except TaskNotFoundError as exc:
48+
raise JobMissingError(job_id=job_id) from exc
4549
except CeleryError as exc:
4650
raise JobSchedulerError(exc=f"{exc}") from exc
4751

4852

49-
@router.expose(reraise_if_error_type=(JobSchedulerError,))
53+
@router.expose(reraise_if_error_type=(JobSchedulerError, JobMissingError))
5054
async def status(
5155
task_manager: TaskManager, job_id: AsyncJobId, job_filter: AsyncJobFilter
5256
) -> AsyncJobStatus:
@@ -59,6 +63,8 @@ async def status(
5963
task_filter=task_filter,
6064
task_uuid=job_id,
6165
)
66+
except TaskNotFoundError as exc:
67+
raise JobMissingError(job_id=job_id) from exc
6268
except CeleryError as exc:
6369
raise JobSchedulerError(exc=f"{exc}") from exc
6470

@@ -71,9 +77,10 @@ async def status(
7177

7278
@router.expose(
7379
reraise_if_error_type=(
80+
JobAbortedError,
7481
JobError,
82+
JobMissingError,
7583
JobNotDoneError,
76-
JobAbortedError,
7784
JobSchedulerError,
7885
)
7986
)
@@ -97,11 +104,11 @@ async def result(
97104
task_filter=task_filter,
98105
task_uuid=job_id,
99106
)
107+
except TaskNotFoundError as exc:
108+
raise JobMissingError(job_id=job_id) from exc
100109
except CeleryError as exc:
101110
raise JobSchedulerError(exc=f"{exc}") from exc
102111

103-
if _status.task_state == TaskState.ABORTED:
104-
raise JobAbortedError(job_id=job_id)
105112
if _status.task_state == TaskState.FAILURE:
106113
# fallback exception to report
107114
exc_type = type(_result).__name__
@@ -127,9 +134,8 @@ async def result(
127134

128135
@router.expose(reraise_if_error_type=(JobSchedulerError,))
129136
async def list_jobs(
130-
task_manager: TaskManager, filter_: str, job_filter: AsyncJobFilter
137+
task_manager: TaskManager, job_filter: AsyncJobFilter
131138
) -> list[AsyncJobGet]:
132-
_ = filter_
133139
assert task_manager # nosec
134140
task_filter = TaskFilter.model_validate(job_filter.model_dump())
135141
try:

packages/celery-library/src/celery_library/signals.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@
66
from celery.worker.worker import WorkController # type: ignore[import-untyped]
77
from servicelib.celery.app_server import BaseAppServer
88
from servicelib.logging_utils import log_context
9-
from settings_library.celery import CelerySettings
109

11-
from .common import create_task_manager
1210
from .utils import get_app_server, set_app_server
1311

1412
_logger = logging.getLogger(__name__)
1513

1614

1715
def on_worker_init(
18-
app_server: BaseAppServer,
19-
celery_settings: CelerySettings,
2016
sender: WorkController,
17+
app_server: BaseAppServer,
2118
**_kwargs,
2219
) -> None:
2320
startup_complete_event = threading.Event()
@@ -26,21 +23,14 @@ def _init(startup_complete_event: threading.Event) -> None:
2623
loop = asyncio.new_event_loop()
2724
asyncio.set_event_loop(loop)
2825

29-
async def _setup_task_manager():
30-
assert sender.app # nosec
31-
assert isinstance(sender.app, Celery) # nosec
32-
33-
app_server.task_manager = await create_task_manager(
34-
sender.app,
35-
celery_settings,
36-
)
26+
assert sender.app # nosec
27+
assert isinstance(sender.app, Celery) # nosec
3728

38-
set_app_server(sender.app, app_server)
29+
set_app_server(sender.app, app_server)
3930

4031
app_server.event_loop = loop
4132

42-
loop.run_until_complete(_setup_task_manager())
43-
loop.run_until_complete(app_server.lifespan(startup_complete_event))
33+
loop.run_until_complete(app_server.run_until_shutdown(startup_complete_event))
4434

4535
thread = threading.Thread(
4636
group=None,

0 commit comments

Comments
 (0)