Skip to content

Commit ed4f1d6

Browse files
committed
style: tox -e ruff
Signed-off-by: Samantha Coyle <[email protected]>
1 parent 5cc1211 commit ed4f1d6

File tree

10 files changed

+212
-112
lines changed

10 files changed

+212
-112
lines changed

dapr_agents/agents/base.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class AgentBase(BaseModel, ABC):
103103
)
104104
registry_store: Optional[str] = Field(
105105
default=None,
106-
description="Agent registry store name for storing static agent information. Defaults to memory_store state store name if not provided."
106+
description="Agent registry store name for storing static agent information. Defaults to memory_store state store name if not provided.",
107107
)
108108

109109
DEFAULT_SYSTEM_PROMPT: ClassVar[str]
@@ -181,11 +181,7 @@ def model_post_init(self, __context: Any) -> None:
181181

182182
# Initialize Dapr client if storage is persistent
183183
# This is needed for state store access and agent registration
184-
if (
185-
self.memory_store
186-
and self.memory_store.name
187-
and self._dapr_client is None
188-
):
184+
if self.memory_store and self.memory_store.name and self._dapr_client is None:
189185
from dapr.clients import DaprClient
190186

191187
self._dapr_client = DaprClient()
@@ -210,7 +206,7 @@ def model_post_init(self, __context: Any) -> None:
210206
"statestore_name": self.memory_store.name,
211207
"registry_name": self.registry_store,
212208
}
213-
209+
214210
self.register_agent(
215211
store_name=self.registry_store,
216212
store_key="agent_registry",
@@ -620,16 +616,20 @@ def register_agent(
620616
)
621617

622618
# reread to obtain the freshly minted ETag
623-
response = self._dapr_client.get_state(store_name=store_name, key=store_key)
619+
response = self._dapr_client.get_state(
620+
store_name=store_name, key=store_key
621+
)
624622
if not response.etag:
625623
raise RuntimeError("ETag still missing after init")
626-
627-
existing = self._deserialize_state(response.data) if response.data else {}
624+
625+
existing = (
626+
self._deserialize_state(response.data) if response.data else {}
627+
)
628628

629629
if existing.get(agent_name) == agent_metadata:
630630
logger.debug(f"Agent '{agent_name}' already registered")
631631
return
632-
632+
633633
safe_metadata = self._serialize_metadata(agent_metadata)
634634

635635
merged = {**existing, agent_name: safe_metadata}
@@ -711,12 +711,13 @@ def _deserialize_state(self, raw: Union[bytes, str, dict]) -> dict:
711711
raise ValueError(f"State is not valid JSON: {exc}") from exc
712712

713713
raise TypeError(f"Unsupported state type {type(raw)!r}")
714-
714+
715715
def _serialize_metadata(self, metadata: Any) -> Any:
716716
"""
717717
Recursively convert Pydantic models (e.g., AgentTool), lists, dicts to JSON-serializable format.
718718
Handles mixed tools: [AgentTool(...), "string", ...] → [{"name": "..."}, "string", ...]
719719
"""
720+
720721
def convert(obj: Any) -> Any:
721722
if hasattr(obj, "model_dump"):
722723
return obj.model_dump()
@@ -727,4 +728,5 @@ def convert(obj: Any) -> Any:
727728
if isinstance(obj, dict):
728729
return {k: convert(v) for k, v in obj.items()}
729730
return obj
730-
return convert(metadata)
731+
732+
return convert(metadata)

dapr_agents/agents/durableagent/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def model_post_init(self, __context: Any) -> None:
122122

123123
# Load the current workflow instance ID from state using session_id
124124
logger.debug(f"State after loading: {self.memory_store._current_state}")
125-
if self.memory_store._current_state and self.memory_store._current_state.get("instances"):
125+
if self.memory_store._current_state and self.memory_store._current_state.get(
126+
"instances"
127+
):
126128
logger.debug(
127129
f"Found {len(self.memory_store._current_state['instances'])} instances in state"
128130
)

dapr_agents/agents/memory_store.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
logger = logging.getLogger(__name__)
1313

14+
1415
class MemoryStore(BaseModel):
1516
"""
1617
Unified storage for both Agent and DurableAgent.
@@ -61,7 +62,9 @@ class MemoryStore(BaseModel):
6162
def model_post_init(self, __context: Any) -> None:
6263
if self.name is None:
6364
if self._dapr_client is not None:
64-
logger.warning("DaprClient initialized but name is None. It will be ignored.")
65+
logger.warning(
66+
"DaprClient initialized but name is None. It will be ignored."
67+
)
6568
self._dapr_client = None
6669
else:
6770
self._dapr_client = DaprClient()
@@ -117,7 +120,9 @@ def _update_session_index(self, instance_id: str) -> None:
117120
try:
118121
session_data = json.loads(raw)
119122
if not isinstance(session_data, dict):
120-
logger.warning(f"Session data not a dict, resetting: {type(session_data)}")
123+
logger.warning(
124+
f"Session data not a dict, resetting: {type(session_data)}"
125+
)
121126
session_data = {}
122127
except json.JSONDecodeError as e:
123128
logger.error(f"Invalid session JSON for '{session_key}': {e}")
@@ -161,7 +166,9 @@ def _update_session_index(self, instance_id: str) -> None:
161166
parsed = json.loads(raw)
162167
if isinstance(parsed, dict):
163168
index_data["sessions"] = parsed.get("sessions", [])
164-
index_data["last_updated"] = parsed.get("last_updated", index_data["last_updated"])
169+
index_data["last_updated"] = parsed.get(
170+
"last_updated", index_data["last_updated"]
171+
)
165172
except json.JSONDecodeError:
166173
logger.warning("Corrupted sessions index, resetting")
167174

@@ -170,11 +177,9 @@ def _update_session_index(self, instance_id: str) -> None:
170177
index_data["last_updated"] = datetime.now().isoformat()
171178
self._save_state_with_metadata(sessions_index_key, index_data)
172179
logger.debug(f"Registered session '{session_id}' in index")
173-
180+
174181
# TODO: in future remove this in favor of just using client.save_state when we use objects and not dictionaries in storage.
175-
def _save_state_with_metadata(
176-
self, key: str, data: Any
177-
) -> None:
182+
def _save_state_with_metadata(self, key: str, data: Any) -> None:
178183
"""Save state with content type metadata."""
179184
# Serialize data to JSON string if it's not already
180185
if isinstance(data, dict):
@@ -185,7 +190,10 @@ def _save_state_with_metadata(
185190
data_to_save = json.dumps(data)
186191

187192
self._dapr_client.save_state(
188-
self.name, key, data_to_save, state_metadata={"contentType": "application/json"}
193+
self.name,
194+
key,
195+
data_to_save,
196+
state_metadata={"contentType": "application/json"},
189197
)
190198

191199
def is_persistent(self) -> bool:
@@ -276,6 +284,7 @@ def _load_messages_from_store(self) -> List[Dict[str, Any]]:
276284
return data.get("messages", [])
277285
return []
278286

287+
279288
class DurableAgentMessage(MessageContent):
280289
id: str = Field(
281290
default_factory=lambda: str(uuid.uuid4()),

dapr_agents/workflow/agentic.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class AgenticWorkflow(
6060

6161
registry_store: Optional[str] = Field(
6262
default=None,
63-
description="Agent registry store name for storing static agent information."
63+
description="Agent registry store name for storing static agent information.",
6464
)
6565

6666
# TODO: test this is respected by runtime.
@@ -109,7 +109,7 @@ def model_post_init(self, __context: Any) -> None:
109109

110110
# Set storage key based on agent name
111111
self.memory_store._set_key(self.name)
112-
112+
113113
logger.info(f"State store '{self.memory_store.name}' initialized.")
114114
self.initialize_state()
115115
if self.registry_store is None:
@@ -364,7 +364,9 @@ def register_agent(
364364
)
365365
# raise an exception to retry the entire operation
366366
raise Exception(f"No etag found for key: {store_key}")
367-
existing_data = self._deserialize_state(response.data) if response.data else {}
367+
existing_data = (
368+
self._deserialize_state(response.data) if response.data else {}
369+
)
368370
if (agent_name, agent_metadata) in existing_data.items():
369371
logger.debug(f"agent {agent_name} already registered.")
370372
return None
@@ -423,4 +425,4 @@ def _deserialize_state(self, raw: Union[bytes, str, dict]) -> dict:
423425
except json.JSONDecodeError as exc:
424426
raise ValueError(f"State is not valid JSON: {exc}") from exc
425427

426-
raise TypeError(f"Unsupported state type {type(raw)!r}")
428+
raise TypeError(f"Unsupported state type {type(raw)!r}")

dapr_agents/workflow/mixins/state.py

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def _reconcile_workflow_statuses(self) -> None:
5555
updated_instances.append(instance_id)
5656

5757
# Save the updated instance back to Redis
58-
instance_key = self.memory_store._get_instance_key(instance_id)
58+
instance_key = self.memory_store._get_instance_key(
59+
instance_id
60+
)
5961
self.memory_store._save_state_with_metadata(
6062
instance_key, instance_data
6163
)
@@ -70,7 +72,9 @@ def _reconcile_workflow_statuses(self) -> None:
7072
updated_instances.append(instance_id)
7173

7274
# Save the updated instance
73-
instance_key = self.memory_store._get_instance_key(instance_id)
75+
instance_key = self.memory_store._get_instance_key(
76+
instance_id
77+
)
7478
self.memory_store._save_state_with_metadata(
7579
instance_key, instance_data
7680
)
@@ -170,7 +174,9 @@ def initialize_state(self) -> None:
170174
logger.debug(
171175
"User provided a state as a Pydantic model. Converting to dict."
172176
)
173-
self.memory_store._current_state = self.memory_store._current_state.model_dump()
177+
self.memory_store._current_state = (
178+
self.memory_store._current_state.model_dump()
179+
)
174180

175181
if not isinstance(self.memory_store._current_state, dict):
176182
raise TypeError(
@@ -209,8 +215,7 @@ def load_state(self) -> dict:
209215

210216
# For durable agents, always load from database to ensure it's the source of truth
211217
response = self._dapr_client.get_state(
212-
self.memory_store.name,
213-
self.memory_store._key
218+
self.memory_store.name, self.memory_store._key
214219
)
215220
if response.data:
216221
state_data = self._deserialize_state(response.data)
@@ -227,8 +232,7 @@ def load_state(self) -> dict:
227232
# Get all sessions for this agent
228233
sessions_index_key = self.memory_store._get_sessions_index_key()
229234
response = self._dapr_client.get_state(
230-
self.memory_store.name,
231-
sessions_index_key
235+
self.memory_store.name, sessions_index_key
232236
)
233237

234238
if response.data:
@@ -242,8 +246,7 @@ def load_state(self) -> dict:
242246
for session_id in session_ids:
243247
session_key = self.memory_store._get_session_key(session_id)
244248
response = self._dapr_client.get_state(
245-
self.memory_store.name,
246-
session_key
249+
self.memory_store.name, session_key
247250
)
248251

249252
if response.data:
@@ -256,8 +259,12 @@ def load_state(self) -> dict:
256259

257260
# Load each instance
258261
for instance_id in instance_ids:
259-
instance_key = self.memory_store._get_instance_key(instance_id)
260-
response = self._dapr_client.get_state(self.memory_store.name, instance_key)
262+
instance_key = self.memory_store._get_instance_key(
263+
instance_id
264+
)
265+
response = self._dapr_client.get_state(
266+
self.memory_store.name, instance_key
267+
)
261268
if response.data:
262269
instance_data = self._deserialize_state(response.data)
263270

@@ -293,7 +300,9 @@ def load_state(self) -> dict:
293300
)
294301
return self.memory_store._current_state
295302
except Exception as e:
296-
logger.error(f"Failed to load state for key '{self.memory_store._key}': {e}")
303+
logger.error(
304+
f"Failed to load state for key '{self.memory_store._key}': {e}"
305+
)
297306
raise RuntimeError(f"Error loading workflow state: {e}") from e
298307

299308
def get_local_state_file_path(self) -> str:
@@ -428,19 +437,27 @@ def save_state(
428437
instance_json = instance_data
429438
else:
430439
instance_json = json.dumps(instance_data)
431-
self._dapr_client.save_state(self.memory_store.name, instance_key, instance_json)
440+
self._dapr_client.save_state(
441+
self.memory_store.name, instance_key, instance_json
442+
)
432443
logger.debug(
433444
f"Saved workflow instance {instance_id} to key '{instance_key}'"
434445
)
435446

436447
# Save other state data (like chat_history) to main key
437448
other_state = {
438-
k: v for k, v in self.memory_store._current_state.items() if k != "instances"
449+
k: v
450+
for k, v in self.memory_store._current_state.items()
451+
if k != "instances"
439452
}
440453
if other_state:
441454
other_state_json = json.dumps(other_state)
442-
self._dapr_client.save_state(self.memory_store.name, self.memory_store._key, other_state_json)
443-
logger.debug(f"Saved non-instance state to key '{self.memory_store._key}'")
455+
self._dapr_client.save_state(
456+
self.memory_store.name, self.memory_store._key, other_state_json
457+
)
458+
logger.debug(
459+
f"Saved non-instance state to key '{self.memory_store._key}'"
460+
)
444461

445462
if self.memory_store.local_directory is not None:
446463
self.save_state_to_disk(state_data=state_to_save)
@@ -451,7 +468,9 @@ def save_state(
451468
f"State reloaded after saving for key '{self.memory_store._key}'."
452469
)
453470
except Exception as e:
454-
logger.error(f"Failed to save state for key '{self.memory_store._key}': {e}")
471+
logger.error(
472+
f"Failed to save state for key '{self.memory_store._key}': {e}"
473+
)
455474
raise
456475

457476
def _deserialize_state(self, raw: Union[bytes, str, dict]) -> dict:
@@ -474,4 +493,4 @@ def _deserialize_state(self, raw: Union[bytes, str, dict]) -> dict:
474493
except json.JSONDecodeError as exc:
475494
raise ValueError(f"State is not valid JSON: {exc}") from exc
476495

477-
raise TypeError(f"Unsupported state type {type(raw)!r}")
496+
raise TypeError(f"Unsupported state type {type(raw)!r}")

dapr_agents/workflow/orchestrators/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ async def trigger_agent(self, name: str, instance_id: str, **kwargs) -> None:
7676
"""Trigger a specific agent to perform an action."""
7777
pass
7878

79-
8079
def _serialize_metadata(self, metadata: Any) -> Any:
8180
"""
8281
Recursively convert Pydantic models (e.g., AgentTool), lists, dicts to JSON-serializable format.
8382
Handles mixed tools: [AgentTool(...), "string", ...] → [{"name": "..."}, "string", ...]
8483
"""
84+
8585
def convert(obj: Any) -> Any:
8686
if hasattr(obj, "model_dump"):
8787
return obj.model_dump()
@@ -92,4 +92,5 @@ def convert(obj: Any) -> Any:
9292
if isinstance(obj, dict):
9393
return {k: convert(v) for k, v in obj.items()}
9494
return obj
95-
return convert(metadata)
95+
96+
return convert(metadata)

0 commit comments

Comments
 (0)