1111from shiny .session import session_context
1212from shiny .types import MISSING
1313from shiny .ui import Chat
14- from shiny .ui ._chat import as_transformed_message
1514from shiny .ui ._chat_normalize import normalize_message , normalize_message_chunk
16- from shiny .ui ._chat_types import ChatMessage
15+ from shiny .ui ._chat_types import ChatMessage , ChatUIMessage
1716
1817# ----------------------------------------------------------------------
1918# Helpers
@@ -52,31 +51,22 @@ def generate_content(token_count: int) -> str:
5251 return " " .join (["foo" for _ in range (1 , n )])
5352
5453 msgs = (
55- as_transformed_message (
56- {
57- "content" : generate_content (102 ),
58- "role" : "system" ,
59- }
60- ),
54+ ChatUIMessage (
55+ content = generate_content (102 ), role = "system"
56+ ).as_transformed_message (),
6157 )
6258
6359 # Throws since system message is too long
6460 with pytest .raises (ValueError ):
6561 chat ._trim_messages (msgs , token_limits = (100 , 0 ), format = MISSING )
6662
6763 msgs = (
68- as_transformed_message (
69- {
70- "content" : generate_content (100 ),
71- "role" : "system" ,
72- }
73- ),
74- as_transformed_message (
75- {
76- "content" : generate_content (2 ),
77- "role" : "user" ,
78- }
79- ),
64+ ChatUIMessage (
65+ content = generate_content (100 ), role = "system"
66+ ).as_transformed_message (),
67+ ChatUIMessage (
68+ content = generate_content (2 ), role = "user"
69+ ).as_transformed_message (),
8070 )
8171
8272 # Throws since only the system message fits
@@ -92,30 +82,24 @@ def generate_content(token_count: int) -> str:
9282 content3 = generate_content (2 )
9383
9484 msgs = (
95- as_transformed_message (
96- {
97- "content" : content1 ,
98- "role" : "system" ,
99- }
100- ),
101- as_transformed_message (
102- {
103- "content" : content2 ,
104- "role" : "user" ,
105- }
106- ),
107- as_transformed_message (
108- {
109- "content" : content3 ,
110- "role" : "user" ,
111- }
112- ),
85+ ChatUIMessage (
86+ content = content1 ,
87+ role = "system" ,
88+ ).as_transformed_message (),
89+ ChatUIMessage (
90+ content = content2 ,
91+ role = "user" ,
92+ ).as_transformed_message (),
93+ ChatUIMessage (
94+ content = content3 ,
95+ role = "user" ,
96+ ).as_transformed_message (),
11397 )
11498
11599 # Should discard the 1st user message
116100 trimmed = chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
117101 assert len (trimmed ) == 2
118- contents = [msg [ " content_server" ] for msg in trimmed ]
102+ contents = [msg . content_server for msg in trimmed ]
119103 assert contents == [content1 , content3 ]
120104
121105 content1 = generate_content (50 )
@@ -124,38 +108,48 @@ def generate_content(token_count: int) -> str:
124108 content4 = generate_content (2 )
125109
126110 msgs = (
127- as_transformed_message (
128- {"content" : content1 , "role" : "system" },
129- ),
130- as_transformed_message (
131- {"content" : content2 , "role" : "user" },
132- ),
133- as_transformed_message (
134- {"content" : content3 , "role" : "system" },
135- ),
136- as_transformed_message (
137- {"content" : content4 , "role" : "user" },
138- ),
111+ ChatUIMessage (
112+ content = content1 ,
113+ role = "system" ,
114+ ).as_transformed_message (),
115+ ChatUIMessage (
116+ content = content2 ,
117+ role = "user" ,
118+ ).as_transformed_message (),
119+ ChatUIMessage (
120+ content = content3 ,
121+ role = "system" ,
122+ ).as_transformed_message (),
123+ ChatUIMessage (
124+ content = content4 ,
125+ role = "user" ,
126+ ).as_transformed_message (),
139127 )
140128
141129 # Should discard the 1st user message
142130 trimmed = chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
143131 assert len (trimmed ) == 3
144- contents = [msg [ " content_server" ] for msg in trimmed ]
132+ contents = [msg . content_server for msg in trimmed ]
145133 assert contents == [content1 , content3 , content4 ]
146134
147135 content1 = generate_content (50 )
148136 content2 = generate_content (10 )
149137
150138 msgs = (
151- as_transformed_message ({"content" : content1 , "role" : "assistant" }),
152- as_transformed_message ({"content" : content2 , "role" : "user" }),
139+ ChatUIMessage (
140+ content = content1 ,
141+ role = "assistant" ,
142+ ).as_transformed_message (),
143+ ChatUIMessage (
144+ content = content2 ,
145+ role = "user" ,
146+ ).as_transformed_message (),
153147 )
154148
155149 # Anthropic requires 1st message to be a user message
156150 trimmed = chat ._trim_messages (msgs , token_limits = (30 , 0 ), format = "anthropic" )
157151 assert len (trimmed ) == 1
158- contents = [msg [ " content_server" ] for msg in trimmed ]
152+ contents = [msg . content_server for msg in trimmed ]
159153 assert contents == [content2 ]
160154
161155
@@ -172,13 +166,15 @@ def generate_content(token_count: int) -> str:
172166
173167
174168def test_string_normalization ():
175- msg = normalize_message_chunk ("Hello world!" )
176- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
169+ m = normalize_message_chunk ("Hello world!" )
170+ assert m .content == "Hello world!"
171+ assert m .role == "assistant"
177172
178173
179174def test_dict_normalization ():
180- msg = normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
181- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
175+ m = normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
176+ assert m .content == "Hello world!"
177+ assert m .role == "assistant"
182178
183179
184180def test_langchain_normalization ():
@@ -194,11 +190,15 @@ def test_langchain_normalization():
194190
195191 # Mock & normalize return value of BaseChatModel.invoke()
196192 msg = BaseMessage (content = "Hello world!" , role = "assistant" , type = "foo" )
197- assert normalize_message (msg ) == {"content" : "Hello world!" , "role" : "assistant" }
193+ m = normalize_message (msg )
194+ assert m .content == "Hello world!"
195+ assert m .role == "assistant"
198196
199197 # Mock & normalize return value of BaseChatModel.stream()
200198 chunk = BaseMessageChunk (content = "Hello " , type = "foo" )
201- assert normalize_message_chunk (chunk ) == {"content" : "Hello " , "role" : "assistant" }
199+ m = normalize_message_chunk (chunk )
200+ assert m .content == "Hello "
201+ assert m .role == "assistant"
202202
203203
204204def test_google_normalization ():
@@ -255,7 +255,9 @@ def test_anthropic_normalization():
255255 usage = Usage (input_tokens = 0 , output_tokens = 0 ),
256256 )
257257
258- assert normalize_message (msg ) == {"content" : "Hello world!" , "role" : "assistant" }
258+ m = normalize_message (msg )
259+ assert m .content == "Hello world!"
260+ assert m .role == "assistant"
259261
260262 # Mock return object from Anthropic().messages.create(stream=True)
261263 chunk = RawContentBlockDeltaEvent (
@@ -264,7 +266,9 @@ def test_anthropic_normalization():
264266 index = 0 ,
265267 )
266268
267- assert normalize_message_chunk (chunk ) == {"content" : "Hello " , "role" : "assistant" }
269+ m = normalize_message_chunk (chunk )
270+ assert m .content == "Hello "
271+ assert m .role == "assistant"
268272
269273
270274def test_openai_normalization ():
@@ -309,8 +313,9 @@ def test_openai_normalization():
309313 created = int (datetime .now ().timestamp ()),
310314 )
311315
312- msg = normalize_message (completion )
313- assert msg == {"content" : "Hello world!" , "role" : "assistant" }
316+ m = normalize_message (completion )
317+ assert m .content == "Hello world!"
318+ assert m .role == "assistant"
314319
315320 # Mock return object from OpenAI().chat.completions.create(stream=True)
316321 chunk = ChatCompletionChunk (
@@ -329,8 +334,9 @@ def test_openai_normalization():
329334 ],
330335 )
331336
332- msg = normalize_message_chunk (chunk )
333- assert msg == {"content" : "Hello " , "role" : "assistant" }
337+ m = normalize_message_chunk (chunk )
338+ assert m .content == "Hello "
339+ assert m .role == "assistant"
334340
335341
336342def test_ollama_normalization ():
@@ -343,8 +349,13 @@ def test_ollama_normalization():
343349 )
344350
345351 msg_dict = {"content" : "Hello world!" , "role" : "assistant" }
346- assert normalize_message (msg ) == msg_dict
347- assert normalize_message_chunk (msg ) == msg_dict
352+ m = normalize_message (msg )
353+ assert m .content == msg_dict ["content" ]
354+ assert m .role == msg_dict ["role" ]
355+
356+ m = normalize_message_chunk (msg )
357+ assert m .content == msg_dict ["content" ]
358+ assert m .role == msg_dict ["role" ]
348359
349360
350361# ------------------------------------------------------------------------------------
@@ -403,9 +414,7 @@ def test_as_google_message():
403414
404415
405416def test_as_langchain_message ():
406- from langchain_core .language_models .base import (
407- LanguageModelInput ,
408- )
417+ from langchain_core .language_models .base import LanguageModelInput
409418 from langchain_core .language_models .base import (
410419 Sequence as LangchainSequence , # pyright: ignore[reportPrivateImportUsage]
411420 )
0 commit comments