Skip to content

Commit f3d88e6

Browse files
committed
Trying async download
1 parent 323cddd commit f3d88e6

File tree

4 files changed

+279
-0
lines changed

4 files changed

+279
-0
lines changed

google/cloud/storage/_experimental/asyncio/json/__init__.py

Whitespace-only changes.
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""Async classes for holding the credentials, and connection"""
2+
3+
4+
import os
5+
6+
import google.auth._credentials_async
7+
from google.cloud.client import _ClientProjectMixin
8+
from google.cloud.client import _CREDENTIALS_REFRESH_TIMEOUT
9+
from google.auth.transport import _aiohttp_requests as async_requests
10+
from google.cloud.storage import retry as storage_retry
11+
from google.auth import _default_async
12+
from google.api_core import retry_async
13+
14+
15+
DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry(predicate=storage_retry._should_retry)
16+
17+
class Client:
18+
SCOPE = None
19+
# Would be overridden by child classes.
20+
21+
def __init__(self):
22+
async_creds, _ = _default_async.default_async(scopes=self.SCOPE)
23+
self._async_credentials = google.auth._credentials_async.with_scopes_if_required(
24+
async_creds, scopes=self.SCOPE
25+
)
26+
self._async_http_internal = None
27+
28+
@property
29+
def _async_http(self):
30+
if self._async_http_internal is None:
31+
self._async_http_internal = async_requests.AuthorizedSession(
32+
self._async_credentials,
33+
refresh_timeout=_CREDENTIALS_REFRESH_TIMEOUT,
34+
)
35+
return self._async_http_internal
36+
37+
async def __aenter__(self):
38+
return self
39+
40+
async def __aexit__(self, exc_type, exc_val, exc_tb):
41+
del exc_type, exc_val, exc_tb
42+
if self._async_http_internal is not None:
43+
await self._async_http_internal.close()
44+
45+
46+
class ClientWithProjectAsync(Client, _ClientProjectMixin):
47+
_SET_PROJECT = True
48+
49+
def __init__(self, project=None):
50+
_ClientProjectMixin.__init__(self, project=project)
51+
Client.__init__(self)
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""Async client for SDK downloads"""
2+
3+
import os
4+
import asyncio
5+
6+
7+
from google.cloud.storage._experimental.asyncio.json import _helpers
8+
from google.cloud.storage._experimental.asyncio.json import download
9+
from google.cloud.storage._helpers import _get_environ_project
10+
from google.cloud.storage._helpers import _DEFAULT_SCHEME
11+
from google.cloud.storage._helpers import _STORAGE_HOST_TEMPLATE
12+
from google.cloud.storage._helpers import _DEFAULT_UNIVERSE_DOMAIN
13+
from google.cloud.storage import blob
14+
15+
16+
class AsyncClient(_helpers.ClientWithProjectAsync):
17+
18+
SCOPE = (
19+
"https://www.googleapis.com/auth/devstorage.full_control",
20+
"https://www.googleapis.com/auth/devstorage.read_only",
21+
"https://www.googleapis.com/auth/devstorage.read_write",
22+
)
23+
24+
@property
25+
def api_endpoint(self):
26+
return _DEFAULT_SCHEME + _STORAGE_HOST_TEMPLATE.format(
27+
universe_domain=_DEFAULT_UNIVERSE_DOMAIN
28+
)
29+
30+
def _get_download_url(self, blob_obj):
31+
return f'{self.api_endpoint}/download/storage/v1/b/{blob_obj.bucket.name}/o/{blob_obj.name}?alt=media'
32+
33+
async def _perform_download(
34+
self,
35+
transport,
36+
file_obj,
37+
download_url,
38+
headers,
39+
start=None,
40+
end=None,
41+
timeout=None,
42+
checksum="md5",
43+
retry=_helpers.DEFAULT_ASYNC_RETRY,
44+
sequential_read=False,
45+
):
46+
download_obj = download.DownloadAsync(
47+
download_url,
48+
stream=file_obj,
49+
headers=headers,
50+
start=start,
51+
end=end,
52+
checksum=checksum,
53+
retry=retry,
54+
sequential_read=sequential_read,
55+
)
56+
await download_obj.consume(transport, timeout=timeout)
57+
58+
def _check_if_sliced_download_is_eligible(self, obj_size, checksum):
59+
if obj_size < 1024*1024*1024:
60+
return False
61+
# Need to support checksum validations for parallel downloads.
62+
return checksum==None
63+
64+
async def download_to_file(
65+
self,
66+
blob_obj,
67+
filename,
68+
start=None,
69+
end=None,
70+
timeout=None,
71+
checksum="md5",
72+
retry=_helpers.DEFAULT_ASYNC_RETRY,
73+
sequential_read=False,
74+
):
75+
download_url = self._get_download_url(blob_obj)
76+
headers = blob._get_encryption_headers(blob_obj._encryption_key)
77+
headers["accept-encoding"] = "gzip"
78+
headers = {
79+
**blob._get_default_headers('testing'),
80+
**headers,
81+
}
82+
83+
transport = self._async_http
84+
if not blob_obj.size:
85+
blob_obj.reload()
86+
obj_size = blob_obj.size
87+
try:
88+
if not sequential_read and self._check_if_sliced_download_is_eligible(obj_size, checksum): # 1GB
89+
print("Sliced Download Preferred, and Starting...")
90+
_parts = 5
91+
chunks_offset = [0] + [obj_size//_parts]*(_parts-1) + [obj_size - obj_size//_parts*(_parts-1)]
92+
for i in range(1, _parts+1):
93+
chunks_offset[i]+=chunks_offset[i-1]
94+
95+
with open(filename, 'wb') as _: pass # trunacates the file to zero, and keeps the file.
96+
97+
tasks, file_handles = [], []
98+
try:
99+
for idx in range(_parts):
100+
file_handle = open(filename, 'r+b')
101+
file_handle.seek(chunks_offset[idx])
102+
tasks.append(
103+
self._perform_download(
104+
transport,
105+
file_handle,
106+
download_url,
107+
headers,
108+
chunks_offset[idx],
109+
chunks_offset[idx+1]-1,
110+
timeout=timeout,
111+
checksum=checksum,
112+
retry=retry,
113+
sequential_read=sequential_read,
114+
)
115+
)
116+
file_handles.append(file_handle)
117+
await asyncio.gather(*tasks)
118+
finally:
119+
for file_handle in file_handles:
120+
file_handle.close()
121+
else:
122+
print("Sequential Download Preferred, and Starting...")
123+
with open(filename, "wb") as file_obj:
124+
await self._perform_download(
125+
transport,
126+
file_obj,
127+
download_url,
128+
headers,
129+
start,
130+
end,
131+
timeout=timeout,
132+
checksum=checksum,
133+
retry=retry,
134+
sequential_read=sequential_read,
135+
)
136+
except (blob.DataCorruption, blob.NotFound):
137+
os.remove(filename)
138+
raise
139+
except blob.InvalidResponse as exc:
140+
blob._raise_from_invalid_response(exc)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Async based download code"""
2+
3+
import http
4+
import aiohttp
5+
6+
from google.cloud.storage._experimental.asyncio.json._helpers import DEFAULT_ASYNC_RETRY
7+
from google.cloud.storage._media.requests import _request_helpers
8+
from google.cloud.storage._media import _download
9+
from google.cloud.storage._media import _helpers
10+
from google.cloud.storage._media.requests import download as storage_download
11+
12+
13+
class DownloadAsync(_request_helpers.RequestsMixin, _download.Download):
14+
15+
def __init__(
16+
self,
17+
media_url,
18+
stream=None,
19+
start=None,
20+
end=None,
21+
headers=None,
22+
checksum="md5",
23+
retry=DEFAULT_ASYNC_RETRY,
24+
sequential_read=False,
25+
):
26+
super(DownloadAsync, self).__init__(
27+
media_url, stream=stream, start=start, end=end, headers=headers, checksum=checksum, retry=retry
28+
)
29+
self.sequential_read = sequential_read
30+
31+
async def _write_to_stream(self, response):
32+
if self._expected_checksum or self._checksum_object:
33+
# Presevre it across calls.
34+
expected_checksum = self._expected_checksum
35+
checksum_object = self._checksum_object
36+
if not self.sequential_read:
37+
# If we've not set expected checksum, or checksum object yet, and if it is not
38+
# sequential download, API would not return us hash value for each chunk.
39+
# We could ideally compute the crc32c checksum for each chunk, and later combine them
40+
# and check, However for prototype not implementing it.
41+
expected_checksum = None
42+
checksum_object = _helpers._DoNothingHash()
43+
self._expected_checksum = expected_checksum
44+
self._checksum_object = checksum_object
45+
else:
46+
# Sequential read, so fetch the hash from the headers.
47+
expected_checksum, checksum_object = _helpers._get_expected_checksum(
48+
response, self._get_headers, self.media_url, checksum_type=self.checksum
49+
)
50+
self._expected_checksum = expected_checksum
51+
self._checksum_object = checksum_object
52+
53+
async with response:
54+
chunk_size = 4096 * 32
55+
async for chunk in response.content.iter_chunked(chunk_size):
56+
self._stream.write(chunk) # for some reason, aiofiles shows worse performance.
57+
self._bytes_downloaded += len(chunk)
58+
checksum_object.update(chunk)
59+
60+
if (
61+
expected_checksum is not None
62+
and response.status != http.client.PARTIAL_CONTENT
63+
):
64+
actual_checksum = _helpers.prepare_checksum_digest(checksum_object.digest())
65+
66+
if actual_checksum != expected_checksum:
67+
raise storage_download.DataCorruption('Corrupted download!')
68+
69+
async def consume(
70+
self,
71+
transport,
72+
timeout=aiohttp.ClientTimeout(total=None, sock_read=300),
73+
):
74+
method, _, payload, headers = self._prepare_request()
75+
request_kwargs = {
76+
"data": payload,
77+
"headers": headers,
78+
"timeout": timeout,
79+
}
80+
async def retriable_request():
81+
url = self.media_url
82+
result = await transport.request(method, url, **request_kwargs)
83+
await self._write_to_stream(result)
84+
if result.status != 200:
85+
result.raise_for_status()
86+
return result
87+
88+
return await _request_helpers.wait_and_retry(retriable_request, self._retry_strategy)

0 commit comments

Comments
 (0)