2525
2626import asyncio
2727import json
28+ import random
2829import time
29- from typing import Any , Literal
30+ from typing import Any , Final , Literal
3031
3132try :
3233 from dapr .aio .clients import DaprClient
33- from dapr .clients .grpc ._state import Consistency , StateOptions
34+ from dapr .clients .grpc ._state import Concurrency , Consistency , StateOptions
3435except ImportError as e :
3536 raise ImportError (
3637 "DaprSession requires the 'dapr' package. Install it with: pip install dapr"
4748DAPR_CONSISTENCY_EVENTUAL : ConsistencyLevel = "eventual"
4849DAPR_CONSISTENCY_STRONG : ConsistencyLevel = "strong"
4950
51+ _MAX_WRITE_ATTEMPTS : Final [int ] = 5
52+ _RETRY_BASE_DELAY_SECONDS : Final [float ] = 0.05
53+ _RETRY_MAX_DELAY_SECONDS : Final [float ] = 1.0
54+
5055
5156class DaprSession (SessionABC ):
5257 """Dapr State Store implementation of :pyclass:`agents.memory.session.Session`."""
@@ -130,12 +135,17 @@ def _get_read_metadata(self) -> dict[str, str]:
130135 metadata ["consistency" ] = self ._consistency
131136 return metadata
132137
133- def _get_state_options (self ) -> StateOptions | None :
134- """Get StateOptions for write/delete consistency level."""
138+ def _get_state_options (self , * , concurrency : Concurrency | None = None ) -> StateOptions | None :
139+ """Get StateOptions configured with consistency and optional concurrency."""
140+ options_kwargs : dict [str , Any ] = {}
135141 if self ._consistency == DAPR_CONSISTENCY_STRONG :
136- return StateOptions ( consistency = Consistency .strong )
142+ options_kwargs [ " consistency" ] = Consistency .strong
137143 elif self ._consistency == DAPR_CONSISTENCY_EVENTUAL :
138- return StateOptions (consistency = Consistency .eventual )
144+ options_kwargs ["consistency" ] = Consistency .eventual
145+ if concurrency is not None :
146+ options_kwargs ["concurrency" ] = concurrency
147+ if options_kwargs :
148+ return StateOptions (** options_kwargs )
139149 return None
140150
141151 def _get_metadata (self ) -> dict [str , str ]:
@@ -153,6 +163,57 @@ async def _deserialize_item(self, item: str) -> TResponseInputItem:
153163 """Deserialize a JSON string to an item. Can be overridden by subclasses."""
154164 return json .loads (item ) # type: ignore[no-any-return]
155165
166+ def _decode_messages (self , data : bytes | None ) -> list [Any ]:
167+ if not data :
168+ return []
169+ try :
170+ messages_json = data .decode ("utf-8" )
171+ messages = json .loads (messages_json )
172+ if isinstance (messages , list ):
173+ return list (messages )
174+ except (json .JSONDecodeError , UnicodeDecodeError ):
175+ return []
176+ return []
177+
178+ def _calculate_retry_delay (self , attempt : int ) -> float :
179+ base : float = _RETRY_BASE_DELAY_SECONDS * (2 ** max (0 , attempt - 1 ))
180+ delay : float = min (base , _RETRY_MAX_DELAY_SECONDS )
181+ # Add jitter (10%) similar to tracing processors to avoid thundering herd.
182+ return delay + random .uniform (0 , 0.1 * delay )
183+
184+ def _is_concurrency_conflict (self , error : Exception ) -> bool :
185+ code_attr = getattr (error , "code" , None )
186+ if callable (code_attr ):
187+ try :
188+ status_code = code_attr ()
189+ except Exception :
190+ status_code = None
191+ if status_code is not None :
192+ status_name = getattr (status_code , "name" , str (status_code ))
193+ if status_name in {"ABORTED" , "FAILED_PRECONDITION" }:
194+ return True
195+ message = str (error ).lower ()
196+ conflict_markers = (
197+ "etag mismatch" ,
198+ "etag does not match" ,
199+ "precondition failed" ,
200+ "concurrency conflict" ,
201+ "invalid etag" ,
202+ "failed to set key" , # Redis state store Lua script error during conditional write
203+ "user_script" , # Redis script failure hint
204+ )
205+ return any (marker in message for marker in conflict_markers )
206+
207+ async def _handle_concurrency_conflict (self , error : Exception , attempt : int ) -> bool :
208+ if not self ._is_concurrency_conflict (error ):
209+ return False
210+ if attempt >= _MAX_WRITE_ATTEMPTS :
211+ return False
212+ delay = self ._calculate_retry_delay (attempt )
213+ if delay > 0 :
214+ await asyncio .sleep (delay )
215+ return True
216+
156217 # ------------------------------------------------------------------
157218 # Session protocol implementation
158219 # ------------------------------------------------------------------
@@ -175,41 +236,24 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
175236 state_metadata = self ._get_read_metadata (),
176237 )
177238
178- if not response .data :
239+ messages = self ._decode_messages (response .data )
240+ if not messages :
179241 return []
180-
181- try :
182- # Parse the messages list from JSON
183- messages_json = response .data .decode ("utf-8" )
184- messages = json .loads (messages_json )
185-
186- if not isinstance (messages , list ):
242+ if limit is not None :
243+ if limit <= 0 :
187244 return []
188-
189- # Apply limit if specified
190- if limit is not None :
191- if limit <= 0 :
192- return []
193- # Return the latest N items
194- messages = messages [- limit :]
195-
196- items : list [TResponseInputItem ] = []
197- for msg in messages :
198- try :
199- if isinstance (msg , str ):
200- item = await self ._deserialize_item (msg )
201- else :
202- item = msg # Already deserialized
203- items .append (item )
204- except (json .JSONDecodeError , TypeError ):
205- # Skip corrupted messages
206- continue
207-
208- return items
209-
210- except (json .JSONDecodeError , UnicodeDecodeError ):
211- # Return empty list for corrupted data
212- return []
245+ messages = messages [- limit :]
246+ items : list [TResponseInputItem ] = []
247+ for msg in messages :
248+ try :
249+ if isinstance (msg , str ):
250+ item = await self ._deserialize_item (msg )
251+ else :
252+ item = msg
253+ items .append (item )
254+ except (json .JSONDecodeError , TypeError ):
255+ continue
256+ return items
213257
214258 async def add_items (self , items : list [TResponseInputItem ]) -> None :
215259 """Add new items to the conversation history.
@@ -221,38 +265,34 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
221265 return
222266
223267 async with self ._lock :
224- # Get existing messages with consistency level
225- response = await self ._dapr_client .get_state (
226- store_name = self ._state_store_name ,
227- key = self ._messages_key ,
228- state_metadata = self ._get_read_metadata (),
229- )
230-
231- # Parse existing messages
232- existing_messages = []
233- if response .data :
268+ serialized_items : list [str ] = [await self ._serialize_item (item ) for item in items ]
269+ attempt = 0
270+ while True :
271+ attempt += 1
272+ response = await self ._dapr_client .get_state (
273+ store_name = self ._state_store_name ,
274+ key = self ._messages_key ,
275+ state_metadata = self ._get_read_metadata (),
276+ )
277+ existing_messages = self ._decode_messages (response .data )
278+ updated_messages = existing_messages + serialized_items
279+ messages_json = json .dumps (updated_messages , separators = ("," , ":" ))
280+ etag = response .etag
234281 try :
235- messages_json = response .data .decode ("utf-8" )
236- existing_messages = json .loads (messages_json )
237- if not isinstance (existing_messages , list ):
238- existing_messages = []
239- except (json .JSONDecodeError , UnicodeDecodeError ):
240- existing_messages = []
241-
242- # Serialize and append new items
243- for item in items :
244- serialized = await self ._serialize_item (item )
245- existing_messages .append (serialized )
246-
247- # Save updated messages list
248- messages_json = json .dumps (existing_messages , separators = ("," , ":" ))
249- await self ._dapr_client .save_state (
250- store_name = self ._state_store_name ,
251- key = self ._messages_key ,
252- value = messages_json ,
253- state_metadata = self ._get_metadata (),
254- options = self ._get_state_options (),
255- )
282+ await self ._dapr_client .save_state (
283+ store_name = self ._state_store_name ,
284+ key = self ._messages_key ,
285+ value = messages_json ,
286+ etag = etag ,
287+ state_metadata = self ._get_metadata (),
288+ options = self ._get_state_options (concurrency = Concurrency .first_write ),
289+ )
290+ break
291+ except Exception as error :
292+ should_retry = await self ._handle_concurrency_conflict (error , attempt )
293+ if should_retry :
294+ continue
295+ raise
256296
257297 # Update metadata
258298 metadata = {
@@ -275,45 +315,41 @@ async def pop_item(self) -> TResponseInputItem | None:
275315 The most recent item if it exists, None if the session is empty
276316 """
277317 async with self ._lock :
278- # Get messages from state store with consistency level
279- response = await self ._dapr_client .get_state (
280- store_name = self ._state_store_name ,
281- key = self ._messages_key ,
282- state_metadata = self ._get_read_metadata (),
283- )
284-
285- if not response .data :
286- return None
287-
288- try :
289- # Parse the messages list
290- messages_json = response .data .decode ("utf-8" )
291- messages = json .loads (messages_json )
292-
293- if not isinstance (messages , list ) or len (messages ) == 0 :
294- return None
295-
296- # Pop the last item
297- last_item = messages .pop ()
298-
299- # Save updated messages list
300- messages_json = json .dumps (messages , separators = ("," , ":" ))
301- await self ._dapr_client .save_state (
318+ attempt = 0
319+ while True :
320+ attempt += 1
321+ response = await self ._dapr_client .get_state (
302322 store_name = self ._state_store_name ,
303323 key = self ._messages_key ,
304- value = messages_json ,
305- state_metadata = self ._get_metadata (),
306- options = self ._get_state_options (),
324+ state_metadata = self ._get_read_metadata (),
307325 )
308-
309- # Deserialize and return the item
326+ messages = self ._decode_messages (response .data )
327+ if not messages :
328+ return None
329+ last_item = messages .pop ()
330+ messages_json = json .dumps (messages , separators = ("," , ":" ))
331+ etag = getattr (response , "etag" , None ) or None
332+ etag = getattr (response , "etag" , None ) or None
333+ try :
334+ await self ._dapr_client .save_state (
335+ store_name = self ._state_store_name ,
336+ key = self ._messages_key ,
337+ value = messages_json ,
338+ etag = etag ,
339+ state_metadata = self ._get_metadata (),
340+ options = self ._get_state_options (concurrency = Concurrency .first_write ),
341+ )
342+ break
343+ except Exception as error :
344+ should_retry = await self ._handle_concurrency_conflict (error , attempt )
345+ if should_retry :
346+ continue
347+ raise
348+ try :
310349 if isinstance (last_item , str ):
311350 return await self ._deserialize_item (last_item )
312- else :
313- return last_item # type: ignore[no-any-return]
314-
315- except (json .JSONDecodeError , UnicodeDecodeError , TypeError ):
316- # Return None for corrupted data
351+ return last_item # type: ignore[no-any-return]
352+ except (json .JSONDecodeError , TypeError ):
317353 return None
318354
319355 async def clear_session (self ) -> None :
0 commit comments