11from abc import ABC , abstractmethod
22from cachetools import TTLCache
3-
3+ from azure .cosmos import CosmosClient , exceptions
4+ from azure .identity import DefaultAzureCredential
45
56class StateStore (ABC ):
67 @abstractmethod
7- def get_state (self , thread_id ) :
8+ def get_state (self , thread_id : str ) -> dict :
89 pass
910
1011 @abstractmethod
11- def save_state (self , thread_id , state ) :
12+ def save_state (self , thread_id : str , state : dict ) -> None :
1213 pass
1314
1415
@@ -21,3 +22,39 @@ def get_state(self, thread_id: str) -> dict:
2122
2223 def save_state (self , thread_id : str , state : dict ) -> None :
2324 self .cache [thread_id ] = state
25+
26+
27+ class CosmosStateStore (StateStore ):
28+ def __init__ (self , endpoint , database , container , partition_key = None ):
29+ client = CosmosClient (
30+ url = endpoint ,
31+ credential = DefaultAzureCredential (),
32+ )
33+ database_client = client .get_database_client (database )
34+ self ._db = database_client .get_container_client (container )
35+ self .partition_key = partition_key
36+
37+ # Set partition key field name
38+ props = self ._db .read ()
39+ pk_paths = props ["partitionKey" ]["paths" ]
40+ if (len (pk_paths ) != 1 ):
41+ raise ValueError ("Only single partition key is supported" )
42+ self .partition_key_name = pk_paths [0 ].lstrip ("/" )
43+ if ("/" in self .partition_key_name ):
44+ raise ValueError ("Only top-level partition key is supported" )
45+
46+ def get_state (self , thread_id : str ) -> dict :
47+ try :
48+ item = self ._db .read_item (item = thread_id , partition_key = self .partition_key )
49+ return item ["state" ]
50+ except exceptions .CosmosResourceNotFoundError :
51+ return None
52+
53+ def save_state (self , thread_id : str , state : dict ) -> None :
54+ self ._db .upsert_item (
55+ body = {
56+ self .partition_key_name : self .partition_key ,
57+ "id" : thread_id ,
58+ "state" : state ,
59+ }
60+ )
0 commit comments