Skip to content

Commit 8d1d407

Browse files
Address PR #252 review findings: error handling, types, tests (#257)
* Address PR #252 review findings: error handling, types, tests Error handling: - Add CorruptMirrorError/ValueError handling to all mirror endpoints - Block unknown models on platform/community keys (fail-closed) - Add OSError handling to create_mirror_endpoint - Make cleanup_expired_mirrors resilient to per-mirror failures - Narrow scheduler cleanup catch to expected exception types - Add field_validator to RefreshMirrorRequest.community_ids Type design: - Make MirrorInfo a frozen dataclass with tuple community_ids - Move is_safe_identifier to src/core/validation.py (shared utility) - Add non-negativity validation to MODEL_PRICING at import time - Expand SecureFormatter key patterns for Anthropic/OpenAI keys Code quality: - Replace deprecated asyncio.get_event_loop() with get_running_loop() - Fix ContextVar comment accuracy (request lifecycle, not per-task) - Use get_active_mirror() instead of _active_mirror_id.get() - Fix docstring inaccuracies (caching, asyncio, model names) Tests: - Add active_mirror_context tests (set/reset, exception safety) - Add MirrorInfo invariant tests (empty ids, invalid id, immutability) - Add serialization round-trip test - Add TTL clamping test - Add run_sync_now invalid sync_type test - Update cost protection test for fail-closed behavior Closes #256 * Use generic redaction placeholder, remove misleading __all__ - Change redaction string from "sk-or-v1-***[redacted]" to "***[key-redacted]" since the pattern now covers multiple providers - Remove __all__ from mirror.py since no callers use wildcard imports from that module (is_safe_identifier now lives in core.validation) * Add ValueError catch to delete endpoint, validate community IDs - Add missing ValueError handling in delete_mirror_endpoint for consistency with all other mirror endpoints - Add community ID validation in MirrorInfo.__post_init__ so corrupt metadata with path-traversal community IDs is caught at load time - Document CorruptMirrorError in refresh_mirror docstring
1 parent 5a323f0 commit 8d1d407

File tree

13 files changed

+314
-65
lines changed

13 files changed

+314
-65
lines changed

src/api/routers/community.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,12 +620,19 @@ def _check_model_cost(model: str, key_source: str) -> None:
620620

621621
pricing = MODEL_PRICING.get(model)
622622
if pricing is None:
623-
logger.warning(
624-
"Model %s not in pricing table; allowing without cost check. "
623+
logger.error(
624+
"Model %s not in pricing table; blocking on platform/community key. "
625625
"Add this model to MODEL_PRICING in src/metrics/cost.py.",
626626
model,
627627
)
628-
return
628+
raise HTTPException(
629+
status_code=403,
630+
detail=(
631+
f"Model '{model}' is not in the approved pricing list and cannot be used "
632+
"with platform or community keys. To use this model, provide your own "
633+
"API key via the X-OpenRouter-Key header."
634+
),
635+
)
629636
input_rate = pricing.input_per_1m
630637

631638
if input_rate >= COST_BLOCK_THRESHOLD:

src/api/routers/mirrors.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
from pydantic import BaseModel, Field, field_validator
1616

1717
from src.api.security import RequireAuth
18+
from src.core.validation import is_safe_identifier
1819
from src.knowledge.db import active_mirror_context
1920
from src.knowledge.mirror import (
21+
CorruptMirrorError,
2022
MirrorInfo,
2123
create_mirror,
2224
delete_mirror,
2325
get_mirror,
2426
get_mirror_db_path,
25-
is_safe_identifier,
2627
list_mirrors,
2728
refresh_mirror,
2829
)
@@ -93,6 +94,16 @@ class RefreshMirrorRequest(BaseModel):
9394
default=None, description="Specific communities to refresh, or null for all"
9495
)
9596

97+
@field_validator("community_ids")
98+
@classmethod
99+
def validate_community_ids(cls, v: list[str] | None) -> list[str] | None:
100+
if v is None:
101+
return v
102+
for cid in v:
103+
if not is_safe_identifier(cid):
104+
raise ValueError(f"Invalid community ID: {cid!r}")
105+
return list(dict.fromkeys(v))
106+
96107

97108
SyncType = Literal["github", "papers", "docstrings", "mailman", "faq", "beps", "all"]
98109

@@ -135,6 +146,12 @@ async def create_mirror_endpoint(
135146
)
136147
except ValueError as e:
137148
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
149+
except OSError as e:
150+
logger.error("Failed to create mirror: %s", e, exc_info=True)
151+
raise HTTPException(
152+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
153+
detail="Failed to create mirror due to a server filesystem error.",
154+
) from e
138155

139156
logger.info(
140157
"Mirror created: %s (communities=%s, owner=%s)",
@@ -161,7 +178,18 @@ async def get_mirror_endpoint(
161178
_auth: RequireAuth,
162179
) -> MirrorResponse:
163180
"""Get metadata for a specific mirror."""
164-
info = get_mirror(mirror_id)
181+
try:
182+
info = get_mirror(mirror_id)
183+
except ValueError:
184+
raise HTTPException(
185+
status_code=status.HTTP_400_BAD_REQUEST,
186+
detail=f"Invalid mirror ID format: '{mirror_id}'",
187+
)
188+
except CorruptMirrorError:
189+
raise HTTPException(
190+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
191+
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
192+
)
165193
if not info:
166194
raise HTTPException(
167195
status_code=status.HTTP_404_NOT_FOUND,
@@ -182,6 +210,11 @@ async def delete_mirror_endpoint(
182210
status_code=status.HTTP_404_NOT_FOUND,
183211
detail=f"Mirror '{mirror_id}' not found",
184212
)
213+
except ValueError:
214+
raise HTTPException(
215+
status_code=status.HTTP_400_BAD_REQUEST,
216+
detail=f"Invalid mirror ID format: '{mirror_id}'",
217+
)
185218
except OSError as e:
186219
logger.error("Failed to delete mirror %s: %s", mirror_id, e, exc_info=True)
187220
raise HTTPException(
@@ -205,6 +238,11 @@ async def refresh_mirror_endpoint(
205238
info = refresh_mirror(mirror_id, community_ids=body.community_ids)
206239
except ValueError as e:
207240
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
241+
except CorruptMirrorError:
242+
raise HTTPException(
243+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
244+
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
245+
)
208246

209247
logger.info("Mirror refreshed via API: %s", mirror_id)
210248
return MirrorResponse.from_info(info)
@@ -222,7 +260,18 @@ async def sync_mirror_endpoint(
222260
databases instead of production. Supports sync types: github, papers,
223261
docstrings, mailman, faq, beps, or all.
224262
"""
225-
info = get_mirror(mirror_id)
263+
try:
264+
info = get_mirror(mirror_id)
265+
except ValueError:
266+
raise HTTPException(
267+
status_code=status.HTTP_400_BAD_REQUEST,
268+
detail=f"Invalid mirror ID format: '{mirror_id}'",
269+
)
270+
except CorruptMirrorError:
271+
raise HTTPException(
272+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
273+
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
274+
)
226275
if not info:
227276
raise HTTPException(
228277
status_code=status.HTTP_404_NOT_FOUND,
@@ -235,8 +284,9 @@ async def sync_mirror_endpoint(
235284
)
236285

237286
# Run sync in a thread with the mirror context explicitly copied.
238-
# asyncio.to_thread copies ContextVars on Python 3.12+ but not 3.11,
239-
# so we capture the context and run within it for compatibility.
287+
# We use run_in_executor with an explicit context copy instead of
288+
# asyncio.to_thread because to_thread only copies ContextVars
289+
# automatically on Python 3.12+.
240290
from src.api.scheduler import run_sync_now
241291

242292
ctx = contextvars.copy_context()
@@ -246,14 +296,21 @@ def _run_sync_in_mirror() -> dict[str, int]:
246296
return run_sync_now(body.sync_type)
247297

248298
try:
249-
results = await asyncio.get_event_loop().run_in_executor(None, ctx.run, _run_sync_in_mirror)
299+
loop = asyncio.get_running_loop()
300+
results = await loop.run_in_executor(None, ctx.run, _run_sync_in_mirror)
250301
total = sum(results.values())
251302
return MirrorSyncResponse(
252303
message=f"Sync completed: {total} items synced into mirror {mirror_id}",
253304
items_synced=results,
254305
)
255306
except ValueError as e:
256307
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
308+
except OSError as e:
309+
logger.error("Mirror sync I/O error for %s: %s", mirror_id, e, exc_info=True)
310+
raise HTTPException(
311+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
312+
detail=f"Sync failed due to a filesystem error: {e}",
313+
) from e
257314
except Exception as e:
258315
logger.error("Mirror sync failed for %s: %s", mirror_id, e, exc_info=True)
259316
raise HTTPException(
@@ -274,7 +331,18 @@ async def download_mirror_db(
274331
"""
275332
from fastapi.responses import FileResponse
276333

277-
info = get_mirror(mirror_id)
334+
try:
335+
info = get_mirror(mirror_id)
336+
except ValueError:
337+
raise HTTPException(
338+
status_code=status.HTTP_400_BAD_REQUEST,
339+
detail=f"Invalid mirror ID format: '{mirror_id}'",
340+
)
341+
except CorruptMirrorError:
342+
raise HTTPException(
343+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
344+
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
345+
)
278346
if not info:
279347
raise HTTPException(
280348
status_code=status.HTTP_404_NOT_FOUND,

src/api/scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,14 +292,14 @@ def _run_beps_sync_for_community(community_id: str) -> bool:
292292
def _cleanup_mirrors() -> None:
293293
"""Remove expired ephemeral database mirrors."""
294294
global _mirror_cleanup_failures
295-
try:
296-
from src.knowledge.mirror import cleanup_expired_mirrors
295+
from src.knowledge.mirror import CorruptMirrorError, cleanup_expired_mirrors
297296

297+
try:
298298
deleted = cleanup_expired_mirrors()
299299
if deleted:
300300
logger.info("Mirror cleanup: removed %d expired mirrors", deleted)
301301
_mirror_cleanup_failures = 0
302-
except Exception:
302+
except (OSError, ValueError, CorruptMirrorError):
303303
_mirror_cleanup_failures += 1
304304
logger.error(
305305
"Mirror cleanup failed (consecutive failures: %d)",

src/core/logging.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""Secure logging configuration with API key redaction.
22
3-
Provides a custom log formatter that automatically redacts OpenRouter API keys
4-
from log messages to prevent credential exposure in centralized logging systems.
3+
Provides a custom log formatter that automatically redacts API keys
4+
(OpenRouter, Anthropic, OpenAI) from log messages to prevent credential
5+
exposure in centralized logging systems.
56
67
Supports both text and JSON-structured logging formats.
78
"""
@@ -17,12 +18,19 @@
1718
class SecureFormatter(logging.Formatter):
1819
"""Custom log formatter that redacts API keys from log messages.
1920
20-
Automatically detects and redacts OpenRouter API keys in the format
21-
sk-or-v1-[64 hex chars] to prevent accidental credential exposure.
21+
Automatically detects and redacts API keys from OpenRouter, Anthropic,
22+
and OpenAI to prevent accidental credential exposure.
2223
"""
2324

24-
# Pattern to match OpenRouter API keys: sk-or-v1-[64 hex chars]
25-
API_KEY_PATTERN = re.compile(r"sk-or-v1-[0-9a-f]{64}", re.IGNORECASE)
25+
# Patterns for API keys from various providers.
26+
# IGNORECASE as defense-in-depth; real keys use lowercase hex.
27+
API_KEY_PATTERN = re.compile(
28+
r"sk-or-v1-[0-9a-f]{64}" # OpenRouter: sk-or-v1-[64 hex chars]
29+
r"|sk-ant-[a-zA-Z0-9_-]{80,}" # Anthropic: sk-ant-...
30+
r"|sk-proj-[a-zA-Z0-9_-]{40,}" # OpenAI project keys: sk-proj-...
31+
r"|sk-[a-zA-Z0-9]{48,}", # Generic OpenAI keys: sk-...
32+
re.IGNORECASE,
33+
)
2634

2735
def format(self, record: logging.LogRecord) -> str:
2836
"""Format log record and redact any API keys.
@@ -55,7 +63,7 @@ def format(self, record: logging.LogRecord) -> str:
5563
if len(formatted) > 100_000: # 100KB limit
5664
formatted = formatted[:100_000] + "... [truncated for safety]"
5765

58-
formatted = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", formatted)
66+
formatted = self.API_KEY_PATTERN.sub("***[key-redacted]", formatted)
5967
except re.error as e:
6068
# Regex pattern is broken - this is a code bug
6169
print(f"CRITICAL: Redaction regex failed: {e}", file=sys.stderr)
@@ -137,7 +145,7 @@ def format(self, record: logging.LogRecord) -> str:
137145
json_str = json.dumps(log_entry, default=str)
138146

139147
# Redact API keys from the JSON string
140-
json_str = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", json_str)
148+
json_str = self.API_KEY_PATTERN.sub("***[key-redacted]", json_str)
141149

142150
return json_str
143151

@@ -154,7 +162,7 @@ def format(self, record: logging.LogRecord) -> str:
154162
"original_message": safe_msg,
155163
}
156164
fallback_json = json.dumps(error_entry)
157-
return self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", fallback_json)
165+
return self.API_KEY_PATTERN.sub("***[key-redacted]", fallback_json)
158166
except Exception as e:
159167
# Unexpected errors - surface to stderr and re-raise
160168
print(f"CRITICAL: Unexpected error in SecureJSONFormatter: {e}", file=sys.stderr)

src/core/services/litellm_llm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ def create_openrouter_llm(
6969
provider: Specific provider to use (e.g., "Cerebras", "DeepInfra/FP8").
7070
Ignored for Anthropic models, which always use "Anthropic" provider.
7171
user_id: User identifier for cache optimization (sticky routing)
72-
enable_caching: Enable prompt caching. If None (default), enabled for all models.
73-
OpenRouter/LiteLLM gracefully handles models that don't support caching.
72+
enable_caching: Enable prompt caching. If None (default), caching is requested
73+
for all models. Models that do not support caching will ignore the
74+
cache_control markers without error.
7475
7576
Returns:
7677
LLM instance configured for OpenRouter

src/core/services/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def get_model(
143143
model_name: Model name. Supports:
144144
- OpenRouter format: 'creator/model' (e.g., 'openai/gpt-oss-120b', 'qwen/qwen3-235b')
145145
- Direct OpenAI: 'gpt-4o', 'gpt-4o-mini', etc.
146-
- Direct Anthropic: 'claude-3-5-sonnet', etc.
146+
- Direct Anthropic: 'claude-3.5-sonnet', etc.
147147
If not provided, uses settings.default_model.
148148
api_key: Optional API key override (for BYOK).
149149
temperature: Model temperature. If not provided, uses settings.llm_temperature.

src/core/validation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Shared input validation utilities.
2+
3+
Provides common validation functions used across modules for
4+
preventing path traversal and ensuring safe identifiers.
5+
"""
6+
7+
8+
def is_safe_identifier(value: str) -> bool:
9+
"""Check if a string is a safe identifier (alphanumeric, hyphens, underscores).
10+
11+
Used for both mirror IDs and community IDs to prevent path traversal.
12+
"""
13+
return bool(value) and value.replace("-", "").replace("_", "").isalnum()

src/knowledge/db.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,16 @@
2121
from pathlib import Path
2222

2323
from src.cli.config import get_data_dir
24-
from src.knowledge.mirror import _validate_mirror_id, is_safe_identifier
24+
from src.core.validation import is_safe_identifier
25+
from src.knowledge.mirror import _validate_mirror_id
2526

2627
logger = logging.getLogger(__name__)
2728

2829
# ContextVar for transparent mirror routing. When set, get_db_path() returns
2930
# the mirror's database path instead of the production path.
30-
# ContextVar is safe for concurrent async tasks; each task gets its own copy,
31-
# so mirror routing in one request does not affect other requests.
31+
# Safe for concurrent requests because the middleware sets and resets the
32+
# value around each request's lifecycle. Nested async calls within the
33+
# same request inherit the value.
3234
_active_mirror_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
3335
"_active_mirror_id", default=None
3436
)
@@ -448,7 +450,7 @@ def get_db_path(project: str = "hed") -> Path:
448450
"Use only alphanumeric characters, hyphens, and underscores."
449451
)
450452

451-
mirror_id = _active_mirror_id.get()
453+
mirror_id = get_active_mirror()
452454
if mirror_id:
453455
# mirror_id was already validated when set via set_active_mirror()
454456
return get_data_dir() / "mirrors" / mirror_id / f"{project}.db"

0 commit comments

Comments
 (0)