diff --git a/conftest.py b/conftest.py index 0c2399b..7226336 100644 --- a/conftest.py +++ b/conftest.py @@ -1,6 +1,12 @@ import pytest -pytest_plugins = ("pytest_jupyter.jupyter_server", "jupyter_server.pytest_plugin") +pytest_plugins = ("pytest_jupyter.jupyter_server", "jupyter_server.pytest_plugin", "pytest_asyncio") + + +def pytest_configure(config): + """Configure pytest settings.""" + # Set asyncio fixture loop scope to function to avoid warnings + config.option.asyncio_default_fixture_loop_scope = "function" @pytest.fixture diff --git a/jupyter_server_documents/kernels/kernel_client.py b/jupyter_server_documents/kernels/kernel_client.py index 9fd05c6..7c42e1d 100644 --- a/jupyter_server_documents/kernels/kernel_client.py +++ b/jupyter_server_documents/kernels/kernel_client.py @@ -80,7 +80,7 @@ async def stop_listening(self): """ # If the listening task isn't defined yet # do nothing. - if not self._listening_task: + if not hasattr(self, '_listening_task') or not self._listening_task: return # Attempt to cancel the task. @@ -93,6 +93,9 @@ async def stop_listening(self): # Log any exceptions that were raised. except Exception as err: self.log.error(err) + finally: + # Clear the task reference + self._listening_task = None _listening_task: t.Optional[t.Awaitable] = Any(allow_none=True) diff --git a/jupyter_server_documents/kernels/kernel_manager.py b/jupyter_server_documents/kernels/kernel_manager.py index 00083da..a59e443 100644 --- a/jupyter_server_documents/kernels/kernel_manager.py +++ b/jupyter_server_documents/kernels/kernel_manager.py @@ -127,8 +127,9 @@ async def connect(self): await asyncio.sleep(0.1) async def disconnect(self): - await self.main_client.stop_listening() - self.main_client.stop_channels() + if self.main_client: + await self.main_client.stop_listening() + self.main_client.stop_channels() async def broadcast_state(self): """Broadcast state to all listeners""" diff --git a/jupyter_server_documents/kernels/multi_kernel_manager.py b/jupyter_server_documents/kernels/multi_kernel_manager.py index 722ef9d..bc26806 100644 --- a/jupyter_server_documents/kernels/multi_kernel_manager.py +++ b/jupyter_server_documents/kernels/multi_kernel_manager.py @@ -7,4 +7,11 @@ def start_watching_activity(self, kernel_id): pass def stop_buffering(self, kernel_id): - pass \ No newline at end of file + pass + + # NOTE: Since we disable watching activity and buffering here, + # this method needs to be forked and remove code related to these things. + async def restart_kernel(self, kernel_id, now=False): + """Restart a kernel by kernel_id""" + self._check_kernel_id(kernel_id) + await self.pinned_superclass._async_restart_kernel(self, kernel_id, now=now) \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/__init__.py b/jupyter_server_documents/tests/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/jupyter_server_documents/tests/kernels/conftest.py b/jupyter_server_documents/tests/kernels/conftest.py new file mode 100644 index 0000000..d957c17 --- /dev/null +++ b/jupyter_server_documents/tests/kernels/conftest.py @@ -0,0 +1,23 @@ +"""Configuration for kernel tests.""" + +import pytest +from unittest.mock import MagicMock + + +@pytest.fixture +def mock_logger(): + """Create a mock logger for testing.""" + return MagicMock() + + +@pytest.fixture +def mock_session(): + """Create a mock session for testing.""" + session = MagicMock() + session.msg_header.return_value = {"msg_id": "test-msg-id"} + session.msg.return_value = {"test": "message"} + session.serialize.return_value = ["", "serialized", "msg"] + session.deserialize.return_value = {"msg_type": "test", "content": b"test"} + session.unpack.return_value = {"test": "data"} + session.feed_identities.return_value = ([], [b"test", b"message"]) + return session \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_kernel_client.py b/jupyter_server_documents/tests/kernels/test_kernel_client.py new file mode 100644 index 0000000..3f24bab --- /dev/null +++ b/jupyter_server_documents/tests/kernels/test_kernel_client.py @@ -0,0 +1,105 @@ +import pytest +from unittest.mock import MagicMock, patch + +from jupyter_server_documents.kernels.kernel_client import DocumentAwareKernelClient +from jupyter_server_documents.kernels.message_cache import KernelMessageCache +from jupyter_server_documents.outputs import OutputProcessor + + +class TestDocumentAwareKernelClient: + """Test cases for DocumentAwareKernelClient.""" + + def test_default_message_cache(self): + """Test that message cache is created by default.""" + client = DocumentAwareKernelClient() + assert isinstance(client.message_cache, KernelMessageCache) + + def test_default_output_processor(self): + """Test that output processor is created by default.""" + client = DocumentAwareKernelClient() + assert isinstance(client.output_processor, OutputProcessor) + + @pytest.mark.asyncio + async def test_stop_listening_no_task(self): + """Test that stop_listening does nothing when no task exists.""" + client = DocumentAwareKernelClient() + client._listening_task = None + + # Should not raise an exception + await client.stop_listening() + + def test_add_listener(self): + """Test adding a listener.""" + client = DocumentAwareKernelClient() + + def test_listener(channel, msg): + pass + + client.add_listener(test_listener) + + assert test_listener in client._listeners + + def test_remove_listener(self): + """Test removing a listener.""" + client = DocumentAwareKernelClient() + + def test_listener(channel, msg): + pass + + client.add_listener(test_listener) + client.remove_listener(test_listener) + + assert test_listener not in client._listeners + + @pytest.mark.asyncio + async def test_add_yroom(self): + """Test adding a YRoom.""" + client = DocumentAwareKernelClient() + + mock_yroom = MagicMock() + await client.add_yroom(mock_yroom) + + assert mock_yroom in client._yrooms + + @pytest.mark.asyncio + async def test_remove_yroom(self): + """Test removing a YRoom.""" + client = DocumentAwareKernelClient() + + mock_yroom = MagicMock() + client._yrooms.add(mock_yroom) + + await client.remove_yroom(mock_yroom) + + assert mock_yroom not in client._yrooms + + def test_send_kernel_info_creates_message(self): + """Test that send_kernel_info creates a kernel info message.""" + client = DocumentAwareKernelClient() + + # Mock session + from jupyter_client.session import Session + client.session = Session() + + with patch.object(client, 'handle_incoming_message') as mock_handle: + client.send_kernel_info() + + # Verify that handle_incoming_message was called with shell channel + mock_handle.assert_called_once() + args, kwargs = mock_handle.call_args + assert args[0] == "shell" # Channel name + assert isinstance(args[1], list) # Message list + + @pytest.mark.asyncio + async def test_handle_outgoing_message_control_channel(self): + """Test that control channel messages bypass document handling.""" + client = DocumentAwareKernelClient() + + msg = [b"test", b"message"] + + with patch.object(client, 'handle_document_related_message') as mock_handle_doc: + with patch.object(client, 'send_message_to_listeners') as mock_send: + await client.handle_outgoing_message("control", msg) + + mock_handle_doc.assert_not_called() + mock_send.assert_called_once_with("control", msg) \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_kernel_client_integration.py b/jupyter_server_documents/tests/kernels/test_kernel_client_integration.py new file mode 100644 index 0000000..7543962 --- /dev/null +++ b/jupyter_server_documents/tests/kernels/test_kernel_client_integration.py @@ -0,0 +1,450 @@ +import pytest +import asyncio +import json +from unittest.mock import MagicMock, AsyncMock, patch +from jupyter_client.session import Session +from jupyter_server_documents.ydocs import YNotebook +import pycrdt + +from jupyter_server_documents.kernels.kernel_client import DocumentAwareKernelClient +from jupyter_server_documents.rooms.yroom import YRoom +from jupyter_server_documents.outputs import OutputProcessor + + +class TestDocumentAwareKernelClientIntegration: + """Integration tests for DocumentAwareKernelClient with YDoc updates.""" + + @pytest.fixture + def mock_yroom_with_notebook(self): + """Create a mock YRoom with a real YNotebook.""" + # Create a real YDoc and YNotebook + ydoc = pycrdt.Doc() + awareness = MagicMock(spec=pycrdt.Awareness) # Mock awareness instead of using real one + ynotebook = YNotebook(ydoc, awareness) + + # Add a simple notebook structure with one cell + ynotebook.set({ + "cells": [ + { + "cell_type": "code", + "id": "test-cell-1", + "source": "2 + 2", + "metadata": {}, + "outputs": [], + "execution_count": None + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.9.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 + }) + + # Create mock YRoom + yroom = MagicMock(spec=YRoom) + yroom.get_jupyter_ydoc = AsyncMock(return_value=ynotebook) + yroom.get_awareness = MagicMock(return_value=awareness) + + return yroom, ynotebook + + @pytest.fixture + def kernel_client_with_yroom(self, mock_yroom_with_notebook): + """Create a DocumentAwareKernelClient with a real YRoom and YNotebook.""" + yroom, ynotebook = mock_yroom_with_notebook + + client = DocumentAwareKernelClient() + client.session = Session() + client.log = MagicMock() + + # Add the YRoom to the client + client._yrooms = {yroom} + + # Mock output processor + client.output_processor = MagicMock(spec=OutputProcessor) + client.output_processor.process_output = MagicMock() + + return client, yroom, ynotebook + + def create_kernel_message(self, session, msg_type, content, parent_msg_id=None, cell_id=None): + """Helper to create properly formatted kernel messages.""" + parent_header = {"msg_id": parent_msg_id} if parent_msg_id else {} + metadata = {"cellId": cell_id} if cell_id else {} + + msg = session.msg(msg_type, content, parent=parent_header, metadata=metadata) + return session.serialize(msg) + + @pytest.mark.asyncio + async def test_execute_input_updates_execution_count(self, kernel_client_with_yroom): + """Test that execute_input messages update execution count in YDoc.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return cell_id + parent_msg_id = "execute-request-123" + cell_id = "test-cell-1" + client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) + + # Create execute_input message + content = {"code": "2 + 2", "execution_count": 1} + msg_parts = self.create_kernel_message( + client.session, "execute_input", content, parent_msg_id, cell_id + ) + + # Process the message + await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter + + # Verify the execution count was updated in the YDoc + cells = ynotebook.get_cell_list() + target_cell = next((cell for cell in cells if cell.get("id") == cell_id), None) + assert target_cell is not None + assert target_cell.get("execution_count") == 1 + + @pytest.mark.asyncio + async def test_status_message_updates_cell_execution_state(self, kernel_client_with_yroom): + """Test that status messages update cell execution state in YDoc.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return cell_id and channel + parent_msg_id = "execute-request-123" + cell_id = "test-cell-1" + client.message_cache.get = MagicMock(return_value={ + "cell_id": cell_id, + "channel": "shell" + }) + + # Create status message with 'busy' state + content = {"execution_state": "busy"} + msg_parts = self.create_kernel_message( + client.session, "status", content, parent_msg_id, cell_id + ) + + # Process the message + await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter + + # Verify the cell execution state was updated to 'running' (converted from 'busy') + cells = ynotebook.get_cell_list() + target_cell = next((cell for cell in cells if cell.get("id") == cell_id), None) + assert target_cell is not None + assert target_cell.get("execution_state") == "running" + + @pytest.mark.asyncio + async def test_kernel_info_reply_updates_language_info(self, kernel_client_with_yroom): + """Test that kernel_info_reply updates language info in YDoc metadata.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache + parent_msg_id = "kernel-info-request-123" + client.message_cache.get = MagicMock(return_value={"cell_id": None}) + + # Create kernel_info_reply message + content = { + "language_info": { + "name": "python", + "version": "3.9.0", + "mimetype": "text/x-python", + "file_extension": ".py" + } + } + msg_parts = self.create_kernel_message( + client.session, "kernel_info_reply", content, parent_msg_id + ) + + # Process the message + await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter + + # Verify language info was updated in notebook metadata + metadata = ynotebook.get_meta() + assert "language_info" in metadata["metadata"] + assert metadata["metadata"]["language_info"]["name"] == "python" + assert metadata["metadata"]["language_info"]["version"] == "3.9.0" + + @pytest.mark.asyncio + async def test_output_message_processed_and_suppressed(self, kernel_client_with_yroom): + """Test that output messages are processed by output processor and suppressed.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return cell_id + parent_msg_id = "execute-request-123" + cell_id = "test-cell-1" + client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) + + # Create execute_result message (output) + content = { + "data": {"text/plain": "4"}, + "metadata": {}, + "execution_count": 1 + } + msg_parts = self.create_kernel_message( + client.session, "execute_result", content, parent_msg_id, cell_id + ) + + # Process the message + result = await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter + + # Verify the output processor was called + client.output_processor.process_output.assert_called_once_with( + "execute_result", cell_id, content + ) + + # Verify the message was suppressed (returned None) + assert result is None + + @pytest.mark.asyncio + async def test_stream_output_message_processed(self, kernel_client_with_yroom): + """Test that stream output messages are processed correctly.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return cell_id + parent_msg_id = "execute-request-123" + cell_id = "test-cell-1" + client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) + + # Create stream message + content = { + "name": "stdout", + "text": "4\n" + } + msg_parts = self.create_kernel_message( + client.session, "stream", content, parent_msg_id, cell_id + ) + + # Process the message + result = await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter + + # Verify the output processor was called + client.output_processor.process_output.assert_called_once_with( + "stream", cell_id, content + ) + + # Verify the message was suppressed + assert result is None + + @pytest.mark.asyncio + async def test_error_output_message_processed(self, kernel_client_with_yroom): + """Test that error output messages are processed correctly.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return cell_id + parent_msg_id = "execute-request-123" + cell_id = "test-cell-1" + client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) + + # Create error message + content = { + "ename": "NameError", + "evalue": "name 'x' is not defined", + "traceback": ["Traceback (most recent call last):", "NameError: name 'x' is not defined"] + } + msg_parts = self.create_kernel_message( + client.session, "error", content, parent_msg_id, cell_id + ) + + # Process the message + result = await client.handle_document_related_message(msg_parts[1:]) # Skip delimiter + + # Verify the output processor was called + client.output_processor.process_output.assert_called_once_with( + "error", cell_id, content + ) + + # Verify the message was suppressed + assert result is None + + @pytest.mark.asyncio + async def test_complete_execution_flow(self, kernel_client_with_yroom): + """Test complete execution flow: execute_input -> status -> output -> status.""" + client, yroom, ynotebook = kernel_client_with_yroom + + parent_msg_id = "execute-request-123" + cell_id = "test-cell-1" + + # Mock message cache to return cell_id and channel + client.message_cache.get = MagicMock(return_value={ + "cell_id": cell_id, + "channel": "shell" + }) + + # Step 1: Execute input + execute_input_content = {"code": "2 + 2", "execution_count": 1} + msg_parts = self.create_kernel_message( + client.session, "execute_input", execute_input_content, parent_msg_id, cell_id + ) + await client.handle_document_related_message(msg_parts[1:]) + + # Step 2: Status busy + status_busy_content = {"execution_state": "busy"} + msg_parts = self.create_kernel_message( + client.session, "status", status_busy_content, parent_msg_id, cell_id + ) + await client.handle_document_related_message(msg_parts[1:]) + + # Step 3: Execute result + result_content = { + "data": {"text/plain": "4"}, + "metadata": {}, + "execution_count": 1 + } + msg_parts = self.create_kernel_message( + client.session, "execute_result", result_content, parent_msg_id, cell_id + ) + await client.handle_document_related_message(msg_parts[1:]) + + # Step 4: Status idle + status_idle_content = {"execution_state": "idle"} + msg_parts = self.create_kernel_message( + client.session, "status", status_idle_content, parent_msg_id, cell_id + ) + await client.handle_document_related_message(msg_parts[1:]) + + # Verify final state of the cell in YDoc + cells = ynotebook.get_cell_list() + target_cell = next((cell for cell in cells if cell.get("id") == cell_id), None) + assert target_cell is not None + assert target_cell.get("execution_count") == 1 + assert target_cell.get("execution_state") == "idle" + + # Verify output processor was called for the result + client.output_processor.process_output.assert_called_with( + "execute_result", cell_id, result_content + ) + + @pytest.mark.asyncio + async def test_awareness_state_updates_for_kernel_status(self, kernel_client_with_yroom): + """Test that kernel status updates awareness state.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return shell channel (for notebook-level status) + parent_msg_id = "kernel-info-request-123" + client.message_cache.get = MagicMock(return_value={ + "cell_id": None, + "channel": "shell" + }) + + # Create status message for kernel-level state + content = {"execution_state": "busy"} + msg_parts = self.create_kernel_message( + client.session, "status", content, parent_msg_id + ) + + # Process the message + await client.handle_document_related_message(msg_parts[1:]) + + # Verify awareness was updated + awareness = yroom.get_awareness() + awareness.set_local_state_field.assert_called_once_with( + "kernel", {"execution_state": "busy"} + ) + + @pytest.mark.asyncio + async def test_multiple_cells_execution_states(self, kernel_client_with_yroom): + """Test that multiple cells can have different execution states.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Add another cell to the notebook + cells = ynotebook.get_cell_list() + ynotebook.append_cell({ + "cell_type": "code", + "id": "test-cell-2", + "source": "print('hello')", + "metadata": {}, + "outputs": [], + "execution_count": None + }) + + # Mock message cache to return different cell_ids + def mock_get(msg_id): + if msg_id == "execute-request-123": + return {"cell_id": "test-cell-1", "channel": "shell"} + elif msg_id == "execute-request-456": + return {"cell_id": "test-cell-2", "channel": "shell"} + return None + + client.message_cache.get = MagicMock(side_effect=mock_get) + + # Set first cell to busy + content1 = {"execution_state": "busy"} + msg_parts1 = self.create_kernel_message( + client.session, "status", content1, "execute-request-123", "test-cell-1" + ) + await client.handle_document_related_message(msg_parts1[1:]) + + # Set second cell to idle + content2 = {"execution_state": "idle"} + msg_parts2 = self.create_kernel_message( + client.session, "status", content2, "execute-request-456", "test-cell-2" + ) + await client.handle_document_related_message(msg_parts2[1:]) + + # Verify both cells have correct states + cells = ynotebook.get_cell_list() + cell1 = next((cell for cell in cells if cell.get("id") == "test-cell-1"), None) + cell2 = next((cell for cell in cells if cell.get("id") == "test-cell-2"), None) + + assert cell1 is not None + assert cell1.get("execution_state") == "running" # 'busy' -> 'running' + assert cell2 is not None + assert cell2.get("execution_state") == "idle" + + @pytest.mark.asyncio + async def test_message_without_cell_id_skips_cell_updates(self, kernel_client_with_yroom): + """Test that messages without cell_id don't update cell-specific data.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return no cell_id + parent_msg_id = "some-request-123" + client.message_cache.get = MagicMock(return_value={"cell_id": None}) + + # Create execute_input message without cell_id + content = {"code": "2 + 2", "execution_count": 1} + msg_parts = self.create_kernel_message( + client.session, "execute_input", content, parent_msg_id + ) + + # Process the message + await client.handle_document_related_message(msg_parts[1:]) + + # Verify no cell was updated (execution_count should remain None) + cells = ynotebook.get_cell_list() + for cell in cells: + assert cell.get("execution_count") is None + + @pytest.mark.asyncio + async def test_display_data_message_processing(self, kernel_client_with_yroom): + """Test that display_data messages are processed correctly.""" + client, yroom, ynotebook = kernel_client_with_yroom + + # Mock message cache to return cell_id + parent_msg_id = "execute-request-123" + cell_id = "test-cell-1" + client.message_cache.get = MagicMock(return_value={"cell_id": cell_id}) + + # Create display_data message + content = { + "data": { + "text/plain": "Hello World", + "text/html": "
Hello World
" + }, + "metadata": {} + } + msg_parts = self.create_kernel_message( + client.session, "display_data", content, parent_msg_id, cell_id + ) + + # Process the message + result = await client.handle_document_related_message(msg_parts[1:]) + + # Verify the output processor was called + client.output_processor.process_output.assert_called_once_with( + "display_data", cell_id, content + ) + + # Verify the message was suppressed + assert result is None \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_kernel_manager.py b/jupyter_server_documents/tests/kernels/test_kernel_manager.py new file mode 100644 index 0000000..b6c7d48 --- /dev/null +++ b/jupyter_server_documents/tests/kernels/test_kernel_manager.py @@ -0,0 +1,71 @@ +import pytest +from unittest.mock import patch + +from jupyter_server_documents.kernels.kernel_manager import NextGenKernelManager +from jupyter_server_documents.kernels.states import ExecutionStates, LifecycleStates + + +class TestNextGenKernelManager: + """Test cases for NextGenKernelManager.""" + + def test_set_state_lifecycle_only(self): + """Test setting only lifecycle state.""" + km = NextGenKernelManager() + km.set_state(LifecycleStates.STARTING) + assert km.lifecycle_state == LifecycleStates.STARTING.value + + def test_set_state_execution_only(self): + """Test setting only execution state.""" + km = NextGenKernelManager() + km.set_state(execution_state=ExecutionStates.IDLE) + assert km.execution_state == ExecutionStates.IDLE.value + + def test_set_state_both(self): + """Test setting both lifecycle and execution states.""" + km = NextGenKernelManager() + km.set_state(LifecycleStates.CONNECTED, ExecutionStates.BUSY) + assert km.lifecycle_state == LifecycleStates.CONNECTED.value + assert km.execution_state == ExecutionStates.BUSY.value + + def test_lifecycle_state_validation(self): + """Test lifecycle state validation.""" + km = NextGenKernelManager() + with pytest.raises(Exception): + km.lifecycle_state = "invalid_state" + + def test_execution_state_validation(self): + """Test execution state validation.""" + km = NextGenKernelManager() + with pytest.raises(Exception): + km.execution_state = "invalid_state" + + def test_execution_state_listener_non_iopub_channel(self): + """Test execution state listener ignores non-iopub channels.""" + km = NextGenKernelManager() + original_state = km.execution_state + + km.execution_state_listener("shell", [b"test", b"message"]) + + # State should remain unchanged + assert km.execution_state == original_state + + @pytest.mark.asyncio + async def test_disconnect_without_client(self): + """Test disconnecting when no client exists.""" + km = NextGenKernelManager() + km.main_client = None + + # Should not raise an exception + await km.disconnect() + + @pytest.mark.asyncio + async def test_restart_kernel_sets_state(self): + """Test that restart_kernel sets restarting state.""" + km = NextGenKernelManager() + + with patch('jupyter_client.manager.AsyncKernelManager.restart_kernel') as mock_restart: + mock_restart.return_value = None + await km.restart_kernel() + + assert km.lifecycle_state == LifecycleStates.RESTARTING.value + mock_restart.assert_called_once() \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_multi_kernel_manager.py b/jupyter_server_documents/tests/kernels/test_multi_kernel_manager.py new file mode 100644 index 0000000..6cfd872 --- /dev/null +++ b/jupyter_server_documents/tests/kernels/test_multi_kernel_manager.py @@ -0,0 +1,82 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from jupyter_server_documents.kernels.multi_kernel_manager import NextGenMappingKernelManager + + +@pytest.fixture +def multi_kernel_manager(): + """Create a NextGenMappingKernelManager instance for testing.""" + mkm = NextGenMappingKernelManager() + mkm._check_kernel_id = MagicMock() + mkm.pinned_superclass = MagicMock() + mkm.pinned_superclass._async_restart_kernel = AsyncMock() + return mkm + + +class TestNextGenMappingKernelManager: + """Test cases for NextGenMappingKernelManager.""" + + def test_start_watching_activity_noop(self, multi_kernel_manager): + """Test that start_watching_activity does nothing.""" + # Should not raise an exception + multi_kernel_manager.start_watching_activity("test-kernel-id") + + def test_stop_buffering_noop(self, multi_kernel_manager): + """Test that stop_buffering does nothing.""" + # Should not raise an exception + multi_kernel_manager.stop_buffering("test-kernel-id") + + @pytest.mark.asyncio + async def test_restart_kernel_checks_id(self, multi_kernel_manager): + """Test that restart_kernel checks kernel ID.""" + kernel_id = "test-kernel-id" + + await multi_kernel_manager.restart_kernel(kernel_id) + + multi_kernel_manager._check_kernel_id.assert_called_once_with(kernel_id) + + @pytest.mark.asyncio + async def test_restart_kernel_calls_superclass(self, multi_kernel_manager): + """Test that restart_kernel calls the superclass method.""" + kernel_id = "test-kernel-id" + + await multi_kernel_manager.restart_kernel(kernel_id, now=True) + + multi_kernel_manager.pinned_superclass._async_restart_kernel.assert_called_once_with( + multi_kernel_manager, kernel_id, now=True + ) + + @pytest.mark.asyncio + async def test_restart_kernel_default_now_parameter(self, multi_kernel_manager): + """Test that restart_kernel uses default now=False.""" + kernel_id = "test-kernel-id" + + await multi_kernel_manager.restart_kernel(kernel_id) + + multi_kernel_manager.pinned_superclass._async_restart_kernel.assert_called_once_with( + multi_kernel_manager, kernel_id, now=False + ) + + @pytest.mark.asyncio + async def test_restart_kernel_propagates_exceptions(self, multi_kernel_manager): + """Test that restart_kernel propagates exceptions from superclass.""" + kernel_id = "test-kernel-id" + test_exception = Exception("Test restart error") + multi_kernel_manager.pinned_superclass._async_restart_kernel.side_effect = test_exception + + with pytest.raises(Exception, match="Test restart error"): + await multi_kernel_manager.restart_kernel(kernel_id) + + @pytest.mark.asyncio + async def test_restart_kernel_propagates_id_check_exceptions(self, multi_kernel_manager): + """Test that restart_kernel propagates exceptions from kernel ID check.""" + kernel_id = "invalid-kernel-id" + test_exception = ValueError("Invalid kernel ID") + multi_kernel_manager._check_kernel_id.side_effect = test_exception + + with pytest.raises(ValueError, match="Invalid kernel ID"): + await multi_kernel_manager.restart_kernel(kernel_id) + + # Superclass method should not be called if ID check fails + multi_kernel_manager.pinned_superclass._async_restart_kernel.assert_not_called() \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_states.py b/jupyter_server_documents/tests/kernels/test_states.py new file mode 100644 index 0000000..3ca60d3 --- /dev/null +++ b/jupyter_server_documents/tests/kernels/test_states.py @@ -0,0 +1,175 @@ +import pytest + +from jupyter_server_documents.kernels.states import LifecycleStates, ExecutionStates, StrContainerEnum, StrContainerEnumMeta + + +class TestStrContainerEnumMeta: + """Test cases for StrContainerEnumMeta.""" + + def test_contains_by_name(self): + """Test that enum names are found with 'in' operator.""" + assert "IDLE" in ExecutionStates + assert "STARTED" in LifecycleStates + + def test_contains_by_value(self): + """Test that enum values are found with 'in' operator.""" + assert "idle" in ExecutionStates + assert "started" in LifecycleStates + + def test_contains_missing(self): + """Test that missing items are not found.""" + assert "MISSING" not in ExecutionStates + assert "missing" not in LifecycleStates + + +class TestStrContainerEnum: + """Test cases for StrContainerEnum base class.""" + + def test_is_string_subclass(self): + """Test that StrContainerEnum is a string subclass.""" + assert issubclass(StrContainerEnum, str) + + def test_enum_value_is_string(self): + """Test that enum values can be used as strings.""" + idle_state = ExecutionStates.IDLE + assert isinstance(idle_state, str) + assert idle_state == "idle" + assert idle_state.upper() == "IDLE" + + +class TestLifecycleStates: + """Test cases for LifecycleStates enum.""" + + def test_all_states_defined(self): + """Test that all expected lifecycle states are defined.""" + expected_states = [ + "UNKNOWN", "STARTING", "STARTED", "TERMINATING", "CONNECTING", + "CONNECTED", "RESTARTING", "RECONNECTING", "CULLED", + "DISCONNECTED", "TERMINATED", "DEAD" + ] + + for state in expected_states: + assert hasattr(LifecycleStates, state) + + def test_state_values(self): + """Test that state values are lowercase versions of names.""" + assert LifecycleStates.UNKNOWN.value == "unknown" + assert LifecycleStates.STARTING.value == "starting" + assert LifecycleStates.STARTED.value == "started" + assert LifecycleStates.TERMINATING.value == "terminating" + assert LifecycleStates.CONNECTING.value == "connecting" + assert LifecycleStates.CONNECTED.value == "connected" + assert LifecycleStates.RESTARTING.value == "restarting" + assert LifecycleStates.RECONNECTING.value == "reconnecting" + assert LifecycleStates.CULLED.value == "culled" + assert LifecycleStates.DISCONNECTED.value == "disconnected" + assert LifecycleStates.TERMINATED.value == "terminated" + assert LifecycleStates.DEAD.value == "dead" + + def test_state_equality(self): + """Test that states can be compared by value.""" + assert LifecycleStates.UNKNOWN == "unknown" + assert LifecycleStates.STARTING == "starting" + assert LifecycleStates.CONNECTED == "connected" + + def test_state_membership(self): + """Test state membership using 'in' operator.""" + assert "starting" in LifecycleStates + assert "STARTING" in LifecycleStates + assert "connected" in LifecycleStates + assert "CONNECTED" in LifecycleStates + assert "invalid_state" not in LifecycleStates + + def test_state_iteration(self): + """Test iterating over lifecycle states.""" + states = list(LifecycleStates) + assert len(states) == 12 # Total number of defined states + assert LifecycleStates.UNKNOWN in states + assert LifecycleStates.DEAD in states + + +class TestExecutionStates: + """Test cases for ExecutionStates enum.""" + + def test_all_states_defined(self): + """Test that all expected execution states are defined.""" + expected_states = ["BUSY", "IDLE", "STARTING", "UNKNOWN", "DEAD"] + + for state in expected_states: + assert hasattr(ExecutionStates, state) + + def test_state_values(self): + """Test that state values are lowercase versions of names.""" + assert ExecutionStates.BUSY.value == "busy" + assert ExecutionStates.IDLE.value == "idle" + assert ExecutionStates.STARTING.value == "starting" + assert ExecutionStates.UNKNOWN.value == "unknown" + assert ExecutionStates.DEAD.value == "dead" + + def test_state_equality(self): + """Test that states can be compared by value.""" + assert ExecutionStates.BUSY == "busy" + assert ExecutionStates.IDLE == "idle" + assert ExecutionStates.STARTING == "starting" + assert ExecutionStates.UNKNOWN == "unknown" + assert ExecutionStates.DEAD == "dead" + + def test_state_membership(self): + """Test state membership using 'in' operator.""" + assert "busy" in ExecutionStates + assert "BUSY" in ExecutionStates + assert "idle" in ExecutionStates + assert "IDLE" in ExecutionStates + assert "invalid_state" not in ExecutionStates + + def test_state_iteration(self): + """Test iterating over execution states.""" + states = list(ExecutionStates) + assert len(states) == 5 # Total number of defined states + assert ExecutionStates.BUSY in states + assert ExecutionStates.IDLE in states + + def test_state_string_operations(self): + """Test that states can be used in string operations.""" + busy_state = ExecutionStates.BUSY + assert busy_state.upper() == "BUSY" + assert busy_state.capitalize() == "Busy" + assert len(busy_state) == 4 + assert busy_state.startswith("b") + + +class TestEnumIntegration: + """Integration tests for both enums.""" + + def test_enum_types_are_different(self): + """Test that the two enum types are distinct.""" + # Since both are StrContainerEnum subclasses, they compare as equal strings + # but they are different types + assert type(LifecycleStates.STARTING) != type(ExecutionStates.STARTING) + assert LifecycleStates.STARTING is not ExecutionStates.STARTING + + def test_enum_values_can_be_same(self): + """Test that enum values can be the same string.""" + # Both have "starting", "unknown", "dead" values + assert LifecycleStates.STARTING.value == ExecutionStates.STARTING.value == "starting" + assert LifecycleStates.UNKNOWN.value == ExecutionStates.UNKNOWN.value == "unknown" + assert LifecycleStates.DEAD.value == ExecutionStates.DEAD.value == "dead" + + def test_enum_members_are_unique_within_enum(self): + """Test that enum members are unique within their enum.""" + lifecycle_values = [state.value for state in LifecycleStates] + execution_values = [state.value for state in ExecutionStates] + + # Check for uniqueness within each enum + assert len(lifecycle_values) == len(set(lifecycle_values)) + assert len(execution_values) == len(set(execution_values)) + + def test_enum_membership_is_type_specific(self): + """Test that membership checks are type-specific.""" + # "idle" is in ExecutionStates but not in LifecycleStates + assert "idle" in ExecutionStates + assert "idle" not in LifecycleStates + + # "connected" is in LifecycleStates but not in ExecutionStates + assert "connected" in LifecycleStates + assert "connected" not in ExecutionStates \ No newline at end of file diff --git a/jupyter_server_documents/tests/kernels/test_websocket_connection.py b/jupyter_server_documents/tests/kernels/test_websocket_connection.py new file mode 100644 index 0000000..3e73536 --- /dev/null +++ b/jupyter_server_documents/tests/kernels/test_websocket_connection.py @@ -0,0 +1,124 @@ +import pytest +from unittest.mock import MagicMock, patch +from tornado.websocket import WebSocketClosedError + +from jupyter_server_documents.kernels.websocket_connection import NextGenKernelWebsocketConnection + + +class TestNextGenKernelWebsocketConnection: + """Test cases for NextGenKernelWebsocketConnection.""" + + def test_kernel_ws_protocol(self): + """Test that the websocket protocol is set correctly.""" + assert NextGenKernelWebsocketConnection.kernel_ws_protocol == "v1.kernel.websocket.jupyter.org" + + def test_inheritance(self): + """Test that the class inherits from BaseKernelWebsocketConnection.""" + from jupyter_server.services.kernels.connection.base import BaseKernelWebsocketConnection + + assert issubclass(NextGenKernelWebsocketConnection, BaseKernelWebsocketConnection) + + # Test that required methods are implemented + conn = NextGenKernelWebsocketConnection() + assert hasattr(conn, 'connect') + assert hasattr(conn, 'disconnect') + assert hasattr(conn, 'handle_incoming_message') + assert hasattr(conn, 'handle_outgoing_message') + assert hasattr(conn, 'kernel_ws_protocol') + + @patch('jupyter_server_documents.kernels.websocket_connection.deserialize_msg_from_ws_v1') + def test_handle_incoming_message_deserializes(self, mock_deserialize): + """Test that incoming messages are deserialized correctly.""" + conn = NextGenKernelWebsocketConnection() + + # Mock the kernel_manager property + mock_kernel_manager = MagicMock() + mock_kernel_manager.main_client = MagicMock() + + with patch.object(type(conn), 'kernel_manager', mock_kernel_manager): + mock_deserialize.return_value = ("shell", [b"test", b"message"]) + + incoming_msg = b"test_websocket_message" + conn.handle_incoming_message(incoming_msg) + + mock_deserialize.assert_called_once_with(incoming_msg) + + @patch('jupyter_server_documents.kernels.websocket_connection.deserialize_msg_from_ws_v1') + def test_handle_incoming_message_no_client(self, mock_deserialize): + """Test that incoming messages are ignored when no client exists.""" + conn = NextGenKernelWebsocketConnection() + + # Mock the kernel_manager property with no client + mock_kernel_manager = MagicMock() + mock_kernel_manager.main_client = None + + with patch.object(type(conn), 'kernel_manager', mock_kernel_manager): + mock_deserialize.return_value = ("shell", [b"test", b"message"]) + + incoming_msg = b"test_websocket_message" + + # Should not raise an exception + conn.handle_incoming_message(incoming_msg) + + @patch('jupyter_server_documents.kernels.websocket_connection.serialize_msg_to_ws_v1') + def test_handle_outgoing_message_removes_signature(self, mock_serialize): + """Test that the signature is properly removed from outgoing messages.""" + conn = NextGenKernelWebsocketConnection() + + # Mock websocket_handler and log to avoid traitlet validation + mock_handler = MagicMock() + mock_log = MagicMock() + + with patch.object(type(conn), 'websocket_handler', mock_handler): + with patch.object(type(conn), 'log', mock_log): + mock_serialize.return_value = b"serialized_message" + + # Message with signature at index 0 + msg = [b"signature", b"header", b"parent", b"metadata", b"content"] + conn.handle_outgoing_message("iopub", msg) + + # Should call serialize with msg[1:] (signature removed) + mock_serialize.assert_called_once_with( + [b"header", b"parent", b"metadata", b"content"], "iopub" + ) + + @patch('jupyter_server_documents.kernels.websocket_connection.serialize_msg_to_ws_v1') + def test_handle_outgoing_message_websocket_closed(self, mock_serialize): + """Test that closed websocket errors are handled gracefully.""" + conn = NextGenKernelWebsocketConnection() + + mock_serialize.return_value = b"serialized_message" + + # Mock websocket_handler to raise WebSocketClosedError + mock_handler = MagicMock() + mock_handler.write_message.side_effect = WebSocketClosedError() + mock_log = MagicMock() + + with patch.object(type(conn), 'websocket_handler', mock_handler): + with patch.object(type(conn), 'log', mock_log): + msg = [b"signature", b"header", b"parent", b"metadata", b"content"] + conn.handle_outgoing_message("iopub", msg) + + mock_log.warning.assert_called_once_with( + "A ZMQ message arrived on a closed websocket channel." + ) + + @patch('jupyter_server_documents.kernels.websocket_connection.serialize_msg_to_ws_v1') + def test_handle_outgoing_message_general_exception(self, mock_serialize): + """Test that general exceptions are handled gracefully.""" + conn = NextGenKernelWebsocketConnection() + + mock_serialize.return_value = b"serialized_message" + test_exception = Exception("Test exception") + + # Mock websocket_handler to raise exception + mock_handler = MagicMock() + mock_handler.write_message.side_effect = test_exception + mock_log = MagicMock() + + with patch.object(type(conn), 'websocket_handler', mock_handler): + with patch.object(type(conn), 'log', mock_log): + msg = [b"signature", b"header", b"parent", b"metadata", b"content"] + conn.handle_outgoing_message("iopub", msg) + + mock_log.error.assert_called_once_with(test_exception) \ No newline at end of file diff --git a/jupyter_server_documents/tests/test_yroom_file_api.py b/jupyter_server_documents/tests/test_yroom_file_api.py index 43e8795..a0a0dcc 100644 --- a/jupyter_server_documents/tests/test_yroom_file_api.py +++ b/jupyter_server_documents/tests/test_yroom_file_api.py @@ -104,12 +104,12 @@ def empty_yunicode() -> YUnicode: @pytest.mark.asyncio(loop_scope="module") async def test_load_plaintext_file( - plaintext_file_api: Awaitable[YRoomFileAPI], + plaintext_file_api: YRoomFileAPI, empty_yunicode: YUnicode, mock_plaintext_file: str, ): # Load content into JupyterYDoc - file_api = await plaintext_file_api + file_api = plaintext_file_api jupyter_ydoc = empty_yunicode file_api.load_content_into(jupyter_ydoc) await file_api.until_content_loaded diff --git a/jupyter_server_documents/ydocs.py b/jupyter_server_documents/ydocs.py index ae23f24..fac0234 100644 --- a/jupyter_server_documents/ydocs.py +++ b/jupyter_server_documents/ydocs.py @@ -74,6 +74,28 @@ def scan_cells(self, cell_id): break return cell_index, target_cell + def get_cell_list(self): + """Get a list of all cells in the notebook. + + Returns a list of pycrdt.Map objects representing the cells. + This method is used by the integration tests. + + :return: List of cells + :rtype: List[pycrdt.Map] + """ + return [self.ycells[i] for i in range(len(self.ycells))] + + def get_meta(self): + """Get the notebook metadata. + + Returns the full metadata structure including nbformat info and custom metadata. + This method is used by the integration tests. + + :return: The notebook metadata + :rtype: Dict + """ + return self._ymeta.to_py() + ydocs = {ep.name: ep.load() for ep in entry_points(group="jupyter_ydoc")}