|
3 | 3 | import functools |
4 | 4 | import logging |
5 | 5 | import urllib.parse |
6 | | -from collections.abc import AsyncGenerator, Sequence |
| 6 | +from collections.abc import AsyncGenerator, Callable, Sequence |
7 | 7 | from dataclasses import dataclass, field |
8 | 8 | from pathlib import Path |
9 | 9 | from typing import Any, Final, Protocol, cast |
|
18 | 18 | from pydantic import AnyUrl, ByteSize, TypeAdapter |
19 | 19 | from servicelib.logging_utils import log_catch, log_context |
20 | 20 | from servicelib.utils import limited_gather |
| 21 | +from servicelib.zip_stream import DEFAULT_CHUNK_SIZE, FileStream |
21 | 22 | from settings_library.s3 import S3Settings |
22 | 23 | from types_aiobotocore_s3 import S3Client |
23 | 24 | from types_aiobotocore_s3.literals import BucketLocationConstraintType |
@@ -57,6 +58,14 @@ def __call__(self, total_bytes_copied: int, *, file_name: str) -> None: |
57 | 58 | ... |
58 | 59 |
|
59 | 60 |
|
| 61 | +class AsyncFileProtocol(Protocol): |
| 62 | + async def read(self, chunk_size: int) -> bytes: |
| 63 | + ... |
| 64 | + |
| 65 | + async def write(self, data: bytes) -> None: |
| 66 | + ... |
| 67 | + |
| 68 | + |
60 | 69 | @dataclass(frozen=True) |
61 | 70 | class SimcoreS3API: # pylint: disable=too-many-public-methods |
62 | 71 | _client: S3Client |
@@ -470,6 +479,79 @@ async def copy_objects_recursively( |
470 | 479 | limit=_MAX_CONCURRENT_COPY, |
471 | 480 | ) |
472 | 481 |
|
| 482 | + async def get_object_file_stream( |
| 483 | + self, |
| 484 | + bucket_name: S3BucketName, |
| 485 | + object_key: S3ObjectKey, |
| 486 | + *, |
| 487 | + chunk_size: int = DEFAULT_CHUNK_SIZE, |
| 488 | + ) -> FileStream: |
| 489 | + response = await self._client.head_object(Bucket=bucket_name, Key=object_key) |
| 490 | + file_size = response["ContentLength"] |
| 491 | + |
| 492 | + # Download the file in chunks |
| 493 | + position = 0 |
| 494 | + while position < file_size: |
| 495 | + # Calculate the range for this chunk |
| 496 | + end = min(position + chunk_size - 1, file_size - 1) |
| 497 | + range_header = f"bytes={position}-{end}" |
| 498 | + |
| 499 | + # Download the chunk |
| 500 | + response = await self._client.get_object( |
| 501 | + Bucket=bucket_name, Key=object_key, Range=range_header |
| 502 | + ) |
| 503 | + |
| 504 | + chunk = await response["Body"].read() |
| 505 | + |
| 506 | + # Yield the chunk for processing |
| 507 | + yield chunk |
| 508 | + |
| 509 | + position += chunk_size |
| 510 | + |
| 511 | + @s3_exception_handler(_logger) |
| 512 | + async def upload_object_from_file_stream( |
| 513 | + self, |
| 514 | + bucket_name: S3BucketName, |
| 515 | + object_key: S3ObjectKey, |
| 516 | + file_stream: Callable[[], FileStream], |
| 517 | + ) -> None: |
| 518 | + # Create a multipart upload |
| 519 | + multipart_response = await self._client.create_multipart_upload( |
| 520 | + Bucket=bucket_name, Key=object_key |
| 521 | + ) |
| 522 | + upload_id = multipart_response["UploadId"] |
| 523 | + |
| 524 | + try: |
| 525 | + parts = [] |
| 526 | + part_number = 1 |
| 527 | + |
| 528 | + async for chunk in file_stream(): |
| 529 | + print(f"partsizze={len(chunk)}") |
| 530 | + |
| 531 | + part_response = await self._client.upload_part( |
| 532 | + Bucket=bucket_name, |
| 533 | + Key=object_key, |
| 534 | + PartNumber=part_number, |
| 535 | + UploadId=upload_id, |
| 536 | + Body=chunk, |
| 537 | + ) |
| 538 | + parts.append({"ETag": part_response["ETag"], "PartNumber": part_number}) |
| 539 | + part_number += 1 |
| 540 | + |
| 541 | + # Complete the multipart upload |
| 542 | + await self._client.complete_multipart_upload( |
| 543 | + Bucket=bucket_name, |
| 544 | + Key=object_key, |
| 545 | + UploadId=upload_id, |
| 546 | + MultipartUpload={"Parts": parts}, |
| 547 | + ) |
| 548 | + except Exception: |
| 549 | + # Abort the multipart upload if something goes wrong |
| 550 | + await self._client.abort_multipart_upload( |
| 551 | + Bucket=bucket_name, Key=object_key, UploadId=upload_id |
| 552 | + ) |
| 553 | + raise |
| 554 | + |
473 | 555 | @staticmethod |
474 | 556 | def is_multipart(file_size: ByteSize) -> bool: |
475 | 557 | return file_size >= MULTIPART_UPLOADS_MIN_TOTAL_SIZE |
|
0 commit comments