Skip to content

Commit 89d0a8b

Browse files
Merge pull request #1381 from MervinPraison/claude/issue-1370-20260413-0932
fix: Address 3 critical production robustness gaps
2 parents c23642c + 0d32636 commit 89d0a8b

File tree

6 files changed

+182
-31
lines changed

6 files changed

+182
-31
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,6 +1622,8 @@ def __init__(
16221622
# Token budget guard (zero overhead when _max_budget is None)
16231623
self._max_budget = _max_budget
16241624
self._on_budget_exceeded = _on_budget_exceeded
1625+
# Thread-safe cost/token tracking (Gap 1a fix)
1626+
self._cost_lock = threading.Lock()
16251627
self._total_cost = 0.0
16261628
self._total_tokens_in = 0
16271629
self._total_tokens_out = 0
@@ -1965,7 +1967,9 @@ def thinking_budget(self, value: Optional[int]) -> None:
19651967
@property
19661968
def total_cost(self) -> float:
19671969
"""Cumulative USD cost of all LLM calls in this agent run."""
1968-
return self._total_cost
1970+
# Thread-safe cost reading (Gap 1a fix)
1971+
with self._cost_lock:
1972+
return self._total_cost
19691973

19701974
@property
19711975
def cost_summary(self) -> dict:
@@ -1974,12 +1978,14 @@ def cost_summary(self) -> dict:
19741978
Returns:
19751979
dict with keys: tokens_in, tokens_out, cost, llm_calls
19761980
"""
1977-
return {
1978-
"tokens_in": self._total_tokens_in,
1979-
"tokens_out": self._total_tokens_out,
1980-
"cost": self._total_cost,
1981-
"llm_calls": self._llm_call_count,
1982-
}
1981+
# Thread-safe cost reading (Gap 1a fix)
1982+
with self._cost_lock:
1983+
return {
1984+
"tokens_in": self._total_tokens_in,
1985+
"tokens_out": self._total_tokens_out,
1986+
"cost": self._total_cost,
1987+
"llm_calls": self._llm_call_count,
1988+
}
19831989

19841990
@property
19851991
def context_manager(self) -> Optional[Any]:

src/praisonai-agents/praisonaiagents/agent/chat_mixin.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -653,26 +653,31 @@ def _chat_completion(self, messages, temperature=1.0, tools=None, stream=True, r
653653
)
654654

655655
# Budget tracking & enforcement (zero overhead when _max_budget is None)
656-
self._total_cost += _cost_usd
657-
self._total_tokens_in += _prompt_tokens
658-
self._total_tokens_out += _completion_tokens
659-
self._llm_call_count += 1
660-
if self._max_budget and self._total_cost >= self._max_budget:
656+
# Thread-safe cost tracking (Gap 1a fix)
657+
with self._cost_lock:
658+
self._total_cost += _cost_usd
659+
self._total_tokens_in += _prompt_tokens
660+
self._total_tokens_out += _completion_tokens
661+
self._llm_call_count += 1
662+
budget_exceeded = self._max_budget and self._total_cost >= self._max_budget
663+
current_cost = self._total_cost
664+
665+
if budget_exceeded:
661666
if self._on_budget_exceeded == "stop":
662667
raise BudgetExceededError(
663-
f"Agent '{self.name}' exceeded budget: ${self._total_cost:.4f} >= ${self._max_budget:.4f}",
668+
f"Agent '{self.name}' exceeded budget: ${current_cost:.4f} >= ${self._max_budget:.4f}",
664669
budget_type="cost",
665670
limit=self._max_budget,
666-
used=self._total_cost,
671+
used=current_cost,
667672
agent_id=self.name
668673
)
669674
elif self._on_budget_exceeded == "warn":
670675
logging.warning(
671-
f"[budget] {self.name}: ${self._total_cost:.4f} exceeded "
676+
f"[budget] {self.name}: ${current_cost:.4f} exceeded "
672677
f"${self._max_budget:.4f} budget"
673678
)
674679
elif callable(self._on_budget_exceeded):
675-
self._on_budget_exceeded(self._total_cost, self._max_budget)
680+
self._on_budget_exceeded(current_cost, self._max_budget)
676681

677682
# Trigger AFTER_LLM hook
678683
from ..hooks import HookEvent, AfterLLMInput

src/praisonai-agents/praisonaiagents/agent/tool_execution.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import contextvars
1616
import concurrent.futures
1717
from typing import List, Optional, Any, Dict, Union, TYPE_CHECKING
18+
from ..errors import ToolExecutionError
1819

1920
if TYPE_CHECKING:
2021
pass
@@ -310,8 +311,14 @@ def execute_with_context():
310311
_duration_ms = (_time.time() - _tool_start_time) * 1000
311312
_trace_emitter.tool_call_end(self.name, function_name, None, _duration_ms, str(e))
312313

313-
# Trigger OnError hook if needed (optional future step)
314-
raise
314+
# Gap 3a fix: Wrap exceptions in ToolExecutionError for better observability
315+
is_retryable = not isinstance(e, (ValueError, TypeError, AttributeError))
316+
raise ToolExecutionError(
317+
f"Tool '{function_name}' failed: {e}",
318+
tool_name=function_name,
319+
agent_id=self.name,
320+
is_retryable=is_retryable,
321+
) from e
315322

316323
def _trigger_after_agent_hook(self, prompt, response, start_time, tools_used=None):
317324
"""Trigger AFTER_AGENT hook and return response."""

src/praisonai-agents/praisonaiagents/session.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def __init__(
5252
agent_url: Optional[str] = None,
5353
memory_config: Optional[Dict[str, Any]] = None,
5454
knowledge_config: Optional[Dict[str, Any]] = None,
55-
timeout: int = 30
55+
timeout: int = 30,
56+
session_ttl: Optional[int] = None # Gap 2b: TTL in seconds
5657
):
5758
"""
5859
Initialize a new session with optional persistence or remote agent connectivity.
@@ -64,6 +65,7 @@ def __init__(
6465
memory_config: Configuration for memory system (defaults to RAG)
6566
knowledge_config: Configuration for knowledge base system
6667
timeout: HTTP timeout for remote agent calls (default: 30 seconds)
68+
session_ttl: Time-to-live in seconds after which session expires (Gap 2b)
6769
"""
6870
self.session_id = session_id or str(uuid.uuid4())[:8]
6971
self.user_id = user_id or "default_user"
@@ -112,6 +114,10 @@ def __init__(
112114
self._agents_instance = None
113115
self._agents = {} # Track agents and their chat histories
114116

117+
# Gap 2b: Session TTL and cleanup support
118+
self.session_ttl = session_ttl
119+
self._created_at = time.time() # Track creation time for TTL
120+
115121
def _get_session_dir(self):
116122
"""Return session-specific directory using paths.py."""
117123
from pathlib import Path
@@ -188,8 +194,8 @@ def Agent(
188194

189195
agent = Agent(**agent_kwargs)
190196

191-
# Create a unique key for this agent (using name and role)
192-
agent_key = f"{name}:{role}"
197+
# Create a unique key for this agent (Gap 2a fix: include session_id for proper isolation)
198+
agent_key = f"{self.session_id}:{name}:{role}"
193199

194200
# Restore chat history if it exists from previous sessions
195201
if agent_key in self._agents:
@@ -272,7 +278,7 @@ def restore_state(self) -> Dict[str, Any]:
272278

273279
def _restore_agent_chat_history(self, agent_key: str) -> List[Dict[str, Any]]:
274280
"""
275-
Restore agent chat history from memory.
281+
Restore agent chat history from SessionStore first, then memory fallback (Gap 2c fix).
276282
277283
Args:
278284
agent_key: Unique identifier for the agent
@@ -283,7 +289,21 @@ def _restore_agent_chat_history(self, agent_key: str) -> List[Dict[str, Any]]:
283289
if self.is_remote:
284290
return []
285291

286-
# Search for agent chat history in memory
292+
# Gap 2c: Try SessionStore first for clean separation
293+
try:
294+
from .session.store import get_default_session_store
295+
session_store = get_default_session_store()
296+
chat_history = session_store.get_chat_history(agent_key)
297+
if not chat_history:
298+
# Backward compatibility: fall back to legacy
299+
# "{session_id}_{agent_key}" format for existing stored conversations.
300+
chat_history = session_store.get_chat_history(f"{self.session_id}_{agent_key}")
301+
if chat_history:
302+
return chat_history
303+
except ImportError:
304+
pass
305+
306+
# Fallback: Search for agent chat history in memory (backward compatibility)
287307
results = self.memory.search_short_term(
288308
query="Agent chat history for",
289309
limit=10
@@ -352,21 +372,20 @@ def _save_agent_chat_histories(self) -> None:
352372
chat_history = agent_data.get("chat_history")
353373

354374
if chat_history is not None:
355-
# G-2 FIX: Try SessionStore first for clean separation
375+
# G-2 FIX: Use SessionStore for clean separation (Gap 2c fix)
356376
session_store = None
357377
try:
358-
from .session import get_default_session_store
378+
from .session.store import get_default_session_store
359379
session_store = get_default_session_store()
360380
except ImportError:
361381
pass
362382

363383
if session_store is not None:
364384
# Use SessionStore for conversation history
365-
session_id = f"{self.session_id}_{agent_key}"
366385
for msg in chat_history:
367386
if isinstance(msg, dict):
368387
session_store.add_message(
369-
session_id,
388+
agent_key,
370389
role=msg.get("role", "user"),
371390
content=msg.get("content", ""),
372391
)
@@ -587,10 +606,48 @@ def send_message(self, message: str, **kwargs) -> str:
587606
"""
588607
return self.chat(message, **kwargs)
589608

609+
def is_expired(self) -> bool:
610+
"""Check if the session has expired based on TTL (Gap 2b fix)."""
611+
if self.session_ttl is None:
612+
return False
613+
return time.time() - self._created_at > self.session_ttl
614+
615+
def close(self) -> None:
616+
"""Close and cleanup the session (Gap 2b fix)."""
617+
if self.is_remote:
618+
return # No cleanup needed for remote sessions
619+
620+
# Clear memory
621+
if self._memory:
622+
try:
623+
# Clear short-term and long-term memory for this session
624+
# Note: This is a basic implementation - specific memory backends
625+
# might need more sophisticated cleanup
626+
self._memory = None
627+
except Exception:
628+
pass # Ignore cleanup errors
629+
630+
# Clear knowledge
631+
if self._knowledge:
632+
try:
633+
self._knowledge = None
634+
except Exception:
635+
pass # Ignore cleanup errors
636+
637+
# Clear agents
638+
self._agents.clear()
639+
640+
def time_to_expiry(self) -> Optional[float]:
641+
"""Get seconds until session expires, or None if no TTL set (Gap 2b)."""
642+
if self.session_ttl is None:
643+
return None
644+
elapsed = time.time() - self._created_at
645+
return max(0, self.session_ttl - elapsed)
646+
590647
def __str__(self) -> str:
591648
if self.is_remote:
592649
return f"Session(id='{self.session_id}', user='{self.user_id}', remote_agent='{self.agent_url}')"
593650
return f"Session(id='{self.session_id}', user='{self.user_id}')"
594651

595652
def __repr__(self) -> str:
596-
return self.__str__()
653+
return self.__str__()

src/praisonai-agents/praisonaiagents/workflows/workflows.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import os
2323
import re
2424
import json
25+
import copy
26+
import time
2527
import logging
2628
from praisonaiagents._logging import get_logger
2729
from pathlib import Path
@@ -585,6 +587,8 @@ class AgentFlow:
585587
_reflection_config: Optional[Any] = field(default=None, repr=False)
586588
# Execution history for debugging (only populated when history=True)
587589
_execution_history: List[Dict[str, Any]] = field(default_factory=list, repr=False)
590+
# Gap 3c: Cross-step handoff cycle detection
591+
_handoff_chain: List[str] = field(default_factory=list, repr=False)
588592

589593
def __post_init__(self):
590594
"""Resolve consolidated params to internal values."""
@@ -899,6 +903,30 @@ def from_template(
899903
"Install with: pip install praisonai"
900904
)
901905

906+
def _check_handoff_cycle(self, step: Any) -> None:
907+
"""Check for cross-step handoff cycles (Gap 3c fix)."""
908+
step_id = getattr(step, 'name', str(step))
909+
910+
# If step involves handoff (basic check for agent steps)
911+
if hasattr(step, 'agent') and step.agent:
912+
# Track this step in handoff chain
913+
if step_id in self._handoff_chain:
914+
# Cycle detected!
915+
cycle_path = self._handoff_chain[self._handoff_chain.index(step_id):] + [step_id]
916+
from ..errors import HandoffCycleError
917+
raise HandoffCycleError(
918+
f"Cross-step handoff cycle detected: {' -> '.join(cycle_path)}",
919+
cycle_path=cycle_path
920+
)
921+
922+
# Add to chain (limit chain length to prevent memory issues)
923+
self._handoff_chain.append(step_id)
924+
if len(self._handoff_chain) > 100: # Reasonable limit
925+
self._handoff_chain = self._handoff_chain[-50:] # Keep last 50
926+
else:
927+
# Non-agent step, reset chain
928+
self._handoff_chain.clear()
929+
902930
def run(
903931
self,
904932
input: str = "",
@@ -933,6 +961,9 @@ def run(
933961
Returns:
934962
Dict with 'output' (final result) and 'steps' (all step results)
935963
"""
964+
# Gap 3c: Clear handoff chain at start of new workflow run
965+
self._handoff_chain.clear()
966+
936967
# Use default LLM if not specified
937968
model = llm or self.llm or "gpt-4o-mini"
938969
logger.debug(f"Workflow using model: {model} (llm={llm}, default_llm={self.llm})")
@@ -1096,6 +1127,9 @@ def run(
10961127
except Exception as e:
10971128
logger.error(f"should_run failed for {step.name}: {e}")
10981129

1130+
# Gap 3c: Check for cross-step handoff cycles
1131+
self._check_handoff_cycle(step)
1132+
10991133
# Execute step with retry and guardrail support
11001134
output = None
11011135
stop = False
@@ -1273,6 +1307,24 @@ def run(
12731307
except Exception as e:
12741308
step_error = e
12751309
output = f"Error: {e}"
1310+
1311+
# Gap 3b fix: Check if error is retryable and implement exponential backoff
1312+
is_retryable = getattr(e, 'is_retryable', True) # Default to retryable
1313+
if not is_retryable:
1314+
# Non-retryable error - break out of retry loop immediately
1315+
if verbose:
1316+
print(f"❌ {step.name} failed with non-retryable error: {e}")
1317+
break
1318+
1319+
retry_count += 1
1320+
if retry_count <= max_retries:
1321+
# Exponential backoff: wait 2^(retry_count-1) seconds
1322+
backoff_seconds = 2 ** (retry_count - 1)
1323+
if verbose:
1324+
print(f"🔄 {step.name} failed (attempt {retry_count}/{max_retries}), retrying in {backoff_seconds}s: {e}")
1325+
time.sleep(backoff_seconds)
1326+
continue # Retry
1327+
12761328
if self.on_step_error:
12771329
try:
12781330
self.on_step_error(step.name, e)
@@ -1826,7 +1878,6 @@ def _execute_specialized_agent(
18261878
# Only retry on transient errors (network, timeout, fetch failures)
18271879
if any(x in error_str for x in ['fetch', 'timeout', 'connection', 'network']):
18281880
if attempt < max_retries:
1829-
import time
18301881
time.sleep(1 * (attempt + 1)) # Exponential backoff
18311882
continue
18321883
# Non-retryable error, raise immediately
@@ -2306,7 +2357,7 @@ def execute_with_branch(step=step, idx=idx, opt_prev=optimized_previous):
23062357
emitter.set_branch(f"parallel_{idx}")
23072358
try:
23082359
return self._execute_single_step_internal(
2309-
step, opt_prev, input, all_variables.copy(), model, False, idx, stream, depth=depth+1
2360+
step, opt_prev, input, copy.deepcopy(all_variables), model, False, idx, stream, depth=depth+1
23102361
)
23112362
finally:
23122363
emitter.clear_branch()
@@ -4290,7 +4341,6 @@ def _save_checkpoint(
42904341
Path to checkpoint file
42914342
"""
42924343
import json
4293-
import time
42944344
from datetime import datetime
42954345

42964346
checkpoint_file = self._get_checkpoints_dir() / f"{name}.json"

0 commit comments

Comments
 (0)