Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 84 additions & 10 deletions jupyter_server_nbmodel/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import asyncio
import json
import os
import typing as t

from dataclasses import dataclass
from functools import partial
from datetime import datetime, timezone

Expand All @@ -25,6 +25,42 @@
from jupyter_server_nbmodel.event_logger import event_logger


@dataclass
class StreamState:
"""State for tracking stream output text processing across messages."""

cursor: int = 0
name: str = ""
stripped_newline: bool = False # Track if trailing \n was stripped


def _apply_terminal_controls(text: str, new_text: str, cursor: int) -> tuple[str, int]:
"""Apply terminal control characters (\\r, \\b, \\n) to text.

Mirrors JupyterLab's packages/outputarea/src/model.ts Private.processText
"""
chars = list(text)
for char in new_text:
match char:
case "\b":
if cursor > 0 and chars[cursor - 1] != "\n":
del chars[cursor - 1]
cursor -= 1
case "\r":
while cursor > 0 and chars[cursor - 1] != "\n":
cursor -= 1
case "\n":
chars.append("\n")
cursor = len(chars)
case _:
if cursor < len(chars):
chars[cursor] = char
else:
chars.append(char)
cursor += 1
return "".join(chars), cursor


if t.TYPE_CHECKING:
import jupyter_client
from nbformat import NotebookNode
Expand Down Expand Up @@ -86,40 +122,78 @@ async def _get_ycell(
return ycell


def _output_hook(outputs: list[NotebookNode], ycell: y.Map | None, msg: dict) -> None:
def _output_hook(
outputs: list[NotebookNode],
ycell: y.Map | None,
stream_state: StreamState,
msg: dict,
) -> None:
"""Callback on execution request when an output is emitted.

Args:
outputs: A list of previously emitted outputs
ycell: The cell being executed
stream_state: Mutable server-side state for tracking stream text processing
msg: The output message
"""
msg_type = msg["header"]["msg_type"]
if msg_type in ("display_data", "stream", "execute_result", "error"):
# FIXME support for version
output = nbformat.v4.output_from_msg(msg)
outputs.append(output)

if ycell is not None:
cell_outputs = ycell["outputs"]
if msg_type == "stream":
with cell_outputs.doc.transaction():
text = output["text"]
# FIXME Logic is quite complex at https://github.com/jupyterlab/jupyterlab/blob/7ae2d436fc410b0cff51042a3350ba71f54f4445/packages/outputarea/src/model.ts#L518
if text.endswith((os.linesep, "\n")):
text = text[:-1]
if (not cell_outputs) or (cell_outputs[-1].get("name", None) != output["name"]):
output["text"] = [text]
stream_name = output["name"]

if stream_state.name != stream_name or not cell_outputs:
# Different stream or first output - start fresh
stream_state.name = stream_name
stream_state.stripped_newline = False
processed_text, stream_state.cursor = _apply_terminal_controls("", text, 0)
# Strip trailing newline for storage (matches JupyterLab behavior)
if processed_text.endswith("\n"):
processed_text = processed_text[:-1]
stream_state.stripped_newline = True
stream_state.cursor = len(processed_text)
output["text"] = [processed_text]
cell_outputs.append(output)
else:
# Same stream - combine with previous, processing \r and \b
last_output = cell_outputs[-1]
last_output["text"].append(text)
current_text = "".join(last_output["text"])
# Restore stripped newline before processing
if stream_state.stripped_newline:
current_text += "\n"
stream_state.cursor = len(current_text)
processed_text, stream_state.cursor = _apply_terminal_controls(
current_text, text, stream_state.cursor
)
# Strip trailing newline for storage
if processed_text.endswith("\n"):
processed_text = processed_text[:-1]
stream_state.stripped_newline = True
stream_state.cursor = len(processed_text)
else:
stream_state.stripped_newline = False
last_output["text"] = [processed_text]
cell_outputs[-1] = last_output
else:
# Non-stream output resets stream state
stream_state.name = ""
stream_state.cursor = 0
stream_state.stripped_newline = False
with cell_outputs.doc.transaction():
cell_outputs.append(output)
elif msg_type == "clear_output":
# FIXME msg.content.wait - if true should clear at the next message
outputs.clear()
stream_state.name = ""
stream_state.cursor = 0
stream_state.stripped_newline = False
if ycell is not None:
del ycell["outputs"][:]
elif msg_type == "update_display_data":
Expand Down Expand Up @@ -208,13 +282,13 @@ async def _execute_snippet(
}
)
outputs = []
stream_state = StreamState()
# FIXME we don't check if the session is consistent (aka the kernel is linked to the document)
# - should we?
reply = await ensure_async(
client.execute_interactive(
snippet,
# FIXME stream partial results
output_hook=partial(_output_hook, outputs, ycell),
output_hook=partial(_output_hook, outputs, ycell, stream_state),
stdin_hook=stdin_hook if client.allow_stdin else None,
)
)
Expand Down
38 changes: 37 additions & 1 deletion jupyter_server_nbmodel/tests/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from jupyter_client.kernelspec import NATIVE_KERNEL_NAME
from jupyter_client.asynchronous.client import AsyncKernelClient
from jupyter_server_nbmodel.models import PendingInput
from jupyter_server_nbmodel.actions import kernel_worker
from jupyter_server_nbmodel.actions import kernel_worker, _apply_terminal_controls

TEST_TIMEOUT = 15

Expand All @@ -23,6 +23,42 @@
ANSI_REGEX = re.compile("\x1b\\[(.*?)([@-~])")


def test_apply_terminal_controls_backspace_and_carriage_return():
"""Test terminal control processing with \\b and \\r across multiple messages.

Simulates:
print('1110\\b', end='', flush=True) # "111" (backspace deletes '0')
print('11', end='', flush=True) # "11111"
print('\\r2 ', end='', flush=True) # "2 111" (CR + overwrite)
print('3', end='', flush=True) # "2 311"
print('4') # "2 341\\n"
"""
# Message 1: "1110\b" -> "111"
text, cursor = _apply_terminal_controls("", "1110\b", 0)
assert text == "111"
assert cursor == 3

# Message 2: "11" -> "11111"
text, cursor = _apply_terminal_controls(text, "11", cursor)
assert text == "11111"
assert cursor == 5

# Message 3: "\r2 " -> "2 111" (carriage return moves to start, overwrites)
text, cursor = _apply_terminal_controls(text, "\r2 ", cursor)
assert text == "2 111"
assert cursor == 2

# Message 4: "3" -> "2 311"
text, cursor = _apply_terminal_controls(text, "3", cursor)
assert text == "2 311"
assert cursor == 3

# Message 5: "4\n" -> "2 341\n"
text, cursor = _apply_terminal_controls(text, "4\n", cursor)
assert text == "2 341\n"
assert cursor == 6


def strip_ansi(text: str):
return ANSI_REGEX.sub("", text)

Expand Down