Skip to content

Commit 8d96626

Browse files
fix: iscoroutine removed from hot path +50 RPS (#14649)
* fix: iscoroutine removed from hot path * fix: replace all instances & separate concerns 1. Replaced all instances of iscoroutine with is_async_callable 2. Place the coroutine checker in its own file * fix: PR comment changes * fix: missing config setting declaration * fix: revert non-performance related changes * fix: revert to initial implementation * fix: remove dead const
1 parent 00d8ded commit 8d96626

File tree

7 files changed

+246
-14
lines changed

7 files changed

+246
-14
lines changed

docs/my-website/docs/proxy/config_settings.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,3 +772,4 @@ router_settings:
772772
| WEBHOOK_URL | URL for receiving webhooks from external services
773773
| SPEND_LOG_RUN_LOOPS | Constant for setting how many runs of 1000 batch deletes should spend_log_cleanup task run |
774774
| SPEND_LOG_CLEANUP_BATCH_SIZE | Number of logs deleted per batch during cleanup. Default is 1000 |
775+
| COROUTINE_CHECKER_MAX_SIZE_IN_MEMORY | Maximum size for CoroutineChecker in-memory cache. Default is 1000 |

litellm/caching/redis_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import litellm
2020
from litellm._logging import print_verbose, verbose_logger
2121
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
22+
from litellm.litellm_core_utils.coroutine_checker import coroutine_checker
2223
from litellm.types.caching import RedisPipelineIncrementOperation
2324
from litellm.types.services import ServiceTypes
2425

@@ -138,7 +139,7 @@ def __init__(
138139
self.redis_flush_size = redis_flush_size
139140
self.redis_version = "Unknown"
140141
try:
141-
if not inspect.iscoroutinefunction(self.redis_client):
142+
if not coroutine_checker.is_async_callable(self.redis_client):
142143
self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
143144
except Exception:
144145
pass

litellm/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,3 +1063,6 @@
10631063
"SMTP_SENDER_EMAIL",
10641064
"TEST_EMAIL_ADDRESS",
10651065
]
1066+
1067+
# CoroutineChecker cache configuration
1068+
COROUTINE_CHECKER_MAX_SIZE_IN_MEMORY = int(os.getenv("COROUTINE_CHECKER_MAX_SIZE_IN_MEMORY", 1000))
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# CoroutineChecker utility for checking if functions/callables are coroutines or coroutine functions
2+
3+
import inspect
4+
from typing import Any
5+
from weakref import WeakKeyDictionary
6+
from litellm.constants import (
7+
COROUTINE_CHECKER_MAX_SIZE_IN_MEMORY,
8+
)
9+
10+
11+
class CoroutineChecker:
12+
"""Utility class for checking coroutine status of functions and callables.
13+
14+
Simple bounded cache using WeakKeyDictionary to avoid memory leaks.
15+
"""
16+
17+
def __init__(self):
18+
self._cache = WeakKeyDictionary()
19+
self._max_size = COROUTINE_CHECKER_MAX_SIZE_IN_MEMORY
20+
21+
def is_async_callable(self, callback: Any) -> bool:
22+
"""Fast, cached check for whether a callback is an async function.
23+
Falls back gracefully if the object cannot be weak-referenced or cached.
24+
2.59x speedup.
25+
"""
26+
# Fast path: check cache first (most common case)
27+
try:
28+
cached = self._cache.get(callback)
29+
if cached is not None:
30+
return cached
31+
except Exception:
32+
pass
33+
34+
# Determine target - optimized path for common cases
35+
target = callback
36+
if not inspect.isfunction(target) and not inspect.ismethod(target):
37+
try:
38+
call_attr = getattr(target, "__call__", None)
39+
if call_attr is not None:
40+
target = call_attr
41+
except Exception:
42+
pass
43+
44+
# Compute result
45+
try:
46+
result = inspect.iscoroutinefunction(target)
47+
except Exception:
48+
result = False
49+
50+
# Cache the result with size enforcement
51+
try:
52+
# Simple size enforcement: clear cache if it gets too large
53+
if len(self._cache) >= self._max_size:
54+
self._cache.clear()
55+
56+
self._cache[callback] = result
57+
except Exception:
58+
pass
59+
60+
return result
61+
62+
# Global instance for backward compatibility and convenience
63+
coroutine_checker = CoroutineChecker()

litellm/router.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from litellm.integrations.custom_logger import CustomLogger
5656
from litellm.litellm_core_utils.asyncify import run_async_function
5757
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
58+
from litellm.litellm_core_utils.coroutine_checker import coroutine_checker
5859
from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
5960
from litellm.litellm_core_utils.dd_tracing import tracer
6061
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
@@ -4049,7 +4050,7 @@ async def async_function_with_retries(self, *args, **kwargs): # noqa: PLR0915
40494050
try:
40504051
# if the function call is successful, no exception will be raised and we'll break out of the loop
40514052
response = await self.make_call(original_function, *args, **kwargs)
4052-
if inspect.iscoroutinefunction(
4053+
if coroutine_checker.is_async_callable(
40534054
response
40544055
): # async errors are often returned as coroutines
40554056
response = await response
@@ -4097,7 +4098,7 @@ async def make_call(self, original_function: Any, *args, **kwargs):
40974098
"""
40984099
model_group = kwargs.get("model")
40994100
response = original_function(*args, **kwargs)
4100-
if inspect.iscoroutinefunction(response) or inspect.isawaitable(response):
4101+
if coroutine_checker.is_async_callable(response) or inspect.isawaitable(response):
41014102
response = await response
41024103
## PROCESS RESPONSE HEADERS
41034104
response = await self.set_response_headers(

litellm/utils.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,12 @@ def get_dynamic_callbacks(
520520
return returned_callbacks
521521

522522

523+
524+
from litellm.litellm_core_utils.coroutine_checker import coroutine_checker
525+
526+
527+
528+
523529
def function_setup( # noqa: PLR0915
524530
original_function: str, rules_obj, start_time, *args, **kwargs
525531
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
@@ -594,7 +600,7 @@ def function_setup( # noqa: PLR0915
594600
if len(litellm.input_callback) > 0:
595601
removed_async_items = []
596602
for index, callback in enumerate(litellm.input_callback): # type: ignore
597-
if inspect.iscoroutinefunction(callback):
603+
if coroutine_checker.is_async_callable(callback):
598604
litellm._async_input_callback.append(callback)
599605
removed_async_items.append(index)
600606

@@ -604,7 +610,7 @@ def function_setup( # noqa: PLR0915
604610
if len(litellm.success_callback) > 0:
605611
removed_async_items = []
606612
for index, callback in enumerate(litellm.success_callback): # type: ignore
607-
if inspect.iscoroutinefunction(callback):
613+
if coroutine_checker.is_async_callable(callback):
608614
litellm.logging_callback_manager.add_litellm_async_success_callback(
609615
callback
610616
)
@@ -629,7 +635,7 @@ def function_setup( # noqa: PLR0915
629635
if len(litellm.failure_callback) > 0:
630636
removed_async_items = []
631637
for index, callback in enumerate(litellm.failure_callback): # type: ignore
632-
if inspect.iscoroutinefunction(callback):
638+
if coroutine_checker.is_async_callable(callback):
633639
litellm.logging_callback_manager.add_litellm_async_failure_callback(
634640
callback
635641
)
@@ -662,7 +668,7 @@ def function_setup( # noqa: PLR0915
662668
removed_async_items = []
663669
for index, callback in enumerate(kwargs["success_callback"]):
664670
if (
665-
inspect.iscoroutinefunction(callback)
671+
coroutine_checker.is_async_callable(callback)
666672
or callback == "dynamodb"
667673
or callback == "s3"
668674
):
@@ -899,12 +905,7 @@ def client(original_function): # noqa: PLR0915
899905
rules_obj = Rules()
900906

901907
def check_coroutine(value) -> bool:
902-
if inspect.iscoroutine(value):
903-
return True
904-
elif inspect.iscoroutinefunction(value):
905-
return True
906-
else:
907-
return False
908+
return coroutine_checker.is_async_callable(value)
908909

909910
async def async_pre_call_deployment_hook(kwargs: Dict[str, Any], call_type: str):
910911
"""
@@ -1598,7 +1599,7 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915
15981599
setattr(e, "timeout", timeout)
15991600
raise e
16001601

1601-
is_coroutine = inspect.iscoroutinefunction(original_function)
1602+
is_coroutine = coroutine_checker.is_async_callable(original_function)
16021603

16031604
# Return the appropriate wrapper based on the original function type
16041605
if is_coroutine:
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
Unit tests for CoroutineChecker utility class.
3+
4+
Focused test suite covering core functionality and main edge cases.
5+
"""
6+
7+
import pytest
8+
from unittest.mock import patch
9+
10+
from litellm.litellm_core_utils.coroutine_checker import CoroutineChecker, coroutine_checker
11+
12+
13+
class TestCoroutineChecker:
14+
"""Test cases for CoroutineChecker class."""
15+
16+
def setup_method(self):
17+
"""Set up test fixtures before each test method."""
18+
self.checker = CoroutineChecker()
19+
20+
def test_init(self):
21+
"""Test CoroutineChecker initialization."""
22+
checker = CoroutineChecker()
23+
assert isinstance(checker, CoroutineChecker)
24+
25+
@pytest.mark.parametrize("obj,expected,description", [
26+
# Basic function types
27+
(lambda: "sync", False, "sync lambda"),
28+
(len, False, "built-in function"),
29+
# Non-callable objects
30+
("string", False, "string"),
31+
(123, False, "integer"),
32+
([], False, "list"),
33+
({}, False, "dict"),
34+
(None, False, "None"),
35+
])
36+
def test_is_async_callable_basic_and_non_callable(self, obj, expected, description):
37+
"""Test is_async_callable with basic types and non-callable objects."""
38+
assert self.checker.is_async_callable(obj) is expected, f"Failed for {description}: {obj}"
39+
40+
def test_is_async_callable_async_and_sync_callables(self):
41+
"""Test is_async_callable with various async and sync callable types."""
42+
# Async and sync functions
43+
async def async_func():
44+
return "async"
45+
46+
def sync_func():
47+
return "sync"
48+
49+
# Class methods
50+
class TestClass:
51+
def sync_method(self):
52+
return "sync"
53+
54+
async def async_method(self):
55+
return "async"
56+
57+
obj = TestClass()
58+
59+
# Callable objects
60+
class SyncCallable:
61+
def __call__(self):
62+
return "sync"
63+
64+
class AsyncCallable:
65+
async def __call__(self):
66+
return "async"
67+
68+
# Test all async callables
69+
assert self.checker.is_async_callable(async_func) is True
70+
assert self.checker.is_async_callable(obj.async_method) is True
71+
assert self.checker.is_async_callable(AsyncCallable()) is True
72+
73+
# Test all sync callables
74+
assert self.checker.is_async_callable(sync_func) is False
75+
assert self.checker.is_async_callable(obj.sync_method) is False
76+
assert self.checker.is_async_callable(SyncCallable()) is False
77+
78+
def test_is_async_callable_caching(self):
79+
"""Test that is_async_callable caches callable objects."""
80+
async def async_func():
81+
return "async"
82+
83+
# Test that it works correctly
84+
result1 = self.checker.is_async_callable(async_func)
85+
assert result1 is True
86+
87+
# Test that callable objects are cached
88+
assert async_func in self.checker._cache
89+
assert self.checker._cache[async_func] is True
90+
91+
# Test that it works consistently
92+
result2 = self.checker.is_async_callable(async_func)
93+
assert result2 is True
94+
95+
def test_edge_cases_and_error_handling(self):
96+
"""Test edge cases and error handling."""
97+
from functools import partial
98+
99+
# Error handling cases
100+
class ProblematicCallable:
101+
def __getattr__(self, name):
102+
if name == "__call__":
103+
raise Exception("Cannot access __call__")
104+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
105+
106+
class UnstringableCallable:
107+
def __str__(self):
108+
raise Exception("Cannot convert to string")
109+
110+
async def __call__(self):
111+
return "async"
112+
113+
# Generator functions
114+
def sync_generator():
115+
yield "sync"
116+
117+
async def async_generator():
118+
yield "async"
119+
120+
# Partial functions
121+
def sync_func(x, y):
122+
return x + y
123+
124+
async def async_func(x, y):
125+
return x + y
126+
127+
sync_partial = partial(sync_func, 1)
128+
async_partial = partial(async_func, 1)
129+
130+
# Test error handling
131+
assert self.checker.is_async_callable(ProblematicCallable()) is False
132+
assert self.checker.is_async_callable(UnstringableCallable()) is True
133+
134+
# Test generators (both sync and async generators are not coroutine functions)
135+
assert self.checker.is_async_callable(sync_generator) is False
136+
assert self.checker.is_async_callable(async_generator) is False
137+
138+
# Test partial functions (don't preserve coroutine nature)
139+
assert self.checker.is_async_callable(sync_partial) is False
140+
assert self.checker.is_async_callable(async_partial) is False
141+
142+
def test_error_handling_in_inspect(self):
143+
"""Test error handling when inspect.iscoroutinefunction raises exception."""
144+
with patch('inspect.iscoroutinefunction', side_effect=Exception("Inspect error")):
145+
async def async_func():
146+
return "async"
147+
148+
# Should return False when inspect raises exception
149+
assert self.checker.is_async_callable(async_func) is False
150+
151+
def test_global_coroutine_checker_instance(self):
152+
"""Test the global coroutine_checker instance."""
153+
assert isinstance(coroutine_checker, CoroutineChecker)
154+
155+
async def async_func():
156+
return "async"
157+
158+
def sync_func():
159+
return "sync"
160+
161+
assert coroutine_checker.is_async_callable(async_func) is True
162+
assert coroutine_checker.is_async_callable(sync_func) is False

0 commit comments

Comments
 (0)