1
+ from __future__ import annotations
2
+
1
3
import unittest
2
4
import json
3
5
from datetime import datetime
@@ -23,15 +25,15 @@ def test_encode_method(self):
23
25
# Create a test event
24
26
timestamp = int (datetime .now ().timestamp () * 1000 )
25
27
event = BaseEvent (type = EventType .RAW , timestamp = timestamp )
26
-
28
+
27
29
# Create encoder and encode event
28
30
encoder = EventEncoder ()
29
31
encoded = encoder .encode (event )
30
-
32
+
31
33
# The encode method calls encode_sse, so the result should be in SSE format
32
34
expected = f"data: { event .model_dump_json (by_alias = True , exclude_none = True )} \n \n "
33
35
self .assertEqual (encoded , expected )
34
-
36
+
35
37
# Verify that camelCase is used in the encoded output
36
38
self .assertIn ('"type":' , encoded )
37
39
self .assertIn ('"timestamp":' , encoded )
@@ -48,25 +50,25 @@ def test_encode_sse_method(self):
48
50
delta = "Hello, world!" ,
49
51
timestamp = 1648214400000
50
52
)
51
-
53
+
52
54
# Create encoder and encode event to SSE
53
55
encoder = EventEncoder ()
54
56
encoded_sse = encoder ._encode_sse (event )
55
-
57
+
56
58
# Verify the format is correct for SSE (data: [json]\n\n)
57
59
self .assertTrue (encoded_sse .startswith ("data: " ))
58
60
self .assertTrue (encoded_sse .endswith ("\n \n " ))
59
-
61
+
60
62
# Extract and verify the JSON content
61
63
json_content = encoded_sse [6 :- 2 ] # Remove "data: " prefix and "\n\n" suffix
62
64
decoded = json .loads (json_content )
63
-
65
+
64
66
# Check that all fields were properly encoded
65
67
self .assertEqual (decoded ["type" ], "TEXT_MESSAGE_CONTENT" )
66
68
self .assertEqual (decoded ["messageId" ], "msg_123" ) # Check snake_case converted to camelCase
67
69
self .assertEqual (decoded ["delta" ], "Hello, world!" )
68
70
self .assertEqual (decoded ["timestamp" ], 1648214400000 )
69
-
71
+
70
72
# Verify that snake_case has been converted to camelCase
71
73
self .assertIn ("messageId" , decoded ) # camelCase key exists
72
74
self .assertNotIn ("message_id" , decoded ) # snake_case key doesn't exist
@@ -75,12 +77,12 @@ def test_encode_with_different_event_types(self):
75
77
"""Test encoding different types of events"""
76
78
# Create encoder
77
79
encoder = EventEncoder ()
78
-
80
+
79
81
# Test with a basic BaseEvent
80
82
base_event = BaseEvent (type = EventType .RAW , timestamp = 1648214400000 )
81
83
encoded_base = encoder .encode (base_event )
82
84
self .assertIn ('"type":"RAW"' , encoded_base )
83
-
85
+
84
86
# Test with a more complex event
85
87
content_event = TextMessageContentEvent (
86
88
type = EventType .TEXT_MESSAGE_CONTENT ,
@@ -89,20 +91,20 @@ def test_encode_with_different_event_types(self):
89
91
timestamp = 1648214400000
90
92
)
91
93
encoded_content = encoder .encode (content_event )
92
-
94
+
93
95
# Verify correct encoding and camelCase conversion
94
96
self .assertIn ('"type":"TEXT_MESSAGE_CONTENT"' , encoded_content )
95
97
self .assertIn ('"messageId":"msg_456"' , encoded_content ) # Check snake_case converted to camelCase
96
98
self .assertIn ('"delta":"Testing different events"' , encoded_content )
97
-
99
+
98
100
# Extract JSON and verify camelCase conversion
99
101
json_content = encoded_content .split ("data: " )[1 ].rstrip ("\n \n " )
100
102
decoded = json .loads (json_content )
101
-
103
+
102
104
# Verify messageId is camelCase (not message_id)
103
105
self .assertIn ("messageId" , decoded )
104
106
self .assertNotIn ("message_id" , decoded )
105
-
107
+
106
108
def test_null_value_exclusion (self ):
107
109
"""Test that fields with None values are excluded from the JSON output"""
108
110
# Create an event with some fields set to None
@@ -111,22 +113,22 @@ def test_null_value_exclusion(self):
111
113
timestamp = 1648214400000 ,
112
114
raw_event = None # Explicitly set to None
113
115
)
114
-
116
+
115
117
# Create encoder and encode event
116
118
encoder = EventEncoder ()
117
119
encoded = encoder .encode (event )
118
-
120
+
119
121
# Extract JSON
120
122
json_content = encoded .split ("data: " )[1 ].rstrip ("\n \n " )
121
123
decoded = json .loads (json_content )
122
-
124
+
123
125
# Verify fields that are present
124
126
self .assertIn ("type" , decoded )
125
127
self .assertIn ("timestamp" , decoded )
126
-
128
+
127
129
# Verify null fields are excluded
128
130
self .assertNotIn ("rawEvent" , decoded )
129
-
131
+
130
132
# Test with another event that has optional fields
131
133
# Create event with some optional fields set to None
132
134
event_with_optional = ToolCallStartEvent (
@@ -136,18 +138,18 @@ def test_null_value_exclusion(self):
136
138
parent_message_id = None , # Optional field explicitly set to None
137
139
timestamp = 1648214400000
138
140
)
139
-
141
+
140
142
encoded_optional = encoder .encode (event_with_optional )
141
143
json_content_optional = encoded_optional .split ("data: " )[1 ].rstrip ("\n \n " )
142
144
decoded_optional = json .loads (json_content_optional )
143
-
145
+
144
146
# Required fields should be present
145
147
self .assertIn ("toolCallId" , decoded_optional )
146
148
self .assertIn ("toolCallName" , decoded_optional )
147
-
149
+
148
150
# Optional field with None value should be excluded
149
151
self .assertNotIn ("parentMessageId" , decoded_optional )
150
-
152
+
151
153
def test_round_trip_serialization (self ):
152
154
"""Test that events can be serialized to JSON with camelCase and deserialized back correctly"""
153
155
# Create a complex event with multiple fields
@@ -158,10 +160,10 @@ def test_round_trip_serialization(self):
158
160
parent_message_id = "msg_parent_456" ,
159
161
timestamp = 1648214400000
160
162
)
161
-
163
+
162
164
# Serialize to JSON with camelCase fields
163
165
json_str = original_event .model_dump_json (by_alias = True )
164
-
166
+
165
167
# Verify JSON uses camelCase
166
168
json_data = json .loads (json_str )
167
169
self .assertIn ("toolCallId" , json_data )
@@ -170,19 +172,19 @@ def test_round_trip_serialization(self):
170
172
self .assertNotIn ("tool_call_id" , json_data )
171
173
self .assertNotIn ("tool_call_name" , json_data )
172
174
self .assertNotIn ("parent_message_id" , json_data )
173
-
175
+
174
176
# Deserialize back to an event
175
177
deserialized_event = ToolCallStartEvent .model_validate_json (json_str )
176
-
178
+
177
179
# Verify the deserialized event is equivalent to the original
178
180
self .assertEqual (deserialized_event .type , original_event .type )
179
181
self .assertEqual (deserialized_event .tool_call_id , original_event .tool_call_id )
180
182
self .assertEqual (deserialized_event .tool_call_name , original_event .tool_call_name )
181
183
self .assertEqual (deserialized_event .parent_message_id , original_event .parent_message_id )
182
184
self .assertEqual (deserialized_event .timestamp , original_event .timestamp )
183
-
185
+
184
186
# Verify complete equality using model_dump
185
187
self .assertEqual (
186
- original_event .model_dump (),
188
+ original_event .model_dump (),
187
189
deserialized_event .model_dump ()
188
190
)
0 commit comments