Skip to content

Commit 8a012f9

Browse files
Merge pull request #14797 from BerriAI/litellm_anthopic_token_count_issue
fix liniting issue
2 parents 7216983 + 3776a00 commit 8a012f9

File tree

6 files changed

+104
-80
lines changed

6 files changed

+104
-80
lines changed

litellm/cost_calculator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,9 @@ def cost_per_token( # noqa: PLR0915
330330
elif custom_llm_provider == "bedrock":
331331
return bedrock_cost_per_token(model=model, usage=usage_block)
332332
elif custom_llm_provider == "openai":
333-
return openai_cost_per_token(model=model, usage=usage_block, service_tier=service_tier)
333+
return openai_cost_per_token(
334+
model=model, usage=usage_block, service_tier=service_tier
335+
)
334336
elif custom_llm_provider == "databricks":
335337
return databricks_cost_per_token(model=model, usage=usage_block)
336338
elif custom_llm_provider == "fireworks_ai":
@@ -351,6 +353,7 @@ def cost_per_token( # noqa: PLR0915
351353
from litellm.llms.dashscope.cost_calculator import (
352354
cost_per_token as dashscope_cost_per_token,
353355
)
356+
354357
return dashscope_cost_per_token(model=model, usage=usage_block)
355358
else:
356359
model_info = _cached_get_model_info_helper(
@@ -663,7 +666,7 @@ def completion_cost( # noqa: PLR0915
663666
completion_response=completion_response
664667
)
665668
rerank_billed_units: Optional[RerankBilledUnits] = None
666-
669+
667670
# Extract service_tier from optional_params if not provided directly
668671
if service_tier is None and optional_params is not None:
669672
service_tier = optional_params.get("service_tier")

litellm/litellm_core_utils/litellm_logging.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,9 @@ def _response_cost_calculator(
12281228
"standard_built_in_tools_params": self.standard_built_in_tools_params,
12291229
"router_model_id": router_model_id,
12301230
"litellm_logging_obj": self,
1231-
"service_tier": self.optional_params.get("service_tier") if self.optional_params else None,
1231+
"service_tier": self.optional_params.get("service_tier")
1232+
if self.optional_params
1233+
else None,
12321234
}
12331235
except Exception as e: # error creating kwargs for cost calculation
12341236
debug_info = StandardLoggingModelCostFailureDebugInformation(
@@ -4191,16 +4193,22 @@ def _generate_cold_storage_object_key(
41914193

41924194
# Get the actual s3_path from the configured cold storage logger instance
41934195
s3_path = "" # default value
4194-
4196+
41954197
# Try to get the actual logger instance from the logger name
41964198
try:
4197-
custom_logger = litellm.logging_callback_manager.get_active_custom_logger_for_callback_name(configured_cold_storage_logger)
4198-
if custom_logger and hasattr(custom_logger, 's3_path') and custom_logger.s3_path:
4199+
custom_logger = litellm.logging_callback_manager.get_active_custom_logger_for_callback_name(
4200+
configured_cold_storage_logger
4201+
)
4202+
if (
4203+
custom_logger
4204+
and hasattr(custom_logger, "s3_path")
4205+
and custom_logger.s3_path
4206+
):
41994207
s3_path = custom_logger.s3_path
42004208
except Exception:
42014209
# If any error occurs in getting the logger instance, use default empty s3_path
42024210
pass
4203-
4211+
42044212
s3_object_key = get_s3_object_key(
42054213
s3_path=s3_path, # Use actual s3_path from logger configuration
42064214
team_alias_prefix="", # Don't split by team alias for cold storage

litellm/litellm_core_utils/llm_cost_calc/utils.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
ImageResponse,
1212
ModelInfo,
1313
PassthroughCallTypes,
14-
Usage,
1514
ServiceTier,
15+
Usage,
1616
)
1717
from litellm.utils import get_model_info
1818

@@ -118,21 +118,21 @@ def _generic_cost_per_character(
118118
def _get_service_tier_cost_key(base_key: str, service_tier: Optional[str]) -> str:
119119
"""
120120
Get the appropriate cost key based on service tier.
121-
121+
122122
Args:
123123
base_key: The base cost key (e.g., "input_cost_per_token")
124124
service_tier: The service tier ("flex", "priority", or None for standard)
125-
125+
126126
Returns:
127127
str: The cost key to use (e.g., "input_cost_per_token_flex" or "input_cost_per_token")
128128
"""
129129
if service_tier is None:
130130
return base_key
131-
131+
132132
# Only use service tier specific keys for "flex" and "priority"
133133
if service_tier.lower() in [ServiceTier.FLEX.value, ServiceTier.PRIORITY.value]:
134134
return f"{base_key}_{service_tier.lower()}"
135-
135+
136136
# For any other service tier, use standard pricing
137137
return base_key
138138

@@ -152,25 +152,23 @@ def _get_token_base_cost(
152152
# Get service tier aware cost keys
153153
input_cost_key = _get_service_tier_cost_key("input_cost_per_token", service_tier)
154154
output_cost_key = _get_service_tier_cost_key("output_cost_per_token", service_tier)
155-
cache_creation_cost_key = _get_service_tier_cost_key("cache_creation_input_token_cost", service_tier)
156-
cache_read_cost_key = _get_service_tier_cost_key("cache_read_input_token_cost", service_tier)
157-
158-
prompt_base_cost = cast(
159-
float, _get_cost_per_unit(model_info, input_cost_key)
155+
cache_creation_cost_key = _get_service_tier_cost_key(
156+
"cache_creation_input_token_cost", service_tier
160157
)
161-
completion_base_cost = cast(
162-
float, _get_cost_per_unit(model_info, output_cost_key)
158+
cache_read_cost_key = _get_service_tier_cost_key(
159+
"cache_read_input_token_cost", service_tier
163160
)
161+
162+
prompt_base_cost = cast(float, _get_cost_per_unit(model_info, input_cost_key))
163+
completion_base_cost = cast(float, _get_cost_per_unit(model_info, output_cost_key))
164164
cache_creation_cost = cast(
165165
float, _get_cost_per_unit(model_info, cache_creation_cost_key)
166166
)
167167
cache_creation_cost_above_1hr = cast(
168168
float,
169169
_get_cost_per_unit(model_info, "cache_creation_input_token_cost_above_1hr"),
170170
)
171-
cache_read_cost = cast(
172-
float, _get_cost_per_unit(model_info, cache_read_cost_key)
173-
)
171+
cache_read_cost = cast(float, _get_cost_per_unit(model_info, cache_read_cost_key))
174172

175173
## CHECK IF ABOVE THRESHOLD
176174
threshold: Optional[float] = None
@@ -183,7 +181,6 @@ def _get_token_base_cost(
183181
1000 if "k" in threshold_str else 1
184182
)
185183
if usage.prompt_tokens > threshold:
186-
187184
prompt_base_cost = cast(
188185
float, _get_cost_per_unit(model_info, key, prompt_base_cost)
189186
)
@@ -278,15 +275,15 @@ def _get_cost_per_unit(
278275
verbose_logger.exception(
279276
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::calculate_cost_per_component(): Exception occured - {cost_per_unit}\nDefaulting to 0.0"
280277
)
281-
278+
282279
# If the service tier key doesn't exist or is None, try to fall back to the standard key
283280
if cost_per_unit is None:
284281
# Check if any service tier suffix exists in the cost key using ServiceTier enum
285282
for service_tier in ServiceTier:
286283
suffix = f"_{service_tier.value}"
287284
if suffix in cost_key:
288285
# Extract the base key by removing the matched suffix
289-
base_key = cost_key.replace(suffix, '')
286+
base_key = cost_key.replace(suffix, "")
290287
fallback_cost = model_info.get(base_key)
291288
if isinstance(fallback_cost, float):
292289
return fallback_cost
@@ -300,7 +297,7 @@ def _get_cost_per_unit(
300297
f"litellm.litellm_core_utils.llm_cost_calc.utils.py::_get_cost_per_unit(): Exception occured - {fallback_cost}\nDefaulting to 0.0"
301298
)
302299
break # Only try the first matching suffix
303-
300+
304301
return default_value
305302

306303

@@ -495,7 +492,10 @@ def _calculate_input_cost(
495492

496493

497494
def generic_cost_per_token(
498-
model: str, usage: Usage, custom_llm_provider: str, service_tier: Optional[str] = None
495+
model: str,
496+
usage: Usage,
497+
custom_llm_provider: str,
498+
service_tier: Optional[str] = None,
499499
) -> Tuple[float, float]:
500500
"""
501501
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
@@ -547,7 +547,9 @@ def generic_cost_per_token(
547547
cache_creation_cost,
548548
cache_creation_cost_above_1hr,
549549
cache_read_cost,
550-
) = _get_token_base_cost(model_info=model_info, usage=usage, service_tier=service_tier)
550+
) = _get_token_base_cost(
551+
model_info=model_info, usage=usage, service_tier=service_tier
552+
)
551553

552554
prompt_cost = _calculate_input_cost(
553555
prompt_tokens_details=prompt_tokens_details,

litellm/llms/openai/cost_calculation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def cost_router(call_type: CallTypes) -> Literal["cost_per_token", "cost_per_sec
1818
return "cost_per_token"
1919

2020

21-
def cost_per_token(model: str, usage: Usage, service_tier: Optional[str] = None) -> Tuple[float, float]:
21+
def cost_per_token(
22+
model: str, usage: Usage, service_tier: Optional[str] = None
23+
) -> Tuple[float, float]:
2224
"""
2325
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
2426
@@ -31,7 +33,10 @@ def cost_per_token(model: str, usage: Usage, service_tier: Optional[str] = None)
3133
"""
3234
## CALCULATE INPUT COST
3335
return generic_cost_per_token(
34-
model=model, usage=usage, custom_llm_provider="openai", service_tier=service_tier
36+
model=model,
37+
usage=usage,
38+
custom_llm_provider="openai",
39+
service_tier=service_tier,
3540
)
3641
# ### Non-cached text tokens
3742
# non_cached_text_tokens = usage.prompt_tokens

litellm/types/utils.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,19 @@
99
Literal,
1010
Mapping,
1111
Optional,
12-
Tuple,
1312
Union,
1413
)
1514

1615
import fastuuid as uuid
17-
from aiohttp import FormData
1816
from openai._models import BaseModel as OpenAIObject
19-
from openai.types.audio.transcription_create_params import FileTypes # type: ignore
20-
from openai.types.chat.chat_completion import ChatCompletion
2117
from openai.types.completion_usage import (
2218
CompletionTokensDetails,
2319
CompletionUsage,
2420
PromptTokensDetails,
2521
)
26-
from openai.types.moderation import (
27-
Categories,
28-
CategoryAppliedInputTypes,
29-
CategoryScores,
30-
)
31-
from openai.types.moderation_create_response import Moderation, ModerationCreateResponse
3222
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
33-
from typing_extensions import Callable, Dict, Required, TypedDict, override
23+
from typing_extensions import Required, TypedDict
3424

35-
import litellm
3625
from litellm.types.llms.base import (
3726
BaseLiteLLMOpenAIResponseObject,
3827
LiteLLMPydanticObjectBase,
@@ -57,7 +46,6 @@
5746
OpenAIRealtimeStreamList,
5847
WebSearchOptions,
5948
)
60-
from .rerank import RerankResponse
6149

6250
if TYPE_CHECKING:
6351
from .vector_stores import VectorStoreSearchResponse
@@ -123,12 +111,18 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
123111
max_output_tokens: Required[Optional[int]]
124112
input_cost_per_token: Required[float]
125113
input_cost_per_token_flex: Optional[float] # OpenAI flex service tier pricing
126-
input_cost_per_token_priority: Optional[float] # OpenAI priority service tier pricing
114+
input_cost_per_token_priority: Optional[
115+
float
116+
] # OpenAI priority service tier pricing
127117
cache_creation_input_token_cost: Optional[float]
128118
cache_creation_input_token_cost_above_1hr: Optional[float]
129119
cache_read_input_token_cost: Optional[float]
130-
cache_read_input_token_cost_flex: Optional[float] # OpenAI flex service tier pricing
131-
cache_read_input_token_cost_priority: Optional[float] # OpenAI priority service tier pricing
120+
cache_read_input_token_cost_flex: Optional[
121+
float
122+
] # OpenAI flex service tier pricing
123+
cache_read_input_token_cost_priority: Optional[
124+
float
125+
] # OpenAI priority service tier pricing
132126
input_cost_per_character: Optional[float] # only for vertex ai models
133127
input_cost_per_audio_token: Optional[float]
134128
input_cost_per_token_above_128k_tokens: Optional[float] # only for vertex ai models
@@ -147,7 +141,9 @@ class ModelInfoBase(ProviderSpecificModelInfo, total=False):
147141
output_cost_per_token_batches: Optional[float]
148142
output_cost_per_token: Required[float]
149143
output_cost_per_token_flex: Optional[float] # OpenAI flex service tier pricing
150-
output_cost_per_token_priority: Optional[float] # OpenAI priority service tier pricing
144+
output_cost_per_token_priority: Optional[
145+
float
146+
] # OpenAI priority service tier pricing
151147
output_cost_per_character: Optional[float] # only for vertex ai models
152148
output_cost_per_audio_token: Optional[float]
153149
output_cost_per_token_above_128k_tokens: Optional[
@@ -1141,9 +1137,6 @@ def __init__(self, **kwargs):
11411137
super().__init__(**kwargs)
11421138

11431139

1144-
from openai.types.chat import ChatCompletionChunk
1145-
1146-
11471140
class ModelResponseBase(OpenAIObject):
11481141
id: str
11491142
"""A unique identifier for the completion."""
@@ -2592,6 +2585,7 @@ class SpecialEnums(Enum):
25922585

25932586
class ServiceTier(Enum):
25942587
"""Enum for service tier types used in cost calculations."""
2588+
25952589
FLEX = "flex"
25962590
PRIORITY = "priority"
25972591

0 commit comments

Comments
 (0)