Skip to content

Commit 54710a8

Browse files
fix: hash callback args correctly to ensure caching works
1 parent 5abf976 commit 54710a8

File tree

3 files changed

+161
-8
lines changed

3 files changed

+161
-8
lines changed
Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,75 @@
11
"""Utility functions for the crewai project module."""
22

33
from collections.abc import Callable
4-
from functools import lru_cache
5-
from typing import ParamSpec, TypeVar, cast
4+
from functools import wraps
5+
from typing import Any, ParamSpec, TypeVar, cast
6+
7+
from pydantic import BaseModel
8+
9+
from crewai.agents.cache.cache_handler import CacheHandler
610

711

812
P = ParamSpec("P")
913
R = TypeVar("R")
14+
cache = CacheHandler()
15+
16+
17+
def _make_hashable(arg: Any) -> Any:
18+
"""Convert argument to hashable form for caching.
19+
20+
Args:
21+
arg: The argument to convert.
22+
23+
Returns:
24+
Hashable representation of the argument.
25+
"""
26+
if isinstance(arg, BaseModel):
27+
return arg.model_dump_json()
28+
if isinstance(arg, dict):
29+
return tuple(sorted((k, _make_hashable(v)) for k, v in arg.items()))
30+
if isinstance(arg, list):
31+
return tuple(_make_hashable(item) for item in arg)
32+
if hasattr(arg, "__dict__"):
33+
return ("__instance__", id(arg))
34+
return arg
1035

1136

1237
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
1338
"""Memoize a method by caching its results based on arguments.
1439
40+
Handles Pydantic BaseModel instances by converting them to JSON strings
41+
before hashing for cache lookup.
42+
1543
Args:
1644
meth: The method to memoize.
1745
1846
Returns:
1947
A memoized version of the method that caches results.
2048
"""
21-
return cast(Callable[P, R], lru_cache(typed=True)(meth))
49+
50+
@wraps(meth)
51+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
52+
"""Wrapper that converts arguments to hashable form before caching.
53+
54+
Args:
55+
*args: Positional arguments to the memoized method.
56+
**kwargs: Keyword arguments to the memoized method.
57+
58+
Returns:
59+
The result of the memoized method call.
60+
"""
61+
hashable_args = tuple(_make_hashable(arg) for arg in args)
62+
hashable_kwargs = tuple(
63+
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
64+
)
65+
cache_key = str((hashable_args, hashable_kwargs))
66+
67+
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
68+
if cached_result is not None:
69+
return cached_result
70+
71+
result = meth(*args, **kwargs)
72+
cache.add(tool=meth.__name__, input=cache_key, output=result)
73+
return result
74+
75+
return cast(Callable[P, R], wrapper)

lib/crewai/tests/experimental/evaluation/test_agent_evaluator.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,23 @@ def test_evaluate_current_iteration(self, mock_crew):
6262
agents=mock_crew.agents, evaluators=[GoalAlignmentEvaluator()]
6363
)
6464

65-
task_completed_event = threading.Event()
65+
task_completed_condition = threading.Condition()
66+
task_completed = False
6667

6768
@crewai_event_bus.on(TaskCompletedEvent)
6869
async def on_task_completed(source, event):
6970
# TaskCompletedEvent fires AFTER evaluation results are stored
70-
task_completed_event.set()
71+
nonlocal task_completed
72+
with task_completed_condition:
73+
task_completed = True
74+
task_completed_condition.notify()
7175

7276
mock_crew.kickoff()
7377

74-
assert task_completed_event.wait(timeout=5), (
75-
"Timeout waiting for task completion"
76-
)
78+
with task_completed_condition:
79+
assert task_completed_condition.wait_for(
80+
lambda: task_completed, timeout=5
81+
), "Timeout waiting for task completion"
7782

7883
results = agent_evaluator.get_evaluation_results()
7984

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Test callback decorator with TaskOutput arguments."""
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
from crewai import Agent, Crew, Task
6+
from crewai.project import CrewBase, callback, task
7+
from crewai.tasks.output_format import OutputFormat
8+
from crewai.tasks.task_output import TaskOutput
9+
10+
11+
def test_callback_decorator_with_taskoutput() -> None:
12+
"""Test that @callback decorator works with TaskOutput arguments."""
13+
14+
@CrewBase
15+
class TestCrew:
16+
"""Test crew with callback."""
17+
18+
callback_called = False
19+
callback_output = None
20+
21+
@callback
22+
def task_callback(self, output: TaskOutput) -> None:
23+
"""Test callback that receives TaskOutput."""
24+
self.callback_called = True
25+
self.callback_output = output
26+
27+
@task
28+
def test_task(self) -> Task:
29+
"""Test task with callback."""
30+
return Task(
31+
description="Test task",
32+
expected_output="Test output",
33+
callback=self.task_callback,
34+
)
35+
36+
test_crew = TestCrew()
37+
task_instance = test_crew.test_task()
38+
39+
test_output = TaskOutput(
40+
description="Test task",
41+
agent="Test Agent",
42+
raw="test result",
43+
output_format=OutputFormat.RAW,
44+
)
45+
46+
task_instance.callback(test_output)
47+
48+
assert test_crew.callback_called
49+
assert test_crew.callback_output == test_output
50+
51+
52+
def test_callback_decorator_with_taskoutput_integration() -> None:
53+
"""Integration test for callback with actual task execution."""
54+
55+
@CrewBase
56+
class TestCrew:
57+
"""Test crew with callback integration."""
58+
59+
callback_called = False
60+
received_output: TaskOutput | None = None
61+
62+
@callback
63+
def task_callback(self, output: TaskOutput) -> None:
64+
"""Callback executed after task completion."""
65+
self.callback_called = True
66+
self.received_output = output
67+
68+
@task
69+
def test_task(self) -> Task:
70+
"""Test task."""
71+
return Task(
72+
description="Test task",
73+
expected_output="Test output",
74+
callback=self.task_callback,
75+
)
76+
77+
test_crew = TestCrew()
78+
79+
agent = Agent(
80+
role="Test Agent",
81+
goal="Test goal",
82+
backstory="Test backstory",
83+
)
84+
85+
task_instance = test_crew.test_task()
86+
task_instance.agent = agent
87+
88+
with patch.object(Agent, "execute_task") as mock_execute:
89+
mock_execute.return_value = "test result"
90+
task_instance.execute_sync()
91+
92+
assert test_crew.callback_called
93+
assert test_crew.received_output is not None
94+
assert test_crew.received_output.raw == "test result"

0 commit comments

Comments
 (0)