Skip to content

Commit ef19493

Browse files
authored
[Key Vault] Handle cryptography RSA keys without local key material (#34330)
1 parent 84012b6 commit ef19493

File tree

8 files changed

+156
-63
lines changed

8 files changed

+156
-63
lines changed

sdk/keyvault/azure-keyvault-keys/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "python",
44
"TagPrefix": "python/keyvault/azure-keyvault-keys",
5-
"Tag": "python/keyvault/azure-keyvault-keys_d3b8eaba1a"
5+
"Tag": "python/keyvault/azure-keyvault-keys_6a80b2f740"
66
}

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ class KeyProperties(object):
6464
6565
:keyword bool managed: Whether the key's lifetime is managed by Key Vault.
6666
:keyword tags: Application specific metadata in the form of key-value pairs.
67-
:paramtype tags: dict[str, str]
67+
:paramtype tags: dict[str, str] or None
6868
:keyword release_policy: The azure.keyvault.keys.KeyReleasePolicy specifying the rules under which the key
6969
can be exported.
70-
:paramtype release_policy: ~azure.keyvault.keys.KeyReleasePolicy
70+
:paramtype release_policy: ~azure.keyvault.keys.KeyReleasePolicy or None
7171
"""
7272

7373
def __init__(self, key_id: str, attributes: "Optional[_models.KeyAttributes]" = None, **kwargs: Any) -> None:
@@ -288,6 +288,8 @@ def __init__(self, encoded_policy: bytes, **kwargs: Any) -> None:
288288
class ReleaseKeyResult(object):
289289
"""The result of a key release operation.
290290
291+
:ivar str value: A signed token containing the released key.
292+
291293
:param str value: A signed token containing the released key.
292294
"""
293295

sdk/keyvault/azure-keyvault-keys/azure/keyvault/keys/crypto/_models.py

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -120,19 +120,14 @@ def get_signature_algorithm(padding: AsymmetricPadding, algorithm: HashAlgorithm
120120
class KeyVaultRSAPublicKey(RSAPublicKey):
121121
"""An `RSAPublicKey` implementation based on a key managed by Key Vault.
122122
123-
Only synchronous clients and operations are supported at this time.
123+
This class should not be instantiated directly. Instead, use the
124+
:func:`~azure.keyvault.keys.crypto.CryptographyClient.create_rsa_public_key` method to create a key based on the
125+
client's key. Only synchronous clients and operations are supported at this time.
124126
"""
125127

126-
def __init__(self, client: "CryptographyClient", key_material: JsonWebKey) -> None:
127-
"""Creates a `KeyVaultRSAPublicKey` from a `CryptographyClient` and key.
128-
129-
:param client: The client that will be used to communicate with Key Vault.
130-
:type client: ~azure.keyvault.keys.crypto.CryptographyClient
131-
:param key_material: They Key Vault key's material, as a `JsonWebKey`.
132-
:type key_material: ~azure.keyvault.keys.JsonWebKey
133-
"""
128+
def __init__(self, client: "CryptographyClient", key_material: Optional[JsonWebKey] = None) -> None:
134129
self._client: "CryptographyClient" = client
135-
self._key: JsonWebKey = key_material
130+
self._key: Optional[JsonWebKey] = key_material
136131

137132
def encrypt(self, plaintext: bytes, padding: AsymmetricPadding) -> bytes:
138133
"""Encrypts the given plaintext.
@@ -156,7 +151,15 @@ def key_size(self) -> int:
156151
157152
:returns: The key's size.
158153
:rtype: int
154+
155+
:raises ValueError: if the client is unable to obtain the key material from Key Vault.
159156
"""
157+
if self._key is None:
158+
raise ValueError(
159+
"Key material could not be obtained from Key Vault. Only remote cryptographic operations "
160+
"(encrypt, verify) can be performed."
161+
)
162+
160163
public_key = self.public_numbers().public_key()
161164
return public_key.key_size
162165

@@ -165,7 +168,15 @@ def public_numbers(self) -> RSAPublicNumbers:
165168
166169
:returns: The public numbers of the key.
167170
:rtype: RSAPublicNumbers
171+
172+
:raises ValueError: if the client is unable to obtain the key material from Key Vault.
168173
"""
174+
if self._key is None:
175+
raise ValueError(
176+
"Key material could not be obtained from Key Vault. Only remote cryptographic operations "
177+
"(encrypt, verify) can be performed."
178+
)
179+
169180
e = int.from_bytes(self._key.e, "big") # type: ignore[attr-defined]
170181
n = int.from_bytes(self._key.n, "big") # type: ignore[attr-defined]
171182
return RSAPublicNumbers(e, n)
@@ -184,7 +195,15 @@ def public_bytes(self, encoding: Encoding, format: PublicFormat) -> bytes:
184195
185196
:returns: The serialized key.
186197
:rtype: bytes
198+
199+
:raises ValueError: if the client is unable to obtain the key material from Key Vault.
187200
"""
201+
if self._key is None:
202+
raise ValueError(
203+
"Key material could not be obtained from Key Vault. Only remote cryptographic operations "
204+
"(encrypt, verify) can be performed."
205+
)
206+
188207
public_key = self.public_numbers().public_key()
189208
return public_key.public_bytes(encoding=encoding, format=format)
190209

@@ -250,12 +269,18 @@ def recover_data_from_signature(
250269
251270
:returns: The signed data.
252271
:rtype: bytes
253-
:raises:
254-
NotImplementedError if the local version of `cryptography` doesn't support this method.
255-
:class:`~cryptography.exceptions.InvalidSignature` if the signature is invalid.
256-
:class:`~cryptography.exceptions.UnsupportedAlgorithm` if the signature data recovery is not supported with
272+
:raises NotImplementedError: if the local version of `cryptography` doesn't support this method.
273+
:raises ~cryptography.exceptions.InvalidSignature: if the signature is invalid.
274+
:raises ~cryptography.exceptions.UnsupportedAlgorithm: if the signature data recovery is not supported with
257275
the provided `padding` type.
276+
:raises ValueError: if the client is unable to obtain the key material from Key Vault.
258277
"""
278+
if self._key is None:
279+
raise ValueError(
280+
"Key material could not be obtained from Key Vault. Only remote cryptographic operations "
281+
"(encrypt, verify) can be performed."
282+
)
283+
259284
public_key = self.public_numbers().public_key()
260285
try:
261286
return public_key.recover_data_from_signature(signature=signature, padding=padding, algorithm=algorithm)
@@ -270,9 +295,13 @@ def __eq__(self, other: object) -> bool:
270295
:param object other: Another object to compare with this instance. Currently, only comparisons with
271296
`KeyVaultRSAPrivateKey` or `JsonWebKey` instances are supported.
272297
273-
:returns: True if the objects are equal; False otherwise.
298+
:returns: True if the objects are equal; False if the objects are unequal or if key material can't be obtained
299+
from Key Vault for comparison.
274300
:rtype: bool
275301
"""
302+
if self._key is None:
303+
return False
304+
276305
if isinstance(other, KeyVaultRSAPublicKey):
277306
return all(getattr(self._key, field) == getattr(other._key, field) for field in self._key._FIELDS)
278307
if isinstance(other, JsonWebKey):
@@ -289,19 +318,14 @@ def verifier( # pylint:disable=docstring-missing-param,docstring-missing-return
289318
class KeyVaultRSAPrivateKey(RSAPrivateKey):
290319
"""An `RSAPrivateKey` implementation based on a key managed by Key Vault.
291320
292-
Only synchronous clients and operations are supported at this time.
321+
This class should not be instantiated directly. Instead, use the
322+
:func:`~azure.keyvault.keys.crypto.CryptographyClient.create_rsa_private_key` method to create a key based on the
323+
client's key. Only synchronous clients and operations are supported at this time.
293324
"""
294325

295-
def __init__(self, client: "CryptographyClient", key_material: JsonWebKey) -> None:
296-
"""Creates a `KeyVaultRSAPrivateKey` from a `CryptographyClient` and key.
297-
298-
:param client: The client that will be used to communicate with Key Vault.
299-
:type client: ~azure.keyvault.keys.crypto.CryptographyClient
300-
:param key_material: They Key Vault key's material, as a `JsonWebKey`.
301-
:type key_material: ~azure.keyvault.keys.JsonWebKey
302-
"""
326+
def __init__(self, client: "CryptographyClient", key_material: Optional[JsonWebKey]) -> None:
303327
self._client: "CryptographyClient" = client
304-
self._key: JsonWebKey = key_material
328+
self._key: Optional[JsonWebKey] = key_material
305329

306330
def decrypt(self, ciphertext: bytes, padding: AsymmetricPadding) -> bytes:
307331
"""Decrypts the provided ciphertext.
@@ -325,7 +349,15 @@ def key_size(self) -> int:
325349
326350
:returns: The key's size.
327351
:rtype: int
352+
353+
:raises ValueError: if the client is unable to obtain the key material from Key Vault.
328354
"""
355+
if self._key is None:
356+
raise ValueError(
357+
"Key material could not be obtained from Key Vault. Only remote cryptographic operations "
358+
"(decrypt, sign) can be performed."
359+
)
360+
329361
# Key size only requires public modulus, which we can always get
330362
# Relying on private numbers instead would cause issues for keys stored in KV (which doesn't return private key)
331363
return self.public_key().key_size
@@ -374,7 +406,15 @@ def private_numbers(self) -> RSAPrivateNumbers:
374406
375407
:returns: The private numbers of the key.
376408
:rtype: ~cryptography.hazmat.primitives.asymmetric.rsa.RSAPrivateNumbers
409+
410+
:raises ValueError: if the client is unable to obtain the key material from Key Vault.
377411
"""
412+
if self._key is None:
413+
raise ValueError(
414+
"Key material could not be obtained from Key Vault. Only remote cryptographic operations "
415+
"(decrypt, sign) can be performed."
416+
)
417+
378418
# Fetch public numbers from JWK
379419
e = int.from_bytes(self._key.e, "big") # type: ignore[attr-defined]
380420
n = int.from_bytes(self._key.n, "big") # type: ignore[attr-defined]
@@ -420,7 +460,15 @@ def private_bytes(
420460
421461
:returns: The serialized key.
422462
:rtype: bytes
463+
464+
:raises ValueError: if the client is unable to obtain the key material from Key Vault.
423465
"""
466+
if self._key is None:
467+
raise ValueError(
468+
"Key material could not be obtained from Key Vault. Only remote cryptographic operations "
469+
"(decrypt, sign) can be performed."
470+
)
471+
424472
try:
425473
private_numbers = self.private_numbers()
426474
except ValueError as exc:

sdk/keyvault/azure-keyvault-keys/tests/_async_test_case.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from azure.core.pipeline import AsyncPipeline
1010
from azure.core.pipeline.transport import AioHttpTransport, HttpRequest
1111
from azure.keyvault.keys import KeyReleasePolicy
12-
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
1312
from devtools_testutils import AzureRecordedTestCase
1413
from _test_case import HSM_SUPPORTED_VERSIONS
1514

@@ -42,7 +41,7 @@ def get_release_policy(attestation_uri, **kwargs):
4241
def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
4342
"""generates a list of parameter pairs for test case parameterization, where [x, y] = [api_version, is_hsm]"""
4443
combinations = []
45-
versions = api_versions or ApiVersion
44+
versions = api_versions or pytest.api_version # pytest.api_version -> [DEFAULT_VERSION] if live, ApiVersion if not
4645

4746
for api_version in versions:
4847
if not only_vault and api_version in HSM_SUPPORTED_VERSIONS:
@@ -74,7 +73,7 @@ def __init__(self, *args, **kwargs):
7473
def __call__(self, fn):
7574
async def _preparer(test_class, api_version, is_hsm, **kwargs):
7675

77-
self._skip_if_not_configured(api_version, is_hsm)
76+
self._skip_if_not_configured(is_hsm)
7877
if not self.is_logging_enabled:
7978
kwargs.update({"logging_enable": False})
8079
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
@@ -98,10 +97,8 @@ def _set_mgmt_settings_real_values(self):
9897
if self.is_live:
9998
os.environ["AZURE_TENANT_ID"] = os.environ["KEYVAULT_TENANT_ID"]
10099
os.environ["AZURE_CLIENT_ID"] = os.environ["KEYVAULT_CLIENT_ID"]
101-
os.environ["AZURE_CLIENT_SECRET"] = os.environ["KEYVAULT_CLIENT_SECRET"]
100+
os.environ["AZURE_CLIENT_SECRET"] = os.environ.get("KEYVAULT_CLIENT_SECRET", "") # Empty for user auth
102101

103-
def _skip_if_not_configured(self, api_version, is_hsm):
104-
if self.is_live and api_version != DEFAULT_VERSION:
105-
pytest.skip("This test only uses the default API version for live tests")
102+
def _skip_if_not_configured(self, is_hsm):
106103
if self.is_live and is_hsm and self.managed_hsm_url is None:
107104
pytest.skip("No HSM endpoint for live testing")

sdk/keyvault/azure-keyvault-keys/tests/_test_case.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from azure.core.pipeline import Pipeline
1010
from azure.core.pipeline.transport import HttpRequest, RequestsTransport
1111
from azure.keyvault.keys import KeyReleasePolicy
12-
from azure.keyvault.keys._shared.client_base import DEFAULT_VERSION, ApiVersion
12+
from azure.keyvault.keys._shared.client_base import ApiVersion
1313
from devtools_testutils import AzureRecordedTestCase
1414

1515

@@ -44,7 +44,7 @@ def get_release_policy(attestation_uri, **kwargs):
4444
def get_test_parameters(only_hsm=False, only_vault=False, api_versions=None):
4545
"""generates a list of parameter pairs for test case parameterization, where [x, y] = [api_version, is_hsm]"""
4646
combinations = []
47-
versions = api_versions or pytest.api_version
47+
versions = api_versions or pytest.api_version # pytest.api_version -> [DEFAULT_VERSION] if live, ApiVersion if not
4848

4949
for api_version in versions:
5050
if not only_vault and api_version in HSM_SUPPORTED_VERSIONS:
@@ -79,7 +79,7 @@ def __init__(self, *args, **kwargs):
7979
def __call__(self, fn):
8080
def _preparer(test_class, api_version, is_hsm, **kwargs):
8181

82-
self._skip_if_not_configured(api_version, is_hsm)
82+
self._skip_if_not_configured(is_hsm)
8383
if not self.is_logging_enabled:
8484
kwargs.update({"logging_enable": False})
8585
endpoint_url = self.managed_hsm_url if is_hsm else self.vault_url
@@ -102,11 +102,8 @@ def _set_mgmt_settings_real_values(self):
102102
if self.is_live:
103103
os.environ["AZURE_TENANT_ID"] = os.environ["KEYVAULT_TENANT_ID"]
104104
os.environ["AZURE_CLIENT_ID"] = os.environ["KEYVAULT_CLIENT_ID"]
105-
os.environ["AZURE_CLIENT_SECRET"] = os.environ["KEYVAULT_CLIENT_SECRET"]
105+
os.environ["AZURE_CLIENT_SECRET"] = os.environ.get("KEYVAULT_CLIENT_SECRET", "") # Empty for user auth
106106

107-
def _skip_if_not_configured(self, api_version, is_hsm):
108-
109-
if self.is_live and api_version != DEFAULT_VERSION:
110-
pytest.skip("This test only uses the default API version for live tests")
107+
def _skip_if_not_configured(self, is_hsm):
111108
if self.is_live and is_hsm and self.managed_hsm_url is None:
112109
pytest.skip("No HSM endpoint for live testing")

0 commit comments

Comments
 (0)