Skip to content

Commit a7416bc

Browse files
authored
Add to_dict and from_dict methods for Stores (#5541)
* Add to_dict and from_dict methods for Stores * Add release notes * Add tests with custom init parameters
1 parent 094d857 commit a7416bc

File tree

7 files changed

+164
-1
lines changed

7 files changed

+164
-1
lines changed

haystack/preview/document_stores/decorator.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from typing import Dict, Any, Type
12
import logging
23

4+
from haystack.preview.document_stores.protocols import Store
5+
from haystack.preview.document_stores.errors import StoreDeserializationError
36

47
logger = logging.getLogger(__name__)
58

@@ -27,6 +30,9 @@ def _decorate(self, cls):
2730
self.registry[cls.__name__] = cls
2831
logger.debug("Registered Store %s", cls)
2932

33+
cls.to_dict = _default_store_to_dict
34+
cls.from_dict = classmethod(_default_store_from_dict)
35+
3036
return cls
3137

3238
def __call__(self, cls=None):
@@ -37,3 +43,28 @@ def __call__(self, cls=None):
3743

3844

3945
store = _Store()
46+
47+
48+
def _default_store_to_dict(store_: Store) -> Dict[str, Any]:
49+
"""
50+
Default store serializer.
51+
Serializes a store to a dictionary.
52+
"""
53+
return {
54+
"hash": id(store_),
55+
"type": store_.__class__.__name__,
56+
"init_parameters": getattr(store_, "init_parameters", {}),
57+
}
58+
59+
60+
def _default_store_from_dict(cls: Type[Store], data: Dict[str, Any]) -> Store:
61+
"""
62+
Default store deserializer.
63+
The "type" field in `data` must match the class that is being deserialized into.
64+
"""
65+
init_params = data.get("init_parameters", {})
66+
if "type" not in data:
67+
raise StoreDeserializationError("Missing 'type' in store serialization data")
68+
if data["type"] != cls.__name__:
69+
raise StoreDeserializationError(f"Store '{data['type']}' can't be deserialized as '{cls.__name__}'")
70+
return cls(**init_params)

haystack/preview/document_stores/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ class DuplicateDocumentError(StoreError):
1212

1313
class MissingDocumentError(StoreError):
1414
pass
15+
16+
17+
class StoreDeserializationError(StoreError):
18+
pass

haystack/preview/document_stores/memory/document_store.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ def __init__(
4747
self.bm25_algorithm = algorithm_class
4848
self.bm25_parameters = bm25_parameters or {}
4949

50+
# Used to convert this instance to a dictionary for serialization
51+
self.init_parameters = {
52+
"bm25_tokenization_regex": bm25_tokenization_regex,
53+
"bm25_algorithm": bm25_algorithm,
54+
"bm25_parameters": self.bm25_parameters,
55+
}
56+
5057
def count_documents(self) -> int:
5158
"""
5259
Returns the number of how many documents are present in the document store.

haystack/preview/document_stores/protocols.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,17 @@ class Store(Protocol):
2525
you're using.
2626
"""
2727

28+
def to_dict(self) -> Dict[str, Any]:
29+
"""
30+
Serializes this store to a dictionary.
31+
"""
32+
33+
@classmethod
34+
def from_dict(cls, data: Dict[str, Any]) -> "Store":
35+
"""
36+
Deserializes the store from a dictionary.
37+
"""
38+
2839
def count_documents(self) -> int:
2940
"""
3041
Returns the number of documents stored.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
features:
3+
- Add `from_dict` and `to_dict` methods to `Store` `Protocol`
4+
- Add default `from_dict` and `to_dict` implementations to classes decorated with `@store`
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from unittest.mock import Mock
2+
3+
import pytest
4+
5+
from haystack.preview.testing.factory import store_class
6+
from haystack.preview.document_stores.decorator import _default_store_to_dict, _default_store_from_dict
7+
from haystack.preview.document_stores.errors import StoreDeserializationError
8+
9+
10+
@pytest.mark.unit
11+
def test_default_store_to_dict():
12+
MyStore = store_class("MyStore")
13+
comp = MyStore()
14+
res = _default_store_to_dict(comp)
15+
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {}}
16+
17+
18+
@pytest.mark.unit
19+
def test_default_store_to_dict_with_custom_init_parameters():
20+
extra_fields = {"init_parameters": {"custom_param": True}}
21+
MyStore = store_class("MyStore", extra_fields=extra_fields)
22+
comp = MyStore()
23+
res = _default_store_to_dict(comp)
24+
assert res == {"hash": id(comp), "type": "MyStore", "init_parameters": {"custom_param": True}}
25+
26+
27+
@pytest.mark.unit
28+
def test_default_store_from_dict():
29+
MyStore = store_class("MyStore")
30+
comp = _default_store_from_dict(MyStore, {"type": "MyStore"})
31+
assert isinstance(comp, MyStore)
32+
33+
34+
@pytest.mark.unit
35+
def test_default_store_from_dict_with_custom_init_parameters():
36+
def store_init(self, custom_param: int):
37+
self.custom_param = custom_param
38+
39+
extra_fields = {"__init__": store_init}
40+
MyStore = store_class("MyStore", extra_fields=extra_fields)
41+
comp = _default_store_from_dict(MyStore, {"type": "MyStore", "init_parameters": {"custom_param": 100}})
42+
assert isinstance(comp, MyStore)
43+
assert comp.custom_param == 100
44+
45+
46+
@pytest.mark.unit
47+
def test_default_store_from_dict_without_type():
48+
with pytest.raises(StoreDeserializationError, match="Missing 'type' in store serialization data"):
49+
_default_store_from_dict(Mock, {})
50+
51+
52+
@pytest.mark.unit
53+
def test_default_store_from_dict_unregistered_store(request):
54+
# We use the test function name as store name to make sure it's not registered.
55+
# Since the registry is global we risk to have a store with the same name registered in another test.
56+
store_name = request.node.name
57+
58+
with pytest.raises(StoreDeserializationError, match=f"Store '{store_name}' can't be deserialized as 'Mock'"):
59+
_default_store_from_dict(Mock, {"type": store_name})

test/preview/document_stores/test_memory.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import logging
2+
from unittest.mock import patch
23

34
import pandas as pd
45
import pytest
56

67
from haystack.preview import Document
78
from haystack.preview.document_stores import Store, MemoryDocumentStore
8-
99
from haystack.testing.preview.document_store import DocumentStoreBaseTests
1010

1111

@@ -18,6 +18,53 @@ class TestMemoryDocumentStore(DocumentStoreBaseTests):
1818
def docstore(self) -> MemoryDocumentStore:
1919
return MemoryDocumentStore()
2020

21+
@pytest.mark.unit
22+
def test_to_dict(self):
23+
store = MemoryDocumentStore()
24+
data = store.to_dict()
25+
assert data == {
26+
"hash": id(store),
27+
"type": "MemoryDocumentStore",
28+
"init_parameters": {
29+
"bm25_tokenization_regex": r"(?u)\b\w\w+\b",
30+
"bm25_algorithm": "BM25Okapi",
31+
"bm25_parameters": {},
32+
},
33+
}
34+
35+
@pytest.mark.unit
36+
def test_to_dict_with_custom_init_parameters(self):
37+
store = MemoryDocumentStore(
38+
bm25_tokenization_regex="custom_regex", bm25_algorithm="BM25Plus", bm25_parameters={"key": "value"}
39+
)
40+
data = store.to_dict()
41+
assert data == {
42+
"hash": id(store),
43+
"type": "MemoryDocumentStore",
44+
"init_parameters": {
45+
"bm25_tokenization_regex": "custom_regex",
46+
"bm25_algorithm": "BM25Plus",
47+
"bm25_parameters": {"key": "value"},
48+
},
49+
}
50+
51+
@pytest.mark.unit
52+
@patch("haystack.preview.document_stores.memory.document_store.re")
53+
def test_from_dict(self, mock_regex):
54+
data = {
55+
"type": "MemoryDocumentStore",
56+
"init_parameters": {
57+
"bm25_tokenization_regex": "custom_regex",
58+
"bm25_algorithm": "BM25Plus",
59+
"bm25_parameters": {"key": "value"},
60+
},
61+
}
62+
store = MemoryDocumentStore.from_dict(data)
63+
mock_regex.compile.assert_called_with("custom_regex")
64+
assert store.tokenizer
65+
assert store.bm25_algorithm.__name__ == "BM25Plus"
66+
assert store.bm25_parameters == {"key": "value"}
67+
2168
@pytest.mark.unit
2269
def test_bm25_retrieval(self, docstore: Store):
2370
docstore = MemoryDocumentStore()

0 commit comments

Comments
 (0)