Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 82 additions & 3 deletions routstr/payment/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import random

import httpx
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic.v1 import BaseModel
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession

from ..core.db import ModelRow, create_session, get_session
from ..core.db import ApiKey, ModelRow, create_session, get_session
from ..core.logging import get_logger
from ..core.settings import settings
from ..wallet import deserialize_token_from_string
from .price import sats_usd_price

logger = get_logger(__name__)
Expand Down Expand Up @@ -521,11 +522,89 @@ def _pricing_matches(
return True



async def _get_request_balance(request: Request, session: AsyncSession) -> int | None:
"""Get the balance from the request headers if authentication is provided."""
headers = request.headers
token: str | None = None

if x_cashu := headers.get("x-cashu"):
token = x_cashu
elif auth := headers.get("authorization"):
parts = auth.split(" ")
if len(parts) > 1:
token = parts[1]

if not token:
return None

# Handle API keys (sk-*)
if token.startswith("sk-"):
try:
# sk- keys use the part after "sk-" as the ID
key_id = token[3:]
key = await session.get(ApiKey, key_id)
if key:
return key.balance - key.reserved_balance
except Exception as e:
logger.warning(f"Error checking API key balance: {e}")
return None

# Handle Cashu tokens
try:
token_obj = deserialize_token_from_string(token)
amount_msat = (
token_obj.amount if token_obj.unit == "msat" else token_obj.amount * 1000
)
return amount_msat
except Exception as e:
logger.debug(f"Failed to deserialize cashu token for balance check: {e}")
return None


@models_router.get("/v1/models")
@models_router.get("/models", include_in_schema=False)
async def models(session: AsyncSession = Depends(get_session)) -> dict:
async def models(
request: Request, session: AsyncSession = Depends(get_session)
) -> dict:
"""Get all available models from all providers with database overrides applied."""
from ..proxy import get_unique_models

items = get_unique_models()

# Optional: Filter by user balance if authenticated
user_balance = await _get_request_balance(request, session)
if user_balance is not None:
filtered_items = []

# Calculate tolerance factor once
tol_factor = 1.0
if not settings.fixed_pricing and settings.tolerance_percentage > 0:
tol_factor = 1.0 - (settings.tolerance_percentage / 100.0)

fixed_cost = settings.fixed_cost_per_request * 1000
min_cost = settings.min_request_msat

for model in items:
model_cost_msats = 0

if settings.fixed_pricing:
model_cost_msats = max(min_cost, fixed_cost)
elif model.sats_pricing:
# Use model specific pricing
max_cost = (
model.sats_pricing.max_cost
* 1000
* tol_factor
)
model_cost_msats = max(min_cost, int(max_cost))
else:
# Fallback if no pricing found
model_cost_msats = max(min_cost, fixed_cost)

if model_cost_msats <= user_balance:
filtered_items.append(model)

items = filtered_items

return {"data": items}
Loading