Skip to content

Commit 9de1581

Browse files
committed
fix(python): add missing thinking events to Event union
The Event union type was missing the following event types that are defined as BaseEvent subclasses: - ThinkingStartEvent - ThinkingEndEvent - ThinkingTextMessageStartEvent - ThinkingTextMessageContentEvent - ThinkingTextMessageEndEvent This caused validation errors when using TypeAdapter[Event] to parse events from models that emit thinking events (e.g., Claude with extended thinking). Also adds a test to ensure all BaseEvent subclasses are included in the Event union, preventing similar issues in the future.
1 parent 137e58a commit 9de1581

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

sdks/python/ag_ui/core/events.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,16 @@ class StepFinishedEvent(BaseEvent):
279279
TextMessageContentEvent,
280280
TextMessageEndEvent,
281281
TextMessageChunkEvent,
282+
ThinkingTextMessageStartEvent,
283+
ThinkingTextMessageContentEvent,
284+
ThinkingTextMessageEndEvent,
282285
ToolCallStartEvent,
283286
ToolCallArgsEvent,
284287
ToolCallEndEvent,
285288
ToolCallChunkEvent,
286289
ToolCallResultEvent,
290+
ThinkingStartEvent,
291+
ThinkingEndEvent,
287292
StateSnapshotEvent,
288293
StateDeltaEvent,
289294
MessagesSnapshotEvent,

sdks/python/tests/test_events.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import unittest
22
import json
3+
import typing
34
from datetime import datetime
45
from pydantic import ValidationError, TypeAdapter
56

7+
from ag_ui.core import events as events_module
68
from ag_ui.core.types import Message, UserMessage, AssistantMessage, FunctionCall, ToolCall
79
from ag_ui.core.events import (
810
EventType,
@@ -596,6 +598,31 @@ def test_event_with_unicode_and_special_chars(self):
596598
# Verify Unicode and special characters are preserved
597599
self.assertEqual(deserialized.delta, text)
598600

601+
def test_all_event_subclasses_in_event_union(self):
602+
"""Ensure all BaseEvent subclasses are included in the Event union type"""
603+
604+
# Get all classes defined in the events module that are subclasses of BaseEvent
605+
event_subclasses = set()
606+
for name in dir(events_module):
607+
obj = getattr(events_module, name)
608+
if (
609+
isinstance(obj, type)
610+
and issubclass(obj, BaseEvent)
611+
and obj is not BaseEvent
612+
):
613+
event_subclasses.add(obj)
614+
615+
# Get all types in the Event union
616+
union_types = set(typing.get_args(typing.get_args(Event)[0]))
617+
618+
# Check that all event subclasses are in the union
619+
missing_from_union = event_subclasses - union_types
620+
self.assertEqual(
621+
missing_from_union,
622+
set(),
623+
f"The following event types are missing from the Event union: {missing_from_union}"
624+
)
599625

626+
600627
if __name__ == "__main__":
601628
unittest.main()

0 commit comments

Comments
 (0)