Skip to content

Commit 759bd39

Browse files
janste63c00kiemon5ter
authored andcommitted
Fix 2 inactive key problems.
Fix that updated keys are marked as inactive if we get 304 from server. Make sure we return inactive keys when calling get_jwt_verify_keys().
1 parent 2b00abb commit 759bd39

File tree

5 files changed

+28
-16
lines changed

5 files changed

+28
-16
lines changed

src/cryptojwt/key_bundle.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def do_remote(self):
364364
"""
365365
Load a JWKS from a webpage.
366366
367-
:return: True or False if load was successful
367+
:return: True if load was successful or False if remote hasn't been modified
368368
"""
369369
# if self.verify_ssl is not None:
370370
# self.httpc_params["verify"] = self.verify_ssl
@@ -408,10 +408,12 @@ def do_remote(self):
408408
if hasattr(_http_resp, "headers"):
409409
headers = getattr(_http_resp, "headers")
410410
self.last_remote = headers.get("last-modified") or headers.get("date")
411+
res = True
411412

412413
elif _http_resp.status_code == 304: # Not modified
413414
LOGGER.debug("%s not modified since %s", self.source, self.last_remote)
414415
self.time_out = time.time() + self.cache_time
416+
res = False
415417

416418
else:
417419
LOGGER.warning(
@@ -424,7 +426,7 @@ def do_remote(self):
424426

425427
self.last_updated = time.time()
426428
self.ignore_errors_until = None
427-
return True
429+
return res
428430

429431
def _parse_remote_response(self, response):
430432
"""
@@ -465,7 +467,6 @@ def update(self):
465467
This is a forced update, will happen even if cache time has not elapsed.
466468
Replaced keys will be marked as inactive and not removed.
467469
"""
468-
res = True # An update was successful
469470
if self.source:
470471
_old_keys = self._keys # just in case
471472

@@ -478,21 +479,25 @@ def update(self):
478479
self.do_local_jwk(self.source)
479480
elif self.fileformat == "der":
480481
self.do_local_der(self.source, self.keytype, self.keyusage)
482+
updated = True
481483
elif self.remote:
482-
res = self.do_remote()
484+
updated = self.do_remote()
483485
except Exception as err:
484486
LOGGER.error("Key bundle update failed: %s", err)
485487
self._keys = _old_keys # restore
486488
return False
487489

488-
now = time.time()
489-
for _key in _old_keys:
490-
if _key not in self._keys:
491-
if not _key.inactive_since: # If already marked don't mess
492-
_key.inactive_since = now
493-
self._keys.append(_key)
490+
if updated:
491+
now = time.time()
492+
for _key in _old_keys:
493+
if _key not in self._keys:
494+
if not _key.inactive_since: # If already marked don't mess
495+
_key.inactive_since = now
496+
self._keys.append(_key)
497+
else:
498+
self._keys = _old_keys
494499

495-
return res
500+
return True
496501

497502
def get(self, typ="", only_active=True):
498503
"""

src/cryptojwt/key_issuer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def get(self, key_use, key_type="", kid=None, alg="", **kwargs):
297297
else:
298298
_bkeys = bundle.keys()
299299
for key in _bkeys:
300-
if key.inactive_since and key_use != "sig":
300+
if key.inactive_since and key_use != "ver":
301301
# Skip inactive keys unless for signature verification
302302
continue
303303
if not key.use or use == key.use:

src/cryptojwt/key_jar.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -601,13 +601,13 @@ def get_jwt_verify_keys(self, jwt, **kwargs):
601601
except KeyError:
602602
pass
603603

604-
keys = self._add_key([], _iss, "sig", _key_type, _kid, nki, allow_missing_kid)
604+
keys = self._add_key([], _iss, "ver", _key_type, _kid, nki, allow_missing_kid)
605605

606606
if _key_type == "oct":
607-
keys.extend(self.get(key_use="sig", issuer_id="", key_type=_key_type))
607+
keys.extend(self.get(key_use="ver", issuer_id="", key_type=_key_type))
608608
else:
609609
# No issuer, just use all keys I have
610-
keys = self.get(key_use="sig", issuer_id="", key_type=_key_type)
610+
keys = self.get(key_use="ver", issuer_id="", key_type=_key_type)
611611

612612
# Only want the appropriate keys.
613613
keys = [k for k in keys if k.appropriate_for("verify")]

tests/test_03_key_bundle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1009,7 +1009,7 @@ def test_remote_not_modified():
10091009

10101010
with responses.RequestsMock() as rsps:
10111011
rsps.add(method="GET", url=source, status=304, headers=headers)
1012-
assert kb.do_remote()
1012+
assert not kb.do_remote()
10131013
assert kb.last_remote == headers.get("Last-Modified")
10141014
timeout2 = kb.time_out
10151015

@@ -1019,6 +1019,7 @@ def test_remote_not_modified():
10191019
kb2 = KeyBundle().load(exp)
10201020
assert kb2.source == source
10211021
assert len(kb2.keys()) == 3
1022+
assert len(kb2.active_keys()) == 3
10221023
assert len(kb2.get("rsa")) == 1
10231024
assert len(kb2.get("oct")) == 1
10241025
assert len(kb2.get("ec")) == 1

tests/test_04_key_jar.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,12 @@ def test_aud(self):
746746
keys = self.bob_keyjar.get_jwt_verify_keys(_jwt.jwt, no_kid_issuer=no_kid_issuer)
747747
assert len(keys) == 1
748748

749+
def test_inactive_verify_key(self):
750+
_jwt = factory(self.sjwt_b)
751+
self.alice_keyjar.return_issuer("Bob")[0].mark_all_as_inactive()
752+
keys = self.alice_keyjar.get_jwt_verify_keys(_jwt.jwt)
753+
assert len(keys) == 1
754+
749755

750756
def test_copy():
751757
kj = KeyJar()

0 commit comments

Comments
 (0)