diff --git a/Dockerfile b/Dockerfile index ac81a3943ca..d8dcd307ef2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -43,7 +43,7 @@ ENV APP_BUILD_HASH=${BUILD_HASH} RUN npm run build ######## WebUI backend ######## -FROM python:3.11.14-slim-bookworm AS base +FROM python:3.11-slim-bookworm AS base # Use args ARG USE_CUDA diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index dab5b6cfe82..babc744c3de 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -19,7 +19,6 @@ from open_webui.env import ( DATA_DIR, DATABASE_URL, - ENABLE_DB_MIGRATIONS, ENV, REDIS_URL, REDIS_KEY_PREFIX, @@ -68,8 +67,7 @@ def run_migrations(): log.exception(f"Error running migrations: {e}") -if ENABLE_DB_MIGRATIONS: - run_migrations() +run_migrations() class Config(Base): @@ -2373,51 +2371,6 @@ class BannerModel(BaseModel): except Exception: PGVECTOR_IVFFLAT_LISTS = 100 -# openGauss -OPENGAUSS_DB_URL = os.environ.get("OPENGAUSS_DB_URL", DATABASE_URL) - -OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH = int( - os.environ.get("OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH", "1536") -) - -OPENGAUSS_POOL_SIZE = os.environ.get("OPENGAUSS_POOL_SIZE", None) - -if OPENGAUSS_POOL_SIZE != None: - try: - OPENGAUSS_POOL_SIZE = int(OPENGAUSS_POOL_SIZE) - except Exception: - OPENGAUSS_POOL_SIZE = None - -OPENGAUSS_POOL_MAX_OVERFLOW = os.environ.get("OPENGAUSS_POOL_MAX_OVERFLOW", 0) - -if OPENGAUSS_POOL_MAX_OVERFLOW == "": - OPENGAUSS_POOL_MAX_OVERFLOW = 0 -else: - try: - OPENGAUSS_POOL_MAX_OVERFLOW = int(OPENGAUSS_POOL_MAX_OVERFLOW) - except Exception: - OPENGAUSS_POOL_MAX_OVERFLOW = 0 - -OPENGAUSS_POOL_TIMEOUT = os.environ.get("OPENGAUSS_POOL_TIMEOUT", 30) - -if OPENGAUSS_POOL_TIMEOUT == "": - OPENGAUSS_POOL_TIMEOUT = 30 -else: - try: - OPENGAUSS_POOL_TIMEOUT = int(OPENGAUSS_POOL_TIMEOUT) - except Exception: - OPENGAUSS_POOL_TIMEOUT = 30 - -OPENGAUSS_POOL_RECYCLE = os.environ.get("OPENGAUSS_POOL_RECYCLE", 3600) - -if OPENGAUSS_POOL_RECYCLE == "": - OPENGAUSS_POOL_RECYCLE = 3600 -else: - try: - OPENGAUSS_POOL_RECYCLE = int(OPENGAUSS_POOL_RECYCLE) - except Exception: - OPENGAUSS_POOL_RECYCLE = 3600 - # Pinecone PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) @@ -3737,7 +3690,6 @@ class BannerModel(BaseModel): os.getenv("WHISPER_MODEL", "base"), ) -WHISPER_COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "int8") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_AUTO_UPDATE = ( not OFFLINE_MODE diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index 6050e37fa30..832914bcd8c 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -77,8 +77,7 @@ def handle_peewee_migration(DATABASE_URL): assert db.is_closed(), "Database connection is still open." -if ENABLE_DB_MIGRATIONS: - handle_peewee_migration(DATABASE_URL) +handle_peewee_migration(DATABASE_URL) SQLALCHEMY_DATABASE_URL = DATABASE_URL diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 82a87e3fd9d..e07be5bd7b9 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -102,9 +102,7 @@ get_rf, ) - -from sqlalchemy.orm import Session -from open_webui.internal.db import ScopedSession, engine, get_session +from open_webui.internal.db import Session, ScopedSession, engine from open_webui.models.functions import Functions from open_webui.models.models import Models @@ -2315,13 +2313,8 @@ async def oauth_login(provider: str, request: Request): # - Email addresses are considered unique, so we fail registration if the email address is already taken @app.get("/oauth/{provider}/login/callback") @app.get("/oauth/{provider}/callback") # Legacy endpoint -async def oauth_login_callback( - provider: str, - request: Request, - response: Response, - db: Session = Depends(get_session), -): - return await oauth_manager.handle_callback(request, provider, response, db=db) +async def oauth_login_callback(provider: str, request: Request, response: Response): + return await oauth_manager.handle_callback(request, provider, response) @app.get("/manifest.json") @@ -2380,7 +2373,7 @@ async def healthcheck(): @app.get("/health/db") async def healthcheck_with_db(): - ScopedSession.execute(text("SELECT 1;")).all() + Session.execute(text("SELECT 1;")).all() return {"status": True} diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 93f17dff115..24ff95b91f5 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -2,8 +2,7 @@ import uuid from typing import Optional -from sqlalchemy.orm import Session -from open_webui.internal.db import Base, JSONField, get_db, get_db_context +from open_webui.internal.db import Base, get_db, Session from open_webui.models.users import UserModel, UserProfileImageResponse, Users from pydantic import BaseModel from sqlalchemy import Boolean, Column, String, Text @@ -88,9 +87,8 @@ def insert_new_auth( profile_image_url: str = "/user.png", role: str = "pending", oauth: Optional[dict] = None, - db: Optional[Session] = None, ) -> Optional[UserModel]: - with get_db_context(db) as db: + with get_db() as db: log.info("insert_new_auth") id = str(uuid.uuid4()) @@ -102,7 +100,7 @@ def insert_new_auth( db.add(result) user = Users.insert_new_user( - id, name, email, profile_image_url, role, oauth=oauth, db=db + id, name, email, profile_image_url, role, oauth=oauth ) db.commit() @@ -114,16 +112,16 @@ def insert_new_auth( return None def authenticate_user( - self, email: str, verify_password: callable, db: Optional[Session] = None + self, email: str, verify_password: callable ) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") - user = Users.get_user_by_email(email, db=db) + user = Users.get_user_by_email(email) if not user: return None try: - with get_db_context(db) as db: + with get_db() as db: auth = db.query(Auth).filter_by(id=user.id, active=True).first() if auth: if verify_password(auth.password): @@ -144,7 +142,7 @@ def authenticate_user_by_api_key( return None try: - user = Users.get_user_by_api_key(api_key, db=db) + user = Users.get_user_by_api_key(api_key) return user if user else None except Exception: return False @@ -154,10 +152,10 @@ def authenticate_user_by_email( ) -> Optional[UserModel]: log.info(f"authenticate_user_by_email: {email}") try: - with get_db_context(db) as db: + with get_db() as db: auth = db.query(Auth).filter_by(email=email, active=True).first() if auth: - user = Users.get_user_by_id(auth.id, db=db) + user = Users.get_user_by_id(auth.id) return user except Exception: return None @@ -166,7 +164,7 @@ def update_user_password_by_id( self, id: str, new_password: str, db: Optional[Session] = None ) -> bool: try: - with get_db_context(db) as db: + with get_db() as db: result = ( db.query(Auth).filter_by(id=id).update({"password": new_password}) ) @@ -179,18 +177,18 @@ def update_email_by_id( self, id: str, email: str, db: Optional[Session] = None ) -> bool: try: - with get_db_context(db) as db: + with get_db() as db: result = db.query(Auth).filter_by(id=id).update({"email": email}) db.commit() return True if result == 1 else False except Exception: return False - def delete_auth_by_id(self, id: str, db: Optional[Session] = None) -> bool: + def delete_auth_by_id(self, id: str) -> bool: try: - with get_db_context(db) as db: + with get_db() as db: # Delete User - result = Users.delete_user_by_id(id, db=db) + result = Users.delete_user_by_id(id) if result: db.query(Auth).filter_by(id=id).delete() diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index bd235307858..77d91aa3db0 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -4,8 +4,7 @@ from typing import Optional from functools import lru_cache -from sqlalchemy.orm import Session -from open_webui.internal.db import Base, get_db, get_db_context +from open_webui.internal.db import Base, get_db from open_webui.models.groups import Groups from open_webui.utils.access_control import has_access from open_webui.models.users import User, UserModel, Users, UserResponse @@ -212,9 +211,11 @@ def _has_permission(self, db, query, filter: dict, permission: str = "read"): return query def insert_new_note( - self, user_id: str, form_data: NoteForm, db: Optional[Session] = None + self, + form_data: NoteForm, + user_id: str, ) -> Optional[NoteModel]: - with get_db_context(db) as db: + with get_db() as db: note = NoteModel( **{ "id": str(uuid.uuid4()), @@ -232,9 +233,9 @@ def insert_new_note( return note def get_notes( - self, skip: int = 0, limit: int = 50, db: Optional[Session] = None + self, skip: Optional[int] = None, limit: Optional[int] = None ) -> list[NoteModel]: - with get_db_context(db) as db: + with get_db() as db: query = db.query(Note).order_by(Note.updated_at.desc()) if skip is not None: query = query.offset(skip) @@ -244,14 +245,9 @@ def get_notes( return [NoteModel.model_validate(note) for note in notes] def search_notes( - self, - user_id: str, - filter: dict = {}, - skip: int = 0, - limit: int = 30, - db: Optional[Session] = None, + self, user_id: str, filter: dict = {}, skip: int = 0, limit: int = 30 ) -> NoteListResponse: - with get_db_context(db) as db: + with get_db() as db: query = db.query(Note, User).outerjoin(User, User.id == Note.user_id) if filter: query_key = filter.get("query") @@ -345,13 +341,12 @@ def get_notes_by_user_id( self, user_id: str, permission: str = "read", - skip: int = 0, - limit: int = 50, - db: Optional[Session] = None, + skip: Optional[int] = None, + limit: Optional[int] = None, ) -> list[NoteModel]: - with get_db_context(db) as db: + with get_db() as db: user_group_ids = [ - group.id for group in Groups.get_groups_by_member_id(user_id, db=db) + group.id for group in Groups.get_groups_by_member_id(user_id) ] query = db.query(Note).order_by(Note.updated_at.desc()) @@ -367,17 +362,15 @@ def get_notes_by_user_id( notes = query.all() return [NoteModel.model_validate(note) for note in notes] - def get_note_by_id( - self, id: str, db: Optional[Session] = None - ) -> Optional[NoteModel]: - with get_db_context(db) as db: + def get_note_by_id(self, id: str) -> Optional[NoteModel]: + with get_db() as db: note = db.query(Note).filter(Note.id == id).first() return NoteModel.model_validate(note) if note else None def update_note_by_id( - self, id: str, form_data: NoteUpdateForm, db: Optional[Session] = None + self, id: str, form_data: NoteUpdateForm ) -> Optional[NoteModel]: - with get_db_context(db) as db: + with get_db() as db: note = db.query(Note).filter(Note.id == id).first() if not note: return None @@ -399,14 +392,11 @@ def update_note_by_id( db.commit() return NoteModel.model_validate(note) if note else None - def delete_note_by_id(self, id: str, db: Optional[Session] = None) -> bool: - try: - with get_db_context(db) as db: - db.query(Note).filter(Note.id == id).delete() - db.commit() - return True - except Exception: - return False + def delete_note_by_id(self, id: str): + with get_db() as db: + db.query(Note).filter(Note.id == id).delete() + db.commit() + return True Notes = NoteTable() diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 0d36d94b8f7..5a86a018461 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,8 +1,7 @@ import time from typing import Optional -from sqlalchemy.orm import Session -from open_webui.internal.db import Base, JSONField, get_db, get_db_context +from open_webui.internal.db import Base, JSONField, get_db, get_db_context, Session from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL @@ -244,9 +243,8 @@ def insert_new_user( profile_image_url: str = "/user.png", role: str = "pending", oauth: Optional[dict] = None, - db: Optional[Session] = None, ) -> Optional[UserModel]: - with get_db_context(db) as db: + with get_db() as db: user = UserModel( **{ "id": id, @@ -273,7 +271,7 @@ def get_user_by_id( self, id: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) except Exception: @@ -283,7 +281,7 @@ def get_user_by_api_key( self, api_key: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: user = ( db.query(User) .join(ApiKey, User.id == ApiKey.user_id) @@ -298,7 +296,7 @@ def get_user_by_email( self, email: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: user = db.query(User).filter_by(email=email).first() return UserModel.model_validate(user) except Exception: @@ -308,7 +306,7 @@ def get_user_by_oauth_sub( self, provider: str, sub: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db_context(db) as db: # type: Session + with get_db() as db: # type: Session dialect_name = db.bind.dialect.name query = db.query(User) @@ -330,9 +328,8 @@ def get_users( filter: Optional[dict] = None, skip: Optional[int] = None, limit: Optional[int] = None, - db: Optional[Session] = None, ) -> dict: - with get_db_context(db) as db: + with get_db() as db: # Join GroupMember so we can order by group_id when requested query = db.query(User) @@ -482,17 +479,17 @@ def get_users_by_user_ids( users = db.query(User).filter(User.id.in_(user_ids)).all() return [UserModel.model_validate(user) for user in users] - def get_num_users(self, db: Optional[Session] = None) -> Optional[int]: - with get_db_context(db) as db: + def get_num_users(self) -> Optional[int]: + with get_db() as db: return db.query(User).count() - def has_users(self, db: Optional[Session] = None) -> bool: - with get_db_context(db) as db: + def has_users(self) -> bool: + with get_db() as db: return db.query(db.query(User).exists()).scalar() - def get_first_user(self, db: Optional[Session] = None) -> UserModel: + def get_first_user(self) -> UserModel: try: - with get_db_context(db) as db: + with get_db() as db: user = db.query(User).order_by(User.created_at).first() return UserModel.model_validate(user) except Exception: @@ -502,7 +499,7 @@ def get_user_webhook_url_by_id( self, id: str, db: Optional[Session] = None ) -> Optional[str]: try: - with get_db_context(db) as db: + with get_db() as db: user = db.query(User).filter_by(id=id).first() if user.settings is None: @@ -516,8 +513,8 @@ def get_user_webhook_url_by_id( except Exception: return None - def get_num_users_active_today(self, db: Optional[Session] = None) -> Optional[int]: - with get_db_context(db) as db: + def get_num_users_active_today(self) -> Optional[int]: + with get_db() as db: current_timestamp = int(datetime.datetime.now().timestamp()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) query = db.query(User).filter( @@ -529,7 +526,7 @@ def update_user_role_by_id( self, id: str, role: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: db.query(User).filter_by(id=id).update({"role": role}) db.commit() user = db.query(User).filter_by(id=id).first() @@ -538,10 +535,10 @@ def update_user_role_by_id( return None def update_user_status_by_id( - self, id: str, form_data: UserStatus, db: Optional[Session] = None + self, id: str, form_data: UserStatus ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: db.query(User).filter_by(id=id).update( {**form_data.model_dump(exclude_none=True)} ) @@ -553,10 +550,10 @@ def update_user_status_by_id( return None def update_user_profile_image_url_by_id( - self, id: str, profile_image_url: str, db: Optional[Session] = None + self, id: str, profile_image_url: str ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: db.query(User).filter_by(id=id).update( {"profile_image_url": profile_image_url} ) @@ -572,7 +569,7 @@ def update_last_active_by_id( self, id: str, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: db.query(User).filter_by(id=id).update( {"last_active_at": int(time.time())} ) @@ -584,7 +581,7 @@ def update_last_active_by_id( return None def update_user_oauth_by_id( - self, id: str, provider: str, sub: str, db: Optional[Session] = None + self, id: str, provider: str, sub: str ) -> Optional[UserModel]: """ Update or insert an OAuth provider/sub pair into the user's oauth JSON field. @@ -595,7 +592,7 @@ def update_user_oauth_by_id( } """ try: - with get_db_context(db) as db: + with get_db() as db: user = db.query(User).filter_by(id=id).first() if not user: return None @@ -619,7 +616,7 @@ def update_user_by_id( self, id: str, updated: dict, db: Optional[Session] = None ) -> Optional[UserModel]: try: - with get_db_context(db) as db: + with get_db() as db: db.query(User).filter_by(id=id).update(updated) db.commit() @@ -654,15 +651,15 @@ def update_user_settings_by_id( except Exception: return None - def delete_user_by_id(self, id: str, db: Optional[Session] = None) -> bool: + def delete_user_by_id(self, id: str) -> bool: try: # Remove User from Groups Groups.remove_user_from_all_groups(id) # Delete User Chats - result = Chats.delete_chats_by_user_id(id, db=db) + result = Chats.delete_chats_by_user_id(id) if result: - with get_db_context(db) as db: + with get_db() as db: # Delete User db.query(User).filter_by(id=id).delete() db.commit() @@ -677,7 +674,7 @@ def get_user_api_key_by_id( self, id: str, db: Optional[Session] = None ) -> Optional[str]: try: - with get_db_context(db) as db: + with get_db() as db: api_key = db.query(ApiKey).filter_by(user_id=id).first() return api_key.key if api_key else None except Exception: @@ -687,7 +684,7 @@ def update_user_api_key_by_id( self, id: str, api_key: str, db: Optional[Session] = None ) -> bool: try: - with get_db_context(db) as db: + with get_db() as db: db.query(ApiKey).filter_by(user_id=id).delete() db.commit() @@ -707,9 +704,9 @@ def update_user_api_key_by_id( except Exception: return False - def delete_user_api_key_by_id(self, id: str, db: Optional[Session] = None) -> bool: + def delete_user_api_key_by_id(self, id: str) -> bool: try: - with get_db_context(db) as db: + with get_db() as db: db.query(ApiKey).filter_by(user_id=id).delete() db.commit() return True @@ -723,16 +720,16 @@ def get_valid_user_ids( users = db.query(User).filter(User.id.in_(user_ids)).all() return [user.id for user in users] - def get_super_admin_user(self, db: Optional[Session] = None) -> Optional[UserModel]: - with get_db_context(db) as db: + def get_super_admin_user(self) -> Optional[UserModel]: + with get_db() as db: user = db.query(User).filter_by(role="admin").first() if user: return UserModel.model_validate(user) else: return None - def get_active_user_count(self, db: Optional[Session] = None) -> int: - with get_db_context(db) as db: + def get_active_user_count(self) -> int: + with get_db() as db: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 count = ( @@ -740,8 +737,8 @@ def get_active_user_count(self, db: Optional[Session] = None) -> int: ) return count - def is_user_active(self, user_id: str, db: Optional[Session] = None) -> bool: - with get_db_context(db) as db: + def is_user_active(self, user_id: str) -> bool: + with get_db() as db: user = db.query(User).filter_by(id=user_id).first() if user and user.last_active_at: # Consider user active if last_active_at within the last 3 minutes diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 15430db1143..48f1b19b1ff 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -90,9 +90,9 @@ def __init__(self) -> None: # if no pgvector uri, use the existing database connection if not PGVECTOR_DB_URL: - from open_webui.internal.db import ScopedSession + from open_webui.internal.db import Session - self.session = ScopedSession + self.session = Session else: if isinstance(PGVECTOR_POOL_SIZE, int): if PGVECTOR_POOL_SIZE > 0: diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index 68595fb5956..b843e0926d0 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -53,10 +53,6 @@ def get_vector(vector_type: str) -> VectorDBBase: from open_webui.retrieval.vector.dbs.pgvector import PgvectorClient return PgvectorClient() - case VectorType.OPENGAUSS: - from open_webui.retrieval.vector.dbs.opengauss import OpenGaussClient - - return OpenGaussClient() case VectorType.ELASTICSEARCH: from open_webui.retrieval.vector.dbs.elasticsearch import ( ElasticsearchClient, diff --git a/backend/open_webui/retrieval/vector/type.py b/backend/open_webui/retrieval/vector/type.py index de20133fce4..292cad1e785 100644 --- a/backend/open_webui/retrieval/vector/type.py +++ b/backend/open_webui/retrieval/vector/type.py @@ -12,4 +12,3 @@ class VectorType(StrEnum): ORACLE23AI = "oracle23ai" S3VECTOR = "s3vector" WEAVIATE = "weaviate" - OPENGAUSS = "opengauss" diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 52e0182cad1..6a440219652 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -39,7 +39,6 @@ from open_webui.utils.headers import include_user_info_headers from open_webui.config import ( WHISPER_MODEL_AUTO_UPDATE, - WHISPER_COMPUTE_TYPE, WHISPER_MODEL_DIR, WHISPER_VAD_FILTER, CACHE_DIR, @@ -133,7 +132,7 @@ def set_faster_whisper_model(model: str, auto_update: bool = False): faster_whisper_kwargs = { "model_size_or_path": model, "device": DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu", - "compute_type": WHISPER_COMPUTE_TYPE, + "compute_type": "int8", "download_root": WHISPER_MODEL_DIR, "local_files_only": not auto_update, } diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 30d4ebe4cc3..761537af7a7 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -105,10 +105,7 @@ class SessionUserInfoResponse(SessionUserResponse, UserStatus): @router.get("/", response_model=SessionUserInfoResponse) async def get_session_user( - request: Request, - response: Response, - user=Depends(get_current_user), - db: Session = Depends(get_session), + request: Request, response: Response, user=Depends(get_current_user) ): auth_header = request.headers.get("Authorization") @@ -142,7 +139,7 @@ async def get_session_user( ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db + user.id, request.app.state.config.USER_PERMISSIONS ) return { @@ -171,15 +168,12 @@ async def get_session_user( @router.post("/update/profile", response_model=UserProfileImageResponse) async def update_profile( - form_data: UpdateProfileForm, - session_user=Depends(get_verified_user), - db: Session = Depends(get_session), + form_data: UpdateProfileForm, session_user=Depends(get_verified_user) ): if session_user: user = Users.update_user_by_id( session_user.id, form_data.model_dump(), - db=db, ) if user: return user @@ -222,17 +216,13 @@ async def update_timezone( @router.post("/update/password", response_model=bool) async def update_password( - form_data: UpdatePasswordForm, - session_user=Depends(get_current_user), - db: Session = Depends(get_session), + form_data: UpdatePasswordForm, session_user=Depends(get_current_user) ): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) if session_user: user = Auths.authenticate_user( - session_user.email, - lambda pw: verify_password(form_data.password, pw), - db=db, + session_user.email, lambda pw: verify_password(form_data.password, pw) ) if user: @@ -241,7 +231,7 @@ async def update_password( except Exception as e: raise HTTPException(400, detail=str(e)) hashed = get_password_hash(form_data.new_password) - return Auths.update_user_password_by_id(user.id, hashed, db=db) + return Auths.update_user_password_by_id(user.id, hashed) else: raise HTTPException(400, detail=ERROR_MESSAGES.INCORRECT_PASSWORD) else: @@ -252,12 +242,7 @@ async def update_password( # LDAP Authentication ############################ @router.post("/ldap", response_model=SessionUserResponse) -async def ldap_auth( - request: Request, - response: Response, - form_data: LdapForm, - db: Session = Depends(get_session), -): +async def ldap_auth(request: Request, response: Response, form_data: LdapForm): # Security checks FIRST - before loading any config if not request.app.state.config.ENABLE_LDAP: raise HTTPException(400, detail="LDAP authentication is not enabled") @@ -443,12 +428,12 @@ async def ldap_auth( if not connection_user.bind(): raise HTTPException(400, "Authentication failed.") - user = Users.get_user_by_email(email, db=db) + user = Users.get_user_by_email(email) if not user: try: role = ( "admin" - if not Users.has_users(db=db) + if not Users.has_users() else request.app.state.config.DEFAULT_USER_ROLE ) @@ -457,7 +442,6 @@ async def ldap_auth( password=str(uuid.uuid4()), name=cn, role=role, - db=db, ) if not user: @@ -468,7 +452,6 @@ async def ldap_auth( apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, - db=db, ) except HTTPException: @@ -479,7 +462,7 @@ async def ldap_auth( 500, detail="Internal error occurred during LDAP user creation." ) - user = Auths.authenticate_user_by_email(email, db=db) + user = Auths.authenticate_user_by_email(email) if user: expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) @@ -509,7 +492,7 @@ async def ldap_auth( ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db + user.id, request.app.state.config.USER_PERMISSIONS ) if ( @@ -518,9 +501,9 @@ async def ldap_auth( and user_groups ): if ENABLE_LDAP_GROUP_CREATION: - Groups.create_groups_by_group_names(user.id, user_groups, db=db) + Groups.create_groups_by_group_names(user.id, user_groups) try: - Groups.sync_groups_by_group_names(user.id, user_groups, db=db) + Groups.sync_groups_by_group_names(user.id, user_groups) log.info( f"Successfully synced groups for user {user.id}: {user_groups}" ) @@ -553,12 +536,7 @@ async def ldap_auth( @router.post("/signin", response_model=SessionUserResponse) -async def signin( - request: Request, - response: Response, - form_data: SigninForm, - db: Session = Depends(get_session), -): +async def signin(request: Request, response: Response, form_data: SigninForm): if not ENABLE_PASSWORD_AUTH: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -579,15 +557,14 @@ async def signin( except Exception as e: pass - if not Users.get_user_by_email(email.lower(), db=db): + if not Users.get_user_by_email(email.lower()): await signup( request, response, SignupForm(email=email, password=str(uuid.uuid4()), name=name), - db=db, ) - user = Auths.authenticate_user_by_email(email, db=db) + user = Auths.authenticate_user_by_email(email) if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": group_names = request.headers.get( WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" @@ -595,33 +572,28 @@ async def signin( group_names = [name.strip() for name in group_names if name.strip()] if group_names: - Groups.sync_groups_by_group_names(user.id, group_names, db=db) + Groups.sync_groups_by_group_names(user.id, group_names) elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin" - if Users.get_user_by_email(admin_email.lower(), db=db): + if Users.get_user_by_email(admin_email.lower()): user = Auths.authenticate_user( - admin_email.lower(), - lambda pw: verify_password(admin_password, pw), - db=db, + admin_email.lower(), lambda pw: verify_password(admin_password, pw) ) else: - if Users.has_users(db=db): + if Users.has_users(): raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) await signup( request, response, SignupForm(email=admin_email, password=admin_password, name="User"), - db=db, ) user = Auths.authenticate_user( - admin_email.lower(), - lambda pw: verify_password(admin_password, pw), - db=db, + admin_email.lower(), lambda pw: verify_password(admin_password, pw) ) else: if signin_rate_limiter.is_limited(form_data.email.lower()): @@ -640,9 +612,7 @@ async def signin( form_data.password = password_bytes.decode("utf-8", errors="ignore") user = Auths.authenticate_user( - form_data.email.lower(), - lambda pw: verify_password(form_data.password, pw), - db=db, + form_data.email.lower(), lambda pw: verify_password(form_data.password, pw) ) if user: @@ -674,7 +644,7 @@ async def signin( ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db + user.id, request.app.state.config.USER_PERMISSIONS ) return { @@ -698,13 +668,8 @@ async def signin( @router.post("/signup", response_model=SessionUserResponse) -async def signup( - request: Request, - response: Response, - form_data: SignupForm, - db: Session = Depends(get_session), -): - has_users = Users.has_users(db=db) +async def signup(request: Request, response: Response, form_data: SignupForm): + has_users = Users.has_users() if WEBUI_AUTH: if ( @@ -726,7 +691,7 @@ async def signup( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower(), db=db): + if Users.get_user_by_email(form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -744,7 +709,6 @@ async def signup( form_data.name, form_data.profile_image_url, role, - db=db, ) if user: @@ -787,7 +751,7 @@ async def signup( ) user_permissions = get_permissions( - user.id, request.app.state.config.USER_PERMISSIONS, db=db + user.id, request.app.state.config.USER_PERMISSIONS ) if not has_users: @@ -797,7 +761,6 @@ async def signup( apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, - db=db, ) return { @@ -819,9 +782,7 @@ async def signup( @router.get("/signout") -async def signout( - request: Request, response: Response, db: Session = Depends(get_session) -): +async def signout(request: Request, response: Response): # get auth token from headers or cookies token = None @@ -843,7 +804,7 @@ async def signout( if oauth_session_id: response.delete_cookie("oauth_session_id") - session = OAuthSessions.get_session_by_id(oauth_session_id, db=db) + session = OAuthSessions.get_session_by_id(oauth_session_id) oauth_server_metadata_url = ( request.app.state.oauth_manager.get_server_metadata_url(session.provider) if session @@ -906,17 +867,14 @@ async def signout( @router.post("/add", response_model=SigninResponse) async def add_user( - request: Request, - form_data: AddUserForm, - user=Depends(get_admin_user), - db: Session = Depends(get_session), + request: Request, form_data: AddUserForm, user=Depends(get_admin_user) ): if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT ) - if Users.get_user_by_email(form_data.email.lower(), db=db): + if Users.get_user_by_email(form_data.email.lower()): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: @@ -932,14 +890,12 @@ async def add_user( form_data.name, form_data.profile_image_url, form_data.role, - db=db, ) if user: apply_default_group_assignment( request.app.state.config.DEFAULT_GROUP_ID, user.id, - db=db, ) token = create_token(data={"id": user.id}) @@ -967,9 +923,7 @@ async def add_user( @router.get("/admin/details") -async def get_admin_details( - request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) -): +async def get_admin_details(request: Request, user=Depends(get_current_user)): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None @@ -977,11 +931,11 @@ async def get_admin_details( log.info(f"Admin details - Email: {admin_email}, Name: {admin_name}") if admin_email: - admin = Users.get_user_by_email(admin_email, db=db) + admin = Users.get_user_by_email(admin_email) if admin: admin_name = admin.name else: - admin = Users.get_first_user(db=db) + admin = Users.get_first_user() if admin: admin_email = admin.email admin_name = admin.name @@ -1241,9 +1195,7 @@ async def update_ldap_config( # create api key @router.post("/api_key", response_model=ApiKey) -async def generate_api_key( - request: Request, user=Depends(get_current_user), db: Session = Depends(get_session) -): +async def generate_api_key(request: Request, user=Depends(get_current_user)): if not request.app.state.config.ENABLE_API_KEYS or not has_permission( user.id, "features.api_keys", request.app.state.config.USER_PERMISSIONS ): @@ -1253,7 +1205,7 @@ async def generate_api_key( ) api_key = create_api_key() - success = Users.update_user_api_key_by_id(user.id, api_key, db=db) + success = Users.update_user_api_key_by_id(user.id, api_key) if success: return { @@ -1265,18 +1217,14 @@ async def generate_api_key( # delete api key @router.delete("/api_key", response_model=bool) -async def delete_api_key( - user=Depends(get_current_user), db: Session = Depends(get_session) -): - return Users.delete_user_api_key_by_id(user.id, db=db) +async def delete_api_key(user=Depends(get_current_user)): + return Users.delete_user_api_key_by_id(user.id) # get api key @router.get("/api_key", response_model=ApiKey) -async def get_api_key( - user=Depends(get_current_user), db: Session = Depends(get_session) -): - api_key = Users.get_user_api_key_by_id(user.id, db=db) +async def get_api_key(user=Depends(get_current_user)): + api_key = Users.get_user_api_key_by_id(user.id) if api_key: return { "api_key": api_key, diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 1c9b2229cf0..32911fa5094 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -22,8 +22,6 @@ from open_webui.config import UPLOAD_DIR from open_webui.constants import ERROR_MESSAGES -from open_webui.internal.db import get_session -from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request @@ -46,11 +44,7 @@ @router.get("/", response_model=list[FolderNameIdResponse]) -async def get_folders( - request: Request, - user=Depends(get_verified_user), - db: Session = Depends(get_session), -): +async def get_folders(request: Request, user=Depends(get_verified_user)): if request.app.state.config.ENABLE_FOLDERS is False: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, @@ -61,23 +55,22 @@ async def get_folders( user.id, "features.folders", request.app.state.config.USER_PERMISSIONS, - db=db, ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - folders = Folders.get_folders_by_user_id(user.id, db=db) + folders = Folders.get_folders_by_user_id(user.id) # Verify folder data integrity folder_list = [] for folder in folders: if folder.parent_id and not Folders.get_folder_by_id_and_user_id( - folder.parent_id, user.id, db=db + folder.parent_id, user.id ): folder = Folders.update_folder_parent_id_by_id_and_user_id( - folder.id, user.id, None, db=db + folder.id, user.id, None ) if folder.data: @@ -87,12 +80,12 @@ async def get_folders( if file.get("type") == "file": if Files.check_access_by_user_id( - file.get("id"), user.id, "read", db=db + file.get("id"), user.id, "read" ): valid_files.append(file) elif file.get("type") == "collection": if Knowledges.check_access_by_user_id( - file.get("id"), user.id, "read", db=db + file.get("id"), user.id, "read" ): valid_files.append(file) else: @@ -100,7 +93,7 @@ async def get_folders( folder.data["files"] = valid_files Folders.update_folder_by_id_and_user_id( - folder.id, user.id, FolderUpdateForm(data=folder.data), db=db + folder.id, user.id, FolderUpdateForm(data=folder.data) ) folder_list.append(FolderNameIdResponse(**folder.model_dump())) @@ -114,13 +107,9 @@ async def get_folders( @router.post("/") -def create_folder( - form_data: FolderForm, - user=Depends(get_verified_user), - db: Session = Depends(get_session), -): +def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - None, user.id, form_data.name, db=db + None, user.id, form_data.name ) if folder: @@ -130,7 +119,7 @@ def create_folder( ) try: - folder = Folders.insert_new_folder(user.id, form_data, db=db) + folder = Folders.insert_new_folder(user.id, form_data) return folder except Exception as e: log.exception(e) @@ -147,10 +136,8 @@ def create_folder( @router.get("/{id}", response_model=Optional[FolderModel]) -async def get_folder_by_id( - id: str, user=Depends(get_verified_user), db: Session = Depends(get_session) -): - folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) +async def get_folder_by_id(id: str, user=Depends(get_verified_user)): + folder = Folders.get_folder_by_id_and_user_id(id, user.id) if folder: return folder else: @@ -167,18 +154,15 @@ async def get_folder_by_id( @router.post("/{id}/update") async def update_folder_name_by_id( - id: str, - form_data: FolderUpdateForm, - user=Depends(get_verified_user), - db: Session = Depends(get_session), + id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user) ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) + folder = Folders.get_folder_by_id_and_user_id(id, user.id) if folder: if form_data.name is not None: # Check if folder with same name exists existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - folder.parent_id, user.id, form_data.name, db=db + folder.parent_id, user.id, form_data.name ) if existing_folder and existing_folder.id != id: raise HTTPException( @@ -187,9 +171,7 @@ async def update_folder_name_by_id( ) try: - folder = Folders.update_folder_by_id_and_user_id( - id, user.id, form_data, db=db - ) + folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data) return folder except Exception as e: log.exception(e) @@ -216,15 +198,12 @@ class FolderParentIdForm(BaseModel): @router.post("/{id}/update/parent") async def update_folder_parent_id_by_id( - id: str, - form_data: FolderParentIdForm, - user=Depends(get_verified_user), - db: Session = Depends(get_session), + id: str, form_data: FolderParentIdForm, user=Depends(get_verified_user) ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) + folder = Folders.get_folder_by_id_and_user_id(id, user.id) if folder: existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - form_data.parent_id, user.id, folder.name, db=db + form_data.parent_id, user.id, folder.name ) if existing_folder: @@ -235,7 +214,7 @@ async def update_folder_parent_id_by_id( try: folder = Folders.update_folder_parent_id_by_id_and_user_id( - id, user.id, form_data.parent_id, db=db + id, user.id, form_data.parent_id ) return folder except Exception as e: @@ -263,16 +242,13 @@ class FolderIsExpandedForm(BaseModel): @router.post("/{id}/update/expanded") async def update_folder_is_expanded_by_id( - id: str, - form_data: FolderIsExpandedForm, - user=Depends(get_verified_user), - db: Session = Depends(get_session), + id: str, form_data: FolderIsExpandedForm, user=Depends(get_verified_user) ): - folder = Folders.get_folder_by_id_and_user_id(id, user.id, db=db) + folder = Folders.get_folder_by_id_and_user_id(id, user.id) if folder: try: folder = Folders.update_folder_is_expanded_by_id_and_user_id( - id, user.id, form_data.is_expanded, db=db + id, user.id, form_data.is_expanded ) return folder except Exception as e: @@ -300,11 +276,10 @@ async def delete_folder_by_id( id: str, delete_contents: Optional[bool] = True, user=Depends(get_verified_user), - db: Session = Depends(get_session), ): - if Chats.count_chats_by_folder_id_and_user_id(id, user.id, db=db): + if Chats.count_chats_by_folder_id_and_user_id(id, user.id): chat_delete_permission = has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS, db=db + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS ) if user.role != "admin" and not chat_delete_permission: raise HTTPException( @@ -313,21 +288,19 @@ async def delete_folder_by_id( ) folders = [] - folders.append(Folders.get_folder_by_id_and_user_id(id, user.id, db=db)) + folders.append(Folders.get_folder_by_id_and_user_id(id, user.id)) while folders: folder = folders.pop() if folder: try: - folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id, db=db) + folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) for folder_id in folder_ids: if delete_contents: - Chats.delete_chats_by_user_id_and_folder_id( - user.id, folder_id, db=db - ) + Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) else: Chats.move_chats_by_user_id_and_folder_id( - user.id, folder_id, None, db=db + user.id, folder_id, None ) return True @@ -341,7 +314,7 @@ async def delete_folder_by_id( finally: # Get all subfolders subfolders = Folders.get_folders_by_parent_id_and_user_id( - folder.id, user.id, db=db + folder.id, user.id ) folders.extend(subfolders) diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index cc0cb8f5a3d..a04aa358dcd 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -16,9 +16,6 @@ from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.internal.db import get_session -from sqlalchemy.orm import Session - from open_webui.utils.auth import get_admin_user, get_verified_user @@ -32,11 +29,7 @@ @router.get("/", response_model=list[GroupResponse]) -async def get_groups( - share: Optional[bool] = None, - user=Depends(get_verified_user), - db: Session = Depends(get_session), -): +async def get_groups(share: Optional[bool] = None, user=Depends(get_verified_user)): filter = {} @@ -46,7 +39,7 @@ async def get_groups( if share is not None: filter["share"] = share - groups = Groups.get_groups(filter=filter, db=db) + groups = Groups.get_groups(filter=filter) return groups @@ -57,17 +50,13 @@ async def get_groups( @router.post("/create", response_model=Optional[GroupResponse]) -async def create_new_group( - form_data: GroupForm, - user=Depends(get_admin_user), - db: Session = Depends(get_session), -): +async def create_new_group(form_data: GroupForm, user=Depends(get_admin_user)): try: - group = Groups.insert_new_group(user.id, form_data, db=db) + group = Groups.insert_new_group(user.id, form_data) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id, db=db), + member_count=Groups.get_group_member_count_by_id(group.id), ) else: raise HTTPException( @@ -88,14 +77,12 @@ async def create_new_group( @router.get("/id/{id}", response_model=Optional[GroupResponse]) -async def get_group_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): - group = Groups.get_group_by_id(id, db=db) +async def get_group_by_id(id: str, user=Depends(get_admin_user)): + group = Groups.get_group_by_id(id) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id, db=db), + member_count=Groups.get_group_member_count_by_id(group.id), ) else: raise HTTPException( @@ -115,15 +102,13 @@ class GroupExportResponse(GroupResponse): @router.get("/id/{id}/export", response_model=Optional[GroupExportResponse]) -async def export_group_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): - group = Groups.get_group_by_id(id, db=db) +async def export_group_by_id(id: str, user=Depends(get_admin_user)): + group = Groups.get_group_by_id(id) if group: return GroupExportResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id, db=db), - user_ids=Groups.get_group_user_ids_by_id(group.id, db=db), + member_count=Groups.get_group_member_count_by_id(group.id), + user_ids=Groups.get_group_user_ids_by_id(group.id), ) else: raise HTTPException( @@ -138,11 +123,9 @@ async def export_group_by_id( @router.post("/id/{id}/users", response_model=list[UserInfoResponse]) -async def get_users_in_group( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +async def get_users_in_group(id: str, user=Depends(get_admin_user)): try: - users = Users.get_users_by_group_id(id, db=db) + users = Users.get_users_by_group_id(id) return users except Exception as e: log.exception(f"Error adding users to group {id}: {e}") @@ -159,17 +142,14 @@ async def get_users_in_group( @router.post("/id/{id}/update", response_model=Optional[GroupResponse]) async def update_group_by_id( - id: str, - form_data: GroupUpdateForm, - user=Depends(get_admin_user), - db: Session = Depends(get_session), + id: str, form_data: GroupUpdateForm, user=Depends(get_admin_user) ): try: - group = Groups.update_group_by_id(id, form_data, db=db) + group = Groups.update_group_by_id(id, form_data) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id, db=db), + member_count=Groups.get_group_member_count_by_id(group.id), ) else: raise HTTPException( @@ -191,20 +171,17 @@ async def update_group_by_id( @router.post("/id/{id}/users/add", response_model=Optional[GroupResponse]) async def add_user_to_group( - id: str, - form_data: UserIdsForm, - user=Depends(get_admin_user), - db: Session = Depends(get_session), + id: str, form_data: UserIdsForm, user=Depends(get_admin_user) ): try: if form_data.user_ids: - form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids, db=db) + form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) - group = Groups.add_users_to_group(id, form_data.user_ids, db=db) + group = Groups.add_users_to_group(id, form_data.user_ids) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id, db=db), + member_count=Groups.get_group_member_count_by_id(group.id), ) else: raise HTTPException( @@ -221,17 +198,14 @@ async def add_user_to_group( @router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse]) async def remove_users_from_group( - id: str, - form_data: UserIdsForm, - user=Depends(get_admin_user), - db: Session = Depends(get_session), + id: str, form_data: UserIdsForm, user=Depends(get_admin_user) ): try: - group = Groups.remove_users_from_group(id, form_data.user_ids, db=db) + group = Groups.remove_users_from_group(id, form_data.user_ids) if group: return GroupResponse( **group.model_dump(), - member_count=Groups.get_group_member_count_by_id(group.id, db=db), + member_count=Groups.get_group_member_count_by_id(group.id), ) else: raise HTTPException( @@ -252,11 +226,9 @@ async def remove_users_from_group( @router.delete("/id/{id}/delete", response_model=bool) -async def delete_group_by_id( - id: str, user=Depends(get_admin_user), db: Session = Depends(get_session) -): +async def delete_group_by_id(id: str, user=Depends(get_admin_user)): try: - result = Groups.delete_group_by_id(id, db=db) + result = Groups.delete_group_by_id(id) if result: return result else: diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index 0fc6930b81f..1737d490f2b 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -23,8 +23,6 @@ from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_permission from open_webui.utils.headers import include_user_info_headers -from open_webui.internal.db import get_session -from sqlalchemy.orm import Session from open_webui.utils.images.comfyui import ( ComfyUICreateImageForm, ComfyUIEditImageForm, @@ -500,7 +498,7 @@ def get_image_data(data: str, headers=None): return None, None -def upload_image(request, image_data, content_type, metadata, user, db=None): +def upload_image(request, image_data, content_type, metadata, user): image_format = mimetypes.guess_extension(content_type) file = UploadFile( file=io.BytesIO(image_data), @@ -528,7 +526,6 @@ def upload_image(request, image_data, content_type, metadata, user, db=None): message_id=message_id, file_ids=[file_item.id], user_id=user.id, - db=db, ) url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index 9070256770b..bd2fd3d4f78 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -25,10 +25,6 @@ ) from open_webui.constants import ERROR_MESSAGES - -from sqlalchemy.orm import Session -from open_webui.internal.db import get_session - log = logging.getLogger(__name__) router = APIRouter() @@ -300,7 +296,7 @@ def get_scim_auth( ) -def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: +def user_to_scim(user: UserModel, request: Request) -> SCIMUser: """Convert internal User model to SCIM User""" # Parse display name into name components name_parts = user.name.split(" ", 1) if user.name else ["", ""] @@ -308,7 +304,7 @@ def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: family_name = name_parts[1] if len(name_parts) > 1 else "" # Get user's groups - user_groups = Groups.get_groups_by_member_id(user.id, db=db) + user_groups = Groups.get_groups_by_member_id(user.id) groups = [ { "value": group.id, @@ -349,13 +345,13 @@ def user_to_scim(user: UserModel, request: Request, db=None) -> SCIMUser: ) -def group_to_scim(group: GroupModel, request: Request, db=None) -> SCIMGroup: +def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: """Convert internal Group model to SCIM Group""" - member_ids = Groups.get_group_user_ids_by_id(group.id, db) or [] + member_ids = Groups.get_group_user_ids_by_id(group.id) members = [] for user_id in member_ids: - user = Users.get_user_by_id(user_id, db=db) + user = Users.get_user_by_id(user_id) if user: members.append( SCIMGroupMember( @@ -487,7 +483,6 @@ async def get_users( count: int = Query(20, ge=1, le=100), filter: Optional[str] = None, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """List SCIM Users""" skip = startIndex - 1 @@ -499,20 +494,20 @@ async def get_users( # In production, you'd want a more robust filter parser if "userName eq" in filter: email = filter.split('"')[1] - user = Users.get_user_by_email(email, db=db) + user = Users.get_user_by_email(email) users_list = [user] if user else [] total = 1 if user else 0 else: - response = Users.get_users(skip=skip, limit=limit, db=db) + response = Users.get_users(skip=skip, limit=limit) users_list = response["users"] total = response["total"] else: - response = Users.get_users(skip=skip, limit=limit, db=db) + response = Users.get_users(skip=skip, limit=limit) users_list = response["users"] total = response["total"] # Convert to SCIM format - scim_users = [user_to_scim(user, request, db=db) for user in users_list] + scim_users = [user_to_scim(user, request) for user in users_list] return SCIMListResponse( totalResults=total, @@ -527,16 +522,15 @@ async def get_user( user_id: str, request: Request, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Get SCIM User by ID""" - user = Users.get_user_by_id(user_id, db=db) + user = Users.get_user_by_id(user_id) if not user: return scim_error( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" ) - return user_to_scim(user, request, db=db) + return user_to_scim(user, request) @router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED) @@ -544,11 +538,10 @@ async def create_user( request: Request, user_data: SCIMUserCreateRequest, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Create SCIM User""" # Check if user already exists - existing_user = Users.get_user_by_email(user_data.userName, db=db) + existing_user = Users.get_user_by_email(user_data.userName) if existing_user: raise HTTPException( status_code=status.HTTP_409_CONFLICT, @@ -579,7 +572,6 @@ async def create_user( email=email, profile_image_url=profile_image, role="user" if user_data.active else "pending", - db=db, ) if not new_user: @@ -588,7 +580,7 @@ async def create_user( detail="Failed to create user", ) - return user_to_scim(new_user, request, db=db) + return user_to_scim(new_user, request) @router.put("/Users/{user_id}", response_model=SCIMUser) @@ -597,10 +589,9 @@ async def update_user( request: Request, user_data: SCIMUserUpdateRequest, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Update SCIM User (full update)""" - user = Users.get_user_by_id(user_id, db=db) + user = Users.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -633,14 +624,14 @@ async def update_user( update_data["profile_image_url"] = user_data.photos[0].value # Update user - updated_user = Users.update_user_by_id(user_id, update_data, db=db) + updated_user = Users.update_user_by_id(user_id, update_data) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update user", ) - return user_to_scim(updated_user, request, db=db) + return user_to_scim(updated_user, request) @router.patch("/Users/{user_id}", response_model=SCIMUser) @@ -649,10 +640,9 @@ async def patch_user( request: Request, patch_data: SCIMPatchRequest, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Update SCIM User (partial update)""" - user = Users.get_user_by_id(user_id, db=db) + user = Users.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -680,7 +670,7 @@ async def patch_user( # Update user if update_data: - updated_user = Users.update_user_by_id(user_id, update_data, db=db) + updated_user = Users.update_user_by_id(user_id, update_data) if not updated_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -689,7 +679,7 @@ async def patch_user( else: updated_user = user - return user_to_scim(updated_user, request, db=db) + return user_to_scim(updated_user, request) @router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -697,17 +687,16 @@ async def delete_user( user_id: str, request: Request, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Delete SCIM User""" - user = Users.get_user_by_id(user_id, db=db) + user = Users.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found", ) - success = Users.delete_user_by_id(user_id, db=db) + success = Users.delete_user_by_id(user_id) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -725,11 +714,10 @@ async def get_groups( count: int = Query(20, ge=1, le=100), filter: Optional[str] = None, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """List SCIM Groups""" # Get all groups - groups_list = Groups.get_all_groups(db=db) + groups_list = Groups.get_all_groups() # Apply pagination total = len(groups_list) @@ -738,7 +726,7 @@ async def get_groups( paginated_groups = groups_list[start:end] # Convert to SCIM format - scim_groups = [group_to_scim(group, request, db=db) for group in paginated_groups] + scim_groups = [group_to_scim(group, request) for group in paginated_groups] return SCIMListResponse( totalResults=total, @@ -753,17 +741,16 @@ async def get_group( group_id: str, request: Request, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Get SCIM Group by ID""" - group = Groups.get_group_by_id(group_id, db=db) + group = Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found", ) - return group_to_scim(group, request, db=db) + return group_to_scim(group, request) @router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) @@ -771,7 +758,6 @@ async def create_group( request: Request, group_data: SCIMGroupCreateRequest, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Create SCIM Group""" # Extract member IDs @@ -789,14 +775,14 @@ async def create_group( ) # Need to get the creating user's ID - we'll use the first admin - admin_user = Users.get_super_admin_user(db=db) + admin_user = Users.get_super_admin_user() if not admin_user: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="No admin user found", ) - new_group = Groups.insert_new_group(admin_user.id, form, db=db) + new_group = Groups.insert_new_group(admin_user.id, form) if not new_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -812,12 +798,12 @@ async def create_group( description=new_group.description, ) - Groups.update_group_by_id(new_group.id, update_form, db=db) - Groups.set_group_user_ids_by_id(new_group.id, member_ids, db=db) + Groups.update_group_by_id(new_group.id, update_form) + Groups.set_group_user_ids_by_id(new_group.id, member_ids) - new_group = Groups.get_group_by_id(new_group.id, db=db) + new_group = Groups.get_group_by_id(new_group.id) - return group_to_scim(new_group, request, db=db) + return group_to_scim(new_group, request) @router.put("/Groups/{group_id}", response_model=SCIMGroup) @@ -826,10 +812,9 @@ async def update_group( request: Request, group_data: SCIMGroupUpdateRequest, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Update SCIM Group (full update)""" - group = Groups.get_group_by_id(group_id, db=db) + group = Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -847,17 +832,17 @@ async def update_group( # Handle members if provided if group_data.members is not None: member_ids = [member.value for member in group_data.members] - Groups.set_group_user_ids_by_id(group_id, member_ids, db=db) + Groups.set_group_user_ids_by_id(group_id, member_ids) # Update group - updated_group = Groups.update_group_by_id(group_id, update_form, db=db) + updated_group = Groups.update_group_by_id(group_id, update_form) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update group", ) - return group_to_scim(updated_group, request, db=db) + return group_to_scim(updated_group, request) @router.patch("/Groups/{group_id}", response_model=SCIMGroup) @@ -866,10 +851,9 @@ async def patch_group( request: Request, patch_data: SCIMPatchRequest, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Update SCIM Group (partial update)""" - group = Groups.get_group_by_id(group_id, db=db) + group = Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -894,7 +878,7 @@ async def patch_group( elif path == "members": # Replace all members Groups.set_group_user_ids_by_id( - group_id, [member["value"] for member in value], db=db + group_id, [member["value"] for member in value] ) elif op == "add": @@ -903,24 +887,22 @@ async def patch_group( if isinstance(value, list): for member in value: if isinstance(member, dict) and "value" in member: - Groups.add_users_to_group( - group_id, [member["value"]], db=db - ) + Groups.add_users_to_group(group_id, [member["value"]]) elif op == "remove": if path and path.startswith("members[value eq"): # Remove specific member member_id = path.split('"')[1] - Groups.remove_users_from_group(group_id, [member_id], db=db) + Groups.remove_users_from_group(group_id, [member_id]) # Update group - updated_group = Groups.update_group_by_id(group_id, update_form, db=db) + updated_group = Groups.update_group_by_id(group_id, update_form) if not updated_group: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to update group", ) - return group_to_scim(updated_group, request, db=db) + return group_to_scim(updated_group, request) @router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT) @@ -928,17 +910,16 @@ async def delete_group( group_id: str, request: Request, _: bool = Depends(get_scim_auth), - db: Session = Depends(get_session), ): """Delete SCIM Group""" - group = Groups.get_group_by_id(group_id, db=db) + group = Groups.get_group_by_id(group_id) if not group: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Group {group_id} not found", ) - success = Groups.delete_group_by_id(group_id, db=db) + success = Groups.delete_group_by_id(group_id) if not success: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/backend/open_webui/test/apps/webui/routers/test_chats.py b/backend/open_webui/test/apps/webui/routers/test_chats.py new file mode 100644 index 00000000000..a36a01fb149 --- /dev/null +++ b/backend/open_webui/test/apps/webui/routers/test_chats.py @@ -0,0 +1,236 @@ +import uuid + +from test.util.abstract_integration_test import AbstractPostgresTest +from test.util.mock_user import mock_webui_user + + +class TestChats(AbstractPostgresTest): + BASE_PATH = "/api/v1/chats" + + def setup_class(cls): + super().setup_class() + + def setup_method(self): + super().setup_method() + from open_webui.models.chats import ChatForm, Chats + + self.chats = Chats + self.chats.insert_new_chat( + "2", + ChatForm( + **{ + "chat": { + "name": "chat1", + "description": "chat1 description", + "tags": ["tag1", "tag2"], + "history": {"currentId": "1", "messages": []}, + } + } + ), + ) + + def test_get_session_user_chat_list(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/")) + assert response.status_code == 200 + first_chat = response.json()[0] + assert first_chat["id"] is not None + assert first_chat["title"] == "New Chat" + assert first_chat["created_at"] is not None + assert first_chat["updated_at"] is not None + + def test_delete_all_user_chats(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.delete(self.create_url("/")) + assert response.status_code == 200 + assert len(self.chats.get_chats()) == 0 + + def test_get_user_chat_list_by_user_id(self): + with mock_webui_user(id="3"): + response = self.fast_api_client.get(self.create_url("/list/user/2")) + assert response.status_code == 200 + first_chat = response.json()[0] + assert first_chat["id"] is not None + assert first_chat["title"] == "New Chat" + assert first_chat["created_at"] is not None + assert first_chat["updated_at"] is not None + + def test_create_new_chat(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url("/new"), + json={ + "chat": { + "name": "chat2", + "description": "chat2 description", + "tags": ["tag1", "tag2"], + } + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["archived"] is False + assert data["chat"] == { + "name": "chat2", + "description": "chat2 description", + "tags": ["tag1", "tag2"], + } + assert data["user_id"] == "2" + assert data["id"] is not None + assert data["share_id"] is None + assert data["title"] == "New Chat" + assert data["updated_at"] is not None + assert data["created_at"] is not None + assert len(self.chats.get_chats()) == 2 + + def test_get_user_chats(self): + self.test_get_session_user_chat_list() + + def test_get_user_archived_chats(self): + self.chats.archive_all_chats_by_user_id("2") + from open_webui.internal.db import Session + + Session.commit() + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url("/all/archived")) + assert response.status_code == 200 + first_chat = response.json()[0] + assert first_chat["id"] is not None + assert first_chat["title"] == "New Chat" + assert first_chat["created_at"] is not None + assert first_chat["updated_at"] is not None + + def test_get_all_user_chats_in_db(self): + with mock_webui_user(id="4"): + response = self.fast_api_client.get(self.create_url("/all/db")) + assert response.status_code == 200 + assert len(response.json()) == 1 + + def test_get_archived_session_user_chat_list(self): + self.test_get_user_archived_chats() + + def test_archive_all_chats(self): + with mock_webui_user(id="2"): + response = self.fast_api_client.post(self.create_url("/archive/all")) + assert response.status_code == 200 + assert len(self.chats.get_archived_chats_by_user_id("2")) == 1 + + def test_get_shared_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + self.chats.update_chat_share_id_by_id(chat_id, chat_id) + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}")) + assert response.status_code == 200 + data = response.json() + assert data["id"] == chat_id + assert data["chat"] == { + "name": "chat1", + "description": "chat1 description", + "tags": ["tag1", "tag2"], + "history": {"currentId": "1", "messages": []}, + } + assert data["id"] == chat_id + assert data["share_id"] == chat_id + assert data["title"] == "New Chat" + + def test_get_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/{chat_id}")) + assert response.status_code == 200 + data = response.json() + assert data["id"] == chat_id + assert data["chat"] == { + "name": "chat1", + "description": "chat1 description", + "tags": ["tag1", "tag2"], + "history": {"currentId": "1", "messages": []}, + } + assert data["share_id"] is None + assert data["title"] == "New Chat" + assert data["user_id"] == "2" + + def test_update_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.post( + self.create_url(f"/{chat_id}"), + json={ + "chat": { + "name": "chat2", + "description": "chat2 description", + "tags": ["tag2", "tag4"], + "title": "Just another title", + } + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["id"] == chat_id + assert data["chat"] == { + "name": "chat2", + "title": "Just another title", + "description": "chat2 description", + "tags": ["tag2", "tag4"], + "history": {"currentId": "1", "messages": []}, + } + assert data["share_id"] is None + assert data["title"] == "Just another title" + assert data["user_id"] == "2" + + def test_delete_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.delete(self.create_url(f"/{chat_id}")) + assert response.status_code == 200 + assert response.json() is True + + def test_clone_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone")) + + assert response.status_code == 200 + data = response.json() + assert data["id"] != chat_id + assert data["chat"] == { + "branchPointMessageId": "1", + "description": "chat1 description", + "history": {"currentId": "1", "messages": []}, + "name": "chat1", + "originalChatId": chat_id, + "tags": ["tag1", "tag2"], + "title": "Clone of New Chat", + } + assert data["share_id"] is None + assert data["title"] == "Clone of New Chat" + assert data["user_id"] == "2" + + def test_archive_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive")) + assert response.status_code == 200 + + chat = self.chats.get_chat_by_id(chat_id) + assert chat.archived is True + + def test_share_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + with mock_webui_user(id="2"): + response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share")) + assert response.status_code == 200 + + chat = self.chats.get_chat_by_id(chat_id) + assert chat.share_id is not None + + def test_delete_shared_chat_by_id(self): + chat_id = self.chats.get_chats()[0].id + share_id = str(uuid.uuid4()) + self.chats.update_chat_share_id_by_id(chat_id, share_id) + with mock_webui_user(id="2"): + response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share")) + assert response.status_code + + chat = self.chats.get_chat_by_id(chat_id) + assert chat.share_id is None diff --git a/backend/open_webui/test/test_oauth_google_groups.py b/backend/open_webui/test/test_oauth_google_groups.py new file mode 100644 index 00000000000..9bc1de9af25 --- /dev/null +++ b/backend/open_webui/test/test_oauth_google_groups.py @@ -0,0 +1,266 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +import aiohttp +from open_webui.utils.oauth import OAuthManager +from open_webui.config import AppConfig + + +class TestOAuthGoogleGroups: + """Basic tests for Google OAuth Groups functionality""" + + def setup_method(self): + """Setup test fixtures""" + self.oauth_manager = OAuthManager(app=MagicMock()) + + @pytest.mark.asyncio + async def test_fetch_google_groups_success(self): + """Test successful Google groups fetching with proper aiohttp mocking""" + # Mock response data from Google Cloud Identity API + mock_response_data = { + "memberships": [ + { + "groupKey": {"id": "admin@company.com"}, + "group": "groups/123", + "displayName": "Admin Group" + }, + { + "groupKey": {"id": "users@company.com"}, + "group": "groups/456", + "displayName": "Users Group" + } + ] + } + + # Create properly structured async mocks + mock_response = MagicMock() + mock_response.status = 200 + mock_response.json = AsyncMock(return_value=mock_response_data) + + # Mock the async context manager for session.get() + mock_get_context = MagicMock() + mock_get_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_get_context.__aexit__ = AsyncMock(return_value=None) + + # Mock the session + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_get_context) + + # Mock the async context manager for ClientSession + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + with patch("aiohttp.ClientSession", return_value=mock_session_context): + groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity( + access_token="test_token", + user_email="user@company.com" + ) + + # Verify the results + assert groups == ["admin@company.com", "users@company.com"] + + # Verify the HTTP call was made correctly + mock_session.get.assert_called_once() + call_args = mock_session.get.call_args + + # Check the URL contains the user email (URL encoded) + url_arg = call_args[0][0] # First positional argument + assert "user%40company.com" in url_arg # @ is encoded as %40 + assert "searchTransitiveGroups" in url_arg + + # Check headers contain the bearer token + headers_arg = call_args[1]["headers"] # headers keyword argument + assert headers_arg["Authorization"] == "Bearer test_token" + assert headers_arg["Content-Type"] == "application/json" + + @pytest.mark.asyncio + async def test_fetch_google_groups_api_error(self): + """Test handling of API errors when fetching groups""" + # Mock failed response + mock_response = MagicMock() + mock_response.status = 403 + mock_response.text = AsyncMock(return_value="Permission denied") + + # Mock the async context manager for session.get() + mock_get_context = MagicMock() + mock_get_context.__aenter__ = AsyncMock(return_value=mock_response) + mock_get_context.__aexit__ = AsyncMock(return_value=None) + + # Mock the session + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_get_context) + + # Mock the async context manager for ClientSession + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + with patch("aiohttp.ClientSession", return_value=mock_session_context): + groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity( + access_token="test_token", + user_email="user@company.com" + ) + + # Should return empty list on error + assert groups == [] + + @pytest.mark.asyncio + async def test_fetch_google_groups_network_error(self): + """Test handling of network errors when fetching groups""" + # Mock the session that raises an exception when get() is called + mock_session = MagicMock() + mock_session.get.side_effect = aiohttp.ClientError("Network error") + + # Mock the async context manager for ClientSession + mock_session_context = MagicMock() + mock_session_context.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_context.__aexit__ = AsyncMock(return_value=None) + + with patch("aiohttp.ClientSession", return_value=mock_session_context): + groups = await self.oauth_manager._fetch_google_groups_via_cloud_identity( + access_token="test_token", + user_email="user@company.com" + ) + + # Should return empty list on network error + assert groups == [] + + @pytest.mark.asyncio + async def test_get_user_role_with_google_groups(self): + """Test role assignment using Google groups""" + # Mock configuration + mock_config = MagicMock() + mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True + mock_config.OAUTH_ROLES_CLAIM = "groups" + mock_config.OAUTH_ALLOWED_ROLES = ["users@company.com"] + mock_config.OAUTH_ADMIN_ROLES = ["admin@company.com"] + mock_config.DEFAULT_USER_ROLE = "pending" + mock_config.OAUTH_EMAIL_CLAIM = "email" + + user_data = {"email": "user@company.com"} + + # Mock Google OAuth scope check and Users class + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \ + patch("open_webui.utils.oauth.Users") as mock_users, \ + patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch: + + mock_scope.value = "openid email profile https://www.googleapis.com/auth/cloud-identity.groups.readonly" + mock_fetch.return_value = ["admin@company.com", "users@company.com"] + mock_users.get_num_users.return_value = 5 # Not first user + + role = await self.oauth_manager.get_user_role( + user=None, + user_data=user_data, + provider="google", + access_token="test_token" + ) + + # Should assign admin role since user is in admin group + assert role == "admin" + mock_fetch.assert_called_once_with("test_token", "user@company.com") + + @pytest.mark.asyncio + async def test_get_user_role_fallback_to_claims(self): + """Test fallback to traditional claims when Google groups fail""" + mock_config = MagicMock() + mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True + mock_config.OAUTH_ROLES_CLAIM = "groups" + mock_config.OAUTH_ALLOWED_ROLES = ["users"] + mock_config.OAUTH_ADMIN_ROLES = ["admin"] + mock_config.DEFAULT_USER_ROLE = "pending" + mock_config.OAUTH_EMAIL_CLAIM = "email" + + user_data = { + "email": "user@company.com", + "groups": ["users"] + } + + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.GOOGLE_OAUTH_SCOPE") as mock_scope, \ + patch("open_webui.utils.oauth.Users") as mock_users, \ + patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch: + + # Mock scope without Cloud Identity + mock_scope.value = "openid email profile" + mock_users.get_num_users.return_value = 5 # Not first user + + role = await self.oauth_manager.get_user_role( + user=None, + user_data=user_data, + provider="google", + access_token="test_token" + ) + + # Should use traditional claims since Cloud Identity scope not present + assert role == "user" + mock_fetch.assert_not_called() + + @pytest.mark.asyncio + async def test_get_user_role_non_google_provider(self): + """Test that non-Google providers use traditional claims""" + mock_config = MagicMock() + mock_config.ENABLE_OAUTH_ROLE_MANAGEMENT = True + mock_config.OAUTH_ROLES_CLAIM = "roles" + mock_config.OAUTH_ALLOWED_ROLES = ["user"] + mock_config.OAUTH_ADMIN_ROLES = ["admin"] + mock_config.DEFAULT_USER_ROLE = "pending" + + user_data = {"roles": ["user"]} + + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.Users") as mock_users, \ + patch.object(self.oauth_manager, "_fetch_google_groups_via_cloud_identity") as mock_fetch: + + mock_users.get_num_users.return_value = 5 # Not first user + + role = await self.oauth_manager.get_user_role( + user=None, + user_data=user_data, + provider="microsoft", + access_token="test_token" + ) + + # Should use traditional claims for non-Google providers + assert role == "user" + mock_fetch.assert_not_called() + + @pytest.mark.asyncio + async def test_update_user_groups_with_google_groups(self): + """Test group management using Google groups from user_data""" + mock_config = MagicMock() + mock_config.OAUTH_GROUPS_CLAIM = "groups" + mock_config.OAUTH_BLOCKED_GROUPS = "[]" + mock_config.ENABLE_OAUTH_GROUP_CREATION = False + + # Mock user with Google groups data + mock_user = MagicMock() + mock_user.id = "user123" + + user_data = { + "google_groups": ["developers@company.com", "employees@company.com"] + } + + # Mock existing groups and user groups + mock_existing_group = MagicMock() + mock_existing_group.name = "developers@company.com" + mock_existing_group.id = "group1" + mock_existing_group.user_ids = [] + mock_existing_group.permissions = {"read": True} + mock_existing_group.description = "Developers group" + + with patch("open_webui.utils.oauth.auth_manager_config", mock_config), \ + patch("open_webui.utils.oauth.Groups") as mock_groups: + + mock_groups.get_groups_by_member_id.return_value = [] + mock_groups.get_groups.return_value = [mock_existing_group] + + await self.oauth_manager.update_user_groups( + user=mock_user, + user_data=user_data, + default_permissions={"read": True} + ) + + # Should use Google groups instead of traditional claims + mock_groups.get_groups_by_member_id.assert_called_once_with("user123") + mock_groups.update_group_by_id.assert_called() diff --git a/backend/open_webui/test/util/abstract_integration_test.py b/backend/open_webui/test/util/abstract_integration_test.py new file mode 100644 index 00000000000..e8492befb64 --- /dev/null +++ b/backend/open_webui/test/util/abstract_integration_test.py @@ -0,0 +1,161 @@ +import logging +import os +import time + +import docker +import pytest +from docker import DockerClient +from pytest_docker.plugin import get_docker_ip +from fastapi.testclient import TestClient +from sqlalchemy import text, create_engine + + +log = logging.getLogger(__name__) + + +def get_fast_api_client(): + from main import app + + with TestClient(app) as c: + return c + + +class AbstractIntegrationTest: + BASE_PATH = None + + def create_url(self, path="", query_params=None): + if self.BASE_PATH is None: + raise Exception("BASE_PATH is not set") + parts = self.BASE_PATH.split("/") + parts = [part.strip() for part in parts if part.strip() != ""] + path_parts = path.split("/") + path_parts = [part.strip() for part in path_parts if part.strip() != ""] + query_parts = "" + if query_params: + query_parts = "&".join( + [f"{key}={value}" for key, value in query_params.items()] + ) + query_parts = f"?{query_parts}" + return "/".join(parts + path_parts) + query_parts + + @classmethod + def setup_class(cls): + pass + + def setup_method(self): + pass + + @classmethod + def teardown_class(cls): + pass + + def teardown_method(self): + pass + + +class AbstractPostgresTest(AbstractIntegrationTest): + DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted" + docker_client: DockerClient + + @classmethod + def _create_db_url(cls, env_vars_postgres: dict) -> str: + host = get_docker_ip() + user = env_vars_postgres["POSTGRES_USER"] + pw = env_vars_postgres["POSTGRES_PASSWORD"] + port = 8081 + db = env_vars_postgres["POSTGRES_DB"] + return f"postgresql://{user}:{pw}@{host}:{port}/{db}" + + @classmethod + def setup_class(cls): + super().setup_class() + try: + env_vars_postgres = { + "POSTGRES_USER": "user", + "POSTGRES_PASSWORD": "example", + "POSTGRES_DB": "openwebui", + } + cls.docker_client = docker.from_env() + cls.docker_client.containers.run( + "postgres:16.2", + detach=True, + environment=env_vars_postgres, + name=cls.DOCKER_CONTAINER_NAME, + ports={5432: ("0.0.0.0", 8081)}, + command="postgres -c log_statement=all", + ) + time.sleep(0.5) + + database_url = cls._create_db_url(env_vars_postgres) + os.environ["DATABASE_URL"] = database_url + retries = 10 + db = None + while retries > 0: + try: + from open_webui.config import OPEN_WEBUI_DIR + + db = create_engine(database_url, pool_pre_ping=True) + db = db.connect() + log.info("postgres is ready!") + break + except Exception as e: + log.warning(e) + time.sleep(3) + retries -= 1 + + if db: + # import must be after setting env! + cls.fast_api_client = get_fast_api_client() + db.close() + else: + raise Exception("Could not connect to Postgres") + except Exception as ex: + log.error(ex) + cls.teardown_class() + pytest.fail(f"Could not setup test environment: {ex}") + + def _check_db_connection(self): + from open_webui.internal.db import Session + + retries = 10 + while retries > 0: + try: + Session.execute(text("SELECT 1")) + Session.commit() + break + except Exception as e: + Session.rollback() + log.warning(e) + time.sleep(3) + retries -= 1 + + def setup_method(self): + super().setup_method() + self._check_db_connection() + + @classmethod + def teardown_class(cls) -> None: + super().teardown_class() + cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True) + + def teardown_method(self): + from open_webui.internal.db import Session + + # rollback everything not yet committed + Session.commit() + + # truncate all tables + tables = [ + "auth", + "chat", + "chatidtag", + "document", + "memory", + "model", + "prompt", + "tag", + '"user"', + ] + for table in tables: + Session.execute(text(f"TRUNCATE TABLE {table}")) + Session.commit() diff --git a/backend/open_webui/test/util/mock_user.py b/backend/open_webui/test/util/mock_user.py new file mode 100644 index 00000000000..7ce64dffa99 --- /dev/null +++ b/backend/open_webui/test/util/mock_user.py @@ -0,0 +1,45 @@ +from contextlib import contextmanager + +from fastapi import FastAPI + + +@contextmanager +def mock_webui_user(**kwargs): + from open_webui.routers.webui import app + + with mock_user(app, **kwargs): + yield + + +@contextmanager +def mock_user(app: FastAPI, **kwargs): + from open_webui.utils.auth import ( + get_current_user, + get_verified_user, + get_admin_user, + get_current_user_by_api_key, + ) + from open_webui.models.users import User + + def create_user(): + user_parameters = { + "id": "1", + "name": "John Doe", + "email": "john.doe@openwebui.com", + "role": "user", + "profile_image_url": "/user.png", + "last_active_at": 1627351200, + "updated_at": 1627351200, + "created_at": 162735120, + **kwargs, + } + return User(**user_parameters) + + app.dependency_overrides = { + get_current_user: create_user, + get_verified_user: create_user, + get_admin_user: create_user, + get_current_user_by_api_key: create_user, + } + yield + app.dependency_overrides = {} diff --git a/backend/open_webui/utils/access_control.py b/backend/open_webui/utils/access_control.py index 7784f6efd7b..97d0b414913 100644 --- a/backend/open_webui/utils/access_control.py +++ b/backend/open_webui/utils/access_control.py @@ -28,7 +28,6 @@ def fill_missing_permissions( def get_permissions( user_id: str, default_permissions: Dict[str, Any], - db: Optional[Any] = None, ) -> Dict[str, Any]: """ Get all permissions for a user by combining the permissions of all groups the user is a member of. @@ -54,7 +53,7 @@ def combine_permissions( ) # Use the most permissive value (True > False) return permissions - user_groups = Groups.get_groups_by_member_id(user_id, db=db) + user_groups = Groups.get_groups_by_member_id(user_id) # Deep copy default permissions to avoid modifying the original dict permissions = json.loads(json.dumps(default_permissions)) @@ -73,7 +72,6 @@ def has_permission( user_id: str, permission_key: str, default_permissions: Dict[str, Any] = {}, - db: Optional[Any] = None, ) -> bool: """ Check if a user has a specific permission by checking the group permissions @@ -94,7 +92,7 @@ def get_permission(permissions: Dict[str, Any], keys: List[str]) -> bool: permission_hierarchy = permission_key.split(".") # Retrieve user group permissions - user_groups = Groups.get_groups_by_member_id(user_id, db=db) + user_groups = Groups.get_groups_by_member_id(user_id) for group in user_groups: if get_permission(group.permissions or {}, permission_hierarchy): @@ -129,7 +127,6 @@ def has_access( access_control: Optional[dict] = None, user_group_ids: Optional[Set[str]] = None, strict: bool = True, - db: Optional[Any] = None, ) -> bool: if access_control is None: if strict: @@ -138,7 +135,7 @@ def has_access( return True if user_group_ids is None: - user_groups = Groups.get_groups_by_member_id(user_id, db=db) + user_groups = Groups.get_groups_by_member_id(user_id) user_group_ids = {group.id for group in user_groups} permitted_ids = get_permitted_group_and_user_ids(type, access_control) @@ -155,10 +152,10 @@ def has_access( # Get all users with access to a resource def get_users_with_access( - type: str = "write", access_control: Optional[dict] = None, db: Optional[Any] = None + type: str = "write", access_control: Optional[dict] = None ) -> list[UserModel]: if access_control is None: - result = Users.get_users(filter={"roles": ["!pending"]}, db=db) + result = Users.get_users(filter={"roles": ["!pending"]}) return result.get("users", []) permitted_ids = get_permitted_group_and_user_ids(type, access_control) @@ -170,8 +167,8 @@ def get_users_with_access( user_ids_with_access = set(permitted_user_ids) - group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids, db=db) + group_user_ids_map = Groups.get_group_user_ids_by_ids(permitted_group_ids) for user_ids in group_user_ids_map.values(): user_ids_with_access.update(user_ids) - return Users.get_users_by_user_ids(list(user_ids_with_access), db=db) + return Users.get_users_by_user_ids(list(user_ids_with_access)) diff --git a/backend/open_webui/utils/groups.py b/backend/open_webui/utils/groups.py index 26fc5d8434f..0f15f27e2cd 100644 --- a/backend/open_webui/utils/groups.py +++ b/backend/open_webui/utils/groups.py @@ -7,7 +7,6 @@ def apply_default_group_assignment( default_group_id: str, user_id: str, - db=None, ) -> None: """ Apply default group assignment to a user if default_group_id is provided. @@ -18,7 +17,7 @@ def apply_default_group_assignment( """ if default_group_id: try: - Groups.add_users_to_group(default_group_id, [user_id], db=db) + Groups.add_users_to_group(default_group_id, [user_id]) except Exception as e: log.error( f"Failed to add user {user_id} to default group {default_group_id}: {e}" diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index c849eb25a82..a2413b327ab 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -7,6 +7,7 @@ import urllib import uuid import json +from urllib.parse import quote from datetime import datetime, timedelta import re @@ -58,6 +59,7 @@ OAUTH_AUDIENCE, WEBHOOK_URL, JWT_EXPIRES_IN, + GOOGLE_OAUTH_SCOPE, AppConfig, ) from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES @@ -635,27 +637,24 @@ async def _preflight_authorization_url( def get_client(self, client_id): if client_id not in self.clients: self.ensure_client_from_config(client_id) - client = self.clients.get(client_id) return client["client"] if client else None def get_client_info(self, client_id): if client_id not in self.clients: self.ensure_client_from_config(client_id) - client = self.clients.get(client_id) return client["client_info"] if client else None def get_server_metadata_url(self, client_id): - client = self.get_client(client_id) - if not client: - return None - - return ( - client._server_metadata_url - if hasattr(client, "_server_metadata_url") - else None - ) + if client_id in self.clients: + client = self.clients[client_id] + return ( + client._server_metadata_url + if hasattr(client, "_server_metadata_url") + else None + ) + return None async def get_oauth_token( self, user_id: str, client_id: str, force_refresh: bool = False @@ -1117,7 +1116,7 @@ async def _perform_token_refresh(self, session) -> dict: log.error(f"Exception during token refresh for provider {provider}: {e}") return None - def get_user_role(self, user, user_data): + async def get_user_role(self, user, user_data, provider=None, access_token=None): user_count = Users.get_num_users() if user and user_count == 1: # If the user is the only user, assign the role "admin" - actually repairs role for single user on login @@ -1161,6 +1160,47 @@ def get_user_role(self, user, user_data): elif isinstance(claim_data, int): oauth_roles = [str(claim_data)] + # Check if this is Google OAuth with Cloud Identity scope + if ( + provider == "google" + and access_token + and "https://www.googleapis.com/auth/cloud-identity.groups.readonly" + in GOOGLE_OAUTH_SCOPE.value + ): + + log.debug( + "Google OAuth with Cloud Identity scope detected - fetching groups via API" + ) + user_email = user_data.get(auth_manager_config.OAUTH_EMAIL_CLAIM, "") + if user_email: + try: + google_groups = ( + await self._fetch_google_groups_via_cloud_identity( + access_token, user_email + ) + ) + # Store groups in user_data for potential group management later + if "google_groups" not in user_data: + user_data["google_groups"] = google_groups + + # Use Google groups as oauth_roles for role determination + oauth_roles = google_groups + log.debug(f"Using Google groups as roles: {oauth_roles}") + except Exception as e: + log.error(f"Failed to fetch Google groups: {e}") + # Fall back to default behavior with claims + oauth_roles = [] + + # If not using Google groups or Google groups fetch failed, use traditional claims method + if not oauth_roles: + # Next block extracts the roles from the user data, accepting nested claims of any depth + if oauth_claim and oauth_allowed_roles and oauth_admin_roles: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + oauth_roles = claim_data if isinstance(claim_data, list) else [] + log.debug(f"Oauth Roles claim: {oauth_claim}") log.debug(f"User roles from oauth: {oauth_roles}") log.debug(f"Accepted user roles: {oauth_allowed_roles}") @@ -1178,7 +1218,9 @@ def get_user_role(self, user, user_data): for admin_role in oauth_admin_roles: # If the user has any of the admin roles, assign the role "admin" if admin_role in oauth_roles: - log.debug("Assigned user the admin role") + log.debug( + f"Assigned user the admin role based on group: {admin_role}" + ) role = "admin" break else: @@ -1191,7 +1233,88 @@ def get_user_role(self, user, user_data): return role - def update_user_groups(self, user, user_data, default_permissions, db=None): + async def _fetch_google_groups_via_cloud_identity( + self, access_token: str, user_email: str + ) -> list[str]: + """ + Fetch Google Workspace groups for a user via Cloud Identity API. + + Args: + access_token: OAuth access token with cloud-identity.groups.readonly scope + user_email: User's email address + + Returns: + List of group email addresses the user belongs to + """ + groups = [] + base_url = "https://content-cloudidentity.googleapis.com/v1/groups/-/memberships:searchTransitiveGroups" + + # Create the query string with proper URL encoding + query_string = f"member_key_id == '{user_email}' && 'cloudidentity.googleapis.com/groups.security' in labels" + encoded_query = quote(query_string) + + headers = { + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + } + + page_token = "" + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + while True: + # Build URL with query parameter + url = f"{base_url}?query={encoded_query}" + + # Add page token to URL if present + if page_token: + url += f"&pageToken={quote(page_token)}" + + log.debug("Fetching Google groups via Cloud Identity API") + + async with session.get( + url, headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL + ) as resp: + if resp.status == 200: + data = await resp.json() + + # Extract group emails from memberships + memberships = data.get("memberships", []) + log.debug(f"Found {len(memberships)} memberships") + for membership in memberships: + group_key = membership.get("groupKey", {}) + group_email = group_key.get("id", "") + if group_email: + groups.append(group_email) + log.debug(f"Found group membership: {group_email}") + + # Check for next page + page_token = data.get("nextPageToken", "") + if not page_token: + break + else: + error_text = await resp.text() + log.error( + f"Failed to fetch Google groups (status {resp.status})" + ) + # Log error details without sensitive information + try: + error_json = json.loads(error_text) + if "error" in error_json: + log.error(f"API error: {error_json['error'].get('message', 'Unknown error')}") + except json.JSONDecodeError: + log.error("Error response contains non-JSON data") + break + + except Exception as e: + log.error(f"Error fetching Google groups via Cloud Identity API: {e}") + + log.info(f"Retrieved {len(groups)} Google groups for user {user_email}") + return groups + + async def update_user_groups( + self, user, user_data, default_permissions, provider=None, access_token=None + ): log.debug("Running OAUTH Group management") oauth_claim = auth_manager_config.OAUTH_GROUPS_CLAIM @@ -1202,28 +1325,30 @@ def update_user_groups(self, user, user_data, default_permissions, db=None): blocked_groups = [] user_oauth_groups = [] - # Nested claim search for groups claim - if oauth_claim: - claim_data = user_data - nested_claims = oauth_claim.split(".") - for nested_claim in nested_claims: - claim_data = claim_data.get(nested_claim, {}) - - if isinstance(claim_data, list): - user_oauth_groups = claim_data - elif isinstance(claim_data, str): - # Split by the configured separator if present - if OAUTH_GROUPS_SEPARATOR in claim_data: - user_oauth_groups = claim_data.split(OAUTH_GROUPS_SEPARATOR) - else: + + # Check if Google groups were fetched via Cloud Identity API + if "google_groups" in user_data: + log.debug( + "Using Google groups from Cloud Identity API for group management" + ) + user_oauth_groups = user_data["google_groups"] + else: + # Nested claim search for groups claim (traditional method) + if oauth_claim: + claim_data = user_data + nested_claims = oauth_claim.split(".") + for nested_claim in nested_claims: + claim_data = claim_data.get(nested_claim, {}) + + if isinstance(claim_data, list): + user_oauth_groups = claim_data + elif isinstance(claim_data, str): user_oauth_groups = [claim_data] - else: - user_oauth_groups = [] + else: + user_oauth_groups = [] - user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id( - user.id, db=db - ) - all_available_groups: list[GroupModel] = Groups.get_all_groups(db=db) + user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id) + all_available_groups: list[GroupModel] = Groups.get_all_groups() # Create groups if they don't exist and creation is enabled if auth_manager_config.ENABLE_OAUTH_GROUP_CREATION: @@ -1249,7 +1374,7 @@ def update_user_groups(self, user, user_data, default_permissions, db=None): ) # Use determined creator ID (admin or fallback to current user) created_group = Groups.insert_new_group( - creator_id, new_group_form, db=db + creator_id, new_group_form ) if created_group: log.info( @@ -1267,7 +1392,7 @@ def update_user_groups(self, user, user_data, default_permissions, db=None): # Refresh the list of all available groups if any were created if groups_created: - all_available_groups = Groups.get_all_groups(db=db) + all_available_groups = Groups.get_all_groups() log.debug("Refreshed list of all available groups after creation.") log.debug(f"Oauth Groups claim: {oauth_claim}") @@ -1288,7 +1413,7 @@ def update_user_groups(self, user, user_data, default_permissions, db=None): log.debug( f"Removing user from group {group_model.name} as it is no longer in their oauth groups" ) - Groups.remove_users_from_group(group_model.id, [user.id], db=db) + Groups.remove_users_from_group(group_model.id, [user.id]) # In case a group is created, but perms are never assigned to the group by hitting "save" group_permissions = group_model.permissions @@ -1303,7 +1428,6 @@ def update_user_groups(self, user, user_data, default_permissions, db=None): permissions=group_permissions, ), overwrite=False, - db=db, ) # Add user to new groups @@ -1319,7 +1443,7 @@ def update_user_groups(self, user, user_data, default_permissions, db=None): f"Adding user to group {group_model.name} as it was found in their oauth groups" ) - Groups.add_users_to_group(group_model.id, [user.id], db=db) + Groups.add_users_to_group(group_model.id, [user.id]) # In case a group is created, but perms are never assigned to the group by hitting "save" group_permissions = group_model.permissions @@ -1334,7 +1458,6 @@ def update_user_groups(self, user, user_data, default_permissions, db=None): permissions=group_permissions, ), overwrite=False, - db=db, ) async def _process_picture_url( @@ -1399,7 +1522,7 @@ async def handle_login(self, request, provider): return await client.authorize_redirect(request, redirect_uri, **kwargs) - async def handle_callback(self, request, provider, response, db=None): + async def handle_callback(self, request, provider, response): if provider not in OAUTH_PROVIDERS: raise HTTPException(404) @@ -1427,9 +1550,8 @@ async def handle_callback(self, request, provider, response, db=None): exc_info=True, ) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) - - # Try to get userinfo from the token first, some providers include it there user_data: UserInfo = token.get("userinfo") + if ( (not user_data) or (auth_manager_config.OAUTH_EMAIL_CLAIM not in user_data) @@ -1515,8 +1637,7 @@ async def handle_callback(self, request, provider, response, db=None): # If allowed domains are configured, check if the email domain is in the list if ( "*" not in auth_manager_config.OAUTH_ALLOWED_DOMAINS - and email.split("@")[-1] - not in auth_manager_config.OAUTH_ALLOWED_DOMAINS + and email.split("@")[-1] not in auth_manager_config.OAUTH_ALLOWED_DOMAINS ): log.warning( f"OAuth callback failed, e-mail domain is not in the list of allowed domains: {user_data}" @@ -1524,20 +1645,23 @@ async def handle_callback(self, request, provider, response, db=None): raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED) # Check if the user exists - user = Users.get_user_by_oauth_sub(provider, sub, db=db) + user = Users.get_user_by_oauth_sub(provider, sub) if not user: # If the user does not exist, check if merging is enabled if auth_manager_config.OAUTH_MERGE_ACCOUNTS_BY_EMAIL: # Check if the user exists by email - user = Users.get_user_by_email(email, db=db) + user = Users.get_user_by_email(email) if user: # Update the user with the new oauth sub - Users.update_user_oauth_by_id(user.id, provider, sub, db=db) + Users.update_user_oauth_by_id(user.id, provider, sub) if user: - determined_role = self.get_user_role(user, user_data) + determined_role = await self.get_user_role( + user, user_data, provider, token.get("access_token") + ) if user.role != determined_role: - Users.update_user_role_by_id(user.id, determined_role, db=db) + Users.update_user_role_by_id(user.id, determined_role) + # Update the user object in memory as well, # to avoid problems with the ENABLE_OAUTH_GROUP_MANAGEMENT check below user.role = determined_role @@ -1546,36 +1670,36 @@ async def handle_callback(self, request, provider, response, db=None): picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM if picture_claim: new_picture_url = user_data.get( - picture_claim, - OAUTH_PROVIDERS[provider].get("picture_url", ""), + picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "") ) processed_picture_url = await self._process_picture_url( new_picture_url, token.get("access_token") ) if processed_picture_url != user.profile_image_url: Users.update_user_profile_image_url_by_id( - user.id, processed_picture_url, db=db + user.id, processed_picture_url ) log.debug(f"Updated profile picture for user {user.email}") - else: + + if not user: # If the user does not exist, check if signups are enabled if auth_manager_config.ENABLE_OAUTH_SIGNUP: # Check if an existing user with the same email already exists - existing_user = Users.get_user_by_email(email, db=db) + existing_user = Users.get_user_by_email(email) if existing_user: raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) picture_claim = auth_manager_config.OAUTH_PICTURE_CLAIM if picture_claim: picture_url = user_data.get( - picture_claim, - OAUTH_PROVIDERS[provider].get("picture_url", ""), + picture_claim, OAUTH_PROVIDERS[provider].get("picture_url", "") ) picture_url = await self._process_picture_url( picture_url, token.get("access_token") ) else: picture_url = "/user.png" + username_claim = auth_manager_config.OAUTH_USERNAME_CLAIM name = user_data.get(username_claim) @@ -1583,6 +1707,10 @@ async def handle_callback(self, request, provider, response, db=None): log.warning("Username claim is missing, using email as name") name = email + role = await self.get_user_role( + None, user_data, provider, token.get("access_token") + ) + user = Auths.insert_new_auth( email=email, password=get_password_hash( @@ -1590,9 +1718,8 @@ async def handle_callback(self, request, provider, response, db=None): ), # Random password, not used name=name, profile_image_url=picture_url, - role=self.get_user_role(None, user_data), + role=role, oauth=oauth_data, - db=db, ) if auth_manager_config.WEBHOOK_URL: @@ -1608,7 +1735,8 @@ async def handle_callback(self, request, provider, response, db=None): ) apply_default_group_assignment( - request.app.state.config.DEFAULT_GROUP_ID, user.id, db=db + request.app.state.config.DEFAULT_GROUP_ID, + user.id, ) else: @@ -1621,15 +1749,17 @@ async def handle_callback(self, request, provider, response, db=None): data={"id": user.id}, expires_delta=parse_duration(auth_manager_config.JWT_EXPIRES_IN), ) + if ( - auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT - and user.role != "admin" + auth_manager_config.ENABLE_OAUTH_GROUP_MANAGEMENT + and user.role != "admin" ): - self.update_user_groups( + await self.update_user_groups( user=user, user_data=user_data, default_permissions=request.app.state.config.USER_PERMISSIONS, - db=db, + provider=provider, + access_token=token.get("access_token"), ) except Exception as e: @@ -1680,16 +1810,15 @@ async def handle_callback(self, request, provider, response, db=None): token["expires_at"] = datetime.now().timestamp() + token["expires_in"] # Clean up any existing sessions for this user/provider first - sessions = OAuthSessions.get_sessions_by_user_id(user.id, db=db) + sessions = OAuthSessions.get_sessions_by_user_id(user.id) for session in sessions: if session.provider == provider: - OAuthSessions.delete_session_by_id(session.id, db=db) + OAuthSessions.delete_session_by_id(session.id) session = OAuthSessions.create_session( user_id=user.id, provider=provider, token=token, - db=db, ) response.set_cookie( diff --git a/backend/requirements.txt b/backend/requirements.txt index 51f0a8a1ae8..5f8b902bff6 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -43,6 +43,7 @@ mcp==1.25.0 openai anthropic google-genai==1.56.0 +google-generativeai==0.8.6 langchain==1.2.0 langchain-community==0.4.1 diff --git a/docker-compose.yaml b/docker-compose.yaml index 349734a9392..1e8999b0393 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -9,9 +9,10 @@ services: image: ollama/ollama:${OLLAMA_DOCKER_TAG-latest} open-webui: - build: - context: . - dockerfile: Dockerfile + network_mode: "host" + #build: + # context: . + # dockerfile: Dockerfile image: ghcr.io/open-webui/open-webui:${WEBUI_DOCKER_TAG-main} container_name: open-webui volumes: @@ -19,10 +20,12 @@ services: depends_on: - ollama ports: - - ${OPEN_WEBUI_PORT-3000}:8080 + - ${OPEN_WEBUI_PORT-3000}:8080 environment: - - 'OLLAMA_BASE_URL=http://ollama:11434' - - 'WEBUI_SECRET_KEY=' + - 'OLLAMA_BASE_URL=http://ollama:11434' +# - 'WEBUI_SECRET_KEY=' + #env_file: + # - .env extra_hosts: - host.docker.internal:host-gateway restart: unless-stopped diff --git a/docs/oauth-google-groups.md b/docs/oauth-google-groups.md new file mode 100644 index 00000000000..40bc62ba5bf --- /dev/null +++ b/docs/oauth-google-groups.md @@ -0,0 +1,95 @@ +# Google OAuth with Cloud Identity Groups Support + +This example demonstrates how to configure Open WebUI to use Google OAuth with Cloud Identity API for group-based role management. + +## Configuration + +### Environment Variables + +```bash +# Google OAuth Configuration +GOOGLE_CLIENT_ID="your-google-client-id.apps.googleusercontent.com" +GOOGLE_CLIENT_SECRET="your-google-client-secret" + +# IMPORTANT: Include the Cloud Identity Groups scope +GOOGLE_OAUTH_SCOPE="openid email profile https://www.googleapis.com/auth/cloud-identity.groups.readonly" + +# Enable OAuth features +ENABLE_OAUTH_SIGNUP=true +ENABLE_OAUTH_ROLE_MANAGEMENT=true +ENABLE_OAUTH_GROUP_MANAGEMENT=true + +# Configure admin roles using Google group emails +OAUTH_ADMIN_ROLES="admin@yourcompany.com,superadmin@yourcompany.com" +OAUTH_ALLOWED_ROLES="users@yourcompany.com,employees@yourcompany.com" + +# Optional: Configure group creation +ENABLE_OAUTH_GROUP_CREATION=true +``` + +## How It Works + +1. **Scope Detection**: When a user logs in with Google OAuth, the system checks if the `https://www.googleapis.com/auth/cloud-identity.groups.readonly` scope is present in `GOOGLE_OAUTH_SCOPE`. + +2. **Groups Fetching**: If the scope is present, the system uses the Google Cloud Identity API to fetch all groups the user belongs to, instead of relying on claims in the OAuth token. + +3. **Role Assignment**: + - If the user belongs to any group listed in `OAUTH_ADMIN_ROLES`, they get admin privileges + - If the user belongs to any group listed in `OAUTH_ALLOWED_ROLES`, they get user privileges + - Default role is applied if no matching groups are found + +4. **Group Management**: If `ENABLE_OAUTH_GROUP_MANAGEMENT` is enabled, Open WebUI groups are synchronized with Google Workspace groups. + +## Google Cloud Console Setup + +1. **Enable APIs**: + - Cloud Identity API + - Cloud Identity Groups API + +2. **OAuth 2.0 Setup**: + - Create OAuth 2.0 credentials + - Add authorized redirect URIs + - Configure consent screen + +3. **Required Scopes**: + ``` + openid + email + profile + https://www.googleapis.com/auth/cloud-identity.groups.readonly + ``` + +## Example Groups Structure + +``` +Your Google Workspace: +├── admin@yourcompany.com (Admin group) +├── superadmin@yourcompany.com (Super admin group) +├── users@yourcompany.com (Regular users) +├── employees@yourcompany.com (All employees) +└── developers@yourcompany.com (Development team) +``` + +## Fallback Behavior + +If the Cloud Identity scope is not present or the API call fails, the system falls back to the traditional method of reading roles from OAuth token claims. + +## Security Considerations + +- The Cloud Identity API requires proper authentication and authorization +- Only users with appropriate permissions can access group membership information +- Groups are fetched server-side, not exposed to the client +- Access tokens are handled securely and not logged + +## Troubleshooting + +1. **Groups not detected**: Ensure the Cloud Identity API is enabled and the OAuth client has the required scope +2. **Permission denied**: Verify the service account or OAuth client has Cloud Identity API access +3. **No admin role**: Check that the user belongs to a group listed in `OAUTH_ADMIN_ROLES` + +## Benefits Over Token Claims + +- **Real-time**: Groups are fetched fresh on each login +- **Complete**: Gets all group memberships, including nested groups +- **Accurate**: No dependency on ID token size limits +- **Flexible**: Can handle complex group hierarchies in Google Workspace \ No newline at end of file diff --git a/src/lib/apis/files/index.ts b/src/lib/apis/files/index.ts index 44af669fa1a..07042c4adea 100644 --- a/src/lib/apis/files/index.ts +++ b/src/lib/apis/files/index.ts @@ -252,7 +252,7 @@ export const getFileContentById = async (id: string) => { }) .then(async (res) => { if (!res.ok) throw await res.json(); - return await res.arrayBuffer(); + return await res.blob(); }) .catch((err) => { error = err.detail; diff --git a/src/lib/apis/index.ts b/src/lib/apis/index.ts index 8f35fbf8815..d4254ecf291 100644 --- a/src/lib/apis/index.ts +++ b/src/lib/apis/index.ts @@ -1,7 +1,10 @@ -import { WEBUI_BASE_URL } from '$lib/constants'; +import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants'; import { convertOpenApiToToolPayload } from '$lib/utils'; import { getOpenAIModelsDirect } from './openai'; +import { parse } from 'yaml'; +import { toast } from 'svelte-sonner'; + export const getModels = async ( token: string = '', connections: object | null = null, @@ -313,7 +316,7 @@ export const getToolServerData = async (token: string, url: string) => { // Check if URL ends with .yaml or .yml to determine format if (url.toLowerCase().endsWith('.yaml') || url.toLowerCase().endsWith('.yml')) { if (!res.ok) throw await res.text(); - const [text, { parse }] = await Promise.all([res.text(), import('yaml')]); + const text = await res.text(); return parse(text); } else { if (!res.ok) throw await res.json(); diff --git a/src/lib/components/admin/Settings/Models.svelte b/src/lib/components/admin/Settings/Models.svelte index d107e09f704..1c1fc6512b4 100644 --- a/src/lib/components/admin/Settings/Models.svelte +++ b/src/lib/components/admin/Settings/Models.svelte @@ -17,7 +17,6 @@ } from '$lib/apis/models'; import { copyToClipboard } from '$lib/utils'; import { page } from '$app/stores'; - import { updateUserSettings } from '$lib/apis/users'; import { getModels } from '$lib/apis'; import Search from '$lib/components/icons/Search.svelte'; @@ -219,19 +218,6 @@ saveAs(blob, `${model.id}-${Date.now()}.json`); }; - const pinModelHandler = async (modelId) => { - let pinnedModels = $settings?.pinnedModels ?? []; - - if (pinnedModels.includes(modelId)) { - pinnedModels = pinnedModels.filter((id) => id !== modelId); - } else { - pinnedModels = [...new Set([...pinnedModels, modelId])]; - } - - settings.set({ ...$settings, pinnedModels: pinnedModels }); - await updateUserSettings(localStorage.token, { ui: $settings }); - }; - onMount(async () => { await init(); const id = $page.url.searchParams.get('id'); @@ -441,9 +427,6 @@ hideHandler={() => { hideModelHandler(model); }} - pinModelHandler={() => { - pinModelHandler(model.id); - }} copyLinkHandler={() => { copyLinkHandler(model); }} diff --git a/src/lib/components/admin/Settings/Models/ModelMenu.svelte b/src/lib/components/admin/Settings/Models/ModelMenu.svelte index d4cd48a37dd..b7e694b1658 100644 --- a/src/lib/components/admin/Settings/Models/ModelMenu.svelte +++ b/src/lib/components/admin/Settings/Models/ModelMenu.svelte @@ -13,10 +13,8 @@ import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte'; import Download from '$lib/components/icons/Download.svelte'; import ArrowUpCircle from '$lib/components/icons/ArrowUpCircle.svelte'; - import Pin from '$lib/components/icons/Pin.svelte'; - import PinSlash from '$lib/components/icons/PinSlash.svelte'; - import { config, settings } from '$lib/stores'; + import { config } from '$lib/stores'; import Link from '$lib/components/icons/Link.svelte'; const i18n = getContext('i18n'); @@ -26,7 +24,6 @@ export let exportHandler: Function; export let hideHandler: Function; - export let pinModelHandler: Function; export let copyLinkHandler: Function; export let cloneHandler: Function; @@ -107,27 +104,6 @@ - { - pinModelHandler(model?.id); - }} - > - {#if ($settings?.pinnedModels ?? []).includes(model?.id)} - - {:else} - - {/if} - - - {#if ($settings?.pinnedModels ?? []).includes(model?.id)} - {$i18n.t('Hide from Sidebar')} - {:else} - {$i18n.t('Keep in Sidebar')} - {/if} - - - { diff --git a/src/lib/components/chat/ChatControls/Embeds.svelte b/src/lib/components/chat/ChatControls/Embeds.svelte index 126124bc69b..e15c86c8bd2 100644 --- a/src/lib/components/chat/ChatControls/Embeds.svelte +++ b/src/lib/components/chat/ChatControls/Embeds.svelte @@ -6,7 +6,7 @@ export let overlay = false; - const getSrcUrl = (url: string, chatId?: string, messageId?: string, sourceId: string) => { + const getSrcUrl = (url: string, chatId?: string, messageId?: string) => { try { const parsed = new URL(url); @@ -18,10 +18,6 @@ parsed.searchParams.set('message_id', messageId); } - if (sourceId) { - parsed.searchParams.set('source_id', sourceId); - } - return parsed.toString(); } catch { // Fallback for relative URLs or invalid input @@ -30,7 +26,6 @@ if (chatId) parts.push(`chat_id=${encodeURIComponent(chatId)}`); if (messageId) parts.push(`message_id=${encodeURIComponent(messageId)}`); - if (sourceId) parts.push(`source_id=${encodeURIComponent(sourceId)}`); if (parts.length === 0) return url; @@ -73,7 +68,7 @@ {/if} diff --git a/src/lib/components/chat/Messages/Citations.svelte b/src/lib/components/chat/Messages/Citations.svelte index 2db74581bdc..2799059b072 100644 --- a/src/lib/components/chat/Messages/Citations.svelte +++ b/src/lib/components/chat/Messages/Citations.svelte @@ -23,26 +23,12 @@ let selectedCitation: any = null; - export const showSourceModal = (sourceId) => { - let index; - let suffix = null; + export const showSourceModal = (sourceIdx) => { + if (citations[sourceIdx]) { + console.log('Showing citation modal for:', citations[sourceIdx]); - if (typeof sourceId === 'string') { - const output = sourceId.split('#'); - index = parseInt(output[0]) - 1; - - if (output.length > 1) { - suffix = output[1]; - } - } else { - index = sourceId - 1; - } - - if (citations[index]) { - console.log('Showing citation modal for:', citations[index]); - - if (citations[index]?.source?.embed_url) { - const embedUrl = citations[index].source.embed_url; + if (citations[sourceIdx]?.source?.embed_url) { + const embedUrl = citations[sourceIdx].source.embed_url; if (embedUrl) { if (readOnly) { // Open in new tab if readOnly @@ -53,19 +39,18 @@ showEmbeds.set(true); embed.set({ url: embedUrl, - title: citations[index]?.source?.name || 'Embedded Content', - source: citations[index], + title: citations[sourceIdx]?.source?.name || 'Embedded Content', + source: citations[sourceIdx], chatId: chatId, - messageId: id, - sourceId: sourceId + messageId: id }); } } else { - selectedCitation = citations[index]; + selectedCitation = citations[sourceIdx]; showCitationModal = true; } } else { - selectedCitation = citations[index]; + selectedCitation = citations[sourceIdx]; showCitationModal = true; } } diff --git a/src/lib/components/chat/Messages/Markdown/SourceToken.svelte b/src/lib/components/chat/Messages/Markdown/SourceToken.svelte index ac2b84cdcdc..7da6d8f89f9 100644 --- a/src/lib/components/chat/Messages/Markdown/SourceToken.svelte +++ b/src/lib/components/chat/Messages/Markdown/SourceToken.svelte @@ -41,9 +41,7 @@ {#if sourceIds} {#if (token?.ids ?? []).length == 1} - {@const id = token.ids[0]} - {@const identifier = token.citationIdentifiers ? token.citationIdentifiers[0] : id - 1} - + {:else} @@ -67,11 +65,9 @@ el={containerElement} > - {#each token.citationIdentifiers ?? token.ids as identifier} - {@const id = - typeof identifier === 'string' ? parseInt(identifier.split('#')[0]) : identifier} + {#each token.ids as sourceId} - + {/each} diff --git a/src/lib/components/chat/ModelSelector/ModelItemMenu.svelte b/src/lib/components/chat/ModelSelector/ModelItemMenu.svelte index 5f795a67091..64e79dbcbb6 100644 --- a/src/lib/components/chat/ModelSelector/ModelItemMenu.svelte +++ b/src/lib/components/chat/ModelSelector/ModelItemMenu.svelte @@ -5,10 +5,9 @@ import { getContext } from 'svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; - import Pin from '$lib/components/icons/Pin.svelte'; - import PinSlash from '$lib/components/icons/PinSlash.svelte'; - import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte'; import Link from '$lib/components/icons/Link.svelte'; + import Eye from '$lib/components/icons/Eye.svelte'; + import EyeSlash from '$lib/components/icons/EyeSlash.svelte'; import { settings } from '$lib/stores'; const i18n = getContext('i18n'); @@ -64,9 +63,9 @@ }} > {#if ($settings?.pinnedModels ?? []).includes(model?.id)} - + {:else} - + {/if} diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 03591d2b71f..594be3ff38c 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -64,6 +64,7 @@ import Note from '../icons/Note.svelte'; import { slide } from 'svelte/transition'; import HotkeyHint from '../common/HotkeyHint.svelte'; + import { key } from 'vega'; const BREAKPOINT = 768; diff --git a/src/lib/components/workspace/Models.svelte b/src/lib/components/workspace/Models.svelte index 44f91be079f..fc45ae54eae 100644 --- a/src/lib/components/workspace/Models.svelte +++ b/src/lib/components/workspace/Models.svelte @@ -24,7 +24,6 @@ import { getModels } from '$lib/apis'; import { getGroups } from '$lib/apis/groups'; - import { updateUserSettings } from '$lib/apis/users'; import { capitalizeFirstLetter, copyToClipboard } from '$lib/utils'; @@ -217,19 +216,6 @@ saveAs(blob, `${model.id}-${Date.now()}.json`); }; - const pinModelHandler = async (modelId) => { - let pinnedModels = $settings?.pinnedModels ?? []; - - if (pinnedModels.includes(modelId)) { - pinnedModels = pinnedModels.filter((id) => id !== modelId); - } else { - pinnedModels = [...new Set([...pinnedModels, modelId])]; - } - - settings.set({ ...$settings, pinnedModels: pinnedModels }); - await updateUserSettings(localStorage.token, { ui: $settings }); - }; - onMount(async () => { viewOption = localStorage.workspaceViewOption ?? ''; page = 1; @@ -563,9 +549,6 @@ hideHandler={() => { hideModelHandler(model); }} - pinModelHandler={() => { - pinModelHandler(model.id); - }} copyLinkHandler={() => { copyLinkHandler(model); }} diff --git a/src/lib/components/workspace/Models/DefaultFiltersSelector.svelte b/src/lib/components/workspace/Models/DefaultFiltersSelector.svelte index d03928a56ca..9b67d543d40 100644 --- a/src/lib/components/workspace/Models/DefaultFiltersSelector.svelte +++ b/src/lib/components/workspace/Models/DefaultFiltersSelector.svelte @@ -24,7 +24,7 @@ - {$i18n.t('Default Filters')} + {$i18n.t('Default Filters')} diff --git a/src/lib/components/workspace/Models/Knowledge/KnowledgeSelector.svelte b/src/lib/components/workspace/Models/Knowledge/KnowledgeSelector.svelte index fa50e9047ff..fff1be2fb60 100644 --- a/src/lib/components/workspace/Models/Knowledge/KnowledgeSelector.svelte +++ b/src/lib/components/workspace/Models/Knowledge/KnowledgeSelector.svelte @@ -112,7 +112,7 @@ - - - - - { - filesInputElement.click(); - }} - > - {#if info.meta.profile_image_url} - - {:else} - - {/if} - - - - - - - - - - + + + { + filesInputElement.click(); + }} + > + {#if info.meta.profile_image_url} + + {:else} + + {/if} + + - - - - { - info.meta.profile_image_url = `${WEBUI_BASE_URL}/static/favicon.png`; - }} - type="button" - > - {$i18n.t('Reset Image')} + + + + + + + + + + { + info.meta.profile_image_url = `${WEBUI_BASE_URL}/static/favicon.png`; + }} + type="button" + > + {$i18n.t('Reset Image')} + + + + + + + + + + - - - - + + - - - - - - - - - - { - showAccessControlModal = true; - }} - > - - - - {$i18n.t('Access')} - - - {#if preset} - - - {$i18n.t('Base Model (From)')} - + + { + showAccessControlModal = true; + }} + > + - - - {$i18n.t('Select a base model')} - {#each $models.filter((m) => (model ? m.id !== model.id : true) && !m?.preset && m?.owned_by !== 'arena' && !(m?.direct ?? false)) as model} - {model.name} - {/each} - + + {$i18n.t('Access')} - - {/if} + + + + {#if preset} - - - {$i18n.t('Description')} - + + {$i18n.t('Base Model (From)')} + - { - enableDescription = !enableDescription; - }} + + - {#if !enableDescription} - {$i18n.t('Default')} - {:else} - {$i18n.t('Custom')} - {/if} - + {$i18n.t('Select a base model')} + {#each $models.filter((m) => (model ? m.id !== model.id : true) && !m?.preset && m?.owned_by !== 'arena' && !(m?.direct ?? false)) as model} + {model.name} + {/each} + - - {#if enableDescription} - - {/if} + {/if} - - - { - const tagName = e.detail; - info.meta.tags = info.meta.tags.filter((tag) => tag.name !== tagName); - }} - on:add={(e) => { - const tagName = e.detail; - if (!(info?.meta?.tags ?? null)) { - info.meta.tags = [{ name: tagName }]; - } else { - info.meta.tags = [...info.meta.tags, { name: tagName }]; - } - }} - /> + + + + {$i18n.t('Description')} + + { + enableDescription = !enableDescription; + }} + > + {#if !enableDescription} + {$i18n.t('Default')} + {:else} + {$i18n.t('Custom')} + {/if} + + + + {#if enableDescription} + + {/if} + + + + + { + const tagName = e.detail; + info.meta.tags = info.meta.tags.filter((tag) => tag.name !== tagName); + }} + on:add={(e) => { + const tagName = e.detail; + if (!(info?.meta?.tags ?? null)) { + info.meta.tags = [{ name: tagName }]; + } else { + info.meta.tags = [...info.meta.tags, { name: tagName }]; + } + }} + /> + + @@ -713,10 +713,14 @@ {/if} + + + + diff --git a/src/lib/components/workspace/Models/ModelMenu.svelte b/src/lib/components/workspace/Models/ModelMenu.svelte index 30712eef5c7..825250e6afb 100644 --- a/src/lib/components/workspace/Models/ModelMenu.svelte +++ b/src/lib/components/workspace/Models/ModelMenu.svelte @@ -13,10 +13,8 @@ import DocumentDuplicate from '$lib/components/icons/DocumentDuplicate.svelte'; import Download from '$lib/components/icons/Download.svelte'; import ArrowUpCircle from '$lib/components/icons/ArrowUpCircle.svelte'; - import Pin from '$lib/components/icons/Pin.svelte'; - import PinSlash from '$lib/components/icons/PinSlash.svelte'; - import { config, user as currentUser, settings } from '$lib/stores'; + import { config, user as currentUser } from '$lib/stores'; import Link from '$lib/components/icons/Link.svelte'; const i18n = getContext('i18n'); @@ -31,7 +29,6 @@ export let copyLinkHandler: Function; export let hideHandler: Function; - export let pinModelHandler: Function; export let deleteHandler: Function; export let onClose: Function; @@ -127,27 +124,6 @@ - { - pinModelHandler(model?.id); - }} - > - {#if ($settings?.pinnedModels ?? []).includes(model?.id)} - - {:else} - - {/if} - - - {#if ($settings?.pinnedModels ?? []).includes(model?.id)} - {$i18n.t('Hide from Sidebar')} - {:else} - {$i18n.t('Keep in Sidebar')} - {/if} - - - { diff --git a/src/lib/i18n/locales/pt-BR/translation.json b/src/lib/i18n/locales/pt-BR/translation.json index b6d2966bbce..89f12727e24 100644 --- a/src/lib/i18n/locales/pt-BR/translation.json +++ b/src/lib/i18n/locales/pt-BR/translation.json @@ -12,7 +12,7 @@ "{{COUNT}} Available Tools": "{{COUNT}} Ferramentas disponíveis", "{{COUNT}} characters": "{{COUNT}} caracteres", "{{COUNT}} extracted lines": "{{COUNT}} linhas extraídas", - "{{COUNT}} files": "{{COUNT}} arquivos", + "{{COUNT}} files": "{COUNT}} arquivos", "{{COUNT}} hidden lines": "{{COUNT}} linhas ocultas", "{{COUNT}} Replies": "{{COUNT}} Respostas", "{{COUNT}} Rows": "{{COUNT}} Linhas",