1212from  shiny .session  import  session_context 
1313from  shiny .types  import  MISSING 
1414from  shiny .ui  import  Chat 
15- from  shiny .ui ._chat  import  as_transformed_message 
1615from  shiny .ui ._chat_normalize  import  normalize_message , normalize_message_chunk 
17- from  shiny .ui ._chat_types  import  ChatMessage 
16+ from  shiny .ui ._chat_types  import  ChatMessage ,  ChatUIMessage 
1817
1918# ---------------------------------------------------------------------- 
2019# Helpers 
@@ -53,31 +52,22 @@ def generate_content(token_count: int) -> str:
5352            return  " " .join (["foo"  for  _  in  range (1 , n )])
5453
5554        msgs  =  (
56-             as_transformed_message (
57-                 {
58-                     "content" : generate_content (102 ),
59-                     "role" : "system" ,
60-                 }
61-             ),
55+             ChatUIMessage (
56+                 content = generate_content (102 ), role = "system" 
57+             ).as_transformed_message (),
6258        )
6359
6460        # Throws since system message is too long 
6561        with  pytest .raises (ValueError ):
6662            chat ._trim_messages (msgs , token_limits = (100 , 0 ), format = MISSING )
6763
6864        msgs  =  (
69-             as_transformed_message (
70-                 {
71-                     "content" : generate_content (100 ),
72-                     "role" : "system" ,
73-                 }
74-             ),
75-             as_transformed_message (
76-                 {
77-                     "content" : generate_content (2 ),
78-                     "role" : "user" ,
79-                 }
80-             ),
65+             ChatUIMessage (
66+                 content = generate_content (100 ), role = "system" 
67+             ).as_transformed_message (),
68+             ChatUIMessage (
69+                 content = generate_content (2 ), role = "user" 
70+             ).as_transformed_message (),
8171        )
8272
8373        # Throws since only the system message fits 
@@ -93,30 +83,24 @@ def generate_content(token_count: int) -> str:
9383        content3  =  generate_content (2 )
9484
9585        msgs  =  (
96-             as_transformed_message (
97-                 {
98-                     "content" : content1 ,
99-                     "role" : "system" ,
100-                 }
101-             ),
102-             as_transformed_message (
103-                 {
104-                     "content" : content2 ,
105-                     "role" : "user" ,
106-                 }
107-             ),
108-             as_transformed_message (
109-                 {
110-                     "content" : content3 ,
111-                     "role" : "user" ,
112-                 }
113-             ),
86+             ChatUIMessage (
87+                 content = content1 ,
88+                 role = "system" ,
89+             ).as_transformed_message (),
90+             ChatUIMessage (
91+                 content = content2 ,
92+                 role = "user" ,
93+             ).as_transformed_message (),
94+             ChatUIMessage (
95+                 content = content3 ,
96+                 role = "user" ,
97+             ).as_transformed_message (),
11498        )
11599
116100        # Should discard the 1st user message 
117101        trimmed  =  chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
118102        assert  len (trimmed ) ==  2 
119-         contents  =  [msg [ " content_server" ]  for  msg  in  trimmed ]
103+         contents  =  [msg . content_server  for  msg  in  trimmed ]
120104        assert  contents  ==  [content1 , content3 ]
121105
122106        content1  =  generate_content (50 )
@@ -125,38 +109,48 @@ def generate_content(token_count: int) -> str:
125109        content4  =  generate_content (2 )
126110
127111        msgs  =  (
128-             as_transformed_message (
129-                 {"content" : content1 , "role" : "system" },
130-             ),
131-             as_transformed_message (
132-                 {"content" : content2 , "role" : "user" },
133-             ),
134-             as_transformed_message (
135-                 {"content" : content3 , "role" : "system" },
136-             ),
137-             as_transformed_message (
138-                 {"content" : content4 , "role" : "user" },
139-             ),
112+             ChatUIMessage (
113+                 content = content1 ,
114+                 role = "system" ,
115+             ).as_transformed_message (),
116+             ChatUIMessage (
117+                 content = content2 ,
118+                 role = "user" ,
119+             ).as_transformed_message (),
120+             ChatUIMessage (
121+                 content = content3 ,
122+                 role = "system" ,
123+             ).as_transformed_message (),
124+             ChatUIMessage (
125+                 content = content4 ,
126+                 role = "user" ,
127+             ).as_transformed_message (),
140128        )
141129
142130        # Should discard the 1st user message 
143131        trimmed  =  chat ._trim_messages (msgs , token_limits = (103 , 0 ), format = MISSING )
144132        assert  len (trimmed ) ==  3 
145-         contents  =  [msg [ " content_server" ]  for  msg  in  trimmed ]
133+         contents  =  [msg . content_server  for  msg  in  trimmed ]
146134        assert  contents  ==  [content1 , content3 , content4 ]
147135
148136        content1  =  generate_content (50 )
149137        content2  =  generate_content (10 )
150138
151139        msgs  =  (
152-             as_transformed_message ({"content" : content1 , "role" : "assistant" }),
153-             as_transformed_message ({"content" : content2 , "role" : "user" }),
140+             ChatUIMessage (
141+                 content = content1 ,
142+                 role = "assistant" ,
143+             ).as_transformed_message (),
144+             ChatUIMessage (
145+                 content = content2 ,
146+                 role = "user" ,
147+             ).as_transformed_message (),
154148        )
155149
156150        # Anthropic requires 1st message to be a user message 
157151        trimmed  =  chat ._trim_messages (msgs , token_limits = (30 , 0 ), format = "anthropic" )
158152        assert  len (trimmed ) ==  1 
159-         contents  =  [msg [ " content_server" ]  for  msg  in  trimmed ]
153+         contents  =  [msg . content_server  for  msg  in  trimmed ]
160154        assert  contents  ==  [content2 ]
161155
162156
@@ -173,13 +167,15 @@ def generate_content(token_count: int) -> str:
173167
174168
175169def  test_string_normalization ():
176-     msg  =  normalize_message_chunk ("Hello world!" )
177-     assert  msg  ==  {"content" : "Hello world!" , "role" : "assistant" }
170+     m  =  normalize_message_chunk ("Hello world!" )
171+     assert  m .content  ==  "Hello world!" 
172+     assert  m .role  ==  "assistant" 
178173
179174
180175def  test_dict_normalization ():
181-     msg  =  normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
182-     assert  msg  ==  {"content" : "Hello world!" , "role" : "assistant" }
176+     m  =  normalize_message_chunk ({"content" : "Hello world!" , "role" : "assistant" })
177+     assert  m .content  ==  "Hello world!" 
178+     assert  m .role  ==  "assistant" 
183179
184180
185181def  test_langchain_normalization ():
@@ -195,11 +191,15 @@ def test_langchain_normalization():
195191
196192    # Mock & normalize return value of BaseChatModel.invoke() 
197193    msg  =  BaseMessage (content = "Hello world!" , role = "assistant" , type = "foo" )
198-     assert  normalize_message (msg ) ==  {"content" : "Hello world!" , "role" : "assistant" }
194+     m  =  normalize_message (msg )
195+     assert  m .content  ==  "Hello world!" 
196+     assert  m .role  ==  "assistant" 
199197
200198    # Mock & normalize return value of BaseChatModel.stream() 
201199    chunk  =  BaseMessageChunk (content = "Hello " , type = "foo" )
202-     assert  normalize_message_chunk (chunk ) ==  {"content" : "Hello " , "role" : "assistant" }
200+     m  =  normalize_message_chunk (chunk )
201+     assert  m .content  ==  "Hello " 
202+     assert  m .role  ==  "assistant" 
203203
204204
205205def  test_google_normalization ():
@@ -256,7 +256,9 @@ def test_anthropic_normalization():
256256        usage = Usage (input_tokens = 0 , output_tokens = 0 ),
257257    )
258258
259-     assert  normalize_message (msg ) ==  {"content" : "Hello world!" , "role" : "assistant" }
259+     m  =  normalize_message (msg )
260+     assert  m .content  ==  "Hello world!" 
261+     assert  m .role  ==  "assistant" 
260262
261263    # Mock return object from Anthropic().messages.create(stream=True) 
262264    chunk  =  RawContentBlockDeltaEvent (
@@ -265,7 +267,9 @@ def test_anthropic_normalization():
265267        index = 0 ,
266268    )
267269
268-     assert  normalize_message_chunk (chunk ) ==  {"content" : "Hello " , "role" : "assistant" }
270+     m  =  normalize_message_chunk (chunk )
271+     assert  m .content  ==  "Hello " 
272+     assert  m .role  ==  "assistant" 
269273
270274
271275def  test_openai_normalization ():
@@ -310,8 +314,9 @@ def test_openai_normalization():
310314        created = int (datetime .now ().timestamp ()),
311315    )
312316
313-     msg  =  normalize_message (completion )
314-     assert  msg  ==  {"content" : "Hello world!" , "role" : "assistant" }
317+     m  =  normalize_message (completion )
318+     assert  m .content  ==  "Hello world!" 
319+     assert  m .role  ==  "assistant" 
315320
316321    # Mock return object from OpenAI().chat.completions.create(stream=True) 
317322    chunk  =  ChatCompletionChunk (
@@ -330,8 +335,9 @@ def test_openai_normalization():
330335        ],
331336    )
332337
333-     msg  =  normalize_message_chunk (chunk )
334-     assert  msg  ==  {"content" : "Hello " , "role" : "assistant" }
338+     m  =  normalize_message_chunk (chunk )
339+     assert  m .content  ==  "Hello " 
340+     assert  m .role  ==  "assistant" 
335341
336342
337343def  test_ollama_normalization ():
@@ -344,8 +350,13 @@ def test_ollama_normalization():
344350    )
345351
346352    msg_dict  =  {"content" : "Hello world!" , "role" : "assistant" }
347-     assert  normalize_message (msg ) ==  msg_dict 
348-     assert  normalize_message_chunk (msg ) ==  msg_dict 
353+     m  =  normalize_message (msg )
354+     assert  m .content  ==  msg_dict ["content" ]
355+     assert  m .role  ==  msg_dict ["role" ]
356+ 
357+     m  =  normalize_message_chunk (msg )
358+     assert  m .content  ==  msg_dict ["content" ]
359+     assert  m .role  ==  msg_dict ["role" ]
349360
350361
351362# ------------------------------------------------------------------------------------ 
0 commit comments