Skip to content

Commit b7ca813

Browse files
refactor arch task logic
1 parent 3669a1b commit b7ca813

File tree

4 files changed

+22
-14
lines changed

4 files changed

+22
-14
lines changed

storage-app/src/shared/archive_helpers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,13 @@ def __init__(
238238
self.queue = queue
239239
self.max_concurrent = max_concurrent
240240
self._done = False
241+
self.iter_count = 0
241242

242243
@property
243244
def ready(self) -> bool: return self._done
244245

245246
async def produce(self):
246247
tasks = []
247-
iter_count = 0
248248

249249
while await self.object_set.fetch_next:
250250
file = self.object_set.next_object()
@@ -254,8 +254,8 @@ async def produce(self):
254254
await wait(tasks, return_when=FIRST_COMPLETED)
255255
tasks = [task for task in tasks if not task.done()]
256256

257-
iter_count += 1
258-
if not iter_count % GC_FREQ: gc_collect()
257+
self.iter_count += 1
258+
if not self.iter_count % GC_FREQ: gc_collect()
259259

260260
await gather(*tasks)
261261
self.queue.put(None)

storage-app/src/shared/settings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
assert STORAGE_PORT
2121

22-
ASYNC_PRODUCER_MAX_CONCURRENT: int = 1_000
23-
ASYNC_PRODUCER_GC_FREQ: int = 100
22+
ASYNC_PRODUCER_MAX_CONCURRENT: int = 256
23+
ASYNC_PRODUCER_GC_FREQ: int = 256
2424
APP_BACKEND_URL: str = "http://" + getenv("APP_BACKEND_URL", "127.0.0.1")
2525
SECRET_KEY: str = getenv("SECRET_KEY", "")
2626
SECRET_ALGO: str = getenv("SECRET_ALGO", "HS256")

storage-app/src/shared/worker_services.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .hasher import VHash, IHash
1616
from queue import Queue
1717
from .archive_helpers import FileProducer, ZipConsumer, ZipWriter
18+
from celery import Task
1819

1920

2021
class EmbeddingStatus(Enum):
@@ -34,13 +35,18 @@ class Zipper:
3435
written: bool = False
3536
archive_extension: str = "zip"
3637

37-
def __init__(self, bucket_name: str, file_ids: list[str]) -> None:
38+
def __init__(
39+
self,
40+
bucket_name: str,
41+
file_ids: list[str],
42+
task: Task
43+
) -> None:
3844
self.object_set = Bucket(bucket_name).get_download_objects(file_ids)
45+
self.bucket_name = bucket_name
46+
self._task = task
3947

4048
self._get_annotation(bucket_name, file_ids)
4149

42-
self.bucket_name = bucket_name
43-
4450
async def archive_objects(self) -> Optional[bool]:
4551
json_data: Any = ("annotation.json", dumps(self.annotation, indent=4).encode("utf-8"))
4652

@@ -61,7 +67,9 @@ async def archive_objects(self) -> Optional[bool]:
6167
if wait_list.task.ready:
6268
wait_list = wait_list.next
6369
continue
64-
await async_stall_for(1)
70+
71+
self._task.update_state(state="PROGRESS")
72+
await async_stall_for(5)
6573

6674
await producer_task
6775
await self.object_set.close()

storage-app/src/worker.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from celery import Celery
22
from shared.settings import BROKER_URL, RESULT_URL, CELERY_CONFIG
33
from shared.worker_services import Zipper, Hasher, EmbeddingStatus
4-
from asyncio import get_event_loop
4+
from asyncio import run
55
from typing import Optional, Any
66
from json import JSONEncoder, loads, dumps
77
from kombu.serialization import register
@@ -26,10 +26,10 @@ def default(self, o) -> Any: return getattr(o, "__json__", super().default)(o)
2626
)
2727

2828

29-
@worker.task(name="produce_download_task")
30-
def produce_download_task(bucket_name: str, file_ids: list[str]) -> str | None:
31-
task = Zipper(bucket_name, file_ids)
32-
get_event_loop().run_until_complete(task.archive_objects())
29+
@worker.task(bind=True, name="produce_download_task")
30+
def produce_download_task(self, bucket_name: str, file_ids: list[str]) -> str | None:
31+
task = Zipper(bucket_name, file_ids, self)
32+
run(task.archive_objects())
3333
return task.archive_id
3434

3535

0 commit comments

Comments
 (0)