diff --git a/integrations/meta_llama/examples/chat_generator_with_structured_output.py b/integrations/meta_llama/examples/chat_generator_with_structured_output.py new file mode 100644 index 0000000000..7e68a6e024 --- /dev/null +++ b/integrations/meta_llama/examples/chat_generator_with_structured_output.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: 2024-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + + +# This example demonstrates how to use the MetaLlamaChatGenerator component +# with structured outputs. +# To run this example, you will need to +# set `LLAMA_API_KEY` environment variable + +from haystack.dataclasses import ChatMessage +from pydantic import BaseModel + +from haystack_integrations.components.generators.meta_llama import MetaLlamaChatGenerator + + +class NobelPrizeInfo(BaseModel): + recipient_name: str + award_year: int + category: str + achievement_description: str + nationality: str + + +chat_messages = [ + ChatMessage.from_user( + "In 2021, American scientist David Julius received the Nobel Prize in" + " Physiology or Medicine for his groundbreaking discoveries on how the human body" + " senses temperature and touch." + ) +] +component = MetaLlamaChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo}) +results = component.run(chat_messages) + +# print(results) diff --git a/integrations/meta_llama/pyproject.toml b/integrations/meta_llama/pyproject.toml index 56790d4f5a..b2c13d2d19 100644 --- a/integrations/meta_llama/pyproject.toml +++ b/integrations/meta_llama/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.13.2"] +dependencies = ["haystack-ai>=2.19.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/meta_llama#readme" diff --git a/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py b/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py index c0b2353428..6433ce9aec 100644 --- a/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py +++ b/integrations/meta_llama/src/haystack_integrations/components/generators/meta_llama/chat/chat_generator.py @@ -3,14 +3,16 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Optional from haystack import component, default_to_dict, logging from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage, StreamingCallbackT -from haystack.tools import Tool, Toolset +from haystack.tools import ToolsType, serialize_tools_or_toolset from haystack.utils import serialize_callable from haystack.utils.auth import Secret +from openai.lib._pydantic import to_strict_json_schema +from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -60,7 +62,7 @@ def __init__( streaming_callback: Optional[StreamingCallbackT] = None, api_base_url: Optional[str] = "https://api.llama.com/compat/v1/", generation_kwargs: Optional[Dict[str, Any]] = None, - tools: Optional[Union[List[Tool], Toolset]] = None, + tools: Optional[ToolsType] = None, ): """ Creates an instance of LlamaChatGenerator. Unless specified otherwise in the `model`, this is for Llama's @@ -91,6 +93,12 @@ def __init__( events as they become available, with the stream terminated by a data: [DONE] message. - `safe_prompt`: Whether to inject a safety prompt before all conversations. - `random_seed`: The seed to use for random sampling. + - `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response. + If provided, the output will always be validated against this + format (unless the model returns a tool call). + For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs). + For structured outputs with streaming, the `response_format` must be a JSON + schema and not a Pydantic model. :param tools: A list of tools for which the model can prepare calls. """ @@ -110,7 +118,7 @@ def _prepare_api_call( messages: list[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, generation_kwargs: Optional[dict[str, Any]] = None, - tools: Optional[Union[list[Tool], Toolset]] = None, + tools: Optional[ToolsType] = None, tools_strict: Optional[bool] = None, ) -> dict[str, Any]: api_args = super(MetaLlamaChatGenerator, self)._prepare_api_call( # noqa: UP008 @@ -133,13 +141,29 @@ def to_dict(self) -> Dict[str, Any]: The serialized component as a dictionary. """ callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None + generation_kwargs = self.generation_kwargs.copy() + response_format = generation_kwargs.get("response_format") + + # If the response format is a Pydantic model, it's converted to openai's json schema format + # If it's already a json schema, it's left as is + if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel): + json_schema = { + "type": "json_schema", + "json_schema": { + "name": response_format.__name__, + "strict": True, + "schema": to_strict_json_schema(response_format), + }, + } + + generation_kwargs["response_format"] = json_schema return default_to_dict( self, model=self.model, streaming_callback=callback_name, api_base_url=self.api_base_url, - generation_kwargs=self.generation_kwargs, + generation_kwargs=generation_kwargs, api_key=self.api_key.to_dict(), - tools=[tool.to_dict() for tool in self.tools] if self.tools else None, + tools=serialize_tools_or_toolset(self.tools), ) diff --git a/integrations/meta_llama/tests/test_llama_chat_generator.py b/integrations/meta_llama/tests/test_llama_chat_generator.py index 713fe46a9a..237a2f1c93 100644 --- a/integrations/meta_llama/tests/test_llama_chat_generator.py +++ b/integrations/meta_llama/tests/test_llama_chat_generator.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import json import os from datetime import datetime from unittest.mock import patch @@ -15,6 +16,7 @@ from openai import OpenAIError from openai.types.chat import ChatCompletion, ChatCompletionMessage from openai.types.chat.chat_completion import Choice +from pydantic import BaseModel from haystack_integrations.components.generators.meta_llama.chat.chat_generator import ( MetaLlamaChatGenerator, @@ -134,12 +136,44 @@ def test_to_dict_default(self, monkeypatch): def test_to_dict_with_parameters(self, monkeypatch): monkeypatch.setenv("ENV_VAR", "test-api-key") + + class NobelPrizeInfo(BaseModel): + recipient_name: str + award_year: int + + schema = { + "json_schema": { + "name": "NobelPrizeInfo", + "schema": { + "additionalProperties": False, + "properties": { + "award_year": { + "title": "Award Year", + "type": "integer", + }, + "recipient_name": { + "title": "Recipient Name", + "type": "string", + }, + }, + "required": [ + "recipient_name", + "award_year", + ], + "title": "NobelPrizeInfo", + "type": "object", + }, + "strict": True, + }, + "type": "json_schema", + } + component = MetaLlamaChatGenerator( api_key=Secret.from_env_var("ENV_VAR"), model="Llama-4-Scout-17B-16E-Instruct-FP8", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", - generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, + generation_kwargs={"max_tokens": 10, "some_test_param": "test-params", "response_format": NobelPrizeInfo}, ) data = component.to_dict() @@ -153,7 +187,7 @@ def test_to_dict_with_parameters(self, monkeypatch): "model": "Llama-4-Scout-17B-16E-Instruct-FP8", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", - "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, + "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params", "response_format": schema}, } for key, value in expected_params.items(): @@ -293,6 +327,77 @@ def __call__(self, chunk: StreamingChunk) -> None: assert callback.counter > 1 assert "Paris" in callback.responses + @pytest.mark.skipif( + not os.environ.get("LLAMA_API_KEY", None), + reason="Export an env var called LLAMA_API_KEY containing the Llama API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_response_format(self): + class NobelPrizeInfo(BaseModel): + recipient_name: str + award_year: int + category: str + achievement_description: str + nationality: str + + chat_messages = [ + ChatMessage.from_user( + "In 2021, American scientist David Julius received the Nobel Prize in" + " Physiology or Medicine for his groundbreaking discoveries on how the human body" + " senses temperature and touch." + ) + ] + component = MetaLlamaChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo}) + results = component.run(chat_messages) + assert isinstance(results, dict) + assert "replies" in results + assert isinstance(results["replies"], list) + assert len(results["replies"]) == 1 + assert isinstance(results["replies"][0], ChatMessage) + message = results["replies"][0] + assert isinstance(message.text, str) + msg = json.loads(message.text) + assert msg["recipient_name"] == "David Julius" + assert msg["award_year"] == 2021 + assert "category" in msg + assert "achievement_description" in msg + assert msg["nationality"] == "American" + + @pytest.mark.skipif( + not os.environ.get("LLAMA_API_KEY", None), + reason="Export an env var called LLAMA_API_KEY containing the Llama API key to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_response_format_json_schema(self): + response_schema = { + "type": "json_schema", + "json_schema": { + "name": "CapitalCity", + "strict": True, + "schema": { + "title": "CapitalCity", + "type": "object", + "properties": { + "city": {"title": "City", "type": "string"}, + "country": {"title": "Country", "type": "string"}, + }, + "required": ["city", "country"], + "additionalProperties": False, + }, + }, + } + + chat_messages = [ChatMessage.from_user("What's the capital of France?")] + comp = MetaLlamaChatGenerator(generation_kwargs={"response_format": response_schema}) + results = comp.run(chat_messages) + assert len(results["replies"]) == 1 + message: ChatMessage = results["replies"][0] + msg = json.loads(message.text) + assert "Paris" in msg["city"] + assert isinstance(msg["country"], str) + assert "France" in msg["country"] + assert message.meta["finish_reason"] == "stop" + @pytest.mark.skipif( not os.environ.get("LLAMA_API_KEY", None), reason="Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test.",