Skip to content

Commit d220802

Browse files
committed
Add service_tier based pricing support for openai
1 parent 52a56bd commit d220802

File tree

9 files changed

+378
-11
lines changed

9 files changed

+378
-11
lines changed

litellm/cost_calculator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ def cost_per_token( # noqa: PLR0915
148148
### CALL TYPE ###
149149
call_type: CallTypesLiteral = "completion",
150150
audio_transcription_file_duration: float = 0.0, # for audio transcription calls - the file time in seconds
151+
### SERVICE TIER ###
152+
service_tier: Optional[str] = None, # for OpenAI service tier pricing
151153
) -> Tuple[float, float]: # type: ignore
152154
"""
153155
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@@ -278,6 +280,7 @@ def cost_per_token( # noqa: PLR0915
278280
model=model_without_prefix,
279281
usage=usage_block,
280282
custom_llm_provider=custom_llm_provider,
283+
service_tier=service_tier,
281284
)
282285

283286
return prompt_cost, completion_cost
@@ -327,7 +330,7 @@ def cost_per_token( # noqa: PLR0915
327330
elif custom_llm_provider == "bedrock":
328331
return bedrock_cost_per_token(model=model, usage=usage_block)
329332
elif custom_llm_provider == "openai":
330-
return openai_cost_per_token(model=model, usage=usage_block)
333+
return openai_cost_per_token(model=model, usage=usage_block, service_tier=service_tier)
331334
elif custom_llm_provider == "databricks":
332335
return databricks_cost_per_token(model=model, usage=usage_block)
333336
elif custom_llm_provider == "fireworks_ai":
@@ -606,6 +609,8 @@ def completion_cost( # noqa: PLR0915
606609
litellm_model_name: Optional[str] = None,
607610
router_model_id: Optional[str] = None,
608611
litellm_logging_obj: Optional[LitellmLoggingObject] = None,
612+
### SERVICE TIER ###
613+
service_tier: Optional[str] = None, # for OpenAI service tier pricing
609614
) -> float:
610615
"""
611616
Calculate the cost of a given completion call fot GPT-3.5-turbo, llama2, any litellm supported llm.
@@ -658,6 +663,10 @@ def completion_cost( # noqa: PLR0915
658663
completion_response=completion_response
659664
)
660665
rerank_billed_units: Optional[RerankBilledUnits] = None
666+
667+
# Extract service_tier from optional_params if not provided directly
668+
if service_tier is None and optional_params is not None:
669+
service_tier = optional_params.get("service_tier")
661670

662671
selected_model = _select_model_name_for_cost_calc(
663672
model=model,
@@ -909,6 +918,7 @@ def completion_cost( # noqa: PLR0915
909918
call_type=cast(CallTypesLiteral, call_type),
910919
audio_transcription_file_duration=audio_transcription_file_duration,
911920
rerank_billed_units=rerank_billed_units,
921+
service_tier=service_tier,
912922
)
913923
_final_cost = (
914924
prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
@@ -1003,6 +1013,8 @@ def response_cost_calculator(
10031013
litellm_model_name: Optional[str] = None,
10041014
router_model_id: Optional[str] = None,
10051015
litellm_logging_obj: Optional[LitellmLoggingObject] = None,
1016+
### SERVICE TIER ###
1017+
service_tier: Optional[str] = None, # for OpenAI service tier pricing
10061018
) -> float:
10071019
"""
10081020
Returns
@@ -1036,6 +1048,7 @@ def response_cost_calculator(
10361048
litellm_model_name=litellm_model_name,
10371049
router_model_id=router_model_id,
10381050
litellm_logging_obj=litellm_logging_obj,
1051+
service_tier=service_tier,
10391052
)
10401053
return response_cost
10411054
except Exception as e:

litellm/litellm_core_utils/litellm_logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,7 @@ def _response_cost_calculator(
12281228
"standard_built_in_tools_params": self.standard_built_in_tools_params,
12291229
"router_model_id": router_model_id,
12301230
"litellm_logging_obj": self,
1231+
"service_tier": self.optional_params.get("service_tier") if self.optional_params else None,
12311232
}
12321233
except Exception as e: # error creating kwargs for cost calculation
12331234
debug_info = StandardLoggingModelCostFailureDebugInformation(

litellm/litellm_core_utils/llm_cost_calc/utils.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from typing import Any, Literal, Optional, Tuple, TypedDict, cast
55

66
import litellm
7-
from litellm._logging import verbose_logger
7+
from litellm._logging import verbose_logger, verbose_proxy_logger
88
from litellm.types.utils import (
99
CacheCreationTokenDetails,
1010
CallTypes,
1111
ImageResponse,
1212
ModelInfo,
1313
PassthroughCallTypes,
1414
Usage,
15+
ServiceTier,
1516
)
1617
from litellm.utils import get_model_info
1718

@@ -114,8 +115,30 @@ def _generic_cost_per_character(
114115
return prompt_cost, completion_cost
115116

116117

118+
def _get_service_tier_cost_key(base_key: str, service_tier: Optional[str]) -> str:
119+
"""
120+
Get the appropriate cost key based on service tier.
121+
122+
Args:
123+
base_key: The base cost key (e.g., "input_cost_per_token")
124+
service_tier: The service tier ("flex", "priority", or None for standard)
125+
126+
Returns:
127+
str: The cost key to use (e.g., "input_cost_per_token_flex" or "input_cost_per_token")
128+
"""
129+
if service_tier is None:
130+
return base_key
131+
132+
# Only use service tier specific keys for "flex" and "priority"
133+
if service_tier.lower() in [ServiceTier.FLEX.value, ServiceTier.PRIORITY.value]:
134+
return f"{base_key}_{service_tier.lower()}"
135+
136+
# For any other service tier, use standard pricing
137+
return base_key
138+
139+
117140
def _get_token_base_cost(
118-
model_info: ModelInfo, usage: Usage
141+
model_info: ModelInfo, usage: Usage, service_tier: Optional[str] = None
119142
) -> Tuple[float, float, float, float, float]:
120143
"""
121144
Return prompt cost, completion cost, and cache costs for a given model and usage.
@@ -126,21 +149,27 @@ def _get_token_base_cost(
126149
Returns:
127150
Tuple[float, float, float, float] - (prompt_cost, completion_cost, cache_creation_cost, cache_read_cost)
128151
"""
152+
# Get service tier aware cost keys
153+
input_cost_key = _get_service_tier_cost_key("input_cost_per_token", service_tier)
154+
output_cost_key = _get_service_tier_cost_key("output_cost_per_token", service_tier)
155+
cache_creation_cost_key = _get_service_tier_cost_key("cache_creation_input_token_cost", service_tier)
156+
cache_read_cost_key = _get_service_tier_cost_key("cache_read_input_token_cost", service_tier)
157+
129158
prompt_base_cost = cast(
130-
float, _get_cost_per_unit(model_info, "input_cost_per_token")
159+
float, _get_cost_per_unit(model_info, input_cost_key)
131160
)
132161
completion_base_cost = cast(
133-
float, _get_cost_per_unit(model_info, "output_cost_per_token")
162+
float, _get_cost_per_unit(model_info, output_cost_key)
134163
)
135164
cache_creation_cost = cast(
136-
float, _get_cost_per_unit(model_info, "cache_creation_input_token_cost")
165+
float, _get_cost_per_unit(model_info, cache_creation_cost_key)
137166
)
138167
cache_creation_cost_above_1hr = cast(
139168
float,
140169
_get_cost_per_unit(model_info, "cache_creation_input_token_cost_above_1hr"),
141170
)
142171
cache_read_cost = cast(
143-
float, _get_cost_per_unit(model_info, "cache_read_input_token_cost")
172+
float, _get_cost_per_unit(model_info, cache_read_cost_key)
144173
)
145174

146175
## CHECK IF ABOVE THRESHOLD
@@ -249,6 +278,29 @@ def _get_cost_per_unit(
249278
verbose_logger.exception(
250279
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::calculate_cost_per_component(): Exception occured - {cost_per_unit}\nDefaulting to 0.0"
251280
)
281+
282+
# If the service tier key doesn't exist or is None, try to fall back to the standard key
283+
if cost_per_unit is None:
284+
# Check if any service tier suffix exists in the cost key using ServiceTier enum
285+
for service_tier in ServiceTier:
286+
suffix = f"_{service_tier.value}"
287+
if suffix in cost_key:
288+
# Extract the base key by removing the matched suffix
289+
base_key = cost_key.replace(suffix, '')
290+
fallback_cost = model_info.get(base_key)
291+
if isinstance(fallback_cost, float):
292+
return fallback_cost
293+
if isinstance(fallback_cost, int):
294+
return float(fallback_cost)
295+
if isinstance(fallback_cost, str):
296+
try:
297+
return float(fallback_cost)
298+
except ValueError:
299+
verbose_logger.exception(
300+
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::_get_cost_per_unit(): Exception occured - {fallback_cost}\nDefaulting to 0.0"
301+
)
302+
break # Only try the first matching suffix
303+
252304
return default_value
253305

254306

@@ -443,7 +495,7 @@ def _calculate_input_cost(
443495

444496

445497
def generic_cost_per_token(
446-
model: str, usage: Usage, custom_llm_provider: str
498+
model: str, usage: Usage, custom_llm_provider: str, service_tier: Optional[str] = None
447499
) -> Tuple[float, float]:
448500
"""
449501
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@@ -495,7 +547,7 @@ def generic_cost_per_token(
495547
cache_creation_cost,
496548
cache_creation_cost_above_1hr,
497549
cache_read_cost,
498-
) = _get_token_base_cost(model_info=model_info, usage=usage)
550+
) = _get_token_base_cost(model_info=model_info, usage=usage, service_tier=service_tier)
499551

500552
prompt_cost = _calculate_input_cost(
501553
prompt_tokens_details=prompt_tokens_details,

litellm/llms/openai/cost_calculation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def cost_router(call_type: CallTypes) -> Literal["cost_per_token", "cost_per_sec
1818
return "cost_per_token"
1919

2020

21-
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
21+
def cost_per_token(model: str, usage: Usage, service_tier: Optional[str] = None) -> Tuple[float, float]:
2222
"""
2323
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
2424
@@ -31,7 +31,7 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
3131
"""
3232
## CALCULATE INPUT COST
3333
return generic_cost_per_token(
34-
model=model, usage=usage, custom_llm_provider="openai"
34+
model=model, usage=usage, custom_llm_provider="openai", service_tier=service_tier
3535
)
3636
# ### Non-cached text tokens
3737
# non_cached_text_tokens = usage.prompt_tokens

0 commit comments

Comments
 (0)