Skip to content

Commit 39a780f

Browse files
committed
Use a classmethod
1 parent 615ed00 commit 39a780f

File tree

3 files changed

+54
-47
lines changed

3 files changed

+54
-47
lines changed

shiny/ui/_chat.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,10 @@ async def _on_user_input():
251251
else:
252252
# A transformed value of None is a special signal to suspend input
253253
# handling (i.e., don't generate a response)
254-
self._store_message(msg.as_transformed_message(), index=n_pre)
254+
self._store_message(
255+
TransformedMessage.from_chat_message(msg),
256+
index=n_pre,
257+
)
255258
await self._remove_loading_message()
256259
self._suspend_input_handler = True
257260

@@ -935,7 +938,7 @@ async def _transform_message(
935938
chunk: ChunkOption = False,
936939
chunk_content: str | None = None,
937940
) -> TransformedMessage | None:
938-
res = message.as_transformed_message()
941+
res = TransformedMessage.from_chat_message(message)
939942
key = res.transform_key
940943

941944
if message.role == "user" and self._transform_user is not None:

shiny/ui/_chat_types.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,6 @@ def __init__(
3838
self.content = content
3939
self.html_deps = deps
4040

41-
def as_transformed_message(self) -> "TransformedMessage":
42-
if self.role == "user":
43-
transform_key = "content_server"
44-
pre_transform_key = "content_client"
45-
else:
46-
transform_key = "content_client"
47-
pre_transform_key = "content_server"
48-
49-
return TransformedMessage(
50-
content_client=self.content,
51-
content_server=self.content,
52-
role=self.role,
53-
transform_key=transform_key,
54-
pre_transform_key=pre_transform_key,
55-
html_deps=self.html_deps,
56-
)
57-
5841

5942
# A message once transformed have been applied
6043
@dataclass
@@ -66,6 +49,24 @@ class TransformedMessage:
6649
pre_transform_key: Literal["content_client", "content_server"]
6750
html_deps: list[dict[str, str]] | None = None
6851

52+
@classmethod
53+
def from_chat_message(cls, message: ChatUIMessage) -> "TransformedMessage":
54+
if message.role == "user":
55+
transform_key = "content_server"
56+
pre_transform_key = "content_client"
57+
else:
58+
transform_key = "content_client"
59+
pre_transform_key = "content_server"
60+
61+
return TransformedMessage(
62+
content_client=message.content,
63+
content_server=message.content,
64+
role=message.role,
65+
transform_key=transform_key,
66+
pre_transform_key=pre_transform_key,
67+
html_deps=message.html_deps,
68+
)
69+
6970

7071
# A message that can be sent to the client
7172
class ClientMessage(TypedDict):

tests/pytest/test_chat.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from shiny.types import MISSING
1414
from shiny.ui import Chat
1515
from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
16-
from shiny.ui._chat_types import ChatMessage, ChatUIMessage
16+
from shiny.ui._chat_types import ChatMessage, ChatUIMessage, Role, TransformedMessage
1717

1818
# ----------------------------------------------------------------------
1919
# Helpers
@@ -42,6 +42,12 @@ def is_type_in_union(type: object, union: object) -> bool:
4242
return False
4343

4444

45+
def transformed_message(content: str, role: Role) -> TransformedMessage:
46+
return TransformedMessage.from_chat_message(
47+
ChatUIMessage(content=content, role=role)
48+
)
49+
50+
4551
def test_chat_message_trimming():
4652
with session_context(test_session):
4753
chat = Chat(id="chat")
@@ -52,22 +58,19 @@ def generate_content(token_count: int) -> str:
5258
return " ".join(["foo" for _ in range(1, n)])
5359

5460
msgs = (
55-
ChatUIMessage(
56-
content=generate_content(102), role="system"
57-
).as_transformed_message(),
61+
transformed_message(
62+
content=generate_content(102),
63+
role="system",
64+
),
5865
)
5966

6067
# Throws since system message is too long
6168
with pytest.raises(ValueError):
6269
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)
6370

6471
msgs = (
65-
ChatUIMessage(
66-
content=generate_content(100), role="system"
67-
).as_transformed_message(),
68-
ChatUIMessage(
69-
content=generate_content(2), role="user"
70-
).as_transformed_message(),
72+
transformed_message(content=generate_content(100), role="system"),
73+
transformed_message(content=generate_content(2), role="user"),
7174
)
7275

7376
# Throws since only the system message fits
@@ -83,18 +86,18 @@ def generate_content(token_count: int) -> str:
8386
content3 = generate_content(2)
8487

8588
msgs = (
86-
ChatUIMessage(
89+
transformed_message(
8790
content=content1,
8891
role="system",
89-
).as_transformed_message(),
90-
ChatUIMessage(
92+
),
93+
transformed_message(
9194
content=content2,
9295
role="user",
93-
).as_transformed_message(),
94-
ChatUIMessage(
96+
),
97+
transformed_message(
9598
content=content3,
9699
role="user",
97-
).as_transformed_message(),
100+
),
98101
)
99102

100103
# Should discard the 1st user message
@@ -109,22 +112,22 @@ def generate_content(token_count: int) -> str:
109112
content4 = generate_content(2)
110113

111114
msgs = (
112-
ChatUIMessage(
115+
transformed_message(
113116
content=content1,
114117
role="system",
115-
).as_transformed_message(),
116-
ChatUIMessage(
118+
),
119+
transformed_message(
117120
content=content2,
118121
role="user",
119-
).as_transformed_message(),
120-
ChatUIMessage(
122+
),
123+
transformed_message(
121124
content=content3,
122125
role="system",
123-
).as_transformed_message(),
124-
ChatUIMessage(
126+
),
127+
transformed_message(
125128
content=content4,
126129
role="user",
127-
).as_transformed_message(),
130+
),
128131
)
129132

130133
# Should discard the 1st user message
@@ -137,14 +140,14 @@ def generate_content(token_count: int) -> str:
137140
content2 = generate_content(10)
138141

139142
msgs = (
140-
ChatUIMessage(
143+
transformed_message(
141144
content=content1,
142145
role="assistant",
143-
).as_transformed_message(),
144-
ChatUIMessage(
146+
),
147+
transformed_message(
145148
content=content2,
146149
role="user",
147-
).as_transformed_message(),
150+
),
148151
)
149152

150153
# Anthropic requires 1st message to be a user message

0 commit comments

Comments
 (0)