Skip to content

Commit c59173a

Browse files
fix: ensure async methods are executable for annotations
1 parent 4d8eec9 commit c59173a

File tree

4 files changed

+180
-19
lines changed

4 files changed

+180
-19
lines changed

lib/crewai/src/crewai/project/annotations.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from collections.abc import Callable
67
from functools import wraps
8+
import inspect
79
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
810

911
from crewai.project.utils import memoize
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
156158
return CacheHandlerMethod(memoize(meth))
157159

158160

161+
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
162+
"""Call a method, awaiting it if async and running in an event loop."""
163+
result = method(*args, **kwargs)
164+
if inspect.iscoroutine(result):
165+
try:
166+
loop = asyncio.get_running_loop()
167+
except RuntimeError:
168+
loop = None
169+
if loop and loop.is_running():
170+
import concurrent.futures
171+
172+
with concurrent.futures.ThreadPoolExecutor() as pool:
173+
return pool.submit(asyncio.run, result).result()
174+
return asyncio.run(result)
175+
return result
176+
177+
159178
@overload
160179
def crew(
161180
meth: Callable[Concatenate[SelfT, P], Crew],
@@ -198,7 +217,7 @@ def wrapper(self: CrewInstance, *args: Any, **kwargs: Any) -> Crew:
198217

199218
# Instantiate tasks in order
200219
for _, task_method in tasks:
201-
task_instance = task_method(self)
220+
task_instance = _call_method(task_method, self)
202221
instantiated_tasks.append(task_instance)
203222
agent_instance = getattr(task_instance, "agent", None)
204223
if agent_instance and agent_instance.role not in agent_roles:
@@ -207,15 +226,15 @@ def wrapper(self: CrewInstance, *args: Any, **kwargs: Any) -> Crew:
207226

208227
# Instantiate agents not included by tasks
209228
for _, agent_method in agents:
210-
agent_instance = agent_method(self)
229+
agent_instance = _call_method(agent_method, self)
211230
if agent_instance.role not in agent_roles:
212231
instantiated_agents.append(agent_instance)
213232
agent_roles.add(agent_instance.role)
214233

215234
self.agents = instantiated_agents
216235
self.tasks = instantiated_tasks
217236

218-
crew_instance = meth(self, *args, **kwargs)
237+
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
219238

220239
def callback_wrapper(
221240
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance

lib/crewai/src/crewai/project/utils.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Utility functions for the crewai project module."""
22

3-
from collections.abc import Callable
3+
from collections.abc import Callable, Coroutine
44
from functools import wraps
5+
import inspect
56
from typing import Any, ParamSpec, TypeVar, cast
67

78
from pydantic import BaseModel
@@ -37,27 +38,25 @@ def _make_hashable(arg: Any) -> Any:
3738
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
3839
"""Memoize a method by caching its results based on arguments.
3940
40-
Handles Pydantic BaseModel instances by converting them to JSON strings
41-
before hashing for cache lookup.
41+
Handles both sync and async methods. Pydantic BaseModel instances are
42+
converted to JSON strings before hashing for cache lookup.
4243
4344
Args:
4445
meth: The method to memoize.
4546
4647
Returns:
4748
A memoized version of the method that caches results.
4849
"""
50+
if inspect.iscoroutinefunction(meth):
51+
return cast(Callable[P, R], _memoize_async(meth))
52+
return _memoize_sync(meth)
4953

50-
@wraps(meth)
51-
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
52-
"""Wrapper that converts arguments to hashable form before caching.
5354

54-
Args:
55-
*args: Positional arguments to the memoized method.
56-
**kwargs: Keyword arguments to the memoized method.
55+
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
56+
"""Memoize a synchronous method."""
5757

58-
Returns:
59-
The result of the memoized method call.
60-
"""
58+
@wraps(meth)
59+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
6160
hashable_args = tuple(_make_hashable(arg) for arg in args)
6261
hashable_kwargs = tuple(
6362
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
@@ -73,3 +72,27 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
7372
return result
7473

7574
return cast(Callable[P, R], wrapper)
75+
76+
77+
def _memoize_async(
78+
meth: Callable[P, Coroutine[Any, Any, R]],
79+
) -> Callable[P, Coroutine[Any, Any, R]]:
80+
"""Memoize an async method."""
81+
82+
@wraps(meth)
83+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
84+
hashable_args = tuple(_make_hashable(arg) for arg in args)
85+
hashable_kwargs = tuple(
86+
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
87+
)
88+
cache_key = str((hashable_args, hashable_kwargs))
89+
90+
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
91+
if cached_result is not None:
92+
return cached_result
93+
94+
result = await meth(*args, **kwargs)
95+
cache.add(tool=meth.__name__, input=cache_key, output=result)
96+
return result
97+
98+
return wrapper

lib/crewai/src/crewai/project/wrappers.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
from collections.abc import Callable
67
from functools import partial
8+
import inspect
79
from pathlib import Path
810
from typing import (
911
TYPE_CHECKING,
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
132134
crew: Callable[..., Crew]
133135

134136

137+
def _resolve_result(result: Any) -> Any:
138+
"""Resolve a potentially async result to its value."""
139+
if inspect.iscoroutine(result):
140+
try:
141+
loop = asyncio.get_running_loop()
142+
except RuntimeError:
143+
loop = None
144+
if loop and loop.is_running():
145+
import concurrent.futures
146+
147+
with concurrent.futures.ThreadPoolExecutor() as pool:
148+
return pool.submit(asyncio.run, result).result()
149+
return asyncio.run(result)
150+
return result
151+
152+
135153
class DecoratedMethod(Generic[P, R]):
136154
"""Base wrapper for methods with decorator metadata.
137155
@@ -162,7 +180,12 @@ def __get__(
162180
"""
163181
if obj is None:
164182
return self
165-
bound = partial(self._meth, obj)
183+
inner = partial(self._meth, obj)
184+
185+
def _bound(*args: Any, **kwargs: Any) -> R:
186+
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
187+
return result
188+
166189
for attr in (
167190
"is_agent",
168191
"is_llm",
@@ -174,8 +197,8 @@ def __get__(
174197
"is_crew",
175198
):
176199
if hasattr(self, attr):
177-
setattr(bound, attr, getattr(self, attr))
178-
return bound
200+
setattr(_bound, attr, getattr(self, attr))
201+
return _bound
179202

180203
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
181204
"""Call the wrapped method.
@@ -236,6 +259,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> TaskResultT:
236259
The task result with name ensured.
237260
"""
238261
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
262+
result = _resolve_result(result)
239263
return self._task_method.ensure_task_name(result)
240264

241265

@@ -292,7 +316,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> TaskResultT:
292316
Returns:
293317
The task instance with name set if not already provided.
294318
"""
295-
return self.ensure_task_name(self._meth(*args, **kwargs))
319+
result = self._meth(*args, **kwargs)
320+
result = _resolve_result(result)
321+
return self.ensure_task_name(result)
296322

297323
def unwrap(self) -> Callable[P, TaskResultT]:
298324
"""Get the original unwrapped method.

lib/crewai/tests/test_project.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,99 @@ def another_simple_tool():
272272
return "Hi!"
273273

274274

275+
class TestAsyncDecoratorSupport:
276+
"""Tests for async method support in @agent, @task decorators."""
277+
278+
def test_async_agent_memoization(self):
279+
"""Async agent methods should be properly memoized."""
280+
281+
class AsyncAgentCrew:
282+
call_count = 0
283+
284+
@agent
285+
async def async_agent(self):
286+
AsyncAgentCrew.call_count += 1
287+
return Agent(
288+
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
289+
)
290+
291+
crew = AsyncAgentCrew()
292+
first_call = crew.async_agent()
293+
second_call = crew.async_agent()
294+
295+
assert first_call is second_call, "Async agent memoization failed"
296+
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
297+
298+
def test_async_task_memoization(self):
299+
"""Async task methods should be properly memoized."""
300+
301+
class AsyncTaskCrew:
302+
call_count = 0
303+
304+
@task
305+
async def async_task(self):
306+
AsyncTaskCrew.call_count += 1
307+
return Task(
308+
description="Async Description", expected_output="Async Output"
309+
)
310+
311+
crew = AsyncTaskCrew()
312+
first_call = crew.async_task()
313+
second_call = crew.async_task()
314+
315+
assert first_call is second_call, "Async task memoization failed"
316+
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
317+
318+
def test_async_task_name_inference(self):
319+
"""Async task should have name inferred from method name."""
320+
321+
class AsyncTaskNameCrew:
322+
@task
323+
async def my_async_task(self):
324+
return Task(
325+
description="Async Description", expected_output="Async Output"
326+
)
327+
328+
crew = AsyncTaskNameCrew()
329+
task_instance = crew.my_async_task()
330+
331+
assert task_instance.name == "my_async_task", (
332+
"Async task name not inferred correctly"
333+
)
334+
335+
def test_async_agent_returns_agent_not_coroutine(self):
336+
"""Async agent decorator should return Agent, not coroutine."""
337+
338+
class AsyncAgentTypeCrew:
339+
@agent
340+
async def typed_async_agent(self):
341+
return Agent(
342+
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
343+
)
344+
345+
crew = AsyncAgentTypeCrew()
346+
result = crew.typed_async_agent()
347+
348+
assert isinstance(result, Agent), (
349+
f"Expected Agent, got {type(result).__name__}"
350+
)
351+
352+
def test_async_task_returns_task_not_coroutine(self):
353+
"""Async task decorator should return Task, not coroutine."""
354+
355+
class AsyncTaskTypeCrew:
356+
@task
357+
async def typed_async_task(self):
358+
return Task(
359+
description="Typed Description", expected_output="Typed Output"
360+
)
361+
362+
crew = AsyncTaskTypeCrew()
363+
result = crew.typed_async_task()
364+
365+
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
366+
367+
275368
def test_internal_crew_with_mcp():
276369
from crewai_tools.adapters.tool_collection import ToolCollection
277370

0 commit comments

Comments
 (0)