Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions shiny/session/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from contextvars import ContextVar, Token
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar

from htmltools import TagChild

if TYPE_CHECKING:
from ._session import Session

Expand Down Expand Up @@ -134,17 +132,6 @@ def require_active_session(session: Optional[Session]) -> Session:
return session


def process_ui(ui: TagChild) -> tuple[str, list[dict[str, str]]]:
"""
Process a UI element with the session, returning the HTML and dependencies.
"""
if isinstance(ui, (str, float, int)):
return str(ui), []
session = require_active_session(None)
res = session._process_ui(ui)
return res["html"], res["deps"]


# Ideally I'd love not to limit the types for T, but if I don't, the type checker has
# trouble figuring out what `T` is supposed to be when run_thunk is actually used. For
# now, just keep expanding the possible types, as needed.
Expand Down
96 changes: 40 additions & 56 deletions shiny/ui/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@
as_provider_message,
)
from ._chat_tokenizer import TokenEncoding, TokenizersEncoding, get_default_tokenizer
from ._chat_types import ChatMessage, ClientMessage, TransformedMessage
from ._chat_types import ChatMessage, ChatMessageDict, ClientMessage, TransformedMessage
from ._html_deps_py_shiny import chat_deps
from .fill import as_fill_item, as_fillable_container

__all__ = (
"Chat",
"ChatExpress",
"chat_ui",
"ChatMessage",
"ChatMessageDict",
)


Expand Down Expand Up @@ -251,7 +251,10 @@ async def _on_user_input():
else:
# A transformed value of None is a special signal to suspend input
# handling (i.e., don't generate a response)
self._store_message(as_transformed_message(msg), index=n_pre)
self._store_message(
TransformedMessage.from_chat_message(msg),
index=n_pre,
)
await self._remove_loading_message()
self._suspend_input_handler = True

Expand Down Expand Up @@ -412,7 +415,7 @@ def messages(
token_limits: tuple[int, int] | None = None,
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
) -> tuple[ChatMessage, ...]: ...
) -> tuple[ChatMessageDict, ...]: ...

def messages(
self,
Expand All @@ -421,7 +424,7 @@ def messages(
token_limits: tuple[int, int] | None = None,
transform_user: Literal["all", "last", "none"] = "all",
transform_assistant: bool = False,
) -> tuple[ChatMessage | ProviderMessage, ...]:
) -> tuple[ChatMessageDict | ProviderMessage, ...]:
"""
Reactively read chat messages

Expand Down Expand Up @@ -489,17 +492,20 @@ def messages(
if token_limits is not None:
messages = self._trim_messages(messages, token_limits, format)

res: list[ChatMessage | ProviderMessage] = []
res: list[ChatMessageDict | ProviderMessage] = []
for i, m in enumerate(messages):
transform = False
if m["role"] == "assistant":
if m.role == "assistant":
transform = transform_assistant
elif m["role"] == "user":
elif m.role == "user":
transform = transform_user == "all" or (
transform_user == "last" and i == len(messages) - 1
)
content_key = m["transform_key" if transform else "pre_transform_key"]
chat_msg = ChatMessage(content=str(m[content_key]), role=m["role"])
content_key = getattr(
m, "transform_key" if transform else "pre_transform_key"
)
content = getattr(m, content_key)
chat_msg = ChatMessageDict(content=str(content), role=m.role)
if not isinstance(format, MISSING_TYPE):
chat_msg = as_provider_message(chat_msg, format)
res.append(chat_msg)
Expand Down Expand Up @@ -593,9 +599,9 @@ async def _append_message(
else:
msg = normalize_message_chunk(message)
# Update the current stream message
chunk_content = msg["content"]
chunk_content = msg.content
self._current_stream_message += chunk_content
msg["content"] = self._current_stream_message
msg.content = self._current_stream_message
if chunk == "end":
self._current_stream_message = ""

Expand Down Expand Up @@ -739,7 +745,7 @@ async def _append_message_stream(
):
id = _utils.private_random_id()

empty = ChatMessage(content="", role="assistant")
empty = ChatMessageDict(content="", role="assistant")
await self._append_message(empty, chunk="start", stream_id=id, icon=icon)

try:
Expand Down Expand Up @@ -771,7 +777,7 @@ async def _send_append_message(
chunk: ChunkOption = False,
icon: HTML | Tag | TagList | None = None,
):
if message["role"] == "system":
if message.role == "system":
# System messages are not displayed in the UI
return

Expand All @@ -786,21 +792,21 @@ async def _send_append_message(
elif chunk == "end":
chunk_type = "message_end"

content = message["content_client"]
content = message.content_client
content_type = "html" if isinstance(content, HTML) else "markdown"

# TODO: pass along dependencies for both content and icon (if any)
msg = ClientMessage(
content=str(content),
role=message["role"],
role=message.role,
content_type=content_type,
chunk_type=chunk_type,
)

if icon is not None:
msg["icon"] = str(icon)

deps = message.get("html_deps", [])
deps = message.html_deps
if deps:
msg["html_deps"] = deps

Expand Down Expand Up @@ -932,15 +938,15 @@ async def _transform_message(
chunk: ChunkOption = False,
chunk_content: str | None = None,
) -> TransformedMessage | None:
res = as_transformed_message(message)
key = res["transform_key"]
res = TransformedMessage.from_chat_message(message)
key = res.transform_key

if message["role"] == "user" and self._transform_user is not None:
content = await self._transform_user(message["content"])
if message.role == "user" and self._transform_user is not None:
content = await self._transform_user(message.content)

elif message["role"] == "assistant" and self._transform_assistant is not None:
elif message.role == "assistant" and self._transform_assistant is not None:
content = await self._transform_assistant(
message["content"],
message.content,
chunk_content or "",
chunk == "end" or chunk is False,
)
Expand All @@ -950,7 +956,7 @@ async def _transform_message(
if content is None:
return None

res[key] = content # type: ignore
setattr(res, key, content)

return res

Expand All @@ -975,7 +981,7 @@ def _store_message(
messages.insert(index, message)

self._messages.set(tuple(messages))
if message["role"] == "user":
if message.role == "user":
self._latest_user_input.set(message)

return None
Expand All @@ -1000,9 +1006,9 @@ def _trim_messages(
n_other_messages: int = 0
token_counts: list[int] = []
for m in messages:
count = self._get_token_count(m["content_server"])
count = self._get_token_count(m.content_server)
token_counts.append(count)
if m["role"] == "system":
if m.role == "system":
n_system_tokens += count
n_system_messages += 1
else:
Expand All @@ -1023,7 +1029,7 @@ def _trim_messages(
n_other_messages2: int = 0
token_counts.reverse()
for i, m in enumerate(reversed(messages)):
if m["role"] == "system":
if m.role == "system":
messages2.append(m)
continue
remaining_non_system_tokens -= token_counts[i]
Expand All @@ -1046,13 +1052,13 @@ def _trim_anthropic_messages(
self,
messages: tuple[TransformedMessage, ...],
) -> tuple[TransformedMessage, ...]:
if any(m["role"] == "system" for m in messages):
if any(m.role == "system" for m in messages):
raise ValueError(
"Anthropic requires a system prompt to be specified in it's `.create()` method "
"(not in the chat messages with `role: system`)."
)
for i, m in enumerate(messages):
if m["role"] == "user":
if m.role == "user":
return messages[i:]

return ()
Expand Down Expand Up @@ -1098,7 +1104,8 @@ def user_input(self, transform: bool = False) -> str | None:
if msg is None:
return None
key = "content_server" if transform else "content_client"
return str(msg[key])
val = getattr(msg, key)
return str(val)

def _user_input(self) -> str:
id = self.user_input_id
Expand Down Expand Up @@ -1194,7 +1201,7 @@ class ChatExpress(Chat):
def ui(
self,
*,
messages: Optional[Sequence[str | ChatMessage]] = None,
messages: Optional[Sequence[str | ChatMessageDict]] = None,
placeholder: str = "Enter a message...",
width: CssUnit = "min(680px, 100%)",
height: CssUnit = "auto",
Expand Down Expand Up @@ -1244,7 +1251,7 @@ def ui(
def chat_ui(
id: str,
*,
messages: Optional[Sequence[TagChild | ChatMessage]] = None,
messages: Optional[Sequence[TagChild | ChatMessageDict]] = None,
placeholder: str = "Enter a message...",
width: CssUnit = "min(680px, 100%)",
height: CssUnit = "auto",
Expand Down Expand Up @@ -1361,27 +1368,4 @@ def chat_ui(
return res


def as_transformed_message(message: ChatMessage) -> TransformedMessage:
if message["role"] == "user":
transform_key = "content_server"
pre_transform_key = "content_client"
else:
transform_key = "content_client"
pre_transform_key = "content_server"

res = TransformedMessage(
content_client=message["content"],
content_server=message["content"],
role=message["role"],
transform_key=transform_key,
pre_transform_key=pre_transform_key,
)

deps = message.get("html_deps", [])
if deps:
res["html_deps"] = deps

return res


CHAT_INSTANCES: WeakValueDictionary[str, Chat] = WeakValueDictionary()
15 changes: 3 additions & 12 deletions shiny/ui/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional, cast

from htmltools import HTML, TagChild, Tagifiable
from htmltools import HTML, Tagifiable

from ..session._utils import process_ui
from ._chat_types import ChatMessage

if TYPE_CHECKING:
Expand Down Expand Up @@ -63,21 +62,13 @@ def normalize(self, message: Any) -> ChatMessage:
x = cast("dict[str, Any]", message)
if "content" not in x:
raise ValueError("Message must have 'content' key")
content, deps = process_ui(cast(TagChild, x["content"]))
res = ChatMessage(content=content, role=x.get("role", "assistant"))
if deps:
res["html_deps"] = deps
return res
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))

def normalize_chunk(self, chunk: Any) -> ChatMessage:
x = cast("dict[str, Any]", chunk)
if "content" not in x:
raise ValueError("Message must have 'content' key")
content, deps = process_ui(cast(TagChild, x["content"]))
res = ChatMessage(content=content, role=x.get("role", "assistant"))
if deps:
res["html_deps"] = deps
return res
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))

def can_normalize(self, message: Any) -> bool:
return isinstance(message, dict)
Expand Down
14 changes: 7 additions & 7 deletions shiny/ui/_chat_provider_types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
from typing import TYPE_CHECKING, Literal, Union

from ._chat_types import ChatMessage
from ._chat_types import ChatMessageDict

if TYPE_CHECKING:
from anthropic.types import MessageParam as AnthropicMessage
Expand Down Expand Up @@ -47,7 +47,7 @@
# TODO: use a strategy pattern to allow others to register
# their own message formats
def as_provider_message(
message: ChatMessage, format: ProviderMessageFormat
message: ChatMessageDict, format: ProviderMessageFormat
) -> "ProviderMessage":
if format == "anthropic":
return as_anthropic_message(message)
Expand All @@ -62,7 +62,7 @@ def as_provider_message(
raise ValueError(f"Unknown format: {format}")


def as_anthropic_message(message: ChatMessage) -> "AnthropicMessage":
def as_anthropic_message(message: ChatMessageDict) -> "AnthropicMessage":
from anthropic.types import MessageParam as AnthropicMessage

if message["role"] == "system":
Expand All @@ -72,7 +72,7 @@ def as_anthropic_message(message: ChatMessage) -> "AnthropicMessage":
return AnthropicMessage(content=message["content"], role=message["role"])


def as_google_message(message: ChatMessage) -> "GoogleMessage":
def as_google_message(message: ChatMessageDict) -> "GoogleMessage":
if sys.version_info < (3, 9):
raise ValueError("Google requires Python 3.9")

Expand All @@ -89,7 +89,7 @@ def as_google_message(message: ChatMessage) -> "GoogleMessage":
return gtypes.ContentDict(parts=[message["content"]], role=role)


def as_langchain_message(message: ChatMessage) -> "LangChainMessage":
def as_langchain_message(message: ChatMessageDict) -> "LangChainMessage":
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

content = message["content"]
Expand All @@ -103,7 +103,7 @@ def as_langchain_message(message: ChatMessage) -> "LangChainMessage":
raise ValueError(f"Unknown role: {message['role']}")


def as_openai_message(message: ChatMessage) -> "OpenAIMessage":
def as_openai_message(message: ChatMessageDict) -> "OpenAIMessage":
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionSystemMessageParam,
Expand All @@ -121,7 +121,7 @@ def as_openai_message(message: ChatMessage) -> "OpenAIMessage":
raise ValueError(f"Unknown role: {role}")


def as_ollama_message(message: ChatMessage) -> "OllamaMessage":
def as_ollama_message(message: ChatMessageDict) -> "OllamaMessage":
from ollama import Message as OllamaMessage

return OllamaMessage(content=message["content"], role=message["role"])
Loading
Loading