Skip to content

Commit 3a97a6a

Browse files
Mariko WakabayashiZsailer
authored andcommitted
Create AsyncFileContentsManager
1 parent ca60bcf commit 3a97a6a

File tree

10 files changed

+679
-143
lines changed

10 files changed

+679
-143
lines changed

jupyter_server/serverapp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@
6868
from .log import log_request
6969
from .services.kernels.kernelmanager import MappingKernelManager, AsyncMappingKernelManager
7070
from .services.config import ConfigManager
71-
from .services.contents.manager import ContentsManager
72-
from .services.contents.filemanager import FileContentsManager
71+
from .services.contents.manager import AsyncContentsManager, ContentsManager
72+
from .services.contents.filemanager import AsyncFileContentsManager, FileContentsManager
7373
from .services.contents.largefilemanager import LargeFileManager
7474
from .services.sessions.sessionmanager import SessionManager
7575
from .gateway.managers import GatewayKernelManager, GatewayKernelSpecManager, GatewaySessionManager, GatewayClient

jupyter_server/services/contents/filecheckpoints.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77
from tornado.web import HTTPError
88

99
from .checkpoints import (
10+
AsyncCheckpoints,
1011
Checkpoints,
1112
GenericCheckpointsMixin,
1213
)
13-
from .fileio import FileManagerMixin
14+
from .fileio import AsyncFileManagerMixin, FileManagerMixin
1415

16+
from anyio import run_sync_in_worker_thread
1517
from jupyter_core.utils import ensure_dir_exists
18+
from jupyter_server.utils import ensure_async
1619
from ipython_genutils.py3compat import getcwd
1720
from traitlets import Unicode
1821

@@ -137,6 +140,70 @@ def no_such_checkpoint(self, path, checkpoint_id):
137140
)
138141

139142

143+
class AsyncFileCheckpoints(FileCheckpoints, AsyncFileManagerMixin, AsyncCheckpoints):
144+
async def create_checkpoint(self, contents_mgr, path):
145+
"""Create a checkpoint."""
146+
checkpoint_id = u'checkpoint'
147+
src_path = contents_mgr._get_os_path(path)
148+
dest_path = self.checkpoint_path(checkpoint_id, path)
149+
await self._copy(src_path, dest_path)
150+
return (await self.checkpoint_model(checkpoint_id, dest_path))
151+
152+
async def restore_checkpoint(self, contents_mgr, checkpoint_id, path):
153+
"""Restore a checkpoint."""
154+
src_path = self.checkpoint_path(checkpoint_id, path)
155+
dest_path = contents_mgr._get_os_path(path)
156+
await self._copy(src_path, dest_path)
157+
158+
async def checkpoint_model(self, checkpoint_id, os_path):
159+
"""construct the info dict for a given checkpoint"""
160+
stats = await run_sync_in_worker_thread(os.stat, os_path)
161+
last_modified = tz.utcfromtimestamp(stats.st_mtime)
162+
info = dict(
163+
id=checkpoint_id,
164+
last_modified=last_modified,
165+
)
166+
return info
167+
168+
# ContentsManager-independent checkpoint API
169+
async def rename_checkpoint(self, checkpoint_id, old_path, new_path):
170+
"""Rename a checkpoint from old_path to new_path."""
171+
old_cp_path = self.checkpoint_path(checkpoint_id, old_path)
172+
new_cp_path = self.checkpoint_path(checkpoint_id, new_path)
173+
if os.path.isfile(old_cp_path):
174+
self.log.debug(
175+
"Renaming checkpoint %s -> %s",
176+
old_cp_path,
177+
new_cp_path,
178+
)
179+
with self.perm_to_403():
180+
await run_sync_in_worker_thread(shutil.move, old_cp_path, new_cp_path)
181+
182+
async def delete_checkpoint(self, checkpoint_id, path):
183+
"""delete a file's checkpoint"""
184+
path = path.strip('/')
185+
cp_path = self.checkpoint_path(checkpoint_id, path)
186+
if not os.path.isfile(cp_path):
187+
self.no_such_checkpoint(path, checkpoint_id)
188+
189+
self.log.debug("unlinking %s", cp_path)
190+
with self.perm_to_403():
191+
await run_sync_in_worker_thread(os.unlink, cp_path)
192+
193+
async def list_checkpoints(self, path):
194+
"""list the checkpoints for a given file
195+
196+
This contents manager currently only supports one checkpoint per file.
197+
"""
198+
path = path.strip('/')
199+
checkpoint_id = "checkpoint"
200+
os_path = self.checkpoint_path(checkpoint_id, path)
201+
if not os.path.isfile(os_path):
202+
return []
203+
else:
204+
return [await self.checkpoint_model(checkpoint_id, os_path)]
205+
206+
140207
class GenericFileCheckpoints(GenericCheckpointsMixin, FileCheckpoints):
141208
"""
142209
Local filesystem Checkpoints that works with any conforming

jupyter_server/services/contents/fileio.py

Lines changed: 112 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77

88
from contextlib import contextmanager
99
import errno
10+
from functools import partial
1011
import io
1112
import os
1213
import shutil
1314

15+
from anyio import open_file, run_sync_in_worker_thread
1416
from tornado.web import HTTPError
1517

1618
from jupyter_server.utils import (
@@ -32,6 +34,11 @@ def replace_file(src, dst):
3234
"""
3335
os.replace(src, dst)
3436

37+
async def async_replace_file(src, dst):
38+
""" replace dst with src asynchronously
39+
"""
40+
await run_sync_in_worker_thread(os.replace, src, dst)
41+
3542
def copy2_safe(src, dst, log=None):
3643
"""copy src to dst
3744
@@ -44,6 +51,18 @@ def copy2_safe(src, dst, log=None):
4451
if log:
4552
log.debug("copystat on %s failed", dst, exc_info=True)
4653

54+
async def async_copy2_safe(src, dst, log=None):
55+
"""copy src to dst asynchronously
56+
57+
like shutil.copy2, but log errors in copystat instead of raising
58+
"""
59+
await run_sync_in_worker_thread(shutil.copyfile, src, dst)
60+
try:
61+
await run_sync_in_worker_thread(shutil.copystat, src, dst)
62+
except OSError:
63+
if log:
64+
log.debug("copystat on %s failed", dst, exc_info=True)
65+
4766
def path_to_intermediate(path):
4867
'''Name of the intermediate file used in atomic writes.
4968
@@ -116,11 +135,10 @@ def atomic_writing(path, text=True, encoding='utf-8', log=None, **kwargs):
116135
os.remove(tmp_path)
117136

118137

119-
120138
@contextmanager
121139
def _simple_writing(path, text=True, encoding='utf-8', log=None, **kwargs):
122140
"""Context manager to write file without doing atomic writing
123-
( for weird filesystem eg: nfs).
141+
(for weird filesystem eg: nfs).
124142
125143
Parameters
126144
----------
@@ -159,8 +177,6 @@ def _simple_writing(path, text=True, encoding='utf-8', log=None, **kwargs):
159177
fileobj.close()
160178

161179

162-
163-
164180
class FileManagerMixin(Configurable):
165181
"""
166182
Mixin for ContentsAPI classes that interact with the filesystem.
@@ -186,7 +202,7 @@ class FileManagerMixin(Configurable):
186202

187203
@contextmanager
188204
def open(self, os_path, *args, **kwargs):
189-
"""wrapper around io.open that turns permission errors into 403"""
205+
"""wrapper around open that turns permission errors into 403"""
190206
with self.perm_to_403(os_path):
191207
with io.open(os_path, *args, **kwargs) as f:
192208
yield f
@@ -330,3 +346,94 @@ def _save_file(self, os_path, content, format):
330346

331347
with self.atomic_writing(os_path, text=False) as f:
332348
f.write(bcontent)
349+
350+
class AsyncFileManagerMixin(FileManagerMixin):
351+
"""
352+
Mixin for ContentsAPI classes that interact with the filesystem asynchronously.
353+
"""
354+
async def _copy(self, src, dest):
355+
"""copy src to dest
356+
357+
like shutil.copy2, but log errors in copystat
358+
"""
359+
await async_copy2_safe(src, dest, log=self.log)
360+
361+
async def _read_notebook(self, os_path, as_version=4):
362+
"""Read a notebook from an os path."""
363+
with self.open(os_path, 'r', encoding='utf-8') as f:
364+
try:
365+
return await run_sync_in_worker_thread(partial(nbformat.read, as_version=as_version), f)
366+
except Exception as e:
367+
e_orig = e
368+
369+
# If use_atomic_writing is enabled, we'll guess that it was also
370+
# enabled when this notebook was written and look for a valid
371+
# atomic intermediate.
372+
tmp_path = path_to_intermediate(os_path)
373+
374+
if not self.use_atomic_writing or not os.path.exists(tmp_path):
375+
raise HTTPError(
376+
400,
377+
u"Unreadable Notebook: %s %r" % (os_path, e_orig),
378+
)
379+
380+
# Move the bad file aside, restore the intermediate, and try again.
381+
invalid_file = path_to_invalid(os_path)
382+
async_replace_file(os_path, invalid_file)
383+
async_replace_file(tmp_path, os_path)
384+
return await self._read_notebook(os_path, as_version)
385+
386+
async def _save_notebook(self, os_path, nb):
387+
"""Save a notebook to an os_path."""
388+
with self.atomic_writing(os_path, encoding='utf-8') as f:
389+
await run_sync_in_worker_thread(partial(nbformat.write, version=nbformat.NO_CONVERT), nb, f)
390+
391+
async def _read_file(self, os_path, format):
392+
"""Read a non-notebook file.
393+
394+
os_path: The path to be read.
395+
format:
396+
If 'text', the contents will be decoded as UTF-8.
397+
If 'base64', the raw bytes contents will be encoded as base64.
398+
If not specified, try to decode as UTF-8, and fall back to base64
399+
"""
400+
if not os.path.isfile(os_path):
401+
raise HTTPError(400, "Cannot read non-file %s" % os_path)
402+
403+
with self.open(os_path, 'rb') as f:
404+
bcontent = await run_sync_in_worker_thread(f.read)
405+
406+
if format is None or format == 'text':
407+
# Try to interpret as unicode if format is unknown or if unicode
408+
# was explicitly requested.
409+
try:
410+
return bcontent.decode('utf8'), 'text'
411+
except UnicodeError as e:
412+
if format == 'text':
413+
raise HTTPError(
414+
400,
415+
"%s is not UTF-8 encoded" % os_path,
416+
reason='bad format',
417+
) from e
418+
return encodebytes(bcontent).decode('ascii'), 'base64'
419+
420+
async def _save_file(self, os_path, content, format):
421+
"""Save content of a generic file."""
422+
if format not in {'text', 'base64'}:
423+
raise HTTPError(
424+
400,
425+
"Must specify format of file contents as 'text' or 'base64'",
426+
)
427+
try:
428+
if format == 'text':
429+
bcontent = content.encode('utf8')
430+
else:
431+
b64_bytes = content.encode('ascii')
432+
bcontent = decodebytes(b64_bytes)
433+
except Exception as e:
434+
raise HTTPError(
435+
400, u'Encoding error saving %s: %s' % (os_path, e)
436+
) from e
437+
438+
with self.atomic_writing(os_path, text=False) as f:
439+
await run_sync_in_worker_thread(f.write, bcontent)

0 commit comments

Comments
 (0)