Skip to content

Commit 0053fce

Browse files
Enhance MultiTurnSample Validator to Support Multiple Tool Calls (#2017)
- Fixes: #1995 The current `validate_user_input` method has a strict validation rule that requires each `ToolMessage` to be immediately preceded by an `AIMessage` with `tool_calls`. This prevents valid conversation patterns where: 1. Multiple `ToolMessage` instances appear in sequence 2. `ToolMessage` instances appear after `ToolMessage` types as long as an `AIMessage` appeared earlier in the conversation ### Changes This PR modifies the validation logic to: 1. Track whether we've seen an `AIMessage` at any point in the conversation 2. Allow a `ToolMessage` to follow either an `AIMessage` or another `ToolMessage` ### Example The provided sample demonstrates this pattern with: ```python from ragas.dataset_schema import MultiTurnSample from ragas.messages import HumanMessage, AIMessage, ToolMessage, ToolCall sample_input = [ HumanMessage( content="Can you provide me with details about Einstein's theory of relativity?" ), AIMessage( content="Got it! Let me fetch more details from 'General Theory of Relativity by A. Einstein'.", tool_calls=[ ToolCall( name="document_retrieve", args={"document": "General Theory of Relativity by A. Einstein"}, ), ToolCall( name="document_retrieve", args={"document": "A. Einstein biography"}, ), ], ), ToolMessage( content="Found relevant documents: 1. Relativity: The Special and the General Theory, 2. General Theory of Relativity by A. Einstein." ), ToolMessage(content="Found relevant documents: 1. A. Einstein biography"), AIMessage(content="I found some documents on Einstein's theory of relativity..."), ] sample = MultiTurnSample(user_input=sample_input) ```
1 parent cd459ee commit 0053fce

File tree

1 file changed

+26
-10
lines changed

1 file changed

+26
-10
lines changed

src/ragas/dataset_schema.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,34 @@ def validate_user_input(
133133
"All inputs must be instances of HumanMessage, AIMessage, or ToolMessage."
134134
)
135135

136-
prev_message = None
137-
for m in messages:
138-
if isinstance(m, ToolMessage):
139-
if not isinstance(prev_message, AIMessage):
140-
raise ValueError(
141-
"ToolMessage instances must be preceded by an AIMessage instance."
142-
)
143-
if prev_message.tool_calls is None:
136+
has_seen_ai_message = False
137+
138+
for i, m in enumerate(messages):
139+
if isinstance(m, AIMessage):
140+
has_seen_ai_message = True
141+
142+
elif isinstance(m, ToolMessage):
143+
# Rule 1: ToolMessage must be preceded by an AIMessage somewhere in the conversation
144+
if not has_seen_ai_message:
144145
raise ValueError(
145-
f"ToolMessage instances must be preceded by an AIMessage instance with tool_calls. Got {prev_message}"
146+
"ToolMessage must be preceded by an AIMessage somewhere in the conversation."
146147
)
147-
prev_message = m
148+
149+
# Rule 2: ToolMessage must follow an AIMessage or another ToolMessage
150+
if i > 0:
151+
prev_message = messages[i - 1]
152+
153+
if isinstance(prev_message, AIMessage):
154+
# Rule 3: If following AIMessage, that message must have tool_calls
155+
if not prev_message.tool_calls:
156+
raise ValueError(
157+
"ToolMessage must follow an AIMessage where tools were called."
158+
)
159+
elif not isinstance(prev_message, ToolMessage):
160+
# Not following AIMessage or ToolMessage
161+
raise ValueError(
162+
"ToolMessage must follow an AIMessage or another ToolMessage."
163+
)
148164

149165
return messages
150166

0 commit comments

Comments
 (0)