Skip to content

Commit 1d170d7

Browse files
Add tutorial for evaluating LangGraph agents (#1636)
- Fixes #1635 This PR adds a detailed tutorial to guide users through building a ReAct agent using LangGraph. The tutorial also walks users through setting up an evaluation pipeline using Ragas to assess the agent's performance. --------- Co-authored-by: Jithin James <[email protected]>
1 parent 96c2952 commit 1d170d7

File tree

9 files changed

+1440
-9
lines changed

9 files changed

+1440
-9
lines changed
8.01 KB
Loading

docs/howtos/integrations/_langgraph_agent_evaluation.md

Lines changed: 424 additions & 0 deletions
Large diffs are not rendered by default.

docs/howtos/integrations/langgraph_agent_evaluation.ipynb

Lines changed: 783 additions & 0 deletions
Large diffs are not rendered by default.

docs/references/integrations.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@
1616
::: ragas.integrations.helicone
1717
options:
1818
show_root_heading: true
19+
20+
::: ragas.integrations.langgraph
21+
options:
22+
show_root_heading: true

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ nav:
9090
- Integrations:
9191
- howtos/integrations/index.md
9292
- LlamaIndex: howtos/integrations/_llamaindex.md
93+
- LangGraph: howtos/integrations/_langgraph_agent_evaluation.md
9394
- Migrations:
9495
- From v0.1 to v0.2: howtos/migrations/migrate_from_v01_to_v02.md
9596
- 📖 References:

src/ragas/evaluation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
from langchain_core.callbacks import BaseCallbackHandler, BaseCallbackManager
88
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
99
from langchain_core.language_models import BaseLanguageModel as LangchainLLM
10-
11-
from llama_index.core.base.llms.base import BaseLLM as LlamaIndexLLM
1210
from llama_index.core.base.embeddings.base import BaseEmbedding as LlamaIndexEmbedding
11+
from llama_index.core.base.llms.base import BaseLLM as LlamaIndexLLM
1312

1413
from ragas._analytics import EvaluationEvent, track, track_was_completed
1514
from ragas.callbacks import ChainType, RagasTracer, new_group
@@ -61,7 +60,9 @@ def evaluate(
6160
dataset: t.Union[Dataset, EvaluationDataset],
6261
metrics: t.Optional[t.Sequence[Metric]] = None,
6362
llm: t.Optional[BaseRagasLLM | LangchainLLM | LlamaIndexLLM] = None,
64-
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings | LlamaIndexEmbedding] = None,
63+
embeddings: t.Optional[
64+
BaseRagasEmbeddings | LangchainEmbeddings | LlamaIndexEmbedding
65+
] = None,
6566
callbacks: Callbacks = None,
6667
in_ci: bool = False,
6768
run_config: RunConfig = RunConfig(),
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import json
2+
from typing import List, Union
3+
4+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
5+
6+
import ragas.messages as r
7+
8+
9+
def convert_to_ragas_messages(
10+
messages: List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]]
11+
) -> List[Union[r.HumanMessage, r.AIMessage, r.ToolMessage]]:
12+
"""
13+
Convert LangChain messages into Ragas messages for agent evaluation.
14+
15+
Parameters
16+
----------
17+
messages : List[Union[HumanMessage, SystemMessage, AIMessage, ToolMessage]]
18+
List of LangChain message objects to be converted.
19+
20+
Returns
21+
-------
22+
List[Union[r.HumanMessage, r.AIMessage, r.ToolMessage]]
23+
List of corresponding Ragas message objects.
24+
25+
Raises
26+
------
27+
ValueError
28+
If an unsupported message type is encountered.
29+
TypeError
30+
If message content is not a string.
31+
32+
Notes
33+
-----
34+
SystemMessages are skipped in the conversion process.
35+
"""
36+
37+
def _validate_string_content(message, message_type: str) -> str:
38+
if not isinstance(message.content, str):
39+
raise TypeError(
40+
f"{message_type} content must be a string, got {type(message.content).__name__}. "
41+
f"Content: {message.content}"
42+
)
43+
return message.content
44+
45+
MESSAGE_TYPE_MAP = {
46+
HumanMessage: lambda m: r.HumanMessage(
47+
content=_validate_string_content(m, "HumanMessage")
48+
),
49+
ToolMessage: lambda m: r.ToolMessage(
50+
content=_validate_string_content(m, "ToolMessage")
51+
),
52+
}
53+
54+
def _extract_tool_calls(message: AIMessage) -> List[r.ToolCall]:
55+
tool_calls = message.additional_kwargs.get("tool_calls", [])
56+
return [
57+
r.ToolCall(
58+
name=tool_call["function"]["name"],
59+
args=json.loads(tool_call["function"]["arguments"]),
60+
)
61+
for tool_call in tool_calls
62+
]
63+
64+
def _convert_ai_message(message: AIMessage) -> r.AIMessage:
65+
tool_calls = _extract_tool_calls(message) if message.additional_kwargs else None
66+
return r.AIMessage(
67+
content=_validate_string_content(message, "AIMessage"),
68+
tool_calls=tool_calls,
69+
)
70+
71+
def _convert_message(message):
72+
if isinstance(message, SystemMessage):
73+
return None # Skip SystemMessages
74+
if isinstance(message, AIMessage):
75+
return _convert_ai_message(message)
76+
converter = MESSAGE_TYPE_MAP.get(type(message))
77+
if converter is None:
78+
raise ValueError(f"Unsupported message type: {type(message).__name__}")
79+
return converter(message)
80+
81+
return [
82+
converted
83+
for message in messages
84+
if (converted := _convert_message(message)) is not None
85+
]

src/ragas/metrics/_topic_adherence.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ class TopicClassificationOutput(BaseModel):
4848
class TopicClassificationPrompt(
4949
PydanticPrompt[TopicClassificationInput, TopicClassificationOutput]
5050
):
51-
instruction = (
52-
"Given a set of topics classify if the topic falls into any of the given reference topics."
53-
)
51+
instruction = "Given a set of topics classify if the topic falls into any of the given reference topics."
5452
input_model = TopicClassificationInput
5553
output_model = TopicClassificationOutput
5654
examples = [
@@ -149,10 +147,14 @@ class TopicAdherenceScore(MetricWithLLM, MultiTurnMetric):
149147
topic_classification_prompt: PydanticPrompt = TopicClassificationPrompt()
150148
topic_refused_prompt: PydanticPrompt = TopicRefusedPrompt()
151149

152-
async def _multi_turn_ascore(self, sample: MultiTurnSample, callbacks: Callbacks) -> float:
150+
async def _multi_turn_ascore(
151+
self, sample: MultiTurnSample, callbacks: Callbacks
152+
) -> float:
153153
assert self.llm is not None, "LLM must be set"
154154
assert isinstance(sample.user_input, list), "Sample user_input must be a list"
155-
assert isinstance(sample.reference_topics, list), "Sample reference_topics must be a list"
155+
assert isinstance(
156+
sample.reference_topics, list
157+
), "Sample reference_topics must be a list"
156158
user_input = sample.pretty_repr()
157159

158160
prompt_input = TopicExtractionInput(user_input=user_input)
@@ -168,7 +170,9 @@ async def _multi_turn_ascore(self, sample: MultiTurnSample, callbacks: Callbacks
168170
data=prompt_input, llm=self.llm, callbacks=callbacks
169171
)
170172
topic_answered_verdict.append(response.refused_to_answer)
171-
topic_answered_verdict = np.array([not answer for answer in topic_answered_verdict])
173+
topic_answered_verdict = np.array(
174+
[not answer for answer in topic_answered_verdict]
175+
)
172176

173177
prompt_input = TopicClassificationInput(
174178
reference_topics=sample.reference_topics, topics=topics

tests/unit/test_langgraph.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import json
2+
3+
import pytest
4+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
5+
6+
import ragas.messages as r
7+
from ragas.integrations.langgraph import convert_to_ragas_messages
8+
9+
10+
def test_human_message_conversion():
11+
"""Test conversion of HumanMessage with valid string content"""
12+
messages = [
13+
HumanMessage(content="Hello, add 4 and 5"),
14+
ToolMessage(content="9", tool_call_id="1"),
15+
]
16+
result = convert_to_ragas_messages(messages)
17+
18+
assert len(result) == 2
19+
assert isinstance(result[0], r.HumanMessage)
20+
assert result[0].content == "Hello, add 4 and 5"
21+
22+
23+
def test_human_message_invalid_content():
24+
"""Test HumanMessage with invalid content type raises TypeError"""
25+
messages = [HumanMessage(content=["invalid", "content"])]
26+
27+
with pytest.raises(TypeError) as exc_info:
28+
convert_to_ragas_messages(messages)
29+
assert "HumanMessage content must be a string" in str(exc_info.value)
30+
31+
32+
def test_ai_message_conversion():
33+
"""Test conversion of AIMessage with valid string content"""
34+
messages = [AIMessage(content="I'm doing well, thanks!")]
35+
result = convert_to_ragas_messages(messages)
36+
37+
assert len(result) == 1
38+
assert isinstance(result[0], r.AIMessage)
39+
assert result[0].content == "I'm doing well, thanks!"
40+
assert result[0].tool_calls is None
41+
42+
43+
def test_ai_message_with_tool_calls():
44+
"""Test conversion of AIMessage with tool calls"""
45+
46+
tool_calls = [
47+
{
48+
"function": {
49+
"arguments": '{"metal_name": "gold"}',
50+
"name": "get_metal_price",
51+
}
52+
},
53+
{
54+
"function": {
55+
"arguments": '{"metal_name": "silver"}',
56+
"name": "get_metal_price",
57+
}
58+
},
59+
]
60+
61+
messages = [
62+
AIMessage(
63+
content="Find the difference in the price of gold and silver?",
64+
additional_kwargs={"tool_calls": tool_calls},
65+
)
66+
]
67+
68+
result = convert_to_ragas_messages(messages)
69+
assert len(result) == 1
70+
assert isinstance(result[0], r.AIMessage)
71+
assert result[0].content == "Find the difference in the price of gold and silver?"
72+
assert len(result[0].tool_calls) == 2
73+
assert result[0].tool_calls[0].name == "get_metal_price"
74+
assert result[0].tool_calls[0].args == {"metal_name": "gold"}
75+
assert result[0].tool_calls[1].name == "get_metal_price"
76+
assert result[0].tool_calls[1].args == {"metal_name": "silver"}
77+
78+
79+
def test_tool_message_conversion():
80+
"""Test conversion of ToolMessage with valid string content"""
81+
messages = [
82+
HumanMessage(content="Hello, add 4 and 5"),
83+
ToolMessage(content="9", tool_call_id="2"),
84+
]
85+
result = convert_to_ragas_messages(messages)
86+
87+
assert len(result) == 2
88+
assert isinstance(result[1], r.ToolMessage)
89+
assert result[1].content == "9"
90+
91+
92+
def test_system_message_skipped():
93+
"""Test that SystemMessages are properly skipped"""
94+
messages = [SystemMessage(content="System prompt"), HumanMessage(content="Hello")]
95+
result = convert_to_ragas_messages(messages)
96+
97+
assert len(result) == 1
98+
assert isinstance(result[0], r.HumanMessage)
99+
assert result[0].content == "Hello"
100+
101+
102+
def test_unsupported_message_type():
103+
"""Test that unsupported message types raise ValueError"""
104+
105+
class CustomMessage:
106+
content = "test"
107+
108+
messages = [CustomMessage()]
109+
110+
with pytest.raises(ValueError) as exc_info:
111+
convert_to_ragas_messages(messages)
112+
assert "Unsupported message type: CustomMessage" in str(exc_info.value)
113+
114+
115+
def test_empty_message_list():
116+
"""Test conversion of empty message list"""
117+
messages = []
118+
result = convert_to_ragas_messages(messages)
119+
assert result == []
120+
121+
122+
def test_invalid_tool_calls_json():
123+
"""Test handling of invalid JSON in tool calls"""
124+
tool_calls = [{"function": {"name": "search", "arguments": "invalid json"}}]
125+
126+
messages = [AIMessage(content="Test", additional_kwargs={"tool_calls": tool_calls})]
127+
128+
with pytest.raises(json.JSONDecodeError):
129+
convert_to_ragas_messages(messages)

0 commit comments

Comments
 (0)