Skip to content

Commit edab740

Browse files
committed
Stricter
1 parent 3a80531 commit edab740

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

src/cryptojwt/jwe/jwe.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import logging
22

3-
from cryptojwt.jwk.jwk import key_from_jwk_dict
43
from ..jwk.asym import AsymmetricKey
4+
from ..jwk.ec import ECKey
5+
from ..jwk.hmac import SYMKey
6+
from ..jwk.jwk import key_from_jwk_dict
7+
from ..jwk.rsa import RSAKey
58
from ..jwx import JWx
69

710
from .exception import DecryptionFailed
@@ -82,14 +85,16 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
8285

8386
# Determine Encryption Class by Algorithm
8487
if _alg in ["RSA-OAEP", "RSA-OAEP-256", "RSA1_5"]:
88+
keys = [k for k in keys if isinstance(k, RSAKey)]
8589
encrypter = JWE_RSA(self.msg, **self._dict)
8690
elif _alg.startswith("A") and _alg.endswith("KW"):
91+
keys = [k for k in keys if isinstance(k, SYMKey)]
8792
encrypter = JWE_SYM(self.msg, **self._dict)
8893
elif _alg.startswith("ECDH-ES"):
89-
90-
# ECDH-ES Requires the Server ECDH-ES Key to be set
94+
keys = [k for k in keys if isinstance(k, ECKey)]
9195
if not keys:
92-
raise NoSuitableECDHKey(_alg)
96+
logger.error(KEY_ERR.format(_alg))
97+
raise NoSuitableEncryptionKey(_alg)
9398

9499
encrypter = JWE_EC(**self._dict)
95100
cek, encrypted_key, iv, params, eprivk = encrypter.enc_setup(
@@ -100,17 +105,25 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
100105
logger.error("'{}' is not a supported algorithm".format(_alg))
101106
raise NotSupportedAlgorithm
102107

108+
if not keys:
109+
logger.error(KEY_ERR.format(_alg))
110+
raise NoSuitableEncryptionKey(_alg)
111+
103112
if cek:
104113
kwargs["cek"] = cek
105114

106115
if iv:
107116
kwargs["iv"] = iv
108117

109118
for key in keys:
110-
if isinstance(key, AsymmetricKey):
119+
if isinstance(key, SYMKey):
120+
_key = key.key
121+
elif isinstance(key, ECKey):
111122
_key = key.public_key()
123+
elif isinstance(key, RSAKey):
124+
_key = key.public_key()
112125
else:
113-
_key = key.key
126+
raise ValueError('Unknown key type')
114127

115128
if key.kid:
116129
encrypter["kid"] = key.kid

src/cryptojwt/jwe/jwe_hmac.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..exception import MissingKey
1212
from ..exception import WrongNumberOfParts
1313
from ..jwk.hmac import SYMKey
14-
from ..utils import intarr2str
14+
from ..utils import intarr2str, as_bytes
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -31,7 +31,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):
3131
:param kwargs: Extra keyword arguments, just ignore for now.
3232
:return:
3333
"""
34-
_msg = self.msg
34+
_msg = as_bytes(self.msg)
3535

3636
_args = self._dict
3737
try:
@@ -60,8 +60,8 @@ def encrypt(self, key, iv="", cek="", **kwargs):
6060

6161
_enc = self["enc"]
6262
_auth_data = jwe.b64_encode_header()
63-
ctxt, tag, cek = self.enc_setup(_enc, _msg.encode(),
64-
auth_data=_auth_data, key=cek, iv=iv)
63+
ctxt, tag, cek = self.enc_setup(_enc, _msg, auth_data=_auth_data,
64+
key=cek, iv=iv)
6565
return jwe.pack(parts=[jek, iv, ctxt, tag])
6666

6767
def decrypt(self, token, key=None, cek=None):

tests/test_07_jwe.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,4 +393,15 @@ def test_encrypt_jwk_key():
393393
assert _enc
394394
decryptor = factory(_enc, alg="ECDH-ES", enc="A128GCM")
395395
res = decryptor.decrypt()
396-
assert res == plain
396+
assert res == plain
397+
398+
399+
def test_sym_encrypt_decrypt_JWE():
400+
encryption_key = SYMKey(use="enc", key='DukeofHazardpass',
401+
kid="some-key-id")
402+
jwe = JWE(plain, alg="A128KW", enc="A128CBC-HS256")
403+
_jwe = jwe.encrypt(keys=[encryption_key], kid="some-key-id")
404+
decryptor = factory(_jwe, alg="A128KW", enc="A128CBC-HS256")
405+
406+
resp = decryptor.decrypt(_jwe, [encryption_key])
407+
assert resp == plain

0 commit comments

Comments
 (0)