Skip to content

Commit ec406bc

Browse files
committed
Add unit tests for Formatter class input and output handling
1 parent 3a4a3d2 commit ec406bc

File tree

1 file changed

+252
-0
lines changed

1 file changed

+252
-0
lines changed

tests/lib/test_formatter.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import json
2+
import uuid
3+
from typing import Any
4+
from unittest.mock import MagicMock
5+
6+
from langchain_core.documents import Document
7+
from langchain_core.load import dumpd
8+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
9+
10+
from langgraphics.formatter import Formatter
11+
12+
13+
def make_run(
14+
run_type: str,
15+
inputs: dict[str, Any] | None = None,
16+
outputs: dict[str, Any] | None = None,
17+
error: str | None = None,
18+
) -> MagicMock:
19+
run = MagicMock()
20+
run.id = uuid.uuid4()
21+
run.run_type = run_type
22+
run.inputs = inputs or {}
23+
run.outputs = outputs or {}
24+
run.error = error
25+
return run
26+
27+
28+
def parsed_inputs(run) -> list[dict]:
29+
return json.loads(Formatter.inputs(run))
30+
31+
32+
def parsed_outputs(run) -> list[dict]:
33+
return json.loads(Formatter.outputs(run))
34+
35+
36+
class TestChatModelInputs:
37+
def test_human_message(self):
38+
run = make_run(
39+
"chat_model", inputs={"messages": [[dumpd(HumanMessage(content="Hello"))]]}
40+
)
41+
result = parsed_inputs(run)
42+
assert result == [{"role": "human", "content": "Hello"}]
43+
44+
def test_system_then_human(self):
45+
run = make_run(
46+
"chat_model",
47+
inputs={
48+
"messages": [
49+
[
50+
dumpd(SystemMessage(content="You are helpful")),
51+
dumpd(HumanMessage(content="What is 2+2?")),
52+
]
53+
]
54+
},
55+
)
56+
result = parsed_inputs(run)
57+
assert result == [
58+
{"role": "system", "content": "You are helpful"},
59+
{"role": "human", "content": "What is 2+2?"},
60+
]
61+
62+
def test_tool_call_message(self):
63+
msg = AIMessage(
64+
content="",
65+
tool_calls=[
66+
{
67+
"name": "search",
68+
"args": {"q": "foo"},
69+
"id": "c1",
70+
"type": "tool_call",
71+
}
72+
],
73+
)
74+
run = make_run("chat_model", inputs={"messages": [[dumpd(msg)]]})
75+
result = parsed_inputs(run)
76+
assert len(result) == 1
77+
assert result[0]["role"] == "ai"
78+
assert "search" in result[0]["content"]
79+
assert "foo" in result[0]["content"]
80+
81+
def test_empty_messages(self):
82+
run = make_run("chat_model", inputs={"messages": [[]]})
83+
result = parsed_inputs(run)
84+
assert result == []
85+
86+
def test_missing_messages_key(self):
87+
run = make_run("chat_model", inputs={})
88+
result = parsed_inputs(run)
89+
assert isinstance(result, list)
90+
91+
92+
class TestChatModelOutputs:
93+
def _make_output(self, msg) -> dict:
94+
return {"generations": [[{"message": dumpd(msg)}]]}
95+
96+
def test_ai_message(self):
97+
run = make_run(
98+
"chat_model",
99+
outputs=self._make_output(AIMessage(content="The answer is 42.")),
100+
)
101+
result = parsed_outputs(run)
102+
assert result == [{"role": "ai", "content": "The answer is 42."}]
103+
104+
def test_tool_call_output(self):
105+
msg = AIMessage(
106+
content="",
107+
tool_calls=[
108+
{
109+
"name": "calculator",
110+
"args": {"a": 1, "b": 2},
111+
"id": "c2",
112+
"type": "tool_call",
113+
}
114+
],
115+
)
116+
run = make_run("chat_model", outputs=self._make_output(msg))
117+
result = parsed_outputs(run)
118+
assert len(result) == 1
119+
assert result[0]["role"] == "ai"
120+
assert "calculator" in result[0]["content"]
121+
122+
def test_error_overrides_output(self):
123+
run = make_run(
124+
"chat_model",
125+
outputs=self._make_output(AIMessage(content="irrelevant")),
126+
error="Something went wrong",
127+
)
128+
result = parsed_outputs(run)
129+
assert result == [{"role": "error", "content": "Something went wrong"}]
130+
131+
132+
class TestLlmInputsOutputs:
133+
def test_single_prompt(self):
134+
run = make_run("llm", inputs={"prompts": ["Summarise this text."]})
135+
result = parsed_inputs(run)
136+
assert result == [{"role": "prompt", "content": "Summarise this text."}]
137+
138+
def test_llm_text_output(self):
139+
run = make_run("llm", outputs={"generations": [{"text": "Summary here."}]})
140+
result = parsed_outputs(run)
141+
assert result == [{"role": "text", "content": "Summary here."}]
142+
143+
def test_llm_error(self):
144+
run = make_run("llm", error="Token limit exceeded")
145+
result = parsed_outputs(run)
146+
assert result == [{"role": "error", "content": "Token limit exceeded"}]
147+
148+
149+
class TestToolInputsOutputs:
150+
def test_string_input(self):
151+
run = make_run("tool", inputs={"input": "search query"})
152+
result = parsed_inputs(run)
153+
assert result == [{"role": "input", "content": "search query"}]
154+
155+
def test_string_output(self):
156+
run = make_run("tool", outputs={"output": "search result"})
157+
result = parsed_outputs(run)
158+
assert result == [{"role": "output", "content": "search result"}]
159+
160+
def test_tool_message_output(self):
161+
tool_msg = ToolMessage(content="42", tool_call_id="c1")
162+
run = make_run("tool", outputs={"output": tool_msg})
163+
result = parsed_outputs(run)
164+
assert result == [{"role": "output", "content": "42"}]
165+
166+
def test_tool_error(self):
167+
run = make_run("tool", error="Tool failed")
168+
result = parsed_outputs(run)
169+
assert result == [{"role": "error", "content": "Tool failed"}]
170+
171+
172+
class TestRetrieverInputsOutputs:
173+
def test_query_input(self):
174+
run = make_run("retriever", inputs={"query": "meaning of life"})
175+
result = parsed_inputs(run)
176+
assert result == [{"role": "query", "content": "meaning of life"}]
177+
178+
def test_documents_output(self):
179+
docs = [
180+
Document(page_content="The answer is 42."),
181+
Document(page_content="It depends."),
182+
]
183+
run = make_run("retriever", outputs={"documents": docs})
184+
result = parsed_outputs(run)
185+
assert result == [
186+
{"role": "document", "content": "The answer is 42."},
187+
{"role": "document", "content": "It depends."},
188+
]
189+
190+
def test_empty_documents(self):
191+
run = make_run("retriever", outputs={"documents": []})
192+
result = parsed_outputs(run)
193+
assert result == []
194+
195+
196+
class TestChainInputsOutputs:
197+
def test_chain_input_messages(self):
198+
run = make_run(
199+
"chain", inputs={"messages": [HumanMessage(content="What is AI?")]}
200+
)
201+
result = parsed_inputs(run)
202+
assert result == [{"role": "human", "content": "What is AI?"}]
203+
204+
def test_chain_output_messages(self):
205+
run = make_run(
206+
"chain",
207+
outputs={
208+
"messages": [
209+
AIMessage(content="AI stands for Artificial Intelligence.")
210+
]
211+
},
212+
)
213+
result = parsed_outputs(run)
214+
assert result == [
215+
{"role": "ai", "content": "AI stands for Artificial Intelligence."}
216+
]
217+
218+
def test_chain_no_messages(self):
219+
run = make_run("chain", inputs={})
220+
result = parsed_inputs(run)
221+
assert result == []
222+
223+
def test_chain_error(self):
224+
run = make_run("chain", error="Graph error")
225+
result = parsed_outputs(run)
226+
assert result == [{"role": "error", "content": "Graph error"}]
227+
228+
229+
class TestNorm:
230+
def test_plain_message(self):
231+
msg = {"type": "human", "data": {"content": "Hello"}}
232+
assert Formatter.norm(msg) == {"role": "human", "content": "Hello"}
233+
234+
def test_tool_calls(self):
235+
msg = {
236+
"type": "ai",
237+
"data": {"tool_calls": [{"name": "fn", "args": {"x": 1}}]},
238+
}
239+
result = Formatter.norm(msg)
240+
assert result["role"] == "ai"
241+
assert "fn" in result["content"]
242+
assert "x" in result["content"]
243+
244+
def test_list_content(self):
245+
msg = {"type": "ai", "data": {"content": [{"text": "Block content"}]}}
246+
assert Formatter.norm(msg) == {"role": "ai", "content": "Block content"}
247+
248+
def test_missing_data(self):
249+
msg = {"type": "system"}
250+
result = Formatter.norm(msg)
251+
assert result["role"] == "system"
252+
assert result["content"] == ""

0 commit comments

Comments
 (0)