Skip to content
This repository was archived by the owner on Sep 22, 2023. It is now read-only.

Commit 94d4b75

Browse files
committed
Clean up download API implementations
* Use aiohttp-provided multipart readers consistently for both vfolder download and kernel download. * Allow "gzip" content-encoding in vfolder download API for legacy, though we read the content *WITHOUT* decompression. T_T * Use event-loop executors for file writes to avoid latency jitters for those who embed this SDK in their async applications. - TODO: We also need to improve the upload implementations...
1 parent 851c0b9 commit 94d4b75

File tree

2 files changed

+63
-63
lines changed

2 files changed

+63
-63
lines changed

src/ai/backend/client/kernel.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
from pathlib import Path
1010
import uuid
1111

12+
import aiohttp
13+
from aiohttp import hdrs
1214
from tqdm import tqdm
1315

1416
from .base import api_function
17+
from .compat import current_loop
18+
from .config import DEFAULT_CHUNK_SIZE
1519
from .exceptions import BackendClientError
1620
from .request import (
1721
Request, AttachedFile,
@@ -457,40 +461,37 @@ async def download(self, files: Sequence[Union[str, Path]],
457461
rqst.set_json({
458462
'files': [*map(str, files)],
459463
})
464+
file_names = []
460465
async with rqst.fetch() as resp:
461-
chunk_size = 1 * 1024
462-
file_names = None
466+
loop = current_loop()
463467
tqdm_obj = tqdm(desc='Downloading files',
464468
unit='bytes', unit_scale=True,
465469
total=resp.content.total_bytes,
466470
disable=not show_progress)
471+
reader = aiohttp.MultipartReader.from_response(resp.raw_response)
467472
with tqdm_obj as pbar:
468-
fp = None
469473
while True:
470-
chunk = await resp.aread(chunk_size)
471-
if not chunk:
474+
part = await reader.next()
475+
if part is None:
472476
break
473-
pbar.update(len(chunk))
474-
# TODO: more elegant parsing of multipart response?
475-
for part in chunk.split(b'\r\n'):
476-
if part.startswith(b'--'):
477-
if fp:
478-
fp.close()
479-
with tarfile.open(fp.name) as tarf:
480-
tarf.extractall(path=dest)
481-
file_names = tarf.getnames()
482-
os.unlink(fp.name)
483-
fp = tempfile.NamedTemporaryFile(suffix='.tar',
484-
delete=False)
485-
elif part.startswith(b'Content-') or part == b'':
486-
continue
487-
else:
488-
fp.write(part)
489-
if fp:
477+
assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() == 'identity'
478+
assert part.headers.get(hdrs.CONTENT_TRANSFER_ENCODING, 'binary').lower() in (
479+
'binary', '8bit', '7bit',
480+
)
481+
fp = tempfile.NamedTemporaryFile(suffix='.tar',
482+
delete=False)
483+
while True:
484+
chunk = await part.read_chunk(DEFAULT_CHUNK_SIZE)
485+
if not chunk:
486+
break
487+
await loop.run_in_executor(None, lambda: fp.write(chunk))
488+
pbar.update(len(chunk))
490489
fp.close()
490+
with tarfile.open(fp.name) as tarf:
491+
tarf.extractall(path=dest)
492+
file_names.extend(tarf.getnames())
491493
os.unlink(fp.name)
492-
result = {'file_names': file_names}
493-
return result
494+
return {'file_names': file_names}
494495

495496
@api_function
496497
async def list_files(self, path: Union[str, Path] = '.'):

src/ai/backend/client/vfolder.py

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
from pathlib import Path
32
from typing import Sequence, Union
43

@@ -7,8 +6,9 @@
76
from tqdm import tqdm
87

98
from .base import api_function
9+
from .compat import current_loop
1010
from .config import DEFAULT_CHUNK_SIZE
11-
from .exceptions import BackendAPIError, BackendClientError
11+
from .exceptions import BackendAPIError
1212
from .request import Request, AttachedFile
1313
from .cli.pretty import ProgressReportingReader
1414

@@ -155,43 +155,42 @@ async def download(self, files: Sequence[Union[str, Path]],
155155
rqst.set_json({
156156
'files': files,
157157
})
158-
try:
159-
async with rqst.fetch() as resp:
160-
if resp.status // 100 != 2:
161-
raise BackendAPIError(resp.status, resp.reason,
162-
await resp.text())
163-
total_bytes = int(resp.headers['X-TOTAL-PAYLOADS-LENGTH'])
164-
tqdm_obj = tqdm(desc='Downloading files',
165-
unit='bytes', unit_scale=True,
166-
total=total_bytes,
167-
disable=not show_progress)
168-
reader = aiohttp.MultipartReader.from_response(resp.raw_response)
169-
with tqdm_obj as pbar:
170-
acc_bytes = 0
171-
while True:
172-
part = await reader.next()
173-
if part is None:
174-
break
175-
assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() == 'identity'
176-
assert part.headers.get(hdrs.CONTENT_TRANSFER_ENCODING, 'binary').lower() in (
177-
'binary', '8bit', '7bit',
178-
)
179-
with open(part.filename, 'wb') as fp:
180-
while True:
181-
chunk = await part.read_chunk(DEFAULT_CHUNK_SIZE)
182-
if not chunk:
183-
break
184-
fp.write(chunk)
185-
acc_bytes += len(chunk)
186-
pbar.update(len(chunk))
187-
pbar.update(total_bytes - acc_bytes)
188-
except (asyncio.CancelledError, asyncio.TimeoutError):
189-
# These exceptions must be bubbled up.
190-
raise
191-
except aiohttp.ClientError as e:
192-
msg = 'Request to the API endpoint has failed.\n' \
193-
'Check your network connection and/or the server status.'
194-
raise BackendClientError(msg) from e
158+
file_names = []
159+
async with rqst.fetch() as resp:
160+
if resp.status // 100 != 2:
161+
raise BackendAPIError(resp.status, resp.reason,
162+
await resp.text())
163+
total_bytes = int(resp.headers['X-TOTAL-PAYLOADS-LENGTH'])
164+
tqdm_obj = tqdm(desc='Downloading files',
165+
unit='bytes', unit_scale=True,
166+
total=total_bytes,
167+
disable=not show_progress)
168+
reader = aiohttp.MultipartReader.from_response(resp.raw_response)
169+
with tqdm_obj as pbar:
170+
loop = current_loop()
171+
acc_bytes = 0
172+
while True:
173+
part = await reader.next()
174+
if part is None:
175+
break
176+
assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() in (
177+
'identity',
178+
'gzip', # Prior to v19.09.4, the server had a bug to set this incorrectly.
179+
# This legacy handling will be removed in v19.12 release.
180+
)
181+
assert part.headers.get(hdrs.CONTENT_TRANSFER_ENCODING, 'binary').lower() in (
182+
'binary', '8bit', '7bit',
183+
)
184+
with open(part.filename, 'wb') as fp:
185+
while True:
186+
chunk = await part.read_chunk(DEFAULT_CHUNK_SIZE)
187+
if not chunk:
188+
break
189+
await loop.run_in_executor(None, lambda: fp.write(chunk))
190+
acc_bytes += len(chunk)
191+
pbar.update(len(chunk))
192+
pbar.update(total_bytes - acc_bytes)
193+
return {'file_names': file_names}
195194

196195
@api_function
197196
async def list_files(self, path: Union[str, Path] = '.'):

0 commit comments

Comments
 (0)