Skip to content

Commit 02db2e8

Browse files
[Performance] RPS Improvement +500 RPS when sending the user field (#14616)
* perf tool * fix: cache type issue * fix: exception hanging & cache setting 1. Removed unhandled exceptions 2. Set cache value to dict
1 parent 68105ce commit 02db2e8

File tree

3 files changed

+134
-12
lines changed

3 files changed

+134
-12
lines changed

litellm/proxy/auth/auth_checks.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -469,16 +469,13 @@ def check_in_budget(end_user_obj: LiteLLM_EndUserTable):
469469
# check if in cache
470470
cached_user_obj = await user_api_key_cache.async_get_cache(key=_key)
471471
if cached_user_obj is not None:
472-
if isinstance(cached_user_obj, dict):
473-
return_obj = LiteLLM_EndUserTable(**cached_user_obj)
474-
check_in_budget(end_user_obj=return_obj)
475-
return return_obj
476-
elif isinstance(cached_user_obj, LiteLLM_EndUserTable):
477-
return_obj = cached_user_obj
478-
check_in_budget(end_user_obj=return_obj)
479-
return return_obj
472+
# Convert cached dict to LiteLLM_EndUserTable instance
473+
return_obj = LiteLLM_EndUserTable(**cached_user_obj)
474+
check_in_budget(end_user_obj=return_obj)
475+
return return_obj
476+
480477
# else, check db
481-
try:
478+
try:
482479
response = await prisma_client.db.litellm_endusertable.find_unique(
483480
where={"user_id": end_user_id},
484481
include={"litellm_budget_table": True},
@@ -487,9 +484,9 @@ def check_in_budget(end_user_obj: LiteLLM_EndUserTable):
487484
if response is None:
488485
raise Exception
489486

490-
# save the end-user object to cache
487+
# save the end-user object to cache (always store as dict for consistency)
491488
await user_api_key_cache.async_set_cache(
492-
key="end_user_id:{}".format(end_user_id), value=response
489+
key="end_user_id:{}".format(end_user_id), value=response.dict()
493490
)
494491

495492
_response = LiteLLM_EndUserTable(**response.dict())
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
Performance utilities for LiteLLM proxy server.
3+
4+
This module provides performance monitoring and profiling functionality for endpoint
5+
performance analysis using cProfile with configurable sampling rates.
6+
"""
7+
8+
import asyncio
9+
import cProfile
10+
import functools
11+
import threading
12+
from pathlib import Path as PathLib
13+
14+
from litellm._logging import verbose_proxy_logger
15+
16+
# Global profiling state
17+
_profile_lock = threading.Lock()
18+
_profiler = None
19+
_last_profile_file_path = None
20+
_sample_counter = 0
21+
_sample_counter_lock = threading.Lock()
22+
23+
24+
def _should_sample(profile_sampling_rate: float) -> bool:
25+
"""Determine if current request should be sampled based on sampling rate."""
26+
if profile_sampling_rate >= 1.0:
27+
return True # Always sample
28+
elif profile_sampling_rate <= 0.0:
29+
return False # Never sample
30+
31+
# Use deterministic sampling based on counter for consistent rate
32+
global _sample_counter
33+
with _sample_counter_lock:
34+
_sample_counter += 1
35+
# Sample based on rate (e.g., 0.1 means sample every 10th request)
36+
should_sample = (_sample_counter % int(1.0 / profile_sampling_rate)) == 0
37+
return should_sample
38+
39+
40+
def _start_profiling(profile_sampling_rate: float) -> None:
41+
"""Start cProfile profiling once globally."""
42+
global _profiler
43+
with _profile_lock:
44+
if _profiler is None:
45+
_profiler = cProfile.Profile()
46+
_profiler.enable()
47+
verbose_proxy_logger.info(f"Profiling started with sampling rate: {profile_sampling_rate}")
48+
49+
50+
def _start_profiling_for_request(profile_sampling_rate: float) -> bool:
51+
"""Start profiling for a specific request (if sampling allows)."""
52+
if _should_sample(profile_sampling_rate):
53+
_start_profiling(profile_sampling_rate)
54+
return True
55+
return False
56+
57+
58+
def _save_stats(profile_file: PathLib) -> None:
59+
"""Save current stats directly to file."""
60+
with _profile_lock:
61+
if _profiler is None:
62+
return
63+
try:
64+
# Disable profiler temporarily to dump stats
65+
_profiler.disable()
66+
_profiler.dump_stats(str(profile_file))
67+
# Re-enable profiler to continue profiling
68+
_profiler.enable()
69+
verbose_proxy_logger.debug(f"Profiling stats saved to {profile_file}")
70+
except Exception as e:
71+
verbose_proxy_logger.error(f"Error saving profiling stats: {e}")
72+
# Make sure profiler is re-enabled even if there's an error
73+
try:
74+
_profiler.enable()
75+
except Exception:
76+
pass
77+
78+
79+
def profile_endpoint(sampling_rate: float = 1.0):
80+
"""Decorator to sample endpoint hits and save to a profile file.
81+
82+
Args:
83+
sampling_rate: Rate of requests to profile (0.0 to 1.0)
84+
- 1.0: Profile all requests (100%)
85+
- 0.1: Profile 1 in 10 requests (10%)
86+
- 0.0: Profile no requests (0%)
87+
"""
88+
def decorator(func):
89+
def set_last_profile_path(path: PathLib) -> None:
90+
global _last_profile_file_path
91+
_last_profile_file_path = path
92+
93+
if asyncio.iscoroutinefunction(func):
94+
@functools.wraps(func)
95+
async def async_wrapper(*args, **kwargs):
96+
is_sampling = _start_profiling_for_request(sampling_rate)
97+
file_path_obj = PathLib("endpoint_profile.pstat")
98+
set_last_profile_path(file_path_obj)
99+
try:
100+
result = await func(*args, **kwargs)
101+
if is_sampling:
102+
_save_stats(file_path_obj)
103+
return result
104+
except Exception:
105+
if is_sampling:
106+
_save_stats(file_path_obj)
107+
raise
108+
return async_wrapper
109+
else:
110+
@functools.wraps(func)
111+
def sync_wrapper(*args, **kwargs):
112+
is_sampling = _start_profiling_for_request(sampling_rate)
113+
file_path_obj = PathLib("endpoint_profile.pstat")
114+
set_last_profile_path(file_path_obj)
115+
try:
116+
result = func(*args, **kwargs)
117+
if is_sampling:
118+
_save_stats(file_path_obj)
119+
return result
120+
except Exception:
121+
if is_sampling:
122+
_save_stats(file_path_obj)
123+
raise
124+
return sync_wrapper
125+
return decorator

tests/proxy_unit_tests/test_auth_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ async def test_get_end_user_object(customer_spend, customer_budget):
4848
)
4949
_cache = DualCache()
5050
_key = "end_user_id:{}".format(end_user_id)
51-
_cache.set_cache(key=_key, value=end_user_obj)
51+
_cache.set_cache(key=_key, value=end_user_obj.model_dump())
5252
try:
5353
await get_end_user_object(
5454
end_user_id=end_user_id,

0 commit comments

Comments
 (0)