Skip to content

Commit 1607087

Browse files
author
Snehit Gajjar
committed
Added tests for memory_pipeline_storage
1 parent 8bd8fb2 commit 1607087

File tree

3 files changed

+100
-5
lines changed

3 files changed

+100
-5
lines changed

graphrag/storage/memory_pipeline_storage.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
"""A module containing 'InMemoryStorage' model."""
55

6-
import asyncio
76
import logging
87
import re
98
from collections.abc import Iterator
@@ -36,7 +35,8 @@ def find(
3635
"""Find files in the storage using a file pattern, as well as a custom filter function."""
3736
# In-memory storage does not support file finding
3837
results: list[tuple[str, dict[str, Any]]] = []
39-
results = [(key, {}) for key in self._storage["files"]]
38+
if self._storage.get("files") is not None:
39+
results = [(key, {}) for key in self._storage["files"]]
4040
return iter(results)
4141

4242
async def get(
@@ -84,7 +84,7 @@ async def delete(self, key: str) -> None:
8484
Args:
8585
- key - The key to delete.
8686
"""
87-
del self._storage[key]
87+
del self._storage["files"][key]
8888

8989
async def clear(self) -> None:
9090
"""Clear the storage."""
@@ -109,14 +109,19 @@ def set_sync(self, key: str, value: Any, encoding: str | None = None) -> None:
109109
- key - The key to set the value for.
110110
- value - The value to set.
111111
"""
112-
task = asyncio.create_task(self.set(key, value, encoding))
113-
logger.info("Setting value for key '%s' in memory storage", task.get_name())
112+
logger.info("Setting value for key '%s' in memory storage", key)
113+
if "files" not in self._storage:
114+
self._storage["files"] = {}
115+
self._storage["files"][key] = value
114116

115117

116118
def create_memory_storage(**kwargs: Any) -> PipelineStorage:
117119
"""Create a memory based storage."""
118120
logger.info("Creating memory storage")
119121
memorystorage = MemoryPipelineStorage()
122+
if kwargs.get("input_files") is None:
123+
return memorystorage
124+
120125
for key, value in kwargs["input_files"].items():
121126
memorystorage.set_sync(key, value)
122127
return memorystorage

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ mkdocs-jupyter = "^0.25.1"
109109
mkdocs-exclude-search = "^0.6.6"
110110
pytest-dotenv = "^0.5.2"
111111
mkdocs-typer = "^0.0.3"
112+
freezegun = "^1.5.3"
112113

113114
[build-system]
114115
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
@@ -268,3 +269,6 @@ asyncio_default_fixture_loop_scope = "function"
268269
asyncio_mode = "auto"
269270
timeout = 1000
270271
env_files = [".env"]
272+
pythonpath = [
273+
"./tests"
274+
]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
"""Blob Storage Tests."""
4+
5+
import os
6+
import re
7+
from datetime import datetime, timezone
8+
9+
import pandas as pd
10+
from freezegun import freeze_time
11+
12+
from graphrag.storage.memory_pipeline_storage import (
13+
MemoryPipelineStorage,
14+
create_memory_storage,
15+
)
16+
17+
__dirname__ = os.path.dirname(__file__)
18+
19+
20+
async def test_find():
21+
input_files = {
22+
"input_files": {
23+
"tests/fixtures/text/input/dulce.txt": pd.DataFrame({
24+
"text": ["Dulce et decorum est"],
25+
"creation_date": [datetime(2023, 1, 1, tzinfo=timezone.utc)],
26+
}),
27+
}
28+
}
29+
30+
storage = create_memory_storage(**input_files)
31+
32+
items = list(
33+
storage.find(
34+
base_dir="any/path",
35+
file_pattern=re.compile(r".*\.txt$"),
36+
file_filter=None,
37+
)
38+
)
39+
assert items == [("tests/fixtures/text/input/dulce.txt", {})]
40+
output = await storage.get("tests/fixtures/text/input/dulce.txt")
41+
assert len(output) > 0
42+
43+
await storage.set("test.txt", "Hello, World!", encoding="utf-8")
44+
output = await storage.get("test.txt")
45+
assert output == "Hello, World!"
46+
await storage.delete("test.txt")
47+
output = await storage.get("test.txt")
48+
assert output is None
49+
50+
51+
@freeze_time("2023-10-01T15:55:00.12345")
52+
async def test_get_creation_date():
53+
storage = MemoryPipelineStorage()
54+
55+
creation_date = await storage.get_creation_date(
56+
"tests/fixtures/text/input/dulce.txt"
57+
)
58+
59+
parsed_datetime = datetime.now(timezone.utc).isoformat()
60+
61+
assert parsed_datetime == creation_date
62+
63+
64+
def test_child():
65+
input_files = {
66+
"input_files": {
67+
"tests/fixtures/text/input/dulce.txt": pd.DataFrame({
68+
"text": ["Dulce et decorum est"],
69+
"creation_date": [datetime(2023, 1, 1, tzinfo=timezone.utc)],
70+
}),
71+
}
72+
}
73+
74+
storage = create_memory_storage(**input_files)
75+
76+
child_storage = storage.child("tests/fixtures/text/input")
77+
items = list(
78+
storage.find(base_dir="any/path", file_pattern=re.compile(r".*\.txt$"))
79+
)
80+
assert items == [("tests/fixtures/text/input/dulce.txt", {})]
81+
assert child_storage == storage
82+
83+
items = list(
84+
child_storage.find(base_dir="any/path", file_pattern=re.compile(r".*\.txt$"))
85+
)
86+
assert items == [("tests/fixtures/text/input/dulce.txt", {})]

0 commit comments

Comments
 (0)