Skip to content

Commit 4fcbfb6

Browse files
fix: ensure callback execution in all achat() paths
- Enhanced _chat_completion method signature to accept task metadata parameters - Updated all calls to _chat_completion to pass task context - Fixed NameError where task_name, task_description, task_id were undefined - Updated _apply_guardrail_with_retry to propagate task metadata - Ensures complete task metadata propagation in sync chat execution path Co-authored-by: Mervin Praison <[email protected]>
1 parent e138249 commit 4fcbfb6

File tree

1 file changed

+14
-14
lines changed
  • src/praisonai-agents/praisonaiagents/agent

1 file changed

+14
-14
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -653,7 +653,7 @@ def _process_guardrail(self, task_output):
653653
error=f"Agent guardrail validation error: {str(e)}"
654654
)
655655

656-
def _apply_guardrail_with_retry(self, response_text, prompt, temperature=0.2, tools=None):
656+
def _apply_guardrail_with_retry(self, response_text, prompt, temperature=0.2, tools=None, task_name=None, task_description=None, task_id=None):
657657
"""Apply guardrail validation with retry logic.
658658
659659
Args:
@@ -707,7 +707,7 @@ def _apply_guardrail_with_retry(self, response_text, prompt, temperature=0.2, to
707707
# Regenerate response for retry
708708
try:
709709
retry_prompt = f"{prompt}\n\nNote: Previous response failed validation due to: {guardrail_result.error}. Please provide an improved response."
710-
response = self._chat_completion([{"role": "user", "content": retry_prompt}], temperature, tools)
710+
response = self._chat_completion([{"role": "user", "content": retry_prompt}], temperature, tools, task_name=task_name, task_description=task_description, task_id=task_id)
711711
if response and response.choices:
712712
current_response = response.choices[0].message.content.strip()
713713
else:
@@ -1072,7 +1072,7 @@ def _process_stream_response(self, messages, temperature, start_time, formatted_
10721072
reasoning_steps=reasoning_steps
10731073
)
10741074

1075-
def _chat_completion(self, messages, temperature=0.2, tools=None, stream=True, reasoning_steps=False):
1075+
def _chat_completion(self, messages, temperature=0.2, tools=None, stream=True, reasoning_steps=False, task_name=None, task_description=None, task_id=None):
10761076
start_time = time.time()
10771077
logging.debug(f"{self.name} sending messages to LLM: {messages}")
10781078

@@ -1297,7 +1297,7 @@ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pyd
12971297

12981298
# Apply guardrail validation for custom LLM response
12991299
try:
1300-
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools)
1300+
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools, task_name, task_description, task_id)
13011301
return validated_response
13021302
except Exception as e:
13031303
logging.error(f"Agent {self.name}: Guardrail validation failed for custom LLM: {e}")
@@ -1357,7 +1357,7 @@ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pyd
13571357
agent_tools=agent_tools
13581358
)
13591359

1360-
response = self._chat_completion(messages, temperature=temperature, tools=tools if tools else None, reasoning_steps=reasoning_steps, stream=self.stream)
1360+
response = self._chat_completion(messages, temperature=temperature, tools=tools if tools else None, reasoning_steps=reasoning_steps, stream=self.stream, task_name=task_name, task_description=task_description, task_id=task_id)
13611361
if not response:
13621362
# Rollback chat history on response failure
13631363
self.chat_history = self.chat_history[:chat_history_length]
@@ -1372,7 +1372,7 @@ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pyd
13721372
self.chat_history.append({"role": "assistant", "content": response_text})
13731373
# Apply guardrail validation even for JSON output
13741374
try:
1375-
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools)
1375+
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools, task_name, task_description, task_id)
13761376
# Execute callback after validation
13771377
self._execute_callback_and_display(original_prompt, validated_response, time.time() - start_time, task_name, task_description, task_id)
13781378
return validated_response
@@ -1391,7 +1391,7 @@ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pyd
13911391
if reasoning_steps and hasattr(response.choices[0].message, 'reasoning_content'):
13921392
# Apply guardrail to reasoning content
13931393
try:
1394-
validated_reasoning = self._apply_guardrail_with_retry(response.choices[0].message.reasoning_content, original_prompt, temperature, tools)
1394+
validated_reasoning = self._apply_guardrail_with_retry(response.choices[0].message.reasoning_content, original_prompt, temperature, tools, task_name, task_description, task_id)
13951395
# Execute callback after validation
13961396
self._execute_callback_and_display(original_prompt, validated_reasoning, time.time() - start_time, task_name, task_description, task_id)
13971397
return validated_reasoning
@@ -1402,7 +1402,7 @@ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pyd
14021402
return None
14031403
# Apply guardrail to regular response
14041404
try:
1405-
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools)
1405+
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools, task_name, task_description, task_id)
14061406
# Execute callback after validation
14071407
self._execute_callback_and_display(original_prompt, validated_response, time.time() - start_time, task_name, task_description, task_id)
14081408
return validated_response
@@ -1426,7 +1426,7 @@ def chat(self, prompt, temperature=0.2, tools=None, output_json=None, output_pyd
14261426
if self._using_custom_llm or self._openai_client is None:
14271427
# For custom LLMs, we need to handle reflection differently
14281428
# Use non-streaming to get complete JSON response
1429-
reflection_response = self._chat_completion(messages, temperature=temperature, tools=None, stream=False, reasoning_steps=False)
1429+
reflection_response = self._chat_completion(messages, temperature=temperature, tools=None, stream=False, reasoning_steps=False, task_name=task_name, task_description=task_description, task_id=task_id)
14301430

14311431
if not reflection_response or not reflection_response.choices:
14321432
raise Exception("No response from reflection request")
@@ -1470,7 +1470,7 @@ def __init__(self, data):
14701470
self.chat_history.append({"role": "assistant", "content": response_text})
14711471
# Apply guardrail validation after satisfactory reflection
14721472
try:
1473-
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools)
1473+
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools, task_name, task_description, task_id)
14741474
# Execute callback after validation
14751475
self._execute_callback_and_display(original_prompt, validated_response, time.time() - start_time, task_name, task_description, task_id)
14761476
return validated_response
@@ -1488,7 +1488,7 @@ def __init__(self, data):
14881488
self.chat_history.append({"role": "assistant", "content": response_text})
14891489
# Apply guardrail validation after max reflections
14901490
try:
1491-
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools)
1491+
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools, task_name, task_description, task_id)
14921492
# Execute callback after validation
14931493
self._execute_callback_and_display(original_prompt, validated_response, time.time() - start_time, task_name, task_description, task_id)
14941494
return validated_response
@@ -1503,7 +1503,7 @@ def __init__(self, data):
15031503
messages.append({"role": "user", "content": "Now regenerate your response using the reflection you made"})
15041504
# For custom LLMs during reflection, always use non-streaming to ensure complete responses
15051505
use_stream = self.stream if not self._using_custom_llm else False
1506-
response = self._chat_completion(messages, temperature=temperature, tools=None, stream=use_stream)
1506+
response = self._chat_completion(messages, temperature=temperature, tools=None, stream=use_stream, task_name=task_name, task_description=task_description, task_id=task_id)
15071507
response_text = response.choices[0].message.content.strip()
15081508
reflection_count += 1
15091509
continue # Continue the loop for more reflections
@@ -1620,7 +1620,7 @@ async def achat(self, prompt: str, temperature=0.2, tools=None, output_json=None
16201620

16211621
# Apply guardrail validation for custom LLM response
16221622
try:
1623-
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools)
1623+
validated_response = self._apply_guardrail_with_retry(response_text, prompt, temperature, tools, task_name, task_description, task_id)
16241624
# Execute callback after validation
16251625
self._execute_callback_and_display(normalized_content, validated_response, time.time() - start_time, task_name, task_description, task_id)
16261626
return validated_response
@@ -1810,7 +1810,7 @@ async def achat(self, prompt: str, temperature=0.2, tools=None, output_json=None
18101810

18111811
# Apply guardrail validation for OpenAI client response
18121812
try:
1813-
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools)
1813+
validated_response = self._apply_guardrail_with_retry(response_text, original_prompt, temperature, tools, task_name, task_description, task_id)
18141814
# Execute callback after validation
18151815
self._execute_callback_and_display(original_prompt, validated_response, time.time() - start_time, task_name, task_description, task_id)
18161816
return validated_response

0 commit comments

Comments
 (0)