Skip to content

Commit 9c28b76

Browse files
authored
[KeyVault] Keyvault Keys to Test Proxy (Azure#24165)
* move conftest into the tests folder * test proxy changes * new recordings * more recordings for crud * sync test recordings * move over to test proxy * kv async recordings * simple clean ups * recordings * clean up imports * pick right vault name * clean up * fix test parse id offline test * override pytest default event loop * fix for async tests, change to aiohttp request * remove commented code * formatting fixes * Delete vcrpy recordings * with block for async client * clean up * code clean ups * move keys specific methods in to a separate class * PR comments * refactor test to use preparer
1 parent b8bcbd5 commit 9c28b76

File tree

944 files changed

+417642
-320902
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

944 files changed

+417642
-320902
lines changed

sdk/keyvault/azure-keyvault-keys/conftest.py

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
import json
6+
import os
7+
8+
import pytest
9+
from azure.core.pipeline import AsyncPipeline
10+
from azure.core.pipeline.transport import AioHttpTransport, HttpRequest
11+
from azure.keyvault.keys import KeyReleasePolicy
12+
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
13+
from devtools_testutils import AzureRecordedTestCase
14+
15+
16+
async def get_attestation_token(attestation_uri):
17+
request = HttpRequest("GET", "{}/generate-test-token".format(attestation_uri))
18+
async with AsyncPipeline(transport=AioHttpTransport()) as pipeline:
19+
response = await pipeline.run(request)
20+
return json.loads(response.http_response.text())["token"]
21+
22+
23+
def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
24+
"""returns a test decorator for test parameterization"""
25+
params = [
26+
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
27+
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
28+
]
29+
return params
30+
31+
32+
def get_release_policy(attestation_uri, **kwargs):
33+
release_policy_json = {
34+
"anyOf": [
35+
{
36+
"anyOf": [
37+
{
38+
"claim": "sdk-test",
39+
"equals": True
40+
}
41+
],
42+
"authority": attestation_uri.rstrip("/") + "/"
43+
}
44+
],
45+
"version": "1.0.0"
46+
}
47+
policy_string = json.dumps(release_policy_json).encode()
48+
return KeyReleasePolicy(policy_string, **kwargs)
49+
50+
51+
def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
52+
"""generates a list of parameter pairs for test case parameterization, where [x, y] = [api_version, is_hsm]"""
53+
combinations = []
54+
versions = api_versions or ApiVersion
55+
hsm_supported_versions = {ApiVersion.V7_2, ApiVersion.V7_3}
56+
57+
for api_version in versions:
58+
if not only_vault and api_version in hsm_supported_versions:
59+
combinations.append([api_version, True])
60+
if not only_hsm:
61+
combinations.append([api_version, False])
62+
return combinations
63+
64+
65+
def is_public_cloud():
66+
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))
67+
68+
69+
class AsyncKeysClientPreparer(AzureRecordedTestCase):
70+
def __init__(self, *args, **kwargs):
71+
vault_playback_url = "https://vaultname.vault.azure.net"
72+
hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
73+
self.is_logging_enabled = kwargs.pop("logging_enable", True)
74+
75+
if self.is_live:
76+
self.vault_url = os.environ["AZURE_KEYVAULT_URL"]
77+
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL")
78+
else:
79+
self.vault_url = vault_playback_url
80+
self.managed_hsm_url = hsm_playback_url
81+
82+
self._set_mgmt_settings_real_values()
83+
84+
def __call__(self, fn):
85+
async def _preparer(test_class, api_version, is_hsm, **kwargs):
86+
87+
self._skip_if_not_configured(api_version, is_hsm)
88+
if not self.is_logging_enabled:
89+
kwargs.update({"logging_enable": False})
90+
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
91+
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)
92+
async with client:
93+
await fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)
94+
95+
return _preparer
96+
97+
98+
99+
def create_key_client(self, vault_uri, **kwargs):
100+
101+
from azure.keyvault.keys.aio import KeyClient
102+
103+
credential = self.get_credential(KeyClient, is_async=True)
104+
105+
return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)
106+
107+
def _set_mgmt_settings_real_values(self):
108+
if self.is_live:
109+
os.environ["AZURE_TENANT_ID"] = os.environ["KEYVAULT_TENANT_ID"]
110+
os.environ["AZURE_CLIENT_ID"] = os.environ["KEYVAULT_CLIENT_ID"]
111+
os.environ["AZURE_CLIENT_SECRET"] = os.environ["KEYVAULT_CLIENT_SECRET"]
112+
113+
def _skip_if_not_configured(self, api_version, is_hsm):
114+
if self.is_live and api_version != DEFAULT_VERSION:
115+
pytest.skip("This test only uses the default API version for live tests")
116+
if self.is_live and is_hsm and self.managed_hsm_url is None:
117+
pytest.skip("No HSM endpoint for live testing")
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
3+
import pytest
4+
from devtools_testutils import AzureRecordedTestCase
5+
6+
7+
class KeysTestCase(AzureRecordedTestCase):
8+
def _get_attestation_uri(self):
9+
playback_uri = "https://fakeattestation.azurewebsites.net"
10+
if self.is_live:
11+
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
12+
real_uri = real_uri.rstrip('/')
13+
if real_uri is None:
14+
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
15+
return real_uri
16+
return playback_uri
17+
18+
def create_crypto_client(self, key, **kwargs):
19+
if kwargs.pop("is_async", False):
20+
from azure.keyvault.keys.crypto.aio import CryptographyClient
21+
credential = self.get_credential(CryptographyClient,is_async=True)
22+
else:
23+
from azure.keyvault.keys.crypto import CryptographyClient
24+
credential = self.get_credential(CryptographyClient)
25+
26+
return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)

sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,13 @@
44
# ------------------------------------
55
import time
66

7-
from azure_devtools.scenario_tests.patches import patch_time_sleep_api
8-
from devtools_testutils import AzureTestCase
7+
from azure.keyvault.keys._shared import HttpChallengeCache
8+
from devtools_testutils import AzureRecordedTestCase
99

1010

11-
class KeyVaultTestCase(AzureTestCase):
12-
def __init__(self, *args, **kwargs):
13-
if "match_body" not in kwargs:
14-
kwargs["match_body"] = True
1511

16-
super(KeyVaultTestCase, self).__init__(*args, **kwargs)
17-
self.replay_patches.append(patch_time_sleep_api)
18-
19-
def setUp(self):
20-
self.list_test_size = 7
21-
super(KeyVaultTestCase, self).setUp()
2212

13+
class KeyVaultTestCase(AzureRecordedTestCase):
2314
def get_resource_name(self, name):
2415
"""helper to create resources with a consistent, test-indicative prefix"""
2516
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))
@@ -48,3 +39,7 @@ def _poll_until_exception(self, fn, expected_exception, max_retries=20, retry_de
4839
return
4940

5041
self.fail("expected exception {expected_exception} was not raised")
42+
43+
def teardown_method(self, method):
44+
HttpChallengeCache.clear()
45+
assert len(HttpChallengeCache._cache) == 0

sdk/keyvault/azure-keyvault-keys/tests/_shared/test_case_async.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,11 @@
44
# ------------------------------------
55
import asyncio
66

7-
from azure_devtools.scenario_tests.patches import mock_in_unit_test
8-
from devtools_testutils import AzureTestCase
7+
from devtools_testutils import AzureRecordedTestCase
8+
from azure.keyvault.keys._shared import HttpChallengeCache
99

1010

11-
def skip_sleep(unit_test):
12-
async def immediate_return(_):
13-
return
14-
15-
return mock_in_unit_test(unit_test, "asyncio.sleep", immediate_return)
16-
17-
18-
class KeyVaultTestCase(AzureTestCase):
19-
def __init__(self, *args, match_body=True, **kwargs):
20-
super().__init__(*args, match_body=match_body, **kwargs)
21-
self.replay_patches.append(skip_sleep)
22-
23-
def setUp(self):
24-
self.list_test_size = 7
25-
super(KeyVaultTestCase, self).setUp()
26-
11+
class KeyVaultTestCase(AzureRecordedTestCase):
2712
def get_resource_name(self, name):
2813
"""helper to create resources with a consistent, test-indicative prefix"""
2914
return super(KeyVaultTestCase, self).get_resource_name("livekvtest{}".format(name))
@@ -51,3 +36,7 @@ async def _poll_until_exception(self, fn, expected_exception, max_retries=20, re
5136
except expected_exception:
5237
return
5338
self.fail("expected exception {expected_exception} was not raised")
39+
40+
def teardown_method(self, method):
41+
HttpChallengeCache.clear()
42+
assert len(HttpChallengeCache._cache) == 0

sdk/keyvault/azure-keyvault-keys/tests/_test_case.py

Lines changed: 29 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,15 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
import functools
65
import json
76
import os
87

8+
import pytest
99
from azure.core.pipeline import Pipeline
1010
from azure.core.pipeline.transport import HttpRequest, RequestsTransport
1111
from azure.keyvault.keys import KeyReleasePolicy
12-
from azure.keyvault.keys._shared import HttpChallengeCache
13-
from azure.keyvault.keys._shared.client_base import ApiVersion, DEFAULT_VERSION
14-
from devtools_testutils import AzureTestCase
15-
from parameterized import parameterized, param
16-
import pytest
17-
from six.moves.urllib_parse import urlparse
18-
19-
20-
def client_setup(testcase_func):
21-
"""decorator that creates a client to be passed in to a test method"""
22-
23-
@functools.wraps(testcase_func)
24-
def wrapper(test_class_instance, api_version, is_hsm=False, **kwargs):
25-
test_class_instance._skip_if_not_configured(api_version, is_hsm)
26-
endpoint_url = test_class_instance.managed_hsm_url if is_hsm else test_class_instance.vault_url
27-
client = test_class_instance.create_key_client(endpoint_url, api_version=api_version, **kwargs)
28-
29-
if kwargs.get("is_async"):
30-
import asyncio
31-
32-
coroutine = testcase_func(test_class_instance, client, is_hsm=is_hsm)
33-
loop = asyncio.get_event_loop()
34-
loop.run_until_complete(coroutine)
35-
else:
36-
testcase_func(test_class_instance, client, is_hsm=is_hsm)
37-
38-
return wrapper
12+
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
13+
from devtools_testutils import AzureRecordedTestCase
3914

4015

4116
def get_attestation_token(attestation_uri):
@@ -48,10 +23,10 @@ def get_attestation_token(attestation_uri):
4823
def get_decorator(only_hsm=False, only_vault=False, api_versions=None, **kwargs):
4924
"""returns a test decorator for test parameterization"""
5025
params = [
51-
param(api_version=p[0], is_hsm=p[1], **kwargs)
26+
pytest.param(p[0],p[1], id=p[0] + ("_mhsm" if p[1] else "_vault" ))
5227
for p in get_test_parameters(only_hsm, only_vault, api_versions=api_versions)
5328
]
54-
return functools.partial(parameterized.expand, params, name_func=suffixed_test_name)
29+
return params
5530

5631

5732
def get_release_policy(attestation_uri, **kwargs):
@@ -87,78 +62,51 @@ def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
8762
return combinations
8863

8964

90-
def suffixed_test_name(testcase_func, param_num, param):
91-
api_version = param.kwargs.get("api_version")
92-
suffix = "mhsm" if param.kwargs.get("is_hsm") else "vault"
93-
return "{}_{}_{}".format(
94-
testcase_func.__name__, parameterized.to_safe_name(api_version), parameterized.to_safe_name(suffix)
95-
)
96-
97-
9865
def is_public_cloud():
9966
return (".microsoftonline.com" in os.getenv('AZURE_AUTHORITY_HOST', ''))
10067

10168

102-
class KeysTestCase(AzureTestCase):
103-
def setUp(self, *args, **kwargs):
69+
class KeysClientPreparer(AzureRecordedTestCase):
70+
def __init__(self, *args, **kwargs):
10471
vault_playback_url = "https://vaultname.vault.azure.net"
105-
hsm_playback_url = "https://managedhsmname.managedhsm.azure.net"
72+
hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
73+
self.is_logging_enabled = kwargs.pop("logging_enable", True)
10674

10775
if self.is_live:
10876
self.vault_url = os.environ["AZURE_KEYVAULT_URL"]
109-
self._scrub_url(real_url=self.vault_url, playback_url=vault_playback_url)
110-
111-
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL")
77+
self.vault_url = self.vault_url.rstrip("/")
78+
self.managed_hsm_url = os.environ.get("AZURE_MANAGEDHSM_URL", None)
11279
if self.managed_hsm_url:
113-
self._scrub_url(real_url=self.managed_hsm_url, playback_url=hsm_playback_url)
80+
self.managed_hsm_url = self.managed_hsm_url.rstrip("/")
11481
else:
11582
self.vault_url = vault_playback_url
11683
self.managed_hsm_url = hsm_playback_url
11784

11885
self._set_mgmt_settings_real_values()
119-
super(KeysTestCase, self).setUp(*args, **kwargs)
12086

121-
def tearDown(self):
122-
HttpChallengeCache.clear()
123-
assert len(HttpChallengeCache._cache) == 0
124-
super(KeysTestCase, self).tearDown()
87+
def __call__(self, fn):
88+
def _preparer(test_class, api_version, is_hsm, **kwargs):
12589

126-
def create_key_client(self, vault_uri, **kwargs):
127-
if kwargs.pop("is_async", False):
128-
from azure.keyvault.keys.aio import KeyClient
129-
130-
credential = self.get_credential(KeyClient, is_async=True)
131-
else:
132-
from azure.keyvault.keys import KeyClient
90+
#self._skip_if_not_configured(api_version, is_hsm)
91+
if not self.is_logging_enabled:
92+
kwargs.update({"logging_enable": False})
93+
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
94+
client = self.create_key_client(endpoint_url, api_version=api_version, **kwargs)
13395

134-
credential = self.get_credential(KeyClient)
135-
return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)
96+
with client:
97+
fn(test_class, client, is_hsm=is_hsm, managed_hsm_url = self.managed_hsm_url, vault_url = self.vault_url)
98+
return _preparer
99+
136100

137-
def create_crypto_client(self, key, **kwargs):
138-
if kwargs.pop("is_async", False):
139-
from azure.keyvault.keys.crypto.aio import CryptographyClient
140101

141-
credential = self.get_credential(CryptographyClient, is_async=True)
142-
else:
143-
from azure.keyvault.keys.crypto import CryptographyClient
102+
def create_key_client(self, vault_uri, **kwargs):
103+
104+
from azure.keyvault.keys import KeyClient
144105

145-
credential = self.get_credential(CryptographyClient)
146-
return self.create_client_from_credential(CryptographyClient, credential=credential, key=key, **kwargs)
106+
credential = self.get_credential(KeyClient)
107+
108+
return self.create_client_from_credential(KeyClient, credential=credential, vault_url=vault_uri, **kwargs)
147109

148-
def _get_attestation_uri(self):
149-
playback_uri = "https://fakeattestation.azurewebsites.net"
150-
if self.is_live:
151-
real_uri = os.environ.get("AZURE_KEYVAULT_ATTESTATION_URL")
152-
if real_uri is None:
153-
pytest.skip("No AZURE_KEYVAULT_ATTESTATION_URL environment variable")
154-
self._scrub_url(real_uri, playback_uri)
155-
return real_uri
156-
return playback_uri
157-
158-
def _scrub_url(self, real_url, playback_url):
159-
real = urlparse(real_url)
160-
playback = urlparse(playback_url)
161-
self.scrubber.register_name_pair(real.netloc, playback.netloc)
162110

163111
def _set_mgmt_settings_real_values(self):
164112
if self.is_live:

0 commit comments

Comments
 (0)