Skip to content

Commit 72dd9eb

Browse files
committed
tackled feedback:
- remove warning from testcontainer by pinning to previous version without this behavior - add optimistic locking to deal with concurrent save from different processes Signed-off-by: Filinto Duran <[email protected]>
1 parent 80fd3e9 commit 72dd9eb

File tree

6 files changed

+383
-165
lines changed

6 files changed

+383
-165
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,12 @@ make lint # run linter
306306
make format-check # run style checker
307307
```
308308

309+
Format code if `make format-check` fails above by running:
310+
311+
```
312+
make format
313+
```
314+
309315
## Acknowledgements
310316

311317
We'd like to acknowledge the excellent work of the open-source community, especially:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ dev = [
7272
"fakeredis>=2.31.3",
7373
"dapr>=1.14.0",
7474
"grpcio>=1.60.0",
75-
"testcontainers[redis]>=4.0.0",
75+
"testcontainers==4.12.0", # pinned to 4.12.0 because 4.13.0 has a warning bug in wait_for_logs, see https://github.com/testcontainers/testcontainers-python/issues/874
7676
]
7777

7878
[tool.uv.workspace]

src/agents/extensions/memory/dapr_session.py

Lines changed: 140 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525

2626
import asyncio
2727
import json
28+
import random
2829
import time
29-
from typing import Any, Literal
30+
from typing import Any, Final, Literal
3031

3132
try:
3233
from dapr.aio.clients import DaprClient
33-
from dapr.clients.grpc._state import Consistency, StateOptions
34+
from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
3435
except ImportError as e:
3536
raise ImportError(
3637
"DaprSession requires the 'dapr' package. Install it with: pip install dapr"
@@ -47,6 +48,10 @@
4748
DAPR_CONSISTENCY_EVENTUAL: ConsistencyLevel = "eventual"
4849
DAPR_CONSISTENCY_STRONG: ConsistencyLevel = "strong"
4950

51+
_MAX_WRITE_ATTEMPTS: Final[int] = 5
52+
_RETRY_BASE_DELAY_SECONDS: Final[float] = 0.05
53+
_RETRY_MAX_DELAY_SECONDS: Final[float] = 1.0
54+
5055

5156
class DaprSession(SessionABC):
5257
"""Dapr State Store implementation of :pyclass:`agents.memory.session.Session`."""
@@ -130,12 +135,17 @@ def _get_read_metadata(self) -> dict[str, str]:
130135
metadata["consistency"] = self._consistency
131136
return metadata
132137

133-
def _get_state_options(self) -> StateOptions | None:
134-
"""Get StateOptions for write/delete consistency level."""
138+
def _get_state_options(self, *, concurrency: Concurrency | None = None) -> StateOptions | None:
139+
"""Get StateOptions configured with consistency and optional concurrency."""
140+
options_kwargs: dict[str, Any] = {}
135141
if self._consistency == DAPR_CONSISTENCY_STRONG:
136-
return StateOptions(consistency=Consistency.strong)
142+
options_kwargs["consistency"] = Consistency.strong
137143
elif self._consistency == DAPR_CONSISTENCY_EVENTUAL:
138-
return StateOptions(consistency=Consistency.eventual)
144+
options_kwargs["consistency"] = Consistency.eventual
145+
if concurrency is not None:
146+
options_kwargs["concurrency"] = concurrency
147+
if options_kwargs:
148+
return StateOptions(**options_kwargs)
139149
return None
140150

141151
def _get_metadata(self) -> dict[str, str]:
@@ -153,6 +163,57 @@ async def _deserialize_item(self, item: str) -> TResponseInputItem:
153163
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
154164
return json.loads(item) # type: ignore[no-any-return]
155165

166+
def _decode_messages(self, data: bytes | None) -> list[Any]:
167+
if not data:
168+
return []
169+
try:
170+
messages_json = data.decode("utf-8")
171+
messages = json.loads(messages_json)
172+
if isinstance(messages, list):
173+
return list(messages)
174+
except (json.JSONDecodeError, UnicodeDecodeError):
175+
return []
176+
return []
177+
178+
def _calculate_retry_delay(self, attempt: int) -> float:
179+
base: float = _RETRY_BASE_DELAY_SECONDS * (2 ** max(0, attempt - 1))
180+
delay: float = min(base, _RETRY_MAX_DELAY_SECONDS)
181+
# Add jitter (10%) similar to tracing processors to avoid thundering herd.
182+
return delay + random.uniform(0, 0.1 * delay)
183+
184+
def _is_concurrency_conflict(self, error: Exception) -> bool:
185+
code_attr = getattr(error, "code", None)
186+
if callable(code_attr):
187+
try:
188+
status_code = code_attr()
189+
except Exception:
190+
status_code = None
191+
if status_code is not None:
192+
status_name = getattr(status_code, "name", str(status_code))
193+
if status_name in {"ABORTED", "FAILED_PRECONDITION"}:
194+
return True
195+
message = str(error).lower()
196+
conflict_markers = (
197+
"etag mismatch",
198+
"etag does not match",
199+
"precondition failed",
200+
"concurrency conflict",
201+
"invalid etag",
202+
"failed to set key", # Redis state store Lua script error during conditional write
203+
"user_script", # Redis script failure hint
204+
)
205+
return any(marker in message for marker in conflict_markers)
206+
207+
async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> bool:
208+
if not self._is_concurrency_conflict(error):
209+
return False
210+
if attempt >= _MAX_WRITE_ATTEMPTS:
211+
return False
212+
delay = self._calculate_retry_delay(attempt)
213+
if delay > 0:
214+
await asyncio.sleep(delay)
215+
return True
216+
156217
# ------------------------------------------------------------------
157218
# Session protocol implementation
158219
# ------------------------------------------------------------------
@@ -175,41 +236,24 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
175236
state_metadata=self._get_read_metadata(),
176237
)
177238

178-
if not response.data:
239+
messages = self._decode_messages(response.data)
240+
if not messages:
179241
return []
180-
181-
try:
182-
# Parse the messages list from JSON
183-
messages_json = response.data.decode("utf-8")
184-
messages = json.loads(messages_json)
185-
186-
if not isinstance(messages, list):
242+
if limit is not None:
243+
if limit <= 0:
187244
return []
188-
189-
# Apply limit if specified
190-
if limit is not None:
191-
if limit <= 0:
192-
return []
193-
# Return the latest N items
194-
messages = messages[-limit:]
195-
196-
items: list[TResponseInputItem] = []
197-
for msg in messages:
198-
try:
199-
if isinstance(msg, str):
200-
item = await self._deserialize_item(msg)
201-
else:
202-
item = msg # Already deserialized
203-
items.append(item)
204-
except (json.JSONDecodeError, TypeError):
205-
# Skip corrupted messages
206-
continue
207-
208-
return items
209-
210-
except (json.JSONDecodeError, UnicodeDecodeError):
211-
# Return empty list for corrupted data
212-
return []
245+
messages = messages[-limit:]
246+
items: list[TResponseInputItem] = []
247+
for msg in messages:
248+
try:
249+
if isinstance(msg, str):
250+
item = await self._deserialize_item(msg)
251+
else:
252+
item = msg
253+
items.append(item)
254+
except (json.JSONDecodeError, TypeError):
255+
continue
256+
return items
213257

214258
async def add_items(self, items: list[TResponseInputItem]) -> None:
215259
"""Add new items to the conversation history.
@@ -221,38 +265,34 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
221265
return
222266

223267
async with self._lock:
224-
# Get existing messages with consistency level
225-
response = await self._dapr_client.get_state(
226-
store_name=self._state_store_name,
227-
key=self._messages_key,
228-
state_metadata=self._get_read_metadata(),
229-
)
230-
231-
# Parse existing messages
232-
existing_messages = []
233-
if response.data:
268+
serialized_items: list[str] = [await self._serialize_item(item) for item in items]
269+
attempt = 0
270+
while True:
271+
attempt += 1
272+
response = await self._dapr_client.get_state(
273+
store_name=self._state_store_name,
274+
key=self._messages_key,
275+
state_metadata=self._get_read_metadata(),
276+
)
277+
existing_messages = self._decode_messages(response.data)
278+
updated_messages = existing_messages + serialized_items
279+
messages_json = json.dumps(updated_messages, separators=(",", ":"))
280+
etag = response.etag
234281
try:
235-
messages_json = response.data.decode("utf-8")
236-
existing_messages = json.loads(messages_json)
237-
if not isinstance(existing_messages, list):
238-
existing_messages = []
239-
except (json.JSONDecodeError, UnicodeDecodeError):
240-
existing_messages = []
241-
242-
# Serialize and append new items
243-
for item in items:
244-
serialized = await self._serialize_item(item)
245-
existing_messages.append(serialized)
246-
247-
# Save updated messages list
248-
messages_json = json.dumps(existing_messages, separators=(",", ":"))
249-
await self._dapr_client.save_state(
250-
store_name=self._state_store_name,
251-
key=self._messages_key,
252-
value=messages_json,
253-
state_metadata=self._get_metadata(),
254-
options=self._get_state_options(),
255-
)
282+
await self._dapr_client.save_state(
283+
store_name=self._state_store_name,
284+
key=self._messages_key,
285+
value=messages_json,
286+
etag=etag,
287+
state_metadata=self._get_metadata(),
288+
options=self._get_state_options(concurrency=Concurrency.first_write),
289+
)
290+
break
291+
except Exception as error:
292+
should_retry = await self._handle_concurrency_conflict(error, attempt)
293+
if should_retry:
294+
continue
295+
raise
256296

257297
# Update metadata
258298
metadata = {
@@ -275,45 +315,41 @@ async def pop_item(self) -> TResponseInputItem | None:
275315
The most recent item if it exists, None if the session is empty
276316
"""
277317
async with self._lock:
278-
# Get messages from state store with consistency level
279-
response = await self._dapr_client.get_state(
280-
store_name=self._state_store_name,
281-
key=self._messages_key,
282-
state_metadata=self._get_read_metadata(),
283-
)
284-
285-
if not response.data:
286-
return None
287-
288-
try:
289-
# Parse the messages list
290-
messages_json = response.data.decode("utf-8")
291-
messages = json.loads(messages_json)
292-
293-
if not isinstance(messages, list) or len(messages) == 0:
294-
return None
295-
296-
# Pop the last item
297-
last_item = messages.pop()
298-
299-
# Save updated messages list
300-
messages_json = json.dumps(messages, separators=(",", ":"))
301-
await self._dapr_client.save_state(
318+
attempt = 0
319+
while True:
320+
attempt += 1
321+
response = await self._dapr_client.get_state(
302322
store_name=self._state_store_name,
303323
key=self._messages_key,
304-
value=messages_json,
305-
state_metadata=self._get_metadata(),
306-
options=self._get_state_options(),
324+
state_metadata=self._get_read_metadata(),
307325
)
308-
309-
# Deserialize and return the item
326+
messages = self._decode_messages(response.data)
327+
if not messages:
328+
return None
329+
last_item = messages.pop()
330+
messages_json = json.dumps(messages, separators=(",", ":"))
331+
etag = getattr(response, "etag", None) or None
332+
etag = getattr(response, "etag", None) or None
333+
try:
334+
await self._dapr_client.save_state(
335+
store_name=self._state_store_name,
336+
key=self._messages_key,
337+
value=messages_json,
338+
etag=etag,
339+
state_metadata=self._get_metadata(),
340+
options=self._get_state_options(concurrency=Concurrency.first_write),
341+
)
342+
break
343+
except Exception as error:
344+
should_retry = await self._handle_concurrency_conflict(error, attempt)
345+
if should_retry:
346+
continue
347+
raise
348+
try:
310349
if isinstance(last_item, str):
311350
return await self._deserialize_item(last_item)
312-
else:
313-
return last_item # type: ignore[no-any-return]
314-
315-
except (json.JSONDecodeError, UnicodeDecodeError, TypeError):
316-
# Return None for corrupted data
351+
return last_item # type: ignore[no-any-return]
352+
except (json.JSONDecodeError, TypeError):
317353
return None
318354

319355
async def clear_session(self) -> None:

0 commit comments

Comments
 (0)