2
2
# Copyright (c) Microsoft Corporation.
3
3
# Licensed under the MIT License.
4
4
# ------------------------------------
5
- import functools
6
5
import json
7
6
import os
8
7
8
+ import pytest
9
9
from azure .core .pipeline import Pipeline
10
10
from azure .core .pipeline .transport import HttpRequest , RequestsTransport
11
11
from azure .keyvault .keys import KeyReleasePolicy
12
- from azure .keyvault .keys ._shared import HttpChallengeCache
13
- from azure .keyvault .keys ._shared .client_base import ApiVersion , DEFAULT_VERSION
14
- from devtools_testutils import AzureTestCase
15
- from parameterized import parameterized , param
16
- import pytest
17
- from six .moves .urllib_parse import urlparse
18
-
19
-
20
- def client_setup (testcase_func ):
21
- """decorator that creates a client to be passed in to a test method"""
22
-
23
- @functools .wraps (testcase_func )
24
- def wrapper (test_class_instance , api_version , is_hsm = False , ** kwargs ):
25
- test_class_instance ._skip_if_not_configured (api_version , is_hsm )
26
- endpoint_url = test_class_instance .managed_hsm_url if is_hsm else test_class_instance .vault_url
27
- client = test_class_instance .create_key_client (endpoint_url , api_version = api_version , ** kwargs )
28
-
29
- if kwargs .get ("is_async" ):
30
- import asyncio
31
-
32
- coroutine = testcase_func (test_class_instance , client , is_hsm = is_hsm )
33
- loop = asyncio .get_event_loop ()
34
- loop .run_until_complete (coroutine )
35
- else :
36
- testcase_func (test_class_instance , client , is_hsm = is_hsm )
37
-
38
- return wrapper
12
+ from azure .keyvault .keys ._shared .client_base import DEFAULT_VERSION , ApiVersion
13
+ from devtools_testutils import AzureRecordedTestCase
39
14
40
15
41
16
def get_attestation_token (attestation_uri ):
@@ -48,10 +23,10 @@ def get_attestation_token(attestation_uri):
48
23
def get_decorator (only_hsm = False , only_vault = False , api_versions = None , ** kwargs ):
49
24
"""returns a test decorator for test parameterization"""
50
25
params = [
51
- param (api_version = p [0 ], is_hsm = p [1 ], ** kwargs )
26
+ pytest . param (p [0 ],p [ 1 ], id = p [0 ] + ( "_mhsm" if p [ 1 ] else "_vault" ) )
52
27
for p in get_test_parameters (only_hsm , only_vault , api_versions = api_versions )
53
28
]
54
- return functools . partial ( parameterized . expand , params , name_func = suffixed_test_name )
29
+ return params
55
30
56
31
57
32
def get_release_policy (attestation_uri , ** kwargs ):
@@ -87,78 +62,51 @@ def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
87
62
return combinations
88
63
89
64
90
- def suffixed_test_name (testcase_func , param_num , param ):
91
- api_version = param .kwargs .get ("api_version" )
92
- suffix = "mhsm" if param .kwargs .get ("is_hsm" ) else "vault"
93
- return "{}_{}_{}" .format (
94
- testcase_func .__name__ , parameterized .to_safe_name (api_version ), parameterized .to_safe_name (suffix )
95
- )
96
-
97
-
98
65
def is_public_cloud ():
99
66
return (".microsoftonline.com" in os .getenv ('AZURE_AUTHORITY_HOST' , '' ))
100
67
101
68
102
- class KeysTestCase ( AzureTestCase ):
103
- def setUp (self , * args , ** kwargs ):
69
+ class KeysClientPreparer ( AzureRecordedTestCase ):
70
+ def __init__ (self , * args , ** kwargs ):
104
71
vault_playback_url = "https://vaultname.vault.azure.net"
105
- hsm_playback_url = "https://managedhsmname.managedhsm.azure.net"
72
+ hsm_playback_url = "https://managedhsmvaultname.vault.azure.net"
73
+ self .is_logging_enabled = kwargs .pop ("logging_enable" , True )
106
74
107
75
if self .is_live :
108
76
self .vault_url = os .environ ["AZURE_KEYVAULT_URL" ]
109
- self ._scrub_url (real_url = self .vault_url , playback_url = vault_playback_url )
110
-
111
- self .managed_hsm_url = os .environ .get ("AZURE_MANAGEDHSM_URL" )
77
+ self .vault_url = self .vault_url .rstrip ("/" )
78
+ self .managed_hsm_url = os .environ .get ("AZURE_MANAGEDHSM_URL" , None )
112
79
if self .managed_hsm_url :
113
- self ._scrub_url ( real_url = self .managed_hsm_url , playback_url = hsm_playback_url )
80
+ self .managed_hsm_url = self .managed_hsm_url . rstrip ( "/" )
114
81
else :
115
82
self .vault_url = vault_playback_url
116
83
self .managed_hsm_url = hsm_playback_url
117
84
118
85
self ._set_mgmt_settings_real_values ()
119
- super (KeysTestCase , self ).setUp (* args , ** kwargs )
120
86
121
- def tearDown (self ):
122
- HttpChallengeCache .clear ()
123
- assert len (HttpChallengeCache ._cache ) == 0
124
- super (KeysTestCase , self ).tearDown ()
87
+ def __call__ (self , fn ):
88
+ def _preparer (test_class , api_version , is_hsm , ** kwargs ):
125
89
126
- def create_key_client (self , vault_uri , ** kwargs ):
127
- if kwargs .pop ("is_async" , False ):
128
- from azure .keyvault .keys .aio import KeyClient
129
-
130
- credential = self .get_credential (KeyClient , is_async = True )
131
- else :
132
- from azure .keyvault .keys import KeyClient
90
+ #self._skip_if_not_configured(api_version, is_hsm)
91
+ if not self .is_logging_enabled :
92
+ kwargs .update ({"logging_enable" : False })
93
+ endpoint_url = self .managed_hsm_url if is_hsm else self .vault_url
94
+ client = self .create_key_client (endpoint_url , api_version = api_version , ** kwargs )
133
95
134
- credential = self .get_credential (KeyClient )
135
- return self .create_client_from_credential (KeyClient , credential = credential , vault_url = vault_uri , ** kwargs )
96
+ with client :
97
+ fn (test_class , client , is_hsm = is_hsm , managed_hsm_url = self .managed_hsm_url , vault_url = self .vault_url )
98
+ return _preparer
99
+
136
100
137
- def create_crypto_client (self , key , ** kwargs ):
138
- if kwargs .pop ("is_async" , False ):
139
- from azure .keyvault .keys .crypto .aio import CryptographyClient
140
101
141
- credential = self . get_credential ( CryptographyClient , is_async = True )
142
- else :
143
- from azure .keyvault .keys . crypto import CryptographyClient
102
+ def create_key_client ( self , vault_uri , ** kwargs ):
103
+
104
+ from azure .keyvault .keys import KeyClient
144
105
145
- credential = self .get_credential (CryptographyClient )
146
- return self .create_client_from_credential (CryptographyClient , credential = credential , key = key , ** kwargs )
106
+ credential = self .get_credential (KeyClient )
107
+
108
+ return self .create_client_from_credential (KeyClient , credential = credential , vault_url = vault_uri , ** kwargs )
147
109
148
- def _get_attestation_uri (self ):
149
- playback_uri = "https://fakeattestation.azurewebsites.net"
150
- if self .is_live :
151
- real_uri = os .environ .get ("AZURE_KEYVAULT_ATTESTATION_URL" )
152
- if real_uri is None :
153
- pytest .skip ("No AZURE_KEYVAULT_ATTESTATION_URL environment variable" )
154
- self ._scrub_url (real_uri , playback_uri )
155
- return real_uri
156
- return playback_uri
157
-
158
- def _scrub_url (self , real_url , playback_url ):
159
- real = urlparse (real_url )
160
- playback = urlparse (playback_url )
161
- self .scrubber .register_name_pair (real .netloc , playback .netloc )
162
110
163
111
def _set_mgmt_settings_real_values (self ):
164
112
if self .is_live :
0 commit comments