11
11
ImageResponse ,
12
12
ModelInfo ,
13
13
PassthroughCallTypes ,
14
- Usage ,
15
14
ServiceTier ,
15
+ Usage ,
16
16
)
17
17
from litellm .utils import get_model_info
18
18
@@ -118,21 +118,21 @@ def _generic_cost_per_character(
118
118
def _get_service_tier_cost_key (base_key : str , service_tier : Optional [str ]) -> str :
119
119
"""
120
120
Get the appropriate cost key based on service tier.
121
-
121
+
122
122
Args:
123
123
base_key: The base cost key (e.g., "input_cost_per_token")
124
124
service_tier: The service tier ("flex", "priority", or None for standard)
125
-
125
+
126
126
Returns:
127
127
str: The cost key to use (e.g., "input_cost_per_token_flex" or "input_cost_per_token")
128
128
"""
129
129
if service_tier is None :
130
130
return base_key
131
-
131
+
132
132
# Only use service tier specific keys for "flex" and "priority"
133
133
if service_tier .lower () in [ServiceTier .FLEX .value , ServiceTier .PRIORITY .value ]:
134
134
return f"{ base_key } _{ service_tier .lower ()} "
135
-
135
+
136
136
# For any other service tier, use standard pricing
137
137
return base_key
138
138
@@ -152,25 +152,23 @@ def _get_token_base_cost(
152
152
# Get service tier aware cost keys
153
153
input_cost_key = _get_service_tier_cost_key ("input_cost_per_token" , service_tier )
154
154
output_cost_key = _get_service_tier_cost_key ("output_cost_per_token" , service_tier )
155
- cache_creation_cost_key = _get_service_tier_cost_key ("cache_creation_input_token_cost" , service_tier )
156
- cache_read_cost_key = _get_service_tier_cost_key ("cache_read_input_token_cost" , service_tier )
157
-
158
- prompt_base_cost = cast (
159
- float , _get_cost_per_unit (model_info , input_cost_key )
155
+ cache_creation_cost_key = _get_service_tier_cost_key (
156
+ "cache_creation_input_token_cost" , service_tier
160
157
)
161
- completion_base_cost = cast (
162
- float , _get_cost_per_unit ( model_info , output_cost_key )
158
+ cache_read_cost_key = _get_service_tier_cost_key (
159
+ "cache_read_input_token_cost" , service_tier
163
160
)
161
+
162
+ prompt_base_cost = cast (float , _get_cost_per_unit (model_info , input_cost_key ))
163
+ completion_base_cost = cast (float , _get_cost_per_unit (model_info , output_cost_key ))
164
164
cache_creation_cost = cast (
165
165
float , _get_cost_per_unit (model_info , cache_creation_cost_key )
166
166
)
167
167
cache_creation_cost_above_1hr = cast (
168
168
float ,
169
169
_get_cost_per_unit (model_info , "cache_creation_input_token_cost_above_1hr" ),
170
170
)
171
- cache_read_cost = cast (
172
- float , _get_cost_per_unit (model_info , cache_read_cost_key )
173
- )
171
+ cache_read_cost = cast (float , _get_cost_per_unit (model_info , cache_read_cost_key ))
174
172
175
173
## CHECK IF ABOVE THRESHOLD
176
174
threshold : Optional [float ] = None
@@ -183,7 +181,6 @@ def _get_token_base_cost(
183
181
1000 if "k" in threshold_str else 1
184
182
)
185
183
if usage .prompt_tokens > threshold :
186
-
187
184
prompt_base_cost = cast (
188
185
float , _get_cost_per_unit (model_info , key , prompt_base_cost )
189
186
)
@@ -278,15 +275,15 @@ def _get_cost_per_unit(
278
275
verbose_logger .exception (
279
276
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::calculate_cost_per_component(): Exception occured - { cost_per_unit } \n Defaulting to 0.0"
280
277
)
281
-
278
+
282
279
# If the service tier key doesn't exist or is None, try to fall back to the standard key
283
280
if cost_per_unit is None :
284
281
# Check if any service tier suffix exists in the cost key using ServiceTier enum
285
282
for service_tier in ServiceTier :
286
283
suffix = f"_{ service_tier .value } "
287
284
if suffix in cost_key :
288
285
# Extract the base key by removing the matched suffix
289
- base_key = cost_key .replace (suffix , '' )
286
+ base_key = cost_key .replace (suffix , "" )
290
287
fallback_cost = model_info .get (base_key )
291
288
if isinstance (fallback_cost , float ):
292
289
return fallback_cost
@@ -300,7 +297,7 @@ def _get_cost_per_unit(
300
297
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::_get_cost_per_unit(): Exception occured - { fallback_cost } \n Defaulting to 0.0"
301
298
)
302
299
break # Only try the first matching suffix
303
-
300
+
304
301
return default_value
305
302
306
303
@@ -495,7 +492,10 @@ def _calculate_input_cost(
495
492
496
493
497
494
def generic_cost_per_token (
498
- model : str , usage : Usage , custom_llm_provider : str , service_tier : Optional [str ] = None
495
+ model : str ,
496
+ usage : Usage ,
497
+ custom_llm_provider : str ,
498
+ service_tier : Optional [str ] = None ,
499
499
) -> Tuple [float , float ]:
500
500
"""
501
501
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@@ -547,7 +547,9 @@ def generic_cost_per_token(
547
547
cache_creation_cost ,
548
548
cache_creation_cost_above_1hr ,
549
549
cache_read_cost ,
550
- ) = _get_token_base_cost (model_info = model_info , usage = usage , service_tier = service_tier )
550
+ ) = _get_token_base_cost (
551
+ model_info = model_info , usage = usage , service_tier = service_tier
552
+ )
551
553
552
554
prompt_cost = _calculate_input_cost (
553
555
prompt_tokens_details = prompt_tokens_details ,
0 commit comments