Skip to content

Commit 895c41e

Browse files
Merge pull request #14619 from BerriAI/litellm_dev_09_16_2025_p1
UI - allow team member to view service account keys they create + Anthropic - include cache creation tokens in prompt token total (separate out during cost tracking)
2 parents 69c0148 + 08ba38a commit 895c41e

File tree

12 files changed

+510
-213
lines changed

12 files changed

+510
-213
lines changed

litellm/litellm_core_utils/llm_cost_calc/utils.py

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -113,20 +113,30 @@ def _generic_cost_per_character(
113113
return prompt_cost, completion_cost
114114

115115

116-
def _get_token_base_cost(model_info: ModelInfo, usage: Usage) -> Tuple[float, float, float, float]:
116+
def _get_token_base_cost(
117+
model_info: ModelInfo, usage: Usage
118+
) -> Tuple[float, float, float, float]:
117119
"""
118120
Return prompt cost, completion cost, and cache costs for a given model and usage.
119121
120122
If input_tokens > threshold and `input_cost_per_token_above_[x]k_tokens` or `input_cost_per_token_above_[x]_tokens` is set,
121123
then we use the corresponding threshold cost for all token types.
122-
124+
123125
Returns:
124126
Tuple[float, float, float, float] - (prompt_cost, completion_cost, cache_creation_cost, cache_read_cost)
125127
"""
126-
prompt_base_cost = cast(float, _get_cost_per_unit(model_info, "input_cost_per_token"))
127-
completion_base_cost = cast(float, _get_cost_per_unit(model_info, "output_cost_per_token"))
128-
cache_creation_cost = cast(float, _get_cost_per_unit(model_info, "cache_creation_input_token_cost"))
129-
cache_read_cost = cast(float, _get_cost_per_unit(model_info, "cache_read_input_token_cost"))
128+
prompt_base_cost = cast(
129+
float, _get_cost_per_unit(model_info, "input_cost_per_token")
130+
)
131+
completion_base_cost = cast(
132+
float, _get_cost_per_unit(model_info, "output_cost_per_token")
133+
)
134+
cache_creation_cost = cast(
135+
float, _get_cost_per_unit(model_info, "cache_creation_input_token_cost")
136+
)
137+
cache_read_cost = cast(
138+
float, _get_cost_per_unit(model_info, "cache_read_input_token_cost")
139+
)
130140

131141
## CHECK IF ABOVE THRESHOLD
132142
threshold: Optional[float] = None
@@ -140,27 +150,44 @@ def _get_token_base_cost(model_info: ModelInfo, usage: Usage) -> Tuple[float, fl
140150
)
141151
if usage.prompt_tokens > threshold:
142152

143-
prompt_base_cost = cast(float, _get_cost_per_unit(model_info, key, prompt_base_cost))
144-
completion_base_cost = cast(float, _get_cost_per_unit(
145-
model_info,
146-
f"output_cost_per_token_above_{threshold_str}_tokens",
147-
completion_base_cost,
148-
))
149-
153+
prompt_base_cost = cast(
154+
float, _get_cost_per_unit(model_info, key, prompt_base_cost)
155+
)
156+
completion_base_cost = cast(
157+
float,
158+
_get_cost_per_unit(
159+
model_info,
160+
f"output_cost_per_token_above_{threshold_str}_tokens",
161+
completion_base_cost,
162+
),
163+
)
164+
150165
# Apply tiered pricing to cache costs
151-
cache_creation_tiered_key = f"cache_creation_input_token_cost_above_{threshold_str}_tokens"
152-
cache_read_tiered_key = f"cache_read_input_token_cost_above_{threshold_str}_tokens"
153-
166+
cache_creation_tiered_key = (
167+
f"cache_creation_input_token_cost_above_{threshold_str}_tokens"
168+
)
169+
cache_read_tiered_key = (
170+
f"cache_read_input_token_cost_above_{threshold_str}_tokens"
171+
)
172+
154173
if cache_creation_tiered_key in model_info:
155-
cache_creation_cost = cast(float, _get_cost_per_unit(
156-
model_info, cache_creation_tiered_key, cache_creation_cost
157-
))
158-
174+
cache_creation_cost = cast(
175+
float,
176+
_get_cost_per_unit(
177+
model_info,
178+
cache_creation_tiered_key,
179+
cache_creation_cost,
180+
),
181+
)
182+
159183
if cache_read_tiered_key in model_info:
160-
cache_read_cost = cast(float, _get_cost_per_unit(
161-
model_info, cache_read_tiered_key, cache_read_cost
162-
))
163-
184+
cache_read_cost = cast(
185+
float,
186+
_get_cost_per_unit(
187+
model_info, cache_read_tiered_key, cache_read_cost
188+
),
189+
)
190+
164191
break
165192
except (IndexError, ValueError):
166193
continue
@@ -195,7 +222,9 @@ def calculate_cost_component(
195222
return 0.0
196223

197224

198-
def _get_cost_per_unit(model_info: ModelInfo, cost_key: str, default_value: Optional[float] = 0.0) -> Optional[float]:
225+
def _get_cost_per_unit(
226+
model_info: ModelInfo, cost_key: str, default_value: Optional[float] = 0.0
227+
) -> Optional[float]:
199228
# Sometimes the cost per unit is a string (e.g.: If a value like "3e-7" was read from the config.yaml)
200229
cost_per_unit = model_info.get(cost_key)
201230
if isinstance(cost_per_unit, float):
@@ -210,7 +239,6 @@ def _get_cost_per_unit(model_info: ModelInfo, cost_key: str, default_value: Opti
210239
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::calculate_cost_per_component(): Exception occured - {cost_per_unit}\nDefaulting to 0.0"
211240
)
212241
return default_value
213-
214242

215243

216244
def generic_cost_per_token(
@@ -238,6 +266,7 @@ def generic_cost_per_token(
238266
### PROCESSING COST
239267
text_tokens = usage.prompt_tokens
240268
cache_hit_tokens = 0
269+
cache_creation_tokens = 0
241270
audio_tokens = 0
242271
character_count = 0
243272
image_count = 0
@@ -249,6 +278,13 @@ def generic_cost_per_token(
249278
)
250279
or 0
251280
)
281+
cache_creation_tokens = (
282+
cast(
283+
Optional[int],
284+
getattr(usage.prompt_tokens_details, "cache_creation_tokens", 0),
285+
)
286+
or 0
287+
)
252288
text_tokens = (
253289
cast(
254290
Optional[int], getattr(usage.prompt_tokens_details, "text_tokens", None)
@@ -279,11 +315,17 @@ def generic_cost_per_token(
279315
)
280316

281317
## EDGE CASE - text tokens not set inside PromptTokensDetails
318+
282319
if text_tokens == 0:
283-
text_tokens = usage.prompt_tokens - cache_hit_tokens - audio_tokens
320+
text_tokens = (
321+
usage.prompt_tokens
322+
- cache_hit_tokens
323+
- audio_tokens
324+
- cache_creation_tokens
325+
)
284326

285-
prompt_base_cost, completion_base_cost, cache_creation_cost, cache_read_cost = _get_token_base_cost(
286-
model_info=model_info, usage=usage
327+
prompt_base_cost, completion_base_cost, cache_creation_cost, cache_read_cost = (
328+
_get_token_base_cost(model_info=model_info, usage=usage)
287329
)
288330

289331
prompt_cost = float(text_tokens) * prompt_base_cost
@@ -297,7 +339,7 @@ def generic_cost_per_token(
297339
)
298340

299341
### CACHE WRITING COST - Now uses tiered pricing
300-
prompt_cost += float(usage._cache_creation_input_tokens or 0) * cache_creation_cost
342+
prompt_cost += float(cache_creation_tokens) * cache_creation_cost
301343

302344
### CHARACTER COST
303345

@@ -350,8 +392,12 @@ def generic_cost_per_token(
350392
## TEXT COST
351393
completion_cost = float(text_tokens) * completion_base_cost
352394

353-
_output_cost_per_audio_token = _get_cost_per_unit(model_info, "output_cost_per_audio_token", None)
354-
_output_cost_per_reasoning_token = _get_cost_per_unit(model_info, "output_cost_per_reasoning_token", None)
395+
_output_cost_per_audio_token = _get_cost_per_unit(
396+
model_info, "output_cost_per_audio_token", None
397+
)
398+
_output_cost_per_reasoning_token = _get_cost_per_unit(
399+
model_info, "output_cost_per_reasoning_token", None
400+
)
355401

356402
## AUDIO COST
357403
if not is_text_tokens_total and audio_tokens is not None and audio_tokens > 0:
@@ -397,7 +443,7 @@ def _call_type_has_image_response(call_type: str) -> bool:
397443
]:
398444
return True
399445
return False
400-
446+
401447
@staticmethod
402448
def route_image_generation_cost_calculator(
403449
model: str,

litellm/llms/anthropic/chat/transformation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ def calculate_usage(
826826
and _usage["cache_creation_input_tokens"] is not None
827827
):
828828
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
829+
prompt_tokens += cache_creation_input_tokens
829830
if (
830831
"cache_read_input_tokens" in _usage
831832
and _usage["cache_read_input_tokens"] is not None

litellm/proxy/management_endpoints/key_management_endpoints.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2462,6 +2462,9 @@ async def list_keys(
24622462
include_team_keys: bool = Query(
24632463
False, description="Include all keys for teams that user is an admin of."
24642464
),
2465+
include_created_by_keys: bool = Query(
2466+
False, description="Include keys created by the user"
2467+
),
24652468
sort_by: Optional[str] = Query(
24662469
default=None,
24672470
description="Column to sort by (e.g. 'user_id', 'created_at', 'spend')",
@@ -2524,6 +2527,7 @@ async def list_keys(
25242527
return_full_object=return_full_object,
25252528
organization_id=organization_id,
25262529
admin_team_ids=admin_team_ids,
2530+
include_created_by_keys=include_created_by_keys,
25272531
sort_by=sort_by,
25282532
sort_order=sort_order,
25292533
)
@@ -2601,6 +2605,7 @@ async def _list_key_helper(
26012605
admin_team_ids: Optional[
26022606
List[str]
26032607
] = None, # New parameter for teams where user is admin
2608+
include_created_by_keys: bool = False,
26042609
sort_by: Optional[str] = None,
26052610
sort_order: str = "desc",
26062611
) -> KeyListResponseObject:
@@ -2650,6 +2655,10 @@ async def _list_key_helper(
26502655
if user_condition:
26512656
or_conditions.append(user_condition)
26522657

2658+
# Add condition for created by keys if provided
2659+
if include_created_by_keys and user_id:
2660+
or_conditions.append({"created_by": user_id})
2661+
26532662
# Add condition for admin team keys if provided
26542663
if admin_team_ids:
26552664
or_conditions.append({"team_id": {"in": admin_team_ids}})

litellm/types/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,9 @@ class PromptTokensDetailsWrapper(
883883
video_length_seconds: Optional[float] = None
884884
"""Length of videos sent to the model. Used for Vertex AI multimodal embeddings."""
885885

886+
cache_creation_tokens: Optional[int] = None
887+
"""Number of cache creation tokens sent to the model. Used for Anthropic prompt caching."""
888+
886889
def __init__(self, *args, **kwargs):
887890
super().__init__(*args, **kwargs)
888891
if self.character_count is None:
@@ -893,6 +896,8 @@ def __init__(self, *args, **kwargs):
893896
del self.video_length_seconds
894897
if self.web_search_requests is None:
895898
del self.web_search_requests
899+
if self.cache_creation_tokens is None:
900+
del self.cache_creation_tokens
896901

897902

898903
class ServerToolUse(BaseModel):
@@ -954,6 +959,7 @@ def __init__(
954959
# handle prompt_tokens_details
955960
_prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
956961

962+
# guarantee prompt_token_details is always a PromptTokensDetailsWrapper
957963
if prompt_tokens_details:
958964
if isinstance(prompt_tokens_details, dict):
959965
_prompt_tokens_details = PromptTokensDetailsWrapper(
@@ -988,6 +994,18 @@ def __init__(
988994
else:
989995
_prompt_tokens_details.cached_tokens = params["cache_read_input_tokens"]
990996

997+
if "cache_creation_input_tokens" in params and isinstance(
998+
params["cache_creation_input_tokens"], int
999+
):
1000+
if _prompt_tokens_details is None:
1001+
_prompt_tokens_details = PromptTokensDetailsWrapper(
1002+
cache_creation_tokens=params["cache_creation_input_tokens"]
1003+
)
1004+
else:
1005+
_prompt_tokens_details.cache_creation_tokens = params[
1006+
"cache_creation_input_tokens"
1007+
]
1008+
9911009
super().__init__(
9921010
prompt_tokens=prompt_tokens or 0,
9931011
completion_tokens=completion_tokens or 0,

0 commit comments

Comments
 (0)