|
1 |
| -import asyncio |
2 | 1 | from pathlib import Path
|
3 | 2 | from typing import Sequence, Union
|
4 | 3 |
|
|
7 | 6 | from tqdm import tqdm
|
8 | 7 |
|
9 | 8 | from .base import api_function
|
| 9 | +from .compat import current_loop |
10 | 10 | from .config import DEFAULT_CHUNK_SIZE
|
11 |
| -from .exceptions import BackendAPIError, BackendClientError |
| 11 | +from .exceptions import BackendAPIError |
12 | 12 | from .request import Request, AttachedFile
|
13 | 13 | from .cli.pretty import ProgressReportingReader
|
14 | 14 |
|
@@ -155,43 +155,42 @@ async def download(self, files: Sequence[Union[str, Path]],
|
155 | 155 | rqst.set_json({
|
156 | 156 | 'files': files,
|
157 | 157 | })
|
158 |
| - try: |
159 |
| - async with rqst.fetch() as resp: |
160 |
| - if resp.status // 100 != 2: |
161 |
| - raise BackendAPIError(resp.status, resp.reason, |
162 |
| - await resp.text()) |
163 |
| - total_bytes = int(resp.headers['X-TOTAL-PAYLOADS-LENGTH']) |
164 |
| - tqdm_obj = tqdm(desc='Downloading files', |
165 |
| - unit='bytes', unit_scale=True, |
166 |
| - total=total_bytes, |
167 |
| - disable=not show_progress) |
168 |
| - reader = aiohttp.MultipartReader.from_response(resp.raw_response) |
169 |
| - with tqdm_obj as pbar: |
170 |
| - acc_bytes = 0 |
171 |
| - while True: |
172 |
| - part = await reader.next() |
173 |
| - if part is None: |
174 |
| - break |
175 |
| - assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() == 'identity' |
176 |
| - assert part.headers.get(hdrs.CONTENT_TRANSFER_ENCODING, 'binary').lower() in ( |
177 |
| - 'binary', '8bit', '7bit', |
178 |
| - ) |
179 |
| - with open(part.filename, 'wb') as fp: |
180 |
| - while True: |
181 |
| - chunk = await part.read_chunk(DEFAULT_CHUNK_SIZE) |
182 |
| - if not chunk: |
183 |
| - break |
184 |
| - fp.write(chunk) |
185 |
| - acc_bytes += len(chunk) |
186 |
| - pbar.update(len(chunk)) |
187 |
| - pbar.update(total_bytes - acc_bytes) |
188 |
| - except (asyncio.CancelledError, asyncio.TimeoutError): |
189 |
| - # These exceptions must be bubbled up. |
190 |
| - raise |
191 |
| - except aiohttp.ClientError as e: |
192 |
| - msg = 'Request to the API endpoint has failed.\n' \ |
193 |
| - 'Check your network connection and/or the server status.' |
194 |
| - raise BackendClientError(msg) from e |
| 158 | + file_names = [] |
| 159 | + async with rqst.fetch() as resp: |
| 160 | + if resp.status // 100 != 2: |
| 161 | + raise BackendAPIError(resp.status, resp.reason, |
| 162 | + await resp.text()) |
| 163 | + total_bytes = int(resp.headers['X-TOTAL-PAYLOADS-LENGTH']) |
| 164 | + tqdm_obj = tqdm(desc='Downloading files', |
| 165 | + unit='bytes', unit_scale=True, |
| 166 | + total=total_bytes, |
| 167 | + disable=not show_progress) |
| 168 | + reader = aiohttp.MultipartReader.from_response(resp.raw_response) |
| 169 | + with tqdm_obj as pbar: |
| 170 | + loop = current_loop() |
| 171 | + acc_bytes = 0 |
| 172 | + while True: |
| 173 | + part = await reader.next() |
| 174 | + if part is None: |
| 175 | + break |
| 176 | + assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() in ( |
| 177 | + 'identity', |
| 178 | + 'gzip', # Prior to v19.09.4, the server had a bug to set this incorrectly. |
| 179 | + # This legacy handling will be removed in v19.12 release. |
| 180 | + ) |
| 181 | + assert part.headers.get(hdrs.CONTENT_TRANSFER_ENCODING, 'binary').lower() in ( |
| 182 | + 'binary', '8bit', '7bit', |
| 183 | + ) |
| 184 | + with open(part.filename, 'wb') as fp: |
| 185 | + while True: |
| 186 | + chunk = await part.read_chunk(DEFAULT_CHUNK_SIZE) |
| 187 | + if not chunk: |
| 188 | + break |
| 189 | + await loop.run_in_executor(None, lambda: fp.write(chunk)) |
| 190 | + acc_bytes += len(chunk) |
| 191 | + pbar.update(len(chunk)) |
| 192 | + pbar.update(total_bytes - acc_bytes) |
| 193 | + return {'file_names': file_names} |
195 | 194 |
|
196 | 195 | @api_function
|
197 | 196 | async def list_files(self, path: Union[str, Path] = '.'):
|
|
0 commit comments