@@ -104,10 +104,15 @@ def __init__(
104
104
start_line = end_line
105
105
106
106
@property
107
- def source (self ) -> str :
107
+ def notebook_source (self ) -> str :
108
108
"""Concatenated notebook source."""
109
109
return "\n " .join (cell .source for cell in self ._cells )
110
110
111
+ @property
112
+ def notebook_uri (self ) -> str :
113
+ """The notebook document's URI."""
114
+ return self ._document .uri
115
+
111
116
def notebook_position (
112
117
self , cell_uri : str , cell_position : Position
113
118
) -> Position :
@@ -241,6 +246,7 @@ def cell_filename(
241
246
workspace : Workspace ,
242
247
cell_uri : str ,
243
248
) -> str :
249
+ """Get the filename (used in diagnostics) for a cell URI."""
244
250
mapper = notebook_coordinate_mapper (workspace , cell_uri = cell_uri )
245
251
if mapper is None :
246
252
raise ValueError (
@@ -296,50 +302,56 @@ def cell_filename(
296
302
T = TypeVar ("T" )
297
303
298
304
299
- class ServerWrapper :
305
+ class ServerWrapper ( LanguageServer ) :
300
306
def __init__ (self , server : LanguageServer ):
301
307
self ._wrapped = server
302
- self .workspace = WorkspaceWrapper (server .workspace )
308
+ self ._workspace = WorkspaceWrapper (server .workspace )
309
+
310
+ @property
311
+ def workspace (self ) -> Workspace :
312
+ return self ._workspace
303
313
304
314
def __getattr__ (self , name : str ) -> Any :
305
315
return getattr (self ._wrapped , name )
306
316
307
317
308
- class WorkspaceWrapper :
318
+ class WorkspaceWrapper ( Workspace ) :
309
319
def __init__ (self , workspace : Workspace ):
310
320
self ._wrapped = workspace
311
321
312
322
def __getattr__ (self , name : str ) -> Any :
313
323
return getattr (self ._wrapped , name )
314
324
315
325
def get_text_document (self , doc_uri : str ) -> TextDocument :
316
- notebook = notebook_coordinate_mapper (self ._wrapped , cell_uri = doc_uri )
317
- if notebook is None :
326
+ mapper = notebook_coordinate_mapper (self ._wrapped , cell_uri = doc_uri )
327
+ if mapper is None :
318
328
return self ._wrapped .get_text_document (doc_uri )
319
- return TextDocument (uri = notebook ._document .uri , source = notebook .source )
329
+ return TextDocument (
330
+ uri = mapper .notebook_uri , source = mapper .notebook_source
331
+ )
320
332
321
333
322
- def _map_params_to_notebook (
323
- notebook : NotebookCoordinateMapper , params : T_params
334
+ def _notebook_params (
335
+ mapper : NotebookCoordinateMapper , params : T_params
324
336
) -> T_params :
325
337
if hasattr (params , "position" ):
326
- notebook_position = notebook .notebook_position (
338
+ notebook_position = mapper .notebook_position (
327
339
params .text_document .uri , params .position
328
340
)
329
- # Ignore mypy error since it doesn't seem to narrow params via the hasattr.
341
+ # Ignore mypy error since it doesn't seem to narrow via hasattr.
330
342
params = attrs .evolve (params , position = notebook_position ) # type: ignore[call-arg]
331
343
332
344
if hasattr (params , "range" ):
333
- notebook_range = notebook .notebook_range (
345
+ notebook_range = mapper .notebook_range (
334
346
params .text_document .uri , params .range
335
347
)
336
- # Ignore mypy error since it doesn't seem to narrow params via the hasattr.
348
+ # Ignore mypy error since it doesn't seem to narrow via hasattr.
337
349
params = attrs .evolve (params , range = notebook_range ) # type: ignore[call-arg]
338
350
339
351
return params
340
352
341
353
342
- def _map_result_to_cells (
354
+ def _cell_results (
343
355
workspace : Workspace ,
344
356
mapper : Optional [NotebookCoordinateMapper ],
345
357
params : _TextDocumentCoordinatesParams ,
@@ -363,12 +375,24 @@ def _map_result_to_cells(
363
375
def supports_notebooks (
364
376
f : Callable [[T_ls , T_params ], T ],
365
377
) -> Callable [[T_ls , T_params ], T ]:
378
+ """Decorator to add basic notebook support to a language server feature.
379
+
380
+ It works by converting params from cell coordinates to notebook coordinates
381
+ before calling the wrapped function, and then converting the result back
382
+ to cell coordinates.
383
+ """
384
+
366
385
def wrapped (ls : T_ls , params : T_params ) -> T :
367
386
mapper = notebook_coordinate_mapper (
368
387
ls .workspace , cell_uri = params .text_document .uri
369
388
)
370
- params = _map_params_to_notebook (mapper , params ) if mapper else params
371
- result = f (ls , params )
372
- return _map_result_to_cells (ls .workspace , mapper , params , result )
389
+ notebook_params = (
390
+ _notebook_params (mapper , params ) if mapper else params
391
+ )
392
+ notebook_server = cast (T_ls , ServerWrapper (ls ))
393
+ result = f (notebook_server , notebook_params )
394
+ return _cell_results (
395
+ notebook_server .workspace , mapper , notebook_params , result
396
+ )
373
397
374
398
return wrapped
0 commit comments