Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions routstr/payment/cost_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def calculate_cost( # todo: can be sync
MSATS_PER_1K_OUTPUT_TOKENS: float = (
float(settings.fixed_per_1k_output_tokens) * 1000.0
)
MSATS_PER_1K_IMAGE_COMPLETION_TOKENS: float = 0.0

if not settings.fixed_pricing:
response_model = response_data.get("model", "")
Expand Down Expand Up @@ -104,18 +105,21 @@ async def calculate_cost( # todo: can be sync
try:
mspp = float(model_obj.sats_pricing.prompt)
mspc = float(model_obj.sats_pricing.completion)
mspci = float(getattr(model_obj.sats_pricing, "completion_image", 0.0))
except Exception:
return CostDataError(message="Invalid pricing data", code="pricing_invalid")

MSATS_PER_1K_INPUT_TOKENS = mspp * 1_000_000.0
MSATS_PER_1K_OUTPUT_TOKENS = mspc * 1_000_000.0
MSATS_PER_1K_IMAGE_COMPLETION_TOKENS = mspci * 1_000_000.0

logger.info(
"Applied model-specific pricing",
extra={
"model": response_model,
"input_price_msats_per_1k": MSATS_PER_1K_INPUT_TOKENS,
"output_price_msats_per_1k": MSATS_PER_1K_OUTPUT_TOKENS,
"image_completion_price_msats_per_1k": MSATS_PER_1K_IMAGE_COMPLETION_TOKENS,
},
)

Expand All @@ -128,13 +132,44 @@ async def calculate_cost( # todo: can be sync
},
)
return cost_data
usage_data = response_data["usage"]
input_tokens = usage_data.get("prompt_tokens", 0)
output_tokens = usage_data.get("completion_tokens", 0)

input_tokens = response_data.get("usage", {}).get("prompt_tokens", 0)
output_tokens = response_data.get("usage", {}).get("completion_tokens", 0)
# added for response api
input_tokens = (
input_tokens if input_tokens != 0 else usage_data.get("input_tokens", 0)
)
output_tokens = (
output_tokens if output_tokens != 0 else usage_data.get("output_tokens", 0)
)

# Calculate image completion cost
image_completion_msats = 0.0
if MSATS_PER_1K_IMAGE_COMPLETION_TOKENS > 0:
completion_details = usage_data.get("completion_tokens_details", {})
image_tokens = completion_details.get("image_tokens", 0)

if image_tokens > 0:
if output_tokens >= image_tokens:
output_tokens -= image_tokens

image_completion_msats = round(
image_tokens / 1000 * MSATS_PER_1K_IMAGE_COMPLETION_TOKENS, 3
)

logger.info(
"Calculated image completion cost",
extra={
"image_tokens": image_tokens,
"image_completion_msats": image_completion_msats,
},
)

input_msats = round(input_tokens / 1000 * MSATS_PER_1K_INPUT_TOKENS, 3)

output_msats = round(output_tokens / 1000 * MSATS_PER_1K_OUTPUT_TOKENS, 3)
token_based_cost = math.ceil(input_msats + output_msats)
token_based_cost = math.ceil(input_msats + output_msats + image_completion_msats)

logger.info(
"Calculated token-based cost",
Expand All @@ -143,6 +178,7 @@ async def calculate_cost( # todo: can be sync
"output_tokens": output_tokens,
"input_cost_msats": input_msats,
"output_cost_msats": output_msats,
"image_completion_msats": image_completion_msats,
"total_cost_msats": token_based_cost,
"model": response_data.get("model", "unknown"),
},
Expand Down
25 changes: 25 additions & 0 deletions routstr/payment/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Pricing(BaseModel):
completion: float
request: float = 0.0
image: float = 0.0
completion_image: float = 0.0
web_search: float = 0.0
internal_reasoning: float = 0.0
input_cache_read: float = 0.0
Expand All @@ -40,6 +41,13 @@ class Pricing(BaseModel):
max_cost: float = 0.0 # in sats not msats


PRICING_OVERRIDES = {
"gemini-3-pro-image-preview": {"completion_image": 0.00012},
"gemini-2.5-flash-image": {"completion_image": 0.00003},
"gemini-2.0-flash": {"completion_image": 0.00003},
}


class TopProvider(BaseModel):
context_length: int | None = None
max_completion_tokens: int | None = None
Expand Down Expand Up @@ -116,6 +124,16 @@ async def async_fetch_openrouter_models(source_filter: str | None = None) -> lis
if not _has_valid_pricing(model):
continue

# Apply manual pricing overrides
if model_id in PRICING_OVERRIDES:
pricing = model.get("pricing", {})
if pricing:
for k, v in PRICING_OVERRIDES[model_id].items():
pricing[k] = str(
v
) # OpenRouter API returns strings for pricing
model["pricing"] = pricing

models_data.append(model)

return models_data
Expand Down Expand Up @@ -148,6 +166,12 @@ def _row_to_model(
if isinstance(pricing, dict) and float(pricing.get("request", 0.0)) <= 0.0:
pricing["request"] = max(pricing.get("request", 0.0), 0.0)

# Apply defaults for missing fields from manual overrides
if row.id in PRICING_OVERRIDES and isinstance(pricing, dict):
for k, v in PRICING_OVERRIDES[row.id].items():
if k not in pricing:
pricing[k] = v

parsed_pricing = Pricing.parse_obj(pricing)
model = Model(
id=row.id,
Expand Down Expand Up @@ -507,6 +531,7 @@ def _pricing_matches(
"completion",
"request",
"image",
"completion_image",
"web_search",
"internal_reasoning",
]
Expand Down
Loading