4
4
from typing import Any , Literal , Optional , Tuple , TypedDict , cast
5
5
6
6
import litellm
7
- from litellm ._logging import verbose_logger
7
+ from litellm ._logging import verbose_logger , verbose_proxy_logger
8
8
from litellm .types .utils import (
9
9
CacheCreationTokenDetails ,
10
10
CallTypes ,
11
11
ImageResponse ,
12
12
ModelInfo ,
13
13
PassthroughCallTypes ,
14
14
Usage ,
15
+ ServiceTier ,
15
16
)
16
17
from litellm .utils import get_model_info
17
18
@@ -114,8 +115,30 @@ def _generic_cost_per_character(
114
115
return prompt_cost , completion_cost
115
116
116
117
118
+ def _get_service_tier_cost_key (base_key : str , service_tier : Optional [str ]) -> str :
119
+ """
120
+ Get the appropriate cost key based on service tier.
121
+
122
+ Args:
123
+ base_key: The base cost key (e.g., "input_cost_per_token")
124
+ service_tier: The service tier ("flex", "priority", or None for standard)
125
+
126
+ Returns:
127
+ str: The cost key to use (e.g., "input_cost_per_token_flex" or "input_cost_per_token")
128
+ """
129
+ if service_tier is None :
130
+ return base_key
131
+
132
+ # Only use service tier specific keys for "flex" and "priority"
133
+ if service_tier .lower () in [ServiceTier .FLEX .value , ServiceTier .PRIORITY .value ]:
134
+ return f"{ base_key } _{ service_tier .lower ()} "
135
+
136
+ # For any other service tier, use standard pricing
137
+ return base_key
138
+
139
+
117
140
def _get_token_base_cost (
118
- model_info : ModelInfo , usage : Usage
141
+ model_info : ModelInfo , usage : Usage , service_tier : Optional [ str ] = None
119
142
) -> Tuple [float , float , float , float , float ]:
120
143
"""
121
144
Return prompt cost, completion cost, and cache costs for a given model and usage.
@@ -126,21 +149,27 @@ def _get_token_base_cost(
126
149
Returns:
127
150
Tuple[float, float, float, float] - (prompt_cost, completion_cost, cache_creation_cost, cache_read_cost)
128
151
"""
152
+ # Get service tier aware cost keys
153
+ input_cost_key = _get_service_tier_cost_key ("input_cost_per_token" , service_tier )
154
+ output_cost_key = _get_service_tier_cost_key ("output_cost_per_token" , service_tier )
155
+ cache_creation_cost_key = _get_service_tier_cost_key ("cache_creation_input_token_cost" , service_tier )
156
+ cache_read_cost_key = _get_service_tier_cost_key ("cache_read_input_token_cost" , service_tier )
157
+
129
158
prompt_base_cost = cast (
130
- float , _get_cost_per_unit (model_info , "input_cost_per_token" )
159
+ float , _get_cost_per_unit (model_info , input_cost_key )
131
160
)
132
161
completion_base_cost = cast (
133
- float , _get_cost_per_unit (model_info , "output_cost_per_token" )
162
+ float , _get_cost_per_unit (model_info , output_cost_key )
134
163
)
135
164
cache_creation_cost = cast (
136
- float , _get_cost_per_unit (model_info , "cache_creation_input_token_cost" )
165
+ float , _get_cost_per_unit (model_info , cache_creation_cost_key )
137
166
)
138
167
cache_creation_cost_above_1hr = cast (
139
168
float ,
140
169
_get_cost_per_unit (model_info , "cache_creation_input_token_cost_above_1hr" ),
141
170
)
142
171
cache_read_cost = cast (
143
- float , _get_cost_per_unit (model_info , "cache_read_input_token_cost" )
172
+ float , _get_cost_per_unit (model_info , cache_read_cost_key )
144
173
)
145
174
146
175
## CHECK IF ABOVE THRESHOLD
@@ -249,6 +278,29 @@ def _get_cost_per_unit(
249
278
verbose_logger .exception (
250
279
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::calculate_cost_per_component(): Exception occured - { cost_per_unit } \n Defaulting to 0.0"
251
280
)
281
+
282
+ # If the service tier key doesn't exist or is None, try to fall back to the standard key
283
+ if cost_per_unit is None :
284
+ # Check if any service tier suffix exists in the cost key using ServiceTier enum
285
+ for service_tier in ServiceTier :
286
+ suffix = f"_{ service_tier .value } "
287
+ if suffix in cost_key :
288
+ # Extract the base key by removing the matched suffix
289
+ base_key = cost_key .replace (suffix , '' )
290
+ fallback_cost = model_info .get (base_key )
291
+ if isinstance (fallback_cost , float ):
292
+ return fallback_cost
293
+ if isinstance (fallback_cost , int ):
294
+ return float (fallback_cost )
295
+ if isinstance (fallback_cost , str ):
296
+ try :
297
+ return float (fallback_cost )
298
+ except ValueError :
299
+ verbose_logger .exception (
300
+ f"litellm.litellm_core_utils.llm_cost_calc.utils.py::_get_cost_per_unit(): Exception occured - { fallback_cost } \n Defaulting to 0.0"
301
+ )
302
+ break # Only try the first matching suffix
303
+
252
304
return default_value
253
305
254
306
@@ -443,7 +495,7 @@ def _calculate_input_cost(
443
495
444
496
445
497
def generic_cost_per_token (
446
- model : str , usage : Usage , custom_llm_provider : str
498
+ model : str , usage : Usage , custom_llm_provider : str , service_tier : Optional [ str ] = None
447
499
) -> Tuple [float , float ]:
448
500
"""
449
501
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@@ -495,7 +547,7 @@ def generic_cost_per_token(
495
547
cache_creation_cost ,
496
548
cache_creation_cost_above_1hr ,
497
549
cache_read_cost ,
498
- ) = _get_token_base_cost (model_info = model_info , usage = usage )
550
+ ) = _get_token_base_cost (model_info = model_info , usage = usage , service_tier = service_tier )
499
551
500
552
prompt_cost = _calculate_input_cost (
501
553
prompt_tokens_details = prompt_tokens_details ,
0 commit comments