Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 159 additions & 2 deletions storage-app/src/shared/archive_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def run(self):

buffer, read_size = task
buffer.seek(0)
v = buffer.read(read_size)
data = buffer.read(read_size)

dest.write(v)
dest.write(data)

del buffer

Expand Down Expand Up @@ -248,6 +248,7 @@ def produce_sync(self):
for file in self.object_set:
self.queue.put((self._get_file_name(file), file.read()))
self.iter_count += 1
file.close()
if not self.iter_count % GC_FREQ: gc_collect()

self.queue.put(None)
Expand Down Expand Up @@ -283,3 +284,159 @@ def _get_file_name(self, file: GridOut) -> str:
if extension: name += f".{extension}"

return name


class SyncZipping():
DUMP_THRESHOLD: int = 10 << 20

def __init__(
self,
dest_name: str,
object_set: Cursor,
additional: list[tuple[str, bytes]]
):
self.object_set = object_set
self.additional = additional
self.file_list = []
self._local_dir_end = 0
self._archive_id = None
self.dest = SyncDataBase \
.get_fs_bucket(TEMP_BUCKET) \
.open_upload_stream(
dest_name,
metadata={"created_at": datetime.now().isoformat()}
)

def dest_write(self, buffer, read_size):
buffer.seek(0)
data = buffer.read(read_size)
self.dest.write(data)

def tell(self) -> int: return self._local_dir_end

def result(self) -> Optional[str]: return self._archive_id

def _dump_buffer(self, buffer: BytesIO, zip_buffer: ZipFile):
dest_offset = self.tell()

new_list = zip_buffer.filelist
for zinfo in new_list: zinfo.header_offset += dest_offset

self.file_list += new_list
self._local_dir_end += buffer.tell()

self.dest_write(buffer, buffer.tell())

zip_buffer.close()

def _finalize(self):
self._write_end_record(end_buffer := BytesIO())
self.dest_write(end_buffer, end_buffer.tell())

self._write_cent_dir(
self.tell() + end_buffer.tell(),
self.tell(),
len(self.file_list),
cent_dir_buffer := BytesIO()
)
self.dest_write(cent_dir_buffer, cent_dir_buffer.tell())

self._archive_id = self.dest._id

self.dest.close()

SyncDataBase.close_connection()

def _write_end_record(self, buffer: BytesIO):
for zinfo in self.file_list:
dt = zinfo.date_time

dosdate = (dt[0] - 1980) << 9 | dt[1] << 5 | dt[2]
dostime = dt[3] << 11 | dt[4] << 5 | (dt[5] // 2)
extra = []

assert zinfo.file_size <= ZIP64_LIMIT and zinfo.compress_size <= ZIP64_LIMIT

file_size = zinfo.file_size
compress_size = zinfo.compress_size

if zinfo.header_offset > ZIP64_LIMIT:
extra.append(zinfo.header_offset)
header_offset = 0xffffffff
else: header_offset = zinfo.header_offset

extra_data = zinfo.extra
min_version = 0

if extra:
extra_data = _Extra.strip(extra_data, (1,))
extra_data = pack_data("<HH" + "Q" * len(extra), 1, 8 * len(extra), *extra) + extra_data

min_version = ZIP64_VERSION

extract_version = max(min_version, zinfo.extract_version)
create_version = max(min_version, zinfo.create_version)

filename, flag_bits = zinfo._encodeFilenameFlags()

centdir = pack_data(
CENTRAL_STRUCT,
CENTRAL_STRING,
create_version,
zinfo.create_system,
extract_version,
zinfo.reserved,
flag_bits,
zinfo.compress_type,
dostime,
dosdate,
zinfo.CRC,
compress_size,
file_size,
len(filename),
len(extra_data),
len(zinfo.comment),
0,
zinfo.internal_attr,
zinfo.external_attr,
header_offset
)

buffer.write(centdir + filename + extra_data + zinfo.comment)

def _write_cent_dir(self, pos: int, start_dir: int, d_size: int, buffer: BytesIO):
cent_dir = pos - start_dir

if d_size > ZIP_FILECOUNT_LIMIT or pos > ZIP64_LIMIT:
pack = (END_64_STRUCT, END_64_STRING, 44, 45, 45, 0, 0, d_size, d_size, 0, pos)
buffer.write(pack_data(*pack))
buffer.write(pack_data(END_64_STRUCT_LOC, END_64_STRING_LOC, 0, pos, 1))
cent_dir = min(cent_dir, 0xFFFFFFFF)
start_dir = min(start_dir, 0xFFFFFFFF)
d_size = min(d_size, 0xFFFF)

buffer.write(pack_data(END_STRUCT, END_STRING, 0, 0, d_size, d_size, cent_dir, start_dir, 0))

def run(self):
buffer = BytesIO()
zip_buffer: ZipFile = ZipFile(buffer, "w", ZIP_DEFLATED)

for file in self.object_set:
f_name, ext = str(file._id), file.metadata.get("file_extension", "")
if ext: f_name += f".{ext}"

f_data = file.read()

zip_buffer.writestr(f_name, f_data)

if buffer.tell() > self.DUMP_THRESHOLD:
self._dump_buffer(buffer, zip_buffer)
buffer = BytesIO()
zip_buffer = ZipFile(buffer, "w", ZIP_DEFLATED)

for f_name, f_data in self.additional: zip_buffer.writestr(f_name, f_data)

if buffer.tell(): self._dump_buffer(buffer, zip_buffer)

self.object_set.close()
self._finalize()
26 changes: 18 additions & 8 deletions storage-app/src/shared/worker_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
from .hasher import VHash, IHash
from queue import Queue
from .archive_helpers import FileProducer, ZipConsumer, ZipWriter
from .archive_helpers import FileProducer, ZipConsumer, ZipWriter, SyncZipping
from celery import Task


Expand All @@ -33,7 +33,6 @@ def to_value(self):


class Zipper:
written: bool = False
archive_extension: str = "zip"

def __init__(
Expand All @@ -48,9 +47,24 @@ def __init__(

self._get_annotation(bucket_name, file_ids)

def archive_objects(self) -> Optional[bool]:
def archive_objects(self):
json_data: Any = ("annotation.json", dumps(self.annotation, indent=4).encode("utf-8"))
object_set = SyncDataBase \
.get_fs_bucket(self.bucket_name) \
.find(
{"_id": {"$in": [get_object_id(str(object_id)) for object_id in self.file_ids]}},
no_cursor_timeout=True
) \
.batch_size(200)

zipper = SyncZipping(f"{self.bucket_name}_dataset", object_set, [json_data])
zipper.run()

assert (result_id := zipper.result()), "Archive was not written"
self._archive_id = result_id

def _archive_objects(self):
json_data: Any = ("annotation.json", dumps(self.annotation, indent=4).encode("utf-8"))
# object_set = Bucket(self.bucket_name).get_download_objects(self.file_ids)
object_set = SyncDataBase \
.get_fs_bucket(self.bucket_name) \
.find(
Expand Down Expand Up @@ -86,13 +100,9 @@ def archive_objects(self) -> Optional[bool]:
consumer.join()
writer.join()

self.written = True

assert (result_id := writer.result()), "Archive was not written"
self._archive_id = result_id

return self.written

def _get_annotation(self, bucket_name: str, file_ids: list[str]) -> Any:
url: str = APP_BACKEND_URL + "/api/files/annotation/"
payload_token = emit_token({"minutes": 1}, SECRET_KEY, SECRET_ALGO)
Expand Down