|
| 1 | +import asyncio |
| 2 | +import json |
| 3 | +import math |
| 4 | +import random |
| 5 | +import shutil |
| 6 | +import string |
| 7 | +from pathlib import Path |
| 8 | +from typing import Any, Iterator, List, Optional, Tuple, Union |
| 9 | + |
| 10 | +import httpx |
| 11 | +from httpx import AsyncClient, Response |
| 12 | +from osparc_client import ( |
| 13 | + BodyCompleteMultipartUploadV0FilesFileIdCompletePost, |
| 14 | + ClientFile, |
| 15 | + ClientFileUploadSchema, |
| 16 | +) |
| 17 | +from osparc_client import FilesApi as _FilesApi |
| 18 | +from osparc_client import FileUploadCompletionBody, FileUploadLinks, UploadedPart |
| 19 | +from tqdm.asyncio import tqdm_asyncio |
| 20 | + |
| 21 | +from . import ApiClient, File |
| 22 | +from ._http_client import AsyncHttpClient |
| 23 | +from ._utils import _file_chunk_generator |
| 24 | + |
| 25 | + |
| 26 | +class FilesApi(_FilesApi): |
| 27 | + """Class for interacting with files""" |
| 28 | + |
| 29 | + def __init__(self, api_client: Optional[ApiClient] = None): |
| 30 | + """Construct object |
| 31 | +
|
| 32 | + Args: |
| 33 | + api_client (ApiClient, optinal): osparc.ApiClient object |
| 34 | + """ |
| 35 | + super().__init__(api_client) |
| 36 | + self._super = super(FilesApi, self) |
| 37 | + user: Optional[str] = self.api_client.configuration.username |
| 38 | + passwd: Optional[str] = self.api_client.configuration.password |
| 39 | + self._auth: Optional[httpx.BasicAuth] = ( |
| 40 | + httpx.BasicAuth(username=user, password=passwd) |
| 41 | + if (user is not None and passwd is not None) |
| 42 | + else None |
| 43 | + ) |
| 44 | + |
| 45 | + def download_file( |
| 46 | + self, file_id: str, *, destination_folder: Optional[Path] = None |
| 47 | + ) -> str: |
| 48 | + if destination_folder is not None and not destination_folder.is_dir(): |
| 49 | + raise RuntimeError( |
| 50 | + f"destination_folder: {destination_folder} must be a directory" |
| 51 | + ) |
| 52 | + downloaded_file: Path = Path(super().download_file(file_id)) |
| 53 | + if destination_folder is not None: |
| 54 | + dest_file: Path = destination_folder / downloaded_file.name |
| 55 | + while dest_file.is_file(): |
| 56 | + new_name = ( |
| 57 | + downloaded_file.stem |
| 58 | + + "".join(random.choices(string.ascii_letters, k=8)) |
| 59 | + + downloaded_file.suffix |
| 60 | + ) |
| 61 | + dest_file = destination_folder / new_name |
| 62 | + shutil.move(downloaded_file, dest_file) |
| 63 | + downloaded_file = dest_file |
| 64 | + return str(downloaded_file.resolve()) |
| 65 | + |
| 66 | + def upload_file(self, file: Union[str, Path]): |
| 67 | + return asyncio.run(self.upload_file_async(file=file)) |
| 68 | + |
| 69 | + async def upload_file_async(self, file: Union[str, Path]) -> File: |
| 70 | + if isinstance(file, str): |
| 71 | + file = Path(file) |
| 72 | + if not file.is_file(): |
| 73 | + raise RuntimeError(f"{file} is not a file") |
| 74 | + client_file: ClientFile = ClientFile( |
| 75 | + filename=file.name, filesize=file.stat().st_size |
| 76 | + ) |
| 77 | + client_upload_schema: ClientFileUploadSchema = self._super.get_upload_links( |
| 78 | + client_file=client_file |
| 79 | + ) |
| 80 | + chunk_size: int = client_upload_schema.upload_schema.chunk_size |
| 81 | + links: FileUploadLinks = client_upload_schema.upload_schema.links |
| 82 | + url_iter: Iterator[Tuple[int, str]] = enumerate( |
| 83 | + iter(client_upload_schema.upload_schema.urls), start=1 |
| 84 | + ) |
| 85 | + if len(client_upload_schema.upload_schema.urls) < math.ceil( |
| 86 | + file.stat().st_size / chunk_size |
| 87 | + ): |
| 88 | + raise RuntimeError( |
| 89 | + "Did not receive sufficient number of upload URLs from the server." |
| 90 | + ) |
| 91 | + |
| 92 | + tasks: list = [] |
| 93 | + async with AsyncHttpClient( |
| 94 | + exception_request_type="post", |
| 95 | + exception_url=links.abort_upload, |
| 96 | + exception_auth=self._auth, |
| 97 | + ) as session: |
| 98 | + async for chunck, size in _file_chunk_generator(file, chunk_size): |
| 99 | + # following https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task |
| 100 | + index, url = next(url_iter) |
| 101 | + task = asyncio.create_task( |
| 102 | + self._upload_chunck( |
| 103 | + http_client=session, |
| 104 | + chunck=chunck, |
| 105 | + chunck_size=size, |
| 106 | + upload_link=url, |
| 107 | + index=index, |
| 108 | + ) |
| 109 | + ) |
| 110 | + tasks.append(task) |
| 111 | + |
| 112 | + uploaded_parts: List[UploadedPart] = await tqdm_asyncio.gather(*tasks) |
| 113 | + |
| 114 | + return await self._complete_multipart_upload( |
| 115 | + session, links.complete_upload, client_file, uploaded_parts |
| 116 | + ) |
| 117 | + |
| 118 | + async def _complete_multipart_upload( |
| 119 | + self, |
| 120 | + http_client: AsyncClient, |
| 121 | + complete_link: str, |
| 122 | + client_file: ClientFile, |
| 123 | + uploaded_parts: List[UploadedPart], |
| 124 | + ) -> File: |
| 125 | + complete_payload = BodyCompleteMultipartUploadV0FilesFileIdCompletePost( |
| 126 | + client_file=client_file, |
| 127 | + uploaded_parts=FileUploadCompletionBody(parts=uploaded_parts), |
| 128 | + ) |
| 129 | + response: Response = await http_client.post( |
| 130 | + complete_link, |
| 131 | + json=complete_payload.to_dict(), |
| 132 | + auth=self._auth, |
| 133 | + ) |
| 134 | + response.raise_for_status() |
| 135 | + payload: dict[str, Any] = response.json() |
| 136 | + return File(**payload) |
| 137 | + |
| 138 | + async def _upload_chunck( |
| 139 | + self, |
| 140 | + http_client: AsyncClient, |
| 141 | + chunck: bytes, |
| 142 | + chunck_size: int, |
| 143 | + upload_link: str, |
| 144 | + index: int, |
| 145 | + ) -> UploadedPart: |
| 146 | + response: Response = await http_client.put( |
| 147 | + upload_link, content=chunck, headers={"Content-Length": f"{chunck_size}"} |
| 148 | + ) |
| 149 | + response.raise_for_status() |
| 150 | + assert response.headers # nosec |
| 151 | + assert "Etag" in response.headers # nosec |
| 152 | + etag: str = json.loads(response.headers["Etag"]) |
| 153 | + return UploadedPart(number=index, e_tag=etag) |
0 commit comments