Skip to content

Commit c963bf0

Browse files
Kristian NylundKristian Nylund
authored andcommitted
Added Cosmos state store
1 parent 9883baf commit c963bf0

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

text_2_sql/autogen/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ dependencies = [
1818
"sqlparse>=0.4.4",
1919
"nltk>=3.8.1",
2020
"cachetools>=5.5.1",
21+
"azure-cosmos>=4.9.0",
22+
"azure-identity>=1.19.0",
2123
]
2224

2325
[dependency-groups]

text_2_sql/autogen/src/autogen_text_2_sql/state_store.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from abc import ABC, abstractmethod
22
from cachetools import TTLCache
3-
3+
from azure.cosmos import CosmosClient, exceptions
4+
from azure.identity import DefaultAzureCredential
45

56
class StateStore(ABC):
67
@abstractmethod
7-
def get_state(self, thread_id):
8+
def get_state(self, thread_id: str) -> dict:
89
pass
910

1011
@abstractmethod
11-
def save_state(self, thread_id, state):
12+
def save_state(self, thread_id: str, state: dict) -> None:
1213
pass
1314

1415

@@ -21,3 +22,39 @@ def get_state(self, thread_id: str) -> dict:
2122

2223
def save_state(self, thread_id: str, state: dict) -> None:
2324
self.cache[thread_id] = state
25+
26+
27+
class CosmosStateStore(StateStore):
28+
def __init__(self, endpoint, database, container, partition_key = None):
29+
client = CosmosClient(
30+
url=endpoint,
31+
credential=DefaultAzureCredential(),
32+
)
33+
database_client = client.get_database_client(database)
34+
self._db = database_client.get_container_client(container)
35+
self.partition_key = partition_key
36+
37+
# Set partition key field name
38+
props = self._db.read()
39+
pk_paths = props["partitionKey"]["paths"]
40+
if (len(pk_paths) != 1):
41+
raise ValueError("Only single partition key is supported")
42+
self.partition_key_name = pk_paths[0].lstrip("/")
43+
if ("/" in self.partition_key_name):
44+
raise ValueError("Only top-level partition key is supported")
45+
46+
def get_state(self, thread_id: str) -> dict:
47+
try:
48+
item = self._db.read_item(item=thread_id, partition_key=self.partition_key)
49+
return item["state"]
50+
except exceptions.CosmosResourceNotFoundError:
51+
return None
52+
53+
def save_state(self, thread_id: str, state: dict) -> None:
54+
self._db.upsert_item(
55+
body={
56+
self.partition_key_name: self.partition_key,
57+
"id": thread_id,
58+
"state": state,
59+
}
60+
)

uv.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)