Skip to content

Commit a294024

Browse files
Security hardening: logging, cost protection, SSRF, model validation (#248)
* Security hardening: logging, cost protection - Wire up SecureFormatter in app startup (#65): call configure_secure_logging() before any logging occurs - Add cost manipulation protection (#67): block models above $15/1M input tokens on platform/community keys, warn above $5/1M; BYOK users unrestricted - Verified SSRF protection (#66) and model validation (#68) already have comprehensive test coverage Closes #65, closes #66, closes #67, closes #68 * Address PR review findings - Fix misleading "fallback rate" comment in _check_model_cost - Add logging for unknown models (operator visibility) - Extract _models_by_cost() test helper to reduce duplication - Add boundary test at exact block threshold - Add BYOK + unknown model test - Assert BYOK guidance in error message - Fix module docstring wording * Fix SecureJSONFormatter broad exception catch Split the catch-all Exception handler into specific expected errors (ValueError, TypeError, KeyError) that include context for debugging, and unexpected errors that re-raise after printing to stderr. Matches the pattern already used in SecureFormatter.format().
1 parent 0d70be0 commit a294024

File tree

5 files changed

+157
-4
lines changed

5 files changed

+157
-4
lines changed

src/api/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,14 @@
2525
from src.api.routers.widget_test import router as widget_test_router
2626
from src.api.scheduler import start_scheduler, stop_scheduler
2727
from src.assistants import discover_assistants, registry
28+
from src.core.logging import configure_secure_logging
2829
from src.metrics.db import init_metrics_db
2930
from src.metrics.middleware import MetricsMiddleware
3031

32+
# Configure secure logging before any other logging occurs.
33+
# This ensures all log output uses SecureFormatter which redacts API keys.
34+
configure_secure_logging()
35+
3136
logger = logging.getLogger(__name__)
3237

3338
# Discover assistants at module load time to populate registry

src/api/routers/community.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from src.assistants.registry import AssistantInfo
3535
from src.core.config.community import WidgetConfig
3636
from src.core.services.litellm_llm import create_openrouter_llm
37-
from src.metrics.cost import estimate_cost
37+
from src.metrics.cost import COST_BLOCK_THRESHOLD, COST_WARN_THRESHOLD, MODEL_PRICING, estimate_cost
3838
from src.metrics.db import (
3939
RequestLogEntry,
4040
extract_token_usage,
@@ -602,6 +602,48 @@ def _select_model(
602602
return (default_model, default_provider)
603603

604604

605+
def _check_model_cost(model: str, key_source: str) -> None:
606+
"""Check if a model's cost exceeds platform thresholds.
607+
608+
Only enforced when using platform or community API keys (not BYOK).
609+
Logs a warning for moderately expensive models and blocks very expensive ones.
610+
611+
Args:
612+
model: Model identifier (e.g., "openai/gpt-4o").
613+
key_source: One of "byok", "community", or "platform".
614+
615+
Raises:
616+
HTTPException(403): If model cost exceeds the block threshold.
617+
"""
618+
if key_source == "byok":
619+
return
620+
621+
pricing = MODEL_PRICING.get(model)
622+
if pricing is None:
623+
logger.info("Model %s not in pricing table; allowing without cost check", model)
624+
return
625+
input_rate = pricing[0]
626+
627+
if input_rate >= COST_BLOCK_THRESHOLD:
628+
raise HTTPException(
629+
status_code=403,
630+
detail=(
631+
f"Model '{model}' costs ${input_rate:.2f}/1M input tokens, "
632+
f"which exceeds the platform limit of ${COST_BLOCK_THRESHOLD:.2f}/1M. "
633+
"To use expensive models, provide your own API key via the "
634+
"X-OpenRouter-Key header. Get a key at: https://openrouter.ai/keys"
635+
),
636+
)
637+
638+
if input_rate >= COST_WARN_THRESHOLD:
639+
logger.warning(
640+
"Model %s costs $%.2f/1M input tokens (warn threshold: $%.2f)",
641+
model,
642+
input_rate,
643+
COST_WARN_THRESHOLD,
644+
)
645+
646+
605647
def _derive_user_id(token: str) -> str:
606648
"""Derive a stable user ID from API token for cache optimization.
607649
@@ -717,6 +759,10 @@ def create_community_assistant(
717759
selected_model, selected_provider = _select_model(
718760
community_info, requested_model, has_byok=bool(byok)
719761
)
762+
763+
# Block expensive models on platform/community keys
764+
_check_model_cost(selected_model, key_source)
765+
720766
logger.debug(
721767
"Using model %s",
722768
selected_model,

src/core/logging.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,23 @@ def format(self, record: logging.LogRecord) -> str:
141141

142142
return json_str
143143

144-
except Exception as e:
145-
# Fallback to safe error message
144+
except (ValueError, TypeError, KeyError) as e:
145+
# Expected serialization errors - include context for debugging
146+
safe_msg = str(getattr(record, "msg", "<no message>"))[:200]
147+
safe_name = getattr(record, "name", "<unknown>")
146148
error_entry = {
147149
"timestamp": datetime.now(UTC).isoformat(),
148150
"level": "ERROR",
149151
"logger": "logging",
150-
"message": f"[LOGGING ERROR: {type(e).__name__}]",
152+
"message": f"[LOGGING ERROR: {type(e).__name__}: {e}]",
153+
"original_logger": safe_name,
154+
"original_message": safe_msg,
151155
}
152156
return json.dumps(error_entry)
157+
except Exception as e:
158+
# Unexpected errors - surface to stderr and re-raise
159+
print(f"CRITICAL: Unexpected error in SecureJSONFormatter: {e}", file=sys.stderr)
160+
raise
153161

154162

155163
def configure_secure_logging(

src/metrics/cost.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@
4040
_FALLBACK_INPUT_RATE = 1.00 # USD per 1M tokens
4141
_FALLBACK_OUTPUT_RATE = 3.00 # USD per 1M tokens
4242

43+
# Cost protection thresholds (USD per 1M input tokens)
44+
# Applied only when using platform/community keys (not BYOK)
45+
COST_WARN_THRESHOLD = 5.0 # Log warning for models above this
46+
COST_BLOCK_THRESHOLD = 15.0 # Block requests for models above this
47+
4348

4449
def estimate_cost(
4550
model: str | None,
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
"""Tests for model cost protection.
2+
3+
Verifies that expensive models are blocked when using platform/community keys,
4+
but allowed when users provide their own API key (BYOK).
5+
"""
6+
7+
import pytest
8+
from fastapi import HTTPException
9+
10+
from src.api.routers.community import _check_model_cost
11+
from src.metrics.cost import COST_BLOCK_THRESHOLD, COST_WARN_THRESHOLD, MODEL_PRICING
12+
13+
14+
def _models_by_cost(min_rate: float = 0.0, max_rate: float = float("inf")) -> list[str]:
15+
"""Return model names with input rates in [min_rate, max_rate)."""
16+
return [m for m, (inp, _) in MODEL_PRICING.items() if min_rate <= inp < max_rate]
17+
18+
19+
class TestCheckModelCost:
20+
"""Tests for _check_model_cost() pre-invocation cost guard."""
21+
22+
def test_cheap_model_on_platform_key_allowed(self) -> None:
23+
"""Cheap models should be allowed on platform keys without error."""
24+
cheap_models = _models_by_cost(max_rate=COST_WARN_THRESHOLD)
25+
assert cheap_models, "Test requires at least one cheap model in MODEL_PRICING"
26+
27+
_check_model_cost(cheap_models[0], "platform")
28+
_check_model_cost(cheap_models[0], "community")
29+
30+
def test_expensive_model_blocked_on_platform_key(self) -> None:
31+
"""Models above block threshold should be rejected with 403 on platform keys."""
32+
expensive_models = _models_by_cost(min_rate=COST_BLOCK_THRESHOLD)
33+
assert expensive_models, "Test requires at least one expensive model in MODEL_PRICING"
34+
35+
with pytest.raises(HTTPException) as exc_info:
36+
_check_model_cost(expensive_models[0], "platform")
37+
assert exc_info.value.status_code == 403
38+
assert "exceeds the platform limit" in exc_info.value.detail
39+
assert "openrouter.ai/keys" in exc_info.value.detail
40+
41+
def test_expensive_model_blocked_on_community_key(self) -> None:
42+
"""Models above block threshold should also be rejected on community keys."""
43+
expensive_models = _models_by_cost(min_rate=COST_BLOCK_THRESHOLD)
44+
assert expensive_models, "Test requires at least one expensive model in MODEL_PRICING"
45+
46+
with pytest.raises(HTTPException) as exc_info:
47+
_check_model_cost(expensive_models[0], "community")
48+
assert exc_info.value.status_code == 403
49+
50+
def test_expensive_model_allowed_with_byok(self) -> None:
51+
"""BYOK users should be able to use any model, even expensive ones."""
52+
expensive_models = _models_by_cost(min_rate=COST_BLOCK_THRESHOLD)
53+
assert expensive_models, "Test requires at least one expensive model in MODEL_PRICING"
54+
55+
_check_model_cost(expensive_models[0], "byok")
56+
57+
def test_unknown_model_allowed_on_platform_key(self) -> None:
58+
"""Unknown models (not in pricing table) should be allowed."""
59+
_check_model_cost("unknown/made-up-model-xyz", "platform")
60+
61+
def test_unknown_model_allowed_with_byok(self) -> None:
62+
"""BYOK users with unknown models should also be allowed."""
63+
_check_model_cost("unknown/made-up-model-xyz", "byok")
64+
65+
def test_warn_threshold_model_not_blocked(self) -> None:
66+
"""Models between warn and block thresholds should be allowed (just warned)."""
67+
warn_only_models = _models_by_cost(
68+
min_rate=COST_WARN_THRESHOLD, max_rate=COST_BLOCK_THRESHOLD
69+
)
70+
if not warn_only_models:
71+
pytest.skip("No models between warn and block thresholds in current pricing")
72+
73+
_check_model_cost(warn_only_models[0], "platform")
74+
75+
def test_model_at_exact_block_threshold_is_blocked(self) -> None:
76+
"""A model priced exactly at the block threshold should be blocked."""
77+
exact_models = [m for m, (inp, _) in MODEL_PRICING.items() if inp == COST_BLOCK_THRESHOLD]
78+
if not exact_models:
79+
pytest.skip("No model priced exactly at block threshold")
80+
81+
with pytest.raises(HTTPException) as exc_info:
82+
_check_model_cost(exact_models[0], "platform")
83+
assert exc_info.value.status_code == 403
84+
85+
def test_thresholds_are_sane(self) -> None:
86+
"""Sanity check: warn threshold should be lower than block threshold."""
87+
assert COST_WARN_THRESHOLD < COST_BLOCK_THRESHOLD
88+
assert COST_WARN_THRESHOLD > 0
89+
assert COST_BLOCK_THRESHOLD > 0

0 commit comments

Comments
 (0)