Skip to content

Commit 4404668

Browse files
authored
Add graphrag-storage. (#2127)
* Add graphrag-storage.
1 parent 20a96cb commit 4404668

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+793
-552
lines changed

docs/config/yaml.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th
8181
#### Fields
8282

8383
- `storage` **StorageConfig**
84-
- `type` **file|blob|cosmosdb** - The storage type to use. Default=`file`
84+
- `type` **file|memory|blob|cosmosdb** - The storage type to use. Default=`file`
8585
- `base_dir` **str** - The base directory to write output artifacts to, relative to the root.
8686
- `connection_string` **str** - (blob/cosmosdb only) The Azure Storage connection string.
8787
- `container_name` **str** - (blob/cosmosdb only) The Azure Storage container name.

packages/graphrag-common/graphrag_common/factory/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33

44
"""The GraphRAG factory module."""
55

6-
from graphrag_common.factory.factory import Factory
6+
from graphrag_common.factory.factory import Factory, ServiceScope
77

8-
__all__ = ["Factory"]
8+
__all__ = ["Factory", "ServiceScope"]

packages/graphrag-common/graphrag_common/factory/factory.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,14 @@ def create(self, strategy: str, init_args: dict[str, Any] | None = None) -> T:
8484
msg = f"Strategy '{strategy}' is not registered. Registered strategies are: {', '.join(list(self._service_initializers.keys()))}"
8585
raise ValueError(msg)
8686

87+
# Delete entries with value None
88+
init_args = {k: v for k, v in (init_args or {}).items() if v is not None}
89+
8790
service_descriptor = self._service_initializers[strategy]
8891
if service_descriptor.scope == "singleton":
8992
if strategy not in self._initialized_services:
9093
self._initialized_services[strategy] = service_descriptor.initializer(
91-
**(init_args or {})
94+
**init_args
9295
)
9396
return self._initialized_services[strategy]
9497

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# GraphRAG Storage
2+
3+
## Basic
4+
5+
```python
6+
import asyncio
7+
from graphrag_storage import StorageConfig, create_storage, StorageType
8+
9+
async def run():
10+
storage = create_storage(
11+
StorageConfig(
12+
type=StorageType.File
13+
base_dir="output"
14+
)
15+
)
16+
17+
await storage.set("my_key", "value")
18+
print(await storage.get("my_key"))
19+
20+
if __name__ == "__main__":
21+
asyncio.run(run())
22+
```
23+
24+
## Custom Storage
25+
26+
```python
27+
import asyncio
28+
from typing import Any
29+
from graphrag_storage import Storage, StorageConfig, create_storage, register_storage
30+
31+
class MyStorage(Storage):
32+
def __init__(self, some_setting: str, optional_setting: str = "default setting", **kwargs: Any):
33+
# Validate settings and initialize
34+
...
35+
36+
#Implement rest of interface
37+
...
38+
39+
register_storage("MyStorage", MyStorage)
40+
41+
async def run():
42+
storage = create_storage(
43+
StorageConfig(
44+
type="MyStorage"
45+
some_setting="My Setting"
46+
)
47+
)
48+
# Or use the factory directly to instantiate with a dict instead of using
49+
# StorageConfig + create_factory
50+
# from graphrag_storage.storage_factory import storage_factory
51+
# storage = storage_factory.create(strategy="MyStorage", init_args={"some_setting": "My Setting"})
52+
53+
await storage.set("my_key", "value")
54+
print(await storage.get("my_key"))
55+
56+
if __name__ == "__main__":
57+
asyncio.run(run())
58+
```
59+
60+
### Details
61+
62+
By default, the `create_storage` comes with the following storage providers registered that correspond to the entries in the `StorageType` enum.
63+
64+
- `FileStorage`
65+
- `AzureBlobStorage`
66+
- `AzureCosmosStorage`
67+
- `MemoryStorage`
68+
69+
The preregistration happens dynamically, e.g., `FileStorage` is only imported and registered if you request a `FileStorage` with `create_storage(StorageType.File, ...)`. There is no need to manually import and register builtin storage providers when using `create_storage`.
70+
71+
If you want a clean factory with no preregistered storage providers then directly import `storage_factory` and bypass using `create_storage`. The downside is that `storage_factory.create` uses a dict for init args instead of the strongly typed `StorageConfig` used with `create_storage`.
72+
73+
```python
74+
from graphrag_storage.storage_factory import storage_factory
75+
from graphrag_storage.file_storage import FileStorage
76+
77+
# storage_factory has no preregistered providers so you must register any
78+
# providers you plan on using.
79+
# May also register a custom implementation, see above for example.
80+
storage_factory.register("my_storage_key", FileStorage)
81+
82+
storage = storage_factory.create(strategy="my_storage_key", init_args={"base_dir": "...", "other_settings": "..."})
83+
84+
...
85+
86+
```
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""The GraphRAG Storage package."""
5+
6+
from graphrag_storage.storage import Storage
7+
from graphrag_storage.storage_config import StorageConfig
8+
from graphrag_storage.storage_factory import (
9+
create_storage,
10+
register_storage,
11+
)
12+
from graphrag_storage.storage_type import StorageType
13+
14+
__all__ = [
15+
"Storage",
16+
"StorageConfig",
17+
"StorageType",
18+
"create_storage",
19+
"register_storage",
20+
]

packages/graphrag/graphrag/storage/blob_pipeline_storage.py renamed to packages/graphrag-storage/graphrag_storage/azure_blob_storage.py

Lines changed: 41 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) 2024 Microsoft Corporation.
22
# Licensed under the MIT License
33

4-
"""Azure Blob Storage implementation of PipelineStorage."""
4+
"""Azure Blob Storage implementation of Storage."""
55

66
import logging
77
import re
@@ -12,61 +12,68 @@
1212
from azure.identity import DefaultAzureCredential
1313
from azure.storage.blob import BlobServiceClient
1414

15-
from graphrag.storage.pipeline_storage import (
16-
PipelineStorage,
15+
from graphrag_storage.storage import (
16+
Storage,
1717
get_timestamp_formatted_with_local_tz,
1818
)
1919

2020
logger = logging.getLogger(__name__)
2121

2222

23-
class BlobPipelineStorage(PipelineStorage):
23+
class AzureBlobStorage(Storage):
2424
"""The Blob-Storage implementation."""
2525

2626
_connection_string: str | None
2727
_container_name: str
2828
_base_dir: str | None
2929
_encoding: str
3030
_storage_account_blob_url: str | None
31+
_blob_service_client: BlobServiceClient
32+
_storage_account_name: str | None
3133

32-
def __init__(self, **kwargs: Any) -> None:
34+
def __init__(
35+
self,
36+
container_name: str,
37+
account_url: str | None = None,
38+
connection_string: str | None = None,
39+
base_dir: str | None = None,
40+
encoding: str = "utf-8",
41+
**kwargs: Any,
42+
) -> None:
3343
"""Create a new BlobStorage instance."""
34-
connection_string = kwargs.get("connection_string")
35-
storage_account_blob_url = kwargs.get("storage_account_blob_url")
36-
base_dir = kwargs.get("base_dir")
37-
container_name = kwargs["container_name"]
38-
if container_name is None:
39-
msg = "No container name provided for blob storage."
40-
raise ValueError(msg)
41-
if connection_string is None and storage_account_blob_url is None:
42-
msg = "No storage account blob url provided for blob storage."
44+
if connection_string is not None and account_url is not None:
45+
msg = "AzureBlobStorage requires only one of connection_string or storage_account_blob_url to be specified, not both."
46+
logger.error(msg)
4347
raise ValueError(msg)
4448

49+
_validate_blob_container_name(container_name)
50+
4551
logger.info(
46-
"Creating blob storage at [%s] and base_dir [%s]", container_name, base_dir
52+
"Creating blob storage at [%s] and base_dir [%s]",
53+
container_name,
54+
base_dir,
4755
)
4856
if connection_string:
4957
self._blob_service_client = BlobServiceClient.from_connection_string(
5058
connection_string
5159
)
52-
else:
53-
if storage_account_blob_url is None:
54-
msg = "Either connection_string or storage_account_blob_url must be provided."
55-
raise ValueError(msg)
56-
60+
elif account_url:
5761
self._blob_service_client = BlobServiceClient(
58-
account_url=storage_account_blob_url,
62+
account_url=account_url,
5963
credential=DefaultAzureCredential(),
6064
)
61-
self._encoding = kwargs.get("encoding", "utf-8")
65+
else:
66+
msg = "AzureBlobStorage requires either a connection_string or storage_account_blob_url to be specified."
67+
logger.error(msg)
68+
raise ValueError(msg)
69+
70+
self._encoding = encoding
6271
self._container_name = container_name
6372
self._connection_string = connection_string
6473
self._base_dir = base_dir
65-
self._storage_account_blob_url = storage_account_blob_url
74+
self._storage_account_blob_url = account_url
6675
self._storage_account_name = (
67-
storage_account_blob_url.split("//")[1].split(".")[0]
68-
if storage_account_blob_url
69-
else None
76+
account_url.split("//")[1].split(".")[0] if account_url else None
7077
)
7178
self._create_container()
7279

@@ -208,17 +215,17 @@ async def delete(self, key: str) -> None:
208215
async def clear(self) -> None:
209216
"""Clear the cache."""
210217

211-
def child(self, name: str | None) -> "PipelineStorage":
218+
def child(self, name: str | None) -> "Storage":
212219
"""Create a child storage instance."""
213220
if name is None:
214221
return self
215222
path = str(Path(self._base_dir) / name) if self._base_dir else name
216-
return BlobPipelineStorage(
223+
return AzureBlobStorage(
217224
connection_string=self._connection_string,
218225
container_name=self._container_name,
219226
encoding=self._encoding,
220227
base_dir=path,
221-
storage_account_blob_url=self._storage_account_blob_url,
228+
account_url=self._storage_account_blob_url,
222229
)
223230

224231
def keys(self) -> list[str]:
@@ -245,7 +252,7 @@ async def get_creation_date(self, key: str) -> str:
245252
return ""
246253

247254

248-
def validate_blob_container_name(container_name: str):
255+
def _validate_blob_container_name(container_name: str) -> None:
249256
"""
250257
Check if the provided blob container name is valid based on Azure rules.
251258
@@ -265,34 +272,7 @@ def validate_blob_container_name(container_name: str):
265272
-------
266273
bool: True if valid, False otherwise.
267274
"""
268-
# Check the length of the name
269-
if len(container_name) < 3 or len(container_name) > 63:
270-
return ValueError(
271-
f"Container name must be between 3 and 63 characters in length. Name provided was {len(container_name)} characters long."
272-
)
273-
274-
# Check if the name starts with a letter or number
275-
if not container_name[0].isalnum():
276-
return ValueError(
277-
f"Container name must start with a letter or number. Starting character was {container_name[0]}."
278-
)
279-
280-
# Check for valid characters (letters, numbers, hyphen) and lowercase letters
281-
if not re.match(r"^[a-z0-9-]+$", container_name):
282-
return ValueError(
283-
f"Container name must only contain:\n- lowercase letters\n- numbers\n- or hyphens\nName provided was {container_name}."
284-
)
285-
286-
# Check for consecutive hyphens
287-
if "--" in container_name:
288-
return ValueError(
289-
f"Container name cannot contain consecutive hyphens. Name provided was {container_name}."
290-
)
291-
292-
# Check for hyphens at the end of the name
293-
if container_name[-1] == "-":
294-
return ValueError(
295-
f"Container name cannot end with a hyphen. Name provided was {container_name}."
296-
)
297-
298-
return True
275+
# Match alphanumeric or single hyphen not at the start or end, repeated 3-63 times.
276+
if not re.match(r"^(?:[0-9a-z]|(?<!^)-(?!$)){3,63}$", container_name):
277+
msg = f"Container name must be between 3 and 63 characters long and contain only lowercase letters, numbers, or hyphens. Name provided was {container_name}."
278+
raise ValueError(msg)

0 commit comments

Comments
 (0)