Skip to content

Commit a4a8daf

Browse files
committed
chore(python-sdk): modernise type hints
Leverage __future__ import annotations to enable support for streamlined type hints using | vs Union from PEP 604. Future annotations was enabled across the codebase so that any additional use of Optional or Union would be flagged by type linters. Removed unused imports, guarded import under TYPE_CHECKING where necessary. Fixed accept parameter definition for EventEncoder and added a missing comment. Fixes: #50
1 parent af51738 commit a4a8daf

File tree

8 files changed

+129
-109
lines changed

8 files changed

+129
-109
lines changed

python-sdk/ag_ui/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
This module contains the core types and events for the Agent User Interaction Protocol.
33
"""
44

5+
from __future__ import annotations
6+
57
from ag_ui.core.events import (
68
EventType,
79
BaseEvent,

python-sdk/ag_ui/core/events.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
This module contains the event types for the Agent User Interaction Protocol Python SDK.
33
"""
44

5+
from __future__ import annotations
6+
57
from enum import Enum
6-
from typing import Any, List, Literal, Optional, Union, Annotated
8+
from typing import Any, List, Literal, Optional, Annotated
79
from pydantic import Field
810

911
from .types import Message, State, ConfiguredBaseModel
@@ -38,8 +40,8 @@ class BaseEvent(ConfiguredBaseModel):
3840
Base event for all events in the Agent User Interaction Protocol.
3941
"""
4042
type: EventType
41-
timestamp: Optional[int] = None
42-
raw_event: Optional[Any] = None
43+
timestamp: int | None = None
44+
raw_event: Any = None
4345

4446

4547
class TextMessageStartEvent(BaseEvent):
@@ -76,9 +78,9 @@ class TextMessageChunkEvent(BaseEvent):
7678
Event containing a chunk of text message content.
7779
"""
7880
type: Literal[EventType.TEXT_MESSAGE_CHUNK]
79-
message_id: Optional[str] = None
81+
message_id: str | None = None
8082
role: Optional[Literal["assistant"]] = None
81-
delta: Optional[str] = None
83+
delta: str | None = None
8284

8385
class ToolCallStartEvent(BaseEvent):
8486
"""
@@ -87,7 +89,7 @@ class ToolCallStartEvent(BaseEvent):
8789
type: Literal[EventType.TOOL_CALL_START]
8890
tool_call_id: str
8991
tool_call_name: str
90-
parent_message_id: Optional[str] = None
92+
parent_message_id: str | None = None
9193

9294

9395
class ToolCallArgsEvent(BaseEvent):
@@ -111,10 +113,10 @@ class ToolCallChunkEvent(BaseEvent):
111113
Event containing a chunk of tool call content.
112114
"""
113115
type: Literal[EventType.TOOL_CALL_CHUNK]
114-
tool_call_id: Optional[str] = None
115-
tool_call_name: Optional[str] = None
116-
parent_message_id: Optional[str] = None
117-
delta: Optional[str] = None
116+
tool_call_id: str | None = None
117+
tool_call_name: str | None = None
118+
parent_message_id: str | None = None
119+
delta: str | None = None
118120

119121
class StateSnapshotEvent(BaseEvent):
120122
"""
@@ -146,7 +148,7 @@ class RawEvent(BaseEvent):
146148
"""
147149
type: Literal[EventType.RAW]
148150
event: Any
149-
source: Optional[str] = None
151+
source: str | None = None
150152

151153

152154
class CustomEvent(BaseEvent):
@@ -182,7 +184,7 @@ class RunErrorEvent(BaseEvent):
182184
"""
183185
type: Literal[EventType.RUN_ERROR]
184186
message: str
185-
code: Optional[str] = None
187+
code: str | None = None
186188

187189

188190
class StepStartedEvent(BaseEvent):
@@ -202,25 +204,23 @@ class StepFinishedEvent(BaseEvent):
202204

203205

204206
Event = Annotated[
205-
Union[
206-
TextMessageStartEvent,
207-
TextMessageContentEvent,
208-
TextMessageEndEvent,
209-
TextMessageChunkEvent,
210-
ToolCallStartEvent,
211-
ToolCallArgsEvent,
212-
ToolCallEndEvent,
213-
ToolCallChunkEvent,
214-
StateSnapshotEvent,
215-
StateDeltaEvent,
216-
MessagesSnapshotEvent,
217-
RawEvent,
218-
CustomEvent,
219-
RunStartedEvent,
220-
RunFinishedEvent,
221-
RunErrorEvent,
222-
StepStartedEvent,
223-
StepFinishedEvent,
224-
],
207+
TextMessageStartEvent |
208+
TextMessageContentEvent |
209+
TextMessageEndEvent |
210+
TextMessageChunkEvent |
211+
ToolCallStartEvent |
212+
ToolCallArgsEvent |
213+
ToolCallEndEvent |
214+
ToolCallChunkEvent |
215+
StateSnapshotEvent |
216+
StateDeltaEvent |
217+
MessagesSnapshotEvent |
218+
RawEvent |
219+
CustomEvent |
220+
RunStartedEvent |
221+
RunFinishedEvent |
222+
RunErrorEvent |
223+
StepStartedEvent |
224+
StepFinishedEvent,
225225
Field(discriminator="type")
226226
]

python-sdk/ag_ui/core/types.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
This module contains the types for the Agent User Interaction Protocol Python SDK.
33
"""
44

5-
from typing import Any, List, Literal, Optional, Union, Annotated
5+
from __future__ import annotations
6+
7+
from typing import Any, List, Literal, Optional, Annotated
68
from pydantic import BaseModel, Field, ConfigDict
79
from pydantic.alias_generators import to_camel
810

@@ -41,8 +43,8 @@ class BaseMessage(ConfiguredBaseModel):
4143
"""
4244
id: str
4345
role: str
44-
content: Optional[str] = None
45-
name: Optional[str] = None
46+
content: str | None = None
47+
name: str | None = None
4648

4749

4850
class DeveloperMessage(BaseMessage):
@@ -88,7 +90,7 @@ class ToolMessage(ConfiguredBaseModel):
8890

8991

9092
Message = Annotated[
91-
Union[DeveloperMessage, SystemMessage, AssistantMessage, UserMessage, ToolMessage],
93+
DeveloperMessage | SystemMessage | AssistantMessage | UserMessage | ToolMessage,
9294
Field(discriminator="role")
9395
]
9496

python-sdk/ag_ui/encoder/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
This module contains the EventEncoder class.
33
"""
44

5+
from __future__ import annotations
6+
57
from ag_ui.encoder.encoder import EventEncoder, AGUI_MEDIA_TYPE
68

79
__all__ = ["EventEncoder", "AGUI_MEDIA_TYPE"]

python-sdk/ag_ui/encoder/encoder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,25 @@
22
This module contains the EventEncoder class
33
"""
44

5-
from ag_ui.core.events import BaseEvent
5+
from __future__ import annotations
6+
7+
from typing import TYPE_CHECKING
8+
9+
if TYPE_CHECKING:
10+
from ag_ui.core.events import BaseEvent
11+
612

713
AGUI_MEDIA_TYPE = "application/vnd.ag-ui.event+proto"
814

915
class EventEncoder:
1016
"""
1117
Encodes Agent User Interaction events.
1218
"""
13-
def __init__(self, accept: str = None):
14-
pass
19+
def __init__(self, accept: str | None = None) -> None:
20+
"""
21+
Initializes the EventEncoder.
22+
"""
23+
self.accept = accept
1524

1625
def get_content_type(self) -> str:
1726
"""

python-sdk/tests/test_encoder.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import unittest
24
import json
35
from datetime import datetime
@@ -23,15 +25,15 @@ def test_encode_method(self):
2325
# Create a test event
2426
timestamp = int(datetime.now().timestamp() * 1000)
2527
event = BaseEvent(type=EventType.RAW, timestamp=timestamp)
26-
28+
2729
# Create encoder and encode event
2830
encoder = EventEncoder()
2931
encoded = encoder.encode(event)
30-
32+
3133
# The encode method calls encode_sse, so the result should be in SSE format
3234
expected = f"data: {event.model_dump_json(by_alias=True, exclude_none=True)}\n\n"
3335
self.assertEqual(encoded, expected)
34-
36+
3537
# Verify that camelCase is used in the encoded output
3638
self.assertIn('"type":', encoded)
3739
self.assertIn('"timestamp":', encoded)
@@ -48,25 +50,25 @@ def test_encode_sse_method(self):
4850
delta="Hello, world!",
4951
timestamp=1648214400000
5052
)
51-
53+
5254
# Create encoder and encode event to SSE
5355
encoder = EventEncoder()
5456
encoded_sse = encoder._encode_sse(event)
55-
57+
5658
# Verify the format is correct for SSE (data: [json]\n\n)
5759
self.assertTrue(encoded_sse.startswith("data: "))
5860
self.assertTrue(encoded_sse.endswith("\n\n"))
59-
61+
6062
# Extract and verify the JSON content
6163
json_content = encoded_sse[6:-2] # Remove "data: " prefix and "\n\n" suffix
6264
decoded = json.loads(json_content)
63-
65+
6466
# Check that all fields were properly encoded
6567
self.assertEqual(decoded["type"], "TEXT_MESSAGE_CONTENT")
6668
self.assertEqual(decoded["messageId"], "msg_123") # Check snake_case converted to camelCase
6769
self.assertEqual(decoded["delta"], "Hello, world!")
6870
self.assertEqual(decoded["timestamp"], 1648214400000)
69-
71+
7072
# Verify that snake_case has been converted to camelCase
7173
self.assertIn("messageId", decoded) # camelCase key exists
7274
self.assertNotIn("message_id", decoded) # snake_case key doesn't exist
@@ -75,12 +77,12 @@ def test_encode_with_different_event_types(self):
7577
"""Test encoding different types of events"""
7678
# Create encoder
7779
encoder = EventEncoder()
78-
80+
7981
# Test with a basic BaseEvent
8082
base_event = BaseEvent(type=EventType.RAW, timestamp=1648214400000)
8183
encoded_base = encoder.encode(base_event)
8284
self.assertIn('"type":"RAW"', encoded_base)
83-
85+
8486
# Test with a more complex event
8587
content_event = TextMessageContentEvent(
8688
type=EventType.TEXT_MESSAGE_CONTENT,
@@ -89,20 +91,20 @@ def test_encode_with_different_event_types(self):
8991
timestamp=1648214400000
9092
)
9193
encoded_content = encoder.encode(content_event)
92-
94+
9395
# Verify correct encoding and camelCase conversion
9496
self.assertIn('"type":"TEXT_MESSAGE_CONTENT"', encoded_content)
9597
self.assertIn('"messageId":"msg_456"', encoded_content) # Check snake_case converted to camelCase
9698
self.assertIn('"delta":"Testing different events"', encoded_content)
97-
99+
98100
# Extract JSON and verify camelCase conversion
99101
json_content = encoded_content.split("data: ")[1].rstrip("\n\n")
100102
decoded = json.loads(json_content)
101-
103+
102104
# Verify messageId is camelCase (not message_id)
103105
self.assertIn("messageId", decoded)
104106
self.assertNotIn("message_id", decoded)
105-
107+
106108
def test_null_value_exclusion(self):
107109
"""Test that fields with None values are excluded from the JSON output"""
108110
# Create an event with some fields set to None
@@ -111,22 +113,22 @@ def test_null_value_exclusion(self):
111113
timestamp=1648214400000,
112114
raw_event=None # Explicitly set to None
113115
)
114-
116+
115117
# Create encoder and encode event
116118
encoder = EventEncoder()
117119
encoded = encoder.encode(event)
118-
120+
119121
# Extract JSON
120122
json_content = encoded.split("data: ")[1].rstrip("\n\n")
121123
decoded = json.loads(json_content)
122-
124+
123125
# Verify fields that are present
124126
self.assertIn("type", decoded)
125127
self.assertIn("timestamp", decoded)
126-
128+
127129
# Verify null fields are excluded
128130
self.assertNotIn("rawEvent", decoded)
129-
131+
130132
# Test with another event that has optional fields
131133
# Create event with some optional fields set to None
132134
event_with_optional = ToolCallStartEvent(
@@ -136,18 +138,18 @@ def test_null_value_exclusion(self):
136138
parent_message_id=None, # Optional field explicitly set to None
137139
timestamp=1648214400000
138140
)
139-
141+
140142
encoded_optional = encoder.encode(event_with_optional)
141143
json_content_optional = encoded_optional.split("data: ")[1].rstrip("\n\n")
142144
decoded_optional = json.loads(json_content_optional)
143-
145+
144146
# Required fields should be present
145147
self.assertIn("toolCallId", decoded_optional)
146148
self.assertIn("toolCallName", decoded_optional)
147-
149+
148150
# Optional field with None value should be excluded
149151
self.assertNotIn("parentMessageId", decoded_optional)
150-
152+
151153
def test_round_trip_serialization(self):
152154
"""Test that events can be serialized to JSON with camelCase and deserialized back correctly"""
153155
# Create a complex event with multiple fields
@@ -158,10 +160,10 @@ def test_round_trip_serialization(self):
158160
parent_message_id="msg_parent_456",
159161
timestamp=1648214400000
160162
)
161-
163+
162164
# Serialize to JSON with camelCase fields
163165
json_str = original_event.model_dump_json(by_alias=True)
164-
166+
165167
# Verify JSON uses camelCase
166168
json_data = json.loads(json_str)
167169
self.assertIn("toolCallId", json_data)
@@ -170,19 +172,19 @@ def test_round_trip_serialization(self):
170172
self.assertNotIn("tool_call_id", json_data)
171173
self.assertNotIn("tool_call_name", json_data)
172174
self.assertNotIn("parent_message_id", json_data)
173-
175+
174176
# Deserialize back to an event
175177
deserialized_event = ToolCallStartEvent.model_validate_json(json_str)
176-
178+
177179
# Verify the deserialized event is equivalent to the original
178180
self.assertEqual(deserialized_event.type, original_event.type)
179181
self.assertEqual(deserialized_event.tool_call_id, original_event.tool_call_id)
180182
self.assertEqual(deserialized_event.tool_call_name, original_event.tool_call_name)
181183
self.assertEqual(deserialized_event.parent_message_id, original_event.parent_message_id)
182184
self.assertEqual(deserialized_event.timestamp, original_event.timestamp)
183-
185+
184186
# Verify complete equality using model_dump
185187
self.assertEqual(
186-
original_event.model_dump(),
188+
original_event.model_dump(),
187189
deserialized_event.model_dump()
188190
)

0 commit comments

Comments
 (0)