diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 33aa82b6ff643..6d9ae97992e40 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -1433,6 +1433,14 @@ def _convert_to_message_template( message_ = _create_template_from_message_type( message_type_str, template, template_format=template_format ) + elif ( + hasattr(message_type_str, "model_fields") + and "type" in message_type_str.model_fields + ): + message_type = message_type_str.model_fields["type"].default + message_ = _create_template_from_message_type( + message_type, template, template_format=template_format + ) else: message_ = message_type_str( prompt=PromptTemplate.from_template( diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 3d2fd04bdce74..02a760a18f195 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -317,6 +317,30 @@ def test_chat_prompt_template_from_messages_jinja2() -> None: ] +def test_chat_prompt_template_from_messages_using_message_classes() -> None: + """Test creating a chat prompt template using message class tuples.""" + template = ChatPromptTemplate.from_messages( + [ + (SystemMessage, "You are a helpful AI bot. Your name is {name}."), + (HumanMessage, "Hello, how are you doing?"), + (AIMessage, "I'm doing well, thanks!"), + (HumanMessage, "{user_input}"), + ] + ) + + expected = [ + SystemMessage( + content="You are a helpful AI bot. Your name is Bob.", additional_kwargs={} + ), + HumanMessage(content="Hello, how are you doing?", additional_kwargs={}), + AIMessage(content="I'm doing well, thanks!", additional_kwargs={}), + HumanMessage(content="What is your name?", additional_kwargs={}), + ] + + messages = template.format_messages(name="Bob", user_input="What is your name?") + assert messages == expected + + @pytest.mark.requires("jinja2") @pytest.mark.parametrize( ("template_format", "image_type_placeholder", "image_data_placeholder"),