|
7 | 7 | # Copyright (c) Jupyter Development Team.
|
8 | 8 | # Distributed under the terms of the Modified BSD License.
|
9 | 9 |
|
| 10 | +from collections import defaultdict |
| 11 | +from functools import partial |
10 | 12 | import os
|
11 | 13 |
|
12 | 14 | from tornado import gen, web
|
|
15 | 17 |
|
16 | 18 | from jupyter_client.session import Session
|
17 | 19 | from jupyter_client.multikernelmanager import MultiKernelManager
|
18 |
| -from traitlets import Bool, Dict, List, Unicode, TraitError, Integer, default, validate |
| 20 | +from traitlets import Any, Bool, Dict, List, Unicode, TraitError, Integer, default, validate |
19 | 21 |
|
20 | 22 | from notebook.utils import to_os_path, exists
|
21 | 23 | from notebook._tz import utcnow, isoformat
|
22 | 24 | from ipython_genutils.py3compat import getcwd
|
23 | 25 |
|
24 |
| -from datetime import datetime, timedelta |
| 26 | +from datetime import timedelta |
25 | 27 |
|
26 | 28 |
|
27 | 29 | class MappingKernelManager(MultiKernelManager):
|
@@ -81,6 +83,11 @@ def _update_root_dir(self, proposal):
|
81 | 83 | Only effective if cull_idle_timeout is not 0."""
|
82 | 84 | )
|
83 | 85 |
|
| 86 | + _kernel_buffers = Any() |
| 87 | + @default('_kernel_buffers') |
| 88 | + def _default_kernel_buffers(self): |
| 89 | + return defaultdict(lambda: {'buffer': [], 'session_key': '', 'channels': {}}) |
| 90 | + |
84 | 91 | #-------------------------------------------------------------------------
|
85 | 92 | # Methods for managing kernels and sessions
|
86 | 93 | #-------------------------------------------------------------------------
|
@@ -142,10 +149,97 @@ def start_kernel(self, kernel_id=None, path=None, **kwargs):
|
142 | 149 | # py2-compat
|
143 | 150 | raise gen.Return(kernel_id)
|
144 | 151 |
|
| 152 | + def start_buffering(self, kernel_id, session_key, channels): |
| 153 | + """Start buffering messages for a kernel |
| 154 | +
|
| 155 | + Parameters |
| 156 | + ---------- |
| 157 | + kernel_id : str |
| 158 | + The id of the kernel to stop buffering. |
| 159 | + session_key: str |
| 160 | + The session_key, if any, that should get the buffer. |
| 161 | + If the session_key matches the current buffered session_key, |
| 162 | + the buffer will be returned. |
| 163 | + channels: dict({'channel': ZMQStream}) |
| 164 | + The zmq channels whose messages should be buffered. |
| 165 | + """ |
| 166 | + self.log.info("Starting buffering for %s", session_key) |
| 167 | + self._check_kernel_id(kernel_id) |
| 168 | + # clear previous buffering state |
| 169 | + self.stop_buffering(kernel_id) |
| 170 | + buffer_info = self._kernel_buffers[kernel_id] |
| 171 | + # record the session key because only one session can buffer |
| 172 | + buffer_info['session_key'] = session_key |
| 173 | + # TODO: the buffer should likely be a memory bounded queue, we're starting with a list to keep it simple |
| 174 | + buffer_info['buffer'] = [] |
| 175 | + buffer_info['channels'] = channels |
| 176 | + |
| 177 | + # forward any future messages to the internal buffer |
| 178 | + def buffer_msg(channel, msg_parts): |
| 179 | + self.log.debug("Buffering msg on %s:%s", kernel_id, channel) |
| 180 | + buffer_info['buffer'].append((channel, msg_parts)) |
| 181 | + |
| 182 | + for channel, stream in channels.items(): |
| 183 | + stream.on_recv(partial(buffer_msg, channel)) |
| 184 | + |
| 185 | + |
| 186 | + def get_buffer(self, kernel_id, session_key): |
| 187 | + """Get the buffer for a given kernel |
| 188 | +
|
| 189 | + Parameters |
| 190 | + ---------- |
| 191 | + kernel_id : str |
| 192 | + The id of the kernel to stop buffering. |
| 193 | + session_key: str, optional |
| 194 | + The session_key, if any, that should get the buffer. |
| 195 | + If the session_key matches the current buffered session_key, |
| 196 | + the buffer will be returned. |
| 197 | + """ |
| 198 | + self.log.debug("Getting buffer for %s", kernel_id) |
| 199 | + if kernel_id not in self._kernel_buffers: |
| 200 | + return |
| 201 | + |
| 202 | + buffer_info = self._kernel_buffers[kernel_id] |
| 203 | + if buffer_info['session_key'] == session_key: |
| 204 | + # remove buffer |
| 205 | + self._kernel_buffers.pop(kernel_id) |
| 206 | + # only return buffer_info if it's a match |
| 207 | + return buffer_info |
| 208 | + else: |
| 209 | + self.stop_buffering(kernel_id) |
| 210 | + |
| 211 | + def stop_buffering(self, kernel_id): |
| 212 | + """Stop buffering kernel messages |
| 213 | +
|
| 214 | + Parameters |
| 215 | + ---------- |
| 216 | + kernel_id : str |
| 217 | + The id of the kernel to stop buffering. |
| 218 | + """ |
| 219 | + self.log.debug("Clearing buffer for %s", kernel_id) |
| 220 | + self._check_kernel_id(kernel_id) |
| 221 | + |
| 222 | + if kernel_id not in self._kernel_buffers: |
| 223 | + return |
| 224 | + buffer_info = self._kernel_buffers.pop(kernel_id) |
| 225 | + # close buffering streams |
| 226 | + for stream in buffer_info['channels'].values(): |
| 227 | + if not stream.closed(): |
| 228 | + stream.on_recv(None) |
| 229 | + stream.socket.close() |
| 230 | + stream.close() |
| 231 | + |
| 232 | + msg_buffer = buffer_info['buffer'] |
| 233 | + if msg_buffer: |
| 234 | + self.log.info("Discarding %s buffered messages for %s", |
| 235 | + len(msg_buffer), buffer_info['session_key']) |
| 236 | + |
145 | 237 | def shutdown_kernel(self, kernel_id, now=False):
|
146 | 238 | """Shutdown a kernel by kernel_id"""
|
147 | 239 | self._check_kernel_id(kernel_id)
|
148 |
| - self._kernels[kernel_id]._activity_stream.close() |
| 240 | + kernel = self._kernels[kernel_id] |
| 241 | + kernel._activity_stream.close() |
| 242 | + self.stop_buffering(kernel_id) |
149 | 243 | self._kernel_connections.pop(kernel_id, None)
|
150 | 244 | return super(MappingKernelManager, self).shutdown_kernel(kernel_id, now=now)
|
151 | 245 |
|
@@ -256,6 +350,7 @@ def record_activity(msg_list):
|
256 | 350 |
|
257 | 351 | idents, fed_msg_list = session.feed_identities(msg_list)
|
258 | 352 | msg = session.deserialize(fed_msg_list)
|
| 353 | + |
259 | 354 | msg_type = msg['header']['msg_type']
|
260 | 355 | self.log.debug("activity on %s: %s", kernel_id, msg_type)
|
261 | 356 | if msg_type == 'status':
|
|
0 commit comments