Skip to content

Commit e9b888f

Browse files
fix: correct MultiTurnSample user_input validation logic
Fixed validation bug where generator expression was not being evaluated. Changed from checking generator object to using all() to properly validate all messages are instances of HumanMessage, AIMessage, or ToolMessage. Added tests to verify validation works correctly.
1 parent 41cc83b commit e9b888f

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

src/ragas/dataset_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def validate_user_input(
128128
messages: t.List[t.Union[HumanMessage, AIMessage, ToolMessage]],
129129
) -> t.List[t.Union[HumanMessage, AIMessage, ToolMessage]]:
130130
"""Validates the user input messages."""
131-
if not (
131+
if not all(
132132
isinstance(m, (HumanMessage, AIMessage, ToolMessage)) for m in messages
133133
):
134134
raise ValueError(

tests/unit/test_dataset_schema.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,31 @@ def test_evaluation_dataset_type():
196196

197197
dataset = EvaluationDataset(samples=[multi_turn_sample])
198198
assert dataset.get_sample_type() == MultiTurnSample
199+
200+
201+
def test_multiturn_sample_validate_user_input_invalid_type():
202+
"""Test that MultiTurnSample validation correctly rejects invalid message types."""
203+
from pydantic import ValidationError
204+
205+
with pytest.raises(ValidationError):
206+
MultiTurnSample(
207+
user_input=[
208+
HumanMessage(content="Hello"),
209+
"invalid_string", # This should be rejected by Pydantic
210+
]
211+
)
212+
213+
214+
def test_multiturn_sample_validate_user_input_valid_types():
215+
"""Test that MultiTurnSample validation accepts valid message types."""
216+
from ragas.messages import AIMessage
217+
218+
sample = MultiTurnSample(
219+
user_input=[
220+
HumanMessage(content="Hello"),
221+
AIMessage(content="Hi there"),
222+
]
223+
)
224+
assert len(sample.user_input) == 2
225+
assert isinstance(sample.user_input[0], HumanMessage)
226+
assert isinstance(sample.user_input[1], AIMessage)

0 commit comments

Comments
 (0)