|
| 1 | +# -------------------------------------------------------------------------------------------- |
| 2 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 3 | +# Licensed under the MIT License. See License.txt in the project root for license information. |
| 4 | +# -------------------------------------------------------------------------------------------- |
| 5 | +# The core utilities in this file are copied from the Azure CLI's Security Domain module: |
| 6 | +# https://github.com/Azure/azure-cli/tree/dev/src/azure-cli/azure/cli/command_modules/keyvault/security_domain |
| 7 | +import base64 |
| 8 | +import hashlib |
| 9 | +import hmac |
| 10 | +import json |
| 11 | + |
| 12 | +from cryptography.hazmat.backends import default_backend |
| 13 | +from cryptography.hazmat.primitives import hashes, padding |
| 14 | +from cryptography.hazmat.primitives.asymmetric import padding as asymmetric_padding |
| 15 | +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes |
| 16 | + |
| 17 | +from utils import Utils |
| 18 | + |
| 19 | + |
| 20 | +class KDF: |
| 21 | + @staticmethod |
| 22 | + def to_big_endian_32bits(value): |
| 23 | + result = bytearray() |
| 24 | + result.append((value & 0xFF000000) >> 24) |
| 25 | + result.append((value & 0x00FF0000) >> 16) |
| 26 | + result.append((value & 0x0000FF00) >> 8) |
| 27 | + result.append(value & 0x000000FF) |
| 28 | + return result |
| 29 | + |
| 30 | + @staticmethod |
| 31 | + def to_big_endian_64bits(value): |
| 32 | + result = bytearray() |
| 33 | + result.append((value & 0xFF00000000000000) >> 56) |
| 34 | + result.append((value & 0x00FF000000000000) >> 48) |
| 35 | + result.append((value & 0x0000FF0000000000) >> 40) |
| 36 | + result.append((value & 0x000000FF00000000) >> 32) |
| 37 | + result.append((value & 0x00000000FF000000) >> 24) |
| 38 | + result.append((value & 0x0000000000FF0000) >> 16) |
| 39 | + result.append((value & 0x000000000000FF00) >> 8) |
| 40 | + result.append(value & 0x00000000000000FF) |
| 41 | + return result |
| 42 | + |
| 43 | + @staticmethod |
| 44 | + def test_sp800_108(): |
| 45 | + label = 'label' |
| 46 | + context = 'context' |
| 47 | + bit_length = 256 |
| 48 | + hex_result = 'f0ca51f6308791404bf68b56024ee7c64d6c737716f81d47e1e68b5c4e399575' |
| 49 | + key = bytearray() |
| 50 | + key.extend([0x41] * 32) |
| 51 | + |
| 52 | + new_key = KDF.sp800_108(key, label, context, bit_length) |
| 53 | + hex_value = new_key.hex().replace('-', '') # type: ignore |
| 54 | + return hex_value.lower() == hex_result |
| 55 | + |
| 56 | + @staticmethod |
| 57 | + def sp800_108(key_in: bytearray, label: str, context: str, bit_length): |
| 58 | + """ |
| 59 | + Note - initialize out to be the number of bytes of keying material that you need |
| 60 | + This implements SP 800-108 in counter mode, see section 5.1 |
| 61 | +
|
| 62 | + Fixed values: |
| 63 | + 1. h - The length of the output of the PRF in bits, and |
| 64 | + 2. r - The length of the binary representation of the counter i. |
| 65 | +
|
| 66 | + Input: KI, Label, Context, and L. |
| 67 | +
|
| 68 | + Process: |
| 69 | + 1. n := ⎡L/h⎤. |
| 70 | + 2. If n > 2^(r-1), then indicate an error and stop. |
| 71 | + 3. result(0):= ∅. |
| 72 | + 4. For i = 1 to n, do |
| 73 | + a. K(i) := PRF (KI, [i]2 || Label || 0x00 || Context || [L]2) |
| 74 | + b. result(i) := result(i-1) || K(i). |
| 75 | +
|
| 76 | + 5. Return: KO := the leftmost L bits of result(n). |
| 77 | + """ |
| 78 | + |
| 79 | + if bit_length <= 0 or bit_length % 8 != 0: |
| 80 | + return None |
| 81 | + |
| 82 | + L = bit_length |
| 83 | + bytes_needed = bit_length // 8 |
| 84 | + hMAC = hmac.HMAC(key_in, digestmod=hashlib.sha512) |
| 85 | + hash_bits = hMAC.digest_size |
| 86 | + n = L // hash_bits |
| 87 | + if L % hash_bits != 0: |
| 88 | + n += 1 |
| 89 | + |
| 90 | + hmac_data_suffix = bytearray() |
| 91 | + hmac_data_suffix.extend(label.encode('UTF-8')) |
| 92 | + hmac_data_suffix.append(0) |
| 93 | + hmac_data_suffix.extend(context.encode('UTF-8')) |
| 94 | + hmac_data_suffix.extend(KDF.to_big_endian_32bits(bit_length)) |
| 95 | + |
| 96 | + out_value = bytearray() |
| 97 | + for i in range(n): |
| 98 | + hmac_data = bytearray() |
| 99 | + hmac_data.extend(KDF.to_big_endian_32bits(i + 1)) |
| 100 | + hmac_data.extend(hmac_data_suffix) |
| 101 | + hMAC.update(hmac_data) |
| 102 | + hash_value = hMAC.digest() |
| 103 | + |
| 104 | + if bytes_needed > len(hash_value): |
| 105 | + out_value.extend(hash_value) |
| 106 | + bytes_needed -= len(hash_value) |
| 107 | + else: |
| 108 | + out_value.extend(hash_value[len(out_value): len(out_value) + bytes_needed]) |
| 109 | + return out_value |
| 110 | + |
| 111 | + return None |
| 112 | + |
| 113 | + |
| 114 | +class JWEHeader: # pylint: disable=too-many-instance-attributes |
| 115 | + _fields = ['alg', 'enc', 'zip', 'jku', 'jwk', 'kid', 'x5u', 'x5c', 'x5t', 'x5t_S256', 'typ', 'cty', 'crit'] |
| 116 | + |
| 117 | + def __init__(self, alg=None, enc=None, zip=None, # pylint: disable=redefined-builtin |
| 118 | + jku=None, jwk=None, kid=None, x5u=None, x5c=None, x5t=None, |
| 119 | + x5t_S256=None, typ=None, cty=None, crit=None): |
| 120 | + """ |
| 121 | + JWE header |
| 122 | +
|
| 123 | + :param alg: algorithm |
| 124 | + :param enc: encryption algorithm |
| 125 | + :param zip: compression algorithm |
| 126 | + :param jku: JWK set URL |
| 127 | + :param jwk: JSON Web key |
| 128 | + :param kid: Key ID |
| 129 | + :param x5u: X.509 certificate URL |
| 130 | + :param x5c: X.509 certificate chain |
| 131 | + :param x5t: X.509 certificate SHA-1 Thumbprint |
| 132 | + :param x5t_S256: X.509 certificate SHA-256 Thumbprint |
| 133 | + :param typ: type |
| 134 | + :param cty: content type |
| 135 | + :param crit: critical |
| 136 | + """ |
| 137 | + self.alg = alg |
| 138 | + self.enc = enc |
| 139 | + self.zip = zip |
| 140 | + self.jku = jku |
| 141 | + self.jwk = jwk |
| 142 | + self.kid = kid |
| 143 | + self.x5u = x5u |
| 144 | + self.x5c = x5c |
| 145 | + self.x5t = x5t |
| 146 | + self.x5t_S256 = x5t_S256 |
| 147 | + self.typ = typ |
| 148 | + self.cty = cty |
| 149 | + self.crit = crit |
| 150 | + |
| 151 | + @staticmethod |
| 152 | + def from_json_str(json_str): |
| 153 | + json_dict = json.loads(json_str) |
| 154 | + jwe_header = JWEHeader() |
| 155 | + for k in jwe_header._fields: |
| 156 | + if k == 'x5t_S256': |
| 157 | + v = json_dict.get('x5t#S256', None) |
| 158 | + else: |
| 159 | + v = json_dict.get(k, None) |
| 160 | + if v is not None: |
| 161 | + setattr(jwe_header, k, v) |
| 162 | + return jwe_header |
| 163 | + |
| 164 | + def to_json_str(self): |
| 165 | + json_dict = {} |
| 166 | + for k in self._fields: |
| 167 | + v = getattr(self, k, None) |
| 168 | + if v is not None: |
| 169 | + if k == 'x5t_S256': |
| 170 | + json_dict['x5t#S256'] = v |
| 171 | + else: |
| 172 | + json_dict[k] = v |
| 173 | + return json.dumps(json_dict) |
| 174 | + |
| 175 | + |
| 176 | +class JWEDecode: |
| 177 | + def __init__(self, compact_jwe=None): |
| 178 | + if compact_jwe is None: |
| 179 | + self.encoded_header = '' |
| 180 | + self.encrypted_key = None |
| 181 | + self.init_vector = None |
| 182 | + self.ciphertext = None |
| 183 | + self.auth_tag = None |
| 184 | + self.protected_header = JWEHeader() |
| 185 | + else: |
| 186 | + parts = compact_jwe.split('.') |
| 187 | + |
| 188 | + self.encoded_header = parts[0] |
| 189 | + header = base64.urlsafe_b64decode(self.encoded_header + '===').decode('ascii') # Fix incorrect padding |
| 190 | + self.protected_header = JWEHeader.from_json_str(header) |
| 191 | + self.encrypted_key = base64.urlsafe_b64decode(parts[1] + '===') |
| 192 | + self.init_vector = base64.urlsafe_b64decode(parts[2] + '===') |
| 193 | + self.ciphertext = base64.urlsafe_b64decode(parts[3] + '===') |
| 194 | + self.auth_tag = base64.urlsafe_b64decode(parts[4] + '===') |
| 195 | + |
| 196 | + def encode_header(self): |
| 197 | + header_json = self.protected_header.to_json_str().replace('": ', '":').replace('", ', '",') |
| 198 | + self.encoded_header = Utils.security_domain_b64_url_encode(header_json.encode('ascii')) |
| 199 | + |
| 200 | + def encode_compact(self): |
| 201 | + ret = [self.encoded_header + '.'] |
| 202 | + |
| 203 | + if self.encrypted_key is not None: |
| 204 | + ret.append(Utils.security_domain_b64_url_encode(self.encrypted_key)) |
| 205 | + ret.append('.') |
| 206 | + |
| 207 | + if self.init_vector is not None: |
| 208 | + ret.append(Utils.security_domain_b64_url_encode(self.init_vector)) |
| 209 | + ret.append('.') |
| 210 | + |
| 211 | + if self.ciphertext is not None: |
| 212 | + ret.append(Utils.security_domain_b64_url_encode(self.ciphertext)) |
| 213 | + ret.append('.') |
| 214 | + |
| 215 | + if self.auth_tag is not None: |
| 216 | + ret.append(Utils.security_domain_b64_url_encode(self.auth_tag)) |
| 217 | + |
| 218 | + return ''.join(ret) |
| 219 | + |
| 220 | + |
| 221 | +class JWE: |
| 222 | + def __init__(self, compact_jwe=None): |
| 223 | + self.jwe_decode = JWEDecode(compact_jwe=compact_jwe) |
| 224 | + |
| 225 | + def encode_compact(self): |
| 226 | + return self.jwe_decode.encode_compact() |
| 227 | + |
| 228 | + def get_padding_mode(self): |
| 229 | + alg = self.jwe_decode.protected_header.alg |
| 230 | + |
| 231 | + if alg == 'RSA-OAEP-256': |
| 232 | + algorithm = hashes.SHA256() |
| 233 | + return asymmetric_padding.OAEP( |
| 234 | + mgf=asymmetric_padding.MGF1(algorithm=algorithm), algorithm=algorithm, label=None) |
| 235 | + |
| 236 | + if alg == 'RSA-OAEP': |
| 237 | + algorithm = hashes.SHA1() |
| 238 | + return asymmetric_padding.OAEP( |
| 239 | + mgf=asymmetric_padding.MGF1(algorithm=algorithm), algorithm=algorithm, label=None) |
| 240 | + |
| 241 | + if alg == 'RSA1_5': |
| 242 | + return asymmetric_padding.PKCS1v15() |
| 243 | + |
| 244 | + return None |
| 245 | + |
| 246 | + def get_cek(self, private_key): |
| 247 | + return private_key.decrypt( |
| 248 | + self.jwe_decode.encrypted_key, |
| 249 | + self.get_padding_mode() |
| 250 | + ) |
| 251 | + |
| 252 | + def set_cek(self, cert, cek): |
| 253 | + public_key = cert.public_key() |
| 254 | + self.jwe_decode.encrypted_key = public_key.encrypt(bytes(cek), self.get_padding_mode()) |
| 255 | + |
| 256 | + @staticmethod |
| 257 | + def dek_from_cek(cek): |
| 258 | + dek = bytearray() |
| 259 | + for i in range(32): |
| 260 | + dek.append(cek[i + 32]) |
| 261 | + return dek |
| 262 | + |
| 263 | + @staticmethod |
| 264 | + def hmac_key_from_cek(cek): |
| 265 | + hk = bytearray() |
| 266 | + for i in range(32): |
| 267 | + hk.append(cek[i]) |
| 268 | + return hk |
| 269 | + |
| 270 | + def get_mac(self, hk): |
| 271 | + header_bytes = bytearray() |
| 272 | + header_bytes.extend(self.jwe_decode.encoded_header.encode('ascii')) |
| 273 | + auth_bits = len(header_bytes) * 8 |
| 274 | + |
| 275 | + hash_data = bytearray() |
| 276 | + hash_data.extend(header_bytes) |
| 277 | + hash_data.extend(self.jwe_decode.init_vector) # type: ignore |
| 278 | + hash_data.extend(self.jwe_decode.ciphertext) # type: ignore |
| 279 | + hash_data.extend(KDF.to_big_endian_64bits(auth_bits)) |
| 280 | + |
| 281 | + hMAC = hmac.HMAC(hk, msg=hash_data, digestmod=hashlib.sha512) |
| 282 | + return hMAC.digest() |
| 283 | + |
| 284 | + def Aes256HmacSha512Decrypt(self, cek): |
| 285 | + dek = JWE.dek_from_cek(cek) |
| 286 | + hk = JWE.hmac_key_from_cek(cek) |
| 287 | + mac_value = self.get_mac(hk) |
| 288 | + |
| 289 | + test = 0 |
| 290 | + i = 0 |
| 291 | + while i < len(self.jwe_decode.auth_tag) == 32: # type: ignore |
| 292 | + test |= (self.jwe_decode.auth_tag[i] ^ mac_value[i]) # type: ignore |
| 293 | + i += 1 |
| 294 | + |
| 295 | + if test != 0: |
| 296 | + return None |
| 297 | + |
| 298 | + aes_key = dek |
| 299 | + aes_iv = self.jwe_decode.init_vector |
| 300 | + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) # type: ignore |
| 301 | + decryptor = cipher.decryptor() |
| 302 | + plaintext = decryptor.update(self.jwe_decode.ciphertext) + decryptor.finalize() # type: ignore |
| 303 | + |
| 304 | + unpadder = padding.PKCS7(128).unpadder() |
| 305 | + plaintext = unpadder.update(bytes(plaintext)) + unpadder.finalize() |
| 306 | + |
| 307 | + return plaintext |
| 308 | + |
| 309 | + def Aes256HmacSha512Encrypt(self, cek, plaintext): |
| 310 | + dek = JWE.dek_from_cek(cek) |
| 311 | + hk = JWE.hmac_key_from_cek(cek) |
| 312 | + |
| 313 | + padder = padding.PKCS7(128).padder() |
| 314 | + plaintext = padder.update(bytes(plaintext)) + padder.finalize() |
| 315 | + |
| 316 | + aes_key = dek |
| 317 | + aes_iv = Utils.get_random(16) |
| 318 | + cipher = Cipher(algorithms.AES(aes_key), modes.CBC(aes_iv), backend=default_backend()) # type: ignore |
| 319 | + encryptor = cipher.encryptor() |
| 320 | + self.jwe_decode.ciphertext = encryptor.update(plaintext) + encryptor.finalize() |
| 321 | + self.jwe_decode.init_vector = aes_iv # type: ignore |
| 322 | + |
| 323 | + mac_value = self.get_mac(hk) |
| 324 | + self.jwe_decode.auth_tag = bytearray() # type: ignore |
| 325 | + for i in range(32): |
| 326 | + self.jwe_decode.auth_tag.append(mac_value[i]) # type: ignore |
| 327 | + |
| 328 | + def decrypt_using_bytes(self, cek): |
| 329 | + if self.jwe_decode.protected_header.enc == 'A256CBC-HS512': |
| 330 | + return self.Aes256HmacSha512Decrypt(cek) |
| 331 | + return None |
| 332 | + |
| 333 | + def get_cek_from_private_key(self, private_key): |
| 334 | + return private_key.decrypt(self.jwe_decode.encrypted_key, self.get_padding_mode()) |
| 335 | + |
| 336 | + def decrypt_using_private_key(self, private_key): |
| 337 | + cek = self.get_cek_from_private_key(private_key) |
| 338 | + return self.decrypt_using_bytes(cek) |
| 339 | + |
| 340 | + def encrypt_using_bytes(self, cek, plaintext, alg_id, kid=None): |
| 341 | + if kid is not None: |
| 342 | + self.jwe_decode.protected_header.alg = 'dir' |
| 343 | + self.jwe_decode.protected_header.kid = kid |
| 344 | + |
| 345 | + if alg_id == 'A256CBC-HS512': |
| 346 | + self.jwe_decode.protected_header.enc = alg_id |
| 347 | + self.jwe_decode.encode_header() |
| 348 | + self.Aes256HmacSha512Encrypt(cek, plaintext) |
| 349 | + |
| 350 | + def encrypt_using_cert(self, cert, plaintext): |
| 351 | + self.jwe_decode.protected_header.alg = 'RSA-OAEP-256' |
| 352 | + self.jwe_decode.protected_header.kid = 'not used' |
| 353 | + cek = Utils.get_random(64) |
| 354 | + self.set_cek(cert, cek) |
| 355 | + self.encrypt_using_bytes(cek, plaintext, alg_id='A256CBC-HS512') |
0 commit comments