|
10 | 10 | # fallback on pysqlite2 if Python was build without sqlite
|
11 | 11 | from pysqlite2 import dbapi2 as sqlite3
|
12 | 12 |
|
| 13 | +from dataclasses import dataclass, fields |
| 14 | +from typing import Union |
| 15 | + |
13 | 16 | from tornado import web
|
14 | 17 | from traitlets import Instance, TraitError, Unicode, validate
|
15 | 18 | from traitlets.config.configurable import LoggingConfigurable
|
|
18 | 21 | from jupyter_server.utils import ensure_async
|
19 | 22 |
|
20 | 23 |
|
| 24 | +class KernelSessionRecordConflict(Exception): |
| 25 | + """Exception class to use when two KernelSessionRecords cannot |
| 26 | + merge because of conflicting data. |
| 27 | + """ |
| 28 | + |
| 29 | + pass |
| 30 | + |
| 31 | + |
| 32 | +@dataclass |
| 33 | +class KernelSessionRecord: |
| 34 | + """A record object for tracking a Jupyter Server Kernel Session. |
| 35 | +
|
| 36 | + Two records that share a session_id must also share a kernel_id, while |
| 37 | + kernels can have multiple session (and thereby) session_ids |
| 38 | + associated with them. |
| 39 | + """ |
| 40 | + |
| 41 | + session_id: Union[None, str] = None |
| 42 | + kernel_id: Union[None, str] = None |
| 43 | + |
| 44 | + def __eq__(self, other: "KernelSessionRecord") -> bool: |
| 45 | + if isinstance(other, KernelSessionRecord): |
| 46 | + condition1 = self.kernel_id and self.kernel_id == other.kernel_id |
| 47 | + condition2 = all( |
| 48 | + [ |
| 49 | + self.session_id == other.session_id, |
| 50 | + self.kernel_id is None or other.kernel_id is None, |
| 51 | + ] |
| 52 | + ) |
| 53 | + if any([condition1, condition2]): |
| 54 | + return True |
| 55 | + # If two records share session_id but have different kernels, this is |
| 56 | + # and ill-posed expression. This should never be true. Raise an exception |
| 57 | + # to inform the user. |
| 58 | + if all( |
| 59 | + [ |
| 60 | + self.session_id, |
| 61 | + self.session_id == other.session_id, |
| 62 | + self.kernel_id != other.kernel_id, |
| 63 | + ] |
| 64 | + ): |
| 65 | + raise KernelSessionRecordConflict( |
| 66 | + "A single session_id can only have one kernel_id " |
| 67 | + "associated with. These two KernelSessionRecords share the same " |
| 68 | + "session_id but have different kernel_ids. This should " |
| 69 | + "not be possible and is likely an issue with the session " |
| 70 | + "records." |
| 71 | + ) |
| 72 | + return False |
| 73 | + |
| 74 | + def update(self, other: "KernelSessionRecord") -> None: |
| 75 | + """Updates in-place a kernel from other (only accepts positive updates""" |
| 76 | + if not isinstance(other, KernelSessionRecord): |
| 77 | + raise TypeError("'other' must be an instance of KernelSessionRecord.") |
| 78 | + |
| 79 | + if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id: |
| 80 | + raise KernelSessionRecordConflict( |
| 81 | + "Could not update the record from 'other' because the two records conflict." |
| 82 | + ) |
| 83 | + |
| 84 | + for field in fields(self): |
| 85 | + if hasattr(other, field.name) and getattr(other, field.name): |
| 86 | + setattr(self, field.name, getattr(other, field.name)) |
| 87 | + |
| 88 | + |
| 89 | +class KernelSessionRecordList: |
| 90 | + """An object for storing and managing a list of KernelSessionRecords. |
| 91 | +
|
| 92 | + When adding a record to the list, the KernelSessionRecordList |
| 93 | + first checks if the record already exists in the list. If it does, |
| 94 | + the record will be updated with the new information; otherwise, |
| 95 | + it will be appended. |
| 96 | + """ |
| 97 | + |
| 98 | + def __init__(self, *records): |
| 99 | + self._records = [] |
| 100 | + for record in records: |
| 101 | + self.update(record) |
| 102 | + |
| 103 | + def __str__(self): |
| 104 | + return str(self._records) |
| 105 | + |
| 106 | + def __contains__(self, record: Union[KernelSessionRecord, str]): |
| 107 | + """Search for records by kernel_id and session_id""" |
| 108 | + if isinstance(record, KernelSessionRecord) and record in self._records: |
| 109 | + return True |
| 110 | + |
| 111 | + if isinstance(record, str): |
| 112 | + for r in self._records: |
| 113 | + if record in [r.session_id, r.kernel_id]: |
| 114 | + return True |
| 115 | + return False |
| 116 | + |
| 117 | + def __len__(self): |
| 118 | + return len(self._records) |
| 119 | + |
| 120 | + def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord: |
| 121 | + """Return a full KernelSessionRecord from a session_id, kernel_id, or |
| 122 | + incomplete KernelSessionRecord. |
| 123 | + """ |
| 124 | + if isinstance(record, str): |
| 125 | + for r in self._records: |
| 126 | + if record == r.kernel_id or record == r.session_id: |
| 127 | + return r |
| 128 | + elif isinstance(record, KernelSessionRecord): |
| 129 | + for r in self._records: |
| 130 | + if record == r: |
| 131 | + return record |
| 132 | + raise ValueError(f"{record} not found in KernelSessionRecordList.") |
| 133 | + |
| 134 | + def update(self, record: KernelSessionRecord) -> None: |
| 135 | + """Update a record in-place or append it if not in the list.""" |
| 136 | + try: |
| 137 | + idx = self._records.index(record) |
| 138 | + self._records[idx].update(record) |
| 139 | + except ValueError: |
| 140 | + self._records.append(record) |
| 141 | + |
| 142 | + def remove(self, record: KernelSessionRecord) -> None: |
| 143 | + """Remove a record if its found in the list. If it's not found, |
| 144 | + do nothing. |
| 145 | + """ |
| 146 | + if record in self._records: |
| 147 | + self._records.remove(record) |
| 148 | + |
| 149 | + |
21 | 150 | class SessionManager(LoggingConfigurable):
|
22 | 151 |
|
23 | 152 | database_filepath = Unicode(
|
@@ -58,6 +187,10 @@ def _validate_database_filepath(self, proposal):
|
58 | 187 | ]
|
59 | 188 | )
|
60 | 189 |
|
| 190 | + def __init__(self, *args, **kwargs): |
| 191 | + super().__init__(*args, **kwargs) |
| 192 | + self._pending_sessions = KernelSessionRecordList() |
| 193 | + |
61 | 194 | # Session database initialized below
|
62 | 195 | _cursor = None
|
63 | 196 | _connection = None
|
@@ -118,15 +251,20 @@ async def create_session(
|
118 | 251 | ):
|
119 | 252 | """Creates a session and returns its model"""
|
120 | 253 | session_id = self.new_session_id()
|
| 254 | + record = KernelSessionRecord(session_id=session_id) |
| 255 | + self._pending_sessions.update(record) |
121 | 256 | if kernel_id is not None and kernel_id in self.kernel_manager:
|
122 | 257 | pass
|
123 | 258 | else:
|
124 | 259 | kernel_id = await self.start_kernel_for_session(
|
125 | 260 | session_id, path, name, type, kernel_name
|
126 | 261 | )
|
| 262 | + record.kernel_id = kernel_id |
| 263 | + self._pending_sessions.update(record) |
127 | 264 | result = await self.save_session(
|
128 | 265 | session_id, path=path, name=name, type=type, kernel_id=kernel_id
|
129 | 266 | )
|
| 267 | + self._pending_sessions.remove(record) |
130 | 268 | return result
|
131 | 269 |
|
132 | 270 | async def start_kernel_for_session(self, session_id, path, name, type, kernel_name):
|
@@ -305,6 +443,9 @@ async def list_sessions(self):
|
305 | 443 |
|
306 | 444 | async def delete_session(self, session_id):
|
307 | 445 | """Deletes the row in the session database with given session_id"""
|
| 446 | + record = KernelSessionRecord(session_id=session_id) |
| 447 | + self._pending_sessions.update(record) |
308 | 448 | session = await self.get_session(session_id=session_id)
|
309 | 449 | await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"]))
|
310 | 450 | self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
|
| 451 | + self._pending_sessions.remove(record) |
0 commit comments