3838 as_provider_message ,
3939)
4040from ._chat_tokenizer import TokenEncoding , TokenizersEncoding , get_default_tokenizer
41- from ._chat_types import ChatMessage , ClientMessage , TransformedMessage
41+ from ._chat_types import ChatMessage , ChatUIMessage , ClientMessage , TransformedMessage
4242from ._html_deps_py_shiny import chat_deps
4343from .fill import as_fill_item , as_fillable_container
4444
@@ -240,7 +240,7 @@ async def _init_chat():
240240 @reactive .effect (priority = 9999 )
241241 @reactive .event (self ._user_input )
242242 async def _on_user_input ():
243- msg = ChatMessage (content = self ._user_input (), role = "user" )
243+ msg = ChatUIMessage (content = self ._user_input (), role = "user" )
244244 # It's possible that during the transform, a message is appended, so get
245245 # the length now, so we can insert the new message at the right index
246246 n_pre = len (self ._messages ())
@@ -251,7 +251,7 @@ async def _on_user_input():
251251 else :
252252 # A transformed value of None is a special signal to suspend input
253253 # handling (i.e., don't generate a response)
254- self ._store_message (as_transformed_message (msg ), index = n_pre )
254+ self ._store_message (msg . as_transformed_message (), index = n_pre )
255255 await self ._remove_loading_message ()
256256 self ._suspend_input_handler = True
257257
@@ -492,14 +492,17 @@ def messages(
492492 res : list [ChatMessage | ProviderMessage ] = []
493493 for i , m in enumerate (messages ):
494494 transform = False
495- if m [ " role" ] == "assistant" :
495+ if m . role == "assistant" :
496496 transform = transform_assistant
497- elif m [ " role" ] == "user" :
497+ elif m . role == "user" :
498498 transform = transform_user == "all" or (
499499 transform_user == "last" and i == len (messages ) - 1
500500 )
501- content_key = m ["transform_key" if transform else "pre_transform_key" ]
502- chat_msg = ChatMessage (content = str (m [content_key ]), role = m ["role" ])
501+ content_key = getattr (
502+ m , "transform_key" if transform else "pre_transform_key"
503+ )
504+ content = getattr (m , content_key )
505+ chat_msg = ChatMessage (content = str (content ), role = m .role )
503506 if not isinstance (format , MISSING_TYPE ):
504507 chat_msg = as_provider_message (chat_msg , format )
505508 res .append (chat_msg )
@@ -593,9 +596,9 @@ async def _append_message(
593596 else :
594597 msg = normalize_message_chunk (message )
595598 # Update the current stream message
596- chunk_content = msg [ " content" ]
599+ chunk_content = msg . content
597600 self ._current_stream_message += chunk_content
598- msg [ " content" ] = self ._current_stream_message
601+ msg . content = self ._current_stream_message
599602 if chunk == "end" :
600603 self ._current_stream_message = ""
601604
@@ -771,7 +774,7 @@ async def _send_append_message(
771774 chunk : ChunkOption = False ,
772775 icon : HTML | Tag | TagList | None = None ,
773776 ):
774- if message [ " role" ] == "system" :
777+ if message . role == "system" :
775778 # System messages are not displayed in the UI
776779 return
777780
@@ -786,21 +789,21 @@ async def _send_append_message(
786789 elif chunk == "end" :
787790 chunk_type = "message_end"
788791
789- content = message [ " content_client" ]
792+ content = message . content_client
790793 content_type = "html" if isinstance (content , HTML ) else "markdown"
791794
792795 # TODO: pass along dependencies for both content and icon (if any)
793796 msg = ClientMessage (
794797 content = str (content ),
795- role = message [ " role" ] ,
798+ role = message . role ,
796799 content_type = content_type ,
797800 chunk_type = chunk_type ,
798801 )
799802
800803 if icon is not None :
801804 msg ["icon" ] = str (icon )
802805
803- deps = message .get ( " html_deps" , [])
806+ deps = message .html_deps
804807 if deps :
805808 msg ["html_deps" ] = deps
806809
@@ -928,19 +931,19 @@ async def _transform_wrapper(content: str, chunk: str, done: bool):
928931
929932 async def _transform_message (
930933 self ,
931- message : ChatMessage ,
934+ message : ChatUIMessage ,
932935 chunk : ChunkOption = False ,
933936 chunk_content : str | None = None ,
934937 ) -> TransformedMessage | None :
935- res = as_transformed_message (message )
936- key = res [ " transform_key" ]
938+ res = message . as_transformed_message ()
939+ key = res . transform_key
937940
938- if message [ " role" ] == "user" and self ._transform_user is not None :
939- content = await self ._transform_user (message [ " content" ] )
941+ if message . role == "user" and self ._transform_user is not None :
942+ content = await self ._transform_user (message . content )
940943
941- elif message [ " role" ] == "assistant" and self ._transform_assistant is not None :
944+ elif message . role == "assistant" and self ._transform_assistant is not None :
942945 content = await self ._transform_assistant (
943- message [ " content" ] ,
946+ message . content ,
944947 chunk_content or "" ,
945948 chunk == "end" or chunk is False ,
946949 )
@@ -975,7 +978,7 @@ def _store_message(
975978 messages .insert (index , message )
976979
977980 self ._messages .set (tuple (messages ))
978- if message [ " role" ] == "user" :
981+ if message . role == "user" :
979982 self ._latest_user_input .set (message )
980983
981984 return None
@@ -1000,9 +1003,9 @@ def _trim_messages(
10001003 n_other_messages : int = 0
10011004 token_counts : list [int ] = []
10021005 for m in messages :
1003- count = self ._get_token_count (m [ " content_server" ] )
1006+ count = self ._get_token_count (m . content_server )
10041007 token_counts .append (count )
1005- if m [ " role" ] == "system" :
1008+ if m . role == "system" :
10061009 n_system_tokens += count
10071010 n_system_messages += 1
10081011 else :
@@ -1023,7 +1026,7 @@ def _trim_messages(
10231026 n_other_messages2 : int = 0
10241027 token_counts .reverse ()
10251028 for i , m in enumerate (reversed (messages )):
1026- if m [ " role" ] == "system" :
1029+ if m . role == "system" :
10271030 messages2 .append (m )
10281031 continue
10291032 remaining_non_system_tokens -= token_counts [i ]
@@ -1046,13 +1049,13 @@ def _trim_anthropic_messages(
10461049 self ,
10471050 messages : tuple [TransformedMessage , ...],
10481051 ) -> tuple [TransformedMessage , ...]:
1049- if any (m [ " role" ] == "system" for m in messages ):
1052+ if any (m . role == "system" for m in messages ):
10501053 raise ValueError (
10511054 "Anthropic requires a system prompt to be specified in it's `.create()` method "
10521055 "(not in the chat messages with `role: system`)."
10531056 )
10541057 for i , m in enumerate (messages ):
1055- if m [ " role" ] == "user" :
1058+ if m . role == "user" :
10561059 return messages [i :]
10571060
10581061 return ()
@@ -1098,7 +1101,8 @@ def user_input(self, transform: bool = False) -> str | None:
10981101 if msg is None :
10991102 return None
11001103 key = "content_server" if transform else "content_client"
1101- return str (msg [key ])
1104+ val = getattr (msg , key )
1105+ return str (val )
11021106
11031107 def _user_input (self ) -> str :
11041108 id = self .user_input_id
@@ -1361,27 +1365,4 @@ def chat_ui(
13611365 return res
13621366
13631367
1364- def as_transformed_message (message : ChatMessage ) -> TransformedMessage :
1365- if message ["role" ] == "user" :
1366- transform_key = "content_server"
1367- pre_transform_key = "content_client"
1368- else :
1369- transform_key = "content_client"
1370- pre_transform_key = "content_server"
1371-
1372- res = TransformedMessage (
1373- content_client = message ["content" ],
1374- content_server = message ["content" ],
1375- role = message ["role" ],
1376- transform_key = transform_key ,
1377- pre_transform_key = pre_transform_key ,
1378- )
1379-
1380- deps = message .get ("html_deps" , [])
1381- if deps :
1382- res ["html_deps" ] = deps
1383-
1384- return res
1385-
1386-
13871368CHAT_INSTANCES : WeakValueDictionary [str , Chat ] = WeakValueDictionary ()
0 commit comments