diff --git a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py index 5d45f50a7b..fe3513ae58 100644 --- a/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py +++ b/pydantic_ai_slim/pydantic_ai/ui/ag_ui/_adapter.py @@ -2,11 +2,12 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from functools import cached_property from typing import ( TYPE_CHECKING, Any, + cast, ) from ... import ExternalToolset, ToolDefinition @@ -107,7 +108,14 @@ def toolset(self) -> AbstractToolset[AgentDepsT] | None: @cached_property def state(self) -> dict[str, Any] | None: """Frontend state from the AG-UI run input.""" - return self.run_input.state + state = self.run_input.state + if state is None: + return None + + if isinstance(state, Mapping) and not state: + return None + + return cast('dict[str, Any]', state) @classmethod def load_messages(cls, messages: Sequence[Message]) -> list[ModelMessage]: diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 05071d2259..33fbff65df 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -226,6 +226,27 @@ async def simple_stream(messages: list[ModelMessage], agent_info: AgentInfo) -> yield '(no tool calls)' +async def test_agui_adapter_state_none() -> None: + """Ensure adapter exposes `None` state when no frontend state provided.""" + agent = Agent( + model=FunctionModel(stream_function=simple_stream), + ) + + run_input = RunAgentInput( + thread_id=uuid_str(), + run_id=uuid_str(), + messages=[], + state=None, + context=[], + tools=[], + forwarded_props=None, + ) + + adapter = AGUIAdapter(agent=agent, run_input=run_input, accept=None) + + assert adapter.state is None + + async def test_basic_user_message() -> None: """Test basic user message with text response.""" agent = Agent( @@ -1193,6 +1214,24 @@ async def test_request_with_state_without_handler() -> None: pass +async def test_request_with_empty_state_without_handler() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + + run_input = create_input( + UserMessage( + id='msg_1', + content='Hello, how are you?', + ), + state={}, + ) + + events = list[dict[str, Any]]() + async for event in run_ag_ui(agent, run_input): + events.append(json.loads(event.removeprefix('data: '))) + + assert events == simple_result() + + async def test_request_with_state_with_custom_handler() -> None: @dataclass class CustomStateDeps: