Skip to content

Commit c456e5c

Browse files
authored
Lorenze/ensure hooks work with lite agents flows (#3981)
* liteagent support hooks * wip llm.call hooks work - needs tests for this * fix tests * fixed more * more tool hooks test cassettes
1 parent 633e279 commit c456e5c

17 files changed

+1639
-52
lines changed

lib/crewai/src/crewai/hooks/llm_hooks.py

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
from typing import TYPE_CHECKING, Any, cast
44

55
from crewai.events.event_listener import event_listener
66
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
@@ -9,51 +9,98 @@
99

1010
if TYPE_CHECKING:
1111
from crewai.agents.crew_agent_executor import CrewAgentExecutor
12+
from crewai.lite_agent import LiteAgent
13+
from crewai.llms.base_llm import BaseLLM
14+
from crewai.utilities.types import LLMMessage
1215

1316

1417
class LLMCallHookContext:
15-
"""Context object passed to LLM call hooks with full executor access.
18+
"""Context object passed to LLM call hooks.
1619
17-
Provides hooks with complete access to the executor state, allowing
20+
Provides hooks with complete access to the execution state, allowing
1821
modification of messages, responses, and executor attributes.
1922
23+
Supports both executor-based calls (agents in crews/flows) and direct LLM calls.
24+
2025
Attributes:
21-
executor: Full reference to the CrewAgentExecutor instance
22-
messages: Direct reference to executor.messages (mutable list).
26+
executor: Reference to the executor (CrewAgentExecutor/LiteAgent) or None for direct calls
27+
messages: Direct reference to messages (mutable list).
2328
Can be modified in both before_llm_call and after_llm_call hooks.
2429
Modifications in after_llm_call hooks persist to the next iteration,
2530
allowing hooks to modify conversation history for subsequent LLM calls.
2631
IMPORTANT: Modify messages in-place (e.g., append, extend, remove items).
2732
Do NOT replace the list (e.g., context.messages = []), as this will break
2833
the executor. Use context.messages.append() or context.messages.extend()
2934
instead of assignment.
30-
agent: Reference to the agent executing the task
31-
task: Reference to the task being executed
32-
crew: Reference to the crew instance
35+
agent: Reference to the agent executing the task (None for direct LLM calls)
36+
task: Reference to the task being executed (None for direct LLM calls or LiteAgent)
37+
crew: Reference to the crew instance (None for direct LLM calls or LiteAgent)
3338
llm: Reference to the LLM instance
34-
iterations: Current iteration count
39+
iterations: Current iteration count (0 for direct LLM calls)
3540
response: LLM response string (only set for after_llm_call hooks).
3641
Can be modified by returning a new string from after_llm_call hook.
3742
"""
3843

44+
executor: CrewAgentExecutor | LiteAgent | None
45+
messages: list[LLMMessage]
46+
agent: Any
47+
task: Any
48+
crew: Any
49+
llm: BaseLLM | None | str | Any
50+
iterations: int
51+
response: str | None
52+
3953
def __init__(
4054
self,
41-
executor: CrewAgentExecutor,
55+
executor: CrewAgentExecutor | LiteAgent | None = None,
4256
response: str | None = None,
57+
messages: list[LLMMessage] | None = None,
58+
llm: BaseLLM | str | Any | None = None, # TODO: look into
59+
agent: Any | None = None,
60+
task: Any | None = None,
61+
crew: Any | None = None,
4362
) -> None:
44-
"""Initialize hook context with executor reference.
63+
"""Initialize hook context with executor reference or direct parameters.
4564
4665
Args:
47-
executor: The CrewAgentExecutor instance
66+
executor: The CrewAgentExecutor or LiteAgent instance (None for direct LLM calls)
4867
response: Optional response string (for after_llm_call hooks)
68+
messages: Optional messages list (for direct LLM calls when executor is None)
69+
llm: Optional LLM instance (for direct LLM calls when executor is None)
70+
agent: Optional agent reference (for direct LLM calls when executor is None)
71+
task: Optional task reference (for direct LLM calls when executor is None)
72+
crew: Optional crew reference (for direct LLM calls when executor is None)
4973
"""
50-
self.executor = executor
51-
self.messages = executor.messages
52-
self.agent = executor.agent
53-
self.task = executor.task
54-
self.crew = executor.crew
55-
self.llm = executor.llm
56-
self.iterations = executor.iterations
74+
if executor is not None:
75+
# Existing path: extract from executor
76+
self.executor = executor
77+
self.messages = executor.messages
78+
self.llm = executor.llm
79+
self.iterations = executor.iterations
80+
# Handle CrewAgentExecutor vs LiteAgent differences
81+
if hasattr(executor, "agent"):
82+
self.agent = executor.agent
83+
self.task = cast("CrewAgentExecutor", executor).task
84+
self.crew = cast("CrewAgentExecutor", executor).crew
85+
else:
86+
# LiteAgent case - is the agent itself, doesn't have task/crew
87+
self.agent = (
88+
executor.original_agent
89+
if hasattr(executor, "original_agent")
90+
else executor
91+
)
92+
self.task = None
93+
self.crew = None
94+
else:
95+
# New path: direct LLM call with explicit parameters
96+
self.executor = None
97+
self.messages = messages or []
98+
self.llm = llm
99+
self.agent = agent
100+
self.task = task
101+
self.crew = crew
102+
self.iterations = 0
103+
57104
self.response = response
58105

59106
def request_human_input(

lib/crewai/src/crewai/lite_agent.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
)
3939
from crewai.events.types.logging_events import AgentLogsExecutionEvent
4040
from crewai.flow.flow_trackable import FlowTrackable
41+
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
42+
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
4143
from crewai.lite_agent_output import LiteAgentOutput
4244
from crewai.llm import LLM
4345
from crewai.llms.base_llm import BaseLLM
@@ -155,6 +157,12 @@ class LiteAgent(FlowTrackable, BaseModel):
155157
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
156158
_guardrail_retry_count: int = PrivateAttr(default=0)
157159
_callbacks: list[TokenCalcHandler] = PrivateAttr(default_factory=list)
160+
_before_llm_call_hooks: list[BeforeLLMCallHookType] = PrivateAttr(
161+
default_factory=get_before_llm_call_hooks
162+
)
163+
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
164+
default_factory=get_after_llm_call_hooks
165+
)
158166

159167
@model_validator(mode="after")
160168
def setup_llm(self) -> Self:
@@ -246,6 +254,26 @@ def _original_role(self) -> str:
246254
"""Return the original role for compatibility with tool interfaces."""
247255
return self.role
248256

257+
@property
258+
def before_llm_call_hooks(self) -> list[BeforeLLMCallHookType]:
259+
"""Get the before_llm_call hooks for this agent."""
260+
return self._before_llm_call_hooks
261+
262+
@property
263+
def after_llm_call_hooks(self) -> list[AfterLLMCallHookType]:
264+
"""Get the after_llm_call hooks for this agent."""
265+
return self._after_llm_call_hooks
266+
267+
@property
268+
def messages(self) -> list[LLMMessage]:
269+
"""Get the messages list for hook context compatibility."""
270+
return self._messages
271+
272+
@property
273+
def iterations(self) -> int:
274+
"""Get the current iteration count for hook context compatibility."""
275+
return self._iterations
276+
249277
def kickoff(
250278
self,
251279
messages: str | list[LLMMessage],
@@ -504,7 +532,7 @@ def _invoke_loop(self) -> AgentFinish:
504532
AgentFinish: The final result of the agent execution.
505533
"""
506534
# Execute the agent loop
507-
formatted_answer = None
535+
formatted_answer: AgentAction | AgentFinish | None = None
508536
while not isinstance(formatted_answer, AgentFinish):
509537
try:
510538
if has_reached_max_iterations(self._iterations, self.max_iterations):
@@ -526,6 +554,7 @@ def _invoke_loop(self) -> AgentFinish:
526554
callbacks=self._callbacks,
527555
printer=self._printer,
528556
from_agent=self,
557+
executor_context=self,
529558
)
530559

531560
except Exception as e:

lib/crewai/src/crewai/llm.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,6 +1642,10 @@ def call(
16421642
if message.get("role") == "system":
16431643
msg_role: Literal["assistant"] = "assistant"
16441644
message["role"] = msg_role
1645+
1646+
if not self._invoke_before_llm_call_hooks(messages, from_agent):
1647+
raise ValueError("LLM call blocked by before_llm_call hook")
1648+
16451649
# --- 5) Set up callbacks if provided
16461650
with suppress_warnings():
16471651
if callbacks and len(callbacks) > 0:
@@ -1651,7 +1655,16 @@ def call(
16511655
params = self._prepare_completion_params(messages, tools)
16521656
# --- 7) Make the completion call and handle response
16531657
if self.stream:
1654-
return self._handle_streaming_response(
1658+
result = self._handle_streaming_response(
1659+
params=params,
1660+
callbacks=callbacks,
1661+
available_functions=available_functions,
1662+
from_task=from_task,
1663+
from_agent=from_agent,
1664+
response_model=response_model,
1665+
)
1666+
else:
1667+
result = self._handle_non_streaming_response(
16551668
params=params,
16561669
callbacks=callbacks,
16571670
available_functions=available_functions,
@@ -1660,14 +1673,12 @@ def call(
16601673
response_model=response_model,
16611674
)
16621675

1663-
return self._handle_non_streaming_response(
1664-
params=params,
1665-
callbacks=callbacks,
1666-
available_functions=available_functions,
1667-
from_task=from_task,
1668-
from_agent=from_agent,
1669-
response_model=response_model,
1670-
)
1676+
if isinstance(result, str):
1677+
result = self._invoke_after_llm_call_hooks(
1678+
messages, result, from_agent
1679+
)
1680+
1681+
return result
16711682
except LLMContextLengthExceededError:
16721683
# Re-raise LLMContextLengthExceededError as it should be handled
16731684
# by the CrewAgentExecutor._invoke_loop method, which can then decide

lib/crewai/src/crewai/llms/base_llm.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def _emit_call_completed_event(
314314
call_type: LLMCallType,
315315
from_task: Task | None = None,
316316
from_agent: Agent | None = None,
317-
messages: str | list[dict[str, Any]] | None = None,
317+
messages: str | list[LLMMessage] | None = None,
318318
) -> None:
319319
"""Emit LLM call completed event."""
320320
crewai_event_bus.emit(
@@ -586,3 +586,134 @@ def get_token_usage_summary(self) -> UsageMetrics:
586586
Dictionary with token usage totals
587587
"""
588588
return UsageMetrics(**self._token_usage)
589+
590+
def _invoke_before_llm_call_hooks(
591+
self,
592+
messages: list[LLMMessage],
593+
from_agent: Agent | None = None,
594+
) -> bool:
595+
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
596+
597+
This method should be called by native provider implementations before
598+
making the actual LLM call when from_agent is None (direct calls).
599+
600+
Args:
601+
messages: The messages being sent to the LLM
602+
from_agent: The agent making the call (None for direct calls)
603+
604+
Returns:
605+
True if LLM call should proceed, False if blocked by hook
606+
607+
Example:
608+
>>> # In a native provider's call() method:
609+
>>> if from_agent is None and not self._invoke_before_llm_call_hooks(
610+
... messages, from_agent
611+
... ):
612+
... raise ValueError("LLM call blocked by hook")
613+
"""
614+
# Only invoke hooks for direct calls (no agent context)
615+
if from_agent is not None:
616+
return True
617+
618+
from crewai.hooks.llm_hooks import (
619+
LLMCallHookContext,
620+
get_before_llm_call_hooks,
621+
)
622+
from crewai.utilities.printer import Printer
623+
624+
before_hooks = get_before_llm_call_hooks()
625+
if not before_hooks:
626+
return True
627+
628+
hook_context = LLMCallHookContext(
629+
executor=None,
630+
messages=messages,
631+
llm=self,
632+
agent=None,
633+
task=None,
634+
crew=None,
635+
)
636+
printer = Printer()
637+
638+
try:
639+
for hook in before_hooks:
640+
result = hook(hook_context)
641+
if result is False:
642+
printer.print(
643+
content="LLM call blocked by before_llm_call hook",
644+
color="yellow",
645+
)
646+
return False
647+
except Exception as e:
648+
printer.print(
649+
content=f"Error in before_llm_call hook: {e}",
650+
color="yellow",
651+
)
652+
653+
return True
654+
655+
def _invoke_after_llm_call_hooks(
656+
self,
657+
messages: list[LLMMessage],
658+
response: str,
659+
from_agent: Agent | None = None,
660+
) -> str:
661+
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
662+
663+
This method should be called by native provider implementations after
664+
receiving the LLM response when from_agent is None (direct calls).
665+
666+
Args:
667+
messages: The messages that were sent to the LLM
668+
response: The response from the LLM
669+
from_agent: The agent that made the call (None for direct calls)
670+
671+
Returns:
672+
The potentially modified response string
673+
674+
Example:
675+
>>> # In a native provider's call() method:
676+
>>> if from_agent is None and isinstance(result, str):
677+
... result = self._invoke_after_llm_call_hooks(
678+
... messages, result, from_agent
679+
... )
680+
"""
681+
# Only invoke hooks for direct calls (no agent context)
682+
if from_agent is not None or not isinstance(response, str):
683+
return response
684+
685+
from crewai.hooks.llm_hooks import (
686+
LLMCallHookContext,
687+
get_after_llm_call_hooks,
688+
)
689+
from crewai.utilities.printer import Printer
690+
691+
after_hooks = get_after_llm_call_hooks()
692+
if not after_hooks:
693+
return response
694+
695+
hook_context = LLMCallHookContext(
696+
executor=None,
697+
messages=messages,
698+
llm=self,
699+
agent=None,
700+
task=None,
701+
crew=None,
702+
response=response,
703+
)
704+
printer = Printer()
705+
modified_response = response
706+
707+
try:
708+
for hook in after_hooks:
709+
result = hook(hook_context)
710+
if result is not None and isinstance(result, str):
711+
modified_response = result
712+
hook_context.response = modified_response
713+
except Exception as e:
714+
printer.print(
715+
content=f"Error in after_llm_call hook: {e}",
716+
color="yellow",
717+
)
718+
719+
return modified_response

0 commit comments

Comments
 (0)