Skip to content

Commit 43700e2

Browse files
authored
[Key Vault] Refine Security Domain tests to not require CLI use mid-execution (#42089)
* Update tests to not require Azure CLI * CredScan suppressions for test certs/keys * Reorganize testing utilities * Fix transfer key file extension; gitignore transfer key * Address feedback
1 parent ec33763 commit 43700e2

File tree

13 files changed

+877
-35
lines changed

13 files changed

+877
-35
lines changed

eng/CredScanSuppression.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@
3535
"eng/common/testproxy/dotnet-devcert.pfx",
3636
"sdk/confidentialledger/azure-confidentialledger/tests/_shared/constants.py",
3737
"sdk/keyvault/azure-keyvault-certificates/tests/ca.key",
38+
"sdk/keyvault/azure-keyvault-securitydomain/tests/resources/certificate0.cer",
39+
"sdk/keyvault/azure-keyvault-securitydomain/tests/resources/certificate1.cer",
40+
"sdk/keyvault/azure-keyvault-securitydomain/tests/resources/certificate2.cer",
41+
"sdk/keyvault/azure-keyvault-securitydomain/tests/resources/certificate0.pem",
42+
"sdk/keyvault/azure-keyvault-securitydomain/tests/resources/certificate1.pem",
43+
"sdk/keyvault/azure-keyvault-securitydomain/tests/resources/certificate2.pem",
3844
"sdk/identity/azure-identity/tests/certificate.pfx",
3945
"sdk/identity/azure-identity/tests/certificate.pem",
4046
"sdk/identity/azure-identity/tests/certificate-with-password.pfx",

sdk/keyvault/.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@
22
*.key
33
*.pfx
44
*security-domain.json
5-
*.pem
5+
*.pem
6+
!azure-keyvault-securitydomain/tests/resources/*.cer
7+
!azure-keyvault-securitydomain/tests/resources/*.pem
8+
*transfer-key.pem
Lines changed: 355 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,355 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
# The core utilities in this file are copied from the Azure CLI's Security Domain module:
6+
# https://github.com/Azure/azure-cli/tree/dev/src/azure-cli/azure/cli/command_modules/keyvault/security_domain
7+
import base64
8+
import hashlib
9+
import hmac
10+
import json
11+
12+
from cryptography.hazmat.backends import default_backend
13+
from cryptography.hazmat.primitives import hashes, padding
14+
from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding
15+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
16+
17+
from utils import Utils
18+
19+
20+
class KDF:
21+
@staticmethod
22+
def to_big_endian_32bits(value):
23+
result = bytearray()
24+
result.append((value & 0xFF000000) >> 24)
25+
result.append((value & 0x00FF0000) >> 16)
26+
result.append((value & 0x0000FF00) >> 8)
27+
result.append(value & 0x000000FF)
28+
return result
29+
30+
@staticmethod
31+
def to_big_endian_64bits(value):
32+
result = bytearray()
33+
result.append((value & 0xFF00000000000000) >> 56)
34+
result.append((value & 0x00FF000000000000) >> 48)
35+
result.append((value & 0x0000FF0000000000) >> 40)
36+
result.append((value & 0x000000FF00000000) >> 32)
37+
result.append((value & 0x00000000FF000000) >> 24)
38+
result.append((value & 0x0000000000FF0000) >> 16)
39+
result.append((value & 0x000000000000FF00) >> 8)
40+
result.append(value & 0x00000000000000FF)
41+
return result
42+
43+
@staticmethod
44+
def test_sp800_108():
45+
label = 'label'
46+
context = 'context'
47+
bit_length = 256
48+
hex_result = 'f0ca51f6308791404bf68b56024ee7c64d6c737716f81d47e1e68b5c4e399575'
49+
key = bytearray()
50+
key.extend([0x41] * 32)
51+
52+
new_key = KDF.sp800_108(key, label, context, bit_length)
53+
hex_value = new_key.hex().replace('-', '') # type: ignore
54+
return hex_value.lower() == hex_result
55+
56+
@staticmethod
57+
def sp800_108(key_in: bytearray, label: str, context: str, bit_length):
58+
"""
59+
Note - initialize out to be the number of bytes of keying material that you need
60+
This implements SP 800-108 in counter mode, see section 5.1
61+
62+
Fixed values:
63+
1. h - The length of the output of the PRF in bits, and
64+
2. r - The length of the binary representation of the counter i.
65+
66+
Input: KI, Label, Context, and L.
67+
68+
Process:
69+
1. n := ⎡L/h⎤.
70+
2. If n > 2^(r-1), then indicate an error and stop.
71+
3. result(0):= ∅.
72+
4. For i = 1 to n, do
73+
a. K(i) := PRF (KI, [i]2 || Label || 0x00 || Context || [L]2)
74+
b. result(i) := result(i-1) || K(i).
75+
76+
5. Return: KO := the leftmost L bits of result(n).
77+
"""
78+
79+
if bit_length <= 0 or bit_length % 8 != 0:
80+
return None
81+
82+
L = bit_length
83+
bytes_needed = bit_length // 8
84+
hMAC = hmac.HMAC(key_in, digestmod=hashlib.sha512)
85+
hash_bits = hMAC.digest_size
86+
n = L // hash_bits
87+
if L % hash_bits != 0:
88+
n += 1
89+
90+
hmac_data_suffix = bytearray()
91+
hmac_data_suffix.extend(label.encode('UTF-8'))
92+
hmac_data_suffix.append(0)
93+
hmac_data_suffix.extend(context.encode('UTF-8'))
94+
hmac_data_suffix.extend(KDF.to_big_endian_32bits(bit_length))
95+
96+
out_value = bytearray()
97+
for i in range(n):
98+
hmac_data = bytearray()
99+
hmac_data.extend(KDF.to_big_endian_32bits(i + 1))
100+
hmac_data.extend(hmac_data_suffix)
101+
hMAC.update(hmac_data)
102+
hash_value = hMAC.digest()
103+
104+
if bytes_needed > len(hash_value):
105+
out_value.extend(hash_value)
106+
bytes_needed -= len(hash_value)
107+
else:
108+
out_value.extend(hash_value[len(out_value): len(out_value) + bytes_needed])
109+
return out_value
110+
111+
return None
112+
113+
114+
class JWEHeader: # pylint: disable=too-many-instance-attributes
115+
_fields = ['alg', 'enc', 'zip', 'jku', 'jwk', 'kid', 'x5u', 'x5c', 'x5t', 'x5t_S256', 'typ', 'cty', 'crit']
116+
117+
def __init__(self, alg=None, enc=None, zip=None, # pylint: disable=redefined-builtin
118+
jku=None, jwk=None, kid=None, x5u=None, x5c=None, x5t=None,
119+
x5t_S256=None, typ=None, cty=None, crit=None):
120+
"""
121+
JWE header
122+
123+
:param alg: algorithm
124+
:param enc: encryption algorithm
125+
:param zip: compression algorithm
126+
:param jku: JWK set URL
127+
:param jwk: JSON Web key
128+
:param kid: Key ID
129+
:param x5u: X.509 certificate URL
130+
:param x5c: X.509 certificate chain
131+
:param x5t: X.509 certificate SHA-1 Thumbprint
132+
:param x5t_S256: X.509 certificate SHA-256 Thumbprint
133+
:param typ: type
134+
:param cty: content type
135+
:param crit: critical
136+
"""
137+
self.alg = alg
138+
self.enc = enc
139+
self.zip = zip
140+
self.jku = jku
141+
self.jwk = jwk
142+
self.kid = kid
143+
self.x5u = x5u
144+
self.x5c = x5c
145+
self.x5t = x5t
146+
self.x5t_S256 = x5t_S256
147+
self.typ = typ
148+
self.cty = cty
149+
self.crit = crit
150+
151+
@staticmethod
152+
def from_json_str(json_str):
153+
json_dict = json.loads(json_str)
154+
jwe_header = JWEHeader()
155+
for k in jwe_header._fields:
156+
if k == 'x5t_S256':
157+
v = json_dict.get('x5t#S256', None)
158+
else:
159+
v = json_dict.get(k, None)
160+
if v is not None:
161+
setattr(jwe_header, k, v)
162+
return jwe_header
163+
164+
def to_json_str(self):
165+
json_dict = {}
166+
for k in self._fields:
167+
v = getattr(self, k, None)
168+
if v is not None:
169+
if k == 'x5t_S256':
170+
json_dict['x5t#S256'] = v
171+
else:
172+
json_dict[k] = v
173+
return json.dumps(json_dict)
174+
175+
176+
class JWEDecode:
177+
def __init__(self, compact_jwe=None):
178+
if compact_jwe is None:
179+
self.encoded_header = ''
180+
self.encrypted_key = None
181+
self.init_vector = None
182+
self.ciphertext = None
183+
self.auth_tag = None
184+
self.protected_header = JWEHeader()
185+
else:
186+
parts = compact_jwe.split('.')
187+
188+
self.encoded_header = parts[0]
189+
header = base64.urlsafe_b64decode(self.encoded_header + '===').decode('ascii') # Fix incorrect padding
190+
self.protected_header = JWEHeader.from_json_str(header)
191+
self.encrypted_key = base64.urlsafe_b64decode(parts[1] + '===')
192+
self.init_vector = base64.urlsafe_b64decode(parts[2] + '===')
193+
self.ciphertext = base64.urlsafe_b64decode(parts[3] + '===')
194+
self.auth_tag = base64.urlsafe_b64decode(parts[4] + '===')
195+
196+
def encode_header(self):
197+
header_json = self.protected_header.to_json_str().replace('": ', '":').replace('", ', '",')
198+
self.encoded_header = Utils.security_domain_b64_url_encode(header_json.encode('ascii'))
199+
200+
def encode_compact(self):
201+
ret = [self.encoded_header + '.']
202+
203+
if self.encrypted_key is not None:
204+
ret.append(Utils.security_domain_b64_url_encode(self.encrypted_key))
205+
ret.append('.')
206+
207+
if self.init_vector is not None:
208+
ret.append(Utils.security_domain_b64_url_encode(self.init_vector))
209+
ret.append('.')
210+
211+
if self.ciphertext is not None:
212+
ret.append(Utils.security_domain_b64_url_encode(self.ciphertext))
213+
ret.append('.')
214+
215+
if self.auth_tag is not None:
216+
ret.append(Utils.security_domain_b64_url_encode(self.auth_tag))
217+
218+
return ''.join(ret)
219+
220+
221+
class JWE:
222+
def __init__(self, compact_jwe=None):
223+
self.jwe_decode = JWEDecode(compact_jwe=compact_jwe)
224+
225+
def encode_compact(self):
226+
return self.jwe_decode.encode_compact()
227+
228+
def get_padding_mode(self):
229+
alg = self.jwe_decode.protected_header.alg
230+
231+
if alg == 'RSA-OAEP-256':
232+
algorithm = hashes.SHA256()
233+
return asymmetric_padding.OAEP(
234+
mgf=asymmetric_padding.MGF1(algorithm=algorithm), algorithm=algorithm, label=None)
235+
236+
if alg == 'RSA-OAEP':
237+
algorithm = hashes.SHA1()
238+
return asymmetric_padding.OAEP(
239+
mgf=asymmetric_padding.MGF1(algorithm=algorithm), algorithm=algorithm, label=None)
240+
241+
if alg == 'RSA1_5':
242+
return asymmetric_padding.PKCS1v15()
243+
244+
return None
245+
246+
def get_cek(self, private_key):
247+
return private_key.decrypt(
248+
self.jwe_decode.encrypted_key,
249+
self.get_padding_mode()
250+
)
251+
252+
def set_cek(self, cert, cek):
253+
public_key = cert.public_key()
254+
self.jwe_decode.encrypted_key = public_key.encrypt(bytes(cek), self.get_padding_mode())
255+
256+
@staticmethod
257+
def dek_from_cek(cek):
258+
dek = bytearray()
259+
for i in range(32):
260+
dek.append(cek[i + 32])
261+
return dek
262+
263+
@staticmethod
264+
def hmac_key_from_cek(cek):
265+
hk = bytearray()
266+
for i in range(32):
267+
hk.append(cek[i])
268+
return hk
269+
270+
def get_mac(self, hk):
271+
header_bytes = bytearray()
272+
header_bytes.extend(self.jwe_decode.encoded_header.encode('ascii'))
273+
auth_bits = len(header_bytes) * 8
274+
275+
hash_data = bytearray()
276+
hash_data.extend(header_bytes)
277+
hash_data.extend(self.jwe_decode.init_vector) # type: ignore
278+
hash_data.extend(self.jwe_decode.ciphertext) # type: ignore
279+
hash_data.extend(KDF.to_big_endian_64bits(auth_bits))
280+
281+
hMAC = hmac.HMAC(hk, msg=hash_data, digestmod=hashlib.sha512)
282+
return hMAC.digest()
283+
284+
def Aes256HmacSha512Decrypt(self, cek):
285+
dek = JWE.dek_from_cek(cek)
286+
hk = JWE.hmac_key_from_cek(cek)
287+
mac_value = self.get_mac(hk)
288+
289+
test = 0
290+
i = 0
291+
while i < len(self.jwe_decode.auth_tag) == 32: # type: ignore
292+
test |= (self.jwe_decode.auth_tag[i] ^ mac_value[i]) # type: ignore
293+
i += 1
294+
295+
if test != 0:
296+
return None
297+
298+
aes_key = dek
299+
aes_iv = self.jwe_decode.init_vector
300+
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) # type: ignore
301+
decryptor = cipher.decryptor()
302+
plaintext = decryptor.update(self.jwe_decode.ciphertext) + decryptor.finalize() # type: ignore
303+
304+
unpadder = padding.PKCS7(128).unpadder()
305+
plaintext = unpadder.update(bytes(plaintext)) + unpadder.finalize()
306+
307+
return plaintext
308+
309+
def Aes256HmacSha512Encrypt(self, cek, plaintext):
310+
dek = JWE.dek_from_cek(cek)
311+
hk = JWE.hmac_key_from_cek(cek)
312+
313+
padder = padding.PKCS7(128).padder()
314+
plaintext = padder.update(bytes(plaintext)) + padder.finalize()
315+
316+
aes_key = dek
317+
aes_iv = Utils.get_random(16)
318+
cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) # type: ignore
319+
encryptor = cipher.encryptor()
320+
self.jwe_decode.ciphertext = encryptor.update(plaintext) + encryptor.finalize()
321+
self.jwe_decode.init_vector = aes_iv # type: ignore
322+
323+
mac_value = self.get_mac(hk)
324+
self.jwe_decode.auth_tag = bytearray() # type: ignore
325+
for i in range(32):
326+
self.jwe_decode.auth_tag.append(mac_value[i]) # type: ignore
327+
328+
def decrypt_using_bytes(self, cek):
329+
if self.jwe_decode.protected_header.enc == 'A256CBC-HS512':
330+
return self.Aes256HmacSha512Decrypt(cek)
331+
return None
332+
333+
def get_cek_from_private_key(self, private_key):
334+
return private_key.decrypt(self.jwe_decode.encrypted_key, self.get_padding_mode())
335+
336+
def decrypt_using_private_key(self, private_key):
337+
cek = self.get_cek_from_private_key(private_key)
338+
return self.decrypt_using_bytes(cek)
339+
340+
def encrypt_using_bytes(self, cek, plaintext, alg_id, kid=None):
341+
if kid is not None:
342+
self.jwe_decode.protected_header.alg = 'dir'
343+
self.jwe_decode.protected_header.kid = kid
344+
345+
if alg_id == 'A256CBC-HS512':
346+
self.jwe_decode.protected_header.enc = alg_id
347+
self.jwe_decode.encode_header()
348+
self.Aes256HmacSha512Encrypt(cek, plaintext)
349+
350+
def encrypt_using_cert(self, cert, plaintext):
351+
self.jwe_decode.protected_header.alg = 'RSA-OAEP-256'
352+
self.jwe_decode.protected_header.kid = 'not used'
353+
cek = Utils.get_random(64)
354+
self.set_cek(cert, cek)
355+
self.encrypt_using_bytes(cek, plaintext, alg_id='A256CBC-HS512')

0 commit comments

Comments
 (0)