diff --git a/haystack_experimental/chat_message_stores/__init__.py b/haystack_experimental/chat_message_stores/__init__.py index a0a07a14..9872e916 100644 --- a/haystack_experimental/chat_message_stores/__init__.py +++ b/haystack_experimental/chat_message_stores/__init__.py @@ -4,4 +4,4 @@ from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore -_all_ = ["InMemoryChatMessageStore"] +__all__ = ["InMemoryChatMessageStore"] diff --git a/haystack_experimental/chat_message_stores/in_memory.py b/haystack_experimental/chat_message_stores/in_memory.py index 2be43d7a..979ce7d0 100644 --- a/haystack_experimental/chat_message_stores/in_memory.py +++ b/haystack_experimental/chat_message_stores/in_memory.py @@ -2,42 +2,58 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, Iterable, List +from typing import Any, Iterable, Optional -from haystack import default_from_dict, default_to_dict, logging +from haystack import default_from_dict, default_to_dict from haystack.dataclasses import ChatMessage from haystack_experimental.chat_message_stores.types import ChatMessageStore -logger = logging.getLogger(__name__) +# Global storage for all InMemoryDocumentStore instances, indexed by the index name. +_STORAGES: dict[str, list[ChatMessage]] = {} -class InMemoryChatMessageStore(ChatMessageStore): +class InMemoryChatMessageStore: """ Stores chat messages in-memory. - """ - def __init__( - self, - ): - """ - Initializes the InMemoryChatMessageStore. - """ - self.messages = [] + The `index` parameter is used as a unique identifier for each conversation or chat session. + It acts as a namespace that isolates messages from different sessions. Each `index` value corresponds to a + separate list of `ChatMessage` objects stored in memory. + + Typical usage involves providing a unique `index` (for example, a session ID or conversation ID) + whenever you write, read, or delete messages. This ensures that chat messages from different + conversations do not overlap. + + Usage example: + ```python + from haystack.dataclasses import ChatMessage + from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore + + message_store = InMemoryChatMessageStore() + + messages = [ + ChatMessage.from_assistant("Hello, how can I help you?"), + ChatMessage.from_user("Hi, I have a question about Python. What is a Protocol?"), + ] + message_store.write_messages(messages, index="user_456_session_123") + retrieved_messages = message_store.retrieve(index="user_456_session_123") + + print(retrieved_messages) + ``` + """ - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Serializes the component to a dictionary. :returns: Dictionary with serialized data. """ - return default_to_dict( - self, - ) + return default_to_dict(self) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "InMemoryChatMessageStore": + def from_dict(cls, data: dict[str, Any]) -> "InMemoryChatMessageStore": """ Deserializes the component from a dictionary. @@ -48,19 +64,26 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryChatMessageStore": """ return default_from_dict(cls, data) - def count_messages(self) -> int: + def count_messages(self, index: str) -> int: """ - Returns the number of chat messages stored. + Returns the number of chat messages stored in this store. + + :param index: + The index for which to count messages. :returns: The number of messages. """ - return len(self.messages) + return len(_STORAGES.get(index, [])) - def write_messages(self, messages: List[ChatMessage]) -> int: + def write_messages(self, index: str, messages: list[ChatMessage]) -> int: """ Writes chat messages to the ChatMessageStore. - :param messages: A list of ChatMessages to write. + :param index: + The index under which to store the messages. + :param messages: + A list of ChatMessages to write. + :returns: The number of messages written. :raises ValueError: If messages is not a list of ChatMessages. @@ -68,19 +91,39 @@ def write_messages(self, messages: List[ChatMessage]) -> int: if not isinstance(messages, Iterable) or any(not isinstance(message, ChatMessage) for message in messages): raise ValueError("Please provide a list of ChatMessages.") - self.messages.extend(messages) + for message in messages: + if index not in _STORAGES: + _STORAGES[index] = [] + _STORAGES[index].append(message) + return len(messages) - def delete_messages(self) -> None: + def delete_messages(self, index: str) -> None: """ Deletes all stored chat messages. + + :param index: + The index from which to delete messages. """ - self.messages = [] + _STORAGES.pop(index, None) - def retrieve(self) -> List[ChatMessage]: + def retrieve(self, index: str, last_k: Optional[int] = None) -> list[ChatMessage]: """ Retrieves all stored chat messages. + :param index: + The index from which to retrieve messages. + :param last_k: + The number of last messages to retrieve. If None, retrieves all messages. + :returns: A list of chat messages. """ - return self.messages + if last_k is not None and last_k <= 0: + raise ValueError("last_k must be greater than 0") + + messages = _STORAGES.get(index, []) + + if last_k is not None: + return messages[-last_k:] + + return messages diff --git a/haystack_experimental/chat_message_stores/types.py b/haystack_experimental/chat_message_stores/types.py index 5300a1eb..5212771b 100644 --- a/haystack_experimental/chat_message_stores/types.py +++ b/haystack_experimental/chat_message_stores/types.py @@ -2,16 +2,15 @@ # # SPDX-License-Identifier: Apache-2.0 -from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Protocol -from haystack import logging from haystack.dataclasses import ChatMessage -logger = logging.getLogger(__name__) +# Ellipsis are needed for the type checker, it's safe to disable module-wide +# pylint: disable=unnecessary-ellipsis -class ChatMessageStore(ABC): +class ChatMessageStore(Protocol): """ Stores ChatMessages to be used by the components of a Pipeline. @@ -22,53 +21,49 @@ class ChatMessageStore(ABC): In order to write or retrieve chat messages, consider using a ChatMessageWriter or ChatMessageRetriever. """ - @abstractmethod - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Serializes this store to a dictionary. :returns: The serialized store as a dictionary. """ + ... @classmethod - @abstractmethod - def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageStore": + def from_dict(cls, data: dict[str, Any]) -> "ChatMessageStore": """ Deserializes the store from a dictionary. :param data: The dictionary to deserialize from. :returns: The deserialized store. """ + ... - @abstractmethod - def count_messages(self) -> int: + def count_messages(self, index: str) -> int: """ Returns the number of chat messages stored. + :param index: The index for which to count messages. + :returns: The number of messages. """ + ... - @abstractmethod - def write_messages(self, messages: List[ChatMessage]) -> int: + def write_messages(self, index: str, messages: list[ChatMessage]) -> int: """ Writes chat messages to the ChatMessageStore. + :param index: The index under which to store the messages. :param messages: A list of ChatMessages to write. - :returns: The number of messages written. - :raises ValueError: If messages is not a list of ChatMessages. + :returns: The number of messages written. """ + ... - @abstractmethod - def delete_messages(self) -> None: + def delete_messages(self, index: str) -> None: """ Deletes all stored chat messages. - """ - - @abstractmethod - def retrieve(self) -> List[ChatMessage]: - """ - Retrieves all stored chat messages. - :returns: A list of chat messages. + :param index: The index from which to delete all messages. """ + ... \ No newline at end of file diff --git a/haystack_experimental/components/retrievers/chat_message_retriever.py b/haystack_experimental/components/retrievers/chat_message_retriever.py index 55ba85ac..1aa7b0e7 100644 --- a/haystack_experimental/components/retrievers/chat_message_retriever.py +++ b/haystack_experimental/components/retrievers/chat_message_retriever.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Optional from haystack import DeserializationError, component, default_from_dict, default_to_dict, logging from haystack.core.serialization import import_class_by_name @@ -53,18 +53,17 @@ def __init__(self, message_store: ChatMessageStore, last_k: int = 10): raise ValueError(f"last_k must be greater than 0. Currently, the last_k is {last_k}") self.last_k = last_k - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Serializes the component to a dictionary. :returns: Dictionary with serialized data. """ - message_store = self.message_store.to_dict() - return default_to_dict(self, message_store=message_store, last_k=self.last_k) + return default_to_dict(self, message_store=self.message_store.to_dict(), last_k=self.last_k) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageRetriever": + def from_dict(cls, data: dict[str, Any]) -> "ChatMessageRetriever": """ Deserializes the component from a dictionary. @@ -88,14 +87,21 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageRetriever": data["init_parameters"]["message_store"] = default_from_dict(message_store_class, message_store_data) return default_from_dict(cls, data) - @component.output_types(messages=List[ChatMessage]) - def run(self, last_k: Optional[int] = None) -> Dict[str, List[ChatMessage]]: + # TODO Consider adding messages as an optional input parameter to append to the retrieved messages + # Optionally allow for a init param to pass a custom function for how to combine the messages?? + @component.output_types(messages=list[ChatMessage]) + def run(self, index: str, last_k: Optional[int] = None) -> dict[str, list[ChatMessage]]: """ Run the ChatMessageRetriever + :param index: + A unique identifier for the chat session or conversation whose messages should be retrieved. + Each `index` corresponds to a distinct chat history stored in the underlying ChatMessageStore. + For example, use a session ID or conversation ID to isolate messages from different chat sessions. :param last_k: The number of last messages to retrieve. This parameter takes precedence over the last_k parameter passed to the ChatMessageRetriever constructor. If unspecified, the last_k parameter passed to the constructor will be used. + :returns: - `messages` - The retrieved chat messages. :raises ValueError: If last_k is not None and is less than 1 @@ -103,6 +109,6 @@ def run(self, last_k: Optional[int] = None) -> Dict[str, List[ChatMessage]]: if last_k is not None and last_k <= 0: raise ValueError("last_k must be greater than 0") - last_k = last_k or self.last_k + resolved_last_k = last_k or self.last_k - return {"messages": self.message_store.retrieve()[-last_k:]} + return {"messages": self.message_store.retrieve(index=index, last_k=resolved_last_k)} diff --git a/haystack_experimental/components/writers/__init__.py b/haystack_experimental/components/writers/__init__.py index 187d1b91..5aca5669 100644 --- a/haystack_experimental/components/writers/__init__.py +++ b/haystack_experimental/components/writers/__init__.py @@ -4,4 +4,4 @@ from haystack_experimental.components.writers.chat_message_writer import ChatMessageWriter -_all_ = ["ChatMessageWriter"] +__all__ = ["ChatMessageWriter"] diff --git a/haystack_experimental/components/writers/chat_message_writer.py b/haystack_experimental/components/writers/chat_message_writer.py index d111326d..6e624aba 100644 --- a/haystack_experimental/components/writers/chat_message_writer.py +++ b/haystack_experimental/components/writers/chat_message_writer.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List +from typing import Any from haystack import DeserializationError, component, default_from_dict, default_to_dict, logging from haystack.core.serialization import import_class_by_name @@ -10,8 +10,6 @@ from haystack_experimental.chat_message_stores.types import ChatMessageStore -logger = logging.getLogger(__name__) - @component class ChatMessageWriter: @@ -34,7 +32,7 @@ class ChatMessageWriter: ``` """ - def __init__(self, message_store: ChatMessageStore): + def __init__(self, message_store: ChatMessageStore) -> None: """ Create a ChatMessageWriter component. @@ -43,7 +41,7 @@ def __init__(self, message_store: ChatMessageStore): """ self.message_store = message_store - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """ Serializes the component to a dictionary. @@ -53,7 +51,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict(self, message_store=self.message_store.to_dict()) @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageWriter": + def from_dict(cls, data: dict[str, Any]) -> "ChatMessageWriter": """ Deserializes the component from a dictionary. @@ -79,18 +77,20 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageWriter": return default_from_dict(cls, data) @component.output_types(messages_written=int) - def run(self, messages: List[ChatMessage]) -> Dict[str, int]: + def run(self, index: str, messages: list[ChatMessage]) -> dict[str, int]: """ Run the ChatMessageWriter on the given input data. + :param index: + A unique identifier for the chat session or conversation whose messages should be retrieved. + Each `index` corresponds to a distinct chat history stored in the underlying ChatMessageStore. + For example, use a session ID or conversation ID to isolate messages from different chat sessions. :param messages: A list of chat messages to write to the store. + :returns: - `messages_written`: Number of messages written to the ChatMessageStore. - - :raises ValueError: - If the specified message store is not found. """ - messages_written = self.message_store.write_messages(messages=messages) + messages_written = self.message_store.write_messages(index=index, messages=messages) return {"messages_written": messages_written} diff --git a/test/chat_message_stores/test_in_memory_chat_message_store.py b/test/chat_message_stores/test_in_memory_chat_message_store.py index e89a3be2..1fb2aa3c 100644 --- a/test/chat_message_stores/test_in_memory_chat_message_store.py +++ b/test/chat_message_stores/test_in_memory_chat_message_store.py @@ -10,10 +10,10 @@ def test_init(self): Test that the InMemoryChatMessageStore can be initialized and that it works as expected. """ store = InMemoryChatMessageStore() - assert store.count_messages() == 0 - assert store.retrieve() == [] - assert store.write_messages([]) == 0 - assert not store.delete_messages() + assert store.count_messages(index="test") == 0 + assert store.retrieve(index="test") == [] + assert store.write_messages(index="test", messages=[]) == 0 + assert not store.delete_messages(index="test") def test_to_dict(self): """ @@ -41,46 +41,50 @@ def test_count_messages(self): Test that the InMemoryChatMessageStore can count the number of messages in the store correctly. """ store = InMemoryChatMessageStore() - assert store.count_messages() == 0 - store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")]) - assert store.count_messages() == 1 - store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) - assert store.count_messages() == 2 - store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) - assert store.count_messages() == 3 + assert store.count_messages(index="test") == 0 + store.write_messages(index="test", messages=[ChatMessage.from_user("Hello, how can I help you?")]) + assert store.count_messages(index="test") == 1 + store.write_messages(index="test", messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) + assert store.count_messages(index="test") == 2 + store.write_messages(index="test", messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) + assert store.count_messages(index="test") == 3 + # Clean up + store.delete_messages(index="test") def test_retrieve(self): """ Test that the InMemoryChatMessageStore can retrieve all messages from the store correctly. """ store = InMemoryChatMessageStore() - assert store.retrieve() == [] - store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")]) - assert store.retrieve() == [ChatMessage.from_user("Hello, how can I help you?")] - store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) - assert store.retrieve() == [ + assert store.retrieve(index="test") == [] + store.write_messages(index="test", messages=[ChatMessage.from_user("Hello, how can I help you?")]) + assert store.retrieve(index="test") == [ChatMessage.from_user("Hello, how can I help you?")] + store.write_messages(index="test", messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) + assert store.retrieve(index="test") == [ ChatMessage.from_user("Hello, how can I help you?"), ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), ] - store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) + store.write_messages(index="test", messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) assert store.retrieve() == [ ChatMessage.from_user("Hello, how can I help you?"), ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?"), ] + # Clean up + store.delete_messages(index="test") def test_delete_messages(self): """ Test that the InMemoryChatMessageStore can delete all messages from the store correctly. """ store = InMemoryChatMessageStore() - assert store.count_messages() == 0 - store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")]) - assert store.count_messages() == 1 - store.delete_messages() - assert store.count_messages() == 0 - store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) - store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) - assert store.count_messages() == 2 - store.delete_messages() - assert store.count_messages() == 0 + assert store.count_messages(index="test") == 0 + store.write_messages(index="test", messages=[ChatMessage.from_user("Hello, how can I help you?")]) + assert store.count_messages(index="test") == 1 + store.delete_messages(index="test") + assert store.count_messages(index="test") == 0 + store.write_messages(index="test", messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")]) + store.write_messages(index="test", messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")]) + assert store.count_messages(index="test") == 2 + store.delete_messages(index="test") + assert store.count_messages(index="test") == 0 diff --git a/test/components/retrievers/test_chat_message_retriever.py b/test/components/retrievers/test_chat_message_retriever.py index 1368ef25..eb9772e9 100644 --- a/test/components/retrievers/test_chat_message_retriever.py +++ b/test/components/retrievers/test_chat_message_retriever.py @@ -17,7 +17,7 @@ def test_init(self): retriever = ChatMessageRetriever(message_store) assert retriever.message_store == message_store - assert retriever.run() == {"messages": []} + assert retriever.run(index="test") == {"messages": []} def test_retrieve_messages(self): """ @@ -29,11 +29,13 @@ def test_retrieve_messages(self): ] message_store = InMemoryChatMessageStore() - message_store.write_messages(messages) + message_store.write_messages(index="test", messages=messages) retriever = ChatMessageRetriever(message_store) assert retriever.message_store == message_store - assert retriever.run() == {"messages": messages} + assert retriever.run(index="test") == {"messages": messages} + # Clean up + message_store.delete_messages(index="test") def test_retrieve_messages_last_k(self): """ @@ -47,31 +49,39 @@ def test_retrieve_messages_last_k(self): ] message_store = InMemoryChatMessageStore() - message_store.write_messages(messages) + message_store.write_messages(index="test", messages=messages) retriever = ChatMessageRetriever(message_store) assert retriever.message_store == message_store - assert retriever.run(last_k=1) == { - "messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]} + assert retriever.run(index="test", last_k=1) == { + "messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")] + } - assert retriever.run(last_k=2) == { - "messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"), - ChatMessage.from_user("Bonjour, comment puis-je vous aider?") - ]} + assert retriever.run(index="test", last_k=2) == { + "messages": [ + ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") + ] + } # outliers - assert retriever.run(last_k=10) == { - "messages": [ChatMessage.from_user("Hello, how can I help you?"), - ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), - ChatMessage.from_user("Hola, como puedo ayudarte?"), - ChatMessage.from_user("Bonjour, comment puis-je vous aider?") - ]} + assert retriever.run(index="test", last_k=10) == { + "messages": [ + ChatMessage.from_user("Hello, how can I help you?"), + ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"), + ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") + ] + } with pytest.raises(ValueError): - retriever.run(last_k=0) + retriever.run(index="test", last_k=0) with pytest.raises(ValueError): - retriever.run(last_k=-1) + retriever.run(index="test", last_k=-1) + + # Clean up + message_store.delete_messages(index="test") def test_retrieve_messages_last_k_init(self): """ @@ -86,20 +96,26 @@ def test_retrieve_messages_last_k_init(self): ] message_store = InMemoryChatMessageStore() - message_store.write_messages(messages) + message_store.write_messages(index="test", messages=messages) retriever = ChatMessageRetriever(message_store, last_k=2) assert retriever.message_store == message_store # last_k is 1 here from run parameter, overrides init of 2 - assert retriever.run(last_k=1) == { - "messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]} + assert retriever.run(index="test", last_k=1) == { + "messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")] + } # last_k is 2 here from init - assert retriever.run() == { - "messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"), - ChatMessage.from_user("Bonjour, comment puis-je vous aider?") - ]} + assert retriever.run(index="test") == { + "messages": [ + ChatMessage.from_user("Hola, como puedo ayudarte?"), + ChatMessage.from_user("Bonjour, comment puis-je vous aider?") + ] + } + + # Clean up + message_store.delete_messages(index="test") def test_to_dict(self): """ @@ -145,27 +161,32 @@ def test_chat_message_retriever_pipeline(self): """ Test that the ChatMessageRetriever can be used in a pipeline and that it works as expected. """ + index = "user_123_session_456" store = InMemoryChatMessageStore() - store.write_messages([ChatMessage.from_assistant("Hello, how can I help you?")]) + store.write_messages(index=index, messages=[ChatMessage.from_assistant("Hello, how can I help you?")]) - pipe = Pipeline() - pipe.add_component("memory_retriever", ChatMessageRetriever(store)) - pipe.add_component("prompt_builder", ChatPromptBuilder(variables=["query", "memories"])) - pipe.connect("memory_retriever", "prompt_builder.memories") - user_prompt = """ - Given the following information, answer the question. + template = ChatMessage.from_user(""" +Given the following information, answer the question. - Context: - {% for memory in memories %} - {{ memory.text }} - {% endfor %} +Context: +{% for msg in messages %} + {{ msg.text }} +{% endfor %} - Question: {{ query }} - Answer: - """ - question = "What is the capital of France?" +Question: {{ query }} +Answer: +""") - res = pipe.run(data={"prompt_builder": {"template": [ChatMessage.from_user(user_prompt)], "query": question}}) + pipe = Pipeline() + pipe.add_component("memory_retriever", ChatMessageRetriever(store)) + pipe.add_component( + "prompt_builder", ChatPromptBuilder(template=[template], required_variables=["query", "messages"]), + ) + pipe.connect("memory_retriever.messages", "prompt_builder.messages") + + res = pipe.run( + data={"prompt_builder": {"query": "What is the capital of France?"}, "memory_retriever": {"index": index}} + ) resulting_prompt = res["prompt_builder"]["prompt"][0].text assert "France" in resulting_prompt assert "how can I help you" in resulting_prompt @@ -176,8 +197,6 @@ def test_chat_message_retriever_pipeline_serde(self): """ pipe = Pipeline() pipe.add_component("memory_retriever", ChatMessageRetriever(InMemoryChatMessageStore())) - pipe.add_component("prompt_builder", ChatPromptBuilder(template=[ChatMessage.from_user("no template")], - variables=["query"])) # now serialize and deserialize the pipeline data = pipe.to_dict()