From c3de0f864ede172c73aa827cacd6040bb3c00e27 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 23 Jun 2025 22:40:21 -0700 Subject: [PATCH 1/4] WIP: Notebook tools --- jupyter_server_documents/tools/notebook.py | 209 +++++++++++++++ jupyter_server_documents/tools/system.py | 292 +++++++++++++++++++++ jupyter_server_documents/tools/terminal.py | 14 + jupyter_server_documents/tools/textfile.py | 14 + pyproject.toml | 1 + 5 files changed, 530 insertions(+) create mode 100644 jupyter_server_documents/tools/notebook.py create mode 100644 jupyter_server_documents/tools/system.py create mode 100644 jupyter_server_documents/tools/terminal.py create mode 100644 jupyter_server_documents/tools/textfile.py diff --git a/jupyter_server_documents/tools/notebook.py b/jupyter_server_documents/tools/notebook.py new file mode 100644 index 0000000..6a424ed --- /dev/null +++ b/jupyter_server_documents/tools/notebook.py @@ -0,0 +1,209 @@ +from typing import Literal, List, Dict, Any, Set +import nbformat + +from jupyter_server.base.call_context import CallContext + +from jupyter_server_documents.rooms.yroom import YRoom + + +def add_cell( + file_path: str, + content: str | None = None, + cell_index: int | None = None, + add_above: bool = False, + cell_type: Literal["code", "markdown", "raw"] = "code" + ): + """Adds a new cell to the Jupyter notebook above or below a specified cell index. + + This function adds a new cell to a Jupyter notebook. It first attempts to use + the in-memory YDoc representation if the notebook is currently active. If the + notebook is not active, it falls back to using the filesystem to read, modify, + and write the notebook file directly. + + Args: + file_path: The absolute path to the notebook file on the filesystem. + content: The content of the new cell. If None, an empty cell is created. + cell_index: The zero-based index where the cell should be added. If None, + the cell is added at the end of the notebook. + add_above: If True, the cell is added above the specified index. If False, + it's added below the specified index. + cell_type: The type of cell to add ("code", "markdown"). + + Returns: + None + """ + + file_id = _get_file_id(file_path) + ydoc = _get_jupyter_ydoc(file_id) + + if ydoc: + cells_count = ydoc.cell_number() + insert_index = _determine_insert_index(cells_count, cell_index, add_above) + ycell = ydoc.create_ycell({ + "cell_type": cell_type, + "source": content or "", + }) + ydoc.cells.insert(insert_index, ycell) + else: + with open(file_path, 'r', encoding='utf-8') as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cells_count = len(notebook.cells) + insert_index = _determine_insert_index(cells_count, cell_index, add_above) + + if cell_type == "code": + notebook.cells.insert(insert_index, nbformat.v4.new_code_cell( + source=content or "" + )) + elif cell_type == "markdown": + notebook.cells.insert(insert_index, nbformat.v4.new_markdown_cell( + source=content or "" + )) + else: + notebook.cells.insert(insert_index, nbformat.v4.new_raw_cell( + source=content or "" + )) + + with open(file_path, 'w', encoding='utf-8') as f: + nbformat.write(notebook, f) + +def delete_cell(file_path: str, cell_index: int): + """Removes a notebook cell at the specified cell index. + + This function deletes a cell from a Jupyter notebook. It first attempts to use + the in-memory YDoc representation if the notebook is currently active. If the + notebook is not active, it falls back to using the filesystem to read, modify, + and write the notebook file directly using nbformat. + + Args: + file_path: The absolute path to the notebook file on the filesystem. + cell_index: The zero-based index of the cell to delete. + + Returns: + None + """ + + file_id = _get_file_id(file_path) + ydoc = _get_jupyter_ydoc(file_id) + if ydoc: + if 0 <= cell_index < len(ydoc.cells): + del ydoc.cells[cell_index] + else: + with open(file_path, 'r', encoding='utf-8') as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + if 0 <= cell_index < len(notebook.cells): + notebook.cells.pop(cell_index) + + with open(file_path, 'w', encoding='utf-8') as f: + nbformat.write(notebook, f) + +def edit_cell( + file_path: str, + cell_index: int, + content: str | None = None + ) -> None: + """Edits the content of a notebook cell at the specified index + + This function modifies the content of a cell in a Jupyter notebook. It first attempts to use + the in-memory YDoc representation if the notebook is currently active. If the + notebook is not active, it falls back to using the filesystem to read, modify, + and write the notebook file directly using nbformat. + + Args: + file_path: The absolute path to the notebook file on the filesystem. + cell_index: The zero-based index of the cell to edit. + content: The new content for the cell. If None, the cell content remains unchanged. + + Returns: + None + + Raises: + IndexError: If the cell_index is out of range for the notebook. + """ + + file_id = _get_file_id(file_path) + ydoc = _get_jupyter_ydoc(file_id) + + if ydoc: + cells_count = len(ydoc.cells) + if 0 <= cell_index < cells_count: + if content is not None: + ydoc.cells[cell_index]["source"] = content + else: + raise IndexError( + f"{cell_index=} is out of range for notebook at {file_path=} with {cells_count=}" + ) + else: + with open(file_path, 'r', encoding='utf-8') as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cell_count = len(notebook.cells) + if 0 <= cell_index < cell_count: + if content is not None: + notebook.cells[cell_index].source = content + + with open(file_path, 'w', encoding='utf-8') as f: + nbformat.write(notebook, f) + else: + raise IndexError( + f"{cell_index=} is out of range for notebook at {file_path=} with {cell_count=}" + ) + +def read_cell(file_path: str, cell_index: int) -> Dict[str, Any]: + """Returns the content and metadata of a cell at the specified index""" + + with open(file_path, 'r', encoding='utf-8') as f: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cell_count = len(notebook.cells) + if 0 <= cell_index < cell_count: + cell = notebook.cells[cell_index] + return cell + else: + raise IndexError( + f"{cell_index=} is out of range for notebook at {file_path=} with {cell_count=}" + ) + +def read_notebook(file_id: str) -> str: + """Returns the complete notebook content as markdown string""" + pass + +def read_notebook_source(file_id: str) -> Dict[str, Any]: + """Returns the complete notebook content including metadata""" + pass + +def summarize_notebook(file_id: str, max_length: int = 500) -> str: + """Generates a summary of the notebook content""" + pass + + +def _get_serverapp(): + handler = CallContext.get(CallContext.JUPYTER_HANDLER) + serverapp = handler.serverapp + return serverapp + +def _get_jupyter_ydoc(file_id: str) -> YRoom | None: + serverapp = _get_serverapp() + yroom_manager = serverapp.web_app.settings["yroom_manager"] + room_id = f"json:notebook:{file_id}" + if yroom_manager.has_room(room_id): + yroom = yroom_manager.get_room(room_id) + notebook = yroom.get_jupyter_ydoc() + return notebook + +def _get_file_id(file_path: str) -> str: + serverapp = _get_serverapp() + file_id_manager = serverapp.web_app.settings["file_id_manager"] + file_id = file_id_manager.get_id(file_path) + + return file_id + +def _determine_insert_index(cells_count: int, cell_index: int, add_above: bool) -> int: + if cell_index is None: + insert_index = cells_count + else: + if not (0 <= cell_index < cells_count): + cell_index = max(0, min(cell_index, cells_count)) + insert_index = cell_index if add_above else cell_index + 1 + return insert_index diff --git a/jupyter_server_documents/tools/system.py b/jupyter_server_documents/tools/system.py new file mode 100644 index 0000000..ea4f92b --- /dev/null +++ b/jupyter_server_documents/tools/system.py @@ -0,0 +1,292 @@ +import subprocess +import shlex +import os +from typing import Optional, Dict, List, Union, Any, Set + +# Whitelist of allowed commands for security +ALLOWED_COMMANDS: Set[str] = { + # Basic file and directory operations + "ls", "grep", "find", "cat", "head", "tail", "wc", + # Search tools + "grep", "rg", "ack", "ag", + # File manipulation + "cp", "mv", "rm", "mkdir", "touch", "chmod", "chown", + # Archive tools + "tar", "gzip", "gunzip", "zip", "unzip", + # Text processing + "sed", "awk", "cut", "sort", "uniq", "tr", "diff", "patch", + # Network tools + "curl", "wget", "ping", "netstat", "ssh", "scp", + # System information + "ps", "top", "df", "du", "free", "uname", "whoami", "date", + # Development tools + "git", "npm", "pip", "python", "node", "java", "javac", "gcc", "make", + # Package managers + "apt", "apt-get", "yum", "brew", "conda", + # Jupyter specific + "jupyter", "ipython", "nbconvert" +} + +class CommandNotAllowedError(Exception): + """Exception raised when a command is not in the whitelist.""" + pass + + +def bash(command: str, description: str, timeout: Optional[int] = None) -> Dict[str, Any]: + """Runs a bash command and returns the result + + Parameters + ---------- + command : str + The bash command to execute + description : str + A description of what the command does (for logging purposes) + timeout : Optional[int], optional + Timeout in seconds for the command execution, by default None + + Returns + ------- + Dict[str, Any] + A dictionary containing: + - stdout: The standard output as a string + - stderr: The standard error as a string + - returncode: The return code of the command + + Raises + ------ + CommandNotAllowedError + If the command is not in the whitelist of allowed commands + subprocess.TimeoutExpired + If the command execution times out + """ + try: + # Split the command into arguments if it's not already a list + if isinstance(command, str): + args = shlex.split(command) + else: + args = command + + # Check if the command is in the whitelist + if not args or args[0] not in ALLOWED_COMMANDS: + raise CommandNotAllowedError( + f"Command '{args[0] if args else ''}' is not in the whitelist of allowed commands. " + f"Allowed commands: {', '.join(sorted(ALLOWED_COMMANDS))}" + ) + + # Execute the command + process = subprocess.Popen( + args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + shell=False + ) + + # Wait for the command to complete with timeout if specified + stdout, stderr = process.communicate(timeout=timeout) + + # Return the result + return { + "stdout": stdout, + "stderr": stderr, + "returncode": process.returncode + } + except subprocess.TimeoutExpired: + # Kill the process if it times out + process.kill() + stdout, stderr = process.communicate() + raise subprocess.TimeoutExpired( + cmd=command, + timeout=timeout, + output=stdout, + stderr=stderr + ) + +def glob(pattern: str, path: str = ".") -> List[str]: + """Runs the unix glob command and returns the result + + Parameters + ---------- + pattern : str + The glob pattern to match (e.g., "*.py", "**/*.md") + path : str, optional + The base path to search from, by default "." + + Returns + ------- + List[str] + A list of file paths matching the pattern + """ + import glob as glob_module + import os + + # Join the path and pattern if path is provided + search_pattern = os.path.join(path, pattern) + + # Use recursive glob if the pattern contains ** + if "**" in pattern: + return glob_module.glob(search_pattern, recursive=True) + else: + return glob_module.glob(search_pattern) + +def grep(pattern: str, path: str, include: Optional[str] = None) -> Dict[str, Any]: + """Runs the unix grep command and returns the result + + Parameters + ---------- + pattern : str + The pattern to search for + path : str + The path to search in (file or directory) + include : Optional[str], optional + File pattern to include in the search (e.g., "*.py"), by default None + + Returns + ------- + Dict[str, Any] + A dictionary containing: + - stdout: The standard output as a string + - stderr: The standard error as a string + - returncode: The return code of the command + - matches: A parsed list of matches (if returncode is 0) + """ + import shutil + import json + + # Check if ripgrep is available (preferred for performance) + use_ripgrep = shutil.which("rg") is not None + + if use_ripgrep: + # Construct ripgrep command + cmd = ["rg", "--json", pattern] + + # Add include pattern if specified + if include: + cmd.extend(["--glob", include]) + + # Add path + cmd.append(path) + + # Execute the command + result = bash(cmd, f"Searching for '{pattern}' in {path} using ripgrep", None) + + # Parse the JSON output if successful + if result["returncode"] == 0 and result["stdout"]: + matches = [] + for line in result["stdout"].strip().split("\n"): + try: + match_data = json.loads(line) + if match_data.get("type") == "match": + matches.append({ + "path": match_data.get("data", {}).get("path", {}).get("text", ""), + "line_number": match_data.get("data", {}).get("line_number", 0), + "line": match_data.get("data", {}).get("lines", {}).get("text", "").strip() + }) + except json.JSONDecodeError: + pass + + result["matches"] = matches + else: + # Fallback to standard grep + cmd = ["grep", "-r", "--line-number"] + + # Add include pattern if specified + if include: + cmd.extend(["--include", include]) + + # Add pattern and path + cmd.extend([pattern, path]) + + # Execute the command + result = bash(cmd, f"Searching for '{pattern}' in {path} using grep", None) + + # Parse the output if successful + if result["returncode"] == 0 and result["stdout"]: + matches = [] + for line in result["stdout"].strip().split("\n"): + parts = line.split(":", 2) + if len(parts) >= 3: + matches.append({ + "path": parts[0], + "line_number": int(parts[1]), + "line": parts[2].strip() + }) + + result["matches"] = matches + elif result["returncode"] == 1: + # grep returns 1 when no matches are found (not an error) + result["matches"] = [] + + return result + +def ls(path: str = ".", ignore: Optional[List[str]] = None) -> Dict[str, Any]: + """Runs the unix ls command and returns the result + + Parameters + ---------- + path : str, optional + The path to list contents for, by default "." + ignore : Optional[List[str]], optional + List of patterns to ignore, by default None + + Returns + ------- + Dict[str, Any] + A dictionary containing: + - stdout: The standard output as a string + - stderr: The standard error as a string + - returncode: The return code of the command + - entries: A parsed list of directory entries (if returncode is 0) + """ + import os + import fnmatch + + # Default to empty list if ignore is None + ignore = ignore or [] + + # Construct ls command with long format + cmd = ["ls", "-la"] + + # Add path + cmd.append(path) + + # Execute the command + result = bash(cmd, f"Listing contents of {path}", None) + + # Parse the output if successful + if result["returncode"] == 0 and result["stdout"]: + entries = [] + lines = result["stdout"].strip().split("\n") + + # Skip the total line and parse each entry + for line in lines[1:]: # Skip the "total X" line + parts = line.split(None, 8) + if len(parts) >= 9: + name = parts[8] + + # Skip if the entry matches any ignore pattern + if any(fnmatch.fnmatch(name, pattern) for pattern in ignore): + continue + + # Determine if it's a directory from the first character of permissions + is_dir = parts[0].startswith("d") + + entry = { + "name": name, + "type": "directory" if is_dir else "file", + "permissions": parts[0], + "links": int(parts[1]), + "owner": parts[2], + "group": parts[3], + "size": int(parts[4]), + "modified": f"{parts[5]} {parts[6]} {parts[7]}" + } + + # Add full path + entry["path"] = os.path.join(path, name) + + entries.append(entry) + + result["entries"] = entries + + return result diff --git a/jupyter_server_documents/tools/terminal.py b/jupyter_server_documents/tools/terminal.py new file mode 100644 index 0000000..9933ef8 --- /dev/null +++ b/jupyter_server_documents/tools/terminal.py @@ -0,0 +1,14 @@ +from typing import Dict, Any, Optional, List + + +def create_new_terminal(name: Optional[str] = None) -> str: + """Creates a new terminal session and returns its ID""" + pass + +def run_terminal_command(terminal_id: str, command: str) -> bool: + """Runs a command in the specified terminal session""" + pass + +def read_terminal_output(terminal_id: str, max_lines: int = 100) -> List[str]: + """Returns the output from the specified terminal session""" + pass diff --git a/jupyter_server_documents/tools/textfile.py b/jupyter_server_documents/tools/textfile.py new file mode 100644 index 0000000..b87d41b --- /dev/null +++ b/jupyter_server_documents/tools/textfile.py @@ -0,0 +1,14 @@ +from typing import Dict, Any, Optional + + +def edit_text_file(file_id: str, content: str) -> bool: + """Edits the content of a text file with the specified file ID""" + pass + +def read_text_file(file_id: str) -> str: + """Returns the content of a text file with the specified file ID""" + pass + +def summarize_text_file(file_id: str, max_length: int = 500) -> str: + """Generates a summary of the text file content""" + pass diff --git a/pyproject.toml b/pyproject.toml index 8a212cd..a4688ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "jupyter_server_fileid>=0.9.0,<0.10.0", "pycrdt>=0.12.0,<0.13.0", "jupyter_ydoc>=3.0.0,<4.0.0", + "jupyter_server_ai_tools>=0.2.0" ] dynamic = ["version", "description", "authors", "urls", "keywords"] From d9e4825650948e0c61ffe7844a576f1970f62c64 Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 23 Jun 2025 22:44:53 -0700 Subject: [PATCH 2/4] Updated allowed commands set --- jupyter_server_documents/tools/system.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jupyter_server_documents/tools/system.py b/jupyter_server_documents/tools/system.py index ea4f92b..8b86a66 100644 --- a/jupyter_server_documents/tools/system.py +++ b/jupyter_server_documents/tools/system.py @@ -6,9 +6,7 @@ # Whitelist of allowed commands for security ALLOWED_COMMANDS: Set[str] = { # Basic file and directory operations - "ls", "grep", "find", "cat", "head", "tail", "wc", - # Search tools - "grep", "rg", "ack", "ag", + "find", "cat", "head", "tail", "wc", # File manipulation "cp", "mv", "rm", "mkdir", "touch", "chmod", "chown", # Archive tools @@ -20,7 +18,7 @@ # System information "ps", "top", "df", "du", "free", "uname", "whoami", "date", # Development tools - "git", "npm", "pip", "python", "node", "java", "javac", "gcc", "make", + "git", "npm", "pip", "python", "node", "make", # Package managers "apt", "apt-get", "yum", "brew", "conda", # Jupyter specific From 4806623a693b462cf9f3a08cf8e4ff39e1035cdf Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Mon, 23 Jun 2025 22:45:29 -0700 Subject: [PATCH 3/4] Removed unused imports --- jupyter_server_documents/tools/system.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/jupyter_server_documents/tools/system.py b/jupyter_server_documents/tools/system.py index 8b86a66..1df4b9f 100644 --- a/jupyter_server_documents/tools/system.py +++ b/jupyter_server_documents/tools/system.py @@ -1,7 +1,6 @@ import subprocess import shlex -import os -from typing import Optional, Dict, List, Union, Any, Set +from typing import Optional, Dict, List, Any, Set # Whitelist of allowed commands for security ALLOWED_COMMANDS: Set[str] = { From 476591774fec5992e1b890b51bb59c6674243a2c Mon Sep 17 00:00:00 2001 From: Piyush Jain Date: Tue, 24 Jun 2025 11:38:52 -0700 Subject: [PATCH 4/4] Updated cell tools to use cell id --- jupyter_server_documents/tools/notebook.py | 85 ++++++++++++++-------- 1 file changed, 53 insertions(+), 32 deletions(-) diff --git a/jupyter_server_documents/tools/notebook.py b/jupyter_server_documents/tools/notebook.py index 6a424ed..c51de7c 100644 --- a/jupyter_server_documents/tools/notebook.py +++ b/jupyter_server_documents/tools/notebook.py @@ -9,11 +9,11 @@ def add_cell( file_path: str, content: str | None = None, - cell_index: int | None = None, + cell_id: str | None = None, add_above: bool = False, cell_type: Literal["code", "markdown", "raw"] = "code" ): - """Adds a new cell to the Jupyter notebook above or below a specified cell index. + """Adds a new cell to the Jupyter notebook above or below a specified cell. This function adds a new cell to a Jupyter notebook. It first attempts to use the in-memory YDoc representation if the notebook is currently active. If the @@ -23,11 +23,11 @@ def add_cell( Args: file_path: The absolute path to the notebook file on the filesystem. content: The content of the new cell. If None, an empty cell is created. - cell_index: The zero-based index where the cell should be added. If None, - the cell is added at the end of the notebook. - add_above: If True, the cell is added above the specified index. If False, - it's added below the specified index. - cell_type: The type of cell to add ("code", "markdown"). + cell_id: The UUID of the cell to add relative to. If None, + the cell is added at the end of the notebook. + add_above: If True, the cell is added above the specified cell. If False, + it's added below the specified cell. + cell_type: The type of cell to add ("code", "markdown", "raw"). Returns: None @@ -38,6 +38,7 @@ def add_cell( if ydoc: cells_count = ydoc.cell_number() + cell_index = _get_cell_index_from_id(ydoc, cell_id) if cell_id else None insert_index = _determine_insert_index(cells_count, cell_index, add_above) ycell = ydoc.create_ycell({ "cell_type": cell_type, @@ -49,6 +50,7 @@ def add_cell( notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) cells_count = len(notebook.cells) + cell_index = _get_cell_index_from_id_nbformat(notebook, cell_id) if cell_id else None insert_index = _determine_insert_index(cells_count, cell_index, add_above) if cell_type == "code": @@ -67,8 +69,8 @@ def add_cell( with open(file_path, 'w', encoding='utf-8') as f: nbformat.write(notebook, f) -def delete_cell(file_path: str, cell_index: int): - """Removes a notebook cell at the specified cell index. +def delete_cell(file_path: str, cell_id: str): + """Removes a notebook cell with the specified cell ID. This function deletes a cell from a Jupyter notebook. It first attempts to use the in-memory YDoc representation if the notebook is currently active. If the @@ -77,7 +79,7 @@ def delete_cell(file_path: str, cell_index: int): Args: file_path: The absolute path to the notebook file on the filesystem. - cell_index: The zero-based index of the cell to delete. + cell_id: The UUID of the cell to delete. Returns: None @@ -86,13 +88,15 @@ def delete_cell(file_path: str, cell_index: int): file_id = _get_file_id(file_path) ydoc = _get_jupyter_ydoc(file_id) if ydoc: - if 0 <= cell_index < len(ydoc.cells): + cell_index = _get_cell_index_from_id(ydoc, cell_id) + if cell_index is not None and 0 <= cell_index < len(ydoc.cells): del ydoc.cells[cell_index] else: with open(file_path, 'r', encoding='utf-8') as f: notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) - if 0 <= cell_index < len(notebook.cells): + cell_index = _get_cell_index_from_id_nbformat(notebook, cell_id) + if cell_index is not None and 0 <= cell_index < len(notebook.cells): notebook.cells.pop(cell_index) with open(file_path, 'w', encoding='utf-8') as f: @@ -100,10 +104,10 @@ def delete_cell(file_path: str, cell_index: int): def edit_cell( file_path: str, - cell_index: int, + cell_id: str, content: str | None = None ) -> None: - """Edits the content of a notebook cell at the specified index + """Edits the content of a notebook cell with the specified ID This function modifies the content of a cell in a Jupyter notebook. It first attempts to use the in-memory YDoc representation if the notebook is currently active. If the @@ -112,57 +116,57 @@ def edit_cell( Args: file_path: The absolute path to the notebook file on the filesystem. - cell_index: The zero-based index of the cell to edit. + cell_id: The UUID of the cell to edit. content: The new content for the cell. If None, the cell content remains unchanged. Returns: None Raises: - IndexError: If the cell_index is out of range for the notebook. + ValueError: If the cell_id is not found in the notebook. """ file_id = _get_file_id(file_path) ydoc = _get_jupyter_ydoc(file_id) if ydoc: - cells_count = len(ydoc.cells) - if 0 <= cell_index < cells_count: + cell_index = _get_cell_index_from_id(ydoc, cell_id) + if cell_index is not None: if content is not None: ydoc.cells[cell_index]["source"] = content else: - raise IndexError( - f"{cell_index=} is out of range for notebook at {file_path=} with {cells_count=}" + raise ValueError( + f"Cell with {cell_id=} not found in notebook at {file_path=}" ) else: with open(file_path, 'r', encoding='utf-8') as f: notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) - cell_count = len(notebook.cells) - if 0 <= cell_index < cell_count: + cell_index = _get_cell_index_from_id_nbformat(notebook, cell_id) + if cell_index is not None: if content is not None: notebook.cells[cell_index].source = content with open(file_path, 'w', encoding='utf-8') as f: nbformat.write(notebook, f) else: - raise IndexError( - f"{cell_index=} is out of range for notebook at {file_path=} with {cell_count=}" + raise ValueError( + f"Cell with {cell_id=} not found in notebook at {file_path=}" ) -def read_cell(file_path: str, cell_index: int) -> Dict[str, Any]: - """Returns the content and metadata of a cell at the specified index""" +def read_cell(file_path: str, cell_id: str) -> Dict[str, Any]: + """Returns the content and metadata of a cell with the specified ID""" with open(file_path, 'r', encoding='utf-8') as f: - notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) - - cell_count = len(notebook.cells) - if 0 <= cell_index < cell_count: + notebook = nbformat.read(f, as_version=nbformat.NO_CONVERT) + + cell_index = _get_cell_index_from_id_nbformat(notebook, cell_id) + if cell_index is not None: cell = notebook.cells[cell_index] return cell else: - raise IndexError( - f"{cell_index=} is out of range for notebook at {file_path=} with {cell_count=}" + raise ValueError( + f"Cell with {cell_id=} not found in notebook at {file_path=}" ) def read_notebook(file_id: str) -> str: @@ -199,6 +203,23 @@ def _get_file_id(file_path: str) -> str: return file_id +def _get_cell_index_from_id(ydoc, cell_id: str) -> int | None: + """Get cell index from cell_id using YDoc interface.""" + try: + cell_index, _ = ydoc.find_cell(cell_id) + return cell_index + except (AttributeError, KeyError): + return None + +def _get_cell_index_from_id_nbformat(notebook, cell_id: str) -> int | None: + """Get cell index from cell_id using nbformat interface.""" + for i, cell in enumerate(notebook.cells): + if hasattr(cell, 'id') and cell.id == cell_id: + return i + elif hasattr(cell, 'metadata') and cell.metadata.get('id') == cell_id: + return i + return None + def _determine_insert_index(cells_count: int, cell_index: int, add_above: bool) -> int: if cell_index is None: insert_index = cells_count