Skip to content

Commit 59b6c73

Browse files
ref worker
1 parent 0a2c57e commit 59b6c73

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

storage-app/src/shared/worker_services.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,19 @@ def __init__(
4141
file_ids: list[str],
4242
task: Task
4343
) -> None:
44-
self.object_set = Bucket(bucket_name).get_download_objects(file_ids)
44+
self.file_ids = file_ids
4545
self.bucket_name = bucket_name
4646
self._task = task
4747

4848
self._get_annotation(bucket_name, file_ids)
4949

5050
async def archive_objects(self) -> Optional[bool]:
5151
json_data: Any = ("annotation.json", dumps(self.annotation, indent=4).encode("utf-8"))
52+
object_set = Bucket(self.bucket_name).get_download_objects(self.file_ids)
5253

5354
queue = Queue()
5455

55-
producer = FileProducer(self.object_set, queue, MAX_CONCURENT)
56+
producer = FileProducer(object_set, queue, MAX_CONCURENT)
5657
writer = ZipWriter(f"{self.bucket_name}_dataset")
5758
consumer = ZipConsumer(queue, [json_data], writer)
5859

@@ -68,11 +69,12 @@ async def archive_objects(self) -> Optional[bool]:
6869
wait_list = wait_list.next
6970
continue
7071

72+
print(f"ZIP WORK STALL, {producer.iter_count}")
7173
self._task.update_state(state="PROGRESS")
7274
await async_stall_for(5)
7375

7476
await producer_task
75-
await self.object_set.close()
77+
await object_set.close()
7678
consumer.join()
7779
writer.join()
7880

storage-app/src/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from celery import Celery
22
from shared.settings import BROKER_URL, RESULT_URL, CELERY_CONFIG
33
from shared.worker_services import Zipper, Hasher, EmbeddingStatus
4-
from asyncio import get_event_loop
4+
from asyncio import run
55
from typing import Optional, Any
66
from json import JSONEncoder, loads, dumps
77
from kombu.serialization import register
@@ -35,7 +35,7 @@ def default(self, o) -> Any: return getattr(o, "__json__", super().default)(o)
3535
)
3636
def produce_download_task(self, bucket_name: str, file_ids: list[str]) -> str | None:
3737
task = Zipper(bucket_name, file_ids, self)
38-
get_event_loop().run_until_complete(task.archive_objects())
38+
run(task.archive_objects())
3939
return task.archive_id
4040

4141

0 commit comments

Comments
 (0)