Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lm_proxy/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Provides the CLI entry point when the package is executed as a Python module."""

from .app import cli_app


Expand Down
1 change: 1 addition & 0 deletions lm_proxy/api_key_check/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Collection of built-in API-key checkers for usage in the configuration."""

from .in_config import check_api_key_in_config
from .with_request import CheckAPIKeyWithRequest
from .allow_all import AllowAll
Expand Down
6 changes: 2 additions & 4 deletions lm_proxy/api_key_check/allow_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
This module provides a simple authentication strategy for development or testing
environments where all API keys should be accepted without validation.
"""

from typing import Optional
from dataclasses import dataclass

Expand All @@ -25,10 +26,7 @@ class AllowAll:
group: str = "default"
capture_api_key: bool = True

def __call__(
self,
api_key: Optional[str]
) -> tuple[str, dict[str, Optional[str]]]:
def __call__(self, api_key: Optional[str]) -> tuple[str, dict[str, Optional[str]]]:
"""
Validate an API key (accepts all keys without verification).
Expand Down
1 change: 1 addition & 0 deletions lm_proxy/api_key_check/in_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
For using this function,
set "api_key_check" configuration value to "lm_proxy.api_key_check.check_api_key_in_config".
"""

from typing import Optional
from ..bootstrap import env

Expand Down
7 changes: 3 additions & 4 deletions lm_proxy/api_key_check/with_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
API key check implementation using HTTP requests.
"""

from typing import Optional
from dataclasses import dataclass, field
import requests
Expand All @@ -13,6 +14,7 @@ class CheckAPIKeyWithRequest: # pylint: disable=too-many-instance-attributes
"""
Validates a Client API key by making an HTTP request to a specified URL.
"""

url: str = field()
method: str = field(default="get")
headers: dict = field(default_factory=dict)
Expand Down Expand Up @@ -45,10 +47,7 @@ def check_func(api_key: str) -> Optional[tuple[str, dict]]:
for k, v in self.headers.items()
}
response = requests.request(
method=self.method,
url=url,
headers=headers,
timeout=self.timeout
method=self.method, url=url, headers=headers, timeout=self.timeout
)
response.raise_for_status()
group = self.default_group
Expand Down
9 changes: 3 additions & 6 deletions lm_proxy/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
LM-Proxy Application Entrypoint
"""

import logging
from typing import Optional
from fastapi import FastAPI
Expand All @@ -19,9 +20,7 @@
@cli_app.callback(invoke_without_command=True)
def run_server(
config: Optional[str] = typer.Option(None, help="Path to the configuration file"),
debug: Optional[bool] = typer.Option(
None, help="Enable debug mode (more verbose logging)"
),
debug: Optional[bool] = typer.Option(None, help="Enable debug mode (more verbose logging)"),
env_file: Optional[str] = typer.Option(
".env",
"--env",
Expand Down Expand Up @@ -55,9 +54,7 @@ def web_app():
"""
Entrypoint for ASGI server
"""
app = FastAPI(
title="LM-Proxy", description="OpenAI-compatible proxy server for LLM inference"
)
app = FastAPI(title="LM-Proxy", description="OpenAI-compatible proxy server for LLM inference")
OpenAIHTTPException.register(app)
app.add_api_route(
path=f"{env.config.api_prefix}/chat/completions",
Expand Down
8 changes: 4 additions & 4 deletions lm_proxy/base_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Base types used in LM-Proxy."""

import uuid
from dataclasses import dataclass, field
from datetime import datetime
Expand All @@ -17,6 +18,7 @@ class ChatCompletionRequest(BaseModel):
"""
Request model for chat/completions endpoint.
"""

model: str
messages: List[mc.Msg | dict]
# | dict --> support of messages with lists of dicts
Expand Down Expand Up @@ -58,6 +60,7 @@ class RequestContext: # pylint: disable=too-many-instance-attributes
"""
Stores information about a single LLM request/response cycle for usage in middleware.
"""

id: Optional[str] = field(default_factory=lambda: str(uuid.uuid4()))
request: Optional[ChatCompletionRequest] = field(default=None)
http_request: Optional[Request] = field(default=None)
Expand All @@ -83,7 +86,4 @@ def to_dict(self) -> dict:
return data


THandler = Callable[
[RequestContext],
Union[Awaitable[None], None]
]
THandler = Callable[[RequestContext], Union[Awaitable[None], None]]
17 changes: 5 additions & 12 deletions lm_proxy/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def format(self, record):

class Env:
"""Runtime environment singleton."""

config: Config
connections: dict[str, mc.types.LLMAsyncFunctionType]
debug: bool
Expand All @@ -66,9 +67,7 @@ def init(config: Config | str | PathLike, debug: bool = False):
if isinstance(config, (str, PathLike)):
config = Config.load(config)
else:
raise ValueError(
"config must be a path (str or PathLike) or Config instance"
)
raise ValueError("config must be a path (str or PathLike) or Config instance")
env.config = config

env._init_components()
Expand All @@ -84,18 +83,12 @@ def init(config: Config | str | PathLike, debug: bool = False):
if inspect.iscoroutinefunction(conn_config):
env.connections[conn_name] = conn_config
elif isinstance(conn_config, str):
env.connections[conn_name] = resolve_instance_or_callable(
conn_config
)
env.connections[conn_name] = resolve_instance_or_callable(conn_config)
else:
mc.configure(
**conn_config, EMBEDDING_DB_TYPE=mc.EmbeddingDbType.NONE
)
mc.configure(**conn_config, EMBEDDING_DB_TYPE=mc.EmbeddingDbType.NONE)
env.connections[conn_name] = mc.env().llm_async_function
except mc.LLMConfigError as e:
raise ValueError(
f"Error in configuration for connection '{conn_name}': {e}"
) from e
raise ValueError(f"Error in configuration for connection '{conn_name}': {e}") from e

logging.info("Done initializing %d connections.", len(env.connections))

Expand Down
1 change: 1 addition & 0 deletions lm_proxy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class ModelListingMode(StrEnum):

class Group(BaseModel):
"""User group configuration."""

api_keys: list[str] = Field(default_factory=list)
allowed_connections: str = Field(default="*") # Comma-separated list or "*"

Expand Down
1 change: 1 addition & 0 deletions lm_proxy/config_loaders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Built-in configuration loaders for different file formats."""

from .python import load_python_config
from .toml import load_toml_config
from .yaml import load_yaml_config
Expand Down
1 change: 1 addition & 0 deletions lm_proxy/config_loaders/json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""JSON configuration loader."""

import json


Expand Down
1 change: 1 addition & 0 deletions lm_proxy/config_loaders/python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Loader for Python configuration files."""

import importlib.util
from ..config import Config

Expand Down
1 change: 1 addition & 0 deletions lm_proxy/config_loaders/toml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""TOML configuration loader."""

import tomllib


Expand Down
12 changes: 3 additions & 9 deletions lm_proxy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@
return connection_name, model_part


def resolve_connection_and_model(
config: Config, external_model: str
) -> tuple[str, str]:
def resolve_connection_and_model(config: Config, external_model: str) -> tuple[str, str]:
"""
Resolves the connection name and model name based on routing rules.
Args:
Expand Down Expand Up @@ -150,9 +148,7 @@
"""
if not api_key:
return None
return hashlib.md5(
(api_key + env.config.encryption_key).encode("utf-8")
).hexdigest()
return hashlib.md5((api_key + env.config.encryption_key).encode("utf-8")).hexdigest()

Check failure

Code scanning / CodeQL

Use of a broken or weak cryptographic hashing algorithm on sensitive data High

Sensitive data (password)
is used in a hashing algorithm (MD5) that is insecure for password hashing, since it is not a computationally expensive hash function.
Sensitive data (password)
is used in a hashing algorithm (MD5) that is insecure for password hashing, since it is not a computationally expensive hash function.
Sensitive data (password)
is used in a hashing algorithm (MD5) that is insecure for password hashing, since it is not a computationally expensive hash function.

Copilot Autofix

AI 16 days ago

General approach: replace the use of MD5 with a modern cryptographic hash function from the SHA-2 family (e.g., SHA-256) for deriving this API key identifier. Since this identifier is not a password verifier but a pseudonymous ID, we do not need a slow KDF like bcrypt/Argon2; a simple SHA-256 hash is sufficient and standard.

Best concrete fix without changing functionality: keep the logic and types of api_key_id exactly the same (returns a deterministic hex string based on api_key and env.config.encryption_key), but change hashlib.md5 to hashlib.sha256. This preserves:

  • Same function signature and return type (str | None).
  • Same salting/combining scheme with env.config.encryption_key.
  • Same encoding (utf-8) and hex-digest format.

Only the length and value of the identifier change, which should be acceptable for any logging/metrics usage and does not affect authorization or API behavior.

Specific changes:

  • File: lm_proxy/core.py
  • In api_key_id (around line 151), replace:
    • hashlib.md5(...).hexdigest()
      with:
    • hashlib.sha256(...).hexdigest()
  • No new imports are needed because hashlib is already imported at the top of the file.

Suggested changeset 1
lm_proxy/core.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/lm_proxy/core.py b/lm_proxy/core.py
--- a/lm_proxy/core.py
+++ b/lm_proxy/core.py
@@ -148,7 +148,7 @@
     """
     if not api_key:
         return None
-    return hashlib.md5((api_key + env.config.encryption_key).encode("utf-8")).hexdigest()
+    return hashlib.sha256((api_key + env.config.encryption_key).encode("utf-8")).hexdigest()
 
 
 def fail_if_service_disabled():
EOF
@@ -148,7 +148,7 @@
"""
if not api_key:
return None
return hashlib.md5((api_key + env.config.encryption_key).encode("utf-8")).hexdigest()
return hashlib.sha256((api_key + env.config.encryption_key).encode("utf-8")).hexdigest()


def fail_if_service_disabled():
Copilot is powered by AI and may make mistakes. Always verify output.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gito is sha256 slower than md5?



def fail_if_service_disabled():
Expand Down Expand Up @@ -207,9 +203,7 @@
return group, api_key, user_info


async def chat_completions(
request: ChatCompletionRequest, raw_request: Request
) -> Response:
async def chat_completions(request: ChatCompletionRequest, raw_request: Request) -> Response:
"""
Endpoint for chat completions that mimics OpenAI's API structure.
"""
Expand Down
1 change: 1 addition & 0 deletions lm_proxy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class OpenAIHTTPException(Exception):
"""
OpenAI API-compatible exception.
"""

# HTTPException from FastAPI is not used directly to provide error response format
# compatible with OpenAI API.

Expand Down
5 changes: 1 addition & 4 deletions lm_proxy/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from .rate_limiter import RateLimiter
from .forward_http_headers import HTTPHeadersForwarder

__all__ = [
'RateLimiter',
'HTTPHeadersForwarder'
]
__all__ = ["RateLimiter", "HTTPHeadersForwarder"]
1 change: 1 addition & 0 deletions lm_proxy/handlers/forward_http_headers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
HTTP headers forwarder LM-Proxy.
"""

from dataclasses import dataclass, field

from lm_proxy.base_types import RequestContext
Expand Down
7 changes: 3 additions & 4 deletions lm_proxy/handlers/rate_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

Provides sliding window rate limiting per API key / IP address / connection / user group / global.
"""

import threading
import time
from dataclasses import dataclass, field
Expand Down Expand Up @@ -33,6 +34,7 @@ class RateLimiter:
per: Scope for rate limiting
("api_key", "connection", "group", "ip", "global").
"""

max_requests: int = 60
window_seconds: float = 60.0
per: RateLimitScope = RateLimitScope.API_KEY
Expand Down Expand Up @@ -68,10 +70,7 @@ async def __call__(self, ctx: RequestContext) -> None:

with self._lock:
if len(self._buckets) > self.max_buckets:
self._buckets = {
k: v for k, v in self._buckets.items()
if v and v[-1] > cutoff
}
self._buckets = {k: v for k, v in self._buckets.items() if v and v[-1] > cutoff}

timestamps = [t for t in self._buckets.get(key, []) if t > cutoff]

Expand Down
6 changes: 6 additions & 0 deletions lm_proxy/loggers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""LLM Request logging."""

import abc
import json
import os
Expand All @@ -11,13 +12,15 @@

class AbstractLogEntryTransformer(abc.ABC): # pylint: disable=too-few-public-methods
"""Transform RequestContext into a dictionary of logged attributes."""

@abc.abstractmethod
def __call__(self, request_context: RequestContext) -> dict:
raise NotImplementedError()


class AbstractLogWriter(abc.ABC): # pylint: disable=too-few-public-methods
"""Writes the logged data to a destination."""

@abc.abstractmethod
def __call__(self, logged_data: dict):
raise NotImplementedError()
Expand All @@ -29,6 +32,7 @@ class LogEntryTransformer(AbstractLogEntryTransformer): # pylint: disable=too-f
The mapping is provided as keyword arguments, where keys are the names of the
logged attributes, and values are the paths to the attributes in RequestContext.
"""

def __init__(self, **kwargs):
self.mapping = kwargs

Expand All @@ -42,6 +46,7 @@ def __call__(self, request_context: RequestContext) -> dict:
@dataclass
class BaseLogger:
"""Base LLM request logger."""

log_writer: AbstractLogWriter | str | dict
entry_transformer: AbstractLogEntryTransformer | str | dict = field(default=None)

Expand Down Expand Up @@ -69,6 +74,7 @@ def __call__(self, request_context: RequestContext):
@dataclass
class JsonLogWriter(AbstractLogWriter):
"""Writes logged data to a JSON file."""

file_name: str

def __post_init__(self):
Expand Down
15 changes: 6 additions & 9 deletions lm_proxy/models_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,18 @@ async def models(request: Request) -> JSONResponse:
is_model_name = not ("*" in model_pattern or "?" in model_pattern)
if not is_model_name:
if env.config.model_listing_mode != ModelListingMode.AS_IS:
if (
env.config.model_listing_mode
== ModelListingMode.IGNORE_WILDCARDS
):
if env.config.model_listing_mode == ModelListingMode.IGNORE_WILDCARDS:
continue
raise NotImplementedError(
f"'{env.config.model_listing_mode}' model listing mode "
f"is not implemented yet"
)
model_data = {
"id": model_pattern,
"object": "model",
"created": 0,
"owned_by": connection_name,
}
"id": model_pattern,
"object": "model",
"created": 0,
"owned_by": connection_name,
}

if aux_info := env.config.model_info.get(model_pattern):
model_data.update(aux_info)
Expand Down
9 changes: 5 additions & 4 deletions multi-build.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,17 @@ def replace_name(old_names: list[str], new_names: list[str], files: list[str] =
p = Path(path)
p.write_text(
re.sub(
fr'(?<![\\/\w]){old_name}\b',
rf"(?<![\\/\w]){old_name}\b",
new_name,
p.read_text(encoding="utf-8"),
flags=re.M
), encoding="utf-8"
flags=re.M,
),
encoding="utf-8",
)


prev = NAMES[0]
for nxt in NAMES[1:]+[NAMES[0]]:
for nxt in NAMES[1:] + [NAMES[0]]:
print(f"Building for project name: {nxt[0]}...")
replace_name(prev, nxt)
subprocess.run(["poetry", "build"], check=True)
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,5 @@ asyncio_mode = "auto"
testpaths = [
"tests",
]
[tool.black]
line-length = 100
Loading