diff --git a/shiny/ui/_chat.py b/shiny/ui/_chat.py index d1f0c6104..d728bcc7d 100644 --- a/shiny/ui/_chat.py +++ b/shiny/ui/_chat.py @@ -489,7 +489,7 @@ def messages( transform_user == "last" and i == len(messages) - 1 ) content_key = m["transform_key" if transform else "pre_transform_key"] - chat_msg = ChatMessage(content=m[content_key], role=m["role"]) + chat_msg = ChatMessage(content=str(m[content_key]), role=m["role"]) if not isinstance(format, MISSING_TYPE): chat_msg = as_provider_message(chat_msg, format) res.append(chat_msg) @@ -635,7 +635,7 @@ async def _send_append_message( content_type = "html" if isinstance(content, HTML) else "markdown" msg = ClientMessage( - content=content, + content=str(content), role=message["role"], content_type=content_type, chunk_type=chunk_type, @@ -790,7 +790,7 @@ async def _transform_message( if content is None: return None - res[key] = content + res[key] = content # type: ignore return res @@ -950,7 +950,7 @@ def user_input(self, transform: bool = False) -> str | None: if msg is None: return None key = "content_server" if transform else "content_client" - return msg[key] + return str(msg[key]) def _user_input(self) -> str: id = self.user_input_id diff --git a/shiny/ui/_chat_normalize.py b/shiny/ui/_chat_normalize.py index 2d5063324..7bec8102c 100644 --- a/shiny/ui/_chat_normalize.py +++ b/shiny/ui/_chat_normalize.py @@ -2,6 +2,8 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional, cast +from htmltools import HTML + from ._chat_types import ChatMessage if TYPE_CHECKING: @@ -49,10 +51,10 @@ def normalize_chunk(self, chunk: Any) -> ChatMessage: return ChatMessage(content=x or "", role="assistant") def can_normalize(self, message: Any) -> bool: - return isinstance(message, str) or message is None + return isinstance(message, (str, HTML)) or message is None def can_normalize_chunk(self, chunk: Any) -> bool: - return isinstance(chunk, str) or chunk is None + return isinstance(chunk, (str, HTML)) or chunk is None class DictNormalizer(BaseMessageNormalizer): diff --git a/shiny/ui/_chat_types.py b/shiny/ui/_chat_types.py index 458286720..34924904e 100644 --- a/shiny/ui/_chat_types.py +++ b/shiny/ui/_chat_types.py @@ -2,6 +2,8 @@ from typing import Literal, TypedDict +from htmltools import HTML + Role = Literal["assistant", "user", "system"] @@ -14,7 +16,7 @@ class ChatMessage(TypedDict): # A message once transformed have been applied class TransformedMessage(TypedDict): - content_client: str + content_client: str | HTML content_server: str role: Role transform_key: Literal["content_client", "content_server"] diff --git a/tests/playwright/shiny/components/chat/transform_assistant/app.py b/tests/playwright/shiny/components/chat/transform_assistant/app.py index fe89d84b5..17f46a2d3 100644 --- a/tests/playwright/shiny/components/chat/transform_assistant/app.py +++ b/tests/playwright/shiny/components/chat/transform_assistant/app.py @@ -1,3 +1,5 @@ +from typing import Union + from shiny.express import render, ui # Set some Shiny page options @@ -12,7 +14,7 @@ # TODO: test with append_message_stream() as well @chat.transform_assistant_response -def transform(content: str) -> str: +def transform(content: str) -> Union[str, ui.HTML]: if content == "return HTML": return ui.HTML(f"Transformed response: {content}") else: diff --git a/tests/playwright/shiny/components/chat/transform_assistant/test_chat_transform_assistant.py b/tests/playwright/shiny/components/chat/transform_assistant/test_chat_transform_assistant.py index d38f35807..b0d177f15 100644 --- a/tests/playwright/shiny/components/chat/transform_assistant/test_chat_transform_assistant.py +++ b/tests/playwright/shiny/components/chat/transform_assistant/test_chat_transform_assistant.py @@ -1,7 +1,6 @@ from playwright.sync_api import Page, expect from utils.deploy_utils import skip_on_webkit -from shiny import ui from shiny.playwright import controller from shiny.run import ShinyAppProc @@ -48,7 +47,7 @@ def test_validate_chat_transform_assistant(page: Page, local_app: ShinyAppProc) {"content": "Transformed response: `hello`", "role": "assistant"}, {"content": "return HTML", "role": "user"}, { - "content": ui.HTML("Transformed response: return HTML"), + "content": "Transformed response: return HTML", "role": "assistant", }, ]