Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# SPDX-FileCopyrightText: 2024-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0


# This example demonstrates how to use the MetaLlamaChatGenerator component
# with structured outputs.
# To run this example, you will need to
# set `LLAMA_API_KEY` environment variable

from haystack.dataclasses import ChatMessage
from pydantic import BaseModel

from haystack_integrations.components.generators.meta_llama import MetaLlamaChatGenerator


class NobelPrizeInfo(BaseModel):
recipient_name: str
award_year: int
category: str
achievement_description: str
nationality: str


chat_messages = [
ChatMessage.from_user(
"In 2021, American scientist David Julius received the Nobel Prize in"
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
" senses temperature and touch."
)
]
component = MetaLlamaChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
results = component.run(chat_messages)

# print(results)
2 changes: 1 addition & 1 deletion integrations/meta_llama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai>=2.13.2"]
dependencies = ["haystack-ai>=2.19.0"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/meta_llama#readme"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional

from haystack import component, default_to_dict, logging
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage, StreamingCallbackT
from haystack.tools import Tool, Toolset
from haystack.tools import ToolsType, serialize_tools_or_toolset
from haystack.utils import serialize_callable
from haystack.utils.auth import Secret
from openai.lib._pydantic import to_strict_json_schema
from pydantic import BaseModel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -60,7 +62,7 @@ def __init__(
streaming_callback: Optional[StreamingCallbackT] = None,
api_base_url: Optional[str] = "https://api.llama.com/compat/v1/",
generation_kwargs: Optional[Dict[str, Any]] = None,
tools: Optional[Union[List[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
):
"""
Creates an instance of LlamaChatGenerator. Unless specified otherwise in the `model`, this is for Llama's
Expand Down Expand Up @@ -91,6 +93,12 @@ def __init__(
events as they become available, with the stream terminated by a data: [DONE] message.
- `safe_prompt`: Whether to inject a safety prompt before all conversations.
- `random_seed`: The seed to use for random sampling.
- `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
If provided, the output will always be validated against this
format (unless the model returns a tool call).
For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
For structured outputs with streaming, the `response_format` must be a JSON
schema and not a Pydantic model.
:param tools:
A list of tools for which the model can prepare calls.
"""
Expand All @@ -110,7 +118,7 @@ def _prepare_api_call(
messages: list[ChatMessage],
streaming_callback: Optional[StreamingCallbackT] = None,
generation_kwargs: Optional[dict[str, Any]] = None,
tools: Optional[Union[list[Tool], Toolset]] = None,
tools: Optional[ToolsType] = None,
tools_strict: Optional[bool] = None,
) -> dict[str, Any]:
api_args = super(MetaLlamaChatGenerator, self)._prepare_api_call( # noqa: UP008
Expand All @@ -133,13 +141,29 @@ def to_dict(self) -> Dict[str, Any]:
The serialized component as a dictionary.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
generation_kwargs = self.generation_kwargs.copy()
response_format = generation_kwargs.get("response_format")

# If the response format is a Pydantic model, it's converted to openai's json schema format
# If it's already a json schema, it's left as is
if response_format and isinstance(response_format, type) and issubclass(response_format, BaseModel):
json_schema = {
"type": "json_schema",
"json_schema": {
"name": response_format.__name__,
"strict": True,
"schema": to_strict_json_schema(response_format),
},
}

generation_kwargs["response_format"] = json_schema

return default_to_dict(
self,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
generation_kwargs=generation_kwargs,
api_key=self.api_key.to_dict(),
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
tools=serialize_tools_or_toolset(self.tools),
)
109 changes: 107 additions & 2 deletions integrations/meta_llama/tests/test_llama_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

import json
import os
from datetime import datetime
from unittest.mock import patch
Expand All @@ -15,6 +16,7 @@
from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from pydantic import BaseModel

from haystack_integrations.components.generators.meta_llama.chat.chat_generator import (
MetaLlamaChatGenerator,
Expand Down Expand Up @@ -134,12 +136,44 @@ def test_to_dict_default(self, monkeypatch):

def test_to_dict_with_parameters(self, monkeypatch):
monkeypatch.setenv("ENV_VAR", "test-api-key")

class NobelPrizeInfo(BaseModel):
recipient_name: str
award_year: int

schema = {
"json_schema": {
"name": "NobelPrizeInfo",
"schema": {
"additionalProperties": False,
"properties": {
"award_year": {
"title": "Award Year",
"type": "integer",
},
"recipient_name": {
"title": "Recipient Name",
"type": "string",
},
},
"required": [
"recipient_name",
"award_year",
],
"title": "NobelPrizeInfo",
"type": "object",
},
"strict": True,
},
"type": "json_schema",
}

component = MetaLlamaChatGenerator(
api_key=Secret.from_env_var("ENV_VAR"),
model="Llama-4-Scout-17B-16E-Instruct-FP8",
streaming_callback=print_streaming_chunk,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params", "response_format": NobelPrizeInfo},
)
data = component.to_dict()

Expand All @@ -153,7 +187,7 @@ def test_to_dict_with_parameters(self, monkeypatch):
"model": "Llama-4-Scout-17B-16E-Instruct-FP8",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params", "response_format": schema},
}

for key, value in expected_params.items():
Expand Down Expand Up @@ -293,6 +327,77 @@ def __call__(self, chunk: StreamingChunk) -> None:
assert callback.counter > 1
assert "Paris" in callback.responses

@pytest.mark.skipif(
not os.environ.get("LLAMA_API_KEY", None),
reason="Export an env var called LLAMA_API_KEY containing the Llama API key to run this test.",
)
@pytest.mark.integration
def test_live_run_response_format(self):
class NobelPrizeInfo(BaseModel):
recipient_name: str
award_year: int
category: str
achievement_description: str
nationality: str

chat_messages = [
ChatMessage.from_user(
"In 2021, American scientist David Julius received the Nobel Prize in"
" Physiology or Medicine for his groundbreaking discoveries on how the human body"
" senses temperature and touch."
)
]
component = MetaLlamaChatGenerator(generation_kwargs={"response_format": NobelPrizeInfo})
results = component.run(chat_messages)
assert isinstance(results, dict)
assert "replies" in results
assert isinstance(results["replies"], list)
assert len(results["replies"]) == 1
assert isinstance(results["replies"][0], ChatMessage)
message = results["replies"][0]
assert isinstance(message.text, str)
msg = json.loads(message.text)
assert msg["recipient_name"] == "David Julius"
assert msg["award_year"] == 2021
assert "category" in msg
assert "achievement_description" in msg
assert msg["nationality"] == "American"

@pytest.mark.skipif(
not os.environ.get("LLAMA_API_KEY", None),
reason="Export an env var called LLAMA_API_KEY containing the Llama API key to run this test.",
)
@pytest.mark.integration
def test_live_run_with_response_format_json_schema(self):
response_schema = {
"type": "json_schema",
"json_schema": {
"name": "CapitalCity",
"strict": True,
"schema": {
"title": "CapitalCity",
"type": "object",
"properties": {
"city": {"title": "City", "type": "string"},
"country": {"title": "Country", "type": "string"},
},
"required": ["city", "country"],
"additionalProperties": False,
},
},
}

chat_messages = [ChatMessage.from_user("What's the capital of France?")]
comp = MetaLlamaChatGenerator(generation_kwargs={"response_format": response_schema})
results = comp.run(chat_messages)
assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
msg = json.loads(message.text)
assert "Paris" in msg["city"]
assert isinstance(msg["country"], str)
assert "France" in msg["country"]
assert message.meta["finish_reason"] == "stop"

@pytest.mark.skipif(
not os.environ.get("LLAMA_API_KEY", None),
reason="Export an env var called LLAMA_API_KEY containing the OpenAI API key to run this test.",
Expand Down
Loading