Skip to content

Commit a059333

Browse files
committed
cleanup code
1 parent dbd6737 commit a059333

File tree

1 file changed

+30
-40
lines changed

1 file changed

+30
-40
lines changed

graphrag/storage/cosmosdb_pipeline_storage.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def __init__(
3636
cosmosdb_account_url: str | None,
3737
connection_string: str | None,
3838
database_name: str,
39-
encoding: str | None = None,
39+
encoding: str = "utf-8",
4040
current_container: str | None = None,
4141
):
42-
"""Initialize the CosmosDB-Storage."""
42+
"""Initialize the CosmosDB Storage."""
4343
if connection_string:
4444
self._cosmos_client = CosmosClient.from_connection_string(connection_string)
4545
else:
@@ -54,11 +54,11 @@ def __init__(
5454
credential=DefaultAzureCredential(),
5555
)
5656

57-
self._encoding = encoding or "utf-8"
57+
self._encoding = encoding
5858
self._database_name = database_name
5959
self._connection_string = connection_string
6060
self._cosmosdb_account_url = cosmosdb_account_url
61-
self._current_container = current_container or None
61+
self._current_container = current_container
6262
self._cosmosdb_account_name = (
6363
cosmosdb_account_url.split("//")[1].split(".")[0]
6464
if cosmosdb_account_url
@@ -73,27 +73,25 @@ def __init__(
7373
self._cosmosdb_account_name,
7474
self._database_name,
7575
)
76-
self.create_database()
77-
if self._current_container is not None:
78-
self.create_container()
76+
self._create_database()
77+
if self._current_container:
78+
self._create_container()
7979

80-
def create_database(self) -> None:
80+
def _create_database(self) -> None:
8181
"""Create the database if it doesn't exist."""
82-
database_name = self._database_name
83-
self._cosmos_client.create_database_if_not_exists(id=database_name)
82+
self._cosmos_client.create_database_if_not_exists(id=self._database_name)
8483

85-
def delete_database(self) -> None:
84+
def _delete_database(self) -> None:
8685
"""Delete the database if it exists."""
87-
if self.database_exists():
86+
if self._database_exists():
8887
self._cosmos_client.delete_database(self._database_name)
8988

90-
def database_exists(self) -> bool:
89+
def _database_exists(self) -> bool:
9190
"""Check if the database exists."""
92-
database_name = self._database_name
9391
database_names = [
9492
database["id"] for database in self._cosmos_client.list_databases()
9593
]
96-
return database_name in database_names
94+
return self._database_name in database_names
9795

9896
def find(
9997
self,
@@ -131,8 +129,7 @@ def item_filter(item: dict[str, Any]) -> bool:
131129
)
132130

133131
try:
134-
database_client = self._database_client
135-
container_client = database_client.get_container_client(
132+
container_client = self._database_client.get_container_client(
136133
str(self._current_container)
137134
)
138135
query = "SELECT * FROM c WHERE CONTAINS(c.id, @pattern)"
@@ -176,9 +173,8 @@ async def get(
176173
) -> Any:
177174
"""Get a file in the database for the given key."""
178175
try:
179-
database_client = self._database_client
180-
if self._current_container is not None:
181-
container_client = database_client.get_container_client(
176+
if self._current_container:
177+
container_client = self._database_client.get_container_client(
182178
self._current_container
183179
)
184180
item = container_client.read_item(item=key, partition_key=key)
@@ -199,9 +195,8 @@ async def get(
199195
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
200196
"""Set a file in the database for the given key."""
201197
try:
202-
database_client = self._database_client
203-
if self._current_container is not None:
204-
container_client = database_client.get_container_client(
198+
if self._current_container:
199+
container_client = self._database_client.get_container_client(
205200
self._current_container
206201
)
207202
if isinstance(value, bytes):
@@ -222,9 +217,8 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
222217

223218
async def has(self, key: str) -> bool:
224219
"""Check if the given file exists in the cosmosdb storage."""
225-
database_client = self._database_client
226-
if self._current_container is not None:
227-
container_client = database_client.get_container_client(
220+
if self._current_container:
221+
container_client = self._database_client.get_container_client(
228222
self._current_container
229223
)
230224
item_names = [item["id"] for item in container_client.read_all_items()]
@@ -233,9 +227,8 @@ async def has(self, key: str) -> bool:
233227

234228
async def delete(self, key: str) -> None:
235229
"""Delete the given file from the cosmosdb storage."""
236-
database_client = self._database_client
237-
if self._current_container is not None:
238-
container_client = database_client.get_container_client(
230+
if self._current_container:
231+
container_client = self._database_client.get_container_client(
239232
self._current_container
240233
)
241234
container_client.delete_item(item=key, partition_key=key)
@@ -252,27 +245,24 @@ def child(self, name: str | None) -> PipelineStorage:
252245
"""Create a child storage instance."""
253246
return self
254247

255-
def create_container(self) -> None:
248+
def _create_container(self) -> None:
256249
"""Create a container for the current container name if it doesn't exist."""
257-
database_client = self._database_client
258-
if self._current_container is not None:
250+
if self._current_container:
259251
partition_key = PartitionKey(path="/id", kind="Hash")
260-
database_client.create_container_if_not_exists(
252+
self._database_client.create_container_if_not_exists(
261253
id=self._current_container,
262254
partition_key=partition_key,
263255
)
264256

265-
def delete_container(self) -> None:
257+
def _delete_container(self) -> None:
266258
"""Delete the container with the current container name if it exists."""
267-
database_client = self._database_client
268-
if self.container_exists() and self._current_container is not None:
269-
database_client.delete_container(self._current_container)
259+
if self._container_exists() and self._current_container:
260+
self._database_client.delete_container(self._current_container)
270261

271-
def container_exists(self) -> bool:
262+
def _container_exists(self) -> bool:
272263
"""Check if the container with the current container name exists."""
273-
database_client = self._database_client
274264
container_names = [
275-
container["id"] for container in database_client.list_containers()
265+
container["id"] for container in self._database_client.list_containers()
276266
]
277267
return self._current_container in container_names
278268

0 commit comments

Comments
 (0)