Skip to content

Commit e0865f9

Browse files
committed
add find test
1 parent 731b5db commit e0865f9

File tree

5 files changed

+72
-47
lines changed

5 files changed

+72
-47
lines changed

graphrag/index/run/workflow.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ async def _inject_workflow_data_dependencies(
3333
workflow_dependencies: dict[str, list[str]],
3434
dataset: pd.DataFrame,
3535
storage: PipelineStorage,
36-
extension: str,
3736
) -> None:
3837
"""Inject the data dependencies into the workflow."""
3938
workflow.add_table(DEFAULT_INPUT_NAME, dataset)
@@ -42,7 +41,7 @@ async def _inject_workflow_data_dependencies(
4241
for id in deps:
4342
workflow_id = f"workflow:{id}"
4443
try:
45-
table = await _load_table_from_storage(f"{id}.{extension}", storage)
44+
table = await _load_table_from_storage(f"{id}.parquet", storage)
4645
except ValueError:
4746
# our workflows allow for transient tables, and we avoid putting those in storage
4847
# however, we need to keep the table in the dependency list for proper execution order.
@@ -98,7 +97,10 @@ async def _process_workflow(
9897
context.stats.workflows[workflow_name] = {"overall": 0.0}
9998

10099
await _inject_workflow_data_dependencies(
101-
workflow, workflow_dependencies, dataset, context.storage, "parquet"
100+
workflow,
101+
workflow_dependencies,
102+
dataset,
103+
context.storage,
102104
)
103105

104106
workflow_start_time = time.time()

graphrag/storage/blob_pipeline_storage.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(
3232
self,
3333
connection_string: str | None,
3434
container_name: str,
35-
encoding: str | None = None,
35+
encoding: str = "utf-8",
3636
path_prefix: str | None = None,
3737
storage_account_blob_url: str | None = None,
3838
):
@@ -50,7 +50,7 @@ def __init__(
5050
account_url=storage_account_blob_url,
5151
credential=DefaultAzureCredential(),
5252
)
53-
self._encoding = encoding or "utf-8"
53+
self._encoding = encoding
5454
self._container_name = container_name
5555
self._connection_string = connection_string
5656
self._path_prefix = path_prefix or ""
@@ -198,7 +198,7 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
198198
if isinstance(value, bytes):
199199
blob_client.upload_blob(value, overwrite=True)
200200
else:
201-
coding = encoding or "utf-8"
201+
coding = encoding or self._encoding
202202
blob_client.upload_blob(value.encode(coding), overwrite=True)
203203
except Exception:
204204
log.exception("Error setting key %s: %s", key)

graphrag/storage/cosmosdb_pipeline_storage.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
2828
_cosmosdb_account_url: str | None
2929
_connection_string: str | None
3030
_database_name: str
31-
_current_container: str | None
31+
container_name: str | None
3232
_encoding: str
3333

3434
def __init__(
@@ -58,7 +58,7 @@ def __init__(
5858
self._database_name = database_name
5959
self._connection_string = connection_string
6060
self._cosmosdb_account_url = cosmosdb_account_url
61-
self._current_container = current_container
61+
self.container_name = current_container
6262
self._cosmosdb_account_name = (
6363
cosmosdb_account_url.split("//")[1].split(".")[0]
6464
if cosmosdb_account_url
@@ -74,7 +74,7 @@ def __init__(
7474
self._database_name,
7575
)
7676
self._create_database()
77-
if self._current_container:
77+
if self.container_name:
7878
self._create_container()
7979

8080
def _create_database(self) -> None:
@@ -117,7 +117,7 @@ def find(
117117

118118
log.info(
119119
"search container %s for documents matching %s",
120-
self._current_container,
120+
self.container_name,
121121
file_pattern.pattern,
122122
)
123123

@@ -130,7 +130,7 @@ def item_filter(item: dict[str, Any]) -> bool:
130130

131131
try:
132132
container_client = self._database_client.get_container_client(
133-
str(self._current_container)
133+
str(self.container_name)
134134
)
135135
query = "SELECT * FROM c WHERE CONTAINS(c.id, @pattern)"
136136
parameters: list[dict[str, Any]] = [
@@ -173,9 +173,9 @@ async def get(
173173
) -> Any:
174174
"""Get a file in the database for the given key."""
175175
try:
176-
if self._current_container:
176+
if self.container_name:
177177
container_client = self._database_client.get_container_client(
178-
self._current_container
178+
self.container_name
179179
)
180180
item = container_client.read_item(item=key, partition_key=key)
181181
item_body = item.get("body")
@@ -195,9 +195,9 @@ async def get(
195195
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
196196
"""Set a file in the database for the given key."""
197197
try:
198-
if self._current_container:
198+
if self.container_name:
199199
container_client = self._database_client.get_container_client(
200-
self._current_container
200+
self.container_name
201201
)
202202
if isinstance(value, bytes):
203203
value_df = pd.read_parquet(BytesIO(value))
@@ -217,29 +217,29 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
217217

218218
async def has(self, key: str) -> bool:
219219
"""Check if the given file exists in the cosmosdb storage."""
220-
if self._current_container:
220+
if self.container_name:
221221
container_client = self._database_client.get_container_client(
222-
self._current_container
222+
self.container_name
223223
)
224224
item_names = [item["id"] for item in container_client.read_all_items()]
225225
return key in item_names
226226
return False
227227

228228
async def delete(self, key: str) -> None:
229229
"""Delete the given file from the cosmosdb storage."""
230-
if self._current_container:
230+
if self.container_name:
231231
container_client = self._database_client.get_container_client(
232-
self._current_container
232+
self.container_name
233233
)
234234
container_client.delete_item(item=key, partition_key=key)
235235

236236
# Function currently deletes all items within the current container, then deletes the container itself.
237237
# TODO: Decide the granularity of deletion (e.g. delete all items within the current container, delete the current container, delete the current database)
238238
async def clear(self) -> None:
239239
"""Clear the cosmosdb storage."""
240-
if self._current_container:
240+
if self.container_name:
241241
container_client = self._database_client.get_container_client(
242-
self._current_container
242+
self.container_name
243243
)
244244
for item in container_client.read_all_items():
245245
item_id = item["id"]
@@ -257,24 +257,24 @@ def child(self, name: str | None) -> PipelineStorage:
257257

258258
def _create_container(self) -> None:
259259
"""Create a container for the current container name if it doesn't exist."""
260-
if self._current_container:
260+
if self.container_name:
261261
partition_key = PartitionKey(path="/id", kind="Hash")
262262
self._database_client.create_container_if_not_exists(
263-
id=self._current_container,
263+
id=self.container_name,
264264
partition_key=partition_key,
265265
)
266266

267267
def _delete_container(self) -> None:
268268
"""Delete the container with the current container name if it exists."""
269-
if self._container_exists() and self._current_container:
270-
self._database_client.delete_container(self._current_container)
269+
if self._container_exists() and self.container_name:
270+
self._database_client.delete_container(self.container_name)
271271

272272
def _container_exists(self) -> bool:
273273
"""Check if the container with the current container name exists."""
274274
container_names = [
275275
container["id"] for container in self._database_client.list_containers()
276276
]
277-
return self._current_container in container_names
277+
return self.container_name in container_names
278278

279279

280280
# TODO remove this helper function and have the factory instantiate the class directly

graphrag/utils/storage.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""Storage functions for the GraphRAG run module."""
55

66
import logging
7-
from io import BytesIO, StringIO
7+
from io import BytesIO
88

99
import pandas as pd
1010

@@ -18,22 +18,8 @@ async def _load_table_from_storage(name: str, storage: PipelineStorage) -> pd.Da
1818
msg = f"Could not find {name} in storage!"
1919
raise ValueError(msg)
2020
try:
21-
log.info("read table from storage: %s", name)
22-
match name.split(".")[-1]:
23-
case "parquet":
24-
return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True)))
25-
case "json":
26-
return pd.read_json(
27-
StringIO(await storage.get(name)),
28-
lines=False,
29-
orient="records",
30-
)
31-
case "csv":
32-
return pd.read_csv(BytesIO(await storage.get(name, as_bytes=True)))
33-
case _:
34-
msg = f"Unknown file extension for {name}"
35-
log.exception(msg)
36-
raise
21+
log.info("reading table from storage: %s", name)
22+
return pd.read_parquet(BytesIO(await storage.get(name, as_bytes=True)))
3723
except Exception:
3824
log.exception("error loading table from storage: %s", name)
3925
raise

tests/integration/storage/test_cosmosdb_storage.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,57 @@
22
# Licensed under the MIT License
33
"""CosmosDB Storage Tests."""
44

5+
import re
56
import sys
67

78
import pytest
89

10+
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
11+
12+
# cspell:disable-next-line well-known-key
13+
WELL_KNOWN_COSMOS_ACCOUNT_URL = "https://localhost:8081"
14+
WELL_KNOWN_COSMOS_CONNECTION_STRING = "AccountEndpoint=https://localhost:8081/;AccountKey=C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==;"
15+
916
# the cosmosdb emulator is only available on windows runners at this time
1017
if not sys.platform.startswith("win"):
1118
pytest.skip("encountered windows-only tests -- skipping", allow_module_level=True)
1219

1320

14-
def test_find():
15-
print("test_find")
16-
assert True
21+
async def test_find():
22+
storage = CosmosDBPipelineStorage(
23+
cosmosdb_account_url=WELL_KNOWN_COSMOS_ACCOUNT_URL,
24+
connection_string=WELL_KNOWN_COSMOS_CONNECTION_STRING,
25+
database_name="testfind",
26+
)
27+
try:
28+
try:
29+
items = list(
30+
storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$"))
31+
)
32+
items = [item[0] for item in items]
33+
assert items == []
34+
35+
await storage.set("christmas.txt", "Merry Christmas!", encoding="utf-8")
36+
items = list(
37+
storage.find(base_dir="input", file_pattern=re.compile(r".*\.txt$"))
38+
)
39+
items = [item[0] for item in items]
40+
assert items == ["christmas.txt"]
41+
42+
await storage.set("test.txt", "Hello, World!", encoding="utf-8")
43+
items = list(storage.find(file_pattern=re.compile(r".*\.txt$")))
44+
items = [item[0] for item in items]
45+
assert items == ["christmas.txt", "test.txt"]
46+
47+
output = await storage.get("test.txt")
48+
assert output == "Hello, World!"
49+
finally:
50+
await storage.delete("test.txt")
51+
output = await storage.get("test.txt")
52+
assert output is None
53+
finally:
54+
storage._delete_container() # noqa: SLF001
1755

1856

1957
def test_child():
20-
print("test_child")
2158
assert True

0 commit comments

Comments
 (0)