Skip to content

Commit 39c6ecb

Browse files
feat: add audit logger middleware and custom logger using loguru
1 parent cf4bbcd commit 39c6ecb

File tree

5 files changed

+434
-0
lines changed

5 files changed

+434
-0
lines changed

backend/open_webui/env.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,25 @@ def parse_section(section):
419419

420420
if OFFLINE_MODE:
421421
os.environ["HF_HUB_OFFLINE"] = "1"
422+
423+
####################################
424+
# AUDIT LOGGING
425+
####################################
426+
ENABLE_AUDIT_LOGS = os.getenv("ENABLE_AUDIT_LOGS", "false").lower() == "true"
427+
# Where to store log file
428+
AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log"
429+
# Maximum size of a file before rotating into a new log file
430+
AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB")
431+
# METADATA | REQUEST | REQUEST_RESPONSE
432+
AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "REQUEST_RESPONSE").upper()
433+
try:
434+
MAX_BODY_LOG_SIZE = int(os.environ.get("MAX_BODY_LOG_SIZE") or 2048)
435+
except ValueError:
436+
MAX_BODY_LOG_SIZE = 2048
437+
438+
# Comma separated list for urls to exclude from audit
439+
AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders").split(
440+
","
441+
)
442+
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
443+
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]

backend/open_webui/main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545
from starlette.responses import Response, StreamingResponse
4646

4747

48+
from open_webui.utils import logger
49+
from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware
50+
from open_webui.utils.logger import start_logger
4851
from open_webui.socket.main import (
4952
app as socket_app,
5053
periodic_usage_pool_cleanup,
@@ -298,8 +301,11 @@
298301
reset_config,
299302
)
300303
from open_webui.env import (
304+
AUDIT_EXCLUDED_PATHS,
305+
AUDIT_LOG_LEVEL,
301306
CHANGELOG,
302307
GLOBAL_LOG_LEVEL,
308+
MAX_BODY_LOG_SIZE,
303309
SAFE_MODE,
304310
SRC_LOG_LEVELS,
305311
VERSION,
@@ -384,6 +390,7 @@ async def get_response(self, path: str, scope):
384390

385391
@asynccontextmanager
386392
async def lifespan(app: FastAPI):
393+
start_logger()
387394
if RESET_CONFIG_ON_START:
388395
reset_config()
389396

@@ -879,6 +886,19 @@ async def inspect_websocket(request: Request, call_next):
879886
app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"])
880887

881888

889+
try:
890+
audit_level = AuditLevel(AUDIT_LOG_LEVEL)
891+
except ValueError as e:
892+
logger.error(f"Invalid audit level: {AUDIT_LOG_LEVEL}. Error: {e}")
893+
audit_level = AuditLevel.NONE
894+
895+
if audit_level != AuditLevel.NONE:
896+
app.add_middleware(
897+
AuditLoggingMiddleware,
898+
audit_level=audit_level,
899+
excluded_paths=AUDIT_EXCLUDED_PATHS,
900+
max_body_size=MAX_BODY_LOG_SIZE,
901+
)
882902
##################################
883903
#
884904
# Chat Endpoints

backend/open_webui/utils/audit.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
from contextlib import asynccontextmanager
2+
from dataclasses import asdict, dataclass
3+
from enum import Enum
4+
import re
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
AsyncGenerator,
9+
Dict,
10+
MutableMapping,
11+
Optional,
12+
cast,
13+
)
14+
import uuid
15+
16+
from asgiref.typing import (
17+
ASGI3Application,
18+
ASGIReceiveCallable,
19+
ASGIReceiveEvent,
20+
ASGISendCallable,
21+
ASGISendEvent,
22+
Scope as ASGIScope,
23+
)
24+
from loguru import logger
25+
from starlette.requests import Request
26+
27+
from open_webui.env import AUDIT_LOG_LEVEL, MAX_BODY_LOG_SIZE
28+
from open_webui.utils.auth import get_current_user, get_http_authorization_cred
29+
from open_webui.models.users import UserModel
30+
31+
32+
if TYPE_CHECKING:
33+
from loguru import Logger
34+
35+
36+
@dataclass(frozen=True)
37+
class AuditLogEntry:
38+
# `Metadata` audit level properties
39+
id: str
40+
user: dict[str, Any]
41+
audit_level: str
42+
verb: str
43+
request_uri: str
44+
user_agent: Optional[str] = None
45+
source_ip: Optional[str] = None
46+
# `Request` audit level properties
47+
request_object: Any = None
48+
# `Request Response` level
49+
response_object: Any = None
50+
response_status_code: Optional[int] = None
51+
52+
53+
class AuditLevel(str, Enum):
54+
NONE = "NONE"
55+
METADATA = "METADATA"
56+
REQUEST = "REQUEST"
57+
REQUEST_RESPONSE = "REQUEST_RESPONSE"
58+
59+
60+
class AuditLogger:
61+
"""
62+
A helper class that encapsulates audit logging functionality. It uses Loguru’s logger with an auditable binding to ensure that audit log entries are filtered correctly.
63+
64+
Parameters:
65+
logger (Logger): An instance of Loguru’s logger.
66+
"""
67+
68+
def __init__(self, logger: "Logger"):
69+
self.logger = logger.bind(auditable=True)
70+
71+
def write(
72+
self,
73+
audit_entry: AuditLogEntry,
74+
*,
75+
log_level: str = "INFO",
76+
extra: Optional[dict] = None,
77+
):
78+
79+
entry = asdict(audit_entry)
80+
81+
if extra:
82+
entry["extra"] = extra
83+
84+
self.logger.log(
85+
log_level,
86+
"",
87+
**entry,
88+
)
89+
90+
91+
class AuditContext:
92+
"""
93+
Captures and aggregates the HTTP request and response bodies during the processing of a request. It ensures that only a configurable maximum amount of data is stored to prevent excessive memory usage.
94+
95+
Attributes:
96+
request_body (bytearray): Accumulated request payload.
97+
response_body (bytearray): Accumulated response payload.
98+
max_body_size (int): Maximum number of bytes to capture.
99+
metadata (Dict[str, Any]): A dictionary to store additional audit metadata (user, http verb, user agent, etc.).
100+
"""
101+
102+
def __init__(self, max_body_size: int = MAX_BODY_LOG_SIZE):
103+
self.request_body = bytearray()
104+
self.response_body = bytearray()
105+
self.max_body_size = max_body_size
106+
self.metadata: Dict[str, Any] = {}
107+
108+
def add_request_chunk(self, chunk: bytes):
109+
if len(self.request_body) < self.max_body_size:
110+
self.request_body.extend(
111+
chunk[: self.max_body_size - len(self.request_body)]
112+
)
113+
114+
def add_response_chunk(self, chunk: bytes):
115+
if len(self.response_body) < self.max_body_size:
116+
self.response_body.extend(
117+
chunk[: self.max_body_size - len(self.response_body)]
118+
)
119+
120+
121+
class AuditLoggingMiddleware:
122+
"""
123+
ASGI middleware that intercepts HTTP requests and responses to perform audit logging. It captures request/response bodies (depending on audit level), headers, HTTP methods, and user information, then logs a structured audit entry at the end of the request cycle.
124+
"""
125+
126+
AUDITED_METHODS = {"PUT", "PATCH", "DELETE", "POST"}
127+
128+
def __init__(
129+
self,
130+
app: ASGI3Application,
131+
*,
132+
excluded_paths: Optional[list[str]] = None,
133+
max_body_size: int = MAX_BODY_LOG_SIZE,
134+
audit_level: AuditLevel = AuditLevel.NONE,
135+
) -> None:
136+
self.app = app
137+
self.audit_logger = AuditLogger(logger)
138+
self.excluded_paths = excluded_paths or []
139+
self.max_body_size = max_body_size
140+
self.audit_level = audit_level
141+
142+
async def __call__(
143+
self,
144+
scope: ASGIScope,
145+
receive: ASGIReceiveCallable,
146+
send: ASGISendCallable,
147+
) -> None:
148+
if scope["type"] != "http":
149+
return await self.app(scope, receive, send)
150+
151+
request = Request(scope=cast(MutableMapping, scope))
152+
153+
if self._should_skip_auditing(request):
154+
return await self.app(scope, receive, send)
155+
156+
async with self._audit_context(request) as context:
157+
158+
async def send_wrapper(message: ASGISendEvent) -> None:
159+
if self.audit_level == AuditLevel.REQUEST_RESPONSE:
160+
await self._capture_response(message, context)
161+
162+
await send(message)
163+
164+
original_receive = receive
165+
166+
async def receive_wrapper() -> ASGIReceiveEvent:
167+
nonlocal original_receive
168+
message = await original_receive()
169+
170+
if self.audit_level in (
171+
AuditLevel.REQUEST,
172+
AuditLevel.REQUEST_RESPONSE,
173+
):
174+
await self._capture_request(message, context)
175+
176+
return message
177+
178+
await self.app(scope, receive_wrapper, send_wrapper)
179+
180+
@asynccontextmanager
181+
async def _audit_context(
182+
self, request: Request
183+
) -> AsyncGenerator[AuditContext, None]:
184+
"""
185+
async context manager that ensures that an audit log entry is recorded after the request is processed.
186+
"""
187+
context = AuditContext()
188+
try:
189+
yield context
190+
finally:
191+
await self._log_audit_entry(request, context)
192+
193+
async def _get_authenticated_user(self, request: Request) -> UserModel:
194+
195+
auth_header = request.headers.get("Authorization")
196+
assert auth_header
197+
user = get_current_user(request, get_http_authorization_cred(auth_header))
198+
199+
return user
200+
201+
def _should_skip_auditing(self, request: Request) -> bool:
202+
if (
203+
request.method not in {"POST", "PUT", "PATCH", "DELETE"}
204+
or AUDIT_LOG_LEVEL == "NONE"
205+
or not request.headers.get("authorization")
206+
):
207+
return True
208+
# match either /api/<resource>/...(for the endpoint /api/chat case) or /api/v1/<resource>/...
209+
pattern = re.compile(
210+
r"^/api(?:/v1)?/(" + "|".join(self.excluded_paths) + r")\b"
211+
)
212+
if pattern.match(request.url.path):
213+
return True
214+
215+
return False
216+
217+
async def _capture_request(self, message: ASGIReceiveEvent, context: AuditContext):
218+
if message["type"] == "http.request":
219+
body = message.get("body", b"")
220+
context.add_request_chunk(body)
221+
222+
async def _capture_response(self, message: ASGISendEvent, context: AuditContext):
223+
if message["type"] == "http.response.start":
224+
context.metadata["response_status_code"] = message["status"]
225+
226+
elif message["type"] == "http.response.body":
227+
body = message.get("body", b"")
228+
context.add_response_chunk(body)
229+
230+
async def _log_audit_entry(self, request: Request, context: AuditContext):
231+
try:
232+
user = await self._get_authenticated_user(request)
233+
234+
entry = AuditLogEntry(
235+
id=str(uuid.uuid4()),
236+
user=user.model_dump(include={"id", "name", "email", "role"}),
237+
audit_level=self.audit_level.value,
238+
verb=request.method,
239+
request_uri=str(request.url),
240+
response_status_code=context.metadata.get("response_status_code", None),
241+
source_ip=request.client.host if request.client else None,
242+
user_agent=request.headers.get("user-agent"),
243+
request_object=context.request_body.decode("utf-8", errors="replace"),
244+
response_object=context.response_body.decode("utf-8", errors="replace"),
245+
)
246+
247+
self.audit_logger.write(entry)
248+
except Exception as e:
249+
logger.error(f"Failed to log audit entry: {str(e)}")

0 commit comments

Comments
 (0)