Skip to content

Commit 9d52aa5

Browse files
committed
Fix: Additional fixes for device/login flow session lifecycle coherency issues.
Updated unit tests accordingly.
1 parent 4e4b739 commit 9d52aa5

File tree

5 files changed

+76
-107
lines changed

5 files changed

+76
-107
lines changed

credenza/api/session/storage/session_store.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,9 @@ def create_session(self,
149149
session_json = self.crypto_codec.encrypt(session_json)
150150

151151
session_key = access_token if use_access_token_as_session_key else self.generate_session_key()
152-
self.map_session(session_key, session_id)
153-
self.backend.setex(self._key(session_id), session_json, int(expires_at - now))
152+
ttl = int(expires_at - now)
153+
self.map_session(session_key, session_id, ttl)
154+
self.backend.setex(self._key(session_id), session_json, ttl)
154155

155156
logger.debug(f"Created session {session_id} (realm={realm})")
156157
return session_key, session_data
@@ -178,18 +179,21 @@ def get_session_by_session_key(self, session_key) -> (str or None, SessionData o
178179
return session_id, session
179180

180181
def update_session(self, session_id, session_data: SessionData):
181-
session_data.updated_at = time.time()
182-
session_data.expires_at = session_data.updated_at + self.ttl
182+
now = time.time()
183+
session_data.updated_at = now
184+
if session_data.expires_at < (now + self.ttl):
185+
session_data.expires_at = (now + self.ttl)
186+
ttl = int(session_data.expires_at - now)
183187
session_key = self.get_session_key_for_session_id(session_id)
184188
if self.crypto_codec:
185-
session_data = self.crypto_codec.encrypt(json.dumps(session_data.to_dict(), separators=(",", ":")))
189+
session_json = self.crypto_codec.encrypt(json.dumps(session_data.to_dict(), separators=(",", ":")))
186190
else:
187-
session_data = json.dumps(session_data.to_dict(), separators=(",", ":"))
188-
189-
self.map_session(session_key, session_id)
190-
self.backend.setex(self._key(session_id), session_data, self.ttl)
191+
session_json = json.dumps(session_data.to_dict(), separators=(",", ":"))
192+
self.map_session(session_key, session_id, ttl)
193+
self.backend.setex(self._key(session_id), session_json, ttl)
191194

192195
logger.debug(f"Updated session {session_id}")
196+
return session_key, session_data
193197

194198
def delete_session(self, session_id):
195199
session = self.get_session_data(session_id)
@@ -233,9 +237,6 @@ def tag_session_metadata(self, session_id: str, metadata: dict, scope: str = "sy
233237
target = getattr(session.session_metadata, scope)
234238
target.update(metadata)
235239

236-
# Update timestamp
237-
session.updated_at = time.time()
238-
239240
# Save updated session back as dict
240241
self.update_session(session_id, session)
241242
logger.debug(f"Tagged session {session_id} metadata[{scope}]: {metadata}")

credenza/refresh/refresh_worker.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def run_refresh_worker(app):
4646
sub = session.userinfo.get("sub")
4747
sys_metadata = session.session_metadata.system or {}
4848
refresh_expires_at = sys_metadata.get("refresh_expires_at")
49-
if refresh_expires_at and now > refresh_expires_at:
49+
if refresh_expires_at and now > (refresh_expires_at - session_expiry_threshold):
5050
audit_event("refresh_expired", session_id=sid)
5151
revoke_tokens(sid, session)
5252
store.delete_session(sid)
@@ -68,14 +68,8 @@ def run_refresh_worker(app):
6868
if allow_auto_refresh:
6969
modified = bool(refresh_additional_tokens(sid, session)) or modified
7070

71-
# Just bump session TTL if allowed and it is about to expire and hasn't otherwise been modified
72-
if not modified and allow_auto_refresh:
73-
ttl = store.get_ttl(sid)
74-
if 0 < ttl < session_expiry_threshold:
75-
modified = True
76-
7771
if modified:
7872
store.update_session(sid, session)
79-
audit_event("device_session_extended", session_id=sid, user=user, sub=sub, realm=realm)
73+
audit_event("device_session_updated", session_id=sid, user=user, sub=sub, realm=realm)
8074

8175
time.sleep(interval)

credenza/rest/session.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,11 @@ def get_session():
5050
realm = session.realm
5151

5252
if request.method == "PUT":
53-
# Extend session lifetime
54-
session.updated_at = now
55-
session.expires_at = now + store.ttl
56-
5753
# Enforce max refreshable lifetime
54+
session_expiry_threshold = current_app.config.get("SESSION_EXPIRY_THRESHOLD", 300)
5855
refresh_expires_at = session.session_metadata.system.get("refresh_expires_at")
59-
if refresh_expires_at and now > refresh_expires_at:
56+
if refresh_expires_at and now > (refresh_expires_at - session_expiry_threshold):
57+
revoke_tokens(sid, session)
6058
store.delete_session(sid)
6159
audit_event("refresh_expired", session_id=sid)
6260
abort(401, "Session has expired and can no longer be refreshed")
@@ -70,8 +68,13 @@ def get_session():
7068
provider = get_augmentation_provider(realm)
7169
provider.enrich_userinfo(session.userinfo, session.additional_tokens)
7270

73-
store.update_session(sid, session)
74-
audit_event("session_extended", session_id=sid, user=user, sub=sub, realm=realm)
71+
skey, session_data = store.update_session(sid, session)
72+
audit_event("session_updated",
73+
session_id=sid,
74+
user=user,
75+
sub=sub,
76+
realm=realm,
77+
expires_at=datetime.fromtimestamp(session_data.expires_at, timezone.utc).isoformat())
7578

7679
response = make_session_response(sid, session)
7780
return make_json_response(response)

test/refresh/test_refresh_worker.py

Lines changed: 4 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,9 @@ def refresh_access_token(self, refresh_token):
175175
run_refresh_worker(app)
176176

177177
# Verify audit events
178-
# - one session refresh success
178+
# - one session update success
179179
assert any(
180-
ev == "device_session_extended" and kw.get("session_id") == sid
180+
ev == "device_session_updated" and kw.get("session_id") == sid
181181
for ev, kw in audit_calls
182182
), f"Missing success event in {audit_calls}"
183183

@@ -209,52 +209,6 @@ def refresh_access_token(self, refresh_token):
209209
# The "fail" block should have been removed entirely
210210
assert "fail" not in new_sess.additional_tokens
211211

212-
def test_device_session_refresh(app, store, base_session, factory, profiles, audit_calls, frozen_time, monkeypatch):
213-
sid = "S3"
214-
now = int(frozen_time)
215-
app.config["OIDC_CLIENT_FACTORY"] = factory
216-
app.config["OIDC_IDP_PROFILES"] = profiles
217-
218-
# Prepare a device session that has “expired” to force the extension path
219-
sess = copy.deepcopy(base_session)
220-
sess.session_metadata.system = {
221-
"device_session": True,
222-
"allow_automatic_refresh": True,
223-
}
224-
sess.expires_at = now - 1
225-
226-
# Drive exactly one session with TTL below token_expiry_threshold
227-
monkeypatch.setattr(store, "list_session_ids", lambda: [sid])
228-
monkeypatch.setattr(store, "get_session_data", lambda s: sess)
229-
monkeypatch.setattr(store, "get_ttl", lambda s: 10)
230-
231-
# Stub out update_session to mirror real behavior
232-
updated = []
233-
def fake_update_session(session_id, session_data):
234-
# frozen_time == time.time() in this test
235-
session_data.updated_at = frozen_time
236-
session_data.expires_at = session_data.updated_at + store.ttl
237-
updated.append((session_id, session_data))
238-
239-
monkeypatch.setattr(store, "update_session", fake_update_session)
240-
241-
# Run one iteration (should raise StopIteration at loop end)
242-
with app.app_context():
243-
with pytest.raises(StopIteration):
244-
run_refresh_worker(app)
245-
246-
# We expect exactly one session-extension audit event
247-
assert any(ev == "device_session_extended" for ev, _ in audit_calls), \
248-
f"got audit events {audit_calls}"
249-
250-
# update_session should have been called once for our SID
251-
assert len(updated) == 1
252-
called_sid, new_sess = updated[0]
253-
assert called_sid == sid
254-
255-
# expires_at must equal frozen_time + store.ttl
256-
assert new_sess.expires_at == now + store.ttl
257-
258212
def test_device_access_token_refresh(app,
259213
store,
260214
base_session,
@@ -321,8 +275,8 @@ def refresh_access_token(self, refresh_token):
321275
events = [ev for ev, _ in audit_calls]
322276
assert "access_token_refreshed" in events
323277

324-
# And the worker should then have extended the session
325-
assert "device_session_extended" in events
278+
# And the worker should then have emitted the device session updated event
279+
assert "device_session_updated" in events
326280

327281
# We must have called update_session exactly once
328282
assert len(updated) == 1

test/rest/test_session.py

Lines changed: 47 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,29 @@ def test_get_session_invalid_bearer(monkeypatch, client):
8181
assert resp.status_code == 404
8282

8383
def test_put_session_extend(client, store, frozen_time):
84-
sid, session = sm.get_current_session()
84+
# Capture original values before the PUT
85+
before = store.get_session_data("fake_current_sid")
86+
old_expires_at = before.expires_at
87+
8588
resp = client.put("/session")
8689
assert resp.status_code == 200
87-
updated = store.get_session_data(sid)
88-
# Expires_at should have been set to frozen_time + ttl
89-
assert updated.expires_at == frozen_time + store.ttl
90-
# updated_at should match frozen_time
91-
assert updated.updated_at == frozen_time
9290

93-
def test_put_session_expired(client):
94-
# simulate that refresh_expires_at is in the past
91+
after = store.get_session_data("fake_current_sid")
92+
93+
# update_session sets updated_at to time.time() (frozen_time in tests)
94+
assert after.updated_at == pytest.approx(frozen_time, abs=1)
95+
96+
# update_session sets expires_at = max(current_expires_at, now + ttl)
97+
expected_expires = max(old_expires_at, frozen_time + store.ttl)
98+
assert after.expires_at == pytest.approx(expected_expires, abs=1)
99+
100+
101+
def test_put_session_expired(client, app, monkeypatch):
95102
sid, sess = sm.get_current_session()
96-
sess.session_metadata.system["refresh_expires_at"] = int(time.time()) - 1
103+
sess.session_metadata.system["refresh_expires_at"] = int(time.time()) + 60
104+
monkeypatch.setattr(sm, "revoke_tokens", lambda sid, session: None)
105+
with app.app_context():
106+
app.config["SESSION_EXPIRY_THRESHOLD"] = 300
97107
resp = client.put("/session")
98108
assert resp.status_code == 401
99109

@@ -118,11 +128,14 @@ def test_put_session_with_refresh_access_token(client,
118128
sess.access_token = "old_access_token"
119129
sess.id_token = "old_id_token"
120130

131+
# Record current expires_at so we can assert the max() behavior
132+
original_expires_at = sess.expires_at
133+
121134
# Stub get_current_session -> (sid, sess)
122135
monkeypatch.setattr(sm,"get_current_session", lambda: (sid, sess))
123136

124137
# Stub out map_session
125-
monkeypatch.setattr(store, "map_session", lambda session_key, session_id: "dummy-map-key")
138+
monkeypatch.setattr(store, "map_session", lambda session_key, session_id, ttl: "dummy-map-key")
126139

127140
# Configure our dummy OIDC client factory in Flask config
128141
class DummyClient:
@@ -154,7 +167,7 @@ def get_client(self, realm, native_client=False):
154167

155168
# Audit events
156169
assert any(ev == "access_token_refreshed" for ev, _ in audit_calls), audit_calls
157-
assert any(ev == "session_extended" for ev, _ in audit_calls), audit_calls
170+
assert any(ev == "session_updated" for ev, _ in audit_calls), audit_calls
158171

159172
# Fetch the persisted session from the store
160173
updated = store.get_session_data(sid)
@@ -169,8 +182,10 @@ def get_client(self, realm, native_client=False):
169182
assert meta["token_expires_at"] == now + 3600
170183
assert meta["refresh_expires_at"] == now + 7200
171184

172-
# Finally, ensure update_session bumped the session TTL
173-
assert updated.expires_at == pytest.approx(now + store.ttl, abs=1)
185+
# update_session should have bumped updated_at and applied max() for expires_at
186+
assert updated.updated_at == pytest.approx(now, abs=1)
187+
expected_expires = max(original_expires_at, now + store.ttl)
188+
assert updated.expires_at == pytest.approx(expected_expires, abs=1)
174189

175190

176191
def test_put_session_with_refresh_access_token_failure(client,
@@ -193,11 +208,13 @@ def test_put_session_with_refresh_access_token_failure(client,
193208
sess.access_token = "old_access_token"
194209
sess.id_token = "old_id_token"
195210

211+
original_expires_at = sess.expires_at
212+
196213
# Patch get_current_session
197214
monkeypatch.setattr(sm, "get_current_session", lambda: (sid, sess))
198215

199216
# Patch map_session
200-
monkeypatch.setattr(store, "map_session", lambda session_key, session_id: "dummy-map-key")
217+
monkeypatch.setattr(store, "map_session", lambda session_key, session_id, ttl: "dummy-map-key")
201218

202219
# Configure dummy client that fails
203220
class DummyFailClient:
@@ -219,7 +236,7 @@ def get_client(self, realm, native_client=False):
219236

220237
# Check audit includes the failure
221238
assert any(ev == "access_token_refresh_failed" for ev, _ in audit_calls), audit_calls
222-
assert any(ev == "session_extended" for ev, _ in audit_calls), audit_calls
239+
assert any(ev == "session_updated" for ev, _ in audit_calls), audit_calls
223240

224241
# Fetch stored session
225242
updated = store.get_session_data(sid)
@@ -234,8 +251,10 @@ def get_client(self, realm, native_client=False):
234251
assert meta["token_expires_at"] == now - 1
235252
assert meta["refresh_expires_at"] == now + 600
236253

237-
# But TTL was still bumped by update_session
238-
assert updated.expires_at == pytest.approx(now + store.ttl, abs=1)
254+
# update_session should have bumped updated_at and applied max() for expires_at
255+
assert updated.updated_at == pytest.approx(now, abs=1)
256+
expected_expires = max(original_expires_at, now + store.ttl)
257+
assert updated.expires_at == pytest.approx(expected_expires, abs=1)
239258

240259

241260
def test_put_session_additional_tokens_refresh(client,
@@ -250,9 +269,6 @@ def test_put_session_additional_tokens_refresh(client,
250269
threshold = 500
251270

252271
# Prepare a session with four additional token blocks:
253-
# - "good" and "fail" are expired by > threshold
254-
# - "not_due" is far in the future
255-
# - "no_rt" has no refresh_token
256272
sess = copy.deepcopy(base_session)
257273
sess.additional_tokens = {
258274
"good": {"refresh_token": "rt_good", "expires_at": now - threshold - 1},
@@ -265,14 +281,16 @@ def test_put_session_additional_tokens_refresh(client,
265281
sess.expires_at = now + 10000
266282

267283
# Stub get_current_session so GET/PUT /session uses our SID and session
268-
monkeypatch.setattr(sm,"get_current_session", lambda: (sid, sess))
284+
monkeypatch.setattr(sm, "get_current_session", lambda: (sid, sess))
269285

270-
# Stub out map_session
271-
monkeypatch.setattr(store, "map_session", lambda session_key, session_id: "dummy-map-key")
286+
monkeypatch.setattr(store, "map_session", lambda session_key, session_id, ttl: "dummy-map-key")
272287

273-
# Capture calls to update_session
274-
updated = []
275-
monkeypatch.setattr(store, "update_session", lambda s, sd: updated.append((s, sd)))
288+
# Capture calls to update_session, and make it return (session_key, session_data)
289+
updated_calls = []
290+
def _update_session(s, sd):
291+
updated_calls.append((s, sd))
292+
return ("dummy-map-key", sd)
293+
monkeypatch.setattr(store, "update_session", _update_session)
276294

277295
# Provide a DummyClient that succeeds for "rt_good" and raises for "rt_fail"
278296
class DummyClient:
@@ -297,7 +315,6 @@ def get_client(self, realm, native_client=False):
297315

298316
with app.app_context():
299317
app.config["TOKEN_EXPIRY_THRESHOLD"] = 300
300-
# Inject our dummy factory
301318
app.config["OIDC_CLIENT_FACTORY"] = DummyFactory()
302319

303320
# Perform the PUT /session
@@ -310,8 +327,8 @@ def get_client(self, realm, native_client=False):
310327
assert "additional_token_refresh_failed" in events, audit_calls
311328

312329
# Only the "good" path should have resulted in update_session
313-
assert len(updated) == 1
314-
called_sid, new_sess = updated[0]
330+
assert len(updated_calls) == 1
331+
called_sid, new_sess = updated_calls[0]
315332
assert called_sid == sid
316333

317334
# Verify the "good" block was updated correctly
@@ -434,4 +451,4 @@ def test_make_session_response_legacy(app, store, base_session):
434451
assert abs(since_dt.timestamp() - base_session.created_at) < 1
435452
assert abs(expires_dt.timestamp() - base_session.expires_at) < 1
436453

437-
assert resp["seconds_remaining"] == store.get_ttl("sid123")
454+
assert resp["seconds_remaining"] == store.get_ttl("sid123")

0 commit comments

Comments
 (0)