diff --git a/tokencost/__init__.py b/tokencost/__init__.py index 3184582..6973f0f 100644 --- a/tokencost/__init__.py +++ b/tokencost/__init__.py @@ -5,5 +5,7 @@ calculate_prompt_cost, calculate_all_costs_and_tokens, calculate_cost_by_tokens, + configure_model, + register_model_pattern, ) from .constants import TOKEN_COSTS_STATIC, TOKEN_COSTS, update_token_costs, refresh_prices diff --git a/tokencost/costs.py b/tokencost/costs.py index ebb9756..cc4680f 100644 --- a/tokencost/costs.py +++ b/tokencost/costs.py @@ -9,6 +9,8 @@ from .constants import TOKEN_COSTS from decimal import Decimal import logging +import re +from typing import Optional, Tuple, Pattern logger = logging.getLogger(__name__) @@ -19,6 +21,122 @@ TokenType = Literal["input", "output", "cached"] +MODEL_PRICE_PATTERNS: List[Tuple[Pattern[str], Dict[str, Union[int, float, str, bool]]]] = [] + + +def _to_per_token(cost_per_1k_tokens: Union[int, float, Decimal]) -> float: + """Convert a price expressed per 1K tokens to a per-token float.""" + return float(Decimal(str(cost_per_1k_tokens)) / Decimal(1000)) + + +def configure_model( + model_name: str, + input_cost_per_1k_tokens: Union[int, float, Decimal], + output_cost_per_1k_tokens: Union[int, float, Decimal], + *, + max_input_tokens: Optional[int] = None, + max_output_tokens: Optional[int] = None, + litellm_provider: Optional[str] = None, + mode: str = "chat", +) -> None: + """ + Register or override pricing for a specific model name. + + Args: + model_name: The exact model identifier to store (case-insensitive). + input_cost_per_1k_tokens: USD per 1K input tokens (e.g., 0.003 for $3.00/M). + output_cost_per_1k_tokens: USD per 1K output tokens (e.g., 0.015 for $15.00/M). + max_input_tokens: Optional maximum input tokens. + max_output_tokens: Optional maximum output tokens. + litellm_provider: Optional provider hint (e.g., "bedrock"). + mode: Model mode, defaults to "chat". + """ + key = model_name.lower() + TOKEN_COSTS[key] = { + "input_cost_per_token": _to_per_token(input_cost_per_1k_tokens), + "output_cost_per_token": _to_per_token(output_cost_per_1k_tokens), + "mode": mode, + } + if max_input_tokens is not None: + TOKEN_COSTS[key]["max_input_tokens"] = int(max_input_tokens) + if max_output_tokens is not None: + TOKEN_COSTS[key]["max_output_tokens"] = int(max_output_tokens) + if litellm_provider is not None: + TOKEN_COSTS[key]["litellm_provider"] = litellm_provider + + +def register_model_pattern( + pattern: str, + input_cost_per_1k_tokens: Union[int, float, Decimal], + output_cost_per_1k_tokens: Union[int, float, Decimal], + *, + max_input_tokens: Optional[int] = None, + max_output_tokens: Optional[int] = None, + litellm_provider: Optional[str] = None, + mode: str = "chat", +) -> None: + """ + Register a wildcard or regex-like pattern that assigns pricing to any matching model. + + The pattern supports '*' as a wildcard. It is converted to a full regex match. + Example: "bedrock/anthropic.claude-3-5-sonnet-*". + """ + # Convert simple wildcard pattern to regex + regex_str = "^" + re.escape(pattern).replace(r"\*", ".*") + "$" + compiled = re.compile(regex_str) + entry: Dict[str, Union[int, float, str, bool]] = { + "input_cost_per_token": _to_per_token(input_cost_per_1k_tokens), + "output_cost_per_token": _to_per_token(output_cost_per_1k_tokens), + "mode": mode, + } + if max_input_tokens is not None: + entry["max_input_tokens"] = int(max_input_tokens) + if max_output_tokens is not None: + entry["max_output_tokens"] = int(max_output_tokens) + if litellm_provider is not None: + entry["litellm_provider"] = litellm_provider + MODEL_PRICE_PATTERNS.append((compiled, entry)) + + +def _normalize_model_for_pricing(model: str) -> str: + """ + Normalize a model identifier for price lookup. + + Rules: + - Lowercase everything + - Keep exact matches if present + - Special-case Bedrock Anthropics: strip the leading "bedrock/" prefix when the next + segment starts with "anthropic.", since pricing keys are stored without the prefix. + - Otherwise, try the last segment after '/'. This helps for provider prefixes like + "azure/", "openai/", etc., when prices are stored under the bare model key. + """ + m = model.lower() + if m in TOKEN_COSTS: + return m + + # bedrock/anthropic.* => anthropic.* (pricing keys stored this way) + if m.startswith("bedrock/") and "/" in m: + first, rest = m.split("/", 1) + if rest.startswith("anthropic."): + if rest in TOKEN_COSTS: + return rest + + # Try last path segment as a fallback (handles e.g., azure/gpt-4o) + if "/" in m: + last = m.split("/")[-1] + if last in TOKEN_COSTS: + return last + + # Try matching any user-registered wildcard patterns. If matched, bind pricing to this key. + for regex, entry in MODEL_PRICE_PATTERNS: + if regex.match(m): + # Cache the computed pricing under the exact model string + TOKEN_COSTS[m] = dict(entry) + return m + + return m + + def _get_field_from_token_type(token_type: TokenType) -> str: """ Get the field name from the token type. @@ -97,7 +215,7 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int: model = strip_ft_model_name(model) # Anthropic token counting requires a valid API key - if "claude-" in model: + if "claude-" in model and not model.startswith("anthropic."): logger.warning( "Warning: Anthropic token counting API is currently in beta. Please expect differences in costs!" ) @@ -199,7 +317,7 @@ def calculate_cost_by_tokens(num_tokens: int, model: str, token_type: TokenType) Returns: Decimal: The calculated cost in USD. """ - model = model.lower() + model = _normalize_model_for_pricing(model) if model not in TOKEN_COSTS: raise KeyError( f"""Model {model} is not implemented. @@ -238,7 +356,8 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal """ model = model.lower() model = strip_ft_model_name(model) - if model not in TOKEN_COSTS: + pricing_model = _normalize_model_for_pricing(model) + if pricing_model not in TOKEN_COSTS: raise KeyError( f"""Model {model} is not implemented. Double-check your spelling, or submit an issue/PR""" @@ -253,7 +372,7 @@ def calculate_prompt_cost(prompt: Union[List[dict], str], model: str) -> Decimal else count_message_tokens(prompt, model) ) - return calculate_cost_by_tokens(prompt_tokens, model, "input") + return calculate_cost_by_tokens(prompt_tokens, pricing_model, "input") def calculate_completion_cost(completion: str, model: str) -> Decimal: @@ -273,7 +392,8 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal: Decimal('0.000014') """ model = strip_ft_model_name(model) - if model not in TOKEN_COSTS: + pricing_model = _normalize_model_for_pricing(model) + if pricing_model not in TOKEN_COSTS: raise KeyError( f"""Model {model} is not implemented. Double-check your spelling, or submit an issue/PR""" @@ -291,7 +411,7 @@ def calculate_completion_cost(completion: str, model: str) -> Decimal: else: completion_tokens = count_string_tokens(completion, model) - return calculate_cost_by_tokens(completion_tokens, model, "output") + return calculate_cost_by_tokens(completion_tokens, pricing_model, "output") def calculate_all_costs_and_tokens( @@ -322,7 +442,7 @@ def calculate_all_costs_and_tokens( else count_message_tokens(prompt, model) ) - if "claude-" in model: + if "claude-" in model and not model.startswith("anthropic."): logger.warning("Warning: Token counting is estimated for ") completion_list = [{"role": "assistant", "content": completion}] # Anthropic appends some 13 additional tokens to the actual completion tokens