Skip to content

Commit 5940942

Browse files
fix: reduced __inits__ overhead in 7% (#14689)
* fix: avoid redundant __init__ calls on hot path Previously, imports on the request hot path caused __init__ to run excessively for every request. This change ensures initialization happens once, reducing cpu overhead. * fix: remove redundant __init__ import The current implementation no longer requires an import at the top of the function. * fix: placed on core utils for future reuse * test: add coverage & remove inline import A general import-checking tool across all endpoints would be a large PR. This commit focuses on a smaller, targeted fix for the discussed case. * added import check to CI
1 parent 4c983f9 commit 5940942

File tree

5 files changed

+116
-14
lines changed

5 files changed

+116
-14
lines changed

.circleci/config.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,6 +1458,7 @@ jobs:
14581458
# - run: python ./tests/documentation_tests/test_general_setting_keys.py
14591459
- run: python ./tests/code_coverage_tests/check_licenses.py
14601460
- run: python ./tests/code_coverage_tests/router_code_coverage.py
1461+
- run: python ./tests/code_coverage_tests/test_chat_completion_imports.py
14611462
- run: python ./tests/code_coverage_tests/info_log_check.py
14621463
- run: python ./tests/code_coverage_tests/test_ban_set_verbose.py
14631464
- run: python ./tests/code_coverage_tests/code_qa_check_tests.py
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""
2+
Cached imports module for LiteLLM.
3+
4+
This module provides cached import functionality to avoid repeated imports
5+
inside functions that are critical to performance.
6+
"""
7+
8+
from typing import TYPE_CHECKING, Callable, Optional, Type
9+
10+
# Type annotations for cached imports
11+
if TYPE_CHECKING:
12+
from litellm.litellm_core_utils.litellm_logging import Logging
13+
from litellm.litellm_core_utils.coroutine_checker import CoroutineChecker
14+
15+
# Global cache variables
16+
_LiteLLMLogging: Optional[Type["Logging"]] = None
17+
_coroutine_checker: Optional["CoroutineChecker"] = None
18+
_set_callbacks: Optional[Callable] = None
19+
20+
21+
def get_litellm_logging_class() -> Type["Logging"]:
22+
"""Get the cached LiteLLM Logging class, initializing if needed."""
23+
global _LiteLLMLogging
24+
if _LiteLLMLogging is not None:
25+
return _LiteLLMLogging
26+
from litellm.litellm_core_utils.litellm_logging import Logging
27+
_LiteLLMLogging = Logging
28+
return _LiteLLMLogging
29+
30+
31+
def get_coroutine_checker() -> "CoroutineChecker":
32+
"""Get the cached coroutine checker instance, initializing if needed."""
33+
global _coroutine_checker
34+
if _coroutine_checker is not None:
35+
return _coroutine_checker
36+
from litellm.litellm_core_utils.coroutine_checker import coroutine_checker
37+
_coroutine_checker = coroutine_checker
38+
return _coroutine_checker
39+
40+
41+
def get_set_callbacks() -> Callable:
42+
"""Get the cached set_callbacks function, initializing if needed."""
43+
global _set_callbacks
44+
if _set_callbacks is not None:
45+
return _set_callbacks
46+
from litellm.litellm_core_utils.litellm_logging import set_callbacks
47+
_set_callbacks = set_callbacks
48+
return _set_callbacks
49+
50+
51+
def clear_cached_imports() -> None:
52+
"""Clear all cached imports. Useful for testing or memory management."""
53+
global _LiteLLMLogging, _coroutine_checker, _set_callbacks
54+
_LiteLLMLogging = None
55+
_coroutine_checker = None
56+
_set_callbacks = None

litellm/proxy/common_utils/http_parsing_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import re
23
from typing import Any, Dict, List, Optional
34

45
import orjson
@@ -51,8 +52,6 @@ async def _read_request_body(request: Optional[Request]) -> Dict:
5152
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
5253

5354
# Replace invalid surrogate pairs
54-
import re
55-
5655
# This regex finds incomplete surrogate pairs
5756
body_str = re.sub(
5857
r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str

litellm/utils.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@
5959
import litellm.litellm_core_utils.json_validation_rule
6060
import litellm.llms
6161
import litellm.llms.gemini
62+
# Import cached imports utilities
63+
from litellm.litellm_core_utils.cached_imports import (
64+
get_coroutine_checker,
65+
get_litellm_logging_class,
66+
get_set_callbacks,
67+
)
6268
from litellm.caching._internal_lru_cache import lru_cache_wrapper
6369
from litellm.caching.caching import DualCache
6470
from litellm.caching.caching_handler import CachingHandlerResponse, LLMCachingHandler
@@ -222,6 +228,7 @@
222228
get_args,
223229
)
224230

231+
225232
from openai import OpenAIError as OriginalError
226233

227234
from litellm.litellm_core_utils.thread_pool_executor import executor
@@ -521,16 +528,12 @@ def get_dynamic_callbacks(
521528

522529

523530

524-
from litellm.litellm_core_utils.coroutine_checker import coroutine_checker
525531

526532

527533
def function_setup( # noqa: PLR0915
528534
original_function: str, rules_obj, start_time, *args, **kwargs
529535
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
530536
### NOTICES ###
531-
from litellm import Logging as LiteLLMLogging
532-
from litellm.litellm_core_utils.litellm_logging import set_callbacks
533-
534537
if litellm.set_verbose is True:
535538
verbose_logger.warning(
536539
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
@@ -593,12 +596,12 @@ def function_setup( # noqa: PLR0915
593596
+ litellm.failure_callback
594597
)
595598
)
596-
set_callbacks(callback_list=callback_list, function_id=function_id)
599+
get_set_callbacks()(callback_list=callback_list, function_id=function_id)
597600
## ASYNC CALLBACKS
598601
if len(litellm.input_callback) > 0:
599602
removed_async_items = []
600603
for index, callback in enumerate(litellm.input_callback): # type: ignore
601-
if coroutine_checker.is_async_callable(callback):
604+
if get_coroutine_checker().is_async_callable(callback):
602605
litellm._async_input_callback.append(callback)
603606
removed_async_items.append(index)
604607

@@ -608,7 +611,7 @@ def function_setup( # noqa: PLR0915
608611
if len(litellm.success_callback) > 0:
609612
removed_async_items = []
610613
for index, callback in enumerate(litellm.success_callback): # type: ignore
611-
if coroutine_checker.is_async_callable(callback):
614+
if get_coroutine_checker().is_async_callable(callback):
612615
litellm.logging_callback_manager.add_litellm_async_success_callback(
613616
callback
614617
)
@@ -633,7 +636,7 @@ def function_setup( # noqa: PLR0915
633636
if len(litellm.failure_callback) > 0:
634637
removed_async_items = []
635638
for index, callback in enumerate(litellm.failure_callback): # type: ignore
636-
if coroutine_checker.is_async_callable(callback):
639+
if get_coroutine_checker().is_async_callable(callback):
637640
litellm.logging_callback_manager.add_litellm_async_failure_callback(
638641
callback
639642
)
@@ -666,7 +669,7 @@ def function_setup( # noqa: PLR0915
666669
removed_async_items = []
667670
for index, callback in enumerate(kwargs["success_callback"]):
668671
if (
669-
coroutine_checker.is_async_callable(callback)
672+
get_coroutine_checker().is_async_callable(callback)
670673
or callback == "dynamodb"
671674
or callback == "s3"
672675
):
@@ -790,7 +793,7 @@ def function_setup( # noqa: PLR0915
790793
call_type=call_type,
791794
):
792795
stream = True
793-
logging_obj = LiteLLMLogging(
796+
logging_obj = get_litellm_logging_class()( # Victim for object pool
794797
model=model, # type: ignore
795798
messages=messages,
796799
stream=stream,
@@ -903,7 +906,7 @@ def client(original_function): # noqa: PLR0915
903906
rules_obj = Rules()
904907

905908
def check_coroutine(value) -> bool:
906-
return coroutine_checker.is_async_callable(value)
909+
return get_coroutine_checker().is_async_callable(value)
907910

908911
async def async_pre_call_deployment_hook(kwargs: Dict[str, Any], call_type: str):
909912
"""
@@ -1597,7 +1600,7 @@ async def wrapper_async(*args, **kwargs): # noqa: PLR0915
15971600
setattr(e, "timeout", timeout)
15981601
raise e
15991602

1600-
is_coroutine = coroutine_checker.is_async_callable(original_function)
1603+
is_coroutine = get_coroutine_checker().is_async_callable(original_function)
16011604

16021605
# Return the appropriate wrapper based on the original function type
16031606
if is_coroutine:
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
## Tests that chat_completion endpoint has no imports inside function bodies
2+
## This is critical for performance optimization in the hot path
3+
4+
import ast
5+
from pathlib import Path
6+
7+
8+
def test_chat_completion_no_imports():
9+
"""Test that chat_completion endpoint has no imports in function bodies."""
10+
# Path to the proxy server file
11+
proxy_server_path = Path(__file__).parent.parent.parent / "litellm" / "proxy" / "proxy_server.py"
12+
13+
with open(proxy_server_path, 'r') as f:
14+
content = f.read()
15+
16+
# Parse the AST
17+
tree = ast.parse(content)
18+
19+
# Find the chat_completion function
20+
chat_completion_func = None
21+
for node in ast.walk(tree):
22+
if (isinstance(node, ast.AsyncFunctionDef) and node.name == "chat_completion"):
23+
chat_completion_func = node
24+
break
25+
26+
assert chat_completion_func is not None, "chat_completion function not found"
27+
28+
# Check for imports inside the function body
29+
import_violations = []
30+
31+
for node in ast.walk(chat_completion_func):
32+
if isinstance(node, (ast.Import, ast.ImportFrom)):
33+
# Get line number
34+
line_num = node.lineno
35+
import_violations.append(line_num)
36+
37+
# Assert no import violations found
38+
if import_violations:
39+
print(f"Found {len(import_violations)} import violations in chat_completion endpoint:")
40+
for line_num in import_violations:
41+
print(f" - Line {line_num}: Import statement found")
42+
print("\nchat_completion endpoint should not contain imports for optimal performance.")
43+
raise Exception("Import violations found in chat_completion endpoint")

0 commit comments

Comments
 (0)