diff --git a/openhands-agent-server/openhands/agent_server/event_router.py b/openhands-agent-server/openhands/agent_server/event_router.py index b0bf91ee8c..7637ed834a 100644 --- a/openhands-agent-server/openhands/agent_server/event_router.py +++ b/openhands-agent-server/openhands/agent_server/event_router.py @@ -38,20 +38,25 @@ def normalize_datetime_to_server_timezone(dt: datetime) -> datetime: """ - Normalize datetime to server timezone for consistent comparison. + Normalize datetime to server timezone for consistent comparison with events. - If the datetime has timezone info, convert to server native timezone. + Event timestamps are stored as naive datetimes in server local time. + This function ensures filter datetimes are also naive in server local time + so they can be compared correctly. + + If the datetime has timezone info, convert to server native timezone and + strip the tzinfo to make it naive. If it's naive (no timezone), assume it's already in server timezone. Args: dt: Input datetime (may be timezone-aware or naive) Returns: - Datetime in server native timezone (timezone-aware) + Naive datetime in server local time """ if dt.tzinfo is not None: - # Timezone-aware: convert to server native timezone - return dt.astimezone(None) + # Timezone-aware: convert to server native timezone, then make naive + return dt.astimezone(None).replace(tzinfo=None) else: # Naive datetime: assume it's already in server timezone return dt diff --git a/openhands-agent-server/openhands/agent_server/sockets.py b/openhands-agent-server/openhands/agent_server/sockets.py index 427432cef2..0f7682b345 100644 --- a/openhands-agent-server/openhands/agent_server/sockets.py +++ b/openhands-agent-server/openhands/agent_server/sockets.py @@ -9,7 +9,8 @@ import logging from dataclasses import dataclass -from typing import Annotated +from datetime import datetime +from typing import Annotated, Literal from uuid import UUID from fastapi import ( @@ -24,6 +25,7 @@ from openhands.agent_server.conversation_service import ( get_default_conversation_service, ) +from openhands.agent_server.event_router import normalize_datetime_to_server_timezone from openhands.agent_server.models import BashEventBase, ExecuteBashRequest from openhands.agent_server.pub_sub import Subscriber from openhands.sdk import Event, Message @@ -91,9 +93,54 @@ async def events_socket( conversation_id: UUID, websocket: WebSocket, session_api_key: Annotated[str | None, Query(alias="session_api_key")] = None, - resend_all: Annotated[bool, Query()] = False, + resend_mode: Annotated[ + Literal["all", "since"] | None, + Query( + description=( + "Mode for resending historical events on connect. " + "'all' sends all events, 'since' sends events after 'after_timestamp'." + ) + ), + ] = None, + after_timestamp: Annotated[ + datetime | None, + Query( + description=( + "Required when resend_mode='since'. Events with timestamp >= this " + "value will be sent. Accepts ISO 8601 format. Timezone-aware " + "datetimes are converted to server local time; naive datetimes " + "assumed in server timezone." + ) + ), + ] = None, + # Deprecated parameter - kept for backward compatibility + resend_all: Annotated[ + bool, + Query( + include_in_schema=False, + deprecated=True, + ), + ] = False, ): - """WebSocket endpoint for conversation events.""" + """WebSocket endpoint for conversation events. + + Args: + conversation_id: The conversation ID to subscribe to. + websocket: The WebSocket connection. + session_api_key: Optional API key for authentication. + resend_mode: Mode for resending historical events on connect. + - 'all': Resend all existing events + - 'since': Resend events after 'after_timestamp' (requires after_timestamp) + - None: Don't resend, just subscribe to new events + after_timestamp: Required when resend_mode='since'. Events with + timestamp >= this value will be sent. Timestamps are interpreted in + server local time. Timezone-aware datetimes are converted to server + timezone. Enables efficient bi-directional loading where REST fetches + historical events and WebSocket handles events after a specific point. + resend_all: DEPRECATED. Use resend_mode='all' instead. Kept for + backward compatibility - if True and resend_mode is None, behaves + as resend_mode='all'. + """ if not await _accept_authenticated_websocket(websocket, session_api_key): return @@ -108,12 +155,44 @@ async def events_socket( _WebSocketSubscriber(websocket) ) + # Determine effective resend mode (handle deprecated resend_all) + effective_mode = resend_mode + if effective_mode is None and resend_all: + logger.warning( + "resend_all is deprecated, use resend_mode='all' instead: " + f"{conversation_id}" + ) + effective_mode = "all" + + # Normalize timezone-aware datetimes to server timezone + normalized_after_timestamp = ( + normalize_datetime_to_server_timezone(after_timestamp) + if after_timestamp + else None + ) + try: - # Resend all existing events if requested - if resend_all: - logger.info(f"Resending events: {conversation_id}") + # Resend existing events based on mode + if effective_mode == "all": + logger.info(f"Resending all events: {conversation_id}") async for event in page_iterator(event_service.search_events): await _send_event(event, websocket) + elif effective_mode == "since": + if not normalized_after_timestamp: + logger.warning( + f"resend_mode='since' requires after_timestamp, " + f"no events will be resent: {conversation_id}" + ) + else: + logger.info( + f"Resending events since {normalized_after_timestamp}: " + f"{conversation_id}" + ) + async for event in page_iterator( + event_service.search_events, + timestamp__gte=normalized_after_timestamp, + ): + await _send_event(event, websocket) # Listen for messages over the socket while True: @@ -140,9 +219,34 @@ async def events_socket( async def bash_events_socket( websocket: WebSocket, session_api_key: Annotated[str | None, Query(alias="session_api_key")] = None, - resend_all: Annotated[bool, Query()] = False, + resend_mode: Annotated[ + Literal["all"] | None, + Query( + description=( + "Mode for resending historical events on connect. " + "'all' sends all events." + ) + ), + ] = None, + # Deprecated parameter - kept for backward compatibility + resend_all: Annotated[ + bool, + Query( + include_in_schema=False, + deprecated=True, + ), + ] = False, ): - """WebSocket endpoint for bash events.""" + """WebSocket endpoint for bash events. + + Args: + websocket: The WebSocket connection. + session_api_key: Optional API key for authentication. + resend_mode: Mode for resending historical events on connect. + - 'all': Resend all existing bash events + - None: Don't resend, just subscribe to new events + resend_all: DEPRECATED. Use resend_mode='all' instead. + """ if not await _accept_authenticated_websocket(websocket, session_api_key): return @@ -150,9 +254,16 @@ async def bash_events_socket( subscriber_id = await bash_event_service.subscribe_to_events( _BashWebSocketSubscriber(websocket) ) + + # Determine effective resend mode (handle deprecated resend_all) + effective_mode = resend_mode + if effective_mode is None and resend_all: + logger.warning("resend_all is deprecated, use resend_mode='all' instead") + effective_mode = "all" + try: # Resend all existing events if requested - if resend_all: + if effective_mode == "all": logger.info("Resending bash events") async for event in page_iterator(bash_event_service.search_bash_events): await _send_bash_event(event, websocket) diff --git a/tests/agent_server/test_event_router.py b/tests/agent_server/test_event_router.py index eb4630e573..c3cab3158f 100644 --- a/tests/agent_server/test_event_router.py +++ b/tests/agent_server/test_event_router.py @@ -1,5 +1,6 @@ """Tests for event_router.py endpoints.""" +from datetime import UTC, datetime, timedelta, timezone from pathlib import Path from typing import cast from unittest.mock import AsyncMock, MagicMock @@ -10,7 +11,10 @@ from fastapi.testclient import TestClient from openhands.agent_server.dependencies import get_event_service -from openhands.agent_server.event_router import event_router +from openhands.agent_server.event_router import ( + event_router, + normalize_datetime_to_server_timezone, +) from openhands.agent_server.event_service import EventService from openhands.agent_server.models import SendMessageRequest from openhands.sdk import Message @@ -18,6 +22,45 @@ from openhands.sdk.llm.message import ImageContent, TextContent +def test_normalize_datetime_naive_passthrough(): + """Naive datetimes should be returned unchanged.""" + naive_dt = datetime(2025, 1, 15, 10, 30, 0) + result = normalize_datetime_to_server_timezone(naive_dt) + + assert result == naive_dt + assert result.tzinfo is None + + +def test_normalize_datetime_utc_converted_to_naive(): + """UTC datetime should be converted to server local time and made naive.""" + utc_dt = datetime(2025, 1, 15, 10, 30, 0, tzinfo=UTC) + result = normalize_datetime_to_server_timezone(utc_dt) + + assert result.tzinfo is None + expected = utc_dt.astimezone(None).replace(tzinfo=None) + assert result == expected + + +def test_normalize_datetime_preserves_microseconds(): + """Microseconds should be preserved through conversion.""" + utc_dt = datetime(2025, 1, 15, 10, 30, 0, 123456, tzinfo=UTC) + result = normalize_datetime_to_server_timezone(utc_dt) + + assert result.microsecond == 123456 + + +def test_normalize_datetime_fixed_offset_timezone(): + """Test with a specific fixed offset timezone (UTC+5:30).""" + ist = timezone(timedelta(hours=5, minutes=30)) + ist_dt = datetime(2025, 1, 15, 16, 0, 0, tzinfo=ist) + + result = normalize_datetime_to_server_timezone(ist_dt) + + assert result.tzinfo is None + expected = ist_dt.astimezone(None).replace(tzinfo=None) + assert result == expected + + @pytest.fixture def client(): """Create a test client for the FastAPI app without authentication.""" diff --git a/tests/agent_server/test_event_router_websocket.py b/tests/agent_server/test_event_router_websocket.py index eb24572dbe..636df193e5 100644 --- a/tests/agent_server/test_event_router_websocket.py +++ b/tests/agent_server/test_event_router_websocket.py @@ -1,5 +1,7 @@ """Tests for websocket functionality in event_router.py""" +from datetime import UTC, datetime +from typing import cast from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 @@ -7,8 +9,10 @@ from fastapi import WebSocketDisconnect from openhands.agent_server.event_service import EventService +from openhands.agent_server.models import EventPage from openhands.agent_server.sockets import _WebSocketSubscriber from openhands.sdk import Message +from openhands.sdk.event import Event from openhands.sdk.event.llm_convertible import MessageEvent from openhands.sdk.llm.message import TextContent @@ -42,358 +46,492 @@ def sample_conversation_id(): return uuid4() -class TestWebSocketSubscriber: - """Test cases for _WebSocketSubscriber class.""" +@pytest.mark.asyncio +async def test_websocket_subscriber_call_success(mock_websocket): + """Test successful event sending through WebSocket subscriber.""" + subscriber = _WebSocketSubscriber(websocket=mock_websocket) + event = MessageEvent( + id="test_event", + source="user", + llm_message=Message(role="user", content=[TextContent(text="test")]), + ) + + await subscriber(event) + + mock_websocket.send_json.assert_called_once() + call_args = mock_websocket.send_json.call_args[0][0] + assert call_args["id"] == "test_event" + + +@pytest.mark.asyncio +async def test_websocket_subscriber_call_exception(mock_websocket): + """Test exception handling in WebSocket subscriber.""" + mock_websocket.send_json.side_effect = Exception("Connection error") + subscriber = _WebSocketSubscriber(websocket=mock_websocket) + event = MessageEvent( + id="test_event", + source="user", + llm_message=Message(role="user", content=[TextContent(text="test")]), + ) + + # Should not raise exception, just log it + await subscriber(event) + + mock_websocket.send_json.assert_called_once() + + +@pytest.mark.asyncio +async def test_websocket_disconnect_breaks_loop( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that WebSocketDisconnect exception breaks the loop.""" + mock_websocket.receive_json.side_effect = WebSocketDisconnect() + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, + ): + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - @pytest.mark.asyncio - async def test_websocket_subscriber_call_success(self, mock_websocket): - """Test successful event sending through WebSocket subscriber.""" - subscriber = _WebSocketSubscriber(websocket=mock_websocket) - event = MessageEvent( - id="test_event", - source="user", - llm_message=Message(role="user", content=[TextContent(text="test")]), + from openhands.agent_server.sockets import events_socket + + await events_socket( + sample_conversation_id, mock_websocket, session_api_key=None ) - await subscriber(event) + mock_event_service.unsubscribe_from_events.assert_called() - mock_websocket.send_json.assert_called_once() - call_args = mock_websocket.send_json.call_args[0][0] - assert call_args["id"] == "test_event" - @pytest.mark.asyncio - async def test_websocket_subscriber_call_exception(self, mock_websocket): - """Test exception handling in WebSocket subscriber.""" - mock_websocket.send_json.side_effect = Exception("Connection error") - subscriber = _WebSocketSubscriber(websocket=mock_websocket) - event = MessageEvent( - id="test_event", - source="user", - llm_message=Message(role="user", content=[TextContent(text="test")]), +@pytest.mark.asyncio +async def test_websocket_no_double_unsubscription( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that unsubscription only happens once even with disconnect.""" + subscriber_id = uuid4() + mock_event_service.subscribe_to_events.return_value = subscriber_id + mock_websocket.receive_json.side_effect = WebSocketDisconnect() + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, + ): + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) + + from openhands.agent_server.sockets import events_socket + + await events_socket( + sample_conversation_id, mock_websocket, session_api_key=None ) - # Should not raise exception, just log it - await subscriber(event) + assert mock_event_service.unsubscribe_from_events.call_count == 1 + mock_event_service.unsubscribe_from_events.assert_called_with(subscriber_id) + - mock_websocket.send_json.assert_called_once() +@pytest.mark.asyncio +async def test_websocket_general_exception_continues_loop( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that general exceptions don't break the loop immediately.""" + call_count = 0 + def side_effect(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("Some error") + elif call_count == 2: + raise WebSocketDisconnect() -class TestWebSocketDisconnectHandling: - """Test cases for WebSocket disconnect handling in the socket endpoint.""" + mock_websocket.receive_json.side_effect = side_effect - @pytest.mark.asyncio - async def test_websocket_disconnect_breaks_loop( - self, mock_websocket, mock_event_service, sample_conversation_id + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, ): - """Test that WebSocketDisconnect exception breaks the loop.""" - # Setup mock to raise WebSocketDisconnect on first receive_json call - mock_websocket.receive_json.side_effect = WebSocketDisconnect() - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - # Mock config to not require authentication - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - # Import and call the socket function directly - from openhands.agent_server.sockets import events_socket + from openhands.agent_server.sockets import events_socket + + await events_socket( + sample_conversation_id, mock_websocket, session_api_key=None + ) + + assert mock_websocket.receive_json.call_count == 2 + mock_event_service.unsubscribe_from_events.assert_called_once() + + +@pytest.mark.asyncio +async def test_websocket_successful_message_processing( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test successful message processing before disconnect.""" + message_data = {"role": "user", "content": "Hello"} + call_count = 0 + + def side_effect(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return message_data + else: + raise WebSocketDisconnect() + + mock_websocket.receive_json.side_effect = side_effect + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, + ): + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) + + from openhands.agent_server.sockets import events_socket + + await events_socket( + sample_conversation_id, mock_websocket, session_api_key=None + ) + + mock_event_service.send_message.assert_called_once() + assert mock_websocket.receive_json.call_count == 2 - # This should not hang or loop infinitely - await events_socket( - sample_conversation_id, mock_websocket, session_api_key=None - ) - # Verify that unsubscribe was called - mock_event_service.unsubscribe_from_events.assert_called() +@pytest.mark.asyncio +async def test_websocket_unsubscribe_in_finally_when_no_disconnect( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that unsubscribe is called even when there's no WebSocketDisconnect.""" + mock_websocket.receive_json.side_effect = RuntimeError("Connection broken") - @pytest.mark.asyncio - async def test_websocket_no_double_unsubscription( - self, mock_websocket, mock_event_service, sample_conversation_id + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, ): - """Test that unsubscription only happens once even with disconnect.""" - subscriber_id = uuid4() - mock_event_service.subscribe_to_events.return_value = subscriber_id - mock_websocket.receive_json.side_effect = WebSocketDisconnect() - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - # Mock config to not require authentication - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - from openhands.agent_server.sockets import events_socket + from openhands.agent_server.sockets import events_socket + with pytest.raises(RuntimeError): await events_socket( sample_conversation_id, mock_websocket, session_api_key=None ) - # Should be called exactly once (not in both except and finally blocks) - assert mock_event_service.unsubscribe_from_events.call_count == 1 - mock_event_service.unsubscribe_from_events.assert_called_with(subscriber_id) + mock_event_service.unsubscribe_from_events.assert_called_once() + + +@pytest.mark.asyncio +async def test_resend_mode_none_no_resend( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that resend_mode=None doesn't trigger event resend.""" + mock_websocket.receive_json.side_effect = WebSocketDisconnect() - @pytest.mark.asyncio - async def test_websocket_general_exception_continues_loop( - self, mock_websocket, mock_event_service, sample_conversation_id + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, ): - """Test that general exceptions don't break the loop immediately.""" - call_count = 0 - - def side_effect(): - nonlocal call_count - call_count += 1 - if call_count == 1: - raise ValueError("Some error") - elif call_count == 2: - raise WebSocketDisconnect() # This should break the loop - - mock_websocket.receive_json.side_effect = side_effect - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - # Mock config to not require authentication - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - from openhands.agent_server.sockets import events_socket + from openhands.agent_server.sockets import events_socket - await events_socket( - sample_conversation_id, mock_websocket, session_api_key=None - ) + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_mode=None, + ) - # Should have been called twice (once for ValueError, once for disconnect) - assert mock_websocket.receive_json.call_count == 2 - mock_event_service.unsubscribe_from_events.assert_called_once() + mock_event_service.search_events.assert_not_called() - @pytest.mark.asyncio - async def test_websocket_successful_message_processing( - self, mock_websocket, mock_event_service, sample_conversation_id + +@pytest.mark.asyncio +async def test_resend_mode_all_resends_events( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that resend_mode='all' resends all existing events.""" + mock_events = [ + MessageEvent( + id="event1", + source="user", + llm_message=Message(role="user", content=[TextContent(text="Hello")]), + ), + MessageEvent( + id="event2", + source="agent", + llm_message=Message(role="assistant", content=[TextContent(text="Hi")]), + ), + ] + mock_event_page = EventPage(items=cast(list[Event], mock_events), next_page_id=None) + mock_event_service.search_events = AsyncMock(return_value=mock_event_page) + mock_websocket.receive_json.side_effect = WebSocketDisconnect() + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, ): - """Test successful message processing before disconnect.""" - message_data = {"role": "user", "content": "Hello"} - call_count = 0 - - def side_effect(): - nonlocal call_count - call_count += 1 - if call_count == 1: - return message_data - else: - raise WebSocketDisconnect() - - mock_websocket.receive_json.side_effect = side_effect - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - # Mock config to not require authentication - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - from openhands.agent_server.sockets import events_socket + from openhands.agent_server.sockets import events_socket - await events_socket( - sample_conversation_id, mock_websocket, session_api_key=None - ) + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_mode="all", + ) - # Should have processed the message - mock_event_service.send_message.assert_called_once() - args, kwargs = mock_event_service.send_message.call_args - message = args[0] - assert message.role == "user" - assert len(message.content) == 1 - assert message.content[0].text == "Hello" - # send_message only takes a message parameter, no run parameter - - @pytest.mark.asyncio - async def test_websocket_unsubscribe_in_finally_when_no_disconnect( - self, mock_websocket, mock_event_service, sample_conversation_id + mock_event_service.search_events.assert_called_once_with(page_id=None) + assert mock_websocket.send_json.call_count == 2 + sent_events = [call[0][0] for call in mock_websocket.send_json.call_args_list] + assert sent_events[0]["id"] == "event1" + assert sent_events[1]["id"] == "event2" + + +@pytest.mark.asyncio +async def test_resend_mode_since_with_timestamp( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that resend_mode='since' with after_timestamp filters events.""" + mock_events = [ + MessageEvent( + id="event1", + source="user", + llm_message=Message(role="user", content=[TextContent(text="Hello")]), + ), + ] + mock_event_page = EventPage(items=cast(list[Event], mock_events), next_page_id=None) + mock_event_service.search_events = AsyncMock(return_value=mock_event_page) + mock_websocket.receive_json.side_effect = WebSocketDisconnect() + + # Use a naive timestamp + test_timestamp = datetime(2024, 1, 15, 10, 30, 0) + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, ): - """Test that unsubscription happens in finally block when no disconnect.""" - # Simulate a different kind of exception that doesn't trigger disconnect handler - mock_websocket.receive_json.side_effect = RuntimeError("Unexpected error") - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - # Mock config to not require authentication - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - from openhands.agent_server.sockets import events_socket + from openhands.agent_server.sockets import events_socket - # This should raise the RuntimeError but still clean up - with pytest.raises(RuntimeError): - await events_socket( - sample_conversation_id, mock_websocket, session_api_key=None - ) + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_mode="since", + after_timestamp=test_timestamp, + ) - # Should still unsubscribe in the finally block - mock_event_service.unsubscribe_from_events.assert_called_once() + mock_event_service.search_events.assert_called_once_with( + page_id=None, timestamp__gte=test_timestamp + ) -class TestResendAllFunctionality: - """Test cases for resend_all parameter functionality.""" +@pytest.mark.asyncio +async def test_resend_mode_since_without_timestamp_logs_warning( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that resend_mode='since' without after_timestamp logs warning.""" + mock_websocket.receive_json.side_effect = WebSocketDisconnect() - @pytest.mark.asyncio - async def test_resend_all_false_no_resend( - self, mock_websocket, mock_event_service, sample_conversation_id + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, + patch("openhands.agent_server.sockets.logger") as mock_logger, ): - """Test that resend_all=False doesn't trigger event resend.""" - mock_websocket.receive_json.side_effect = WebSocketDisconnect() - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - from openhands.agent_server.sockets import events_socket + from openhands.agent_server.sockets import events_socket - await events_socket( - sample_conversation_id, - mock_websocket, - session_api_key=None, - resend_all=False, - ) - - # search_events should not be called when not resending - mock_event_service.search_events.assert_not_called() + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_mode="since", + after_timestamp=None, + ) - @pytest.mark.asyncio - async def test_resend_all_true_resends_events( - self, mock_websocket, mock_event_service, sample_conversation_id + # Should log a warning and not call search_events + mock_logger.warning.assert_called() + warning_call = str(mock_logger.warning.call_args) + assert "resend_mode='since' requires after_timestamp" in warning_call + mock_event_service.search_events.assert_not_called() + + +@pytest.mark.asyncio +async def test_resend_mode_since_timezone_aware_is_normalized( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that timezone-aware timestamps are normalized to naive server time.""" + mock_events = [ + MessageEvent( + id="event1", + source="user", + llm_message=Message(role="user", content=[TextContent(text="Hello")]), + ), + ] + mock_event_page = EventPage(items=cast(list[Event], mock_events), next_page_id=None) + mock_event_service.search_events = AsyncMock(return_value=mock_event_page) + mock_websocket.receive_json.side_effect = WebSocketDisconnect() + + # Use a timezone-aware timestamp (UTC) + test_timestamp = datetime(2024, 1, 15, 10, 30, 0, tzinfo=UTC) + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, ): - """Test that resend_all=True resends all existing events.""" - # Create mock events to resend - mock_events = [ - MessageEvent( - id="event1", - source="user", - llm_message=Message(role="user", content=[TextContent(text="Hello")]), - ), - MessageEvent( - id="event2", - source="agent", - llm_message=Message(role="assistant", content=[TextContent(text="Hi")]), - ), - ] + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) - from typing import cast + from openhands.agent_server.sockets import events_socket - from openhands.agent_server.models import EventPage - from openhands.sdk.event import Event + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_mode="since", + after_timestamp=test_timestamp, + ) - mock_event_page = EventPage( - items=cast(list[Event], mock_events), next_page_id=None + # search_events should be called with the normalized timestamp + mock_event_service.search_events.assert_called_once() + call_args = mock_event_service.search_events.call_args + passed_timestamp = call_args.kwargs["timestamp__gte"] + # The timestamp should be naive (no tzinfo) + assert passed_timestamp is not None + assert passed_timestamp.tzinfo is None + # It should represent the same instant in time (converted to local) + expected = test_timestamp.astimezone(None).replace(tzinfo=None) + assert passed_timestamp == expected + + +# Backward compatibility tests for deprecated resend_all parameter + + +@pytest.mark.asyncio +async def test_deprecated_resend_all_true_still_works( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test backward compatibility: resend_all=True still resends all events.""" + mock_events = [ + MessageEvent( + id="event1", + source="user", + llm_message=Message(role="user", content=[TextContent(text="Hello")]), + ), + ] + mock_event_page = EventPage(items=cast(list[Event], mock_events), next_page_id=None) + mock_event_service.search_events = AsyncMock(return_value=mock_event_page) + mock_websocket.receive_json.side_effect = WebSocketDisconnect() + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, + patch("openhands.agent_server.sockets.logger") as mock_logger, + ): + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) + + from openhands.agent_server.sockets import events_socket + + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_all=True, ) - mock_event_service.search_events = AsyncMock(return_value=mock_event_page) - mock_websocket.receive_json.side_effect = WebSocketDisconnect() - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) - from openhands.agent_server.sockets import events_socket + # Should log deprecation warning + mock_logger.warning.assert_called() + warning_call = str(mock_logger.warning.call_args) + assert "resend_all is deprecated" in warning_call - await events_socket( - sample_conversation_id, - mock_websocket, - session_api_key=None, - resend_all=True, - ) + # But still function correctly + mock_event_service.search_events.assert_called_once_with(page_id=None) + assert mock_websocket.send_json.call_count == 1 - # search_events should be called to get all events - mock_event_service.search_events.assert_called_once_with(page_id=None) - # All events should be sent through websocket - assert mock_websocket.send_json.call_count == 2 - sent_events = [call[0][0] for call in mock_websocket.send_json.call_args_list] - assert sent_events[0]["id"] == "event1" - assert sent_events[1]["id"] == "event2" +@pytest.mark.asyncio +async def test_deprecated_resend_all_false_no_resend( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test backward compatibility: resend_all=False doesn't trigger event resend.""" + mock_websocket.receive_json.side_effect = WebSocketDisconnect() - @pytest.mark.asyncio - async def test_resend_all_handles_search_events_exception( - self, mock_websocket, mock_event_service, sample_conversation_id + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, ): - """Test that exceptions during search_events cause the WebSocket to fail.""" - mock_event_service.search_events = AsyncMock( - side_effect=Exception("Search failed") + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) + + from openhands.agent_server.sockets import events_socket + + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_all=False, ) - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) + mock_event_service.search_events.assert_not_called() + - from openhands.agent_server.sockets import events_socket - - # Should raise the exception from search_events - with pytest.raises(Exception, match="Search failed"): - await events_socket( - sample_conversation_id, - mock_websocket, - session_api_key=None, - resend_all=True, - ) - - # search_events should be called - mock_event_service.search_events.assert_called_once() - # WebSocket should be subscribed but then unsubscribed due to exception - mock_event_service.subscribe_to_events.assert_called_once() - mock_event_service.unsubscribe_from_events.assert_called_once() - - @pytest.mark.asyncio - async def test_resend_all_handles_send_json_exception( - self, mock_websocket, mock_event_service, sample_conversation_id +@pytest.mark.asyncio +async def test_resend_mode_takes_precedence_over_resend_all( + mock_websocket, mock_event_service, sample_conversation_id +): + """Test that resend_mode takes precedence over deprecated resend_all.""" + mock_websocket.receive_json.side_effect = WebSocketDisconnect() + + with ( + patch( + "openhands.agent_server.sockets.conversation_service" + ) as mock_conv_service, + patch("openhands.agent_server.sockets.get_default_config") as mock_config, + patch("openhands.agent_server.sockets.logger") as mock_logger, ): - """Test that exceptions during send_json are handled gracefully.""" - # Create mock events to resend + mock_config.return_value.session_api_keys = None + mock_conv_service.get_event_service = AsyncMock(return_value=mock_event_service) + + from openhands.agent_server.sockets import events_socket + + # If resend_mode is explicitly None and resend_all=True, it should + # fallback to resend_all behavior for backward compat. But if + # resend_mode is set, it takes precedence over resend_all. + # Let's test with resend_mode="all" and resend_all=False mock_events = [ MessageEvent( id="event1", @@ -401,46 +539,21 @@ async def test_resend_all_handles_send_json_exception( llm_message=Message(role="user", content=[TextContent(text="Hello")]), ), ] - - from typing import cast - - from openhands.agent_server.models import EventPage - from openhands.sdk.event import Event - mock_event_page = EventPage( items=cast(list[Event], mock_events), next_page_id=None ) mock_event_service.search_events = AsyncMock(return_value=mock_event_page) - # Make send_json fail during resend - mock_websocket.send_json.side_effect = Exception("Send failed") - mock_websocket.receive_json.side_effect = WebSocketDisconnect() - - with ( - patch( - "openhands.agent_server.sockets.conversation_service" - ) as mock_conv_service, - patch("openhands.agent_server.sockets.get_default_config") as mock_config, - ): - mock_config.return_value.session_api_keys = None - mock_conv_service.get_event_service = AsyncMock( - return_value=mock_event_service - ) - - from openhands.agent_server.sockets import events_socket - - # Should not raise exception, should handle gracefully - await events_socket( - sample_conversation_id, - mock_websocket, - session_api_key=None, - resend_all=True, - ) + await events_socket( + sample_conversation_id, + mock_websocket, + session_api_key=None, + resend_mode="all", + resend_all=False, # This should be ignored since resend_mode is set + ) - # search_events should be called - mock_event_service.search_events.assert_called_once() - # send_json should be called (and fail) - mock_websocket.send_json.assert_called_once() - # WebSocket should still be subscribed and unsubscribed normally - mock_event_service.subscribe_to_events.assert_called_once() - mock_event_service.unsubscribe_from_events.assert_called_once() + # resend_mode="all" should trigger resend, not the resend_all=False + mock_event_service.search_events.assert_called_once() + # No deprecation warning since we're using the new API + warning_calls = [str(c) for c in mock_logger.warning.call_args_list] + assert not any("resend_all is deprecated" in w for w in warning_calls) diff --git a/tests/agent_server/test_event_service.py b/tests/agent_server/test_event_service.py index a861d34edc..3ced94c964 100644 --- a/tests/agent_server/test_event_service.py +++ b/tests/agent_server/test_event_service.py @@ -399,11 +399,17 @@ async def test_search_events_timestamp_range_filter( async def test_search_events_timestamp_filter_with_timezone_aware( self, event_service, mock_conversation_with_timestamped_events ): - """Test filtering events with timezone-aware datetime.""" + """Test filtering events with timezone-aware datetime requires normalization. + + Event timestamps are naive (server local time), so callers must normalize + timezone-aware datetimes to naive before filtering. This is done by the + REST/WebSocket API layer via normalize_datetime_to_server_timezone(). + """ event_service._conversation = mock_conversation_with_timestamped_events - # Filter events >= 12:00:00 UTC (should return events 3, 4, 5) - filter_time = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC) + # Filter events >= 12:00:00 (naive, as if normalized by API layer) + # The API layer would convert a tz-aware datetime to naive server time + filter_time = datetime(2025, 1, 1, 12, 0, 0) # naive datetime result = await event_service.search_events(timestamp__gte=filter_time) assert len(result.items) == 3 @@ -541,11 +547,16 @@ async def test_count_events_timestamp_range_filter( async def test_count_events_timestamp_filter_with_timezone_aware( self, event_service, mock_conversation_with_timestamped_events ): - """Test counting events with timezone-aware datetime.""" + """Test counting events with timezone-aware datetime requires normalization. + + Event timestamps are naive (server local time), so callers must normalize + timezone-aware datetimes to naive before filtering. This is done by the + REST/WebSocket API layer via normalize_datetime_to_server_timezone(). + """ event_service._conversation = mock_conversation_with_timestamped_events - # Count events >= 12:00:00 UTC (should return 3) - filter_time = datetime(2025, 1, 1, 12, 0, 0, tzinfo=UTC) + # Count events >= 12:00:00 (naive, as if normalized by API layer) + filter_time = datetime(2025, 1, 1, 12, 0, 0) # naive datetime result = await event_service.count_events(timestamp__gte=filter_time) assert result == 3