Skip to content

Commit 5cade6e

Browse files
committed
Update tests
1 parent dadd0c4 commit 5cade6e

File tree

1 file changed

+50
-58
lines changed

1 file changed

+50
-58
lines changed

tests/pytest/test_chat.py

Lines changed: 50 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
from shiny.session import session_context
1212
from shiny.types import MISSING
1313
from shiny.ui import Chat
14-
from shiny.ui._chat import as_transformed_message
1514
from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
16-
from shiny.ui._chat_types import ChatMessage
15+
from shiny.ui._chat_types import ChatMessage, ChatUIMessage
1716

1817
# ----------------------------------------------------------------------
1918
# Helpers
@@ -52,31 +51,22 @@ def generate_content(token_count: int) -> str:
5251
return " ".join(["foo" for _ in range(1, n)])
5352

5453
msgs = (
55-
as_transformed_message(
56-
{
57-
"content": generate_content(102),
58-
"role": "system",
59-
}
60-
),
54+
ChatUIMessage(
55+
content=generate_content(102), role="system"
56+
).as_transformed_message(),
6157
)
6258

6359
# Throws since system message is too long
6460
with pytest.raises(ValueError):
6561
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)
6662

6763
msgs = (
68-
as_transformed_message(
69-
{
70-
"content": generate_content(100),
71-
"role": "system",
72-
}
73-
),
74-
as_transformed_message(
75-
{
76-
"content": generate_content(2),
77-
"role": "user",
78-
}
79-
),
64+
ChatUIMessage(
65+
content=generate_content(100), role="system"
66+
).as_transformed_message(),
67+
ChatUIMessage(
68+
content=generate_content(2), role="user"
69+
).as_transformed_message(),
8070
)
8171

8272
# Throws since only the system message fits
@@ -92,30 +82,24 @@ def generate_content(token_count: int) -> str:
9282
content3 = generate_content(2)
9383

9484
msgs = (
95-
as_transformed_message(
96-
{
97-
"content": content1,
98-
"role": "system",
99-
}
100-
),
101-
as_transformed_message(
102-
{
103-
"content": content2,
104-
"role": "user",
105-
}
106-
),
107-
as_transformed_message(
108-
{
109-
"content": content3,
110-
"role": "user",
111-
}
112-
),
85+
ChatUIMessage(
86+
content=content1,
87+
role="system",
88+
).as_transformed_message(),
89+
ChatUIMessage(
90+
content=content2,
91+
role="user",
92+
).as_transformed_message(),
93+
ChatUIMessage(
94+
content=content3,
95+
role="user",
96+
).as_transformed_message(),
11397
)
11498

11599
# Should discard the 1st user message
116100
trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING)
117101
assert len(trimmed) == 2
118-
contents = [msg["content_server"] for msg in trimmed]
102+
contents = [msg.content_server for msg in trimmed]
119103
assert contents == [content1, content3]
120104

121105
content1 = generate_content(50)
@@ -124,38 +108,48 @@ def generate_content(token_count: int) -> str:
124108
content4 = generate_content(2)
125109

126110
msgs = (
127-
as_transformed_message(
128-
{"content": content1, "role": "system"},
129-
),
130-
as_transformed_message(
131-
{"content": content2, "role": "user"},
132-
),
133-
as_transformed_message(
134-
{"content": content3, "role": "system"},
135-
),
136-
as_transformed_message(
137-
{"content": content4, "role": "user"},
138-
),
111+
ChatUIMessage(
112+
content=content1,
113+
role="system",
114+
).as_transformed_message(),
115+
ChatUIMessage(
116+
content=content2,
117+
role="user",
118+
).as_transformed_message(),
119+
ChatUIMessage(
120+
content=content3,
121+
role="system",
122+
).as_transformed_message(),
123+
ChatUIMessage(
124+
content=content4,
125+
role="user",
126+
).as_transformed_message(),
139127
)
140128

141129
# Should discard the 1st user message
142130
trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING)
143131
assert len(trimmed) == 3
144-
contents = [msg["content_server"] for msg in trimmed]
132+
contents = [msg.content_server for msg in trimmed]
145133
assert contents == [content1, content3, content4]
146134

147135
content1 = generate_content(50)
148136
content2 = generate_content(10)
149137

150138
msgs = (
151-
as_transformed_message({"content": content1, "role": "assistant"}),
152-
as_transformed_message({"content": content2, "role": "user"}),
139+
ChatUIMessage(
140+
content=content1,
141+
role="assistant",
142+
).as_transformed_message(),
143+
ChatUIMessage(
144+
content=content2,
145+
role="user",
146+
).as_transformed_message(),
153147
)
154148

155149
# Anthropic requires 1st message to be a user message
156150
trimmed = chat._trim_messages(msgs, token_limits=(30, 0), format="anthropic")
157151
assert len(trimmed) == 1
158-
contents = [msg["content_server"] for msg in trimmed]
152+
contents = [msg.content_server for msg in trimmed]
159153
assert contents == [content2]
160154

161155

@@ -403,9 +397,7 @@ def test_as_google_message():
403397

404398

405399
def test_as_langchain_message():
406-
from langchain_core.language_models.base import (
407-
LanguageModelInput,
408-
)
400+
from langchain_core.language_models.base import LanguageModelInput
409401
from langchain_core.language_models.base import (
410402
Sequence as LangchainSequence, # pyright: ignore[reportPrivateImportUsage]
411403
)

0 commit comments

Comments
 (0)