Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion haystack_experimental/chat_message_stores/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore

_all_ = ["InMemoryChatMessageStore"]
__all__ = ["InMemoryChatMessageStore"]
97 changes: 70 additions & 27 deletions haystack_experimental/chat_message_stores/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,58 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Iterable, List
from typing import Any, Iterable, Optional

from haystack import default_from_dict, default_to_dict, logging
from haystack import default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage

from haystack_experimental.chat_message_stores.types import ChatMessageStore

logger = logging.getLogger(__name__)
# Global storage for all InMemoryDocumentStore instances, indexed by the index name.
_STORAGES: dict[str, list[ChatMessage]] = {}


class InMemoryChatMessageStore(ChatMessageStore):
class InMemoryChatMessageStore:
"""
Stores chat messages in-memory.
"""

def __init__(
self,
):
"""
Initializes the InMemoryChatMessageStore.
"""
self.messages = []
The `index` parameter is used as a unique identifier for each conversation or chat session.
It acts as a namespace that isolates messages from different sessions. Each `index` value corresponds to a
separate list of `ChatMessage` objects stored in memory.

Typical usage involves providing a unique `index` (for example, a session ID or conversation ID)
whenever you write, read, or delete messages. This ensures that chat messages from different
conversations do not overlap.

Usage example:
```python
from haystack.dataclasses import ChatMessage
from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore

message_store = InMemoryChatMessageStore()

messages = [
ChatMessage.from_assistant("Hello, how can I help you?"),
ChatMessage.from_user("Hi, I have a question about Python. What is a Protocol?"),
]
message_store.write_messages(messages, index="user_456_session_123")
retrieved_messages = message_store.retrieve(index="user_456_session_123")

print(retrieved_messages)
```
"""

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
)
return default_to_dict(self)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "InMemoryChatMessageStore":
def from_dict(cls, data: dict[str, Any]) -> "InMemoryChatMessageStore":
"""
Deserializes the component from a dictionary.

Expand All @@ -48,39 +64,66 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryChatMessageStore":
"""
return default_from_dict(cls, data)

def count_messages(self) -> int:
def count_messages(self, index: str) -> int:
"""
Returns the number of chat messages stored.
Returns the number of chat messages stored in this store.

:param index:
The index for which to count messages.

:returns: The number of messages.
"""
return len(self.messages)
return len(_STORAGES.get(index, []))

def write_messages(self, messages: List[ChatMessage]) -> int:
def write_messages(self, index: str, messages: list[ChatMessage]) -> int:
"""
Writes chat messages to the ChatMessageStore.

:param messages: A list of ChatMessages to write.
:param index:
The index under which to store the messages.
:param messages:
A list of ChatMessages to write.

:returns: The number of messages written.

:raises ValueError: If messages is not a list of ChatMessages.
"""
if not isinstance(messages, Iterable) or any(not isinstance(message, ChatMessage) for message in messages):
raise ValueError("Please provide a list of ChatMessages.")

self.messages.extend(messages)
for message in messages:
if index not in _STORAGES:
_STORAGES[index] = []
_STORAGES[index].append(message)

return len(messages)

def delete_messages(self) -> None:
def delete_messages(self, index: str) -> None:
"""
Deletes all stored chat messages.

:param index:
The index from which to delete messages.
"""
self.messages = []
_STORAGES.pop(index, None)

def retrieve(self) -> List[ChatMessage]:
def retrieve(self, index: str, last_k: Optional[int] = None) -> list[ChatMessage]:
"""
Retrieves all stored chat messages.

:param index:
The index from which to retrieve messages.
:param last_k:
The number of last messages to retrieve. If None, retrieves all messages.

:returns: A list of chat messages.
"""
return self.messages
if last_k is not None and last_k <= 0:
raise ValueError("last_k must be greater than 0")

messages = _STORAGES.get(index, [])

if last_k is not None:
return messages[-last_k:]

return messages
43 changes: 19 additions & 24 deletions haystack_experimental/chat_message_stores/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
#
# SPDX-License-Identifier: Apache-2.0

from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Protocol

from haystack import logging
from haystack.dataclasses import ChatMessage

logger = logging.getLogger(__name__)
# Ellipsis are needed for the type checker, it's safe to disable module-wide
# pylint: disable=unnecessary-ellipsis


class ChatMessageStore(ABC):
class ChatMessageStore(Protocol):
"""
Stores ChatMessages to be used by the components of a Pipeline.

Expand All @@ -22,53 +21,49 @@ class ChatMessageStore(ABC):
In order to write or retrieve chat messages, consider using a ChatMessageWriter or ChatMessageRetriever.
"""

@abstractmethod
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""
Serializes this store to a dictionary.

:returns: The serialized store as a dictionary.
"""
...

@classmethod
@abstractmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageStore":
def from_dict(cls, data: dict[str, Any]) -> "ChatMessageStore":
"""
Deserializes the store from a dictionary.

:param data: The dictionary to deserialize from.
:returns: The deserialized store.
"""
...

@abstractmethod
def count_messages(self) -> int:
def count_messages(self, index: str) -> int:
"""
Returns the number of chat messages stored.

:param index: The index for which to count messages.

:returns: The number of messages.
"""
...

@abstractmethod
def write_messages(self, messages: List[ChatMessage]) -> int:
def write_messages(self, index: str, messages: list[ChatMessage]) -> int:
"""
Writes chat messages to the ChatMessageStore.

:param index: The index under which to store the messages.
:param messages: A list of ChatMessages to write.
:returns: The number of messages written.

:raises ValueError: If messages is not a list of ChatMessages.
:returns: The number of messages written.
"""
...

@abstractmethod
def delete_messages(self) -> None:
def delete_messages(self, index: str) -> None:
"""
Deletes all stored chat messages.
"""

@abstractmethod
def retrieve(self) -> List[ChatMessage]:
"""
Retrieves all stored chat messages.

:returns: A list of chat messages.
:param index: The index from which to delete all messages.
"""
...
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Optional
from typing import Any, Optional

from haystack import DeserializationError, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
Expand Down Expand Up @@ -53,18 +53,17 @@ def __init__(self, message_store: ChatMessageStore, last_k: int = 10):
raise ValueError(f"last_k must be greater than 0. Currently, the last_k is {last_k}")
self.last_k = last_k

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.

:returns:
Dictionary with serialized data.
"""
message_store = self.message_store.to_dict()
return default_to_dict(self, message_store=message_store, last_k=self.last_k)
return default_to_dict(self, message_store=self.message_store.to_dict(), last_k=self.last_k)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageRetriever":
def from_dict(cls, data: dict[str, Any]) -> "ChatMessageRetriever":
"""
Deserializes the component from a dictionary.

Expand All @@ -88,21 +87,28 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageRetriever":
data["init_parameters"]["message_store"] = default_from_dict(message_store_class, message_store_data)
return default_from_dict(cls, data)

@component.output_types(messages=List[ChatMessage])
def run(self, last_k: Optional[int] = None) -> Dict[str, List[ChatMessage]]:
# TODO Consider adding messages as an optional input parameter to append to the retrieved messages
# Optionally allow for a init param to pass a custom function for how to combine the messages??
@component.output_types(messages=list[ChatMessage])
def run(self, index: str, last_k: Optional[int] = None) -> dict[str, list[ChatMessage]]:
"""
Run the ChatMessageRetriever

:param index:
A unique identifier for the chat session or conversation whose messages should be retrieved.
Each `index` corresponds to a distinct chat history stored in the underlying ChatMessageStore.
For example, use a session ID or conversation ID to isolate messages from different chat sessions.
:param last_k: The number of last messages to retrieve. This parameter takes precedence over the last_k
parameter passed to the ChatMessageRetriever constructor. If unspecified, the last_k parameter passed
to the constructor will be used.

:returns:
- `messages` - The retrieved chat messages.
:raises ValueError: If last_k is not None and is less than 1
"""
if last_k is not None and last_k <= 0:
raise ValueError("last_k must be greater than 0")

last_k = last_k or self.last_k
resolved_last_k = last_k or self.last_k

return {"messages": self.message_store.retrieve()[-last_k:]}
return {"messages": self.message_store.retrieve(index=index, last_k=resolved_last_k)}
2 changes: 1 addition & 1 deletion haystack_experimental/components/writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

from haystack_experimental.components.writers.chat_message_writer import ChatMessageWriter

_all_ = ["ChatMessageWriter"]
__all__ = ["ChatMessageWriter"]
22 changes: 11 additions & 11 deletions haystack_experimental/components/writers/chat_message_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List
from typing import Any

from haystack import DeserializationError, component, default_from_dict, default_to_dict, logging
from haystack.core.serialization import import_class_by_name
from haystack.dataclasses import ChatMessage

from haystack_experimental.chat_message_stores.types import ChatMessageStore

logger = logging.getLogger(__name__)


@component
class ChatMessageWriter:
Expand All @@ -34,7 +32,7 @@ class ChatMessageWriter:
```
"""

def __init__(self, message_store: ChatMessageStore):
def __init__(self, message_store: ChatMessageStore) -> None:
"""
Create a ChatMessageWriter component.

Expand All @@ -43,7 +41,7 @@ def __init__(self, message_store: ChatMessageStore):
"""
self.message_store = message_store

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.

Expand All @@ -53,7 +51,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, message_store=self.message_store.to_dict())

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageWriter":
def from_dict(cls, data: dict[str, Any]) -> "ChatMessageWriter":
"""
Deserializes the component from a dictionary.

Expand All @@ -79,18 +77,20 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessageWriter":
return default_from_dict(cls, data)

@component.output_types(messages_written=int)
def run(self, messages: List[ChatMessage]) -> Dict[str, int]:
def run(self, index: str, messages: list[ChatMessage]) -> dict[str, int]:
"""
Run the ChatMessageWriter on the given input data.

:param index:
A unique identifier for the chat session or conversation whose messages should be retrieved.
Each `index` corresponds to a distinct chat history stored in the underlying ChatMessageStore.
For example, use a session ID or conversation ID to isolate messages from different chat sessions.
:param messages:
A list of chat messages to write to the store.

:returns:
- `messages_written`: Number of messages written to the ChatMessageStore.

:raises ValueError:
If the specified message store is not found.
"""

messages_written = self.message_store.write_messages(messages=messages)
messages_written = self.message_store.write_messages(index=index, messages=messages)
return {"messages_written": messages_written}
Loading
Loading