Skip to content

Commit d32b887

Browse files
add hook to observe pending sessions (#751)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 57c0676 commit d32b887

File tree

2 files changed

+304
-7
lines changed

2 files changed

+304
-7
lines changed

jupyter_server/services/sessions/sessionmanager.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
# fallback on pysqlite2 if Python was build without sqlite
1111
from pysqlite2 import dbapi2 as sqlite3
1212

13+
from dataclasses import dataclass, fields
14+
from typing import Union
15+
1316
from tornado import web
1417
from traitlets import Instance, TraitError, Unicode, validate
1518
from traitlets.config.configurable import LoggingConfigurable
@@ -18,6 +21,132 @@
1821
from jupyter_server.utils import ensure_async
1922

2023

24+
class KernelSessionRecordConflict(Exception):
25+
"""Exception class to use when two KernelSessionRecords cannot
26+
merge because of conflicting data.
27+
"""
28+
29+
pass
30+
31+
32+
@dataclass
33+
class KernelSessionRecord:
34+
"""A record object for tracking a Jupyter Server Kernel Session.
35+
36+
Two records that share a session_id must also share a kernel_id, while
37+
kernels can have multiple session (and thereby) session_ids
38+
associated with them.
39+
"""
40+
41+
session_id: Union[None, str] = None
42+
kernel_id: Union[None, str] = None
43+
44+
def __eq__(self, other: "KernelSessionRecord") -> bool:
45+
if isinstance(other, KernelSessionRecord):
46+
condition1 = self.kernel_id and self.kernel_id == other.kernel_id
47+
condition2 = all(
48+
[
49+
self.session_id == other.session_id,
50+
self.kernel_id is None or other.kernel_id is None,
51+
]
52+
)
53+
if any([condition1, condition2]):
54+
return True
55+
# If two records share session_id but have different kernels, this is
56+
# and ill-posed expression. This should never be true. Raise an exception
57+
# to inform the user.
58+
if all(
59+
[
60+
self.session_id,
61+
self.session_id == other.session_id,
62+
self.kernel_id != other.kernel_id,
63+
]
64+
):
65+
raise KernelSessionRecordConflict(
66+
"A single session_id can only have one kernel_id "
67+
"associated with. These two KernelSessionRecords share the same "
68+
"session_id but have different kernel_ids. This should "
69+
"not be possible and is likely an issue with the session "
70+
"records."
71+
)
72+
return False
73+
74+
def update(self, other: "KernelSessionRecord") -> None:
75+
"""Updates in-place a kernel from other (only accepts positive updates"""
76+
if not isinstance(other, KernelSessionRecord):
77+
raise TypeError("'other' must be an instance of KernelSessionRecord.")
78+
79+
if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id:
80+
raise KernelSessionRecordConflict(
81+
"Could not update the record from 'other' because the two records conflict."
82+
)
83+
84+
for field in fields(self):
85+
if hasattr(other, field.name) and getattr(other, field.name):
86+
setattr(self, field.name, getattr(other, field.name))
87+
88+
89+
class KernelSessionRecordList:
90+
"""An object for storing and managing a list of KernelSessionRecords.
91+
92+
When adding a record to the list, the KernelSessionRecordList
93+
first checks if the record already exists in the list. If it does,
94+
the record will be updated with the new information; otherwise,
95+
it will be appended.
96+
"""
97+
98+
def __init__(self, *records):
99+
self._records = []
100+
for record in records:
101+
self.update(record)
102+
103+
def __str__(self):
104+
return str(self._records)
105+
106+
def __contains__(self, record: Union[KernelSessionRecord, str]):
107+
"""Search for records by kernel_id and session_id"""
108+
if isinstance(record, KernelSessionRecord) and record in self._records:
109+
return True
110+
111+
if isinstance(record, str):
112+
for r in self._records:
113+
if record in [r.session_id, r.kernel_id]:
114+
return True
115+
return False
116+
117+
def __len__(self):
118+
return len(self._records)
119+
120+
def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord:
121+
"""Return a full KernelSessionRecord from a session_id, kernel_id, or
122+
incomplete KernelSessionRecord.
123+
"""
124+
if isinstance(record, str):
125+
for r in self._records:
126+
if record == r.kernel_id or record == r.session_id:
127+
return r
128+
elif isinstance(record, KernelSessionRecord):
129+
for r in self._records:
130+
if record == r:
131+
return record
132+
raise ValueError(f"{record} not found in KernelSessionRecordList.")
133+
134+
def update(self, record: KernelSessionRecord) -> None:
135+
"""Update a record in-place or append it if not in the list."""
136+
try:
137+
idx = self._records.index(record)
138+
self._records[idx].update(record)
139+
except ValueError:
140+
self._records.append(record)
141+
142+
def remove(self, record: KernelSessionRecord) -> None:
143+
"""Remove a record if its found in the list. If it's not found,
144+
do nothing.
145+
"""
146+
if record in self._records:
147+
self._records.remove(record)
148+
149+
21150
class SessionManager(LoggingConfigurable):
22151

23152
database_filepath = Unicode(
@@ -58,6 +187,10 @@ def _validate_database_filepath(self, proposal):
58187
]
59188
)
60189

190+
def __init__(self, *args, **kwargs):
191+
super().__init__(*args, **kwargs)
192+
self._pending_sessions = KernelSessionRecordList()
193+
61194
# Session database initialized below
62195
_cursor = None
63196
_connection = None
@@ -118,15 +251,20 @@ async def create_session(
118251
):
119252
"""Creates a session and returns its model"""
120253
session_id = self.new_session_id()
254+
record = KernelSessionRecord(session_id=session_id)
255+
self._pending_sessions.update(record)
121256
if kernel_id is not None and kernel_id in self.kernel_manager:
122257
pass
123258
else:
124259
kernel_id = await self.start_kernel_for_session(
125260
session_id, path, name, type, kernel_name
126261
)
262+
record.kernel_id = kernel_id
263+
self._pending_sessions.update(record)
127264
result = await self.save_session(
128265
session_id, path=path, name=name, type=type, kernel_id=kernel_id
129266
)
267+
self._pending_sessions.remove(record)
130268
return result
131269

132270
async def start_kernel_for_session(self, session_id, path, name, type, kernel_name):
@@ -305,6 +443,9 @@ async def list_sessions(self):
305443

306444
async def delete_session(self, session_id):
307445
"""Deletes the row in the session database with given session_id"""
446+
record = KernelSessionRecord(session_id=session_id)
447+
self._pending_sessions.update(record)
308448
session = await self.get_session(session_id=session_id)
309449
await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"]))
310450
self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
451+
self._pending_sessions.remove(record)

0 commit comments

Comments
 (0)