Skip to content

Commit b86ff1b

Browse files
sarmientoFclaudenjbrake
authored
fix(gateway): pricing lookup uses wrong separator, cost always NULL (#827)
## Description `_log_usage` in `chat.py` builds the pricing `model_key` using colon separator (`provider:model`), but `pricing_init.py` stores keys using slash separator (`provider/model`). This mismatch means the DB lookup never finds a match, so `cost` is always `NULL` and user spend is never updated for chat completions. The fix tries `provider/model` first (matching the convention used by `pricing_init.py`, `audio.py`, and `search.py`), then falls back to `provider:model` for backwards compatibility with pricing configured via the API using colon format. ### Why the existing tests don't catch this The `model_pricing` fixture in `conftest.py` creates pricing with `model_key = "gemini:gemini-2.5-flash"` (colon format), which happens to match the old lookup. In production, pricing seeded via `pricing_init.py` uses slash format — the two never match. ## PR Type - [x] Bug Fix ## Relevant Issues N/A ## Checklist - [x] I have read and understand the existing codebase and relevant files - [x] I have tested these changes locally - [x] Tests pass with my changes - [ ] I have updated documentation (if applicable) - [x] I have read the [Contributing Guidelines](CONTRIBUTING.md) - [x] I have checked my code follows the project's code style - [ ] I am an AI Agent filling out this form (check box if true) ## AI Disclosure - **AI Model used**: Claude Opus 4.6 - **AI Developer Tool used**: Claude Code - **Additional context**: Bug was discovered during production debugging (cost column showing NULL for all chat completions). The fix and PR were pair-programmed with Claude Code. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: njbrake <njbrake@gmail.com> Co-authored-by: Nathan Brake <33383515+njbrake@users.noreply.github.com>
1 parent 68a872f commit b86ff1b

File tree

4 files changed

+128
-5
lines changed

4 files changed

+128
-5
lines changed

src/any_llm/gateway/pricing_init.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ def initialize_pricing_from_config(config: GatewayConfig, db: Session) -> None:
2828

2929
logger.info(f"Loading pricing configuration for {len(config.pricing)} model(s)")
3030

31-
for model_key, pricing_config in config.pricing.items():
32-
provider, _ = AnyLLM.split_model_provider(model_key)
31+
for raw_model_key, pricing_config in config.pricing.items():
32+
provider, model_name = AnyLLM.split_model_provider(raw_model_key)
33+
model_key = f"{provider.value}:{model_name}"
3334

3435
if provider.value not in config.providers:
3536
msg = (

src/any_llm/gateway/routes/chat.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ async def _log_usage(
122122
usage_log.total_tokens = usage_data.total_tokens
123123

124124
model_key = f"{provider}:{model}" if provider else model
125+
model_key_legacy = f"{provider}/{model}" if provider else None
125126
pricing = db.query(ModelPricing).filter(ModelPricing.model_key == model_key).first()
127+
if not pricing and model_key_legacy:
128+
pricing = db.query(ModelPricing).filter(ModelPricing.model_key == model_key_legacy).first()
126129

127130
if pricing:
128131
cost = (usage_data.prompt_tokens / 1_000_000) * pricing.input_price_per_million + (
@@ -135,7 +138,8 @@ async def _log_usage(
135138
if user:
136139
user.spend = float(user.spend) + cost
137140
else:
138-
logger.warning(f"No pricing configured for model '{model_key}'. Usage will be tracked without cost.")
141+
attempted = f"'{model_key}'" + (f" or '{model_key_legacy}'" if model_key_legacy else "")
142+
logger.warning(f"No pricing configured for {attempted}. Usage will be tracked without cost.")
139143

140144
db.add(usage_log)
141145
try:

src/any_llm/gateway/routes/pricing.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from pydantic import BaseModel, Field
55
from sqlalchemy.orm import Session
66

7+
from any_llm.any_llm import AnyLLM
78
from any_llm.gateway.auth import verify_master_key
89
from any_llm.gateway.db import ModelPricing, get_db
910

@@ -34,14 +35,16 @@ async def set_pricing(
3435
db: Annotated[Session, Depends(get_db)],
3536
) -> PricingResponse:
3637
"""Set or update pricing for a model."""
37-
pricing = db.query(ModelPricing).filter(ModelPricing.model_key == request.model_key).first()
38+
provider, model_name = AnyLLM.split_model_provider(request.model_key)
39+
normalized_key = f"{provider.value}:{model_name}"
40+
pricing = db.query(ModelPricing).filter(ModelPricing.model_key == normalized_key).first()
3841

3942
if pricing:
4043
pricing.input_price_per_million = request.input_price_per_million
4144
pricing.output_price_per_million = request.output_price_per_million
4245
else:
4346
pricing = ModelPricing(
44-
model_key=request.model_key,
47+
model_key=normalized_key,
4548
input_price_per_million=request.input_price_per_million,
4649
output_price_per_million=request.output_price_per_million,
4750
)

tests/gateway/test_pricing_config.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from any_llm.gateway.config import GatewayConfig, PricingConfig
1010
from any_llm.gateway.db import ModelPricing, get_db
11+
from any_llm.gateway.db.models import UsageLog
12+
from any_llm.gateway.routes.chat import _log_usage
1113
from any_llm.gateway.server import create_app
14+
from any_llm.types.completion import CompletionUsage
1215

1316

1417
def test_pricing_loaded_from_config(postgres_url: str, test_db: Session) -> None:
@@ -116,6 +119,59 @@ def test_pricing_validation_requires_configured_provider(postgres_url: str, test
116119
create_app(config)
117120

118121

122+
def test_pricing_loaded_from_config_normalizes_legacy_slash_format(postgres_url: str, test_db: Session) -> None:
123+
"""Test that pricing configured with legacy slash format is normalized to colon format."""
124+
config = GatewayConfig(
125+
database_url=postgres_url,
126+
master_key="test-master-key",
127+
host="127.0.0.1",
128+
port=8000,
129+
providers={"openai": {"api_key": "test-key"}},
130+
pricing={
131+
"openai/gpt-4": PricingConfig(
132+
input_price_per_million=30.0,
133+
output_price_per_million=60.0,
134+
),
135+
},
136+
)
137+
138+
app = create_app(config)
139+
140+
def override_get_db() -> Any:
141+
yield test_db
142+
143+
app.dependency_overrides[get_db] = override_get_db
144+
145+
with TestClient(app):
146+
# Pricing should be stored with canonical colon format, not slash
147+
pricing_slash = test_db.query(ModelPricing).filter(ModelPricing.model_key == "openai/gpt-4").first()
148+
assert pricing_slash is None, "Pricing should not be stored with legacy slash format"
149+
150+
pricing_colon = test_db.query(ModelPricing).filter(ModelPricing.model_key == "openai:gpt-4").first()
151+
assert pricing_colon is not None, "Pricing should be stored with canonical colon format"
152+
assert pricing_colon.input_price_per_million == 30.0
153+
assert pricing_colon.output_price_per_million == 60.0
154+
155+
156+
def test_set_pricing_api_normalizes_legacy_slash_format(
157+
client: TestClient,
158+
master_key_header: dict[str, str],
159+
) -> None:
160+
"""Test that the pricing API normalizes legacy slash format to colon format."""
161+
response = client.post(
162+
"/v1/pricing",
163+
json={
164+
"model_key": "gemini/gemini-2.5-flash",
165+
"input_price_per_million": 0.075,
166+
"output_price_per_million": 0.30,
167+
},
168+
headers=master_key_header,
169+
)
170+
assert response.status_code == 200
171+
data = response.json()
172+
assert data["model_key"] == "gemini:gemini-2.5-flash", "API should normalize slash to colon format"
173+
174+
119175
def test_pricing_initialization_with_no_config(postgres_url: str, test_db: Session) -> None:
120176
"""Test that app starts successfully when no pricing is configured."""
121177
config = GatewayConfig(
@@ -139,3 +195,62 @@ def override_get_db() -> Any:
139195
# No pricing should be in database
140196
pricing_count = test_db.query(ModelPricing).count()
141197
assert pricing_count == 0, "No pricing should be loaded when config is empty"
198+
199+
200+
@pytest.mark.asyncio
201+
async def test_log_usage_finds_pricing_with_legacy_slash_format(test_db: Session) -> None:
202+
"""Test that _log_usage falls back to legacy slash format when colon format is not found."""
203+
# Simulate pricing stored with legacy slash format (e.g., from before normalization fix)
204+
legacy_pricing = ModelPricing(
205+
model_key="openai/gpt-4",
206+
input_price_per_million=30.0,
207+
output_price_per_million=60.0,
208+
)
209+
test_db.add(legacy_pricing)
210+
test_db.commit()
211+
212+
usage = CompletionUsage(prompt_tokens=1000, completion_tokens=500, total_tokens=1500)
213+
214+
await _log_usage(
215+
db=test_db,
216+
api_key_obj=None,
217+
model="gpt-4",
218+
provider="openai",
219+
endpoint="/v1/chat/completions",
220+
usage_override=usage,
221+
)
222+
223+
log = test_db.query(UsageLog).first()
224+
assert log is not None
225+
assert log.cost is not None, "Cost should be calculated via legacy slash format fallback"
226+
expected_cost = (1000 / 1_000_000) * 30.0 + (500 / 1_000_000) * 60.0
227+
assert abs(log.cost - expected_cost) < 0.0001
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_log_usage_finds_pricing_with_colon_format(test_db: Session) -> None:
232+
"""Test that _log_usage finds pricing with canonical colon format."""
233+
pricing = ModelPricing(
234+
model_key="openai:gpt-4",
235+
input_price_per_million=30.0,
236+
output_price_per_million=60.0,
237+
)
238+
test_db.add(pricing)
239+
test_db.commit()
240+
241+
usage = CompletionUsage(prompt_tokens=1000, completion_tokens=500, total_tokens=1500)
242+
243+
await _log_usage(
244+
db=test_db,
245+
api_key_obj=None,
246+
model="gpt-4",
247+
provider="openai",
248+
endpoint="/v1/chat/completions",
249+
usage_override=usage,
250+
)
251+
252+
log = test_db.query(UsageLog).first()
253+
assert log is not None
254+
assert log.cost is not None, "Cost should be calculated with canonical colon format"
255+
expected_cost = (1000 / 1_000_000) * 30.0 + (500 / 1_000_000) * 60.0
256+
assert abs(log.cost - expected_cost) < 0.0001

0 commit comments

Comments
 (0)