Skip to content

Commit 12e1744

Browse files
Copilotjgbradley1
andcommitted
Fix pytest errors in storage factory tests by updating PipelineStorage interface implementation
Co-authored-by: jgbradley1 <[email protected]>
1 parent 728db4d commit 12e1744

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

graphrag/config/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ class GlobalSearchDefaults:
239239
class StorageDefaults:
240240
"""Default values for storage."""
241241

242-
type = StorageType.file
242+
type: StorageType = StorageType.file
243243
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
244244
connection_string: None = None
245245
container_name: None = None

tests/integration/storage/test_factory.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -103,41 +103,50 @@ def test_create_unknown_storage():
103103

104104
def test_register_class_directly_raises_error():
105105
"""Test that registering a class directly raises a TypeError."""
106+
import re
107+
from collections.abc import Iterator
108+
from typing import Any
109+
106110
from graphrag.storage.pipeline_storage import PipelineStorage
107111

108112
class CustomStorage(PipelineStorage):
109113
def __init__(self, **kwargs):
110114
pass
111115

112-
def child_exists(self, name: str) -> bool:
113-
return False
114-
115-
def create_child(self, name: str) -> PipelineStorage:
116-
return self
117-
118-
def delete_child(self, name: str) -> None:
119-
pass
120-
121-
def list_children(self) -> list[str]:
122-
return []
123-
124-
def get(self, key: str) -> bytes | None:
116+
def find(
117+
self,
118+
file_pattern: re.Pattern[str],
119+
base_dir: str | None = None,
120+
file_filter: dict[str, Any] | None = None,
121+
max_count=-1,
122+
) -> Iterator[tuple[str, dict[str, Any]]]:
123+
return iter([])
124+
125+
async def get(
126+
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
127+
) -> Any:
125128
return None
126129

127-
def set(self, key: str, value: bytes) -> None:
130+
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
128131
pass
129132

130-
def delete(self, key: str) -> None:
133+
async def delete(self, key: str) -> None:
131134
pass
132135

133-
def has(self, key: str) -> bool:
136+
async def has(self, key: str) -> bool:
134137
return False
135138

136-
def list(self) -> list[str]:
139+
async def clear(self) -> None:
140+
pass
141+
142+
def child(self, name: str | None) -> "PipelineStorage":
143+
return self
144+
145+
def keys(self) -> list[str]:
137146
return []
138147

139-
def clear(self) -> None:
140-
pass
148+
async def get_creation_date(self, key: str) -> str:
149+
return "2024-01-01 00:00:00 +0000"
141150

142151
# Attempting to register a class directly should raise TypeError
143152
with pytest.raises(

0 commit comments

Comments
 (0)