Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ENV APP_BUILD_HASH=${BUILD_HASH}
RUN npm run build

######## WebUI backend ########
FROM python:3.11.14-slim-bookworm AS base
FROM python:3.11-slim-bookworm AS base

# Use args
ARG USE_CUDA
Expand Down
50 changes: 1 addition & 49 deletions backend/open_webui/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from open_webui.env import (
DATA_DIR,
DATABASE_URL,
ENABLE_DB_MIGRATIONS,
ENV,
REDIS_URL,
REDIS_KEY_PREFIX,
Expand Down Expand Up @@ -68,8 +67,7 @@ def run_migrations():
log.exception(f"Error running migrations: {e}")


if ENABLE_DB_MIGRATIONS:
run_migrations()
run_migrations()


class Config(Base):
Expand Down Expand Up @@ -2373,51 +2371,6 @@ class BannerModel(BaseModel):
except Exception:
PGVECTOR_IVFFLAT_LISTS = 100

# openGauss
OPENGAUSS_DB_URL = os.environ.get("OPENGAUSS_DB_URL", DATABASE_URL)

OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH = int(
os.environ.get("OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH", "1536")
)

OPENGAUSS_POOL_SIZE = os.environ.get("OPENGAUSS_POOL_SIZE", None)

if OPENGAUSS_POOL_SIZE != None:
try:
OPENGAUSS_POOL_SIZE = int(OPENGAUSS_POOL_SIZE)
except Exception:
OPENGAUSS_POOL_SIZE = None

OPENGAUSS_POOL_MAX_OVERFLOW = os.environ.get("OPENGAUSS_POOL_MAX_OVERFLOW", 0)

if OPENGAUSS_POOL_MAX_OVERFLOW == "":
OPENGAUSS_POOL_MAX_OVERFLOW = 0
else:
try:
OPENGAUSS_POOL_MAX_OVERFLOW = int(OPENGAUSS_POOL_MAX_OVERFLOW)
except Exception:
OPENGAUSS_POOL_MAX_OVERFLOW = 0

OPENGAUSS_POOL_TIMEOUT = os.environ.get("OPENGAUSS_POOL_TIMEOUT", 30)

if OPENGAUSS_POOL_TIMEOUT == "":
OPENGAUSS_POOL_TIMEOUT = 30
else:
try:
OPENGAUSS_POOL_TIMEOUT = int(OPENGAUSS_POOL_TIMEOUT)
except Exception:
OPENGAUSS_POOL_TIMEOUT = 30

OPENGAUSS_POOL_RECYCLE = os.environ.get("OPENGAUSS_POOL_RECYCLE", 3600)

if OPENGAUSS_POOL_RECYCLE == "":
OPENGAUSS_POOL_RECYCLE = 3600
else:
try:
OPENGAUSS_POOL_RECYCLE = int(OPENGAUSS_POOL_RECYCLE)
except Exception:
OPENGAUSS_POOL_RECYCLE = 3600

# Pinecone
PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None)
PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None)
Expand Down Expand Up @@ -3737,7 +3690,6 @@ class BannerModel(BaseModel):
os.getenv("WHISPER_MODEL", "base"),
)

WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
WHISPER_MODEL_AUTO_UPDATE = (
not OFFLINE_MODE
Expand Down
4 changes: 2 additions & 2 deletions backend/open_webui/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@
SRC_LOG_LEVELS = {} # Legacy variable, do not remove

WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
# if WEBUI_NAME != "Open WebUI":
# WEBUI_NAME += " (Open WebUI)"

WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"

Expand Down
3 changes: 1 addition & 2 deletions backend/open_webui/internal/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def handle_peewee_migration(DATABASE_URL):
assert db.is_closed(), "Database connection is still open."


if ENABLE_DB_MIGRATIONS:
handle_peewee_migration(DATABASE_URL)
handle_peewee_migration(DATABASE_URL)


SQLALCHEMY_DATABASE_URL = DATABASE_URL
Expand Down
15 changes: 4 additions & 11 deletions backend/open_webui/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@
get_rf,
)


from sqlalchemy.orm import Session
from open_webui.internal.db import ScopedSession, engine, get_session
from open_webui.internal.db import Session, ScopedSession, engine

from open_webui.models.functions import Functions
from open_webui.models.models import Models
Expand Down Expand Up @@ -2315,13 +2313,8 @@ async def oauth_login(provider: str, request: Request):
# - Email addresses are considered unique, so we fail registration if the email address is already taken
@app.get("/oauth/{provider}/login/callback")
@app.get("/oauth/{provider}/callback") # Legacy endpoint
async def oauth_login_callback(
provider: str,
request: Request,
response: Response,
db: Session = Depends(get_session),
):
return await oauth_manager.handle_callback(request, provider, response, db=db)
async def oauth_login_callback(provider: str, request: Request, response: Response):
return await oauth_manager.handle_callback(request, provider, response)


@app.get("/manifest.json")
Expand Down Expand Up @@ -2380,7 +2373,7 @@ async def healthcheck():

@app.get("/health/db")
async def healthcheck_with_db():
ScopedSession.execute(text("SELECT 1;")).all()
Session.execute(text("SELECT 1;")).all()
return {"status": True}


Expand Down
30 changes: 14 additions & 16 deletions backend/open_webui/models/auths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import uuid
from typing import Optional

from sqlalchemy.orm import Session
from open_webui.internal.db import Base, JSONField, get_db, get_db_context
from open_webui.internal.db import Base, get_db, Session
from open_webui.models.users import UserModel, UserProfileImageResponse, Users
from pydantic import BaseModel
from sqlalchemy import Boolean, Column, String, Text
Expand Down Expand Up @@ -88,9 +87,8 @@ def insert_new_auth(
profile_image_url: str = "/user.png",
role: str = "pending",
oauth: Optional[dict] = None,
db: Optional[Session] = None,
) -> Optional[UserModel]:
with get_db_context(db) as db:
with get_db() as db:
log.info("insert_new_auth")

id = str(uuid.uuid4())
Expand All @@ -102,7 +100,7 @@ def insert_new_auth(
db.add(result)

user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth=oauth, db=db
id, name, email, profile_image_url, role, oauth=oauth
)

db.commit()
Expand All @@ -114,16 +112,16 @@ def insert_new_auth(
return None

def authenticate_user(
self, email: str, verify_password: callable, db: Optional[Session] = None
self, email: str, verify_password: callable
) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")

user = Users.get_user_by_email(email, db=db)
user = Users.get_user_by_email(email)
if not user:
return None

try:
with get_db_context(db) as db:
with get_db() as db:
auth = db.query(Auth).filter_by(id=user.id, active=True).first()
if auth:
if verify_password(auth.password):
Expand All @@ -144,7 +142,7 @@ def authenticate_user_by_api_key(
return None

try:
user = Users.get_user_by_api_key(api_key, db=db)
user = Users.get_user_by_api_key(api_key)
return user if user else None
except Exception:
return False
Expand All @@ -154,10 +152,10 @@ def authenticate_user_by_email(
) -> Optional[UserModel]:
log.info(f"authenticate_user_by_email: {email}")
try:
with get_db_context(db) as db:
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth:
user = Users.get_user_by_id(auth.id, db=db)
user = Users.get_user_by_id(auth.id)
return user
except Exception:
return None
Expand All @@ -166,7 +164,7 @@ def update_user_password_by_id(
self, id: str, new_password: str, db: Optional[Session] = None
) -> bool:
try:
with get_db_context(db) as db:
with get_db() as db:
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
Expand All @@ -179,18 +177,18 @@ def update_email_by_id(
self, id: str, email: str, db: Optional[Session] = None
) -> bool:
try:
with get_db_context(db) as db:
with get_db() as db:
result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit()
return True if result == 1 else False
except Exception:
return False

def delete_auth_by_id(self, id: str, db: Optional[Session] = None) -> bool:
def delete_auth_by_id(self, id: str) -> bool:
try:
with get_db_context(db) as db:
with get_db() as db:
# Delete User
result = Users.delete_user_by_id(id, db=db)
result = Users.delete_user_by_id(id)

if result:
db.query(Auth).filter_by(id=id).delete()
Expand Down
54 changes: 22 additions & 32 deletions backend/open_webui/models/notes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from typing import Optional
from functools import lru_cache

from sqlalchemy.orm import Session
from open_webui.internal.db import Base, get_db, get_db_context
from open_webui.internal.db import Base, get_db
from open_webui.models.groups import Groups
from open_webui.utils.access_control import has_access
from open_webui.models.users import User, UserModel, Users, UserResponse
Expand Down Expand Up @@ -212,9 +211,11 @@ def _has_permission(self, db, query, filter: dict, permission: str = "read"):
return query

def insert_new_note(
self, user_id: str, form_data: NoteForm, db: Optional[Session] = None
self,
form_data: NoteForm,
user_id: str,
) -> Optional[NoteModel]:
with get_db_context(db) as db:
with get_db() as db:
note = NoteModel(
**{
"id": str(uuid.uuid4()),
Expand All @@ -232,9 +233,9 @@ def insert_new_note(
return note

def get_notes(
self, skip: int = 0, limit: int = 50, db: Optional[Session] = None
self, skip: Optional[int] = None, limit: Optional[int] = None
) -> list[NoteModel]:
with get_db_context(db) as db:
with get_db() as db:
query = db.query(Note).order_by(Note.updated_at.desc())
if skip is not None:
query = query.offset(skip)
Expand All @@ -244,14 +245,9 @@ def get_notes(
return [NoteModel.model_validate(note) for note in notes]

def search_notes(
self,
user_id: str,
filter: dict = {},
skip: int = 0,
limit: int = 30,
db: Optional[Session] = None,
self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30
) -> NoteListResponse:
with get_db_context(db) as db:
with get_db() as db:
query = db.query(Note, User).outerjoin(User, User.id == Note.user_id)
if filter:
query_key = filter.get("query")
Expand Down Expand Up @@ -345,13 +341,12 @@ def get_notes_by_user_id(
self,
user_id: str,
permission: str = "read",
skip: int = 0,
limit: int = 50,
db: Optional[Session] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
) -> list[NoteModel]:
with get_db_context(db) as db:
with get_db() as db:
user_group_ids = [
group.id for group in Groups.get_groups_by_member_id(user_id, db=db)
group.id for group in Groups.get_groups_by_member_id(user_id)
]

query = db.query(Note).order_by(Note.updated_at.desc())
Expand All @@ -367,17 +362,15 @@ def get_notes_by_user_id(
notes = query.all()
return [NoteModel.model_validate(note) for note in notes]

def get_note_by_id(
self, id: str, db: Optional[Session] = None
) -> Optional[NoteModel]:
with get_db_context(db) as db:
def get_note_by_id(self, id: str) -> Optional[NoteModel]:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
return NoteModel.model_validate(note) if note else None

def update_note_by_id(
self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None
self, id: str, form_data: NoteUpdateForm
) -> Optional[NoteModel]:
with get_db_context(db) as db:
with get_db() as db:
note = db.query(Note).filter(Note.id == id).first()
if not note:
return None
Expand All @@ -399,14 +392,11 @@ def update_note_by_id(
db.commit()
return NoteModel.model_validate(note) if note else None

def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool:
try:
with get_db_context(db) as db:
db.query(Note).filter(Note.id == id).delete()
db.commit()
return True
except Exception:
return False
def delete_note_by_id(self, id: str):
with get_db() as db:
db.query(Note).filter(Note.id == id).delete()
db.commit()
return True


Notes = NoteTable()
Loading