Skip to content

Commit 1675212

Browse files
authored
fix(Chat)!: Move away from inheriting from TypedDict for internal Chat classes (#1897)
1 parent 2f46b2c commit 1675212

File tree

6 files changed

+177
-159
lines changed

6 files changed

+177
-159
lines changed

shiny/session/_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from contextvars import ContextVar, Token
1111
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
1212

13-
from htmltools import TagChild
14-
1513
if TYPE_CHECKING:
1614
from ._session import Session
1715

@@ -134,17 +132,6 @@ def require_active_session(session: Optional[Session]) -> Session:
134132
return session
135133

136134

137-
def process_ui(ui: TagChild) -> tuple[str, list[dict[str, str]]]:
138-
"""
139-
Process a UI element with the session, returning the HTML and dependencies.
140-
"""
141-
if isinstance(ui, (str, float, int)):
142-
return str(ui), []
143-
session = require_active_session(None)
144-
res = session._process_ui(ui)
145-
return res["html"], res["deps"]
146-
147-
148135
# Ideally I'd love not to limit the types for T, but if I don't, the type checker has
149136
# trouble figuring out what `T` is supposed to be when run_thunk is actually used. For
150137
# now, just keep expanding the possible types, as needed.

shiny/ui/_chat.py

Lines changed: 40 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@
3838
as_provider_message,
3939
)
4040
from ._chat_tokenizer import TokenEncoding, TokenizersEncoding, get_default_tokenizer
41-
from ._chat_types import ChatMessage, ClientMessage, TransformedMessage
41+
from ._chat_types import ChatMessage, ChatMessageDict, ClientMessage, TransformedMessage
4242
from ._html_deps_py_shiny import chat_deps
4343
from .fill import as_fill_item, as_fillable_container
4444

4545
__all__ = (
4646
"Chat",
4747
"ChatExpress",
4848
"chat_ui",
49-
"ChatMessage",
49+
"ChatMessageDict",
5050
)
5151

5252

@@ -251,7 +251,10 @@ async def _on_user_input():
251251
else:
252252
# A transformed value of None is a special signal to suspend input
253253
# handling (i.e., don't generate a response)
254-
self._store_message(as_transformed_message(msg), index=n_pre)
254+
self._store_message(
255+
TransformedMessage.from_chat_message(msg),
256+
index=n_pre,
257+
)
255258
await self._remove_loading_message()
256259
self._suspend_input_handler = True
257260

@@ -412,7 +415,7 @@ def messages(
412415
token_limits: tuple[int, int] | None = None,
413416
transform_user: Literal["all", "last", "none"] = "all",
414417
transform_assistant: bool = False,
415-
) -> tuple[ChatMessage, ...]: ...
418+
) -> tuple[ChatMessageDict, ...]: ...
416419

417420
def messages(
418421
self,
@@ -421,7 +424,7 @@ def messages(
421424
token_limits: tuple[int, int] | None = None,
422425
transform_user: Literal["all", "last", "none"] = "all",
423426
transform_assistant: bool = False,
424-
) -> tuple[ChatMessage | ProviderMessage, ...]:
427+
) -> tuple[ChatMessageDict | ProviderMessage, ...]:
425428
"""
426429
Reactively read chat messages
427430
@@ -489,17 +492,20 @@ def messages(
489492
if token_limits is not None:
490493
messages = self._trim_messages(messages, token_limits, format)
491494

492-
res: list[ChatMessage | ProviderMessage] = []
495+
res: list[ChatMessageDict | ProviderMessage] = []
493496
for i, m in enumerate(messages):
494497
transform = False
495-
if m["role"] == "assistant":
498+
if m.role == "assistant":
496499
transform = transform_assistant
497-
elif m["role"] == "user":
500+
elif m.role == "user":
498501
transform = transform_user == "all" or (
499502
transform_user == "last" and i == len(messages) - 1
500503
)
501-
content_key = m["transform_key" if transform else "pre_transform_key"]
502-
chat_msg = ChatMessage(content=str(m[content_key]), role=m["role"])
504+
content_key = getattr(
505+
m, "transform_key" if transform else "pre_transform_key"
506+
)
507+
content = getattr(m, content_key)
508+
chat_msg = ChatMessageDict(content=str(content), role=m.role)
503509
if not isinstance(format, MISSING_TYPE):
504510
chat_msg = as_provider_message(chat_msg, format)
505511
res.append(chat_msg)
@@ -593,9 +599,9 @@ async def _append_message(
593599
else:
594600
msg = normalize_message_chunk(message)
595601
# Update the current stream message
596-
chunk_content = msg["content"]
602+
chunk_content = msg.content
597603
self._current_stream_message += chunk_content
598-
msg["content"] = self._current_stream_message
604+
msg.content = self._current_stream_message
599605
if chunk == "end":
600606
self._current_stream_message = ""
601607

@@ -739,7 +745,7 @@ async def _append_message_stream(
739745
):
740746
id = _utils.private_random_id()
741747

742-
empty = ChatMessage(content="", role="assistant")
748+
empty = ChatMessageDict(content="", role="assistant")
743749
await self._append_message(empty, chunk="start", stream_id=id, icon=icon)
744750

745751
try:
@@ -771,7 +777,7 @@ async def _send_append_message(
771777
chunk: ChunkOption = False,
772778
icon: HTML | Tag | TagList | None = None,
773779
):
774-
if message["role"] == "system":
780+
if message.role == "system":
775781
# System messages are not displayed in the UI
776782
return
777783

@@ -786,21 +792,21 @@ async def _send_append_message(
786792
elif chunk == "end":
787793
chunk_type = "message_end"
788794

789-
content = message["content_client"]
795+
content = message.content_client
790796
content_type = "html" if isinstance(content, HTML) else "markdown"
791797

792798
# TODO: pass along dependencies for both content and icon (if any)
793799
msg = ClientMessage(
794800
content=str(content),
795-
role=message["role"],
801+
role=message.role,
796802
content_type=content_type,
797803
chunk_type=chunk_type,
798804
)
799805

800806
if icon is not None:
801807
msg["icon"] = str(icon)
802808

803-
deps = message.get("html_deps", [])
809+
deps = message.html_deps
804810
if deps:
805811
msg["html_deps"] = deps
806812

@@ -932,15 +938,15 @@ async def _transform_message(
932938
chunk: ChunkOption = False,
933939
chunk_content: str | None = None,
934940
) -> TransformedMessage | None:
935-
res = as_transformed_message(message)
936-
key = res["transform_key"]
941+
res = TransformedMessage.from_chat_message(message)
942+
key = res.transform_key
937943

938-
if message["role"] == "user" and self._transform_user is not None:
939-
content = await self._transform_user(message["content"])
944+
if message.role == "user" and self._transform_user is not None:
945+
content = await self._transform_user(message.content)
940946

941-
elif message["role"] == "assistant" and self._transform_assistant is not None:
947+
elif message.role == "assistant" and self._transform_assistant is not None:
942948
content = await self._transform_assistant(
943-
message["content"],
949+
message.content,
944950
chunk_content or "",
945951
chunk == "end" or chunk is False,
946952
)
@@ -950,7 +956,7 @@ async def _transform_message(
950956
if content is None:
951957
return None
952958

953-
res[key] = content # type: ignore
959+
setattr(res, key, content)
954960

955961
return res
956962

@@ -975,7 +981,7 @@ def _store_message(
975981
messages.insert(index, message)
976982

977983
self._messages.set(tuple(messages))
978-
if message["role"] == "user":
984+
if message.role == "user":
979985
self._latest_user_input.set(message)
980986

981987
return None
@@ -1000,9 +1006,9 @@ def _trim_messages(
10001006
n_other_messages: int = 0
10011007
token_counts: list[int] = []
10021008
for m in messages:
1003-
count = self._get_token_count(m["content_server"])
1009+
count = self._get_token_count(m.content_server)
10041010
token_counts.append(count)
1005-
if m["role"] == "system":
1011+
if m.role == "system":
10061012
n_system_tokens += count
10071013
n_system_messages += 1
10081014
else:
@@ -1023,7 +1029,7 @@ def _trim_messages(
10231029
n_other_messages2: int = 0
10241030
token_counts.reverse()
10251031
for i, m in enumerate(reversed(messages)):
1026-
if m["role"] == "system":
1032+
if m.role == "system":
10271033
messages2.append(m)
10281034
continue
10291035
remaining_non_system_tokens -= token_counts[i]
@@ -1046,13 +1052,13 @@ def _trim_anthropic_messages(
10461052
self,
10471053
messages: tuple[TransformedMessage, ...],
10481054
) -> tuple[TransformedMessage, ...]:
1049-
if any(m["role"] == "system" for m in messages):
1055+
if any(m.role == "system" for m in messages):
10501056
raise ValueError(
10511057
"Anthropic requires a system prompt to be specified in it's `.create()` method "
10521058
"(not in the chat messages with `role: system`)."
10531059
)
10541060
for i, m in enumerate(messages):
1055-
if m["role"] == "user":
1061+
if m.role == "user":
10561062
return messages[i:]
10571063

10581064
return ()
@@ -1098,7 +1104,8 @@ def user_input(self, transform: bool = False) -> str | None:
10981104
if msg is None:
10991105
return None
11001106
key = "content_server" if transform else "content_client"
1101-
return str(msg[key])
1107+
val = getattr(msg, key)
1108+
return str(val)
11021109

11031110
def _user_input(self) -> str:
11041111
id = self.user_input_id
@@ -1194,7 +1201,7 @@ class ChatExpress(Chat):
11941201
def ui(
11951202
self,
11961203
*,
1197-
messages: Optional[Sequence[str | ChatMessage]] = None,
1204+
messages: Optional[Sequence[str | ChatMessageDict]] = None,
11981205
placeholder: str = "Enter a message...",
11991206
width: CssUnit = "min(680px, 100%)",
12001207
height: CssUnit = "auto",
@@ -1244,7 +1251,7 @@ def ui(
12441251
def chat_ui(
12451252
id: str,
12461253
*,
1247-
messages: Optional[Sequence[TagChild | ChatMessage]] = None,
1254+
messages: Optional[Sequence[TagChild | ChatMessageDict]] = None,
12481255
placeholder: str = "Enter a message...",
12491256
width: CssUnit = "min(680px, 100%)",
12501257
height: CssUnit = "auto",
@@ -1361,27 +1368,4 @@ def chat_ui(
13611368
return res
13621369

13631370

1364-
def as_transformed_message(message: ChatMessage) -> TransformedMessage:
1365-
if message["role"] == "user":
1366-
transform_key = "content_server"
1367-
pre_transform_key = "content_client"
1368-
else:
1369-
transform_key = "content_client"
1370-
pre_transform_key = "content_server"
1371-
1372-
res = TransformedMessage(
1373-
content_client=message["content"],
1374-
content_server=message["content"],
1375-
role=message["role"],
1376-
transform_key=transform_key,
1377-
pre_transform_key=pre_transform_key,
1378-
)
1379-
1380-
deps = message.get("html_deps", [])
1381-
if deps:
1382-
res["html_deps"] = deps
1383-
1384-
return res
1385-
1386-
13871371
CHAT_INSTANCES: WeakValueDictionary[str, Chat] = WeakValueDictionary()

shiny/ui/_chat_normalize.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
from abc import ABC, abstractmethod
33
from typing import TYPE_CHECKING, Any, Optional, cast
44

5-
from htmltools import HTML, TagChild, Tagifiable
5+
from htmltools import HTML, Tagifiable
66

7-
from ..session._utils import process_ui
87
from ._chat_types import ChatMessage
98

109
if TYPE_CHECKING:
@@ -63,21 +62,13 @@ def normalize(self, message: Any) -> ChatMessage:
6362
x = cast("dict[str, Any]", message)
6463
if "content" not in x:
6564
raise ValueError("Message must have 'content' key")
66-
content, deps = process_ui(cast(TagChild, x["content"]))
67-
res = ChatMessage(content=content, role=x.get("role", "assistant"))
68-
if deps:
69-
res["html_deps"] = deps
70-
return res
65+
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))
7166

7267
def normalize_chunk(self, chunk: Any) -> ChatMessage:
7368
x = cast("dict[str, Any]", chunk)
7469
if "content" not in x:
7570
raise ValueError("Message must have 'content' key")
76-
content, deps = process_ui(cast(TagChild, x["content"]))
77-
res = ChatMessage(content=content, role=x.get("role", "assistant"))
78-
if deps:
79-
res["html_deps"] = deps
80-
return res
71+
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))
8172

8273
def can_normalize(self, message: Any) -> bool:
8374
return isinstance(message, dict)

shiny/ui/_chat_provider_types.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys
22
from typing import TYPE_CHECKING, Literal, Union
33

4-
from ._chat_types import ChatMessage
4+
from ._chat_types import ChatMessageDict
55

66
if TYPE_CHECKING:
77
from anthropic.types import MessageParam as AnthropicMessage
@@ -47,7 +47,7 @@
4747
# TODO: use a strategy pattern to allow others to register
4848
# their own message formats
4949
def as_provider_message(
50-
message: ChatMessage, format: ProviderMessageFormat
50+
message: ChatMessageDict, format: ProviderMessageFormat
5151
) -> "ProviderMessage":
5252
if format == "anthropic":
5353
return as_anthropic_message(message)
@@ -62,7 +62,7 @@ def as_provider_message(
6262
raise ValueError(f"Unknown format: {format}")
6363

6464

65-
def as_anthropic_message(message: ChatMessage) -> "AnthropicMessage":
65+
def as_anthropic_message(message: ChatMessageDict) -> "AnthropicMessage":
6666
from anthropic.types import MessageParam as AnthropicMessage
6767

6868
if message["role"] == "system":
@@ -72,7 +72,7 @@ def as_anthropic_message(message: ChatMessage) -> "AnthropicMessage":
7272
return AnthropicMessage(content=message["content"], role=message["role"])
7373

7474

75-
def as_google_message(message: ChatMessage) -> "GoogleMessage":
75+
def as_google_message(message: ChatMessageDict) -> "GoogleMessage":
7676
if sys.version_info < (3, 9):
7777
raise ValueError("Google requires Python 3.9")
7878

@@ -89,7 +89,7 @@ def as_google_message(message: ChatMessage) -> "GoogleMessage":
8989
return gtypes.ContentDict(parts=[message["content"]], role=role)
9090

9191

92-
def as_langchain_message(message: ChatMessage) -> "LangChainMessage":
92+
def as_langchain_message(message: ChatMessageDict) -> "LangChainMessage":
9393
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
9494

9595
content = message["content"]
@@ -103,7 +103,7 @@ def as_langchain_message(message: ChatMessage) -> "LangChainMessage":
103103
raise ValueError(f"Unknown role: {message['role']}")
104104

105105

106-
def as_openai_message(message: ChatMessage) -> "OpenAIMessage":
106+
def as_openai_message(message: ChatMessageDict) -> "OpenAIMessage":
107107
from openai.types.chat import (
108108
ChatCompletionAssistantMessageParam,
109109
ChatCompletionSystemMessageParam,
@@ -121,7 +121,7 @@ def as_openai_message(message: ChatMessage) -> "OpenAIMessage":
121121
raise ValueError(f"Unknown role: {role}")
122122

123123

124-
def as_ollama_message(message: ChatMessage) -> "OllamaMessage":
124+
def as_ollama_message(message: ChatMessageDict) -> "OllamaMessage":
125125
from ollama import Message as OllamaMessage
126126

127127
return OllamaMessage(content=message["content"], role=message["role"])

0 commit comments

Comments
 (0)