diff --git a/src/ragas/dataset_schema.py b/src/ragas/dataset_schema.py index 47d475493..7c0df3cfc 100644 --- a/src/ragas/dataset_schema.py +++ b/src/ragas/dataset_schema.py @@ -128,7 +128,7 @@ def validate_user_input( messages: t.List[t.Union[HumanMessage, AIMessage, ToolMessage]], ) -> t.List[t.Union[HumanMessage, AIMessage, ToolMessage]]: """Validates the user input messages.""" - if not ( + if not all( isinstance(m, (HumanMessage, AIMessage, ToolMessage)) for m in messages ): raise ValueError( diff --git a/tests/unit/test_dataset_schema.py b/tests/unit/test_dataset_schema.py index bcea4a67f..5a619c834 100644 --- a/tests/unit/test_dataset_schema.py +++ b/tests/unit/test_dataset_schema.py @@ -196,3 +196,31 @@ def test_evaluation_dataset_type(): dataset = EvaluationDataset(samples=[multi_turn_sample]) assert dataset.get_sample_type() == MultiTurnSample + + +def test_multiturn_sample_validate_user_input_invalid_type(): + """Test that MultiTurnSample validation correctly rejects invalid message types.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + MultiTurnSample( + user_input=[ + HumanMessage(content="Hello"), + "invalid_string", # This should be rejected by Pydantic + ] + ) + + +def test_multiturn_sample_validate_user_input_valid_types(): + """Test that MultiTurnSample validation accepts valid message types.""" + from ragas.messages import AIMessage + + sample = MultiTurnSample( + user_input=[ + HumanMessage(content="Hello"), + AIMessage(content="Hi there"), + ] + ) + assert len(sample.user_input) == 2 + assert isinstance(sample.user_input[0], HumanMessage) + assert isinstance(sample.user_input[1], AIMessage)