Skip to content

Commit d1ab444

Browse files
[Storage] Support OAuth for import/export managed disks (Azure#22984)
1 parent 9ee41c6 commit d1ab444

12 files changed

+574
-42
lines changed

sdk/storage/azure-storage-blob/CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# Release History
22

3+
## 12.10.0b4 (2022-02-24)
4+
5+
### Features Added
6+
- Updated clients to support both SAS and OAuth together.
7+
- Updated OAuth implementation to use the AAD scope returned in a Bearer challenge.
8+
9+
### Bugs Fixed
10+
- Addressed a few `mypy` typing hint errors.
11+
312
## 12.10.0b3 (2022-02-08)
413

514
This version and all future versions will require Python 3.6+. Python 2.7 is no longer supported.

sdk/storage/azure-storage-blob/azure/storage/blob/_shared/authentication.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# --------------------------------------------------------------------------
66

77
import logging
8+
import re
89
import sys
910

1011
try:
@@ -140,3 +141,38 @@ def on_request(self, request):
140141

141142
self._add_authorization_header(request, string_to_sign)
142143
#logger.debug("String_to_sign=%s", string_to_sign)
144+
145+
146+
class StorageHttpChallenge(object):
147+
def __init__(self, challenge):
148+
""" Parses an HTTP WWW-Authentication Bearer challenge from the Storage service. """
149+
if not challenge:
150+
raise ValueError("Challenge cannot be empty")
151+
152+
self._parameters = {}
153+
self.scheme, trimmed_challenge = challenge.strip().split(" ", 1)
154+
155+
# name=value pairs either comma or space separated with values possibly being
156+
# enclosed in quotes
157+
for item in re.split('[, ]', trimmed_challenge):
158+
comps = item.split("=")
159+
if len(comps) == 2:
160+
key = comps[0].strip(' "')
161+
value = comps[1].strip(' "')
162+
if key:
163+
self._parameters[key] = value
164+
165+
# Extract and verify required parameters
166+
self.authorization_uri = self._parameters.get('authorization_uri')
167+
if not self.authorization_uri:
168+
raise ValueError("Authorization Uri not found")
169+
170+
self.resource_id = self._parameters.get('resource_id')
171+
if not self.resource_id:
172+
raise ValueError("Resource id not found")
173+
174+
uri_path = urlparse(self.authorization_uri).path.lstrip("/")
175+
self.tenant_id = uri_path.split("/")[0]
176+
177+
def get_value(self, key):
178+
return self._parameters.get(key)

sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,30 @@
2525
from azure.core.pipeline import Pipeline
2626
from azure.core.pipeline.transport import RequestsTransport, HttpTransport
2727
from azure.core.pipeline.policies import (
28-
RedirectPolicy,
28+
AzureSasCredentialPolicy,
2929
ContentDecodePolicy,
30-
BearerTokenCredentialPolicy,
31-
ProxyPolicy,
3230
DistributedTracingPolicy,
3331
HttpLoggingPolicy,
32+
RedirectPolicy,
33+
ProxyPolicy,
3434
UserAgentPolicy,
35-
AzureSasCredentialPolicy
3635
)
3736

38-
from .constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT
37+
from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT, SERVICE_HOST_BASE
3938
from .models import LocationMode
4039
from .authentication import SharedKeyCredentialPolicy
4140
from .shared_access_signature import QueryStringConstants
4241
from .request_handlers import serialize_batch_body, _get_batch_request_delimiter
4342
from .policies import (
44-
StorageHeadersPolicy,
43+
ExponentialRetry,
44+
StorageBearerTokenCredentialPolicy,
4545
StorageContentValidation,
46+
StorageHeadersPolicy,
47+
StorageHosts,
48+
StorageLoggingPolicy,
4649
StorageRequestHook,
4750
StorageResponseHook,
48-
StorageLoggingPolicy,
49-
StorageHosts,
5051
QueueMessagePolicy,
51-
ExponentialRetry,
5252
)
5353
from .._version import VERSION
5454
from .response_handlers import process_storage_error, PartialBatchErrorException
@@ -208,18 +208,18 @@ def _format_query_string(self, sas_token, credential, snapshot=None, share_snaps
208208
if sas_token and isinstance(credential, AzureSasCredential):
209209
raise ValueError(
210210
"You cannot use AzureSasCredential when the resource URI also contains a Shared Access Signature.")
211-
if sas_token and not credential:
212-
query_str += sas_token
213-
elif is_credential_sastoken(credential):
211+
if is_credential_sastoken(credential):
214212
query_str += credential.lstrip("?")
215213
credential = None
214+
elif sas_token:
215+
query_str += sas_token
216216
return query_str.rstrip("?&"), credential
217217

218218
def _create_pipeline(self, credential, **kwargs):
219219
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
220220
self._credential_policy = None
221221
if hasattr(credential, "get_token"):
222-
self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
222+
self._credential_policy = StorageBearerTokenCredentialPolicy(credential)
223223
elif isinstance(credential, SharedKeyCredentialPolicy):
224224
self._credential_policy = credential
225225
elif isinstance(credential, AzureSasCredential):

sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,25 @@
1515
from azure.core.async_paging import AsyncList
1616
from azure.core.exceptions import HttpResponseError
1717
from azure.core.pipeline.policies import (
18-
ContentDecodePolicy,
19-
AsyncBearerTokenCredentialPolicy,
2018
AsyncRedirectPolicy,
19+
AzureSasCredentialPolicy,
20+
ContentDecodePolicy,
2121
DistributedTracingPolicy,
2222
HttpLoggingPolicy,
23-
AzureSasCredentialPolicy,
2423
)
2524
from azure.core.pipeline.transport import AsyncHttpTransport
2625

27-
from .constants import STORAGE_OAUTH_SCOPE, CONNECTION_TIMEOUT, READ_TIMEOUT
26+
from .constants import CONNECTION_TIMEOUT, READ_TIMEOUT
2827
from .authentication import SharedKeyCredentialPolicy
2928
from .base_client import create_configuration
3029
from .policies import (
3130
StorageContentValidation,
32-
StorageRequestHook,
33-
StorageHosts,
3431
StorageHeadersPolicy,
32+
StorageHosts,
33+
StorageRequestHook,
3534
QueueMessagePolicy
3635
)
37-
from .policies_async import AsyncStorageResponseHook
36+
from .policies_async import AsyncStorageBearerTokenCredentialPolicy, AsyncStorageResponseHook
3837

3938
from .response_handlers import process_storage_error, PartialBatchErrorException
4039

@@ -70,7 +69,7 @@ def _create_pipeline(self, credential, **kwargs):
7069
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
7170
self._credential_policy = None
7271
if hasattr(credential, 'get_token'):
73-
self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
72+
self._credential_policy = AsyncStorageBearerTokenCredentialPolicy(credential)
7473
elif isinstance(credential, SharedKeyCredentialPolicy):
7574
self._credential_policy = credential
7675
elif isinstance(credential, AzureSasCredential):

sdk/storage/azure-storage-blob/azure/storage/blob/_shared/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# 4000MB (max block size)/ 50KB/s (an arbitrarily chosen minimum upload speed)
2323
READ_TIMEOUT = 80000
2424

25+
DEFAULT_OAUTH_SCOPE = "/.default"
2526
STORAGE_OAUTH_SCOPE = "https://storage.azure.com/.default"
2627

2728
SERVICE_HOST_BASE = 'core.windows.net'

sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,17 @@
3030
)
3131

3232
from azure.core.pipeline.policies import (
33+
BearerTokenCredentialPolicy,
3334
HeadersPolicy,
34-
SansIOHTTPPolicy,
35-
NetworkTraceLoggingPolicy,
3635
HTTPPolicy,
37-
RequestHistory
36+
NetworkTraceLoggingPolicy,
37+
RequestHistory,
38+
SansIOHTTPPolicy,
3839
)
3940
from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError
4041

42+
from .authentication import StorageHttpChallenge
43+
from .constants import DEFAULT_OAUTH_SCOPE, STORAGE_OAUTH_SCOPE
4144
from .models import LocationMode
4245

4346
try:
@@ -46,6 +49,7 @@
4649
_unicode_type = str
4750

4851
if TYPE_CHECKING:
52+
from azure.core.credentials import TokenCredential
4953
from azure.core.pipeline import PipelineRequest, PipelineResponse
5054

5155

@@ -292,26 +296,36 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument
292296

293297
def send(self, request):
294298
# type: (PipelineRequest) -> PipelineResponse
295-
data_stream_total = request.context.get('data_stream_total') or \
296-
request.context.options.pop('data_stream_total', None)
297-
download_stream_current = request.context.get('download_stream_current') or \
298-
request.context.options.pop('download_stream_current', None)
299-
upload_stream_current = request.context.get('upload_stream_current') or \
300-
request.context.options.pop('upload_stream_current', None)
299+
# Values could be 0
300+
data_stream_total = request.context.get('data_stream_total')
301+
if data_stream_total is None:
302+
data_stream_total = request.context.options.pop('data_stream_total', None)
303+
download_stream_current = request.context.get('download_stream_current')
304+
if download_stream_current is None:
305+
download_stream_current = request.context.options.pop('download_stream_current', None)
306+
upload_stream_current = request.context.get('upload_stream_current')
307+
if upload_stream_current is None:
308+
upload_stream_current = request.context.options.pop('upload_stream_current', None)
309+
301310
response_callback = request.context.get('response_callback') or \
302311
request.context.options.pop('raw_response_hook', self._response_callback)
303312

304313
response = self.next.send(request)
314+
305315
will_retry = is_retry(response, request.context.options.get('mode'))
306-
if not will_retry and download_stream_current is not None:
316+
# Auth error could come from Bearer challenge, in which case this request will be made again
317+
is_auth_error = response.http_response.status_code == 401
318+
should_update_counts = not (will_retry or is_auth_error)
319+
320+
if should_update_counts and download_stream_current is not None:
307321
download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
308322
if data_stream_total is None:
309323
content_range = response.http_response.headers.get('Content-Range')
310324
if content_range:
311325
data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
312326
else:
313327
data_stream_total = download_stream_current
314-
elif not will_retry and upload_stream_current is not None:
328+
elif should_update_counts and upload_stream_current is not None:
315329
upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
316330
for pipeline_obj in [request, response]:
317331
pipeline_obj.context['data_stream_total'] = data_stream_total
@@ -620,3 +634,24 @@ def get_backoff_time(self, settings):
620634
if self.backoff > self.random_jitter_range else 0
621635
random_range_end = self.backoff + self.random_jitter_range
622636
return random_generator.uniform(random_range_start, random_range_end)
637+
638+
639+
class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
640+
""" Custom Bearer token credential policy for following Storage Bearer challenges """
641+
642+
def __init__(self, credential, **kwargs):
643+
# type: (TokenCredential, **Any) -> None
644+
super(StorageBearerTokenCredentialPolicy, self).__init__(credential, STORAGE_OAUTH_SCOPE, **kwargs)
645+
646+
def on_challenge(self, request, response):
647+
# type: (PipelineRequest, PipelineResponse) -> bool
648+
try:
649+
auth_header = response.http_response.headers.get("WWW-Authenticate")
650+
challenge = StorageHttpChallenge(auth_header)
651+
except ValueError:
652+
return False
653+
654+
scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
655+
self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
656+
657+
return True

sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py

Lines changed: 42 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,15 @@
1010
import logging
1111
from typing import Any, TYPE_CHECKING
1212

13-
from azure.core.pipeline.policies import AsyncHTTPPolicy
13+
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy
1414
from azure.core.exceptions import AzureError
1515

16+
from .authentication import StorageHttpChallenge
17+
from .constants import DEFAULT_OAUTH_SCOPE, STORAGE_OAUTH_SCOPE
1618
from .policies import is_retry, StorageRetryPolicy
1719

1820
if TYPE_CHECKING:
21+
from azure.core.credentials_async import AsyncTokenCredential
1922
from azure.core.pipeline import PipelineRequest, PipelineResponse
2023

2124

@@ -44,28 +47,37 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument
4447

4548
async def send(self, request):
4649
# type: (PipelineRequest) -> PipelineResponse
47-
data_stream_total = request.context.get('data_stream_total') or \
48-
request.context.options.pop('data_stream_total', None)
49-
download_stream_current = request.context.get('download_stream_current') or \
50-
request.context.options.pop('download_stream_current', None)
51-
upload_stream_current = request.context.get('upload_stream_current') or \
52-
request.context.options.pop('upload_stream_current', None)
50+
# Values could be 0
51+
data_stream_total = request.context.get('data_stream_total')
52+
if data_stream_total is None:
53+
data_stream_total = request.context.options.pop('data_stream_total', None)
54+
download_stream_current = request.context.get('download_stream_current')
55+
if download_stream_current is None:
56+
download_stream_current = request.context.options.pop('download_stream_current', None)
57+
upload_stream_current = request.context.get('upload_stream_current')
58+
if upload_stream_current is None:
59+
upload_stream_current = request.context.options.pop('upload_stream_current', None)
60+
5361
response_callback = request.context.get('response_callback') or \
5462
request.context.options.pop('raw_response_hook', self._response_callback)
5563

5664
response = await self.next.send(request)
5765
await response.http_response.load_body()
5866

5967
will_retry = is_retry(response, request.context.options.get('mode'))
60-
if not will_retry and download_stream_current is not None:
68+
# Auth error could come from Bearer challenge, in which case this request will be made again
69+
is_auth_error = response.http_response.status_code == 401
70+
should_update_counts = not (will_retry or is_auth_error)
71+
72+
if should_update_counts and download_stream_current is not None:
6173
download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
6274
if data_stream_total is None:
6375
content_range = response.http_response.headers.get('Content-Range')
6476
if content_range:
6577
data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
6678
else:
6779
data_stream_total = download_stream_current
68-
elif not will_retry and upload_stream_current is not None:
80+
elif should_update_counts and upload_stream_current is not None:
6981
upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
7082
for pipeline_obj in [request, response]:
7183
pipeline_obj.context['data_stream_total'] = data_stream_total
@@ -218,3 +230,24 @@ def get_backoff_time(self, settings):
218230
if self.backoff > self.random_jitter_range else 0
219231
random_range_end = self.backoff + self.random_jitter_range
220232
return random_generator.uniform(random_range_start, random_range_end)
233+
234+
235+
class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
236+
""" Custom Bearer token credential policy for following Storage Bearer challenges """
237+
238+
def __init__(self, credential, **kwargs):
239+
# type: (AsyncTokenCredential, **Any) -> None
240+
super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, STORAGE_OAUTH_SCOPE, **kwargs)
241+
242+
async def on_challenge(self, request, response):
243+
# type: (PipelineRequest, PipelineResponse) -> bool
244+
try:
245+
auth_header = response.http_response.headers.get("WWW-Authenticate")
246+
challenge = StorageHttpChallenge(auth_header)
247+
except ValueError:
248+
return False
249+
250+
scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
251+
await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
252+
253+
return True

sdk/storage/azure-storage-blob/azure/storage/blob/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
# license information.
55
# --------------------------------------------------------------------------
66

7-
VERSION = "12.10.0b3"
7+
VERSION = "12.10.0b4"

0 commit comments

Comments
 (0)