3838 as_provider_message ,
3939)
4040from ._chat_tokenizer import TokenEncoding , TokenizersEncoding , get_default_tokenizer
41- from ._chat_types import ChatMessage , ClientMessage , TransformedMessage
41+ from ._chat_types import ChatMessage , ChatMessageDict , ClientMessage , TransformedMessage
4242from ._html_deps_py_shiny import chat_deps
4343from .fill import as_fill_item , as_fillable_container
4444
4545__all__ = (
4646 "Chat" ,
4747 "ChatExpress" ,
4848 "chat_ui" ,
49- "ChatMessage " ,
49+ "ChatMessageDict " ,
5050)
5151
5252
@@ -251,7 +251,10 @@ async def _on_user_input():
251251 else :
252252 # A transformed value of None is a special signal to suspend input
253253 # handling (i.e., don't generate a response)
254- self ._store_message (as_transformed_message (msg ), index = n_pre )
254+ self ._store_message (
255+ TransformedMessage .from_chat_message (msg ),
256+ index = n_pre ,
257+ )
255258 await self ._remove_loading_message ()
256259 self ._suspend_input_handler = True
257260
@@ -412,7 +415,7 @@ def messages(
412415 token_limits : tuple [int , int ] | None = None ,
413416 transform_user : Literal ["all" , "last" , "none" ] = "all" ,
414417 transform_assistant : bool = False ,
415- ) -> tuple [ChatMessage , ...]: ...
418+ ) -> tuple [ChatMessageDict , ...]: ...
416419
417420 def messages (
418421 self ,
@@ -421,7 +424,7 @@ def messages(
421424 token_limits : tuple [int , int ] | None = None ,
422425 transform_user : Literal ["all" , "last" , "none" ] = "all" ,
423426 transform_assistant : bool = False ,
424- ) -> tuple [ChatMessage | ProviderMessage , ...]:
427+ ) -> tuple [ChatMessageDict | ProviderMessage , ...]:
425428 """
426429 Reactively read chat messages
427430
@@ -489,17 +492,20 @@ def messages(
489492 if token_limits is not None :
490493 messages = self ._trim_messages (messages , token_limits , format )
491494
492- res : list [ChatMessage | ProviderMessage ] = []
495+ res : list [ChatMessageDict | ProviderMessage ] = []
493496 for i , m in enumerate (messages ):
494497 transform = False
495- if m [ " role" ] == "assistant" :
498+ if m . role == "assistant" :
496499 transform = transform_assistant
497- elif m [ " role" ] == "user" :
500+ elif m . role == "user" :
498501 transform = transform_user == "all" or (
499502 transform_user == "last" and i == len (messages ) - 1
500503 )
501- content_key = m ["transform_key" if transform else "pre_transform_key" ]
502- chat_msg = ChatMessage (content = str (m [content_key ]), role = m ["role" ])
504+ content_key = getattr (
505+ m , "transform_key" if transform else "pre_transform_key"
506+ )
507+ content = getattr (m , content_key )
508+ chat_msg = ChatMessageDict (content = str (content ), role = m .role )
503509 if not isinstance (format , MISSING_TYPE ):
504510 chat_msg = as_provider_message (chat_msg , format )
505511 res .append (chat_msg )
@@ -593,9 +599,9 @@ async def _append_message(
593599 else :
594600 msg = normalize_message_chunk (message )
595601 # Update the current stream message
596- chunk_content = msg [ " content" ]
602+ chunk_content = msg . content
597603 self ._current_stream_message += chunk_content
598- msg [ " content" ] = self ._current_stream_message
604+ msg . content = self ._current_stream_message
599605 if chunk == "end" :
600606 self ._current_stream_message = ""
601607
@@ -739,7 +745,7 @@ async def _append_message_stream(
739745 ):
740746 id = _utils .private_random_id ()
741747
742- empty = ChatMessage (content = "" , role = "assistant" )
748+ empty = ChatMessageDict (content = "" , role = "assistant" )
743749 await self ._append_message (empty , chunk = "start" , stream_id = id , icon = icon )
744750
745751 try :
@@ -771,7 +777,7 @@ async def _send_append_message(
771777 chunk : ChunkOption = False ,
772778 icon : HTML | Tag | TagList | None = None ,
773779 ):
774- if message [ " role" ] == "system" :
780+ if message . role == "system" :
775781 # System messages are not displayed in the UI
776782 return
777783
@@ -786,21 +792,21 @@ async def _send_append_message(
786792 elif chunk == "end" :
787793 chunk_type = "message_end"
788794
789- content = message [ " content_client" ]
795+ content = message . content_client
790796 content_type = "html" if isinstance (content , HTML ) else "markdown"
791797
792798 # TODO: pass along dependencies for both content and icon (if any)
793799 msg = ClientMessage (
794800 content = str (content ),
795- role = message [ " role" ] ,
801+ role = message . role ,
796802 content_type = content_type ,
797803 chunk_type = chunk_type ,
798804 )
799805
800806 if icon is not None :
801807 msg ["icon" ] = str (icon )
802808
803- deps = message .get ( " html_deps" , [])
809+ deps = message .html_deps
804810 if deps :
805811 msg ["html_deps" ] = deps
806812
@@ -932,15 +938,15 @@ async def _transform_message(
932938 chunk : ChunkOption = False ,
933939 chunk_content : str | None = None ,
934940 ) -> TransformedMessage | None :
935- res = as_transformed_message (message )
936- key = res [ " transform_key" ]
941+ res = TransformedMessage . from_chat_message (message )
942+ key = res . transform_key
937943
938- if message [ " role" ] == "user" and self ._transform_user is not None :
939- content = await self ._transform_user (message [ " content" ] )
944+ if message . role == "user" and self ._transform_user is not None :
945+ content = await self ._transform_user (message . content )
940946
941- elif message [ " role" ] == "assistant" and self ._transform_assistant is not None :
947+ elif message . role == "assistant" and self ._transform_assistant is not None :
942948 content = await self ._transform_assistant (
943- message [ " content" ] ,
949+ message . content ,
944950 chunk_content or "" ,
945951 chunk == "end" or chunk is False ,
946952 )
@@ -950,7 +956,7 @@ async def _transform_message(
950956 if content is None :
951957 return None
952958
953- res [ key ] = content # type: ignore
959+ setattr ( res , key , content )
954960
955961 return res
956962
@@ -975,7 +981,7 @@ def _store_message(
975981 messages .insert (index , message )
976982
977983 self ._messages .set (tuple (messages ))
978- if message [ " role" ] == "user" :
984+ if message . role == "user" :
979985 self ._latest_user_input .set (message )
980986
981987 return None
@@ -1000,9 +1006,9 @@ def _trim_messages(
10001006 n_other_messages : int = 0
10011007 token_counts : list [int ] = []
10021008 for m in messages :
1003- count = self ._get_token_count (m [ " content_server" ] )
1009+ count = self ._get_token_count (m . content_server )
10041010 token_counts .append (count )
1005- if m [ " role" ] == "system" :
1011+ if m . role == "system" :
10061012 n_system_tokens += count
10071013 n_system_messages += 1
10081014 else :
@@ -1023,7 +1029,7 @@ def _trim_messages(
10231029 n_other_messages2 : int = 0
10241030 token_counts .reverse ()
10251031 for i , m in enumerate (reversed (messages )):
1026- if m [ " role" ] == "system" :
1032+ if m . role == "system" :
10271033 messages2 .append (m )
10281034 continue
10291035 remaining_non_system_tokens -= token_counts [i ]
@@ -1046,13 +1052,13 @@ def _trim_anthropic_messages(
10461052 self ,
10471053 messages : tuple [TransformedMessage , ...],
10481054 ) -> tuple [TransformedMessage , ...]:
1049- if any (m [ " role" ] == "system" for m in messages ):
1055+ if any (m . role == "system" for m in messages ):
10501056 raise ValueError (
10511057 "Anthropic requires a system prompt to be specified in it's `.create()` method "
10521058 "(not in the chat messages with `role: system`)."
10531059 )
10541060 for i , m in enumerate (messages ):
1055- if m [ " role" ] == "user" :
1061+ if m . role == "user" :
10561062 return messages [i :]
10571063
10581064 return ()
@@ -1098,7 +1104,8 @@ def user_input(self, transform: bool = False) -> str | None:
10981104 if msg is None :
10991105 return None
11001106 key = "content_server" if transform else "content_client"
1101- return str (msg [key ])
1107+ val = getattr (msg , key )
1108+ return str (val )
11021109
11031110 def _user_input (self ) -> str :
11041111 id = self .user_input_id
@@ -1194,7 +1201,7 @@ class ChatExpress(Chat):
11941201 def ui (
11951202 self ,
11961203 * ,
1197- messages : Optional [Sequence [str | ChatMessage ]] = None ,
1204+ messages : Optional [Sequence [str | ChatMessageDict ]] = None ,
11981205 placeholder : str = "Enter a message..." ,
11991206 width : CssUnit = "min(680px, 100%)" ,
12001207 height : CssUnit = "auto" ,
@@ -1244,7 +1251,7 @@ def ui(
12441251def chat_ui (
12451252 id : str ,
12461253 * ,
1247- messages : Optional [Sequence [TagChild | ChatMessage ]] = None ,
1254+ messages : Optional [Sequence [TagChild | ChatMessageDict ]] = None ,
12481255 placeholder : str = "Enter a message..." ,
12491256 width : CssUnit = "min(680px, 100%)" ,
12501257 height : CssUnit = "auto" ,
@@ -1361,27 +1368,4 @@ def chat_ui(
13611368 return res
13621369
13631370
1364- def as_transformed_message (message : ChatMessage ) -> TransformedMessage :
1365- if message ["role" ] == "user" :
1366- transform_key = "content_server"
1367- pre_transform_key = "content_client"
1368- else :
1369- transform_key = "content_client"
1370- pre_transform_key = "content_server"
1371-
1372- res = TransformedMessage (
1373- content_client = message ["content" ],
1374- content_server = message ["content" ],
1375- role = message ["role" ],
1376- transform_key = transform_key ,
1377- pre_transform_key = pre_transform_key ,
1378- )
1379-
1380- deps = message .get ("html_deps" , [])
1381- if deps :
1382- res ["html_deps" ] = deps
1383-
1384- return res
1385-
1386-
13871371CHAT_INSTANCES : WeakValueDictionary [str , Chat ] = WeakValueDictionary ()
0 commit comments