1212from shiny .session import session_context
1313from shiny .types import MISSING
1414from shiny .ui import Chat
15- from shiny .ui ._chat import as_transformed_message
1615from shiny .ui ._chat_normalize import normalize_message , normalize_message_chunk
17- from shiny .ui ._chat_types import ChatMessage
16+ from shiny .ui ._chat_types import ChatMessage , ChatUIMessage
1817
1918# ----------------------------------------------------------------------
2019# Helpers
@@ -53,31 +52,22 @@ def generate_content(token_count: int) -> str:
5352 return " " .join (["foo" for _ in range (1 , n )])
5453
5554 msgs = (
56- as_transformed_message (
57- {
58- "content" : generate_content (102 ),
59- "role" : "system" ,
60- }
61- ),
55+ ChatUIMessage (
56+ content = generate_content (102 ), role = "system"
57+ ).as_transformed_message (),
6258 )
6359
6460 # Throws since system message is too long
6561 with pytest .raises (ValueError ):
6662 chat ._trim_messages (msgs , token_limits = (100 , 0 ), format = MISSING )
6763
6864 msgs = (
69- as_transformed_message (
70- {
71- "content" : generate_content (100 ),
72- "role" : "system" ,
73- }
74- ),
75- as_transformed_message (
76- {
77- "content" : generate_content (2 ),
78- "role" : "user" ,
79- }
80- ),
65+ ChatUIMessage (
66+ content = generate_content (100 ), role = "system"
67+ ).as_transformed_message (),
68+ ChatUIMessage (
69+ content = generate_content (2 ), role = "user"
70+ ).as_transformed_message (),
8171 )
8272
8373 # Throws since only the system message fits
@@ -93,30 +83,24 @@ def generate_content(token_count: int) -> str:
9383 content3 = generate_content (2 )
9484
9585 msgs = (
96- as_transformed_message (
97- {
98- "content" : content1 ,
99- "role" : "system" ,
100- }
101- ),
102- as_transformed_message (
103- {
104- "content" : content2 ,
105- "role" : "user" ,
106- }
107- ),
108- as_transformed_message (
109- {
110- "content" : content3 ,
111- "role" : "user" ,
112- }
113- ),
86+ ChatUIMessage (
87+ content = content1 ,
88+ role = "system" ,
89+ ).as_transformed_message (),
90+ ChatUIMessage (
91+ content = content2 ,
92+ role = "user" ,
93+ ).as_transformed_message (),
94+ ChatUIMessage (
95+ content = content3 ,
96+ role = "user" ,
97+ ).as_transformed_message (),
11498 )
11599
116100 # Should discard the 1st user message
117101 trimmed = chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
118102 assert len (trimmed ) == 2
119- contents = [msg [ " content_server" ] for msg in trimmed ]
103+ contents = [msg . content_server for msg in trimmed ]
120104 assert contents == [content1 , content3 ]
121105
122106 content1 = generate_content (50 )
@@ -125,38 +109,48 @@ def generate_content(token_count: int) -> str:
125109 content4 = generate_content (2 )
126110
127111 msgs = (
128- as_transformed_message (
129- {"content" : content1 , "role" : "system" },
130- ),
131- as_transformed_message (
132- {"content" : content2 , "role" : "user" },
133- ),
134- as_transformed_message (
135- {"content" : content3 , "role" : "system" },
136- ),
137- as_transformed_message (
138- {"content" : content4 , "role" : "user" },
139- ),
112+ ChatUIMessage (
113+ content = content1 ,
114+ role = "system" ,
115+ ).as_transformed_message (),
116+ ChatUIMessage (
117+ content = content2 ,
118+ role = "user" ,
119+ ).as_transformed_message (),
120+ ChatUIMessage (
121+ content = content3 ,
122+ role = "system" ,
123+ ).as_transformed_message (),
124+ ChatUIMessage (
125+ content = content4 ,
126+ role = "user" ,
127+ ).as_transformed_message (),
140128 )
141129
142130 # Should discard the 1st user message
143131 trimmed = chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
144132 assert len (trimmed ) == 3
145- contents = [msg [ " content_server" ] for msg in trimmed ]
133+ contents = [msg . content_server for msg in trimmed ]
146134 assert contents == [content1 , content3 , content4 ]
147135
148136 content1 = generate_content (50 )
149137 content2 = generate_content (10 )
150138
151139 msgs = (
152- as_transformed_message ({"content" : content1 , "role" : "assistant" }),
153- as_transformed_message ({"content" : content2 , "role" : "user" }),
140+ ChatUIMessage (
141+ content = content1 ,
142+ role = "assistant" ,
143+ ).as_transformed_message (),
144+ ChatUIMessage (
145+ content = content2 ,
146+ role = "user" ,
147+ ).as_transformed_message (),
154148 )
155149
156150 # Anthropic requires 1st message to be a user message
157151 trimmed = chat ._trim_messages (msgs , token_limits = (30 , 0 ), format = "anthropic" )
158152 assert len (trimmed ) == 1
159- contents = [msg [ " content_server" ] for msg in trimmed ]
153+ contents = [msg . content_server for msg in trimmed ]
160154 assert contents == [content2 ]
161155
162156
@@ -173,13 +167,15 @@ def generate_content(token_count: int) -> str:
173167
174168
175169def test_string_normalization ():
176- msg = normalize_message_chunk ("Hello world!" )
177- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
170+ m = normalize_message_chunk ("Hello world!" )
171+ assert m .content == "Hello world!"
172+ assert m .role == "assistant"
178173
179174
180175def test_dict_normalization ():
181- msg = normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
182- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
176+ m = normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
177+ assert m .content == "Hello world!"
178+ assert m .role == "assistant"
183179
184180
185181def test_langchain_normalization ():
@@ -195,11 +191,15 @@ def test_langchain_normalization():
195191
196192 # Mock & normalize return value of BaseChatModel.invoke()
197193 msg = BaseMessage (content = "Hello world!" , role = "assistant" , type = "foo" )
198- assert normalize_message (msg ) == {"content" : "Hello world!" , "role" : "assistant" }
194+ m = normalize_message (msg )
195+ assert m .content == "Hello world!"
196+ assert m .role == "assistant"
199197
200198 # Mock & normalize return value of BaseChatModel.stream()
201199 chunk = BaseMessageChunk (content = "Hello " , type = "foo" )
202- assert normalize_message_chunk (chunk ) == {"content" : "Hello " , "role" : "assistant" }
200+ m = normalize_message_chunk (chunk )
201+ assert m .content == "Hello "
202+ assert m .role == "assistant"
203203
204204
205205def test_google_normalization ():
@@ -256,7 +256,9 @@ def test_anthropic_normalization():
256256 usage = Usage (input_tokens = 0 , output_tokens = 0 ),
257257 )
258258
259- assert normalize_message (msg ) == {"content" : "Hello world!" , "role" : "assistant" }
259+ m = normalize_message (msg )
260+ assert m .content == "Hello world!"
261+ assert m .role == "assistant"
260262
261263 # Mock return object from Anthropic().messages.create(stream=True)
262264 chunk = RawContentBlockDeltaEvent (
@@ -265,7 +267,9 @@ def test_anthropic_normalization():
265267 index = 0 ,
266268 )
267269
268- assert normalize_message_chunk (chunk ) == {"content" : "Hello " , "role" : "assistant" }
270+ m = normalize_message_chunk (chunk )
271+ assert m .content == "Hello "
272+ assert m .role == "assistant"
269273
270274
271275def test_openai_normalization ():
@@ -310,8 +314,9 @@ def test_openai_normalization():
310314 created = int (datetime .now ().timestamp ()),
311315 )
312316
313- msg = normalize_message (completion )
314- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
317+ m = normalize_message (completion )
318+ assert m .content == "Hello world!"
319+ assert m .role == "assistant"
315320
316321 # Mock return object from OpenAI().chat.completions.create(stream=True)
317322 chunk = ChatCompletionChunk (
@@ -330,8 +335,9 @@ def test_openai_normalization():
330335 ],
331336 )
332337
333- msg = normalize_message_chunk (chunk )
334- assert msg == {"content" : "Hello " , "role" : "assistant" }
338+ m = normalize_message_chunk (chunk )
339+ assert m .content == "Hello "
340+ assert m .role == "assistant"
335341
336342
337343def test_ollama_normalization ():
@@ -344,8 +350,13 @@ def test_ollama_normalization():
344350 )
345351
346352 msg_dict = {"content" : "Hello world!" , "role" : "assistant" }
347- assert normalize_message (msg ) == msg_dict
348- assert normalize_message_chunk (msg ) == msg_dict
353+ m = normalize_message (msg )
354+ assert m .content == msg_dict ["content" ]
355+ assert m .role == msg_dict ["role" ]
356+
357+ m = normalize_message_chunk (msg )
358+ assert m .content == msg_dict ["content" ]
359+ assert m .role == msg_dict ["role" ]
349360
350361
351362# ------------------------------------------------------------------------------------
0 commit comments