Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tokencost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@
calculate_prompt_cost,
calculate_all_costs_and_tokens,
calculate_cost_by_tokens,
configure_model,
register_model_pattern,
)
from .constants import TOKEN_COSTS_STATIC, TOKEN_COSTS, update_token_costs, refresh_prices
134 changes: 127 additions & 7 deletions tokencost/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .constants import TOKEN_COSTS
from decimal import Decimal
import logging
import re
from typing import Optional, Tuple, Pattern

logger = logging.getLogger(__name__)

Expand All @@ -19,6 +21,122 @@
TokenType = Literal["input", "output", "cached"]


MODEL_PRICE_PATTERNS: List[Tuple[Pattern[str], Dict[str, Union[int, float, str, bool]]]] = []


def _to_per_token(cost_per_1k_tokens: Union[int, float, Decimal]) -> float:
"""Convert a price expressed per 1K tokens to a per-token float."""
return float(Decimal(str(cost_per_1k_tokens)) / Decimal(1000))


def configure_model(
model_name: str,
input_cost_per_1k_tokens: Union[int, float, Decimal],
output_cost_per_1k_tokens: Union[int, float, Decimal],
*,
max_input_tokens: Optional[int] = None,
max_output_tokens: Optional[int] = None,
litellm_provider: Optional[str] = None,
mode: str = "chat",
) -> None:
"""
Register or override pricing for a specific model name.

Args:
model_name: The exact model identifier to store (case-insensitive).
input_cost_per_1k_tokens: USD per 1K input tokens (e.g., 0.003 for $3.00/M).
output_cost_per_1k_tokens: USD per 1K output tokens (e.g., 0.015 for $15.00/M).
max_input_tokens: Optional maximum input tokens.
max_output_tokens: Optional maximum output tokens.
litellm_provider: Optional provider hint (e.g., "bedrock").
mode: Model mode, defaults to "chat".
"""
key = model_name.lower()
TOKEN_COSTS[key] = {
"input_cost_per_token": _to_per_token(input_cost_per_1k_tokens),
"output_cost_per_token": _to_per_token(output_cost_per_1k_tokens),
"mode": mode,
}
if max_input_tokens is not None:
TOKEN_COSTS[key]["max_input_tokens"] = int(max_input_tokens)
if max_output_tokens is not None:
TOKEN_COSTS[key]["max_output_tokens"] = int(max_output_tokens)
if litellm_provider is not None:
TOKEN_COSTS[key]["litellm_provider"] = litellm_provider


def register_model_pattern(
pattern: str,
input_cost_per_1k_tokens: Union[int, float, Decimal],
output_cost_per_1k_tokens: Union[int, float, Decimal],
*,
max_input_tokens: Optional[int] = None,
max_output_tokens: Optional[int] = None,
litellm_provider: Optional[str] = None,
mode: str = "chat",
) -> None:
"""
Register a wildcard or regex-like pattern that assigns pricing to any matching model.

The pattern supports '*' as a wildcard. It is converted to a full regex match.
Example: "bedrock/anthropic.claude-3-5-sonnet-*".
"""
# Convert simple wildcard pattern to regex
regex_str = "^" + re.escape(pattern).replace(r"\*", ".*") + "$"
compiled = re.compile(regex_str)
entry: Dict[str, Union[int, float, str, bool]] = {
"input_cost_per_token": _to_per_token(input_cost_per_1k_tokens),
"output_cost_per_token": _to_per_token(output_cost_per_1k_tokens),
"mode": mode,
}
if max_input_tokens is not None:
entry["max_input_tokens"] = int(max_input_tokens)
if max_output_tokens is not None:
entry["max_output_tokens"] = int(max_output_tokens)
if litellm_provider is not None:
entry["litellm_provider"] = litellm_provider
MODEL_PRICE_PATTERNS.append((compiled, entry))


def _normalize_model_for_pricing(model: str) -> str:
"""
Normalize a model identifier for price lookup.

Rules:
- Lowercase everything
- Keep exact matches if present
- Special-case Bedrock Anthropics: strip the leading "bedrock/" prefix when the next
segment starts with "anthropic.", since pricing keys are stored without the prefix.
- Otherwise, try the last segment after '/'. This helps for provider prefixes like
"azure/", "openai/", etc., when prices are stored under the bare model key.
"""
m = model.lower()
if m in TOKEN_COSTS:
return m

# bedrock/anthropic.* => anthropic.* (pricing keys stored this way)
if m.startswith("bedrock/") and "/" in m:
first, rest = m.split("/", 1)
if rest.startswith("anthropic."):
if rest in TOKEN_COSTS:
return rest

# Try last path segment as a fallback (handles e.g., azure/gpt-4o)
if "/" in m:
last = m.split("/")[-1]
if last in TOKEN_COSTS:
return last

# Try matching any user-registered wildcard patterns. If matched, bind pricing to this key.
for regex, entry in MODEL_PRICE_PATTERNS:
if regex.match(m):
# Cache the computed pricing under the exact model string
TOKEN_COSTS[m] = dict(entry)
return m

return m


def _get_field_from_token_type(token_type: TokenType) -> str:
"""
Get the field name from the token type.
Expand Down Expand Up @@ -97,7 +215,7 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int:
model = strip_ft_model_name(model)

# Anthropic token counting requires a valid API key
if "claude-" in model:
if "claude-" in model and not model.startswith("anthropic."):
logger.warning(
"Warning: Anthropic token counting API is currently in beta. Please expect differences in costs!"
)
Expand Down Expand Up @@ -199,7 +317,7 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: TokenType)
Returns:
Decimal: The calculated cost in USD.
"""
model = model.lower()
model = _normalize_model_for_pricing(model)
if model not in TOKEN_COSTS:
raise KeyError(
f"""Model {model} is not implemented.
Expand Down Expand Up @@ -238,7 +356,8 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
"""
model = model.lower()
model = strip_ft_model_name(model)
if model not in TOKEN_COSTS:
pricing_model = _normalize_model_for_pricing(model)
if pricing_model not in TOKEN_COSTS:
raise KeyError(
f"""Model {model} is not implemented.
Double-check your spelling, or submit an issue/PR"""
Expand All @@ -253,7 +372,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal
else count_message_tokens(prompt, model)
)

return calculate_cost_by_tokens(prompt_tokens, model, "input")
return calculate_cost_by_tokens(prompt_tokens, pricing_model, "input")


def calculate_completion_cost(completion: str, model: str) -> Decimal:
Expand All @@ -273,7 +392,8 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
Decimal('0.000014')
"""
model = strip_ft_model_name(model)
if model not in TOKEN_COSTS:
pricing_model = _normalize_model_for_pricing(model)
if pricing_model not in TOKEN_COSTS:
raise KeyError(
f"""Model {model} is not implemented.
Double-check your spelling, or submit an issue/PR"""
Expand All @@ -291,7 +411,7 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal:
else:
completion_tokens = count_string_tokens(completion, model)

return calculate_cost_by_tokens(completion_tokens, model, "output")
return calculate_cost_by_tokens(completion_tokens, pricing_model, "output")


def calculate_all_costs_and_tokens(
Expand Down Expand Up @@ -322,7 +442,7 @@ def calculate_all_costs_and_tokens(
else count_message_tokens(prompt, model)
)

if "claude-" in model:
if "claude-" in model and not model.startswith("anthropic."):
logger.warning("Warning: Token counting is estimated for ")
completion_list = [{"role": "assistant", "content": completion}]
# Anthropic appends some 13 additional tokens to the actual completion tokens
Expand Down