Skip to content

Commit 471cda1

Browse files
[Storage] Add Blob download perf test that uses HTTP library directly (#44111)
1 parent cff4c23 commit 471cda1

File tree

3 files changed

+82
-4
lines changed

3 files changed

+82
-4
lines changed

sdk/storage/azure-storage-blob/tests/perfstress_tests/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ The tests currently written for the T2 SDK:
7777
- `UploadFromFileTest` Uploads a local file of `size` bytes to a new Blob.
7878
- `UploadBlockTest` Upload a single block of `size` bytes within a Blob.
7979
- `DownloadTest` Download a stream of `size` bytes.
80+
- `DownloadToFileTest` Downloads a blob of `size` bytes to a local file.
81+
- `DownloadBasicTest` Downloads using basic HTTP library primitives, ignoring content.
8082
- `ListBlobsTest` List a specified number of blobs.
8183

8284
### T1 Tests

sdk/storage/azure-storage-blob/tests/perfstress_tests/_test_base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
class _ServiceTest(PerfStressTest):
1919
service_client = None
2020
async_service_client = None
21+
sync_token_credential = None
22+
async_token_credential = None
2123

2224
def __init__(self, arguments):
2325
super().__init__(arguments)
@@ -42,13 +44,13 @@ def __init__(self, arguments):
4244
use_managed_identity = os.environ.get("AZURE_STORAGE_USE_MANAGED_IDENTITY", "false").lower() == "true"
4345
if self.args.use_entra_id or use_managed_identity:
4446
account_name = self.get_from_env("AZURE_STORAGE_ACCOUNT_NAME")
45-
sync_token_credential = SyncManagedIdentityCredential() if use_managed_identity else self.get_credential(is_async=False)
46-
async_token_credential = AsyncManagedIdentityCredential() if use_managed_identity else self.get_credential(is_async=True)
47+
self.sync_token_credential = SyncManagedIdentityCredential() if use_managed_identity else self.get_credential(is_async=False)
48+
self.async_token_credential = AsyncManagedIdentityCredential() if use_managed_identity else self.get_credential(is_async=True)
4749

4850
# We assume these tests will only be run on the Azure public cloud for now.
4951
url = f"https://{account_name}.blob.core.windows.net"
50-
_ServiceTest.service_client = SyncBlobServiceClient(account_url=url, credential=sync_token_credential, **self._client_kwargs)
51-
_ServiceTest.async_service_client = AsyncBlobServiceClient(account_url=url, credential=async_token_credential, **self._client_kwargs)
52+
_ServiceTest.service_client = SyncBlobServiceClient(account_url=url, credential=self.sync_token_credential, **self._client_kwargs)
53+
_ServiceTest.async_service_client = AsyncBlobServiceClient(account_url=url, credential=self.async_token_credential, **self._client_kwargs)
5254
else:
5355
connection_string = self.get_from_env("AZURE_STORAGE_CONNECTION_STRING")
5456
_ServiceTest.service_client = SyncBlobServiceClient.from_connection_string(conn_str=connection_string, **self._client_kwargs)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
import asyncio
7+
import aiohttp
8+
import requests
9+
from concurrent.futures import ThreadPoolExecutor
10+
11+
from devtools_testutils.perfstress_tests import RandomStream
12+
13+
from ._test_base import _BlobTest
14+
15+
16+
TOKEN_SCOPE = "https://storage.azure.com/.default"
17+
18+
class DownloadBasicTest(_BlobTest):
19+
def __init__(self, arguments):
20+
super().__init__(arguments)
21+
22+
async def global_setup(self):
23+
await super().global_setup()
24+
data = RandomStream(self.args.size)
25+
await self.async_blob_client.upload_blob(data, max_concurrency=10)
26+
27+
self.chunk_size = self.args.max_block_size or 4 * 1024 * 1024
28+
if self.async_token_credential:
29+
token = await self.async_token_credential.get_token(TOKEN_SCOPE)
30+
self.auth_header = "Bearer " + token.token
31+
else:
32+
raise NotImplementedError("DownloadBasicTest requires Entra ID authentication.")
33+
34+
def run_sync(self):
35+
chunk_ranges = self._get_chunk_ranges()
36+
with ThreadPoolExecutor(self.args.max_concurrency) as executor:
37+
with requests.sessions.Session() as session:
38+
executor.map(lambda r: self.download_chunk_requests(
39+
session, r[0], r[1]), chunk_ranges)
40+
41+
async def run_async(self):
42+
chunk_ranges = self._get_chunk_ranges()
43+
semaphore = asyncio.Semaphore(self.args.max_concurrency)
44+
45+
async with aiohttp.ClientSession() as session:
46+
tasks = [self.download_chunk_aiohttp(session, offset, end, semaphore) for offset, end in chunk_ranges]
47+
await asyncio.gather(*tasks)
48+
49+
def _get_chunk_ranges(self):
50+
chunk_ranges = []
51+
offset = 0
52+
while offset < self.args.size:
53+
end = min(offset + self.chunk_size - 1, self.args.size - 1)
54+
chunk_ranges.append((offset, end))
55+
offset = end + 1
56+
return chunk_ranges
57+
58+
def download_chunk_requests(self, session: requests.sessions.Session, offset: int, end: int):
59+
headers = {'x-ms-version': self.blob_client.api_version, 'Range': f'bytes={offset}-{end}', 'Authorization': self.auth_header}
60+
response = session.get(self.blob_client.url, headers=headers)
61+
62+
if response.status_code in (200, 206):
63+
pass
64+
else:
65+
raise Exception(f"Download failed with status code {response.status_code}")
66+
67+
async def download_chunk_aiohttp(self, session: aiohttp.ClientSession, offset: int, end: int, semaphore: asyncio.Semaphore):
68+
async with semaphore:
69+
headers = {'x-ms-version': self.blob_client.api_version, 'Range': f'bytes={offset}-{end}', 'Authorization': self.auth_header}
70+
async with session.get(self.blob_client.url, headers=headers) as response:
71+
if response.status in (200, 206):
72+
await response.read()
73+
else:
74+
raise Exception(f"Download failed with status code {response.status}")

0 commit comments

Comments
 (0)