Skip to content

Commit 615ed00

Browse files
committed
Update tests
1 parent 1efe094 commit 615ed00

File tree

1 file changed

+80
-69
lines changed

1 file changed

+80
-69
lines changed

tests/pytest/test_chat.py

Lines changed: 80 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,8 @@
1212
from shiny.session import session_context
1313
from shiny.types import MISSING
1414
from shiny.ui import Chat
15-
from shiny.ui._chat import as_transformed_message
1615
from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
17-
from shiny.ui._chat_types import ChatMessage
16+
from shiny.ui._chat_types import ChatMessage, ChatUIMessage
1817

1918
# ----------------------------------------------------------------------
2019
# Helpers
@@ -53,31 +52,22 @@ def generate_content(token_count: int) -> str:
5352
return " ".join(["foo" for _ in range(1, n)])
5453

5554
msgs = (
56-
as_transformed_message(
57-
{
58-
"content": generate_content(102),
59-
"role": "system",
60-
}
61-
),
55+
ChatUIMessage(
56+
content=generate_content(102), role="system"
57+
).as_transformed_message(),
6258
)
6359

6460
# Throws since system message is too long
6561
with pytest.raises(ValueError):
6662
chat._trim_messages(msgs, token_limits=(100, 0), format=MISSING)
6763

6864
msgs = (
69-
as_transformed_message(
70-
{
71-
"content": generate_content(100),
72-
"role": "system",
73-
}
74-
),
75-
as_transformed_message(
76-
{
77-
"content": generate_content(2),
78-
"role": "user",
79-
}
80-
),
65+
ChatUIMessage(
66+
content=generate_content(100), role="system"
67+
).as_transformed_message(),
68+
ChatUIMessage(
69+
content=generate_content(2), role="user"
70+
).as_transformed_message(),
8171
)
8272

8373
# Throws since only the system message fits
@@ -93,30 +83,24 @@ def generate_content(token_count: int) -> str:
9383
content3 = generate_content(2)
9484

9585
msgs = (
96-
as_transformed_message(
97-
{
98-
"content": content1,
99-
"role": "system",
100-
}
101-
),
102-
as_transformed_message(
103-
{
104-
"content": content2,
105-
"role": "user",
106-
}
107-
),
108-
as_transformed_message(
109-
{
110-
"content": content3,
111-
"role": "user",
112-
}
113-
),
86+
ChatUIMessage(
87+
content=content1,
88+
role="system",
89+
).as_transformed_message(),
90+
ChatUIMessage(
91+
content=content2,
92+
role="user",
93+
).as_transformed_message(),
94+
ChatUIMessage(
95+
content=content3,
96+
role="user",
97+
).as_transformed_message(),
11498
)
11599

116100
# Should discard the 1st user message
117101
trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING)
118102
assert len(trimmed) == 2
119-
contents = [msg["content_server"] for msg in trimmed]
103+
contents = [msg.content_server for msg in trimmed]
120104
assert contents == [content1, content3]
121105

122106
content1 = generate_content(50)
@@ -125,38 +109,48 @@ def generate_content(token_count: int) -> str:
125109
content4 = generate_content(2)
126110

127111
msgs = (
128-
as_transformed_message(
129-
{"content": content1, "role": "system"},
130-
),
131-
as_transformed_message(
132-
{"content": content2, "role": "user"},
133-
),
134-
as_transformed_message(
135-
{"content": content3, "role": "system"},
136-
),
137-
as_transformed_message(
138-
{"content": content4, "role": "user"},
139-
),
112+
ChatUIMessage(
113+
content=content1,
114+
role="system",
115+
).as_transformed_message(),
116+
ChatUIMessage(
117+
content=content2,
118+
role="user",
119+
).as_transformed_message(),
120+
ChatUIMessage(
121+
content=content3,
122+
role="system",
123+
).as_transformed_message(),
124+
ChatUIMessage(
125+
content=content4,
126+
role="user",
127+
).as_transformed_message(),
140128
)
141129

142130
# Should discard the 1st user message
143131
trimmed = chat._trim_messages(msgs, token_limits=(103, 0), format=MISSING)
144132
assert len(trimmed) == 3
145-
contents = [msg["content_server"] for msg in trimmed]
133+
contents = [msg.content_server for msg in trimmed]
146134
assert contents == [content1, content3, content4]
147135

148136
content1 = generate_content(50)
149137
content2 = generate_content(10)
150138

151139
msgs = (
152-
as_transformed_message({"content": content1, "role": "assistant"}),
153-
as_transformed_message({"content": content2, "role": "user"}),
140+
ChatUIMessage(
141+
content=content1,
142+
role="assistant",
143+
).as_transformed_message(),
144+
ChatUIMessage(
145+
content=content2,
146+
role="user",
147+
).as_transformed_message(),
154148
)
155149

156150
# Anthropic requires 1st message to be a user message
157151
trimmed = chat._trim_messages(msgs, token_limits=(30, 0), format="anthropic")
158152
assert len(trimmed) == 1
159-
contents = [msg["content_server"] for msg in trimmed]
153+
contents = [msg.content_server for msg in trimmed]
160154
assert contents == [content2]
161155

162156

@@ -173,13 +167,15 @@ def generate_content(token_count: int) -> str:
173167

174168

175169
def test_string_normalization():
176-
msg = normalize_message_chunk("Hello world!")
177-
assert msg == {"content": "Hello world!", "role": "assistant"}
170+
m = normalize_message_chunk("Hello world!")
171+
assert m.content == "Hello world!"
172+
assert m.role == "assistant"
178173

179174

180175
def test_dict_normalization():
181-
msg = normalize_message_chunk({"content": "Hello world!", "role": "assistant"})
182-
assert msg == {"content": "Hello world!", "role": "assistant"}
176+
m = normalize_message_chunk({"content": "Hello world!", "role": "assistant"})
177+
assert m.content == "Hello world!"
178+
assert m.role == "assistant"
183179

184180

185181
def test_langchain_normalization():
@@ -195,11 +191,15 @@ def test_langchain_normalization():
195191

196192
# Mock & normalize return value of BaseChatModel.invoke()
197193
msg = BaseMessage(content="Hello world!", role="assistant", type="foo")
198-
assert normalize_message(msg) == {"content": "Hello world!", "role": "assistant"}
194+
m = normalize_message(msg)
195+
assert m.content == "Hello world!"
196+
assert m.role == "assistant"
199197

200198
# Mock & normalize return value of BaseChatModel.stream()
201199
chunk = BaseMessageChunk(content="Hello ", type="foo")
202-
assert normalize_message_chunk(chunk) == {"content": "Hello ", "role": "assistant"}
200+
m = normalize_message_chunk(chunk)
201+
assert m.content == "Hello "
202+
assert m.role == "assistant"
203203

204204

205205
def test_google_normalization():
@@ -256,7 +256,9 @@ def test_anthropic_normalization():
256256
usage=Usage(input_tokens=0, output_tokens=0),
257257
)
258258

259-
assert normalize_message(msg) == {"content": "Hello world!", "role": "assistant"}
259+
m = normalize_message(msg)
260+
assert m.content == "Hello world!"
261+
assert m.role == "assistant"
260262

261263
# Mock return object from Anthropic().messages.create(stream=True)
262264
chunk = RawContentBlockDeltaEvent(
@@ -265,7 +267,9 @@ def test_anthropic_normalization():
265267
index=0,
266268
)
267269

268-
assert normalize_message_chunk(chunk) == {"content": "Hello ", "role": "assistant"}
270+
m = normalize_message_chunk(chunk)
271+
assert m.content == "Hello "
272+
assert m.role == "assistant"
269273

270274

271275
def test_openai_normalization():
@@ -310,8 +314,9 @@ def test_openai_normalization():
310314
created=int(datetime.now().timestamp()),
311315
)
312316

313-
msg = normalize_message(completion)
314-
assert msg == {"content": "Hello world!", "role": "assistant"}
317+
m = normalize_message(completion)
318+
assert m.content == "Hello world!"
319+
assert m.role == "assistant"
315320

316321
# Mock return object from OpenAI().chat.completions.create(stream=True)
317322
chunk = ChatCompletionChunk(
@@ -330,8 +335,9 @@ def test_openai_normalization():
330335
],
331336
)
332337

333-
msg = normalize_message_chunk(chunk)
334-
assert msg == {"content": "Hello ", "role": "assistant"}
338+
m = normalize_message_chunk(chunk)
339+
assert m.content == "Hello "
340+
assert m.role == "assistant"
335341

336342

337343
def test_ollama_normalization():
@@ -344,8 +350,13 @@ def test_ollama_normalization():
344350
)
345351

346352
msg_dict = {"content": "Hello world!", "role": "assistant"}
347-
assert normalize_message(msg) == msg_dict
348-
assert normalize_message_chunk(msg) == msg_dict
353+
m = normalize_message(msg)
354+
assert m.content == msg_dict["content"]
355+
assert m.role == msg_dict["role"]
356+
357+
m = normalize_message_chunk(msg)
358+
assert m.content == msg_dict["content"]
359+
assert m.role == msg_dict["role"]
349360

350361

351362
# ------------------------------------------------------------------------------------

0 commit comments

Comments
 (0)