1
+ """Tests for text message events with different roles."""
2
+
3
+ import unittest
4
+ from pydantic import ValidationError
5
+ from ag_ui .core import (
6
+ EventType ,
7
+ TextMessageStartEvent ,
8
+ TextMessageContentEvent ,
9
+ TextMessageEndEvent ,
10
+ TextMessageChunkEvent ,
11
+ Role ,
12
+ )
13
+
14
+ # Test all available roles for text messages (excluding "tool")
15
+ TEXT_MESSAGE_ROLES = ["developer" , "system" , "assistant" , "user" ]
16
+
17
+
18
+ class TestTextMessageRoles (unittest .TestCase ):
19
+ """Test text message events with different roles."""
20
+
21
+ def test_text_message_start_with_all_roles (self ) -> None :
22
+ """Test TextMessageStartEvent with different roles."""
23
+ for role in TEXT_MESSAGE_ROLES :
24
+ with self .subTest (role = role ):
25
+ event = TextMessageStartEvent (
26
+ message_id = "test-msg" ,
27
+ role = role ,
28
+ )
29
+
30
+ self .assertEqual (event .type , EventType .TEXT_MESSAGE_START )
31
+ self .assertEqual (event .message_id , "test-msg" )
32
+ self .assertEqual (event .role , role )
33
+
34
+ def test_text_message_chunk_with_all_roles (self ) -> None :
35
+ """Test TextMessageChunkEvent with different roles."""
36
+ for role in TEXT_MESSAGE_ROLES :
37
+ with self .subTest (role = role ):
38
+ event = TextMessageChunkEvent (
39
+ message_id = "test-msg" ,
40
+ role = role ,
41
+ delta = f"Hello from { role } " ,
42
+ )
43
+
44
+ self .assertEqual (event .type , EventType .TEXT_MESSAGE_CHUNK )
45
+ self .assertEqual (event .message_id , "test-msg" )
46
+ self .assertEqual (event .role , role )
47
+ self .assertEqual (event .delta , f"Hello from { role } " )
48
+
49
+ def test_text_message_chunk_without_role (self ) -> None :
50
+ """Test TextMessageChunkEvent without role (should be optional)."""
51
+ event = TextMessageChunkEvent (
52
+ message_id = "test-msg" ,
53
+ delta = "Hello without role" ,
54
+ )
55
+
56
+ self .assertEqual (event .type , EventType .TEXT_MESSAGE_CHUNK )
57
+ self .assertEqual (event .message_id , "test-msg" )
58
+ self .assertIsNone (event .role )
59
+ self .assertEqual (event .delta , "Hello without role" )
60
+
61
+ def test_multiple_messages_different_roles (self ) -> None :
62
+ """Test creating multiple messages with different roles."""
63
+ events = []
64
+
65
+ for role in TEXT_MESSAGE_ROLES :
66
+ start_event = TextMessageStartEvent (
67
+ message_id = f"msg-{ role } " ,
68
+ role = role ,
69
+ )
70
+ content_event = TextMessageContentEvent (
71
+ message_id = f"msg-{ role } " ,
72
+ delta = f"Message from { role } " ,
73
+ )
74
+ end_event = TextMessageEndEvent (
75
+ message_id = f"msg-{ role } " ,
76
+ )
77
+
78
+ events .extend ([start_event , content_event , end_event ])
79
+
80
+ # Verify we have 3 events per role
81
+ self .assertEqual (len (events ), len (TEXT_MESSAGE_ROLES ) * 3 )
82
+
83
+ # Verify each start event has the correct role
84
+ for i , role in enumerate (TEXT_MESSAGE_ROLES ):
85
+ start_event = events [i * 3 ]
86
+ self .assertIsInstance (start_event , TextMessageStartEvent )
87
+ self .assertEqual (start_event .role , role )
88
+ self .assertEqual (start_event .message_id , f"msg-{ role } " )
89
+
90
+ def test_text_message_serialization (self ) -> None :
91
+ """Test that text message events serialize correctly with roles."""
92
+ for role in TEXT_MESSAGE_ROLES :
93
+ with self .subTest (role = role ):
94
+ event = TextMessageStartEvent (
95
+ message_id = "test-msg" ,
96
+ role = role ,
97
+ )
98
+
99
+ # Convert to dict and back
100
+ event_dict = event .model_dump ()
101
+ self .assertEqual (event_dict ["role" ], role )
102
+ self .assertEqual (event_dict ["type" ], EventType .TEXT_MESSAGE_START )
103
+ self .assertEqual (event_dict ["message_id" ], "test-msg" )
104
+
105
+ # Recreate from dict
106
+ new_event = TextMessageStartEvent (** event_dict )
107
+ self .assertEqual (new_event .role , role )
108
+ self .assertEqual (new_event , event )
109
+
110
+ def test_invalid_role_rejected (self ) -> None :
111
+ """Test that invalid roles are rejected."""
112
+ # Test with completely invalid role
113
+ with self .assertRaises (ValidationError ):
114
+ TextMessageStartEvent (
115
+ message_id = "test-msg" ,
116
+ role = "invalid_role" , # type: ignore
117
+ )
118
+
119
+ # Test that 'tool' role is not allowed for text messages
120
+ with self .assertRaises (ValidationError ):
121
+ TextMessageStartEvent (
122
+ message_id = "test-msg" ,
123
+ role = "tool" , # type: ignore
124
+ )
125
+
126
+ # Test that 'tool' role is not allowed for chunks either
127
+ with self .assertRaises (ValidationError ):
128
+ TextMessageChunkEvent (
129
+ message_id = "test-msg" ,
130
+ role = "tool" , # type: ignore
131
+ delta = "Tool message" ,
132
+ )
133
+
134
+ def test_text_message_start_default_role (self ) -> None :
135
+ """Test that TextMessageStartEvent defaults to 'assistant' role."""
136
+ event = TextMessageStartEvent (
137
+ message_id = "test-msg" ,
138
+ )
139
+
140
+ self .assertEqual (event .type , EventType .TEXT_MESSAGE_START )
141
+ self .assertEqual (event .message_id , "test-msg" )
142
+ self .assertEqual (event .role , "assistant" ) # Should default to assistant
143
+
144
+
145
+ if __name__ == "__main__" :
146
+ unittest .main ()
0 commit comments