@@ -489,15 +489,20 @@ def encrypt(self, key, iv="", cek="", **kwargs):
489489 else :
490490 raise ParameterError ("Zip has unknown value: %s" % self ["zip" ])
491491
492+ kwarg_cek = cek or None
493+
492494 _enc = self ["enc" ]
493495 cek , iv = self ._generate_key_and_iv (_enc , cek , iv )
496+ self ["cek" ] = cek
494497
495498 logger .debug ("cek: %s, iv: %s" % ([c for c in cek ], [c for c in iv ]))
496499
497500 _encrypt = RSAEncrypter (self .with_digest ).encrypt
498501
499502 _alg = self ["alg" ]
500- if _alg == "RSA-OAEP" :
503+ if kwarg_cek :
504+ jwe_enc_key = ''
505+ elif _alg == "RSA-OAEP" :
501506 jwe_enc_key = _encrypt (cek , key , 'pkcs1_oaep_padding' )
502507 elif _alg == "RSA1_5" :
503508 jwe_enc_key = _encrypt (cek , key )
@@ -511,7 +516,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):
511516 ctxt , tag , key = self .enc_setup (_enc , _msg , enc_header , cek , iv )
512517 return jwe .pack (parts = [jwe_enc_key , iv , ctxt , tag ])
513518
514- def decrypt (self , token , key ):
519+ def decrypt (self , token , key , cek = None ):
515520 """ Decrypts a JWT
516521
517522 :param token: The JWT
@@ -529,13 +534,16 @@ def decrypt(self, token, key):
529534 _decrypt = RSAEncrypter (self .with_digest ).decrypt
530535
531536 _alg = jwe .headers ["alg" ]
532- if _alg == "RSA-OAEP" :
537+ if cek :
538+ pass
539+ elif _alg == "RSA-OAEP" :
533540 cek = _decrypt (jek , key , 'pkcs1_oaep_padding' )
534541 elif _alg == "RSA1_5" :
535542 cek = _decrypt (jek , key )
536543 else :
537544 raise NotSupportedAlgorithm (_alg )
538545
546+ self ["cek" ] = cek
539547 enc = jwe .headers ["enc" ]
540548 try :
541549 assert enc in SUPPORTED ["enc" ]
@@ -687,7 +695,7 @@ def encrypt(self, key, iv="", cek="", **kwargs):
687695 return jwe .pack (parts = [kwargs ['encrypted_key' ], iv , ctxt , tag ])
688696 return jwe .pack (parts = [iv , ctxt , tag ])
689697
690- def decrypt (self , token = None , key = None ):
698+ def decrypt (self , token = None , key = None , ** kwargs ):
691699
692700 if not self .cek :
693701 raise Exception ("Content Encryption Key is Not Yet Set" )
@@ -747,7 +755,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
747755 :return: Encrypted message
748756 """
749757
750- encrypted_key = cek = iv = None
758+ # encrypted_key = cek = iv = None
751759 _alg = self ["alg" ]
752760
753761 # Find Usable Keys
@@ -801,6 +809,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
801809
802810 try :
803811 token = encrypter .encrypt (_key , ** kwargs )
812+ self ["cek" ] = encrypter .cek if 'cek' in encrypter else None
804813 except TypeError as err :
805814 raise err
806815 else :
@@ -811,7 +820,7 @@ def encrypt(self, keys=None, cek="", iv="", **kwargs):
811820 logger .error ("Could not find any suitable encryption key" )
812821 raise NoSuitableEncryptionKey ()
813822
814- def decrypt (self , token = None , keys = None , alg = None ):
823+ def decrypt (self , token = None , keys = None , alg = None , cek = None ):
815824 if token :
816825 jwe = JWEnc ().unpack (token )
817826 # header, ek, eiv, ctxt, tag = token.split(b".")
@@ -829,7 +838,7 @@ def decrypt(self, token=None, keys=None, alg=None):
829838 else :
830839 keys = self ._pick_keys (self ._get_keys (), use = "enc" , alg = _alg )
831840
832- if not keys :
841+ if not keys and not cek :
833842 raise NoSuitableDecryptionKey (_alg )
834843
835844 if _alg in ["RSA-OAEP" , "RSA1_5" ]:
@@ -847,10 +856,21 @@ def decrypt(self, token=None, keys=None, alg=None):
847856 else :
848857 raise NotSupportedAlgorithm
849858
859+ if cek :
860+ try :
861+ msg = decrypter .decrypt (as_bytes (token ), None , cek = cek )
862+ self ["cek" ] = decrypter .cek if 'cek' in decrypter else None
863+ except (KeyError , DecryptionFailed ):
864+ pass
865+ else :
866+ logger .debug ("Decrypted message using exiting CEK" )
867+ return msg
868+
850869 for key in keys :
851870 _key = key .encryption_key (alg = _alg , private = False )
852871 try :
853872 msg = decrypter .decrypt (as_bytes (token ), _key )
873+ self ["cek" ] = decrypter .cek if 'cek' in decrypter else None
854874 except (KeyError , DecryptionFailed ):
855875 pass
856876 else :
0 commit comments