diff --git a/.env.example b/.env.example index c38bf88bfb9..35ea12a8856 100644 --- a/.env.example +++ b/.env.example @@ -7,6 +7,15 @@ OPENAI_API_KEY='' # AUTOMATIC1111_BASE_URL="http://localhost:7860" +# For production, you should only need one host as +# fastapi serves the svelte-kit built frontend and backend from the same host and port. +# To test with CORS locally, you can set something like +# CORS_ALLOW_ORIGIN='http://localhost:5173;http://localhost:8080' +CORS_ALLOW_ORIGIN='*' + +# For production you should set this to match the proxy configuration (127.0.0.1) +FORWARDED_ALLOW_IPS='*' + # DO NOT TRACK SCARF_NO_ANALYTICS=true DO_NOT_TRACK=true diff --git a/.gitattributes b/.gitattributes index 526c8a38d4a..bf368a4c6ca 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1 +1,49 @@ -*.sh text eol=lf \ No newline at end of file +# TypeScript +*.ts text eol=lf +*.tsx text eol=lf + +# JavaScript +*.js text eol=lf +*.jsx text eol=lf +*.mjs text eol=lf +*.cjs text eol=lf + +# Svelte +*.svelte text eol=lf + +# HTML/CSS +*.html text eol=lf +*.css text eol=lf +*.scss text eol=lf +*.less text eol=lf + +# Config files and JSON +*.json text eol=lf +*.jsonc text eol=lf +*.yml text eol=lf +*.yaml text eol=lf +*.toml text eol=lf + +# Shell scripts +*.sh text eol=lf + +# Markdown & docs +*.md text eol=lf +*.mdx text eol=lf +*.txt text eol=lf + +# Git-related +.gitattributes text eol=lf +.gitignore text eol=lf + +# Prettier and other dotfiles +.prettierrc text eol=lf +.prettierignore text eol=lf +.eslintrc text eol=lf +.eslintignore text eol=lf +.stylelintrc text eol=lf +.editorconfig text eol=lf + +# Misc +*.env text eol=lf +*.lock text eol=lf \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml index feecd16c747..5be1ac21b3c 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -11,7 +11,9 @@ body: ## Important Notes - - **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project. + - **Before submitting a bug report**: Please check the [Issues](https://github.com/open-webui/open-webui/issues) and [Discussions](https://github.com/open-webui/open-webui/discussions) sections to see if a similar issue has already been reported. If unsure, start a discussion first, as this helps us efficiently focus on improving the project. Duplicates may be closed without notice. **Please search for existing issues and discussions.** + + - Check for opened, **but also for (recently) CLOSED issues** as the issue you are trying to report **might already have been fixed!** - **Respectful collaboration**: Open WebUI is a volunteer-driven project with a single maintainer and contributors who also have full-time jobs. Please be constructive and respectful in your communication. @@ -25,7 +27,9 @@ body: label: Check Existing Issues description: Confirm that you’ve checked for existing reports before submitting a new one. options: - - label: I have searched the existing issues and discussions. + - label: I have searched for any existing and/or related issues. + required: true + - label: I have searched for any existing and/or related discussions. required: true - label: I am using the latest version of Open WebUI. required: true @@ -47,7 +51,7 @@ body: id: open-webui-version attributes: label: Open WebUI Version - description: Specify the version (e.g., v0.3.11) + description: Specify the version (e.g., v0.6.26) validations: required: true @@ -63,7 +67,7 @@ body: id: operating-system attributes: label: Operating System - description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04) + description: Specify the OS (e.g., Windows 10, macOS Sonoma, Ubuntu 22.04, Debian 12) validations: required: true @@ -89,9 +93,20 @@ body: required: true - label: I have included the Docker container logs. required: true - - label: I have listed steps to reproduce the bug in detail. + - label: I have **provided every relevant configuration, setting, and environment variable used in my setup.** + required: true + - label: I have clearly **listed every relevant configuration, custom setting, environment variable, and command-line option that influences my setup** (such as Docker Compose overrides, .env values, browser settings, authentication configurations, etc). + required: true + - label: | + I have documented **step-by-step reproduction instructions that are precise, sequential, and leave nothing to interpretation**. My steps: + - Start with the initial platform/version/OS and dependencies used, + - Specify exact install/launch/configure commands, + - List URLs visited, user input (incl. example values/emails/passwords if needed), + - Describe all options and toggles enabled or changed, + - Include any files or environmental changes, + - Identify the expected and actual result at each stage, + - Ensure any reasonably skilled user can follow and hit the same issue. required: true - - type: textarea id: expected-behavior attributes: @@ -112,15 +127,26 @@ body: id: reproduction-steps attributes: label: Steps to Reproduce - description: Providing clear, step-by-step instructions helps us reproduce and fix the issue faster. If we can't reproduce it, we can't fix it. + description: | + Please provide a **very detailed, step-by-step guide** to reproduce the issue. Your instructions should be so clear and precise that anyone can follow them without guesswork. Include every relevant detail—settings, configuration options, exact commands used, values entered, and any prerequisites or environment variables. + **If full reproduction steps and all relevant settings are not provided, your issue may not be addressed.** + **If your steps to reproduction are incomplete, lacking detail or not reproducible, your issue can not be addressed.** + placeholder: | - 1. Go to '...' - 2. Click on '...' - 3. Scroll down to '...' - 4. See the error message '...' + Example (include every detail): + 1. Start with a clean Ubuntu 22.04 install. + 2. Install Docker v24.0.5 and start the service. + 3. Clone the Open WebUI repo (git clone ...). + 4. Use the Docker Compose file without modifications. + 5. Open browser Chrome 115.0 in incognito mode. + 6. Go to http://localhost:8080 and log in with user "test@example.com". + 7. Set the language to "English" and theme to "Dark". + 8. Attempt to connect to Ollama at "http://localhost:11434". + 9. Observe that the error message "Connection refused" appears at the top right. + + Please list each step carefully and include all relevant configuration, settings, and options. validations: required: true - - type: textarea id: logs-screenshots attributes: @@ -142,5 +168,5 @@ body: attributes: value: | ## Note - If the bug report is incomplete or does not follow instructions, it may not be addressed. Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue. + **If the bug report is incomplete, does not follow instructions or is lacking details it may not be addressed.** Ensure that you've followed all the **README.md** and **troubleshooting.md** guidelines, and provide all necessary information for us to reproduce the issue. Thank you for contributing to Open WebUI! diff --git a/.github/ISSUE_TEMPLATE/feature_request.yaml b/.github/ISSUE_TEMPLATE/feature_request.yaml index 2a326f65e46..4f159f4faa4 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yaml +++ b/.github/ISSUE_TEMPLATE/feature_request.yaml @@ -8,8 +8,9 @@ body: value: | ## Important Notes ### Before submitting - Please check the [Issues](https://github.com/open-webui/open-webui/issues) or [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted. + Please check the open AND closed [Issues](https://github.com/open-webui/open-webui/issues) AND [Discussions](https://github.com/open-webui/open-webui/discussions) to see if a similar request has been posted. It's likely we're already tracking it! If you’re unsure, start a discussion post first. + If your feature request might impact others in the community, consider opening a discussion instead and evaluate whether and how to implement it. This will help us efficiently focus on improving the project. ### Collaborate respectfully @@ -35,7 +36,7 @@ body: label: Check Existing Issues description: Please confirm that you've checked for existing similar requests options: - - label: I have searched the existing issues and discussions. + - label: I have searched all existing open AND closed issues and discussions for similar requests. I have found none that is comparable to my request. required: true - type: textarea id: problem-description diff --git a/.github/dependabot.yml b/.github/dependabot.yml index ed93957ea4a..1c83fd305bb 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -12,12 +12,6 @@ updates: interval: monthly target-branch: 'dev' - - package-ecosystem: npm - directory: '/' - schedule: - interval: monthly - target-branch: 'dev' - - package-ecosystem: 'github-actions' directory: '/' schedule: diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7f603cb10c5..0ec871f328f 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -4,14 +4,15 @@ **Before submitting, make sure you've checked the following:** -- [ ] **Target branch:** Please verify that the pull request targets the `dev` branch. +- [ ] **Target branch:** Verify that the pull request targets the `dev` branch. Not targeting the `dev` branch may lead to immediate closure of the PR. - [ ] **Description:** Provide a concise description of the changes made in this pull request. - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description. -- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources? +- [ ] **Documentation:** If necessary, update relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs) like environment variables, the tutorials, or other documentation sources. - [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation? -- [ ] **Testing:** Have you written and run sufficient tests to validate the changes? +- [ ] **Testing:** Perform manual tests to verify the implemented fix/feature works as intended AND does not break any other functionality. Take this as an opportunity to make screenshots of the feature/fix and include it in the PR description. +- [ ] **Agentic AI Code:**: Confirm this Pull Request is **not written by any AI Agent** or has at least gone through additional human review **and** manual testing. If any AI Agent is the co-author of this PR, it may lead to immediate closure of the PR. - [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards? -- [ ] **Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following: +- [ ] **Title Prefix:** To clearly categorize this pull request, prefix the pull request title using one of the following: - **BREAKING CHANGE**: Significant changes that may affect compatibility - **build**: Changes that affect the build system or external dependencies - **ci**: Changes to our continuous integration processes or workflows @@ -73,4 +74,4 @@ ### Contributor License Agreement -By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms. +By submitting this pull request, I confirm that I have read and fully agree to the [Contributor License Agreement (CLA)](https://github.com/open-webui/open-webui/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT), and I am providing my contributions under its terms. diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index 443d904199d..019fbb6baea 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -11,7 +11,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Check for changes in package.json run: | @@ -36,7 +36,7 @@ jobs: echo "::set-output name=content::$CHANGELOG_ESCAPED" - name: Create GitHub release - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -61,7 +61,7 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Trigger Docker build workflow - uses: actions/github-script@v7 + uses: actions/github-script@v8 with: script: | github.rest.actions.createWorkflowDispatch({ diff --git a/.github/workflows/deploy-to-hf-spaces.yml b/.github/workflows/deploy-to-hf-spaces.yml index 7fc66acf5c4..a30046af895 100644 --- a/.github/workflows/deploy-to-hf-spaces.yml +++ b/.github/workflows/deploy-to-hf-spaces.yml @@ -27,7 +27,7 @@ jobs: HF_TOKEN: ${{ secrets.HF_TOKEN }} steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: lfs: true diff --git a/.github/workflows/docker-build.yaml b/.github/workflows/docker-build.yaml index e61a69f33ae..a8f9266e9d7 100644 --- a/.github/workflows/docker-build.yaml +++ b/.github/workflows/docker-build.yaml @@ -14,16 +14,18 @@ env: jobs: build-main-image: - runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + runs-on: ${{ matrix.runner }} permissions: contents: read packages: write strategy: fail-fast: false matrix: - platform: - - linux/amd64 - - linux/arm64 + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm steps: # GitHub Packages requires the entire repository name to be in lowercase @@ -41,7 +43,7 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -111,16 +113,18 @@ jobs: retention-days: 1 build-cuda-image: - runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + runs-on: ${{ matrix.runner }} permissions: contents: read packages: write strategy: fail-fast: false matrix: - platform: - - linux/amd64 - - linux/arm64 + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm steps: # GitHub Packages requires the entire repository name to be in lowercase @@ -138,7 +142,7 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -210,17 +214,122 @@ jobs: if-no-files-found: error retention-days: 1 + build-cuda126-image: + runs-on: ${{ matrix.runner }} + permissions: + contents: read + packages: write + strategy: + fail-fast: false + matrix: + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm + + steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + + - name: Prepare + run: | + platform=${{ matrix.platform }} + echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV + + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for Docker images (cuda126 tag) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=git- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126 + flavor: | + latest=${{ github.ref == 'refs/heads/main' }} + suffix=-cuda126,onlatest=true + + - name: Extract metadata for Docker cache + id: cache-meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + ${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }} + flavor: | + prefix=cache-cuda126-${{ matrix.platform }}- + latest=false + + - name: Build Docker image (cuda126) + uses: docker/build-push-action@v5 + id: build + with: + context: . + push: true + platforms: ${{ matrix.platform }} + labels: ${{ steps.meta.outputs.labels }} + outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true + cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} + cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max + build-args: | + BUILD_HASH=${{ github.sha }} + USE_CUDA=true + USE_CUDA_VER=cu126 + + - name: Export digest + run: | + mkdir -p /tmp/digests + digest="${{ steps.build.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + uses: actions/upload-artifact@v4 + with: + name: digests-cuda126-${{ env.PLATFORM_PAIR }} + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + build-ollama-image: - runs-on: ${{ matrix.platform == 'linux/arm64' && 'ubuntu-24.04-arm' || 'ubuntu-latest' }} + runs-on: ${{ matrix.runner }} permissions: contents: read packages: write strategy: fail-fast: false matrix: - platform: - - linux/amd64 - - linux/arm64 + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm steps: # GitHub Packages requires the entire repository name to be in lowercase @@ -238,7 +347,7 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Set up QEMU uses: docker/setup-qemu-action@v3 @@ -310,6 +419,108 @@ jobs: if-no-files-found: error retention-days: 1 + build-slim-image: + runs-on: ${{ matrix.runner }} + permissions: + contents: read + packages: write + strategy: + fail-fast: false + matrix: + include: + - platform: linux/amd64 + runner: ubuntu-latest + - platform: linux/arm64 + runner: ubuntu-24.04-arm + + steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + + - name: Prepare + run: | + platform=${{ matrix.platform }} + echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV + + - name: Checkout repository + uses: actions/checkout@v5 + + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for Docker images (slim tag) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=git- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=slim + flavor: | + latest=${{ github.ref == 'refs/heads/main' }} + suffix=-slim,onlatest=true + + - name: Extract metadata for Docker cache + id: cache-meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + ${{ github.ref_type == 'tag' && 'type=raw,value=main' || '' }} + flavor: | + prefix=cache-slim-${{ matrix.platform }}- + latest=false + + - name: Build Docker image (slim) + uses: docker/build-push-action@v5 + id: build + with: + context: . + push: true + platforms: ${{ matrix.platform }} + labels: ${{ steps.meta.outputs.labels }} + outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true + cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} + cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max + build-args: | + BUILD_HASH=${{ github.sha }} + USE_SLIM=true + + - name: Export digest + run: | + mkdir -p /tmp/digests + digest="${{ steps.build.outputs.digest }}" + touch "/tmp/digests/${digest#sha256:}" + + - name: Upload digest + uses: actions/upload-artifact@v4 + with: + name: digests-slim-${{ env.PLATFORM_PAIR }} + path: /tmp/digests/* + if-no-files-found: error + retention-days: 1 + merge-main-images: runs-on: ubuntu-latest needs: [build-main-image] @@ -324,7 +535,7 @@ jobs: IMAGE_NAME: '${{ github.repository }}' - name: Download digests - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: pattern: digests-main-* path: /tmp/digests @@ -378,7 +589,7 @@ jobs: IMAGE_NAME: '${{ github.repository }}' - name: Download digests - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: pattern: digests-cuda-* path: /tmp/digests @@ -420,6 +631,62 @@ jobs: run: | docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} + merge-cuda126-images: + runs-on: ubuntu-latest + needs: [build-cuda126-image] + steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + + - name: Download digests + uses: actions/download-artifact@v5 + with: + pattern: digests-cuda126-* + path: /tmp/digests + merge-multiple: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for Docker images (default latest tag) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=git- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=cuda126 + flavor: | + latest=${{ github.ref == 'refs/heads/main' }} + suffix=-cuda126,onlatest=true + + - name: Create manifest list and push + working-directory: /tmp/digests + run: | + docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ + $(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *) + + - name: Inspect image + run: | + docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} + merge-ollama-images: runs-on: ubuntu-latest needs: [build-ollama-image] @@ -434,7 +701,7 @@ jobs: IMAGE_NAME: '${{ github.repository }}' - name: Download digests - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v5 with: pattern: digests-ollama-* path: /tmp/digests @@ -475,3 +742,59 @@ jobs: - name: Inspect image run: | docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} + + merge-slim-images: + runs-on: ubuntu-latest + needs: [build-slim-image] + steps: + # GitHub Packages requires the entire repository name to be in lowercase + # although the repository owner has a lowercase username, this prevents some people from running actions after forking + - name: Set repository and image name to lowercase + run: | + echo "IMAGE_NAME=${IMAGE_NAME,,}" >>${GITHUB_ENV} + echo "FULL_IMAGE_NAME=ghcr.io/${IMAGE_NAME,,}" >>${GITHUB_ENV} + env: + IMAGE_NAME: '${{ github.repository }}' + + - name: Download digests + uses: actions/download-artifact@v5 + with: + pattern: digests-slim-* + path: /tmp/digests + merge-multiple: true + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to the Container registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata for Docker images (default slim tag) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.FULL_IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=tag + type=sha,prefix=git- + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=raw,enable=${{ github.ref == 'refs/heads/main' }},prefix=,suffix=,value=slim + flavor: | + latest=${{ github.ref == 'refs/heads/main' }} + suffix=-slim,onlatest=true + + - name: Create manifest list and push + working-directory: /tmp/digests + run: | + docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ + $(printf '${{ env.FULL_IMAGE_NAME }}@sha256:%s ' *) + + - name: Inspect image + run: | + docker buildx imagetools inspect ${{ env.FULL_IMAGE_NAME }}:${{ steps.meta.outputs.version }} diff --git a/.github/workflows/format-backend.yaml b/.github/workflows/format-backend.yaml index 1bcdd92c1db..562e6aa1c13 100644 --- a/.github/workflows/format-backend.yaml +++ b/.github/workflows/format-backend.yaml @@ -30,10 +30,10 @@ jobs: - 3.12.x steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v5 - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '${{ matrix.python-version }}' diff --git a/.github/workflows/format-build-frontend.yaml b/.github/workflows/format-build-frontend.yaml index 9a007581ffe..eaa1072fbc4 100644 --- a/.github/workflows/format-build-frontend.yaml +++ b/.github/workflows/format-build-frontend.yaml @@ -24,15 +24,15 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v5 with: node-version: '22' - name: Install Dependencies - run: npm install + run: npm install --force - name: Format Frontend run: npm run format @@ -51,15 +51,15 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Setup Node.js - uses: actions/setup-node@v4 + uses: actions/setup-node@v5 with: node-version: '22' - name: Install Dependencies - run: npm ci + run: npm ci --force - name: Run vitest run: npm run test:frontend diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index fd1adab3a93..9995ccedae0 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -16,15 +16,15 @@ jobs: id-token: write steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 - name: Install Git run: sudo apt-get update && sudo apt-get install -y git - - uses: actions/setup-node@v4 + - uses: actions/setup-node@v5 with: node-version: 22 - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: 3.11 - name: Build diff --git a/.gitignore b/.gitignore index 32271f8087e..07494bd151c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +x.py +yarn.lock .DS_Store node_modules /build @@ -12,7 +14,8 @@ vite.config.ts.timestamp-* __pycache__/ *.py[cod] *$py.class - +.nvmrc +CLAUDE.md # C extensions *.so diff --git a/.prettierrc b/.prettierrc index a77fddea909..22558729f47 100644 --- a/.prettierrc +++ b/.prettierrc @@ -5,5 +5,6 @@ "printWidth": 100, "plugins": ["prettier-plugin-svelte"], "pluginSearchDirs": ["."], - "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }] + "overrides": [{ "files": "*.svelte", "options": { "parser": "svelte" } }], + "endOfLine": "lf" } diff --git a/CHANGELOG.md b/CHANGELOG.md index d17f2cf2e74..38e3e2be4de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,893 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.6.34] - 2025-10-16 + +### Added + +- 📄 MinerU is now supported as a document parser backend, with support for both local and managed API deployments. [#18306](https://github.com/open-webui/open-webui/pull/18306) +- 🔒 JWT token expiration default is now set to 4 weeks instead of never expiring, with security warnings displayed in backend logs and admin UI when set to unlimited. [#18261](https://github.com/open-webui/open-webui/pull/18261), [#18262](https://github.com/open-webui/open-webui/pull/18262) +- ⚡ Page loading performance is improved by preventing unnecessary API requests when sidebar folders are not expanded. [#18179](https://github.com/open-webui/open-webui/pull/18179), [#17476](https://github.com/open-webui/open-webui/issues/17476) +- 📁 File hash values are now included in the knowledge endpoint response, enabling efficient file synchronization through hash comparison. [#18284](https://github.com/open-webui/open-webui/pull/18284), [#18283](https://github.com/open-webui/open-webui/issues/18283) +- 🎨 Chat dialog scrollbar visibility is improved by increasing its width, making it easier to use for navigation. [#18369](https://github.com/open-webui/open-webui/pull/18369), [#11782](https://github.com/open-webui/open-webui/issues/11782) +- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security. +- 🌐 Translations for Catalan, Chinese, Czech, Finnish, German, Kabyle, Korean, Portuguese (Brazil), Spanish, Thai, and Turkish were enhanced and expanded. + +### Fixed + +- 📚 Focused retrieval mode now works correctly, preventing the system from forcing full context mode and loading all documents in a knowledge base regardless of settings. [#18133](https://github.com/open-webui/open-webui/issues/18133) +- 🔧 Filter inlet functions now correctly execute on tool call continuations, ensuring parameter persistence throughout tool interactions. [#18222](https://github.com/open-webui/open-webui/issues/18222) +- 🛠️ External tool servers now properly support DELETE requests with body data. [#18289](https://github.com/open-webui/open-webui/pull/18289), [#18287](https://github.com/open-webui/open-webui/issues/18287) +- 🗄️ Oracle23ai vector database client now correctly handles variable initialization, resolving UnboundLocalError when retrieving items from collections. [#18356](https://github.com/open-webui/open-webui/issues/18356) +- 🔧 Model auto-pull functionality now works correctly even when user settings remain unmodified. [#18324](https://github.com/open-webui/open-webui/pull/18324) +- 🎨 Duplicate HTML content in artifacts is now prevented by improving code block detection logic. [#18195](https://github.com/open-webui/open-webui/pull/18195), [#6154](https://github.com/open-webui/open-webui/issues/6154) +- 💬 Pinned chats now appear in the Reference Chats list and can be referenced in conversations. [#18288](https://github.com/open-webui/open-webui/issues/18288) +- 📝 Misleading knowledge base warning text in documents settings is clarified to correctly instruct users about reindexing vectors. [#18263](https://github.com/open-webui/open-webui/pull/18263) +- 🔔 Toast notifications can now be dismissed even when a modal is open. [#18260](https://github.com/open-webui/open-webui/pull/18260) +- 🔘 The "Chats" button in the sidebar now correctly toggles chat list visibility without navigating away from the current page. [#18232](https://github.com/open-webui/open-webui/pull/18232) +- 🎯 The Integrations menu no longer closes prematurely when clicking outside the Valves modal. [#18310](https://github.com/open-webui/open-webui/pull/18310) +- 🛠️ Tool ID display issues where "undefined" was incorrectly shown in the interface are now resolved. [#18178](https://github.com/open-webui/open-webui/pull/18178) +- 🛠️ Model management issues caused by excessively long model IDs are now prevented through validation that limits model IDs to 256 characters. [#18125](https://github.com/open-webui/open-webui/issues/18125) + +## [0.6.33] - 2025-10-08 + +### Added + +- 🎨 Workspace interface received a comprehensive redesign across Models, Knowledge, Prompts, and Tools sections, featuring reorganized controls, view filters for created vs shared items, tag selectors, improved visual hierarchy, and streamlined import/export functionality. [Commit](https://github.com/open-webui/open-webui/commit/2c59a288603d8c5f004f223ee00fef37cc763a8e), [Commit](https://github.com/open-webui/open-webui/commit/6050c86ab6ef6b8c96dd3f99c62a6867011b67a4), [Commit](https://github.com/open-webui/open-webui/commit/96ecb47bc71c072aa34ef2be10781b042bef4e8c), [Commit](https://github.com/open-webui/open-webui/commit/2250d102b28075a9611696e911536547abb8b38a), [Commit](https://github.com/open-webui/open-webui/commit/23c8f6d507bfee75ab0015a3e2972d5c26f7e9bf), [Commit](https://github.com/open-webui/open-webui/commit/a743b16728c6ae24b8befbc2d7f24eb9e20c4ad5) +- 🛠️ Functions admin interface received a comprehensive redesign with creator attribution display, ownership filters for created vs shared items, improved organization, and refined styling. [Commit](https://github.com/open-webui/open-webui/commit/f5e1a42f51acc0b9d5b63a33c1ca2e42470239c1) +- ⚡ Page initialization performance is significantly improved through parallel data loading and optimized folder API calls, reducing initial page load time. [#17559](https://github.com/open-webui/open-webui/pull/17559), [#17889](https://github.com/open-webui/open-webui/pull/17889) +- ⚡ Chat overview component is now dynamically loaded on demand, reducing initial page bundle size by approximately 470KB and improving first-screen loading speed. [#17595](https://github.com/open-webui/open-webui/pull/17595) +- 📁 Folders can now be attached to chats using the "#" command, automatically expanding to include all files within the folder for streamlined knowledge base integration. [Commit](https://github.com/open-webui/open-webui/commit/d2cb78179d66dc85188172a08622d4c97a2ea1ee) +- 📱 Progressive Web App now supports Android share target functionality, allowing users to share web pages, YouTube videos, and text directly to Open WebUI from the system share menu. [#17633](https://github.com/open-webui/open-webui/pull/17633), [#17125](https://github.com/open-webui/open-webui/issues/17125) +- 🗄️ Redis session storage is now available as an experimental option for OAuth authentication flows via the ENABLE_STAR_SESSIONS_MIDDLEWARE environment variable, providing shared session state across multi-replica deployments to address CSRF errors, though currently only basic Redis setups are supported. [#17223](https://github.com/open-webui/open-webui/pull/17223), [#15373](https://github.com/open-webui/open-webui/issues/15373), [Docs:Commit](https://github.com/open-webui/docs/commit/14052347f165d1b597615370373d7289ce44c7f9) +- 📊 Vega and Vega-Lite chart visualization renderers are now supported in code blocks, enabling inline rendering of data visualizations with automatic compilation of Vega-Lite specifications. [#18033](https://github.com/open-webui/open-webui/pull/18033), [#18040](https://github.com/open-webui/open-webui/pull/18040), [#18022](https://github.com/open-webui/open-webui/issues/18022) +- 🔗 OpenAI connections now support custom HTTP headers, enabling users to configure authentication and routing headers for specific deployment requirements. [#18021](https://github.com/open-webui/open-webui/pull/18021), [#9732](https://github.com/open-webui/open-webui/discussions/9732) +- 🔐 OpenID Connect authentication now supports OIDC providers without email scope via the ENABLE_OAUTH_WITHOUT_EMAIL environment variable, enabling compatibility with identity providers that don't expose email addresses. [#18047](https://github.com/open-webui/open-webui/pull/18047), [#18045](https://github.com/open-webui/open-webui/issues/18045) +- 🤖 Ollama model management modal now features individual model update cancellation, comprehensive tooltips for all buttons, and streamlined notification behavior to reduce toast spam. [#16863](https://github.com/open-webui/open-webui/pull/16863) +- ☁️ OneDrive file picker now includes search functionality and "My Organization" pivot for business accounts, enabling easier file discovery across organizational content. [#17930](https://github.com/open-webui/open-webui/pull/17930), [#17929](https://github.com/open-webui/open-webui/issues/17929) +- 📊 Chat overview flow diagram now supports toggling between vertical and horizontal layout orientations for improved visualization flexibility. [#17941](https://github.com/open-webui/open-webui/pull/17941) +- 🔊 OpenAI Text-to-Speech engine now supports additional parameters, allowing users to customize TTS behavior with provider-specific options via JSON configuration. [#17985](https://github.com/open-webui/open-webui/issues/17985), [#17188](https://github.com/open-webui/open-webui/pull/17188) +- 🛠️ Tool server list now displays server name, URL, and type (OpenAPI or MCP) for easier identification and management. [#18062](https://github.com/open-webui/open-webui/issues/18062) +- 📁 Folders now remember the last selected model, automatically applying it when starting new chats within that folder. [#17836](https://github.com/open-webui/open-webui/issues/17836) +- 🔢 Ollama embedding endpoint now supports the optional dimensions parameter for controlling embedding output size, compatible with Ollama v0.11.11 and later. [#17942](https://github.com/open-webui/open-webui/pull/17942) +- ⚡ Workspace knowledge page load time is improved by removing redundant API calls, enhancing overall responsiveness. [#18057](https://github.com/open-webui/open-webui/pull/18057) +- ⚡ File metadata query performance is enhanced by selecting only relevant columns instead of retrieving entire records, reducing database overhead. [#18013](https://github.com/open-webui/open-webui/pull/18013) +- 📄 Note PDF exports now include titles and properly render in dark mode with appropriate background colors. [Commit](https://github.com/open-webui/open-webui/commit/216fb5c3db1a223ffe6e72d97aa9551fe0e2d028) +- 📄 Docling document extraction now supports additional parameters for VLM pipeline configuration, enabling customized vision model settings. [#17363](https://github.com/open-webui/open-webui/pull/17363) +- ⚙️ Server startup script now supports passing arbitrary arguments to uvicorn, enabling custom server configuration options. [#17919](https://github.com/open-webui/open-webui/pull/17919), [#17918](https://github.com/open-webui/open-webui/issues/17918) +- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security. +- 🌐 Translations for German, Danish, Spanish, Korean, Portuguese (Brazil), Simplified Chinese, and Traditional Chinese were enhanced and expanded. + +### Fixed + +- 💬 System prompts are no longer duplicated in chat requests, eliminating confusion and excessive token usage caused by repeated instructions being sent to models. [#17198](https://github.com/open-webui/open-webui/issues/17198), [#16855](https://github.com/open-webui/open-webui/issues/16855) +- 🔐 MCP OAuth 2.1 authentication now complies with the standard by implementing PKCE with S256 code challenge method and explicitly passing client credentials during token authorization, resolving "code_challenge: Field required" and "client_id: Field required" errors when connecting to OAuth-secured MCP servers. [Commit](https://github.com/open-webui/open-webui/commit/911a114ad459f5deebd97543c13c2b90196efb54), [#18010](https://github.com/open-webui/open-webui/issues/18010), [#18087](https://github.com/open-webui/open-webui/pull/18087) +- 🔐 OAuth signup flow now handles password hashing correctly by migrating from passlib to native bcrypt, preventing failures when passwords exceed 72 bytes. [#17917](https://github.com/open-webui/open-webui/issues/17917) +- 🔐 OAuth token refresh errors are resolved by properly registering and storing OAuth clients, fixing "Constructor parameter should be str" exceptions for Google, Microsoft, and OIDC providers. [#17829](https://github.com/open-webui/open-webui/issues/17829) +- 🔐 OAuth server metadata URL is now correctly accessed via the proper attribute, fixing automatic token refresh and logout functionality for Microsoft OAuth provider when OPENID_PROVIDER_URL is not set. [#18065](https://github.com/open-webui/open-webui/pull/18065) +- 🔐 OAuth credential decryption failures now allow the application to start gracefully with clear error messages instead of crashing, preventing complete service outages when WEBUI_SECRET_KEY mismatches occur during database migrations or environment changes. [#18094](https://github.com/open-webui/open-webui/pull/18094), [#18092](https://github.com/open-webui/open-webui/issues/18092) +- 🔐 OAuth 2.1 server discovery now correctly attempts all configured discovery URLs in sequence instead of only trying the first URL. [#17906](https://github.com/open-webui/open-webui/pull/17906), [#17904](https://github.com/open-webui/open-webui/issues/17904), [#18026](https://github.com/open-webui/open-webui/pull/18026) +- 🔐 Login redirect now correctly honors the redirect query parameter after authentication, ensuring users are returned to their intended destination with query parameters intact instead of defaulting to the homepage. [#18071](https://github.com/open-webui/open-webui/issues/18071) +- ☁️ OneDrive Business integration authentication regression is resolved, ensuring the popup now properly triggers when connecting to OneDrive accounts. [#17902](https://github.com/open-webui/open-webui/pull/17902), [#17825](https://github.com/open-webui/open-webui/discussions/17825), [#17816](https://github.com/open-webui/open-webui/issues/17816) +- 👥 Default group settings now persist correctly after page navigation, ensuring configuration changes are properly saved and retained. [#17899](https://github.com/open-webui/open-webui/issues/17899), [#18003](https://github.com/open-webui/open-webui/issues/18003) +- 📁 Folder data integrity is now verified on retrieval, automatically fixing orphaned folders with invalid parent references and ensuring proper cascading deletion of nested folder structures. [Commit](https://github.com/open-webui/open-webui/commit/5448618dd5ea181b9635b77040cef60926a902ff) +- 🗄️ Redis Sentinel and Redis Cluster configurations with the experimental ENABLE_STAR_SESSIONS_MIDDLEWARE feature are now properly isolated by making the feature opt-in only, preventing ReadOnlyError failures when connecting to read replicas in multi-node Redis deployments. [#18073](https://github.com/open-webui/open-webui/issues/18073) +- 📊 Mermaid and Vega diagram rendering now displays error toast notifications when syntax errors are detected, helping users identify and fix diagram issues instead of silently failing. [#18068](https://github.com/open-webui/open-webui/pull/18068) +- 🤖 Reasoning models that return reasoning_content instead of content no longer cause NoneType errors during chat title generation, follow-up suggestions, and tag generation. [#18080](https://github.com/open-webui/open-webui/pull/18080) +- 📚 Citation rendering now correctly handles multiple source references in a single bracket, parsing formats like [1,2] and [1, 2] into separate clickable citation links. [#18120](https://github.com/open-webui/open-webui/pull/18120) +- 🔍 Web search now handles individual source failures gracefully, continuing to process remaining sources instead of failing entirely when a single URL is unreachable or returns an error. [Commit](https://github.com/open-webui/open-webui/commit/e000494e488090c5f66989a2b3f89d3eaeb7946b), [Commit](https://github.com/open-webui/open-webui/commit/53e98620bff38ab9280aee5165af0a704bdd99b9) +- 🔍 Hybrid search with reranking now handles empty result sets gracefully instead of crashing with ValueError when all results are filtered out due to relevance thresholds. [#18096](https://github.com/open-webui/open-webui/issues/18096) +- 🔍 Reranking models without defined padding tokens now work correctly by automatically falling back to eos_token_id as pad_token_id, fixing "Cannot handle batch sizes > 1" errors for models like Qwen3-Reranker. [#18108](https://github.com/open-webui/open-webui/pull/18108), [#16027](https://github.com/open-webui/open-webui/discussions/16027) +- 🔍 Model selector search now correctly returns results for non-admin users by dynamically updating the search index when the model list changes, fixing a race condition that caused empty search results. [#17996](https://github.com/open-webui/open-webui/pull/17996), [#17960](https://github.com/open-webui/open-webui/pull/17960) +- ⚡ Task model function calling performance is improved by excluding base64 image data from payloads, significantly reducing token count and memory usage when images are present in conversations. [#17897](https://github.com/open-webui/open-webui/pull/17897) +- 🤖 Text selection "Ask" action now correctly recognizes and uses local models configured via direct connections instead of only showing external provider models. [#17896](https://github.com/open-webui/open-webui/issues/17896) +- 🛑 Task cancellation API now returns accurate response status, correctly reporting successful cancellations instead of incorrectly indicating failures. [#17920](https://github.com/open-webui/open-webui/issues/17920) +- 💬 Follow-up query suggestions are now generated and displayed in temporary chats, matching the behavior of saved chats. [#14987](https://github.com/open-webui/open-webui/issues/14987) +- 🔊 Azure Text-to-Speech now properly escapes special characters like ampersands in SSML, preventing HTTP 400 errors and ensuring audio generation succeeds for all text content. [#17962](https://github.com/open-webui/open-webui/issues/17962) +- 🛠️ OpenAPI tool server calls with optional parameters now execute successfully even when no arguments are provided, removing the incorrect requirement for a request body. [#18036](https://github.com/open-webui/open-webui/issues/18036) +- 🛠️ MCP mode tool server connections no longer incorrectly validate the OpenAPI path field, allowing seamless switching between OpenAPI and MCP connection types. [#17989](https://github.com/open-webui/open-webui/pull/17989), [#17988](https://github.com/open-webui/open-webui/issues/17988) +- 🛠️ Third-party tool responses containing non-UTF8 or invalid byte sequences are now handled gracefully without causing request failures. [#17882](https://github.com/open-webui/open-webui/pull/17882) +- 🎨 Workspace filter dropdown now correctly renders model tags as strings instead of displaying individual characters, fixing broken filtering interface when models have multiple tags. [#18034](https://github.com/open-webui/open-webui/issues/18034) +- ⌨️ Ctrl+Enter keyboard shortcut now correctly sends messages in mobile and narrow browser views on Chrome instead of inserting newlines. [#17975](https://github.com/open-webui/open-webui/issues/17975) +- ⌨️ Tab characters are now preserved when pasting code or formatted text into the chat input box in plain text mode. [#17958](https://github.com/open-webui/open-webui/issues/17958) +- 📋 Text selection copying from the chat input box now correctly copies only the selected text instead of the entire textbox content. [#17911](https://github.com/open-webui/open-webui/issues/17911) +- 🔍 Web search query logging now uses debug level instead of info level, preventing user search queries from appearing in production logs. [#17888](https://github.com/open-webui/open-webui/pull/17888) +- 📝 Debug print statements in middleware were removed to prevent excessive log pollution and respect configured logging levels. [#17943](https://github.com/open-webui/open-webui/issues/17943) + +### Changed + +- 🗄️ Milvus vector database dependency is updated from pymilvus 2.5.0 to 2.6.2, ensuring compatibility with newer Milvus versions but requiring users on older Milvus instances to either upgrade their database or manually downgrade the pymilvus package. [#18066](https://github.com/open-webui/open-webui/pull/18066) + +## [0.6.32] - 2025-09-29 + +### Added + +- ⚡ JSON model import moved to backend processing for significant performance improvements when importing large model files. [#17871](https://github.com/open-webui/open-webui/pull/17871) +- ⚠️ Visual warnings for group permissions that display when a permission is disabled in a group but remains enabled in the default user role, clarifying inheritance behavior for administrators. [#17848](https://github.com/open-webui/open-webui/pull/17848) +- 🗄️ Milvus multi-tenancy mode using shared collections with resource ID filtering for improved scalability, mirroring the existing Qdrant implementation and configurable via ENABLE_MILVUS_MULTITENANCY_MODE environment variable. [#17837](https://github.com/open-webui/open-webui/pull/17837) +- 🛠️ Enhanced tool result processing with improved error handling, better MCP tool result handling, and performance improvements for embedded UI components. [Commit](https://github.com/open-webui/open-webui/commit/4f06f29348b2c9d71c87d1bbe5b748a368f5101f) +- 👥 New user groups now automatically inherit default group permissions, streamlining the admin setup process by eliminating manual permission configuration. [#17843](https://github.com/open-webui/open-webui/pull/17843) +- 🗂️ Bulk unarchive functionality for all chats, providing a single backend endpoint to efficiently restore all archived chats at once. [#17857](https://github.com/open-webui/open-webui/pull/17857) +- 🏷️ Browser tab title toggle setting allows users to control whether chat titles appear in the browser tab or display only "Open WebUI". [#17851](https://github.com/open-webui/open-webui/pull/17851) +- 💬 Reply-to-message functionality in channels, allowing users to reply directly to specific messages with visual threading and context display. [Commit](https://github.com/open-webui/open-webui/commit/1a18928c94903ad1f1f0391b8ade042c3e60205b) +- 🔧 Tool server import and export functionality, allowing direct upload of openapi.json and openapi.yaml files as an alternative to URL-based configuration. [#14446](https://github.com/open-webui/open-webui/issues/14446) +- 🔧 User valve configuration for Functions is now available in the integration menu, providing consistent management alongside Tools. [#17784](https://github.com/open-webui/open-webui/issues/17784) +- 🔐 Admin permission toggle for controlling public sharing of notes, configurable via USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING environment variable. [#17801](https://github.com/open-webui/open-webui/pull/17801), [Docs:#715](https://github.com/open-webui/docs/pull/715) +- 🗄️ DISKANN index type support for Milvus vector database with configurable maximum degree and search list size parameters. [#17770](https://github.com/open-webui/open-webui/pull/17770), [Docs:Commit](https://github.com/open-webui/docs/commit/cec50ab4d4b659558ca1ccd4b5e6fc024f05fb83) +- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security. +- 🌐 Translations for Chinese (Simplified & Traditional) and Bosnian (Latin) were enhanced and expanded. + +### Fixed + +- 🛠️ MCP tool calls are now correctly routed to the appropriate server when multiple streamable-http MCP servers are enabled, preventing "Tool not found" errors. [#17817](https://github.com/open-webui/open-webui/issues/17817) +- 🛠️ External tool servers (OpenAPI/MCP) now properly process and return tool results to the model, restoring functionality that was broken in v0.6.31. [#17764](https://github.com/open-webui/open-webui/issues/17764) +- 🔧 User valve detection now correctly identifies valves in imported tool code, ensuring gear icons appear in the integrations menu for all tools with user valves. [#17765](https://github.com/open-webui/open-webui/issues/17765) +- 🔐 MCP OAuth discovery now correctly handles multi-tenant configurations by including subpaths in metadata URL discovery. [#17768](https://github.com/open-webui/open-webui/issues/17768) +- 🗄️ Milvus query operations now correctly use -1 instead of None for unlimited queries, preventing TypeError exceptions. [#17769](https://github.com/open-webui/open-webui/pull/17769), [#17088](https://github.com/open-webui/open-webui/issues/17088) +- 📁 File upload error messages are now displayed when files are modified during upload, preventing user confusion on Android and Windows devices. [#17777](https://github.com/open-webui/open-webui/pull/17777) +- 🎨 MessageInput Integrations button hover effect now displays correctly with proper visual feedback. [#17767](https://github.com/open-webui/open-webui/pull/17767) +- 🎯 "Set as default" label positioning is fixed to ensure it remains clickable in all scenarios, including multi-model configurations. [#17779](https://github.com/open-webui/open-webui/pull/17779) +- 🎛️ Floating buttons now correctly retrieve message context by using the proper messageId parameter in createMessagesList calls. [#17823](https://github.com/open-webui/open-webui/pull/17823) +- 📌 Pinned chats are now properly cleared from the sidebar after archiving all chats, ensuring UI consistency without requiring a page refresh. [#17832](https://github.com/open-webui/open-webui/pull/17832) +- 🗑️ Delete confirmation modals now properly truncate long names for Notes, Prompts, Tools, and Functions to prevent modal overflow. [#17812](https://github.com/open-webui/open-webui/pull/17812) +- 🌐 Internationalization function calls now use proper Svelte store subscription syntax, preventing "i18n.t is not a function" errors on the model creation page. [#17819](https://github.com/open-webui/open-webui/pull/17819) +- 🎨 Playground chat interface button layout is corrected to prevent vertical text rendering for Assistant/User role buttons. [#17819](https://github.com/open-webui/open-webui/pull/17819) +- 🏷️ UI text truncation is improved across multiple components including usernames in admin panels, arena model names, model tags, and filter tags to prevent layout overflow issues. [#17805](https://github.com/open-webui/open-webui/pull/17805), [#17803](https://github.com/open-webui/open-webui/pull/17803), [#17791](https://github.com/open-webui/open-webui/pull/17791), [#17796](https://github.com/open-webui/open-webui/pull/17796) + +## [0.6.31] - 2025-09-25 + +### Added + +- 🔌 MCP (streamable HTTP) server support was added alongside existing OpenAPI server integration, allowing users to connect both server types through an improved server configuration interface. [#15932](https://github.com/open-webui/open-webui/issues/15932) [#16651](https://github.com/open-webui/open-webui/pull/16651), [Commit](https://github.com/open-webui/open-webui/commit/fd7385c3921eb59af76a26f4c475aedb38ce2406), [Commit](https://github.com/open-webui/open-webui/commit/777e81f7a8aca957a359d51df8388e5af4721a68), [Commit](https://github.com/open-webui/open-webui/commit/de7f7b3d855641450f8e5aac34fbae0665e0b80e), [Commit](https://github.com/open-webui/open-webui/commit/f1bbf3a91e4713039364b790e886e59b401572d0), [Commit](https://github.com/open-webui/open-webui/commit/c55afc42559c32a6f0c8beb0f1bb18e9360ab8af), [Commit](https://github.com/open-webui/open-webui/commit/61f20acf61f4fe30c0e5b0180949f6e1a8cf6524) +- 🔐 To enable MCP server authentication, OAuth 2.1 dynamic client registration was implemented with secure automatic client registration, encrypted session management, and seamless authentication flows. [Commit](https://github.com/open-webui/open-webui/commit/972be4eda5a394c111e849075f94099c9c0dd9aa), [Commit](https://github.com/open-webui/open-webui/commit/77e971dd9fbeee806e2864e686df5ec75e82104b), [Commit](https://github.com/open-webui/open-webui/commit/879abd7feea3692a2f157da4a458d30f27217508), [Commit](https://github.com/open-webui/open-webui/commit/422d38fd114b1ebd8a7dbb114d64e14791e67d7a), [Docs:#709](https://github.com/open-webui/docs/pull/709) +- 🛠️ External & Built-In Tools can now support rich UI element embedding ([Docs](https://docs.openwebui.com/features/plugin/tools/development)), allowing tools to return HTML content and interactive iframes that display directly within chat conversations with configurable security settings. [Commit](https://github.com/open-webui/open-webui/commit/07c5b25bc8b63173f406feb3ba183d375fedee6a), [Commit](https://github.com/open-webui/open-webui/commit/a5d8882bba7933a2c2c31c0a1405aba507c370bb), [Commit](https://github.com/open-webui/open-webui/commit/7be5b7f50f498de97359003609fc5993a172f084), [Commit](https://github.com/open-webui/open-webui/commit/a89ffccd7e96705a4a40e845289f4fcf9c4ae596) +- 📝 Note editor now supports drag-and-drop reordering of list items with visual drag handles, making list organization more intuitive and efficient. [Commit](https://github.com/open-webui/open-webui/commit/e4e97e727e9b4971f1c363b1280ca3a101599d88), [Commit](https://github.com/open-webui/open-webui/commit/aeb5288a3c7a6e9e0a47b807cc52f870c1b7dbe6) +- 🔍 Search modal was enhanced with quick action buttons for starting new conversations and creating notes, with intelligent content pre-population from search queries. [Commit](https://github.com/open-webui/open-webui/commit/aa6f63a335e172fec1dc94b2056541f52c1167a6), [Commit](https://github.com/open-webui/open-webui/commit/612a52d7bb7dbe9fa0bbbc8ac0a552d2b9801146), [Commit](https://github.com/open-webui/open-webui/commit/b03529b006f3148e895b1094584e1ab129ecac5b) +- 🛠️ Tool user valve configuration interface was added to the integrations menu, displaying clickable gear icon buttons with tooltips for tools that support user-specific settings, making personal tool configurations easily accessible. [Commit](https://github.com/open-webui/open-webui/commit/27d61307cdce97ed11a05ec13fc300249d6022cd) +- 👥 Channel access control was enhanced to require write permissions for posting, editing, and deleting messages, while read-only users can view content but cannot contribute. [#17543](https://github.com/open-webui/open-webui/pull/17543) +- 💬 Channel models now support image processing, allowing AI assistants to view and analyze images shared in conversation threads. [Commit](https://github.com/open-webui/open-webui/commit/9f0010e234a6f40782a66021435d3c02b9c23639) +- 🌐 Attach Webpage button was added to the message input menu, providing a user-friendly modal interface for attaching web content and YouTube videos as an alternative to the existing URL syntax. [#17534](https://github.com/open-webui/open-webui/pull/17534) +- 🔐 Redis session storage support was added for OAuth redirects, providing better state handling in multi-pod Kubernetes deployments and resolving CSRF mismatch errors. [#17223](https://github.com/open-webui/open-webui/pull/17223), [#15373](https://github.com/open-webui/open-webui/issues/15373) +- 🔍 Ollama Cloud web search integration was added as a new search engine option, providing access to web search functionality through Ollama's cloud infrastructure. [Commit](https://github.com/open-webui/open-webui/commit/e06489d92baca095b8f376fbef223298c7772579), [Commit](https://github.com/open-webui/open-webui/commit/4b6d34438bcfc45463dc7a9cb984794b32c1f0a1), [Commit](https://github.com/open-webui/open-webui/commit/05c46008da85357dc6890b846789dfaa59f4a520), [Commit](https://github.com/open-webui/open-webui/commit/fe65fe0b97ec5a8fff71592ff04a25c8e123d108), [Docs:#708](https://github.com/open-webui/docs/pull/708) +- 🔍 Perplexity Websearch API integration was added as a new search engine option, providing access to the new websearch functionality provided by Perplexity. [#17756](https://github.com/open-webui/open-webui/issues/17756), [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/7f411dd5cc1c29733216f79e99eeeed0406a2afe) +- ☁️ OneDrive integration was improved to support separate client IDs for personal and business authentication, enabling both integrations to work simultaneously. [#17619](https://github.com/open-webui/open-webui/pull/17619), [Docs](https://docs.openwebui.com/tutorials/integrations/onedrive-sharepoint), [Docs](https://docs.openwebui.com/getting-started/env-configuration/#onedrive) +- 📝 Pending user overlay content now supports markdown formatting, enabling rich text display for custom messages similar to banner functionality. [#17681](https://github.com/open-webui/open-webui/pull/17681) +- 🎨 Image generation model selection was centralized to enable dynamic model override in function calls, allowing pipes and tools to specify different models than the global default while maintaining backward compatibility. [#17689](https://github.com/open-webui/open-webui/pull/17689) +- 🎨 Interface design was modernized with updated visual styling, improved spacing, and refined component layouts across modals, sidebar, settings, and navigation elements. [Commit](https://github.com/open-webui/open-webui/commit/27a91cc80a24bda0a3a188bc3120a8ab57b00881), [Commit](https://github.com/open-webui/open-webui/commit/4ad743098615f9c58daa9df392f31109aeceeb16), [Commit](https://github.com/open-webui/open-webui/commit/fd7385c3921eb59af76a26f4c475aedb38ce2406) +- 📊 Notes query performance was optimized through database-level filtering and separated access control logic, reducing memory usage and eliminating N+1 query problems for better scalability. [#17607](https://github.com/open-webui/open-webui/pull/17607) [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/da661756fa7eec754270e6dd8c67cbf74a28a17f) +- ⚡ Page loading performance was optimized by deferring API requests until components are actually opened, including ChangelogModal, ModelSelector, RecursiveFolder, ArchivedChatsModal, and SearchModal. [#17542](https://github.com/open-webui/open-webui/pull/17542), [#17555](https://github.com/open-webui/open-webui/pull/17555), [#17557](https://github.com/open-webui/open-webui/pull/17557), [#17541](https://github.com/open-webui/open-webui/pull/17541), [#17640](https://github.com/open-webui/open-webui/pull/17640) +- ⚡ Bundle size was reduced by 1.58MB through optimized highlight.js language support, improving page loading speed and reducing bandwidth usage. [#17645](https://github.com/open-webui/open-webui/pull/17645) +- ⚡ Editor collaboration functionality was refactored to reduce package size by 390KB and minimize compilation errors, improving build performance and reliability. [#17593](https://github.com/open-webui/open-webui/pull/17593) +- ♿ Enhanced user interface accessibility through the addition of unique element IDs, improving targeting for testing, styling, and assistive technologies while providing better semantic markup for screen readers and accessibility tools. [#17746](https://github.com/open-webui/open-webui/pull/17746) +- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security. +- 🌐 Translations for Portuguese (Brazil), Chinese (Simplified and Traditional), Korean, Irish, Spanish, Finnish, French, Kabyle, Russian, and Catalan were enhanced and improved. + +### Fixed + +- 🛡️ SVG content security was enhanced by implementing DOMPurify sanitization to prevent XSS attacks through malicious SVG elements, ensuring safe rendering of user-generated SVG content. [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/750a659a9fee7687e667d9d755e17b8a0c77d557) +- ☁️ OneDrive attachment menu rendering issues were resolved by restructuring the submenu interface from dropdown to tabbed navigation, preventing menu items from being hidden or clipped due to overflow constraints. [#17554](https://github.com/open-webui/open-webui/issues/17554), [Commit](https://github.com/open-webui/open-webui/pull/17747/commits/90e4b49b881b644465831cc3028bb44f0f7a2196) +- 💬 Attached conversation references now persist throughout the entire chat session, ensuring models can continue querying referenced conversations after multiple conversation turns. [#17750](https://github.com/open-webui/open-webui/issues/17750) +- 🔍 Search modal text box focus issues after pinning or unpinning chats were resolved, allowing users to properly exit the search interface by clicking outside the text box. [#17743](https://github.com/open-webui/open-webui/issues/17743) +- 🔍 Search function chat list is now properly updated in real-time when chats are created or deleted, eliminating stale search results and preview loading failures. [#17741](https://github.com/open-webui/open-webui/issues/17741) +- 💬 Chat jitter and delayed code block expansion in multi-model sessions were resolved by reverting dynamic CodeEditor loading, restoring stable rendering behavior. [#17715](https://github.com/open-webui/open-webui/pull/17715), [#17684](https://github.com/open-webui/open-webui/issues/17684) +- 📎 File upload handling was improved to properly recognize uploaded files even when no accompanying text message is provided, resolving issues where attachments were ignored in custom prompts. [#17492](https://github.com/open-webui/open-webui/issues/17492) +- 💬 Chat conversation referencing within projects was restored by including foldered chats in the reference menu, allowing users to properly quote conversations from within their project scope. [#17530](https://github.com/open-webui/open-webui/issues/17530) +- 🔍 RAG query generation is now skipped when all attached files are set to full context mode, preventing unnecessary retrieval operations and improving system efficiency. [#17744](https://github.com/open-webui/open-webui/pull/17744) +- 💾 Memory leaks in file handling and HTTP connections are prevented through proper resource cleanup, ensuring stable memory usage during large file downloads and processing operations. [#17608](https://github.com/open-webui/open-webui/pull/17608) +- 🔐 OAuth access token refresh errors are resolved by properly implementing async/await patterns, preventing "coroutine object has no attribute get" failures during token expiry. [#17585](https://github.com/open-webui/open-webui/issues/17585), [#17678](https://github.com/open-webui/open-webui/issues/17678) +- ⚙️ Valve behavior was improved to properly handle default values and array types, ensuring only explicitly set values are persisted while maintaining consistent distinction between custom and default valve states. [#17664](https://github.com/open-webui/open-webui/pull/17664) +- 🔍 Hybrid search functionality was enhanced to handle inconsistent parameter types and prevent failures when collection results are None, empty, or in unexpected formats. [#17617](https://github.com/open-webui/open-webui/pull/17617) +- 📁 Empty folder deletion is now allowed regardless of chat deletion permission restrictions, resolving cases where users couldn't remove folders after deleting all contained chats. [#17683](https://github.com/open-webui/open-webui/pull/17683) +- 📝 Rich text editor console errors were resolved by adding proper error handling when the TipTap editor view is not available or not yet mounted. [#17697](https://github.com/open-webui/open-webui/issues/17697) +- 🗒️ Hidden models are now properly excluded from the notes section dropdown and default model selection, preventing users from accessing models they shouldn't see. [#17722](https://github.com/open-webui/open-webui/pull/17722) +- 🖼️ AI-generated image download filenames now use a clean, translatable "Generated Image" format instead of potentially problematic response text, improving file management and compatibility. [#17721](https://github.com/open-webui/open-webui/pull/17721) +- 🎨 Toggle switch display issues in the Integrations interface are fixed, preventing background highlighting and obscuring on hover. [#17564](https://github.com/open-webui/open-webui/issues/17564) + +### Changed + +- 👥 Channel permissions now require write access for message posting, editing, and deletion, with existing user groups defaulting to read-only access requiring manual admin migration to write permissions for full participation. +- ☁️ OneDrive environment variable configuration was updated to use separate ONEDRIVE_CLIENT_ID_PERSONAL and ONEDRIVE_CLIENT_ID_BUSINESS variables for better client ID separation, while maintaining backward compatibility with the legacy ONEDRIVE_CLIENT_ID variable. [Docs](https://docs.openwebui.com/tutorials/integrations/onedrive-sharepoint), [Docs](https://docs.openwebui.com/getting-started/env-configuration/#onedrive) + +## [0.6.30] - 2025-09-17 + +### Added + +- 🔑 Microsoft Entra ID authentication type support was added for Azure OpenAI connections, enabling enhanced security and streamlined authentication workflows. + +### Fixed + +- ☁️ OneDrive integration was fixed after recent breakage, restoring reliable account connectivity and file access. + +## [0.6.29] - 2025-09-17 + +### Added + +- 🎨 The chat input menu has been completely overhauled with a revolutionary new design, consolidating attachments under a unified '+' button, organizing integrations into a streamlined options menu, and introducing powerful, interactive selectors for attaching chats, notes, and knowledge base items. [Commit](https://github.com/open-webui/open-webui/commit/a68342d5a887e36695e21f8c2aec593b159654ff), [Commit](https://github.com/open-webui/open-webui/commit/96b8aaf83ff341fef432649366bc5155bac6cf20), [Commit](https://github.com/open-webui/open-webui/commit/4977e6d50f7b931372c96dd5979ca635d58aeb78), [Commit](https://github.com/open-webui/open-webui/commit/d973db829f7ec98b8f8fe7d3b2822d588e79f94e), [Commit](https://github.com/open-webui/open-webui/commit/d4c628de09654df76653ad9bce9cb3263e2f27c8), [Commit](https://github.com/open-webui/open-webui/commit/cd740f436db4ea308dbede14ef7ff56e8126f51b), [Commit](https://github.com/open-webui/open-webui/commit/5c2db102d06b5c18beb248d795682ff422e9b6d1), [Commit](https://github.com/open-webui/open-webui/commit/031cf38655a1a2973194d2eaa0fbbd17aca8ee92), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/3ed0a6d11fea1a054e0bc8aa8dfbe417c7c53e51), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/eadec9e86e01bc8f9fb90dfe7a7ae4fc3bfa6420), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/c03ca7270e64e3a002d321237160c0ddaf2bb129), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/b53ddfbd19aa94e9cbf7210acb31c3cfafafa5fe), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/c923461882fcde30ae297a95e91176c95b9b72e1) +- 🤖 AI models can now be mentioned in channels to automatically generate responses, enabling multi-model conversations where mentioned models participate directly in threaded discussions with full context awareness. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/4fe97d8794ee18e087790caab9e5d82886006145) +- 💬 The Channels feature now utilizes the modern rich text editor, including support for '/', '@', and '#' command suggestions. [Commit](https://github.com/open-webui/open-webui/commit/06c1426e14ac0dfaf723485dbbc9723a4d89aba9), [Commit](https://github.com/open-webui/open-webui/commit/02f7c3258b62970ce79716f75d15467a96565054) +- 📎 Channel message input now supports direct paste functionality for images and files from the clipboard, streamlining content sharing workflows. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/6549fc839f86c40c26c2ef4dedcaf763a9304418) +- ⚙️ Models can now be configured with default features (Web Search, Image Generation) and filters that automatically activate when a user selects the model. [Commit](https://github.com/open-webui/open-webui/commit/9a555478273355a5177bfc7f7211c64778e4c8de), [Commit](https://github.com/open-webui/open-webui/commit/384a53b339820068e92f7eaea0d9f3e0536c19c2), [Commit](https://github.com/open-webui/open-webui/commit/d7f43bfc1a30c065def8c50d77c2579c1a3c5c67), [Commit](https://github.com/open-webui/open-webui/commit/6a67a2217cc5946ad771e479e3a37ac213210748) +- 💬 The ability to reference other chats as context within a conversation was added via the attachment menu. [Commit](https://github.com/open-webui/open-webui/commit/e097bbdf11ae4975c622e086df00d054291cdeb3), [Commit](https://github.com/open-webui/open-webui/commit/f3cd2ffb18e7dedbe88430f9ae7caa6b3cfd79d0), [Commit](https://github.com/open-webui/open-webui/commit/74263c872c5d574a9bb0944d7984f748dc772dba), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/aa8ab349ed2fcb46d1cf994b9c0de2ec2ea35d0d), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/025eef754f0d46789981defd473d001e3b1d0ca2) +- 🎨 The command suggestion UI for prompts ('/'), models ('@'), and knowledge ('#') was completely overhauled with a more responsive and keyboard-navigable interface. [Commit](https://github.com/open-webui/open-webui/commit/6b69c4da0fb9329ccf7024483960e070cf52ccab), [Commit](https://github.com/open-webui/open-webui/commit/06a6855f844456eceaa4d410c93379460e208202), [Commit](https://github.com/open-webui/open-webui/commit/c55f5578280b936cf581a743df3703e3db1afd54), [Commit](https://github.com/open-webui/open-webui/commit/f68d1ba394d4423d369f827894cde99d760b2402) +- 👥 User and channel suggestions were added to the mention system, enabling '@' mentions for users and models, and '#' mentions for channels with searchable user lookup and clickable navigation. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/bbd1d2b58c89b35daea234f1fc9208f2af840899), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/aef1e06f0bb72065a25579c982dd49157e320268), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/779db74d7e9b7b00d099b7d65cfbc8a831e74690) +- 📁 Folder functionality was enhanced with custom background image support, improved drag-and-drop capabilities for moving folders to root level, and better menu interactions. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/2a234829f5dfdfde27fdfd30591caa908340efb4), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/2b1ee8b0dc5f7c0caaafdd218f20705059fa72e2), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/b1e5bc8e490745f701909c19b6a444b67c04660e), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/3e584132686372dfeef187596a7c557aa5f48308) +- ☁️ OneDrive integration configuration now supports selecting between personal and work/school account types via ENABLE_ONEDRIVE_PERSONAL and ENABLE_ONEDRIVE_BUSINESS environment variables. [#17354](https://github.com/open-webui/open-webui/pull/17354), [Commit](https://github.com/open-webui/open-webui/commit/e1e3009a30f9808ce06582d81a60e391f5ca09ec), [Docs:#697](https://github.com/open-webui/docs/pull/697) +- ⚡ Mermaid.js is now dynamically loaded on demand, significantly reducing first-screen loading time and improving initial page performance. [#17476](https://github.com/open-webui/open-webui/issues/17476), [#17477](https://github.com/open-webui/open-webui/pull/17477) +- ⚡ Azure MSAL browser library is now dynamically loaded on demand, reducing initial bundle size by 730KB and improving first-screen loading speed. [#17479](https://github.com/open-webui/open-webui/pull/17479) +- ⚡ CodeEditor component is now dynamically loaded on demand, reducing initial bundle size by 1MB and improving first-screen loading speed. [#17498](https://github.com/open-webui/open-webui/pull/17498) +- ⚡ Hugging Face Transformers library is now dynamically loaded on demand, reducing initial bundle size by 1.9MB and improving first-screen loading speed. [#17499](https://github.com/open-webui/open-webui/pull/17499) +- ⚡ jsPDF and html2canvas-pro libraries are now dynamically loaded on demand, reducing initial bundle size by 980KB and improving first-screen loading speed. [#17502](https://github.com/open-webui/open-webui/pull/17502) +- ⚡ Leaflet mapping library is now dynamically loaded on demand, reducing initial bundle size by 454KB and improving first-screen loading speed. [#17503](https://github.com/open-webui/open-webui/pull/17503) +- 📊 OpenTelemetry metrics collection was enhanced to properly handle HTTP 500 errors and ensure metrics are recorded even during exceptions. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/b14617a653c6bdcfd3102c12f971924fd1faf572) +- 🔒 OAuth token retrieval logic was refactored, improving the reliability and consistency of authentication handling across the backend. [Commit](https://github.com/open-webui/open-webui/commit/6c0a5fa91cdbf6ffb74667ee61ca96bebfdfbc50) +- 💻 Code block output processing was improved to handle Python execution results more reliably, along with refined visual styling and button layouts. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/0e5320c39e308ff97f2ca9e289618af12479eb6e) +- ⚡ Message input processing was optimized to skip unnecessary text variable handling when input is empty, improving performance. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/e1386fe80b77126a12dabc4ad058abe9b024b275) +- 📄 Individual chat PDF export was added to the sidebar chat menu, allowing users to export single conversations as PDF documents with both stylized and plain text options. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/d041d58bb619689cd04a391b4f8191b23941ca62) +- 🛠️ Function validation was enhanced with improved valve validation and better error handling during function loading and synchronization. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/e66e0526ed6a116323285f79f44237538b6c75e6), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/8edfd29102e0a61777b23d3575eaa30be37b59a5) +- 🔔 Notification toast interaction was enhanced with drag detection to prevent accidental clicks and added keyboard support for accessibility. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/621e7679c427b6f0efa85f95235319238bf171ad) +- 🗓️ Improved date and time formatting dynamically adapts to the selected language, ensuring consistent localization across the UI. [#17409](https://github.com/open-webui/open-webui/pull/17409), [Commit](https://github.com/open-webui/open-webui/commit/2227f24bd6d861b1fad8d2cabacf7d62ce137d0c) +- 🔒 Feishu SSO integration was added, allowing users to authenticate via Feishu. [#17284](https://github.com/open-webui/open-webui/pull/17284), [Docs:#685](https://github.com/open-webui/docs/pull/685) +- 🔠 Toggle filters in the chat input options menu are now sorted alphabetically for easier navigation. [Commit](https://github.com/open-webui/open-webui/commit/ca853ca4656180487afcd84230d214f91db52533) +- 🎨 Long chat titles in the sidebar are now truncated to prevent text overflow and maintain a clean layout. [#17356](https://github.com/open-webui/open-webui/pull/17356) +- 🎨 Temporary chat interface design was refined with improved layout and visual consistency. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/67549dcadd670285d491bd41daf3d081a70fd094), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/2ca34217e68f3b439899c75881dfb050f49c9eb2), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/fb02ec52a5df3f58b53db4ab3a995c15f83503cd) +- 🎨 Download icon consistency was improved across the entire interface by standardizing the icon component used in menus, functions, tools, and export features. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/596be451ece7e11b5cd25465d49670c27a1cb33f) +- 🎨 Settings interface was enhanced with improved iconography and reorganized the 'Chats' section into 'Data Controls' for better clarity. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/8bf0b40fdd978b5af6548a6e1fb3aabd90bcd5cd) +- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security. +- 🌐 Translations for Finnish, German, Kabyle, Portuguese (Brazil), Simplified Chinese, Spanish (Spain), and Traditional Chinese (Taiwan) were enhanced and expanded. + +### Fixed + +- 📚 Knowledge base permission logic was corrected to ensure private collection owners can access their own content when embedding bypass is enabled. [#17432](https://github.com/open-webui/open-webui/issues/17432), [Commit](https://github.com/open-webui/open-webui/commit/a51f0c30ec1472d71487eab3e15d0351a2716b12) +- ⚙️ Connection URL editing in Admin Settings now properly saves changes instead of reverting to original values, fixing issues with both Ollama and OpenAI-compatible endpoints. [#17435](https://github.com/open-webui/open-webui/issues/17435), [Commit](https://github.com/open-webui/open-webui/commit/e4c864de7eb0d577843a80688677ce3659d1f81f) +- 📊 Usage information collection from Google models was corrected to handle providers that send usage data alongside content chunks instead of separately. [#17421](https://github.com/open-webui/open-webui/pull/17421), [Commit](https://github.com/open-webui/open-webui/commit/c2f98a4cd29ed738f395fef09c42ab8e73cd46a0) +- ⚙️ Settings modal scrolling issue was resolved by moving image compression controls to a dedicated modal, preventing the main settings from becoming scrollable out of view. [#17474](https://github.com/open-webui/open-webui/issues/17474), [Commit](https://github.com/open-webui/open-webui/commit/fed5615c19b0045a55b0be426b468a57bfda4b66) +- 📁 Folder click behavior was improved to prevent accidental actions by implementing proper double-click detection and timing delays for folder expansion and selection. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/19e3214997170eea6ee92452e8c778e04a28e396) +- 🔐 Access control component reliability was improved with better null checking and error handling for group permissions and private access scenarios. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/c8780a7f934c5e49a21b438f2f30232f83cf75d2), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/32015c392dbc6b7367a6a91d9e173e675ea3402c) +- 🔗 The citation modal now correctly displays and links to external web page sources in addition to internal documents. [Commit](https://github.com/open-webui/open-webui/commit/9208a84185a7e59524f00a7576667d493c3ac7d4) +- 🔗 Web and YouTube attachment handling was fixed, ensuring their content is now reliably processed and included in the chat context for retrieval. [Commit](https://github.com/open-webui/open-webui/commit/210197fd438b52080cda5d6ce3d47b92cdc264c8) +- 📂 Large file upload failures are resolved by correcting the processing logic for scenarios where document embedding is bypassed. [Commit](https://github.com/open-webui/open-webui/commit/051b6daa8299fd332503bd584563556e2ae6adab) +- 🌐 Rich text input placeholder text now correctly updates when the interface language is switched, ensuring proper localization. [#17473](https://github.com/open-webui/open-webui/pull/17473), [Commit](https://github.com/open-webui/open-webui/commit/77358031f5077e6efe5cc08d8d4e5831c7cd1cd9) +- 📊 Llama.cpp server timing metrics are now correctly parsed and displayed by fixing a typo in the response handling. [#17350](https://github.com/open-webui/open-webui/issues/17350), [Commit](https://github.com/open-webui/open-webui/commit/cf72f5503f39834b9da44ebbb426a3674dad0caa) +- 🛠️ Filter functions with file_handler configuration now properly handle messages without file attachments, preventing runtime errors. [#17423](https://github.com/open-webui/open-webui/pull/17423) +- 🔔 Channel notification delivery was fixed to properly handle background task execution and user access checking. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/1077b2ac8b96e49c2ad2620e76eb65bbb2a3a1f3) + +### Changed + +- 📝 Prompt template variables are now optional by default instead of being forced as required, allowing flexible workflows with optional metadata fields. [#17447](https://github.com/open-webui/open-webui/issues/17447), [Commit](https://github.com/open-webui/open-webui/commit/d5824b1b495fcf86e57171769bcec2a0f698b070), [Docs:#696](https://github.com/open-webui/docs/pull/696) +- 🛠️ Direct external tool servers now require explicit user selection from the input interface instead of being automatically included in conversations, providing better control over tool usage. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/0f04227c34ca32746c43a9323e2df32299fcb6af), [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/99bba12de279dd55c55ded35b2e4f819af1c9ab5) +- 📺 Widescreen mode option was removed from Channels interface, with all channel layouts now using full-width display. [Commit](https://github.com/open-webui/open-webui/pull/17420/commits/d46b7b8f1b99a8054b55031fe935c8a16d5ec956) +- 🎛️ The plain textarea input option was deprecated, and the custom text editor is now the standard for all chat inputs. [Commit](https://github.com/open-webui/open-webui/commit/153afd832ccd12a1e5fd99b085008d080872c161) + +## [0.6.28] - 2025-09-10 + +### Added + +- 🔍 The "@" command for model selection now supports real-time search and filtering, improving usability and aligning its behavior with other input commands. [#17307](https://github.com/open-webui/open-webui/issues/17307), [Commit](https://github.com/open-webui/open-webui/commit/f2a09c71499489ee71599af4a179e7518aaf658b) +- 🛠️ External tool server data handling is now more robust, automatically attempting to parse specifications as JSON before falling back to YAML, regardless of the URL extension. [Commit](https://github.com/open-webui/open-webui/commit/774c0056bde88ed4831422efa81506488e3d6641) +- 🎯 The "Title" field is now automatically focused when creating a new chat folder, streamlining the folder creation process. [#17315](https://github.com/open-webui/open-webui/issues/17315), [Commit](https://github.com/open-webui/open-webui/commit/c51a651a2d5e2a27546416666812e9b92205562d) +- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security. +- 🌐 Brazilian Portuguese and Simplified Chinese translations were expanded and refined. + +### Fixed + +- 🔊 A regression affecting Text-to-Speech for local providers using the OpenAI engine was fixed by reverting a URL joining change. [#17316](https://github.com/open-webui/open-webui/issues/17316), [Commit](https://github.com/open-webui/open-webui/commit/8339f59cdfc63f2d58c8e26933d1bf1438479d75) +- 🪧 A regression was fixed where the input modal for prompts with placeholders would not open, causing the raw prompt text to be pasted into the chat input field instead. [#17325](https://github.com/open-webui/open-webui/issues/17325), [Commit](https://github.com/open-webui/open-webui/commit/d5cb65527eaa4831459a4c7dbf187daa9c0525ae) +- 🔑 An issue was resolved where modified connection keys in the OpenAIConnection component did not take effect. [#17324](https://github.com/open-webui/open-webui/pull/17324) + +## [0.6.27] - 2025-09-09 + +### Added + +- 📁 Emoji folder icons were added, allowing users to personalize workspace organization with visual cues, including improved chevron display. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/1588f42fe777ad5d807e3f2fc8dbbc47a8db87c0), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/b70c0f36c0f5bbfc2a767429984d6fba1a7bb26c), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/11dea8795bfce42aa5d8d58ef316ded05173bd87), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/c0a47169fa059154d5f5a9ea6b94f9a66d82f255) +- 📁 The 'Search Collection' input field now dynamically displays the total number of files within the knowledge base. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/fbbe1117ae4c9c8fec6499d790eee275818eccc5) +- ☁️ A provider toggle in connection settings now allows users to manually specify Azure OpenAI deployments. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/5bdd334b74fbd154085f2d590f4afdba32469c8a) +- ⚡ Model list caching performance was optimized by fixing cache key generation to reduce redundant API calls. [#17158](https://github.com/open-webui/open-webui/pull/17158) +- 🎨 Azure OpenAI image generation is now supported, with configurations for IMAGES_OPENAI_API_VERSION via environment variable and admin UI. [#17147](https://github.com/open-webui/open-webui/pull/17147), [#16274](https://github.com/open-webui/open-webui/discussions/16274), [Docs:#679](https://github.com/open-webui/docs/pull/679) +- ⚡ Comprehensive N+1 query performance is optimized by reducing database queries from 1+N to 1+1 patterns across major listing endpoints. [#17165](https://github.com/open-webui/open-webui/pull/17165), [#17160](https://github.com/open-webui/open-webui/pull/17160), [#17161](https://github.com/open-webui/open-webui/pull/17161), [#17162](https://github.com/open-webui/open-webui/pull/17162), [#17159](https://github.com/open-webui/open-webui/pull/17159), [#17166](https://github.com/open-webui/open-webui/pull/17166) +- ⚡ The PDF.js library is now dynamically loaded, significantly reducing initial page load size and improving responsiveness. [#17222](https://github.com/open-webui/open-webui/pull/17222) +- ⚡ The heic2any library is now dynamically loaded across various message input components, including channels, for faster page loads. [#17225](https://github.com/open-webui/open-webui/pull/17225), [#17229](https://github.com/open-webui/open-webui/pull/17229) +- 📚 The knowledge API now supports a "delete_file" query parameter, allowing configurable file deletion behavior. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/22c4ef4fb096498066b73befe993ae3a82f7a8e7) +- 📊 Llama.cpp timing statistics are now integrated into the usage field for comprehensive model performance metrics. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/e830b4959ecd4b2795e29e53026984a58a7696a9) +- 🗄️ The PGVECTOR_CREATE_EXTENSION environment variable now allows control over automatic pgvector extension creation. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/c2b4976c82d335ed524bd80dc914b5e2f5bfbd9e), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/b45219c8b15b48d5ee3d42983e1107bbcefbab01), [Docs:#672](https://github.com/open-webui/docs/pull/672) +- 🔒 Comprehensive server-side OAuth token management was implemented, securely storing encrypted tokens in a new database table and introducing an automatic refresh mechanism, enabling seamless and secure forwarding of valid user-specific OAuth tokens to downstream services, including OpenAI-compatible endpoints and external tool servers via the new "system_oauth" authentication type, resolving long-standing issues such as large token size limitations, stale/expired tokens, and reliable token propagation, and enhancing overall security by minimizing client-side token exposure, configurable via "ENABLE_OAUTH_ID_TOKEN_COOKIE" and "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY" environment variables. [Docs:#683](https://github.com/open-webui/docs/pull/683), [#17210](https://github.com/open-webui/open-webui/pull/17210), [#8957](https://github.com/open-webui/open-webui/discussions/8957), [#11029](https://github.com/open-webui/open-webui/discussions/11029), [#17178](https://github.com/open-webui/open-webui/issues/17178), [#17183](https://github.com/open-webui/open-webui/issues/17183), [Commit](https://github.com/open-webui/open-webui/commit/217f4daef09b36d3d4cc4681e11d3ebd9984a1a5), [Commit](https://github.com/open-webui/open-webui/commit/fc11e4384fe98fac659e10596f67c23483578867), [Commit](https://github.com/open-webui/open-webui/commit/f11bdc6ab5dd5682bb3e27166e77581f5b8af3e0), [Commit](https://github.com/open-webui/open-webui/commit/f71834720e623761d972d4d740e9bbd90a3a86c6), [Commit](https://github.com/open-webui/open-webui/commit/b5bb6ae177dcdc4e8274d7e5ffa50bc8099fd466), [Commit](https://github.com/open-webui/open-webui/commit/b786d1e3f3308ef4f0f95d7130ddbcaaca4fc927), [Commit](https://github.com/open-webui/open-webui/commit/8a9f8627017bd0a74cbd647891552b26e56aabb7), [Commit](https://github.com/open-webui/open-webui/commit/30d1dc2c60e303756120fe1c5538968c4e6139f4), [Commit](https://github.com/open-webui/open-webui/commit/2b2d123531eb3f42c0e940593832a64e2806240d), [Commit](https://github.com/open-webui/open-webui/commit/6f6412dd16c63c2bb4df79a96b814bf69cb3f880) +- 🔒 Conditional Permission Hardening for OpenShift Deployments: Added a build argument to enable optional permission hardening for OpenShift and container environments. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/0ebe4f8f8490451ac8e85a4846f010854d9b54e5) +- 👥 Regex pattern support is added for OAuth blocked groups, allowing more flexible group filtering rules. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/df66e21472646648d008ebb22b0e8d5424d491df) +- 💬 Web search result display was enhanced to include titles and favicons, providing a clearer overview of search sources. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/33f04a771455e3fabf8f0e8ebb994ae7f41b8ed4), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/0a85dd4bca23022729eafdbc82c8c139fa365af2), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/16090bc2721fde492afa2c4af5927e2b668527e1), [#17197](https://github.com/open-webui/open-webui/pull/17197), [#14179](https://github.com/open-webui/open-webui/issues/14179), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/1cdb7aed1ee9bf81f2fd0404be52dcfa64f8ed4f), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/f2525ebc447c008cf7269ef20ce04fa456f302c4), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/7f523de408ede4075349d8de71ae0214b7e1a62e), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/3d37e4a42d344051ae715ab59bd7b5718e46c343), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/cd5e2be27b613314aadda6107089331783987985), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/6dc0df247347aede2762fe2065cf30275fd137ae) +- 💬 A new setting was added to control whether clicking a suggested prompt automatically sends the message or only inserts the text. [#17192](https://github.com/open-webui/open-webui/issues/17192), [Commit](https://github.com/open-webui/open-webui/commit/e023a98f11fc52feb21e4065ec707cc98e50c7d3) +- 🔄 Various improvements were implemented across the frontend and backend to enhance performance, stability, and security. +- 🌐 Translations for Portuguese (Brazil), Simplified Chinese, Catalan, and Spanish were enhanced and expanded. + +### Fixed + +- 🔍 Hybrid search functionality now correctly handles lexical-semantic weight labels and avoids errors when BM25 weight is zero. [#17049](https://github.com/open-webui/open-webui/pull/17049), [#17046](https://github.com/open-webui/open-webui/issues/17046) +- 🛑 Task stopping errors are prevented by gracefully handling multiple stop requests for the same task. [#17195](https://github.com/open-webui/open-webui/pull/17195) +- 🐍 Code execution package detection precision is improved in Pyodide to prevent unnecessary package inclusions. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/bbe116795860a81a647d9567e0d9cb1950650095) +- 🛠️ Tool message format API compliance is fixed by ensuring content fields in tool call responses contain valid string values instead of null. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/37bf0087e5b8a324009c9d06b304027df351ea6b) +- 📱 Mobile app config API authentication now supports Authorization header token verification with cookie fallback for iOS and Android requests. [#17175](https://github.com/open-webui/open-webui/pull/17175) +- 💾 Knowledge file save race conditions are prevented by serializing API calls and adding an "isSaving" guard. [#17137](https://github.com/open-webui/open-webui/pull/17137), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/4ca936f0bf9813bee11ec8aea41d7e34fb6b16a9) +- 🔐 The SSO login button visibility is restored for OIDC PKCE authentication without a client secret. [#17012](https://github.com/open-webui/open-webui/pull/17012) +- 🔊 Text-to-Speech (TTS) API requests now use proper URL joining methods, ensuring reliable functionality regardless of trailing slashes in the base URL. [#17061](https://github.com/open-webui/open-webui/pull/17061) +- 🛡️ Admin account creation on Hugging Face Spaces now correctly detects the configured port, resolving issues with custom port deployments. [#17064](https://github.com/open-webui/open-webui/pull/17064) +- 📁 Unicode filename support is improved for external document loaders by properly URL-encoding filenames in HTTP headers. [#17013](https://github.com/open-webui/open-webui/pull/17013), [#17000](https://github.com/open-webui/open-webui/issues/17000) +- 🔗 Web page and YouTube attachments are now correctly processed by setting their type as "text" and using collection names for accurate content retrieval. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/487979859a6ffcfd60468f523822cdf838fbef5b) +- ✍️ Message input composition event handling is fixed to properly manage text input for multilingual users using Input Method Editors (IME). [#17085](https://github.com/open-webui/open-webui/pull/17085) +- 💬 Follow-up tooltip duplication is removed, streamlining the user interface and preventing visual clutter. [#17186](https://github.com/open-webui/open-webui/pull/17186) +- 🎨 Chat button text display is corrected by preventing clipping of descending characters and removing unnecessary capitalization. [#17191](https://github.com/open-webui/open-webui/pull/17191) +- 🧠 RAG Loop/Error with Gemma 3.1 2B Instruct is fixed by correctly unwrapping unexpected single-item list responses from models. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/1bc9711afd2b72cd07c4e539a83783868733767c), [#17213](https://github.com/open-webui/open-webui/issues/17213) +- 🖼️ HEIC conversion failures are resolved, improving robustness of image handling. [#17225](https://github.com/open-webui/open-webui/pull/17225) +- 📦 The slim Docker image size regression has been fixed by refining the build process to correctly exclude components when USE_SLIM=true. [#16997](https://github.com/open-webui/open-webui/issues/16997), [Commit](https://github.com/open-webui/open-webui/commit/be373e9fd42ac73b0302bdb487e16dbeae178b4e), [Commit](https://github.com/open-webui/open-webui/commit/0ebe4f8f8490451ac8e85a4846f010854d9b54e5) +- 📁 Knowledge base update validation errors are resolved, ensuring seamless management via UI or API. [#17244](https://github.com/open-webui/open-webui/issues/17244), [Commit](https://github.com/open-webui/open-webui/commit/9aac1489080a5c9441e89b1a56de0d3a672bc5fb) +- 🔐 Resolved a security issue where a global web search setting overrode model-specific restrictions, ensuring model-level settings are now correctly prioritized. [#17151](https://github.com/open-webui/open-webui/issues/17151), [Commit](https://github.com/open-webui/open-webui/commit/9368d0ac751ec3072d5a96712b80a9b20a642ce6) +- 🔐 OAuth redirect reliability is improved by robustly preserving the intended redirect path using session storage. [#17235](https://github.com/open-webui/open-webui/issues/17235), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/4f2b821088367da18374027919594365c7a3f459), [#15575](https://github.com/open-webui/open-webui/pull/15575), [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/d9f97c832c556fae4b116759da0177bf4fe619de) +- 🔐 Fixed a security vulnerability where knowledge base access within chat folders persisted after permissions were revoked. [#17182](https://github.com/open-webui/open-webui/issues/17182), [Commit](https://github.com/open-webui/open-webui/commit/40e40d1dddf9ca937e99af41c8ca038dbc93a7e6) +- 🔒 OIDC access denied errors are now displayed as user-friendly toast notifications instead of raw JSON. [#17208](https://github.com/open-webui/open-webui/issues/17208), [Commit](https://github.com/open-webui/open-webui/commit/3d6d050ad82d360adc42d6e9f42e8faf8d13c9f4) +- 💬 Chat exception handling is enhanced to prevent system instability during message generation and ensure graceful error recovery. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/f56889c5c7f0cf1a501c05d35dfa614e4f8b6958) +- 🔒 Static asset authentication is improved by adding crossorigin="use-credentials" attributes to all link elements, enabling proper cookie forwarding for proxy environments and authenticated requests to favicon, manifest, and stylesheet resources. [#17280](https://github.com/open-webui/open-webui/pull/17280), [Commit](https://github.com/open-webui/open-webui/commit/f17d8b5d19e1a05df7d63f53e939c99772a59c1e) + +### Changed + +- 🛠️ Renamed "Tools" to "External Tools" across the UI for clearer distinction between built-in and external functionalities. [Commit](https://github.com/open-webui/open-webui/pull/17070/commits/0bca4e230ef276bec468889e3be036242ad11086f) +- 🛡️ Default permission validation for message regeneration and deletion actions is enhanced to provide more restrictive access controls, improving chat security and user data protection. [#17285](https://github.com/open-webui/open-webui/pull/17285) + +## [0.6.26] - 2025-08-28 + +### Added + +- 🛂 **Granular Chat Interaction Permissions**: Added fine-grained permission controls for individual chat actions including "Continue Response", "Regenerate Response", "Rate Response", and "Delete Messages". Administrators can now configure these permissions per user group or set system defaults via environment variables, providing enhanced security and governance by preventing potential system prompt leakage through response continuation and enabling precise control over user interactions with AI responses. +- 🧠 **Custom Reasoning Tags Configuration**: Added configurable reasoning tag detection for AI model responses, allowing administrators and users to customize how the system identifies and processes reasoning content. Users can now define custom reasoning tag pairs, use default tags like "think" and "reasoning", or disable reasoning detection entirely through the Advanced Parameters interface, providing enhanced control over AI thought process visibility. +- 📱 **Pull-to-Refresh Support**: Added pull-to-refresh functionality allowing user to easily refresh the interface by pulling down on the navbar area. This resolves timeout issues that occurred when temporarily switching away from the app during long AI response generations, eliminating the need to close and relaunch the PWA. +- 📁 **Configurable File Upload Processing Mode**: Added "process_in_background" query parameter to the file upload API endpoint, allowing clients to choose between asynchronous (default) and synchronous file processing. Setting "process_in_background=false" forces the upload request to wait until extraction and embedding complete, returning immediately usable files and simplifying integration for backend API consumers that prefer blocking calls over polling workflows. +- 🔐 **Azure Document Intelligence DefaultAzureCredential Support**: Added support for authenticating with Azure Document Intelligence using DefaultAzureCredential in addition to API key authentication, enabling seamless integration with Azure Entra ID and managed identity authentication for enterprise Azure environments. +- 🔐 **Authentication Bootstrapping Enhancements**: Added "ENABLE_INITIAL_ADMIN_SIGNUP" environment variable and "?form=true" URL parameter to enable initial admin user creation and forced login form display in SSO-only deployments. This resolves bootstrap issues where administrators couldn't create the first user when login forms were disabled, allowing proper initialization of SSO-configured deployments without requiring temporary configuration changes. +- ⚡ **Query Generation Caching**: Added "ENABLE_QUERIES_CACHE" environment variable to enable request-scoped caching of generated search queries. When both web search and file retrieval are active, queries generated for web search are automatically reused for file retrieval, eliminating duplicate LLM API calls and reducing token usage and costs while maintaining search quality. +- 🔧 **Configurable Tool Call Retry Limit**: Added "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES" environment variable to control the maximum number of sequential tool calls allowed before safety stopping a chat session. This replaces the previous hardcoded limit of 10, enabling administrators to configure higher limits for complex workflows requiring extensive tool interactions. +- 📦 **Slim Docker Image Variant**: Added new slim Docker image option via "USE_SLIM" build argument that excludes embedded AI models and Ollama installation, reducing image size by approximately 1GB. This variant enables faster image pulls and deployments for environments where AI models are managed externally, particularly beneficial for auto-scaling clusters and distributed deployments. +- 🗂️ **Shift-to-Delete Functionality for Workspace Prompts**: Added keyboard shortcut support for quick prompt deletion on the Workspace Prompts page. Hold Shift and hover over any prompt to reveal a trash icon for instant deletion, bringing consistent interaction patterns across all workspace sections (Models, Tools, Functions, and now Prompts) and streamlining prompt management workflows. +- ♿ **Accessibility Enhancements**: Enhanced user interface accessibility with improved keyboard navigation, ARIA labels, and screen reader compatibility across key platform components. +- 📄 **Optimized PDF Export for Smaller File Size**: PDF exports are now significantly optimized, producing much smaller files for faster downloads and easier sharing or archiving of your chats and documents. +- 📦 **Slimmed Default Install with Optional Full Dependencies**: Installing Open WebUI via pip now defaults to a slimmer package; PostgreSQL support is no longer included by default—simply use 'pip install open-webui[all]' to include all optional dependencies for full feature compatibility. +- 🔄 **General Backend Refactoring**: Implemented various backend improvements to enhance performance, stability, and security, ensuring a more resilient and reliable platform for all users. +- 🌐 **Localization & Internationalization Improvements**: Enhanced and expanded translations for Finnish, Spanish, Japanese, Polish, Portuguese (Brazil), and Chinese, including missing translations and typo corrections, providing a more natural and professional user experience for speakers of these languages across the entire interface. + +### Fixed + +- ⚠️ **Chat Error Feedback Restored**: Fixed an issue where various backend errors (tool server failures, API connection issues, malformed responses) would cause chats to load indefinitely without providing user feedback. The system now properly displays error messages when failures occur during chat generation, allowing users to understand issues and retry as needed instead of waiting indefinitely. +- 🖼️ **Image Generation Steps Setting Visibility Fixed**: Fixed a UI issue where the "Set Steps" configuration option was incorrectly displayed for OpenAI and Gemini image generation engines that don't support this parameter. The setting is now only visible for compatible engines like ComfyUI and Automatic1111, eliminating user confusion about non-functional configuration options. +- 📄 **Datalab Marker API Document Loader Fixed**: Fixed broken Datalab Marker API document loader functionality by correcting URL path handling for both hosted Datalab services and self-hosted Marker servers. Removed hardcoded "/marker" paths from the loader code and restored proper default URL structure, ensuring PDF and document processing works correctly with both deployment types. +- 📋 **Citation Error Handling Improved**: Fixed an issue where malformed citation or source objects from external tools, pipes, or filters would cause JavaScript errors and make the chat interface completely unresponsive. The system now gracefully handles missing or undefined citation properties, allowing conversations to load properly even when tools generate defective citation events. +- 👥 **Group User Add API Endpoint Fixed**: Fixed an issue where the "/api/v1/groups/id/{group_id}/users/add" API endpoint would accept requests without errors but fail to actually add users to groups. The system now properly initializes and deduplicates user ID lists, ensuring users are correctly added to and removed from groups via API calls. +- 🛠️ **External Tool Server Error Handling Improved**: Fixed an issue where unreachable or misconfigured external tool servers would cause JavaScript errors and prevent the interface from loading properly. The system now gracefully handles connection failures, displays appropriate error messages, and filters out inaccessible servers while maintaining full functionality for working connections. +- 📋 **Code Block Copy Button Content Fixed**: Fixed an issue where the copy button in code blocks would copy the original AI-generated code instead of any user-edited content, ensuring the copy function always captures the currently displayed code as modified by users. +- 📄 **PDF Export Content Mismatch Fixed**: Resolved an issue where exporting a PDF of one chat while viewing another chat would incorrectly generate the PDF using the currently viewed chat's content instead of the intended chat's content. Additionally optimized the PDF generation algorithm with improved canvas slicing, better memory management, and enhanced image quality, while removing the problematic PDF export option from individual chat menus to prevent further confusion. +- 🖱️ **Windows Sidebar Cursor Icon Corrected**: Fixed confusing cursor icons on Windows systems where sidebar toggle buttons displayed resize cursors (ew-resize) instead of appropriate pointer cursors. The sidebar buttons now show standard pointer cursors on Windows, eliminating user confusion about whether the buttons expand/collapse the sidebar or resize it. +- 📝 **Safari IME Composition Bug Fixed**: Resolved an issue where pressing Enter while composing Chinese text using Input Method Editors (IMEs) on Safari would prematurely send messages instead of completing text composition. The system now properly detects composition states and ignores keydown events that occur immediately after composition ends, ensuring smooth multilingual text input across all browsers. +- 🔍 **Hybrid Search Parameter Handling Fixed**: Fixed an issue where the "hybrid" parameter in collection query requests was not being properly evaluated, causing the system to ignore user-specified hybrid search preferences and only check global configuration. Additionally resolved a division by zero error that occurred in hybrid search when BM25Retriever was called with empty document lists, ensuring robust search functionality across all collection states. +- 💬 **RTL Text Orientation in Messages Fixed**: Fixed text alignment issues in user messages and AI responses for Right-to-Left languages, ensuring proper text direction based on user language settings. Code blocks now consistently use Left-to-Right orientation regardless of the user's language preference, maintaining code readability across all supported languages. +- 📁 **File Content Preview in Modal Restored**: Fixed an issue where clicking on uploaded files would display an empty preview modal, even when the files were successfully processed and available for AI context. File content now displays correctly in the preview modal, ensuring users can verify and review their uploaded documents as intended. +- 🌐 **Playwright Timeout Configuration Corrected**: Fixed an issue where Playwright timeout values were incorrectly converted from milliseconds to seconds with an additional 1000x multiplier, causing excessively long web loading timeouts. The timeout parameter now correctly uses the configured millisecond values as intended, ensuring responsive web search and document loading operations. + +### Changed + +- 🔄 **Follow-Up Question Language Constraint Removed**: Follow-up question suggestions no longer strictly adhere to the chat's primary language setting, allowing for more flexible and diverse suggestion generation that may include questions in different languages based on conversation context and relevance rather than enforced language matching. + +## [0.6.25] - 2025-08-22 + +### Fixed + +- 🖼️ **Image Generation Reliability Restored**: Fixed a key issue causing image generation failures. +- 🏆 **Reranking Functionality Restored**: Resolved errors with rerank feature. + +## [0.6.24] - 2025-08-21 + +### Added + +- ♿ **High Contrast Mode in Chat Messages**: Implemented enhanced High Contrast Mode support for chat messages, making text and important details easier to read and improving accessibility for users with visual preferences or requirements. +- 🌎 **Localization & Internationalization Improvements**: Enhanced and expanded translations for a more natural and professional user experience for speakers of these languages across the entire interface. + +### Fixed + +- 🖼️ **ComfyUI Image Generation Restored**: Fixed a critical bug where ComfyUI-based image generation was not functioning, ensuring users can once again effortlessly create and interact with AI-generated visuals in their workflows. +- 🛠️ **Tool Server Loading and Visibility Restored**: Resolved an issue where connected tool servers were not loading or visible, restoring seamless integration and uninterrupted access to all external and custom tools directly within the platform. +- 🛡️ **Redis User Session Reliability**: Fixed a problem affecting the saving of user sessions in Redis, ensuring reliable login sessions, stable authentication, and secure multi-user environments. + +## [0.6.23] - 2025-08-21 + +### Added + +- ⚡ **Asynchronous Chat Payload Processing**: Refactored the chat completion pipeline to return a response immediately for streaming requests involving web search or tool calls. This enables users to stop ongoing generations promptly and preventing network timeouts during lengthy preprocessing phases, thus significantly improving user experience and responsiveness. +- 📁 **Asynchronous File Upload with Polling**: Implemented an asynchronous file upload process with frontend polling to resolve gateway timeouts and improve reliability when uploading large files. This ensures that even lengthy file processing, such as embedding or transcription, does not block the user interface or lead to connection timeouts, providing a smoother experience for all file operations. +- 📈 **Database Performance Indexes and Migration Script**: Introduced new database indexes on the "chat", "tag", and "function" tables to significantly enhance query performance for SQLite and PostgreSQL installations. For existing deployments, a new Alembic migration script is included to seamlessly apply these indexes, ensuring faster filtering and sorting operations across the platform. +- ✨ **Enhanced Database Performance Options**: Introduced new configurable options to significantly improve database performance, especially for SQLite. This includes "DATABASE_ENABLE_SQLITE_WAL" to enable SQLite WAL (Write-Ahead Logging) mode for concurrent operations, and "DATABASE_DEDUPLICATE_INTERVAL" which, in conjunction with a new deduplication mechanism, reduces redundant updates to "user.last_active_at", minimizing write conflicts across all database types. +- 💾 **Save Temporary Chats Button**: Introduced a new 'Save Chat' button for conversations initiated in temporary mode. This allows users to permanently save valuable temporary conversations to their chat history, providing greater flexibility and ensuring important discussions are not lost. +- 📂 **Chat Movement Options in Menu**: Added the ability to move chats directly to folders from the chat menu. This enhances chat organization and allows users to manage their conversations more efficiently by relocating them between folders with ease. +- 💬 **Language-Aware Follow-Up Suggestions**: Enhanced the AI's follow-up question generation to dynamically adapt to the primary language of the current chat. Follow-up prompts will now be suggested in the same language the user and AI are conversing in, ensuring more natural and contextually relevant interactions. +- 👤 **Expanded User Profile Details**: Introduced new user profile fields including username, bio, gender, and date of birth, allowing for more comprehensive user customization and information management. This enhancement includes corresponding updates to the database schema, API, and user interface for seamless integration. +- 👥 **Direct Navigation to User Groups from User Edit**: Enhanced the user edit modal to include a direct link to the associated user group. This allows administrators to quickly navigate from a user's profile to their group settings, streamlining user and group management workflows. +- 🔧 **Enhanced External Tool Server Compatibility**: Improved handling of responses from external tool servers, allowing both the backend and frontend to process plain text content in addition to JSON, ensuring greater flexibility and integration with diverse tool outputs. +- 🗣️ **Enhanced Audio Transcription Language Fallback and Deepgram Support**: Implemented a robust language fallback mechanism for both OpenAI and Deepgram Speech-to-Text (STT) API calls. If a specified language parameter is not supported by the model or provider, the system will now intelligently retry the transcription without the language parameter or with a default, ensuring greater reliability and preventing failed API calls. This also specifically adds and refines support for the audio language parameter in Deepgram API integrations. +- ⚡ **Optimized Hybrid Search Performance for BM25 Weight Configuration**: Enhanced hybrid search to significantly improve performance when the BM25 weight is set to 0 or less. This optimization intelligently disables unnecessary collection retrieval and BM25 ranking calculations, leading to faster search results without impacting accuracy for configurations that do not utilize lexical search contributions. +- 🔒 **Configurable Code Interpreter Module Blacklist**: Introduced the "CODE_INTERPRETER_BLACKLISTED_MODULES" environment variable, allowing administrators to specify Python modules that are forbidden from being imported or executed within the code interpreter. This significantly enhances the security posture by mitigating risks associated with arbitrary code execution, such as unauthorized data access, system manipulation, or outbound connections. +- 🔐 **Enhanced OAuth Role Claim Handling**: Improved compatibility with diverse OAuth providers by allowing role claims to be supplied as single strings or integers, in addition to arrays. The system now automatically normalizes these single-value claims into arrays for consistent processing, streamlining integration with identity providers that format role data differently. +- ⚙️ **Configurable Tool Call Timeout**: Introduced the "AIOHTTP_CLIENT_TIMEOUT" environment variable, allowing administrators to specify custom timeout durations for external tool calls, which is crucial for integrations with tools that have varying or extended response times. +- 🛠️ **Improved Tool Callable Generation for Google genai SDK**: Enhanced the creation of tool callables to directly support native function calling within the Google 'genai' SDK. This refactoring ensures proper signature inference and removes extraneous parameters, enabling seamless integration for advanced AI workflows using Google's generative AI models. +- ✨ **Dynamic Loading of 'kokoro-js'**: Implemented dynamic loading for the 'kokoro-js' library, preventing failures and improving compatibility on older iOS browsers that may not support direct imports or certain modern JavaScript APIs like 'DecompressionStream'. +- 🖥️ **Improved Command List Visibility on Small Screens**: Resolved an issue where the top items in command lists (e.g., Knowledge Base, Models, Prompts) were hidden or overlapped by the header on smaller screen sizes or specific browser zoom levels. The command option lists now dynamically adjust their height, ensuring all items are fully visible and accessible with proper scrolling. +- 📦 **Improved Docker Image Compatibility for Arbitrary UIDs**: Fixed issues preventing the Open WebUI container from running in environments with arbitrary User IDs (UIDs), such as OpenShift's restricted Security Context Constraints (SCC). The Dockerfile has been updated to correctly set file system permissions for "/app" and "/root" directories, ensuring they are writable by processes running with a supplemental GID 0, thus resolving permission errors for Python libraries and application caches. +- ♿ **Accessibility Enhancements**: Significantly improved the semantic structure of chat messages by using "section", "h2", "ul", and "li" HTML tags, and enhanced screen reader compatibility by explicitly hiding decorative images with "aria-hidden" attributes. This refactoring provides clearer structural context and improves overall accessibility and web standards compliance for the conversation flow. +- 🌐 **Localization & Internationalization Improvements**: Significantly expanded internationalization support throughout the user interface, translating numerous user-facing strings in toast messages, placeholders, and other UI elements. This, alongside continuous refinement and expansion of translations for languages including Brazilian Portuguese, Kabyle (Taqbaylit), Czech, Finnish, Chinese (Simplified), Chinese (Traditional), and German, and general fixes for several other translation files, further enhances linguistic coverage and user experience. + +### Fixed + +- 🛡️ **Resolved Critical OIDC SSO Login Failure**: Fixed a critical issue where OIDC Single Sign-On (SSO) logins failed due to an error in setting the authentication token as a cookie during the redirect process. This ensures reliable and seamless authentication for users utilizing OIDC providers, restoring full login functionality that was impacted by previous security hardening. +- ⚡ **Prevented UI Blocking by Unreachable Webhooks**: Resolved a critical performance and user experience issue where synchronous webhook calls to unreachable or slow endpoints would block the entire user interface for all users. Webhook requests are now processed asynchronously using "aiohttp", ensuring that the UI remains responsive and functional even if webhook delivery encounters delays or failures. +- 🔒 **Password Change Option Hidden for Externally Authenticated Users**: Resolved an issue where the password change dialog was visible to users authenticated via external methods (e.g., LDAP, OIDC, Trusted Header). The option to change a password in user settings is now correctly hidden for these users, as their passwords are managed externally, streamlining the user interface and preventing confusion. +- 💬 **Resolved Temporary Chat and Permission Enforcement Issues**: Fixed a bug where temporary chats (identified by "chat_id = local") incorrectly triggered database checks, leading to 404 errors. This also resolves the issue where the 'USER_PERMISSIONS_CHAT_TEMPORARY_ENFORCED' setting was not functioning as intended, ensuring temporary chat mode now works correctly for user roles. +- 🔐 **Admin Model Visibility for Administrators**: Private models remained visible and usable for administrators in the chat model selector, even when the intended privacy setting ("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS" - now renamed to "BYPASS_ADMIN_ACCESS_CONTROL") was disabled. This ensures consistent enforcement of model access controls and adherence to the principle of least privilege. +- 🔍 **Clarified Web Search Engine Label for DDGS**: Addressed user confusion and inaccurate labeling by renaming "duckduckgo" to "DDGS" (Dux Distributed Global Search) in the web search engine selector. This clarifies that the system utilizes DDGS, a metasearch library that aggregates results from various search providers, accurately reflecting its underlying functionality rather than implying exclusive use of DuckDuckGo's search engine. +- 🛠️ **Improved Settings UI Reactivity and Visibility**: Resolved an issue where settings tabs for 'Connections' and 'Tools' did not dynamically update their visibility based on global administrative feature flags (e.g., 'enable_direct_connections'). The UI now reactively shows or hides these sections, ensuring a consistent and clear experience when administrators control feature availability. +- 🎚️ **Restored Model and Banner Reordering Functionality**: Fixed a bug that prevented administrators from reordering models in the Admin Panel's 'Models' settings and banners in the 'Interface' settings via drag-and-drop. The sortable functionality has been restored, allowing for proper customization of display order. +- 📝 **Restored Custom Pending User Overlay Visibility**: Fixed an issue where the custom title and description configured for pending users were not visible. The application now correctly exposes these UI configuration settings to pending users, ensuring that the custom onboarding messages are displayed as intended. +- 📥 **Fixed Community Function Import Compatibility**: Resolved an issue that prevented the successful import of function files downloaded from openwebui.com due to schema differences. The system now correctly processes these files, allowing for seamless integration of community-contributed functions. +- 📦 **Fixed Stale Ollama Version in Docker Images**: Resolved an issue where the Ollama installation within Docker images could become stale due to caching during the build process. The Dockerfile now includes a mechanism to invalidate the build cache for the Ollama installation step, ensuring that the latest version of Ollama is always installed. +- 🗄️ **Improved Milvus Query Handling for Large Datasets**: Fixed a "MilvusException" that occurred when attempting to query more than 16384 entries from a Milvus collection. The query logic has been refactored to use "query_iterator()", enabling efficient fetching of larger result sets in batches and resolving the previous limitation on the number of entries that could be retrieved. +- 🐛 **Restored Message Toolbar Icons for Empty Messages with Files**: Fixed an issue where the edit, copy, and delete icons were not displayed on user messages that contained an attached file but no text content. This ensures full interaction capabilities for all message types, allowing users to manage their messages consistently. +- 💬 **Resolved Streaming Interruption for Kimi-Dev Models**: Fixed an issue where streaming responses from Kimi-Dev models would halt prematurely upon encountering specific 'thinking' tokens (◁think▷, ◁/think▷). The system now correctly processes these tokens, ensuring uninterrupted streaming and proper handling of hidden or collapsible thinking sections. +- 🔍 **Enhanced Knowledge Base Search Functionality**: Improved the search capability within the 'Knowledge' section of the Workspace. Previously, searching for knowledge bases required exact term matches or starting with the first letter. Now, the search algorithm has been refined to allow broader, less exact matches, making it easier and more intuitive to find relevant knowledge bases. +- 📝 **Resolved Chinese Input 'Enter' Key Issue (macOS & iOS Safari)**: Fixed a bug where pressing the 'Enter' key during text composition with Input Method Editors (IMEs) on macOS and iOS Safari browsers would prematurely send the message. The system now robustly handles the composition state by addressing a 'compositionend' event bug specific to Safari, ensuring a smooth and expected typing experience for users of various languages, including Chinese and Korean. +- 🔐 **Resolved OAUTH_GROUPS_CLAIM Configuration Issue**: Fixed a bug where the "OAUTH_GROUPS_CLAIM" environment variable was not correctly parsed due to a typo in the configuration file. This ensures that OAuth group management features, including automatic group creation, now correctly utilize the specified claim from the identity provider, allowing for seamless integration with external user directories like Keycloak. +- 🗄️ **Resolved Azure PostgreSQL pgvector Extension Permissions**: Fixed an issue preventing the creation of "pgvector" and "pgcrypto" extensions on Azure PostgreSQL Flexible Servers due to permission limitations (e.g., 'Only members of "azure_pg_admin" are allowed to use "CREATE EXTENSION"'). The extension creation process now includes a conditional check, ensuring seamless deployment and compatibility with Azure PostgreSQL environments even with restricted database user permissions. +- 🛠️ **Improved Backend Path Resolution and Alembic Stability**: Fixed issues causing Alembic database migrations to fail due to incorrect path resolution within the application. By implementing canonical path resolution for core directories and refining Alembic configuration, the robustness and correctness of internal pathing have been significantly enhanced, ensuring reliable database operations. +- 📊 **Resolved Arena Model Identification in Feedback History**: Fixed an issue where the model used for feedback in arena settings was incorrectly reported as 'arena-model' in the evaluation history. The system now correctly logs and displays the actual model ID that received the feedback, restoring clarity and enabling proper analysis of model performance in arena environments. +- 🎨 **Resolved Icon Overlap in 'Her' Theme**: Fixed a visual glitch in the 'Her' theme where icons would overlap on the loading screen and certain icons appeared incongruous. The display has been corrected to ensure proper visual presentation and theme consistency. +- 🛠️ **Resolved Model Sorting TypeError with Null Names**: Fixed a "TypeError" that occurred in the "/api/models" endpoint when sorting models with null or missing names. The model sorting logic has been improved to gracefully handle such edge cases by ensuring that model IDs and names are treated as empty strings if their values are null or undefined, preventing comparison errors and improving API stability. +- 💬 **Resolved Silently Dropped Streaming Response Chunks**: Fixed an issue where the final partial chunks of streaming chat responses could be silently dropped, leading to incomplete message delivery. The system now reliably flush any pending delta data upon stream termination, early breaks (e.g., code interpreter tags), or connection closure, ensuring complete and accurate response delivery. +- 📱 **Disabled Overscroll for iOS Frontend**: Fixed an issue where overscrolling was enabled on iOS devices, causing unexpected scrolling behavior over fixed or sticky elements within the PWA. Overscroll has now been disabled, providing a more native application-like experience for iOS users. +- 📝 **Resolved Code Block Input Issue with Shift+Enter**: Fixed a bug where typing three backticks followed by a language and then pressing Shift+Enter would cause the code block prefix to disappear, preventing proper code formatting. The system now correctly preserves the code block syntax, ensuring consistent behavior for multi-line code input. +- 🛠️ **Improved OpenAI Model List Handling for Null Names**: Fixed an edge case where some OpenAI-compatible API providers might return models with a null value for their 'name' field. This could lead to issues like broken model list sorting. The system now gracefully handles these instances by removing the null 'name' key, ensuring stable model retrieval and display. +- 🔍 **Resolved DDGS Concurrent Request Configuration**: Fixed an issue where the configured number of concurrent requests was not being honored for the DDGS (Dux Distributed Global Search) metasearch engine. The system now correctly applies the specified concurrency setting, improving efficiency for web searches. +- 🛠️ **Improved Tool List Synchronization in Multi-Replica Deployments**: Resolved an issue where tool updates were not consistently reflected across all instances in multi-replica environments, leading to stale tool lists for users on other replicas. The tool list in the message input menu is now automatically refreshed each time it is accessed, ensuring all users always see the most current set of available tools. +- 🛠️ **Resolved Duplicate Tool Name Collision**: Fixed an issue where tools with identical names from different external servers were silently removed, preventing their simultaneous use. The system now correctly handles tool name collisions by internally prefixing tools with their server identifier, allowing multiple instances of similarly named tools from different servers to be active and usable by LLMs. +- 🖼️ **Resolved Image Generation API Size Parameter Issue**: Fixed a bug where the "/api/v1/images/generations" API endpoint did not correctly apply the 'size' parameter specified in the request payload for image generation. The system now properly honors the requested image dimensions (e.g., '1980x1080'), ensuring that generated images match the user's explicit size preference rather than defaulting to settings. +- 🗄️ **Resolved S3 Vector Upload Limitations**: Fixed an issue that prevented uploading more than 500 vectors to S3 Vector buckets due to API limitations, which resulted in a "ValidationException". S3 vector uploads are now batched in groups of 500, ensuring successful processing of larger datasets. +- 🛠️ **Fixed Tool Installation Error During Startup**: Resolved a "NoneType" error that occurred during tool installation at startup when 'tool.user' was unexpectedly null. The system now includes a check to ensure 'tool.user' exists before attempting to access its properties, preventing crashes and ensuring robust tool initialization. +- 🛠️ **Improved Azure OpenAI GPT-5 Parameter Handling**: Fixed an issue with Azure OpenAI SDK parameter handling to correctly support GPT-5 models. The 'max_tokens' parameter is now appropriately converted to 'max_completion_tokens' for GPT-5 models, ensuring consistent behavior and proper function execution similar to existing o-series models. +- 🐛 **Resolved Exception with Missing Group Permissions**: Fixed an exception that occurred in the access control logic when group permission objects were missing or null. The system now correctly handles cases where groups may not have explicit permission definitions, ensuring that 'None' checks prevent errors and maintain application stability when processing user permissions. +- 🛠️ **Improved OpenAI API Base URL Handling**: Fixed an issue where a trailing slash in the 'OPENAI_API_BASE_URL' configuration could lead to models not being detected or the endpoint failing. The system now automatically removes trailing slashes from the configured URL, ensuring robust and consistent connections to OpenAI-compatible APIs. +- 🖼️ **Resolved S3-Compatible Storage Upload Failures**: Fixed an issue where uploads to S3-compatible storage providers would fail with an "XAmzContentSHA256Mismatch" error. The system now correctly handles checksum calculations, ensuring reliable file and image uploads to S3-compatible services. +- 🌐 **Corrected 'Releases' Link**: Fixed an issue where the 'Releases' button in the user menu directed to an incorrect URL, now correctly linking to the Open WebUI GitHub releases page. +- 🛠️ **Resolved Model Sorting Errors with Null or Undefined Names**: Fixed multiple "TypeError" instances that occurred when attempting to sort model lists where model names were null or undefined. The sorting logic across various UI components (including Ollama model selection, leaderboard, and admin model settings) has been made more robust by gracefully handling absent model names, preventing crashes and ensuring consistent alphabetical sorting based on available name or ID. +- 🎨 **Resolved Banner Dismissal Issue with Iteration IDs**: Fixed a bug where dismissing banners could lead to unintended multiple banner dismissals or other incorrect behavior, especially when banners lacked unique iteration IDs. Unique IDs are now assigned during banner iteration, ensuring proper individual dismissal and consistent display behavior. + +### Changed + +- 🛂 **Environment Variable for Admin Access Control**: The environment variable "ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS" has been renamed to "BYPASS_ADMIN_ACCESS_CONTROL". This new name more accurately reflects its function as a control to allow administrators to bypass model access restrictions. Users are encouraged to update their configurations to use the new variable name; existing configurations using the old name will still be honored for backward compatibility. +- 🗂️ **Core Directory Path Resolution Updated**: The internal mechanism for resolving core application directory paths ("OPEN_WEBUI_DIR", "BACKEND_DIR", "BASE_DIR") has been updated to use canonical resolution via "Path().resolve()". This change improves path reliability but may require adjustments for any external scripts or configurations that previously relied on specific non-canonical path interpretations. +- 🗃️ **Database Performance Options**: New database performance options, "DATABASE_ENABLE_SQLITE_WAL" and "DATABASE_DEDUPLICATE_INTERVAL", are now available. If "DATABASE_ENABLE_SQLITE_WAL" is enabled, SQLite will operate in WAL mode, which may alter SQLite's file locking behavior. If "DATABASE_DEDUPLICATE_INTERVAL" is set to a non-zero value, the "user.last_active_at" timestamp will be updated less frequently, leading to slightly less real-time accuracy for this specific field but significantly reducing database write conflicts and improving overall performance. Both options are disabled by default. +- 🌐 **Renamed Web Search Concurrency Setting**: The environment variable "WEB_SEARCH_CONCURRENT_REQUESTS" has been renamed to "WEB_LOADER_CONCURRENT_REQUESTS". This change clarifies its scope, explicitly applying to the concurrency of the web loader component (which fetches content from search results) rather than the initial search engine query. Users relying on the old environment variable name for configuring web search concurrency must update their configurations to use "WEB_LOADER_CONCURRENT_REQUESTS". + +## [0.6.22] - 2025-08-11 + +### Added + +- 🔗 **OpenAI API '/v1' Endpoint Compatibility**: Enhanced API compatibility by supporting requests to paths like '/v1/models', '/v1/embeddings', and '/v1/chat/completions'. This allows Open WebUI to integrate more seamlessly with tools that expect OpenAI's '/v1' API structure. +- 🪄 **Toggle for Guided Response Regeneration Menu**: Introduced a new setting in 'Interface' settings, providing the ability to enable or disable the expanded guided response regeneration menu. This offers users more control over their chat workflow and interface preferences. +- ✨ **General UI/UX Enhancements**: Implemented various user interface and experience improvements, including more rounded corners for cards in the Knowledge, Prompts, and Tools sections, and minor layout adjustments within the chat Navbar for improved visual consistency. +- 🌐 **Localization & Internationalization Improvements**: Introduced support for the Kabyle (Taqbaylit) language, refined and expanded translations for Chinese, expanding the platform's linguistic coverage. + +### Fixed + +- 🐞 **OpenAI Error Message Propagation**: Resolved an issue where specific OpenAI API errors (e.g., 'Organization Not Verified') were obscured by generic 'JSONResponse' iterable errors. The system now correctly propagates detailed and actionable error messages from OpenAI to the user. +- 🌲 **Pinecone Insert Issue**: Fixed a bug that prevented proper insertion of items into Pinecone vector databases. +- 📦 **S3 Vector Issue**: Resolved a bug where s3vector functionality failed due to incorrect import paths. +- 🏠 **Landing Page Option Setting Not Working**: Fixed an issue where the landing page option in settings was not functioning as intended. + +## [0.6.21] - 2025-08-10 + +### Added + +- 👥 **User Groups in Edit Modal**: Added display of user groups information in the user edit modal, allowing administrators to view and manage group memberships directly when editing a user. + +### Fixed + +- 🐞 **Chat Completion 'model_id' Error**: Resolved a critical issue where chat completions failed with an "undefined model_id" error after upgrading to version 0.6.20, ensuring all models now function correctly and reliably. +- 🛠️ **Audit Log User Information Logging**: Fixed an issue where user information was not being correctly logged in the audit trail due to an unreflected function prototype change, ensuring complete logging for administrative oversight. +- 🛠️ **OpenTelemetry Configuration Consistency**: Fixed an issue where OpenTelemetry metric and log exporters' 'insecure' settings did not correctly default to the general OpenTelemetry 'insecure' flag, ensuring consistent security configurations across all OpenTelemetry exports. +- 📝 **Reply Input Content Display**: Fixed an issue where replying to a message incorrectly displayed '{{INPUT_CONTENT}}' instead of the actual message content, ensuring proper content display in replies. +- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Catalan, Korean, Spanish and Irish, ensuring a more fluent and native experience for global users. + +## [0.6.20] - 2025-08-10 + +### Fixed + +- 🛠️ **Quick Actions "Add" Behavior**: Fixed a bug where using the "Add" button in Quick Actions would add the resulting message as the very first message in the chat, instead of appending it to the latest message. + +## [0.6.19] - 2025-08-09 + +### Added + +- ✨ **Modernized Sidebar and Major UI Refinements**: The main navigation sidebar has been completely redesigned with a modern, cleaner aesthetic, featuring a sticky header and footer to keep key controls accessible. Core sidebar logic, like the pinned models list, was also refactored into dedicated components for better performance and maintainability. +- 🪄 **Guided Response Regeneration**: The "Regenerate" button has been transformed into a powerful new menu. You can now guide the AI's next attempt by suggesting changes in a text prompt, or use one-click options like "Try Again," "Add Details," or "More Concise" to instantly refine and reshape the response to better fit your needs. +- 🛠️ **Improved Tool Call Handling for GPT-OSS Models**: Implemented robust handling for tool calls specifically for GPT-OSS models, ensuring proper function execution and integration. +- 🛑 **Stop Button for Merge Responses**: Added a dedicated stop button to immediately halt the generation of merged AI responses, providing users with more control over ongoing outputs. +- 🔄 **Experimental SCIM 2.0 Support**: Implemented SCIM 2.0 (System for Cross-domain Identity Management) protocol support, enabling enterprise-grade automated user and group provisioning from identity providers like Okta, Azure AD, and Google Workspace for seamless user lifecycle management. Configuration is managed securely via environment variables. +- 🗂️ **Amazon S3 Vector Support**: You can now use Amazon S3 Vector as a high-performance vector database for your Retrieval-Augmented Generation (RAG) workflows. This provides a scalable, cloud-native storage option for users deeply integrated into the AWS ecosystem, simplifying infrastructure and enabling enterprise-scale knowledge management. +- 🗄️ **Oracle 23ai Vector Search Support**: Added support for Oracle 23ai's new vector search capabilities as a supported vector database, providing a robust and scalable option for managing large-scale documents and integrating vector search with existing business data at the database level. +- ⚡ **Qdrant Performance and Configuration Enhancements**: The Qdrant client has been significantly improved with faster data retrieval logic for 'get' and 'query' operations. New environment variables ('QDRANT_TIMEOUT', 'QDRANT_HNSW_M') provide administrators with finer control over query timeouts and HNSW index parameters, enabling better performance tuning for large-scale deployments. +- 🔐 **Encrypted SQLite Database with SQLCipher**: You can now encrypt your entire SQLite database at rest using SQLCipher. By setting the 'DATABASE_TYPE' to 'sqlite+sqlcipher' and providing a 'DATABASE_PASSWORD', all data is transparently encrypted, providing an essential security layer for protecting sensitive information in self-hosted deployments. Note that this requires additional system libraries and the 'sqlcipher3-wheels' Python package. +- 🚀 **Efficient Redis Connection Management**: Implemented a shared connection pool cache to reuse Redis connections, dramatically reducing the number of active clients. This prevents connection exhaustion errors, improves performance, and ensures greater stability in high-concurrency deployments and those using Redis Sentinel. +- ⚡ **Batched Response Streaming for High Performance**: Dramatically improve performance and stability during high-speed response streaming by batching multiple tokens together before sending them to the client. A new 'Stream Delta Chunk Size' advanced parameter can be set per-model or in user/chat settings, significantly reducing CPU load on the server, Redis, and client, and preventing connection issues in high-concurrency environments. +- ⚙️ **Global Batched Streaming Configuration**: Administrators can now set a system-wide default for response streaming using the new 'CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE' environment variable. This allows for global performance tuning, while still letting per-model or per-chat settings override the default for more granular control. +- 🔎 **Advanced Chat Search with Status Filters**: Quickly find any conversation with powerful new search filters. You can now instantly narrow down your chats using prefixes like 'pinned:true', 'shared:true', and 'archived:true' directly in the search bar. An intelligent dropdown menu assists you by suggesting available filter options as you type, streamlining your workflow and making chat management more efficient than ever. +- 🛂 **Granular Chat Controls Permissions**: Administrators can now manage chat settings with greater detail. The main "Chat Controls" permission now acts as a master switch, while new granular toggles for "Valves", "System Prompts", and "Advanced Parameters" allow for more specific control over which sections are visible to users inside the panel. +- ✍️ **Formatting Toolbar for Chat Input**: Introduced a dedicated formatting toolbar for the rich text chat input field, providing users with more accessible options for text styling and editing, configurable via interface settings. +- 📑 **Tabbed View for Multi-Model Responses**: You can now enable a new tabbed interface to view responses from multiple models. Instead of side-scrolling cards, this compact view organizes each model's response into its own tab, making it easier to compare outputs and saving vertical space. This feature can be toggled on or off in Interface settings. +- ↕️ **Reorder Pinned Models via Drag-and-Drop**: You can now organize your pinned models in the sidebar by simply dragging and dropping them into your preferred order. This custom layout is saved automatically, giving you more flexible control over your workspace. +- 📌 **Quick Model Unpin Shortcut**: You can now quickly unpin a model by holding the Shift key and hovering over it to reveal an instant unpin button, streamlining your workspace customization. +- ⚡ **Improved Chat Input Performance**: The chat input is now significantly more responsive, especially when pasting or typing large amounts of text. This was achieved by implementing a debounce mechanism for the auto-save feature, which prevents UI lag and ensures a smooth, uninterrupted typing experience. +- ✍️ **Customizable Floating Quick Actions with Tool Support**: Take full control of your text interaction workflow with new customizable floating quick actions. In Settings, you can create, edit, or disable these actions and even integrate tools using the '{{TOOL:tool_id}}' syntax in your prompts, enabling powerful one-click automations on selected text. This is in addition to using placeholders like '{{CONTENT}}' and '{{INPUT_CONTENT}}' for custom text transformations. +- 🔒 **Admin Workspace Privacy Control**: Introduced the 'ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS' environment variable (defaults to 'True') allowing administrators to control their access privileges to workspace items (Knowledge, Models, Prompts, Tools). When disabled, administrators adhere to the same access control rules as regular users, enhancing data separation for multi-tenant deployments. +- 🗄️ **Comprehensive Model Configuration Management**: Administrators can now export the entire model configuration to a file and use a new declarative sync endpoint to manage models in bulk. This powerful feature enables seamless backups, migrations, and state replication across multiple instances. +- 📦 **Native Redis Cluster Mode Support**: Added full support for connecting to Redis in cluster mode, allowing for scalable and highly available Redis deployments beyond Sentinel-managed setups. New environment variables 'REDIS_CLUSTER' and 'WEBSOCKET_REDIS_CLUSTER' enable the use of 'redis.cluster.RedisCluster' clients. +- 📊 **Granular OpenTelemetry Metrics Configuration**: Introduced dedicated environment variables and enhanced configuration options for OpenTelemetry metrics, allowing for separate OTLP endpoints, basic authentication credentials, and protocol (HTTP/gRPC) specifically for metrics export, independent of trace settings. This provides greater flexibility for integrating with diverse observability stacks. +- 🪵 **Granular OpenTelemetry Logging Configuration**: Enhanced the OpenTelemetry logging integration by introducing dedicated environment variables for logs, allowing separate OTLP endpoints, basic authentication credentials, and protocol (HTTP/gRPC) specifically for log export, independent of general OTel settings. The application's default Python logger now leverages this configuration to automatically send logs to your OTel endpoint when enabled via 'ENABLE_OTEL_LOGS'. +- 📁 **Enhanced Folder Chat Management with Sorting and Time Blocks**: The chat list within folders now supports comprehensive sorting options by title and updated time, along with intelligent time-based grouping (e.g., "Today," "Yesterday") similar to the main chat view, making navigation and organization of project-specific conversations significantly easier. +- ⚙️ **Configurable Datalab Marker API & Advanced Processing Options**: Enhanced Datalab Marker API integration, allowing administrators to configure custom API base URLs for self-hosting and to specify comprehensive processing options via a new 'additional_config' JSON parameter. This replaces the deprecated language selection feature and provides granular control over document extraction, with streamlined API endpoint resolution for more robust self-hosted deployments. +- 🧑‍💼 **Export All Users to CSV**: Administrators can now export a complete list of all users to a CSV file directly from the Admin Panel's database settings. This provides a simple, one-click way to generate user data for auditing, reporting, or management purposes. +- 🛂 **Customizable OAuth 'sub' Claim**: Administrators can now use the 'OAUTH_SUB_CLAIM_OVERRIDE' environment variable to specify which claim from the identity provider should be used as the unique user identifier ('sub'). This provides greater flexibility and control for complex enterprise authentication setups where modifying the IDP's default claims is not possible. +- 👁️ **Password Visibility Toggle for Input Fields**: Password fields across the application (login, registration, user management, and account settings) now utilize a new 'SensitiveInput' component, providing a consistent toggle to reveal/hide passwords for improved usability and security. +- 🛂 **Optional "Confirm Password" on Sign-Up**: To help prevent password typos during account creation, administrators can now enable a "Confirm Password" field on the sign-up page. This feature is disabled by default and can be activated via an environment variable for enhanced user experience. +- 💬 **View Full Chat from User Feedback**: Administrators can now easily navigate to the full conversation associated with a user feedback entry directly from the feedback modal, streamlining the review and troubleshooting process. +- 🎚️ **Intuitive Hybrid Search BM25-Weight Slider**: The numerical input for the BM25-Weight parameter in Hybrid Search has been replaced with an interactive slider, offering a more intuitive way to adjust the balance between lexical and semantic search. A "Default/Custom" toggle and clearer labels enhance usability and understanding of this key parameter. +- ⚙️ **Enhanced Bulk Function Synchronization**: The API endpoint for synchronizing functions has been significantly improved to reliably handle bulk updates. This ensures that importing and managing large libraries of functions is more robust and error-free for administrators. +- 🖼️ **Option to Disable Image Compression in Channels**: Introduced a new setting under Interface options to allow users to force-disable image compression specifically for images posted in channels, ensuring higher resolution for critical visual content. +- 🔗 **Custom CORS Scheme Support**: Introduced a new environment variable 'CORS_ALLOW_CUSTOM_SCHEME' that allows administrators to define custom URL schemes (e.g., 'app://') for CORS origins, enabling greater flexibility for local development or desktop client integrations. +- ♿ **Translatable and Accessible Banners**: Enhanced banner elements with translatable badge text and proper ARIA attributes (aria-label, aria-hidden) for SVG icons, significantly improving accessibility and screen reader compatibility. +- ⚠️ **OAuth Configuration Warning for Missing OPENID_PROVIDER_URL**: Added a proactive startup warning that notifies administrators when OAuth providers (Google, Microsoft, or GitHub) are configured but the essential 'OPENID_PROVIDER_URL' environment variable is missing. This prevents silent OAuth logout failures and guides administrators to complete their setup correctly. +- ♿ **Major Accessibility Enhancements**: Key parts of the interface have been made significantly more accessible. The user profile menu is now fully navigable via keyboard, essential controls in the Playground now include proper ARIA labels for screen readers, and decorative images have been hidden from assistive technologies to reduce audio clutter. Menu buttons also feature enhanced accessibility with 'aria-label', 'aria-hidden' for SVGs, and 'aria-pressed' for toggle buttons. +- ⚙️ **General Backend Refactoring**: Implemented various backend improvements to enhance performance, stability, and security, ensuring a more resilient and reliable platform for all users, including refining logging output to be cleaner and more efficient by conditionally including 'extra_json' fields and improving consistent metadata handling in vector database operations, and laying preliminary scaffolding for future analytics features. +- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Catalan, Danish, Korean, Persian, Polish, Simplified Chinese, and Spanish, ensuring a more fluent and native experience for global users across all supported languages. + +### Fixed + +- 🛡️ **Hardened Channel Message Security**: Fixed a key permission flaw that allowed users with channel access to edit or delete messages belonging to others. The system now correctly enforces that users can only modify their own messages, protecting data integrity in shared channels. +- 🛡️ **Hardened OAuth Security by Removing JWT from URL**: Fixed a critical security vulnerability where the authentication token was exposed in the URL after a successful OAuth login. The token is now transferred via a browser cookie, preventing potential leaks through browser history or server logs and protecting user sessions. +- 🛡️ **Hardened Chat Completion API Security**: The chat completion API endpoint now includes an explicit ownership check, ensuring non-admin users cannot access chats that do not belong to them and preventing potential unauthorized access. +- 🛠️ **Resilient Model Loading**: Fixed an issue where a failure in loading the model list (e.g., from a misconfigured provider) would prevent the entire user interface, including the admin panel, from loading. The application now gracefully handles these errors, ensuring the UI remains accessible. +- 🔒 **Resolved FIPS Self-Test Failure**: Fixed a critical issue that prevented Open WebUI from running on FIPS-compliant systems, specifically resolving the "FATAL FIPS SELFTEST FAILURE" error related to OpenSSL and SentenceTransformers, restoring compatibility with secure environments. +- 📦 **Redis Cluster Connection Restored**: Fixed an issue where the backend was unable to connect to Redis in cluster mode, now ensuring seamless integration with scalable Redis cluster deployments. +- 📦 **PGVector Connection Stability**: Fixed an issue where read-only operations could leave database transactions idle, preventing potential connection errors and improving overall database stability and resource management. +- 🛠️ **OpenAPI Tool Integration for Array Parameters Fixed**: Resolved a critical bug where external tools using array parameters (e.g., for tags) would fail when used with OpenAI models. The system now correctly generates the required 'items' property in the function schema, restoring functionality and preventing '400 Bad Request' errors. +- 🛠️ **Tool Creation for Users Restored**: Fixed a bug in the code editor where status messages were incorrectly prepended to tool scripts, causing a syntax error upon saving. All authorized users can now reliably create and save new tools. +- 📁 **Folder Knowledge Processing Restored**: Fixed a bug where files uploaded to folder and model knowledge bases were not being extracted or analyzed for Retrieval-Augmented Generation (RAG) when the 'Max Upload Count' setting was empty, ensuring seamless document processing and knowledge augmentation. +- 🧠 **Custom Model Knowledge Base Updates Recognized**: Fixed a bug where custom models linked to to knowledge bases did not automatically recognize newly added files to those knowledge bases. Models now correctly incorporate the latest information from updated knowledge collections. +- 📦 **Comprehensive Redis Key Prefixing**: Corrected hardcoded prefixes to ensure the REDIS_KEY_PREFIX is now respected across all WebSocket and task management keys. This prevents data collisions in multi-instance deployments and improves compatibility with Redis cluster mode. +- ✨ **More Descriptive OpenAI Router Errors**: The OpenAI-compatible API router now propagates detailed upstream error messages instead of returning a generic 'Bad Request'. This provides clear, actionable feedback for developers and API users, making it significantly easier to debug and resolve issues with model requests. +- 🔐 **Hardened OIDC Signout Flow**: The OpenID Connect signout process now verifies that the 'OPENID_PROVIDER_URL' is configured before attempting to communicate with it, preventing potential errors and ensuring a more reliable logout experience. +- 🍓 **Raspberry Pi Compatibility Restored**: Pinned the pyarrow library to version 20.0.0, resolving an "Illegal Instruction" crash on ARM-based devices like the Raspberry Pi and ensuring stable operation on this hardware. +- 📁 **Folder System Prompt Variables Restored**: Fixed a bug where prompt variables (e.g., '{{CURRENT_DATETIME}}') were not being rendered in Folder-level System Prompts. This restores an important capability for creating dynamic, context-aware instructions for all chats within a project folder. +- 📝 **Note Access in Knowledge Retrieval Fixed**: Corrected a permission oversight in knowledge retrieval, ensuring users can always use their own notes as a source for RAG without needing explicit sharing permissions. +- 🤖 **Title Generation Compatibility for GPT-5 Models**: Added support for 'gpt-5' models in the payload handler, which correctly converts the deprecated 'max_tokens' parameter to 'max_completion_tokens'. This resolves title generation failures and ensures seamless operation with the latest generation of models. +- ⚙️ **Correct API 'finish_reason' in Streaming Responses**: Fixed an issue where intermediate 'reasoning_content' chunks in streaming API responses incorrectly reported a 'finish_reason' of 'stop'. The 'finish_reason' is now correctly set to 'null' for these chunks, ensuring compatibility with third-party applications that rely on this field. +- 📈 **Evaluation Pages Stability**: Resolved a crash on the Leaderboard and Feedbacks pages when processing legacy feedback entries that were missing a 'rating' field. The system now gracefully handles this older data, ensuring both pages load reliably for all users. +- 🤝 **Reliable Collaborative Session Cleanup**: Fixed an asynchronous bug in the real-time collaboration engine that prevented document sessions from being properly cleaned up after all users had left. This ensures greater stability and resource management for features like Collaborative Notes. +- 🧠 **Enhanced Memory Stability and Security**: Refactored memory update and delete operations to strictly enforce user ownership, preventing potential data integrity issues. Additionally, improved error handling for memory queries now provides clearer feedback when no memories exists. +- 🧑‍⚖️ **Restored Admin Access to User Feedback**: Fixed a permission issue that blocked administrators from viewing or editing user feedback they didn't create, ensuring they can properly manage all evaluations across the platform. +- 🔐 **PGVector Encryption Fix for Metadata**: Corrected a SQL syntax error in the experimental 'PGVECTOR_PGCRYPTO' feature that prevented encrypted metadata from being saved. Document uploads to encrypted PGVector collections now work as intended. +- 🔍 **Serply Web Search Integration Restored**: Fixed an issue where incorrect parameters were passed to the Serply web search engine, restoring its functionality for RAG and web search workflows. +- 🔍 **Resilient Web Search Processing**: Web search retrieval now gracefully handles search results that are missing a 'snippet', preventing crashes and ensuring that RAG workflows complete successfully even with incomplete data from search engines. +- 🖼️ **Table Pasting in Rich Text Input Displayed Correctly**: Fixed an issue where pasting table text into the rich text input would incorrectly display it as code. Tables are now properly rendered as expected, improving content formatting and user experience. +- ✍️ **Rich Text Input TypeError Resolution**: Addressed a potential 'TypeError: ue.getWordAtDocPos is not a function' in 'MessageInput.svelte' by refactoring how the 'getWordAtDocPos' function is accessed and referenced from 'RichTextInput.svelte', ensuring stable rich text input behavior, especially after production restarts. +- ✏️ **Manual Code Block Creation in Chat Restored**: Fixed an issue where typing three backticks and then pressing Shift+Enter would incorrectly remove the backticks when "Enter to Send" mode was active. This ensures users can reliably create multi-line code blocks manually. +- 🎨 **Consistent Dark Mode Background**: Fixed an issue where the application background could incorrectly flash or remain white during page loads and refreshes in dark mode, ensuring a seamless and consistent visual experience. +- 🎨 **'Her' Theme Rendering Fixed**: Corrected a bug that caused the "Her" theme to incorrectly render as a dark theme in some situations. The theme now reliably applies its intended light appearance across all sessions. +- 📜 **Corrected Markdown Table Line Break Rendering**: Fixed an issue where line breaks ('
') within Markdown tables were displayed as raw HTML instead of being rendered correctly. This ensures that tables with multi-line cell content are now displayed as intended. +- 🚦 **Corrected App Configuration for Pending Users**: Fixed an issue where users awaiting approval could incorrectly load the full application interface, leading to a confusing or broken UI. This ensures that only fully approved users receive the standard app 'config', resulting in a smoother and more reliable onboarding experience. +- 🔄 **Chat Cloning Now Includes Tags, Folder Status, and Pinned Status**: When cloning a chat or shared chat, its associated tags, folder organization, and pinned status are now correctly replicated, ensuring consistent chat management. +- ⚙️ **Enhanced Backend Reliability**: Resolved a potential crash in knowledge base retrieval when referencing a deleted note. Additionally, chat processing was refactored to ensure model information is saved more reliably, enhancing overall system stability. +- ⚙️ **Floating 'Ask/Explain' Modal Stability**: Fixed an issue that spammed the console with errors when navigating away while a model was generating a response in the floating 'Ask' or 'Explain' modals. In-flight requests are now properly cancelled, improving application stability. +- ⚡ **Optimized User Count Checks**: Improved performance for user count and existence checks across the application by replacing resource-intensive 'COUNT' queries with more efficient 'EXISTS' queries, reducing database load. +- 🔐 **Hardened OpenTelemetry Exporter Configuration**: The OTLP HTTP exporter no longer uses a potentially insecure explicit flag, improving security by relying on the connection URL's protocol (HTTP/HTTPS) to ensure transport safety. +- 📱 **Mobile User Menu Closing Behavior Fixed**: Resolved an issue where the user menu would remain open on mobile devices after selecting an option, ensuring the menu correctly closes and returns focus to the main interface for a smoother mobile experience. +- 📱 **OnBoarding Page Display Fixed on Mobile**: Resolved an issue where buttons on the OnBoarding page were not consistently visible on certain mobile browsers, ensuring a functional and complete user experience across devices. +- ↕️ **Improved Pinned Models Drag-and-Drop Behavior**: The drag-and-drop functionality for reordering pinned models is now explicitly disabled on mobile devices, ensuring better usability and preventing potential UI conflicts or unexpected behavior. +- 📱 **PWA Rotation Behavior Corrected**: The Progressive Web App now correctly respects the device's screen orientation lock, preventing unwanted rotation and ensuring a more native mobile experience. +- ✏️ **Improved Chat Title Editing Behavior**: Changes to a chat title are now reliably saved when the user clicks away or presses Enter, replacing a less intuitive behavior that could accidentally discard edits. This makes renaming chats a smoother and more predictable experience. +- ✏️ **Underscores Allowed in Prompt Commands**: Fixed the validation for prompt commands to correctly allow the use of underscores ('\_'), aligning with documentation examples and improving flexibility in naming custom prompts. +- 💡 **Title Generation Button Behavior Fixed**: Resolved an issue where clicking the "Generate Title" button while editing a chat or note title would incorrectly save the title before generation could start. The focus is now managed correctly, ensuring a smooth and predictable user experience. +- ✏️ **Consistent Chat Input Height**: Fixed a minor visual bug where the chat input field's height would change slightly when toggling the "Rich Text Input for Chat" setting, ensuring a more stable and consistent layout. +- 🙈 **Admin UI Toggle Stability**: Fixed a visual glitch in the Admin settings where toggle switches could briefly display an incorrect state on page load, ensuring the UI always accurately reflects the saved settings. +- 🙈 **Community Sharing Button Visibility**: The "Share to Community" button on the feedback page is now correctly hidden when the Enable Community Sharing feature is disabled in the admin settings, ensuring the UI respects the configured sharing policy. +- 🙈 **"Help Us Translate" Link Visibility**: The "Help us translate" link in settings is now correctly hidden in deployments with specific license configurations, ensuring a cleaner interface for enterprise users. +- 🔗 **Robust Tool Server URL Handling**: Fixed an issue where providing a full URL for a tool server's OpenAPI specification resulted in an invalid path. The system now correctly handles both absolute URLs and relative paths, improving configuration flexibility. +- 🔧 **Improved Azure URL Detection**: The logic for identifying Azure OpenAI endpoints has been made more robust, ensuring all valid Azure URLs are now correctly detected for a smoother connection setup. +- ⚙️ **Corrected Direct Connection Save Logic**: Fixed a bug in the Admin Connections settings page by removing a redundant save action for 'Direct Connections', leading to more reliable and predictable behavior when updating settings. +- 🔗 **Corrected "Discover" Links**: The "Discover" links for models, prompts, tools, and functions now point to their specific, relevant pages on openwebui.com, improving content discovery for users. +- ⏱️ **Refined Display of AI Thought Duration**: Adjusted the display logic for AI thought (reasoning) durations to more accurately show very short thought times as "less than a second," improving clarity in AI process feedback. +- 📜 **Markdown Line Break Rendering Refinement**: Improved handling of line breaks within Markdown rendering for better visual consistency. +- 🛠️ **Corrected OpenTelemetry Docker Compose Example**: The docker-compose.otel.yaml file has been fixed and enhanced by removing duplicates, adding necessary environment variables, and hardening security settings, ensuring a more reliable out-of-box observability setup. +- 🛠️ **Development Script CORS Fix**: Corrected the CORS origin URL in the local development script (dev.sh) by removing the trailing slash, ensuring a more reliable and consistent setup for developers. +- ⬆️ **OpenTelemetry Libraries Updated**: Upgraded all OpenTelemetry-related libraries to their latest versions, ensuring better performance, stability, and compatibility for observability. + +### Changed + +- ❗ **Docling Integration Upgraded to v1 API (Breaking Change)**: The integration with the Docling document processing engine has been updated to its new, stable '/v1' API. This is required for compatibility with Docling version 1.0.0 and newer. As a result, older versions of Docling are no longer supported. Users who rely on Docling for document ingestion **must upgrade** their docling-serve instance to ensure continued functionality. +- 🗣️ **Admin-First Whisper Language Priority**: The global WHISPER_LANGUAGE setting now acts as a strict override for audio transcriptions. If set, it will be used for all speech-to-text tasks, ignoring any language specified by the user on a per-request basis. This gives administrators more control over transcription consistency. +- ✂️ **Datalab Marker API Language Selection Removed**: The separate language selection option for the Datalab Marker API has been removed, as its functionality is now integrated and superseded by the more comprehensive 'additional_config' parameter. Users should transition to using 'additional_config' for relevant language and processing settings. +- 📄 **Documentation and Releases Links Visibility**: The "Documentation" and "Releases" links in the user menu are now visible only to admin users, streamlining the user interface for non-admin roles. + +## [0.6.18] - 2025-07-19 + +### Fixed + +- 🚑 **Users Not Loading in Groups**: Resolved an issue where user list was not displaying within user groups, restoring full visibility and management of group memberships for teams and admins. + +## [0.6.17] - 2025-07-19 + +### Added + +- 📂 **Dedicated Folder View with Chat List**: Clicking a folder now reveals a brand-new landing page showcasing a list of all chats within that folder, making navigation simpler and giving teams immediate visibility into project-specific conversations. +- 🆕 **Streamlined Folder Creation Modal**: Creating a new folder is now a seamless, unified experience with a dedicated modal that visually and functionally matches the edit folder flow, making workspace organization more intuitive and error-free for all users. +- 🗃️ **Direct File Uploads to Folder Knowledge**: You can now upload files straight to a folder’s knowledge—empowering you to enrich project spaces by adding resources and documents directly, without the need to pre-create knowledge bases beforehand. +- 🔎 **Chat Preview in Search**: When searching chats, instantly preview results in context without having to open them—making discovery, auditing, and recall dramatically quicker, especially in large, active teams. +- 🖼️ **Image Upload and Inline Insertion in Notes**: Notes now support inserting images directly among your text, letting you create rich, visually structured documentation, brainstorms, or reports in a more natural and engaging way—no more images just as attachments. +- 📱 **Enhanced Note Selection Editing and Q&A**: Select any portion of your notes to either edit just the highlighted part or ask focused questions about that content—streamlining workflows, boosting productivity, and making reviews or AI-powered enhancements more targeted. +- 📝 **Copy Notes as Rich Text**: Copy entire notes—including all formatting, images, and structure—directly as rich text for seamless pasting into emails, reports, or other tools, maintaining clarity and consistency outside the WebUI. +- ⚡ **Fade-In Streaming Text Experience**: Live-generated responses now elegantly fade in as the AI streams them, creating a more natural and visually engaging reading experience; easily toggled off in Interface settings if you prefer static displays. +- 🔄 **Settings for Follow-Up Prompts**: Fine-tune your follow-up prompt experience—with new controls, you can choose to keep them visible or have them inserted directly into the message input instead of auto-submitting, giving you more flexibility and control over your workflow. +- 🔗 **Prompt Variable Documentation Quick Link**: Access documentation for prompt variables in one click from the prompt editor modal—shortening the learning curve and making advanced prompt-building more accessible. +- 📈 **Active and Total User Metrics for Telemetry**: Gain valuable insights into usage patterns and platform engagement with new metrics tracking active and total users—enhancing auditability and planning for large organizations. +- 🏷️ **Traceability with Log Trace and Span IDs**: Each log entry now carries detailed trace and span IDs, making it much easier for admins to pinpoint and resolve issues across distributed systems or in complex troubleshooting. +- 👥 **User Group Add/Remove Endpoints**: Effortlessly add or remove users from groups with new, improved endpoints—giving admins and team leads faster, clearer control over collaboration and permissions. +- ⚙️ **Note Settings and Controls Streamlined**: The main “Settings” for notes are now simply called “Controls”, and note files now reside in a dedicated controls section, decluttering navigation and making it easier to find and configure note-related options. +- 🚀 **Faster Admin User Page Loads**: The user list endpoint for admins has been optimized to exclude heavy profile images, speeding up load times for large teams and reducing waiting during administrative tasks. +- 📡 **Chat ID Header Forwarding**: Ollama and OpenAI router requests now include the chat ID in request headers, enabling better request correlation and debugging capabilities across AI model integrations. +- 🧠 **Enhanced Reasoning Tag Processing**: Improved and expanded reasoning tag parsing to handle various tag formats more robustly, including standard XML-style tags and custom delimiters, ensuring better AI reasoning transparency and debugging capabilities. +- 🔐 **OAuth Token Endpoint Authentication Method**: Added configurable OAuth token endpoint authentication method support, providing enhanced flexibility and security options for enterprise OAuth integrations and identity provider compatibility. +- 🛡️ **Redis Sentinel High Availability Support**: Comprehensive Redis Sentinel failover implementation with automatic master discovery, intelligent retry logic for connection failures, and seamless operation during master node outages—eliminating single points of failure and ensuring continuous service availability in production deployments. +- 🌐 **Localization & Internationalization Improvements**: Refined and expanded translations for Simplified Chinese, Traditional Chinese, French, German, Korean, and Polish, ensuring a more fluent and native experience for global users across all supported languages. + +### Fixed + +- 🏷️ **Hybrid Search Functionality Restored**: Hybrid search now works seamlessly again—enabling more accurate, relevant, and comprehensive knowledge discovery across all RAG-powered workflows. +- 🚦 **Note Chat - Edit Button Disabled During AI Generation**: The edit button when chatting with a note is now disabled while the AI is responding—preventing accidental edits and ensuring workflow clarity during chat sessions. +- 🧹 **Cleaner Database Credentials**: Database connection no longer duplicates ‘@’ in credentials, preventing potential connection issues and ensuring smoother, more reliable integrations. +- 🧑‍💻 **File Deletion Now Removes Related Vector Data**: When files are deleted from storage, they are now purged from the vector database as well, ensuring clean data management and preventing clutter or stale search results. +- 📁 **Files Modal Translation Issues Fixed**: All modal dialog strings—including “Using Entire Document” and “Using Focused Retrieval”—are now fully translated for a more consistent and localized UI experience. +- 🚫 **Drag-and-Drop File Upload Disabled for Unsupported Models**: File upload by drag-and-drop is disabled when using models that do not support attachments—removing confusion and preventing workflow interruptions. +- 🔑 **Ollama Tool Calls Now Reliable**: Fixed issues with Ollama-based tool calls, ensuring uninterrupted AI augmentation and tool use for every chat. +- 📄 **MIME Type Help String Correction**: Cleaned up mimetype help text by removing extraneous characters, providing clearer guidance for file upload configurations. +- 📝 **Note Editor Permission Fix**: Removed unnecessary admin-only restriction from note chat functionality, allowing all authorized users to access note editing features as intended. +- 📋 **Chat Sources Handling Improved**: Fixed sources handling logic to prevent duplicate source assignments in chat messages, ensuring cleaner and more accurate source attribution during conversations. +- 😀 **Emoji Generation Error Handling**: Improved error handling in audio router and fixed metadata structure for emoji generation tasks, preventing crashes and ensuring more reliable emoji generation functionality. +- 🔒 **Folder System Prompt Permission Enforcement**: System prompt fields in folder edit modal are now properly hidden for users without system prompt permissions, ensuring consistent security policy enforcement across all folder management interfaces. +- 🌐 **WebSocket Redis Lock Timeout Type Conversion**: Fixed proper integer type conversion for WebSocket Redis lock timeout configuration with robust error handling, preventing potential configuration errors and ensuring stable WebSocket connections. +- 📦 **PostHog Dependency Added**: Added PostHog 5.4.0 library to resolve ChromaDB compatibility issues, ensuring stable vector database operations and preventing library version conflicts during deployment. + +### Changed + +- 👀 **Tiptap Editor Upgraded to v3**: The underlying rich text editor has been updated for future-proofing, though some supporting libraries remain on v2 for compatibility. For now, please install dependencies using 'npm install --force' to avoid installation errors. +- 🚫 **Removed Redundant or Unused Strings and Elements**: Miscellaneous unused, duplicate, or obsolete code and translations have been cleaned up to maintain a streamlined and high-performance experience. + +## [0.6.16] - 2025-07-14 + +### Added + +- 🗂️ **Folders as Projects**: Organize your workflow with folder-based projects—set folder-level system prompts and associate custom knowledge, bringing seamless, context-rich management to teams and users handling multiple initiatives or clients. +- 📁 **Instant Folder-Based Chat Creation**: Start a new chat directly from any folder; just click and your new conversation is automatically embedded in the right project context—no more manual dragging or setup, saving time and eliminating mistakes. +- 🧩 **Prompt Variables with Automatic Input Modal**: Prompts containing variables now display a clean, auto-generated input modal that **autofocuses on the first field** for instant value entry—just select the prompt and fill in exactly what’s needed, reducing friction and guesswork. +- 🔡 **Variable Input Typing in Prompts**: Define input types for prompt variables (e.g., text, textarea, number, select, color, date, map and more), giving everyone a clearer and more precise prompt-building experience for advanced automation or workflows. +- 🚀 **Base Model List Caching**: Cache your base model list to speed up model selection and reduce repeated API calls; toggle this in Admin Settings > Connections for responsive model management even in large or multi-provider setups. +- ⏱️ **Configurable Model List Cache TTL**: Take control over model list caching with the new MODEL_LIST_CACHE_TTL environment variable. Set a custom cache duration in seconds to balance performance and freshness, reducing API requests in stable environments or ensuring rapid updates when models change frequently. +- 🔖 **Reference Notes as Knowledge or in Chats**: Use any note as knowledge for a model or folder, or reference it directly from chat—integrate living documentation into your Retrieval Augmented Generation workflows or discussions, bridging knowledge and action. +- 📝 **Chat Directly with Notes (Experimental)**: Ask questions about any note, and directly edit or update notes from within a chat—unlock direct AI-powered brainstorming, summarization, and cleanup, like having your own collaborative AI canvas. +- 🤝 **Collaborative Notes with Multi-User Editing**: Share notes with others and collaborate live—multiple users can edit a note in real-time, boosting cooperative knowledge building and workflow documentation. +- 🛡️ **Collaborative Note Permissions**: Control who can view or edit each note with robust sharing permissions, ensuring privacy or collaboration per your organizational needs. +- 🔗 **Copy Link to Notes**: Quickly copy and share direct links to notes for easier knowledge transfer within your team or external collaborators. +- 📋 **Task List Support in Notes**: Add, organize, and manage checklists or tasks inside your notes—plan projects, track to-dos, and keep everything actionable in a single space. +- 🧠 **AI-Generated Note Titles**: Instantly generate relevant and concise titles for your notes using AI—keep your knowledge library organized without tedious manual editing. +- 🔄 **Full Undo/Redo Support in Notes**: Effortlessly undo or redo your latest note changes—never fear mistakes or accidental edits while collaborating or writing. +- 📝 **Enhanced Note Word/Character Counter**: Always know the size of your notes with built-in counters, making it easier to adhere to length guidelines for shared or published content. +- 🖊️ **Floating & Bubble Formatting Menus in Note Editor**: Access text formatting tools through both a floating menu and an intuitive bubble menu directly in the note editor—making rich text editing faster, more discoverable, and easier than ever. +- ✍️ **Rich Text Prompt Insertion**: A new setting allows prompts to be inserted directly into the chat box as fully-formatted rich text, preserving Markdown elements like headings, lists, and bold text for a more intuitive and visually consistent editing experience. +- 🌐 **Configurable Database URL**: WebUI now supports more flexible database configuration via new environment variables—making deployment and scaling simpler across various infrastructure setups. +- 🎛️ **Completely Frontend-Handled File Upload in Temporary Chats**: When using temporary chats, file extraction now occurs fully in your browser with zero files sent to the backend, further strengthening privacy and giving you instant feedback. +- 🔄 **Enhanced Banner and Chat Command Visibility**: Banner handling and command feedback in chat are now clearer and more contextually visible, making alerts, suggestions, and automation easier to spot and interact with for all users. +- 📱 **Mobile Experience Polished**: The "new chat" button is back in mobile, plus core navigation and input controls have been smoothed out for better usability on phones and tablets. +- 📄 **OpenDocument Text (.odt) Support**: Seamlessly upload and process .odt files from open-source office suites like LibreOffice and OpenOffice, expanding your ability to build knowledge from a wider range of document formats. +- 📑 **Enhanced Markdown Document Splitting**: Improve knowledge retrieval from Markdown files with a new header-aware splitting strategy. This method intelligently chunks documents based on their header structure, preserving the original context and hierarchy for more accurate and relevant RAG results. +- 📚 **Full Context Mode for Knowledge Bases**: When adding a knowledge base to a folder or custom model, you can now toggle full context mode for the entire knowledge base. This bypasses the usual chunking and retrieval process, making it perfect for leaner knowledge bases. +- 🕰️ **Configurable OAuth Timeout**: Enhance login reliability by setting a custom timeout (OAUTH_TIMEOUT) for all OAuth providers (Google, Microsoft, GitHub, OIDC), preventing authentication failures on slow or restricted networks. +- 🎨 **Accessibility & High-Contrast Theme Enhancements**: Major accessibility overhaul with significant updates to the high-contrast theme. Improved focus visibility, ARIA labels, and semantic HTML ensure core components like the chat interface and model selector are fully compliant and readable for visually impaired users. +- ↕️ **Resizable System Prompt Fields**: Conveniently resize system prompt input fields to comfortably view and edit lengthy or complex instructions, improving the user experience for advanced model configuration. +- 🔧 **Granular Update Check Control**: Gain finer control over outbound connections with the new ENABLE_VERSION_UPDATE_CHECK flag. This allows administrators to disable version update checks independently of the full OFFLINE_MODE, perfect for environments with restricted internet access that still need to download embedding models. +- 🗃️ **Configurable Qdrant Collection Prefix**: Enhance scalability by setting a custom QDRANT_COLLECTION_PREFIX. This allows multiple Open WebUI instances to share a single Qdrant cluster safely, ensuring complete data isolation between separate deployments without conflicts. +- ⚙️ **Improved Default Database Performance**: Enhanced out-of-the-box performance by setting smarter database connection pooling defaults, reducing API response times for users on non-SQLite databases without requiring manual configuration. +- 🔧 **Configurable Redis Key Prefix**: Added support for the REDIS_KEY_PREFIX environment variable, allowing multiple Open WebUI instances to share a Redis cluster with isolated key namespaces for improved multi-tenancy. +- ➡️ **Forward User Context to Reranker**: For advanced RAG integrations, user information (ID, name, email, role) can now be forwarded as HTTP headers to external reranking services, enabling personalized results or per-user access control. +- ⚙️ **PGVector Connection Pooling**: Enhance performance and stability for PGVector-based RAG by enabling and configuring the database connection pool. New environment variables allow fine-tuning of pool size, timeout, and overflow settings to handle high-concurrency workloads efficiently. +- ⚙️ **General Backend Refactoring**: Extensive refactoring delivers a faster, more reliable, and robust backend experience—improving chat speed, model management, and day-to-day reliability. +- 🌍 **Expanded & Improved Translations**: Enjoy a more accessible and intuitive experience thanks to comprehensive updates and enhancements for Chinese (Simplified and Traditional), German, French, Catalan, Irish, and Spanish translations throughout the interface. + +### Fixed + +- 🛠️ **Rich Text Input Stability and Performance**: Multiple improvements ensure faster, cleaner text editing and rendering with reduced glitches—especially supporting links, color picking, checkbox controls, and code blocks in notes and chats. +- 📷 **Seamless iPhone Image Uploads**: Effortlessly upload photos from iPhones and other devices using HEIC format—images are now correctly recognized and processed, eliminating compatibility issues. +- 🔄 **Audio MIME Type Registration**: Issues with audio file content types have been resolved, guaranteeing smoother, error-free uploads and playback for transcription or note attachments. +- 🖍️ **Input Commands Now Always Visible**: Input commands (like prompts or knowledge) dynamically adjust their height on small screens, ensuring nothing is cut off and every tool remains easily accessible. +- 🛑 **Tool Result Rendering**: Fixed display problems with tool results, providing fast, clear feedback when using external or internal tools. +- 🗂️ **Table Alignment in Markdown**: Markdown tables are now rendered and aligned as expected, keeping reports and documentation readable. +- 🖼️ **Thread Image Handling**: Fixed an issue where messages containing only images in threads weren’t displayed correctly. +- 🗝️ **Note Access Control Security**: Tightened access control logic for notes to guarantee that shared or collaborative notes respect all user permissions and privacy safeguards. +- 🧾 **Ollama API Compatibility**: Fixed model parameter naming in the API to ensure uninterrupted compatibility for all Ollama endpoints. +- 🛠️ **Detection for 'text/html' Files**: Files loaded with docling/tika are now reliably detected as the correct type, improving knowledge ingestion and document parsing. +- 🔐 **OAuth Login Stability**: Resolved a critical OAuth bug that caused login failures on subsequent attempts after logging out. The user session is now completely cleared on logout, ensuring reliable and secure authentication across all supported providers (Google, Microsoft, GitHub, OIDC). +- 🚪 **OAuth Logout and Redirect Reliability**: The OAuth logout process has been made more robust. Logout requests now correctly use proxy environment variables, ensuring they succeed in corporate networks. Additionally, the custom WEBUI_AUTH_SIGNOUT_REDIRECT_URL is now properly respected for all OAuth/OIDC configurations, ensuring a seamless sign-out experience. +- 📜 **Banner Newline Rendering**: Banners now correctly render newline characters, ensuring that multi-line announcements and messages are displayed with their intended formatting. +- ℹ️ **Consistent Model Description Rendering**: Model descriptions now render Markdown correctly in the main chat interface, matching the formatting seen in the model selection dropdown for a consistent user experience. +- 🔄 **Offline Mode Update Check Display**: Corrected a UI bug where the "Checking for Updates..." message would display indefinitely when the application was set to offline mode. +- 🛠️ **Tool Result Encoding**: Fixed a bug where tool calls returning non-ASCII characters would fail, ensuring robust handling of international text and special characters in tool outputs. + +## [0.6.15] - 2025-06-16 + +### Added + +- 🖼️ **Global Image Compression Option**: Effortlessly set image compression globally so all image uploads and outputs are optimized, speeding up load times and saving bandwidth—perfect for teams dealing with large files or limited network resources. +- 🎤 **Custom Speech-to-Text Content-Type for Transcription**: Define custom content types for audio transcription, ensuring compatibility with diverse audio sources and unlocking smoother, more accurate transcriptions in advanced setups. +- 🗂️ **LDAP Group Synchronization (Experimental)**: Automatically sync user groups from your LDAP directory directly into Open WebUI for seamless enterprise access management—simplifies identity integration and governance across your organization. +- 📈 **OpenTelemetry Metrics via OTLP Exporter (Experimental)**: Gain enterprise-grade analytics and monitor your AI usage in real time with experimental OpenTelemetry Metrics support—connect to any OTLP-compatible backend for instant insights into performance, load, and user interactions. +- 🕰️ **See User Message Timestamps on Hover (Chat Bubble UI)**: Effortlessly check when any user message was sent by hovering over it in Chat Bubble mode—no more switching screens or digging through logs for context. +- 🗂️ **Leaderboard Sorting Options**: Sort the leaderboard directly in the UI for a clearer, more actionable view of top performers, models, or tools—making analysis and recognition quick and easy for teams. +- 🏆 **Evaluation Details Modal in Feedbacks and Leaderboard**: Dive deeper with new modals that display detailed evaluation information when reviewing feedbacks and leaderboard rankings—accelerates learning, progress tracking, and quality improvement. +- 🔄 **Support for Multiple Pages in External Document Loaders**: Effortlessly extract and work with content spanning multiple pages in external documents, giving you complete flexibility for in-depth research and document workflows. +- 🌐 **New Accessibility Enhancements Across the Interface**: Benefit from significant accessibility improvements—tab navigation, ARIA roles/labels, better high-contrast text/modes, accessible modals, and more—making Open WebUI more usable and equitable for everyone, including those using assistive technologies. +- ⚡ **Performance & Stability Upgrades Across Frontend and Backend**: Enjoy a smoother, more reliable experience with numerous behind-the-scenes optimizations and refactoring on both frontend and backend—resulting in faster load times, fewer errors, and even greater stability throughout your workflows. +- 🌏 **Updated and Expanded Localizations**: Enjoy improved, up-to-date translations for Finnish, German (now with model pinning features), Korean, Russian, Simplified Chinese, Spanish, and more—making every interaction smoother, clearer, and more intuitive for international users. + +### Fixed + +- 🦾 **Ollama Error Messages More Descriptive**: Receive clearer, more actionable error messages when something goes wrong with Ollama models—making troubleshooting and user support faster and more effective. +- 🌐 **Bypass Webloader Now Works as Expected**: Resolved an issue where the "bypass webloader" feature failed to function correctly, ensuring web search bypasses operate smoothly and reliably for lighter, faster query results. +- 🔍 **Prevent Redundant Documents in Citation List**: The expanded citation list no longer shows duplicate documents, offering a cleaner, easier-to-digest reference experience when reviewing sources in knowledge and research workflows. +- 🛡️ **Trusted Header Email Matching is Now Case-Insensitive**: Fixed a critical authentication issue where email case sensitivity could cause secure headers to mismatch, ensuring robust, seamless login and session management in all environments. +- ⚙️ **Direct Tool Server Input Accepts Empty Strings**: You can now submit direct tool server commands without unexpected errors when passing empty-string values, improving integration and automation efficiency. +- 📄 **Citation Page Number for Page 1 is Now Displayed**: Corrected an oversight where references for page 1 documents were missing the page number; citations are now always accurate and fully visible. +- 📒 **Notes Access Restored**: Fixed an issue where some users could not access their notes—everyone can now view and manage their notes reliably, ensuring seamless documentation and workflow continuity. +- 🛑 **OAuth Callback Double-Slash Issue Resolved**: Fixed rare cases where an extra slash in OAuth callbacks caused failed logins or redirects, making third-party login integrations more reliable. + +### Changed + +- 🔑 **Dedicated Permission for System Prompts**: System prompt access is now controlled by its own specific permission instead of being grouped with general chat controls, empowering admins with finer-grained management over who can view or modify system prompts for enhanced security and workflow customization. +- 🛠️ **YouTube Transcript API and python-pptx Updated**: Enjoy better performance, reliability, and broader compatibility thanks to underlying library upgrades—less friction with media-rich and presentation workflows. + +### Removed + +- 🗑️ **Console Logging Disabled in Production**: All 'console.log' and 'console.debug' statements are now disabled in production, guaranteeing improved security and cleaner browser logs for end users by removing extraneous technical output. + +## [0.6.14] - 2025-06-10 + +### Added + +- 🤖 **Automatic "Follow Up" Suggestions**: Open WebUI now intelligently generates actionable "Follow Up" suggestions automatically with each message you send, helping you stay productive and inspired without interrupting your flow; you can always disable this in Settings if you prefer a distraction-free experience. +- 🧩 **OpenAI-Compatible Embeddings Endpoint**: Introducing a fully OpenAI-style '/api/embeddings' endpoint—now you can plug in OpenAI-style embeddings workflows with zero hassle, making integrations with external tools and platforms seamless and familiar. +- ↗️ **Model Pinning for Quick Access**: Pin your favorite or most-used models to the sidebar for instant selection—no more scrolling through long model lists; your go-to models are always visible and ready for fast access. +- 📌 **Selector Model Item Menu**: Each model in the selector now features a menu where you can easily pin/unpin to the sidebar and copy a direct link—simplifying collaboration and staying organized in even the busiest environments. +- 🛑 **Reliable Stop for Ongoing Chats in Multi-Replica Setups**: Stopping or cancelling an in-progress chat now works reliably even in clustered deployments—ensuring every user can interrupt AI output at any time, no matter your scale. +- 🧠 **'Think' Parameter for Ollama Models**: Leverage new 'think' parameter support for Ollama—giving you advanced control over AI reasoning process and further tuning model behavior for your unique use cases. +- 💬 **Picture Description Modes for Docling**: Customize how images are described/extracted by Docling Loader for smarter, more detailed, and workflow-tailored image understanding in your document pipelines. +- 🛠 **Settings Modal Deep Linking**: Every tab in Settings now has its own route—making direct navigation and sharing of precise settings faster and more intuitive. +- 🎤 **Audio HTML Component Token**: Easily embed and play audio directly in your chats, improving voice-based workflows and making audio content instantly accessible and manageable from any conversation. +- 🔑 **Support for Secret Key File**: Now you can specify 'WEBUI_SECRET_KEY_FILE' for more secure and flexible key management—ideal for advanced deployments and tighter security standards. +- 💡 **Clarity When Cloning Prompts**: Cloned workspace prompts are clearly labelled with "(Clone)" and IDs have "-clone", keeping your prompt library organized and preventing accidental overwrites. +- 📝 **Dedicated User Role Edit Modal**: Updating user roles now reliably opens a dedicated edit user modal instead of cycling through roles—making it safer and more clear to manage team permissions. +- 🏞️ **Better Handling & Storage of Interpreter-Generated Images**: Code interpreter-generated images are now centrally stored and reliably loaded from the database or cloud storage, ensuring your artifacts are always available. +- 🚀 **Pinecone & Vector Search Optimizations**: Applied latest best practices from Pinecone for smarter timeouts, intelligent retry control, improved connection pooling, faster DNS, and concurrent batch handling—giving you more reliable, faster document search and RAG performance without manual tweaks. +- ⚙️ **Ollama Advanced Parameters Unified**: 'keep_alive' and 'format' options are now integrated into the advanced params section—edit everything from the model editor for flexible model control. +- 🛠️ **CUDA 12.6 Docker Image Support**: Deploy to NVIDIA GPUs with capability 7.0 and below (e.g., V100, GTX1080) via new cuda126 image—broadening your hardware options for scalable AI workloads. +- 🔒 **Experimental Table-Level PGVector Data Encryption**: Activate pgcrypto encryption support for pgvector to secure your vector search table contents, giving organizations enhanced compliance and data protection—perfect for enterprise or regulated environments. +- 👁 **Accessibility Upgrades Across Interface**: Chat buttons and close controls are now labelled and structured for optimal accessibility support, ensuring smoother operation with assistive technologies. +- 🎨 **High-Contrast Mode Expansions**: High-contrast accessibility mode now also applies to menu items, tabs, and search input fields, offering a more readable experience for all users. +- 🛠️ **Tooltip & Translation Clarity**: Improved translation and tooltip clarity, especially over radio buttons, making the UI more understandable for all users. +- 🔠 **Global Localization & Translation Improvements**: Hefty upgrades to Traditional Chinese, Simplified Chinese, Hebrew, Russian, Irish, German, and Danish translation packs—making the platform feel native and intuitive for even more users worldwide. +- ⚡ **General Backend Stability & Security Enhancements**: Refined numerous backend routines to minimize memory use, improve performance, and streamline integration with external APIs—making the entire platform more robust and secure for daily work. + +### Fixed + +- 🏷 **Feedback Score Display Improved**: Addressed overflow and visibility issues with feedback scores for more readable and accessible evaluations. +- 🗂 **Admin Settings Model Edits Apply Immediately**: Changes made in the Model Editor within Admin Settings now take effect instantly, eliminating confusion during model management. +- 🔄 **Assigned Tools Update Instantly on New Chats**: Models assigned with specific tools now consistently update and are available in every new chat—making tool workflows more predictable and robust. +- 🛠 **Document Settings Saved Only on User Action**: Document settings now save only when you press the Save button, reducing accidental changes and ensuring greater control. +- 🔊 **Voice Recording on Older iOS Devices Restored**: Voice input is now fully functional on older iOS devices, keeping voice workflows accessible to all users. +- 🔒 **Trusted Email Header Session Security**: User sessions now strictly verify the trusted email header matches the logged-in user's email, ensuring secure authentication and preventing accidental session switching. +- 🔒 **Consistent User Signout on Email Mismatch**: When the trusted email in the header changes, you will now be properly signed out and redirected, safeguarding your session's integrity. +- 🛠 **General Error & Content Validation Improvements**: Smarter error handling means clearer messages and fewer unnecessary retries—making batch uploads, document handling, and knowledge indexing more resilient. +- 🕵️ **Better Feedback on Chat Title Edits**: Error messages now show clearly if problems occur while editing chat titles. + +## [0.6.13] - 2025-05-30 + +### Added + +- 🟦 **Azure OpenAI Embedding Support**: You can now select Azure OpenAI endpoints for text embeddings, unlocking seamless integration with enterprise-scale Azure AI for powerful RAG and knowledge workflows—no more workarounds, connect and scale effortlessly. +- 🧩 **Smarter Custom Parameter Handling**: Instantly enjoy more flexible model setup—any JSON pasted into custom parameter fields is now parsed automatically, so you can define rich, nested parameters without tedious manual adjustment. This streamlines advanced configuration for all models and accelerates experimentation. +- ⚙️ **General Backend Refactoring**: Significant backend improvements deliver a cleaner codebase for better maintainability, faster performance, and even greater platform reliability—making all your workflows run more smoothly. +- 🌏 **Localization Upgrades**: Experience highly improved user interface translations and clarity in Simplified, Traditional Chinese, Korean, and Finnish, offering a more natural, accurate, and accessible experience for global users. + +### Fixed + +- 🛡️ **Robust Message Handling on Chat Load**: Fixed an issue where chat pages could fail to load if a referenced message was missing or undefined; now, chats always load smoothly and missing IDs no longer disrupt your workflow. +- 📝 **Correct Prompt Access Control**: Ensured that the prompt access controls register properly, restoring reliable permissioning and safeguarding your prompt workflows. +- 🛠 **Open WebUI-Specific Params No Longer Sent to Models**: Fixed a bug that sent internal WebUI parameters to APIs, ensuring only intended model options are transmitted—restoring predictable, error-free model operation. +- 🧠 **Refined Memory Error Handling**: Enhanced stability during memory-related operations, so even uncommon memory errors are gracefully managed without disrupting your session—resulting in a more reliable, worry-free experience. + +## [0.6.12] - 2025-05-29 + +### Added + +- 🧩 **Custom Advanced Model Parameters**: You can now add your own tailor-made advanced parameters to any model, empowering you to fine-tune behavior and unlock greater flexibility beyond just the built-in options—accelerate your experimentation. +- 🪧 **Datalab Marker API Content Extraction Support**: Seamlessly extract content from files and documents using the Datalab Marker API directly in your workflows, enabling more robust structured data extraction for RAG and document processing with just a simple engine switch in the UI. +- ⚡ **Parallelized Base Model Fetching**: Experience noticeably faster startup and model refresh times—base model data now loads in parallel, drastically shortening delays in busy or large-scale deployments. +- 🧠 **Efficient Function Loading and Caching**: Functions are now only reloaded if their content changes, preventing unnecessary duplicate loads, saving bandwidth, and boosting performance. +- 🌍 **Localization & Translation Enhancements**: Improved and expanded Simplified, Traditional Chinese, and Russian translations, providing smoother, more accurate, and context-aware experiences for global users. + +### Fixed + +- 💬 **Stable Message Input Box**: Fixed an issue where the message input box would shift unexpectedly (especially on mobile or with screen reader support), ensuring a smooth and reliable typing experience for every user. +- 🔊 **Reliable Read Aloud (Text-to-Speech)**: Read aloud now works seamlessly across messages, so users depending on TTS for accessibility or multitasking will experience uninterrupted and clear voice playback. +- 🖼 **Image Preview and Download Restored**: Fixed problems with image preview and downloads, ensuring frictionless creation, previewing, and downloading of images in your chats—no more interruptions in creative or documentation workflows. +- 📱 **Improved Mobile Styling for Workspace Capabilities**: Capabilities management is now readable and easy-to-use even on mobile devices, empowering admins and users to manage access quickly on the go. +- 🔁 **/api/v1/retrieval/query/collection Endpoint Reliability**: Queries to retrieval collections now return the expected results, bolstering the reliability of your knowledge workflows and citation-ready responses. + +### Removed + +- 🧹 **Duplicate CSS Elements**: Streamlined the UI by removing redundant CSS, reducing clutter and improving load times for a smoother visual experience. + +## [0.6.11] - 2025-05-27 + +### Added + +- 🟢 **Ollama Model Status Indicator in Model Selector**: Instantly see which Ollama models are currently loaded with a clear indicator in the model selector, helping you stay organized and optimize local model usage. +- 🗑️ **Unload Ollama Model Directly from Model Selector**: Easily release memory and resources by unloading any loaded Ollama model right in the model selector—streamline hardware management without switching pages. +- 🗣️ **User-Configurable Speech-to-Text Language Setting**: Improve transcription accuracy by letting individual users explicitly set their preferred STT language in their settings—ideal for multilingual teams and clear audio capture. +- ⚡ **Granular Audio Playback Speed Control**: Instead of just presets, you can now choose granular audio speed using a numeric input, giving you complete control over playback pace in transcriptions and media reviews. +- 📦 **GZip, Brotli, ZStd Compression Middleware**: Enjoy significantly faster page loads and reduced bandwidth usage with new server-side compression—giving users a snappier, more efficient experience. +- 🏷️ **Configurable Weight for BM25 in Hybrid Search**: Fine-tune search relevance by adjusting the weight for BM25 inside hybrid search from the UI, letting you tailor knowledge search results to your workflow. +- 🧪 **Bypass File Creation with CTRL + SHIFT + V**: When “Paste Large Text as File” is enabled, use CTRL + SHIFT + V to skip the file creation dialog and instantly upload text as a file—perfect for rapid document prep. +- 🌐 **Bypass Web Loader in Web Search**: Choose to bypass web content loading and use snippets directly in web search for faster, more reliable results when page loads are slow or blocked. +- 🚀 **Environment Variable: WEBUI_AUTH_TRUSTED_GROUPS_HEADER**: Now sync and manage user groups directly via trusted HTTP header, unlocking smoother single sign-on and identity integrations for organizations. +- 🏢 **Workspace Models Visibility Controls**: You can now hide workspace-level models from both the model selector and shared environments—keep your team focused and reduce clutter from rarely-used endpoints. +- 🛡️ **Copy Model Link**: You can now copy a direct link to any model—including those hidden from the selector—making sharing and onboarding others more seamless. +- 🔗 **Load Function Directly from URL**: Simplify custom function management—just paste any GitHub function URL into Open WebUI and import new functions in seconds. +- ⚙️ **Custom Name/Description for External Tool Servers**: Personalize and clarify external tool servers by assigning custom names and descriptions, making it easier to manage integrations in large-scale workspaces. +- 🌍 **Custom OpenAPI JSON URL Support for Tool Servers**: Supports specifying any custom OpenAPI JSON URL, unlocking more flexible integration with any backend for tool calls. +- 📊 **Source Field Now Displays in Non-Streaming Responses with Attachments**: When files or knowledge are attached, the "source" field now appears for all responses, even in non-streaming mode—enabling improved citation workflow. +- 🎛 **Pinned Chats**: Reduced payload size on pinned chat requests—leading to faster load times and less data usage, especially on busy warehouses. +- 🛠 **Import/Export Default Prompt Suggestions**: Enjoy one-click import/export of prompt suggestions, making it much easier to share, reuse, and manage best practices across teams or deployments. +- 🍰 **Banners Now Sortable from Admin Settings**: Quickly re-order or prioritize banners, letting you highlight the most critical info for your team. +- 🛠 **Advanced Chat Parameters—Clearer Ollama Support Labels**: Parameters and advanced settings now explicitly indicate if they are Ollama-specific, reducing confusion and improving setup accuracy. +- 🤏 **Scroll Bar Thumb Improved for Better Visibility**: Enhanced scrollbar styling makes navigation more accessible and visually intuitive. +- 🗄️ **Modal Redesign for Archived and User Chat Listings**: Clean, modern modal interface for browsing archived and user-specific chats makes locating conversations faster and more pleasant. +- 📝 **Add/Edit Memory Modal UX**: Memory modals are now larger and have resizable input fields, supporting easier editing of long or complex memory content. +- 🏆 **Translation & Localization Enhancements**: Major upgrades to Chinese (Simplified & Traditional), Korean, Russian, German, Danish, Finnish—not just fixing typos, but consistency, tone, and terminology for a more natural native-language experience. +- ⚡ **General Backend Stability & Security Enhancements**: Various backend refinements ensure a more resilient, reliable, and secure platform for smoother operation and peace of mind. + +### Fixed + +- 🖼️ **Image Generation with Allowed File Extensions Now Works Reliably**: Ensure seamless image generation even when strict file extension rules are set—no more blocked creative workflows due to technical hiccups. +- 🗂 **Remove Leading Dot for File Extension Check**: Fixed an issue where file validation failed because of a leading dot, making file uploads and knowledge management more robust. +- 🏷️ **Correct Local/External Model Classification**: The platform now accurately distinguishes between local and external models—preventing local models from showing up as external (and vice versa)—ensuring seamless setup, clarity, and management of your AI model endpoints. +- 📄 **External Document Loader Now Functions as Intended**: External document loaders are reliably invoked, ensuring smoother knowledge ingestion from external sources—expanding your RAG and knowledge workflows. +- 🎯 **Correct Handling of Toggle Filters**: Toggle filters are now robustly managed, preventing accidental auto-activation and ensuring user preferences are always respected. +- 🗃 **S3 Tagging Character Restrictions Fixed**: Tags for files in S3 now automatically meet Amazon’s allowed character set, avoiding upload errors and ensuring cross-cloud compatibility. +- 🛡️ **Authentication Now Uses Password Hash When Duplicate Emails Exist**: Ensures account security and prevents access issues if duplicate emails are present in your system. + +### Changed + +- 🧩 **Admin Settings: OAuth Redirects Now Use WEBUI_URL**: The OAuth redirect URL is now based on the explicitly set WEBUI_URL, ensuring single sign-on and identity provider integrations always send users to the correct frontend. + +### Removed + +- 💡 **Duplicate/Typo Component Removals**: Obsolete components have been cleaned up, reducing confusion and improving overall code quality for the team. +- 🚫 **Streaming Upsert in Pinecone Removed**: Removed streaming upsert references for better compatibility and future-proofing with latest Pinecone SDK updates. + ## [0.6.10] - 2025-05-19 ### Added diff --git a/Dockerfile b/Dockerfile index d7de72f0155..ad393338d81 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,8 @@ # use build args in the docker build command with --build-arg="BUILDARG=true" ARG USE_CUDA=false ARG USE_OLLAMA=false +ARG USE_SLIM=false +ARG USE_PERMISSION_HARDENING=false # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) ARG USE_CUDA_VER=cu128 # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers @@ -24,13 +26,16 @@ ARG GID=0 FROM --platform=$BUILDPLATFORM node:22-alpine3.20 AS build ARG BUILD_HASH +# Set Node.js options (heap limit Allocation failed - JavaScript heap out of memory) +# ENV NODE_OPTIONS="--max-old-space-size=4096" + WORKDIR /app # to store git revision in build RUN apk add --no-cache git COPY package.json package-lock.json ./ -RUN npm ci +RUN npm ci --force COPY . . ENV APP_BUILD_HASH=${BUILD_HASH} @@ -43,6 +48,8 @@ FROM python:3.11-slim-bookworm AS base ARG USE_CUDA ARG USE_OLLAMA ARG USE_CUDA_VER +ARG USE_SLIM +ARG USE_PERMISSION_HARDENING ARG USE_EMBEDDING_MODEL ARG USE_RERANKING_MODEL ARG UID @@ -54,6 +61,7 @@ ENV ENV=prod \ # pass build args to the build USE_OLLAMA_DOCKER=${USE_OLLAMA} \ USE_CUDA_DOCKER=${USE_CUDA} \ + USE_SLIM_DOCKER=${USE_SLIM} \ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \ USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL} @@ -108,29 +116,13 @@ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry # Make sure the user has access to the app and root directory RUN chown -R $UID:$GID /app $HOME -RUN if [ "$USE_OLLAMA" = "true" ]; then \ - apt-get update && \ - # Install pandoc and netcat - apt-get install -y --no-install-recommends git build-essential pandoc netcat-openbsd curl && \ - apt-get install -y --no-install-recommends gcc python3-dev && \ - # for RAG OCR - apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ - # install helper tools - apt-get install -y --no-install-recommends curl jq && \ - # install ollama - curl -fsSL https://ollama.com/install.sh | sh && \ - # cleanup - rm -rf /var/lib/apt/lists/*; \ - else \ - apt-get update && \ - # Install pandoc, netcat and gcc - apt-get install -y --no-install-recommends git build-essential pandoc gcc netcat-openbsd curl jq && \ - apt-get install -y --no-install-recommends gcc python3-dev && \ - # for RAG OCR - apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ - # cleanup - rm -rf /var/lib/apt/lists/*; \ - fi +# Install common system dependencies +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + git build-essential pandoc gcc netcat-openbsd curl jq \ + python3-dev \ + ffmpeg libsm6 libxext6 \ + && rm -rf /var/lib/apt/lists/* # install python dependencies COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt @@ -146,13 +138,22 @@ RUN pip3 install --no-cache-dir uv && \ else \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \ + if [ "$USE_SLIM" != "true" ]; then \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; import tiktoken; tiktoken.get_encoding(os.environ['TIKTOKEN_ENCODING_NAME'])"; \ fi; \ - chown -R $UID:$GID /app/backend/data/ - + fi; \ + mkdir -p /app/backend/data && chown -R $UID:$GID /app/backend/data/ && \ + rm -rf /var/lib/apt/lists/*; +# Install Ollama if requested +RUN if [ "$USE_OLLAMA" = "true" ]; then \ + date +%s > /tmp/ollama_build_hash && \ + echo "Cache broken at timestamp: `cat /tmp/ollama_build_hash`" && \ + curl -fsSL https://ollama.com/install.sh | sh && \ + rm -rf /var/lib/apt/lists/*; \ + fi # copy embedding weight from build # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 @@ -170,6 +171,17 @@ EXPOSE 8080 HEALTHCHECK CMD curl --silent --fail http://localhost:${PORT:-8080}/health | jq -ne 'input.status == true' || exit 1 +# Minimal, atomic permission hardening for OpenShift (arbitrary UID): +# - Group 0 owns /app and /root +# - Directories are group-writable and have SGID so new files inherit GID 0 +RUN if [ "$USE_PERMISSION_HARDENING" = "true" ]; then \ + set -eux; \ + chgrp -R 0 /app /root || true; \ + chmod -R g+rwX /app /root || true; \ + find /app -type d -exec chmod g+s {} + || true; \ + find /root -type d -exec chmod g+s {} + || true; \ + fi + USER $UID:$GID ARG BUILD_HASH diff --git a/LICENSE_HISTORY b/LICENSE_HISTORY new file mode 100644 index 00000000000..a9eb5e259d6 --- /dev/null +++ b/LICENSE_HISTORY @@ -0,0 +1,53 @@ +All code and materials created before commit `60d84a3aae9802339705826e9095e272e3c83623` are subject to the following copyright and license: + +Copyright (c) 2023-2025 Timothy Jaeryang Baek +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +All code and materials created before commit `a76068d69cd59568b920dfab85dc573dbbb8f131` are subject to the following copyright and license: + +MIT License + +Copyright (c) 2023 Timothy Jaeryang Baek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LICENSE_NOTICE b/LICENSE_NOTICE new file mode 100644 index 00000000000..4e00d46d9ac --- /dev/null +++ b/LICENSE_NOTICE @@ -0,0 +1,11 @@ +# Open WebUI Multi-License Notice + +This repository contains code governed by multiple licenses based on the date and origin of contribution: + +1. All code committed prior to commit a76068d69cd59568b920dfab85dc573dbbb8f131 is licensed under the MIT License (see LICENSE_HISTORY). + +2. All code committed from commit a76068d69cd59568b920dfab85dc573dbbb8f131 up to and including commit 60d84a3aae9802339705826e9095e272e3c83623 is licensed under the BSD 3-Clause License (see LICENSE_HISTORY). + +3. All code contributed or modified after commit 60d84a3aae9802339705826e9095e272e3c83623 is licensed under the Open WebUI License (see LICENSE). + +For details on which commits are covered by which license, refer to LICENSE_HISTORY. diff --git a/README.md b/README.md index 8445b5a3921..49c0a8d9d3e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ **Open WebUI is an [extensible](https://docs.openwebui.com/features/plugin/), feature-rich, and user-friendly self-hosted AI platform designed to operate entirely offline.** It supports various LLM runners like **Ollama** and **OpenAI-compatible APIs**, with **built-in inference engine** for RAG, making it a **powerful AI deployment solution**. +Passionate about open-source AI? [Join our team →](https://careers.openwebui.com/) + ![Open WebUI Demo](./demo.gif) > [!TIP] @@ -29,6 +31,8 @@ For more information, be sure to check out our [Open WebUI Documentation](https: - 🛡️ **Granular Permissions and User Groups**: By allowing administrators to create detailed user roles and permissions, we ensure a secure user environment. This granularity not only enhances security but also allows for customized user experiences, fostering a sense of ownership and responsibility amongst users. +- 🔄 **SCIM 2.0 Support**: Enterprise-grade user and group provisioning through SCIM 2.0 protocol, enabling seamless integration with identity providers like Okta, Azure AD, and Google Workspace for automated user lifecycle management. + - 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices. - 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface. @@ -66,14 +70,34 @@ Want to learn more about Open WebUI's features? Check out our [Open WebUI docume #### Emerald - + + + + + + + +
+ + Tailscale + + + Tailscale • Connect self-hosted AI to any device with Tailscale +
+ + Warp + + + Warp • The intelligent terminal for developers
@@ -171,6 +195,8 @@ After installation, you can access Open WebUI at [http://localhost:3000](http:// We offer various installation alternatives, including non-Docker native installation methods, Docker Compose, Kustomize, and Helm. Visit our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/) or join our [Discord community](https://discord.gg/5rJgQTnV4s) for comprehensive guidance. +Look at the [Local Development Guide](https://docs.openwebui.com/getting-started/advanced-topics/development) for instructions on setting up a local development environment. + ### Troubleshooting Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s). @@ -222,7 +248,7 @@ Discover upcoming features on our roadmap in the [Open WebUI Documentation](http ## License 📜 -This project is licensed under the [Open WebUI License](LICENSE), a revised BSD-3-Clause license. You receive all the same rights as the classic BSD-3 license: you can use, modify, and distribute the software, including in proprietary and commercial products, with minimal restrictions. The only additional requirement is to preserve the "Open WebUI" branding, as detailed in the LICENSE file. For full terms, see the [LICENSE](LICENSE) document. 📄 +This project contains code under multiple licenses. The current codebase includes components licensed under the Open WebUI License with an additional requirement to preserve the "Open WebUI" branding, as well as prior contributions under their respective original licenses. For a detailed record of license changes and the applicable terms for each section of the code, please refer to [LICENSE_HISTORY](./LICENSE_HISTORY). For complete and updated licensing details, please see the [LICENSE](./LICENSE) and [LICENSE_HISTORY](./LICENSE_HISTORY) files. ## Support 💬 diff --git a/backend/dev.sh b/backend/dev.sh index 5449ab77777..042fbd9efa1 100755 --- a/backend/dev.sh +++ b/backend/dev.sh @@ -1,2 +1,3 @@ +export CORS_ALLOW_ORIGIN="http://localhost:5173;http://localhost:8080" PORT="${PORT:-8080}" -uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload \ No newline at end of file +uvicorn open_webui.main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload diff --git a/backend/open_webui/alembic.ini b/backend/open_webui/alembic.ini index 4eff85f0c62..dccd8a3c123 100644 --- a/backend/open_webui/alembic.ini +++ b/backend/open_webui/alembic.ini @@ -10,7 +10,7 @@ script_location = migrations # sys.path path, will be prepended to sys.path if present. # defaults to the current working directory. -prepend_sys_path = . +prepend_sys_path = .. # timezone to use when rendering the date within the migration file # as well as the filename. diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index b1955b056d2..f7926abe85b 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -7,18 +7,21 @@ from datetime import datetime from pathlib import Path -from typing import Generic, Optional, TypeVar +from typing import Generic, Union, Optional, TypeVar from urllib.parse import urlparse import requests from pydantic import BaseModel from sqlalchemy import JSON, Column, DateTime, Integer, func +from authlib.integrations.starlette_client import OAuth + from open_webui.env import ( DATA_DIR, DATABASE_URL, ENV, REDIS_URL, + REDIS_KEY_PREFIX, REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT, FRONTEND_BUILD_DIR, @@ -165,9 +168,19 @@ def __init__(self, env_name: str, config_path: str, env_value: T): self.config_path = config_path self.env_value = env_value self.config_value = get_config_value(config_path) + if self.config_value is not None and ENABLE_PERSISTENT_CONFIG: - log.info(f"'{env_name}' loaded from the latest database entry") - self.value = self.config_value + if ( + self.config_path.startswith("oauth.") + and not ENABLE_OAUTH_PERSISTENT_CONFIG + ): + log.info( + f"Skipping loading of '{env_name}' as OAuth persistent config is disabled" + ) + self.value = env_value + else: + log.info(f"'{env_name}' loaded from the latest database entry") + self.value = self.config_value else: self.value = env_value @@ -209,19 +222,32 @@ def save(self): class AppConfig: + _redis: Union[redis.Redis, redis.cluster.RedisCluster] = None + _redis_key_prefix: str + _state: dict[str, PersistentConfig] - _redis: Optional[redis.Redis] = None def __init__( - self, redis_url: Optional[str] = None, redis_sentinels: Optional[list] = [] + self, + redis_url: Optional[str] = None, + redis_sentinels: Optional[list] = [], + redis_cluster: Optional[bool] = False, + redis_key_prefix: str = "open-webui", ): - super().__setattr__("_state", {}) if redis_url: + super().__setattr__("_redis_key_prefix", redis_key_prefix) super().__setattr__( "_redis", - get_redis_connection(redis_url, redis_sentinels, decode_responses=True), + get_redis_connection( + redis_url, + redis_sentinels, + redis_cluster, + decode_responses=True, + ), ) + super().__setattr__("_state", {}) + def __setattr__(self, key, value): if isinstance(value, PersistentConfig): self._state[key] = value @@ -230,7 +256,7 @@ def __setattr__(self, key, value): self._state[key].save() if self._redis: - redis_key = f"open-webui:config:{key}" + redis_key = f"{self._redis_key_prefix}:config:{key}" self._redis.set(redis_key, json.dumps(self._state[key].value)) def __getattr__(self, key): @@ -239,7 +265,7 @@ def __getattr__(self, key): # If Redis is available, check for an updated value if self._redis: - redis_key = f"open-webui:config:{key}" + redis_key = f"{self._redis_key_prefix}:config:{key}" redis_value = self._redis.get(redis_key) if redis_value is not None: @@ -281,13 +307,22 @@ def __getattr__(self, key): JWT_EXPIRES_IN = PersistentConfig( - "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") + "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "4w") ) +if JWT_EXPIRES_IN.value == "-1": + log.warning( + "⚠️ SECURITY WARNING: JWT_EXPIRES_IN is set to '-1'\n" + " See: https://docs.openwebui.com/getting-started/env-configuration\n" + ) + #################################### # OAuth config #################################### +ENABLE_OAUTH_PERSISTENT_CONFIG = ( + os.environ.get("ENABLE_OAUTH_PERSISTENT_CONFIG", "False").lower() == "true" +) ENABLE_OAUTH_SIGNUP = PersistentConfig( "ENABLE_OAUTH_SIGNUP", @@ -347,6 +382,24 @@ def __getattr__(self, key): os.environ.get("MICROSOFT_CLIENT_TENANT_ID", ""), ) +MICROSOFT_CLIENT_LOGIN_BASE_URL = PersistentConfig( + "MICROSOFT_CLIENT_LOGIN_BASE_URL", + "oauth.microsoft.login_base_url", + os.environ.get( + "MICROSOFT_CLIENT_LOGIN_BASE_URL", "https://login.microsoftonline.com" + ), +) + +MICROSOFT_CLIENT_PICTURE_URL = PersistentConfig( + "MICROSOFT_CLIENT_PICTURE_URL", + "oauth.microsoft.picture_url", + os.environ.get( + "MICROSOFT_CLIENT_PICTURE_URL", + "https://graph.microsoft.com/v1.0/me/photo/$value", + ), +) + + MICROSOFT_OAUTH_SCOPE = PersistentConfig( "MICROSOFT_OAUTH_SCOPE", "oauth.microsoft.scope", @@ -413,6 +466,18 @@ def __getattr__(self, key): os.environ.get("OAUTH_SCOPES", "openid email profile"), ) +OAUTH_TIMEOUT = PersistentConfig( + "OAUTH_TIMEOUT", + "oauth.oidc.oauth_timeout", + os.environ.get("OAUTH_TIMEOUT", ""), +) + +OAUTH_TOKEN_ENDPOINT_AUTH_METHOD = PersistentConfig( + "OAUTH_TOKEN_ENDPOINT_AUTH_METHOD", + "oauth.oidc.token_endpoint_auth_method", + os.environ.get("OAUTH_TOKEN_ENDPOINT_AUTH_METHOD", None), +) + OAUTH_CODE_CHALLENGE_METHOD = PersistentConfig( "OAUTH_CODE_CHALLENGE_METHOD", "oauth.oidc.code_challenge_method", @@ -425,6 +490,12 @@ def __getattr__(self, key): os.environ.get("OAUTH_PROVIDER_NAME", "SSO"), ) +OAUTH_SUB_CLAIM = PersistentConfig( + "OAUTH_SUB_CLAIM", + "oauth.oidc.sub_claim", + os.environ.get("OAUTH_SUB_CLAIM", None), +) + OAUTH_USERNAME_CLAIM = PersistentConfig( "OAUTH_USERNAME_CLAIM", "oauth.oidc.username_claim", @@ -447,7 +518,31 @@ def __getattr__(self, key): OAUTH_GROUPS_CLAIM = PersistentConfig( "OAUTH_GROUPS_CLAIM", "oauth.oidc.group_claim", - os.environ.get("OAUTH_GROUP_CLAIM", "groups"), + os.environ.get("OAUTH_GROUPS_CLAIM", os.environ.get("OAUTH_GROUP_CLAIM", "groups")), +) + +FEISHU_CLIENT_ID = PersistentConfig( + "FEISHU_CLIENT_ID", + "oauth.feishu.client_id", + os.environ.get("FEISHU_CLIENT_ID", ""), +) + +FEISHU_CLIENT_SECRET = PersistentConfig( + "FEISHU_CLIENT_SECRET", + "oauth.feishu.client_secret", + os.environ.get("FEISHU_CLIENT_SECRET", ""), +) + +FEISHU_OAUTH_SCOPE = PersistentConfig( + "FEISHU_OAUTH_SCOPE", + "oauth.feishu.scope", + os.environ.get("FEISHU_OAUTH_SCOPE", "contact:user.base:readonly"), +) + +FEISHU_REDIRECT_URI = PersistentConfig( + "FEISHU_REDIRECT_URI", + "oauth.feishu.redirect_uri", + os.environ.get("FEISHU_REDIRECT_URI", ""), ) ENABLE_OAUTH_ROLE_MANAGEMENT = PersistentConfig( @@ -516,15 +611,23 @@ def load_oauth_providers(): OAUTH_PROVIDERS.clear() if GOOGLE_CLIENT_ID.value and GOOGLE_CLIENT_SECRET.value: - def google_oauth_register(client): - client.register( + def google_oauth_register(oauth: OAuth): + client = oauth.register( name="google", client_id=GOOGLE_CLIENT_ID.value, client_secret=GOOGLE_CLIENT_SECRET.value, server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", - client_kwargs={"scope": GOOGLE_OAUTH_SCOPE.value}, + client_kwargs={ + "scope": GOOGLE_OAUTH_SCOPE.value, + **( + {"timeout": int(OAUTH_TIMEOUT.value)} + if OAUTH_TIMEOUT.value + else {} + ), + }, redirect_uri=GOOGLE_REDIRECT_URI.value, ) + return client OAUTH_PROVIDERS["google"] = { "redirect_uri": GOOGLE_REDIRECT_URI.value, @@ -537,28 +640,34 @@ def google_oauth_register(client): and MICROSOFT_CLIENT_TENANT_ID.value ): - def microsoft_oauth_register(client): - client.register( + def microsoft_oauth_register(oauth: OAuth): + client = oauth.register( name="microsoft", client_id=MICROSOFT_CLIENT_ID.value, client_secret=MICROSOFT_CLIENT_SECRET.value, - server_metadata_url=f"https://login.microsoftonline.com/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}", + server_metadata_url=f"{MICROSOFT_CLIENT_LOGIN_BASE_URL.value}/{MICROSOFT_CLIENT_TENANT_ID.value}/v2.0/.well-known/openid-configuration?appid={MICROSOFT_CLIENT_ID.value}", client_kwargs={ "scope": MICROSOFT_OAUTH_SCOPE.value, + **( + {"timeout": int(OAUTH_TIMEOUT.value)} + if OAUTH_TIMEOUT.value + else {} + ), }, redirect_uri=MICROSOFT_REDIRECT_URI.value, ) + return client OAUTH_PROVIDERS["microsoft"] = { "redirect_uri": MICROSOFT_REDIRECT_URI.value, - "picture_url": "https://graph.microsoft.com/v1.0/me/photo/$value", + "picture_url": MICROSOFT_CLIENT_PICTURE_URL.value, "register": microsoft_oauth_register, } if GITHUB_CLIENT_ID.value and GITHUB_CLIENT_SECRET.value: - def github_oauth_register(client): - client.register( + def github_oauth_register(oauth: OAuth): + client = oauth.register( name="github", client_id=GITHUB_CLIENT_ID.value, client_secret=GITHUB_CLIENT_SECRET.value, @@ -566,9 +675,17 @@ def github_oauth_register(client): authorize_url="https://github.com/login/oauth/authorize", api_base_url="https://api.github.com", userinfo_endpoint="https://api.github.com/user", - client_kwargs={"scope": GITHUB_CLIENT_SCOPE.value}, + client_kwargs={ + "scope": GITHUB_CLIENT_SCOPE.value, + **( + {"timeout": int(OAUTH_TIMEOUT.value)} + if OAUTH_TIMEOUT.value + else {} + ), + }, redirect_uri=GITHUB_CLIENT_REDIRECT_URI.value, ) + return client OAUTH_PROVIDERS["github"] = { "redirect_uri": GITHUB_CLIENT_REDIRECT_URI.value, @@ -578,13 +695,23 @@ def github_oauth_register(client): if ( OAUTH_CLIENT_ID.value - and OAUTH_CLIENT_SECRET.value + and (OAUTH_CLIENT_SECRET.value or OAUTH_CODE_CHALLENGE_METHOD.value) and OPENID_PROVIDER_URL.value ): - def oidc_oauth_register(client): + def oidc_oauth_register(oauth: OAuth): client_kwargs = { "scope": OAUTH_SCOPES.value, + **( + { + "token_endpoint_auth_method": OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value + } + if OAUTH_TOKEN_ENDPOINT_AUTH_METHOD.value + else {} + ), + **( + {"timeout": int(OAUTH_TIMEOUT.value)} if OAUTH_TIMEOUT.value else {} + ), } if ( @@ -598,7 +725,7 @@ def oidc_oauth_register(client): % ("S256", OAUTH_CODE_CHALLENGE_METHOD.value) ) - client.register( + client = oauth.register( name="oidc", client_id=OAUTH_CLIENT_ID.value, client_secret=OAUTH_CLIENT_SECRET.value, @@ -606,6 +733,7 @@ def oidc_oauth_register(client): client_kwargs=client_kwargs, redirect_uri=OPENID_REDIRECT_URI.value, ) + return client OAUTH_PROVIDERS["oidc"] = { "name": OAUTH_PROVIDER_NAME.value, @@ -613,6 +741,53 @@ def oidc_oauth_register(client): "register": oidc_oauth_register, } + if FEISHU_CLIENT_ID.value and FEISHU_CLIENT_SECRET.value: + + def feishu_oauth_register(oauth: OAuth): + client = oauth.register( + name="feishu", + client_id=FEISHU_CLIENT_ID.value, + client_secret=FEISHU_CLIENT_SECRET.value, + access_token_url="https://open.feishu.cn/open-apis/authen/v2/oauth/token", + authorize_url="https://accounts.feishu.cn/open-apis/authen/v1/authorize", + api_base_url="https://open.feishu.cn/open-apis", + userinfo_endpoint="https://open.feishu.cn/open-apis/authen/v1/user_info", + client_kwargs={ + "scope": FEISHU_OAUTH_SCOPE.value, + **( + {"timeout": int(OAUTH_TIMEOUT.value)} + if OAUTH_TIMEOUT.value + else {} + ), + }, + redirect_uri=FEISHU_REDIRECT_URI.value, + ) + return client + + OAUTH_PROVIDERS["feishu"] = { + "register": feishu_oauth_register, + "sub_claim": "user_id", + } + + configured_providers = [] + if GOOGLE_CLIENT_ID.value: + configured_providers.append("Google") + if MICROSOFT_CLIENT_ID.value: + configured_providers.append("Microsoft") + if GITHUB_CLIENT_ID.value: + configured_providers.append("GitHub") + if FEISHU_CLIENT_ID.value: + configured_providers.append("Feishu") + + if configured_providers and not OPENID_PROVIDER_URL.value: + provider_list = ", ".join(configured_providers) + log.warning( + f"⚠️ OAuth providers configured ({provider_list}) but OPENID_PROVIDER_URL not set - logout will not work!" + ) + log.warning( + f"Set OPENID_PROVIDER_URL to your OAuth provider's OpenID Connect discovery endpoint to fix logout functionality." + ) + load_oauth_providers() @@ -622,6 +797,17 @@ def oidc_oauth_register(client): STATIC_DIR = Path(os.getenv("STATIC_DIR", OPEN_WEBUI_DIR / "static")).resolve() +try: + if STATIC_DIR.exists(): + for item in STATIC_DIR.iterdir(): + if item.is_file() or item.is_symlink(): + try: + item.unlink() + except Exception as e: + pass +except Exception as e: + pass + for file_path in (FRONTEND_BUILD_DIR / "static").glob("**/*"): if file_path.is_file(): target_path = STATIC_DIR / file_path.relative_to( @@ -701,12 +887,6 @@ def oidc_oauth_register(client): pass -#################################### -# LICENSE_KEY -#################################### - -LICENSE_KEY = os.environ.get("LICENSE_KEY", "") - #################################### # STORAGE PROVIDER #################################### @@ -757,7 +937,7 @@ def oidc_oauth_register(client): ENABLE_DIRECT_CONNECTIONS = PersistentConfig( "ENABLE_DIRECT_CONNECTIONS", "direct.enable", - os.environ.get("ENABLE_DIRECT_CONNECTIONS", "True").lower() == "true", + os.environ.get("ENABLE_DIRECT_CONNECTIONS", "False").lower() == "true", ) #################################### @@ -839,6 +1019,9 @@ def oidc_oauth_register(client): if OPENAI_API_BASE_URL == "": OPENAI_API_BASE_URL = "https://api.openai.com/v1" +else: + if OPENAI_API_BASE_URL.endswith("/"): + OPENAI_API_BASE_URL = OPENAI_API_BASE_URL[:-1] OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "") OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY @@ -877,6 +1060,18 @@ def oidc_oauth_register(client): pass OPENAI_API_BASE_URL = "https://api.openai.com/v1" + +#################################### +# MODELS +#################################### + +ENABLE_BASE_MODELS_CACHE = PersistentConfig( + "ENABLE_BASE_MODELS_CACHE", + "models.base_models_cache", + os.environ.get("ENABLE_BASE_MODELS_CACHE", "False").lower() == "true", +) + + #################################### # TOOL_SERVERS #################################### @@ -901,9 +1096,7 @@ def oidc_oauth_register(client): #################################### -WEBUI_URL = PersistentConfig( - "WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "http://localhost:3000") -) +WEBUI_URL = PersistentConfig("WEBUI_URL", "webui.url", os.environ.get("WEBUI_URL", "")) ENABLE_SIGNUP = PersistentConfig( @@ -1035,6 +1228,11 @@ def oidc_oauth_register(client): == "true" ) +USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING = ( + os.environ.get("USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING", "False").lower() + == "true" +) + USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING = ( os.environ.get( "USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING", "False" @@ -1061,6 +1259,18 @@ def oidc_oauth_register(client): os.environ.get("USER_PERMISSIONS_CHAT_CONTROLS", "True").lower() == "true" ) +USER_PERMISSIONS_CHAT_VALVES = ( + os.environ.get("USER_PERMISSIONS_CHAT_VALVES", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_SYSTEM_PROMPT = ( + os.environ.get("USER_PERMISSIONS_CHAT_SYSTEM_PROMPT", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_PARAMS = ( + os.environ.get("USER_PERMISSIONS_CHAT_PARAMS", "True").lower() == "true" +) + USER_PERMISSIONS_CHAT_FILE_UPLOAD = ( os.environ.get("USER_PERMISSIONS_CHAT_FILE_UPLOAD", "True").lower() == "true" ) @@ -1069,6 +1279,23 @@ def oidc_oauth_register(client): os.environ.get("USER_PERMISSIONS_CHAT_DELETE", "True").lower() == "true" ) +USER_PERMISSIONS_CHAT_DELETE_MESSAGE = ( + os.environ.get("USER_PERMISSIONS_CHAT_DELETE_MESSAGE", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = ( + os.environ.get("USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE", "True").lower() == "true" +) + +USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE = ( + os.environ.get("USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE", "True").lower() + == "true" +) + +USER_PERMISSIONS_CHAT_RATE_RESPONSE = ( + os.environ.get("USER_PERMISSIONS_CHAT_RATE_RESPONSE", "True").lower() == "true" +) + USER_PERMISSIONS_CHAT_EDIT = ( os.environ.get("USER_PERMISSIONS_CHAT_EDIT", "True").lower() == "true" ) @@ -1143,11 +1370,19 @@ def oidc_oauth_register(client): "public_knowledge": USER_PERMISSIONS_WORKSPACE_KNOWLEDGE_ALLOW_PUBLIC_SHARING, "public_prompts": USER_PERMISSIONS_WORKSPACE_PROMPTS_ALLOW_PUBLIC_SHARING, "public_tools": USER_PERMISSIONS_WORKSPACE_TOOLS_ALLOW_PUBLIC_SHARING, + "public_notes": USER_PERMISSIONS_NOTES_ALLOW_PUBLIC_SHARING, }, "chat": { "controls": USER_PERMISSIONS_CHAT_CONTROLS, + "valves": USER_PERMISSIONS_CHAT_VALVES, + "system_prompt": USER_PERMISSIONS_CHAT_SYSTEM_PROMPT, + "params": USER_PERMISSIONS_CHAT_PARAMS, "file_upload": USER_PERMISSIONS_CHAT_FILE_UPLOAD, "delete": USER_PERMISSIONS_CHAT_DELETE, + "delete_message": USER_PERMISSIONS_CHAT_DELETE_MESSAGE, + "continue_response": USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE, + "regenerate_response": USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE, + "rate_response": USER_PERMISSIONS_CHAT_RATE_RESPONSE, "edit": USER_PERMISSIONS_CHAT_EDIT, "share": USER_PERMISSIONS_CHAT_SHARE, "export": USER_PERMISSIONS_CHAT_EXPORT, @@ -1212,6 +1447,18 @@ def oidc_oauth_register(client): ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true" +ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = ( + os.environ.get("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS", "True").lower() == "true" +) + +BYPASS_ADMIN_ACCESS_CONTROL = ( + os.environ.get( + "BYPASS_ADMIN_ACCESS_CONTROL", + os.environ.get("ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS", "True"), + ).lower() + == "true" +) + ENABLE_ADMIN_CHAT_ACCESS = ( os.environ.get("ENABLE_ADMIN_CHAT_ACCESS", "True").lower() == "true" ) @@ -1247,19 +1494,14 @@ def oidc_oauth_register(client): THREAD_POOL_SIZE = None -def validate_cors_origins(origins): - for origin in origins: - if origin != "*": - validate_cors_origin(origin) - - def validate_cors_origin(origin): parsed_url = urlparse(origin) - # Check if the scheme is either http or https - if parsed_url.scheme not in ["http", "https"]: + # Check if the scheme is either http or https, or a custom scheme + schemes = ["http", "https"] + CORS_ALLOW_CUSTOM_SCHEME + if parsed_url.scheme not in schemes: raise ValueError( - f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' are allowed." + f"Invalid scheme in CORS_ALLOW_ORIGIN: '{origin}'. Only 'http' and 'https' and CORS_ALLOW_CUSTOM_SCHEME are allowed." ) # Ensure that the netloc (domain + port) is present, indicating it's a valid URL @@ -1272,16 +1514,22 @@ def validate_cors_origin(origin): # To test CORS_ALLOW_ORIGIN locally, you can set something like # CORS_ALLOW_ORIGIN=http://localhost:5173;http://localhost:8080 # in your .env file depending on your frontend port, 5173 in this case. -CORS_ALLOW_ORIGIN = os.environ.get( - "CORS_ALLOW_ORIGIN", "*;http://localhost:5173;http://localhost:8080" -).split(";") +CORS_ALLOW_ORIGIN = os.environ.get("CORS_ALLOW_ORIGIN", "*").split(";") + +# Allows custom URL schemes (e.g., app://) to be used as origins for CORS. +# Useful for local development or desktop clients with schemes like app:// or other custom protocols. +# Provide a semicolon-separated list of allowed schemes in the environment variable CORS_ALLOW_CUSTOM_SCHEMES. +CORS_ALLOW_CUSTOM_SCHEME = os.environ.get("CORS_ALLOW_CUSTOM_SCHEME", "").split(";") -if "*" in CORS_ALLOW_ORIGIN: +if CORS_ALLOW_ORIGIN == ["*"]: log.warning( "\n\nWARNING: CORS_ALLOW_ORIGIN IS SET TO '*' - NOT RECOMMENDED FOR PRODUCTION DEPLOYMENTS.\n" ) - -validate_cors_origins(CORS_ALLOW_ORIGIN) +else: + # You have to pick between a single wildcard or a list of origins. + # Doing both will result in CORS errors in the browser. + for origin in CORS_ALLOW_ORIGIN: + validate_cors_origin(origin) class BannerModel(BaseModel): @@ -1413,6 +1661,35 @@ class BannerModel(BaseModel): {{MESSAGES:END:6}} """ + +FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = PersistentConfig( + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", + "task.follow_up.prompt_template", + os.environ.get("FOLLOW_UP_GENERATION_PROMPT_TEMPLATE", ""), +) + +DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = """### Task: +Suggest 3-5 relevant follow-up questions or prompts that the user might naturally ask next in this conversation as a **user**, based on the chat history, to help continue or deepen the discussion. +### Guidelines: +- Write all follow-up questions from the user’s point of view, directed to the assistant. +- Make questions concise, clear, and directly related to the discussed topic(s). +- Only suggest follow-ups that make sense given the chat content and do not repeat what was already covered. +- If the conversation is very short or not specific, suggest more general (but relevant) follow-ups the user might ask. +- Use the conversation's primary language; default to English if multilingual. +- Response must be a JSON array of strings, no extra text or formatting. +### Output: +JSON format: { "follow_ups": ["Question 1?", "Question 2?", "Question 3?"] } +### Chat History: + +{{MESSAGES:END:6}} +""" + +ENABLE_FOLLOW_UP_GENERATION = PersistentConfig( + "ENABLE_FOLLOW_UP_GENERATION", + "task.follow_up.enable", + os.environ.get("ENABLE_FOLLOW_UP_GENERATION", "True").lower() == "true", +) + ENABLE_TAGS_GENERATION = PersistentConfig( "ENABLE_TAGS_GENERATION", "task.tags.enable", @@ -1684,6 +1961,11 @@ class BannerModel(BaseModel): ), ) +CODE_INTERPRETER_BLOCKED_MODULES = [ + library.strip() + for library in os.environ.get("CODE_INTERPRETER_BLOCKED_MODULES", "").split(",") + if library.strip() +] DEFAULT_CODE_INTERPRETER_PROMPT = """ #### Tools Available @@ -1734,26 +2016,36 @@ class BannerModel(BaseModel): # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # Milvus - MILVUS_URI = os.environ.get("MILVUS_URI", f"{DATA_DIR}/vector_db/milvus.db") MILVUS_DB = os.environ.get("MILVUS_DB", "default") MILVUS_TOKEN = os.environ.get("MILVUS_TOKEN", None) - MILVUS_INDEX_TYPE = os.environ.get("MILVUS_INDEX_TYPE", "HNSW") MILVUS_METRIC_TYPE = os.environ.get("MILVUS_METRIC_TYPE", "COSINE") MILVUS_HNSW_M = int(os.environ.get("MILVUS_HNSW_M", "16")) MILVUS_HNSW_EFCONSTRUCTION = int(os.environ.get("MILVUS_HNSW_EFCONSTRUCTION", "100")) MILVUS_IVF_FLAT_NLIST = int(os.environ.get("MILVUS_IVF_FLAT_NLIST", "128")) +MILVUS_DISKANN_MAX_DEGREE = int(os.environ.get("MILVUS_DISKANN_MAX_DEGREE", "56")) +MILVUS_DISKANN_SEARCH_LIST_SIZE = int( + os.environ.get("MILVUS_DISKANN_SEARCH_LIST_SIZE", "100") +) +ENABLE_MILVUS_MULTITENANCY_MODE = ( + os.environ.get("ENABLE_MILVUS_MULTITENANCY_MODE", "false").lower() == "true" +) +# Hyphens not allowed, need to use underscores in collection names +MILVUS_COLLECTION_PREFIX = os.environ.get("MILVUS_COLLECTION_PREFIX", "open_webui") # Qdrant QDRANT_URI = os.environ.get("QDRANT_URI", None) QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", None) QDRANT_ON_DISK = os.environ.get("QDRANT_ON_DISK", "false").lower() == "true" -QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "False").lower() == "true" +QDRANT_PREFER_GRPC = os.environ.get("QDRANT_PREFER_GRPC", "false").lower() == "true" QDRANT_GRPC_PORT = int(os.environ.get("QDRANT_GRPC_PORT", "6334")) +QDRANT_TIMEOUT = int(os.environ.get("QDRANT_TIMEOUT", "5")) +QDRANT_HNSW_M = int(os.environ.get("QDRANT_HNSW_M", "16")) ENABLE_QDRANT_MULTITENANCY_MODE = ( - os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "false").lower() == "true" + os.environ.get("ENABLE_QDRANT_MULTITENANCY_MODE", "true").lower() == "true" ) +QDRANT_COLLECTION_PREFIX = os.environ.get("QDRANT_COLLECTION_PREFIX", "open-webui") # OpenSearch OPENSEARCH_URI = os.environ.get("OPENSEARCH_URI", "https://localhost:9200") @@ -1785,6 +2077,55 @@ class BannerModel(BaseModel): os.environ.get("PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH", "1536") ) +PGVECTOR_CREATE_EXTENSION = ( + os.getenv("PGVECTOR_CREATE_EXTENSION", "true").lower() == "true" +) +PGVECTOR_PGCRYPTO = os.getenv("PGVECTOR_PGCRYPTO", "false").lower() == "true" +PGVECTOR_PGCRYPTO_KEY = os.getenv("PGVECTOR_PGCRYPTO_KEY", None) +if PGVECTOR_PGCRYPTO and not PGVECTOR_PGCRYPTO_KEY: + raise ValueError( + "PGVECTOR_PGCRYPTO is enabled but PGVECTOR_PGCRYPTO_KEY is not set. Please provide a valid key." + ) + + +PGVECTOR_POOL_SIZE = os.environ.get("PGVECTOR_POOL_SIZE", None) + +if PGVECTOR_POOL_SIZE != None: + try: + PGVECTOR_POOL_SIZE = int(PGVECTOR_POOL_SIZE) + except Exception: + PGVECTOR_POOL_SIZE = None + +PGVECTOR_POOL_MAX_OVERFLOW = os.environ.get("PGVECTOR_POOL_MAX_OVERFLOW", 0) + +if PGVECTOR_POOL_MAX_OVERFLOW == "": + PGVECTOR_POOL_MAX_OVERFLOW = 0 +else: + try: + PGVECTOR_POOL_MAX_OVERFLOW = int(PGVECTOR_POOL_MAX_OVERFLOW) + except Exception: + PGVECTOR_POOL_MAX_OVERFLOW = 0 + +PGVECTOR_POOL_TIMEOUT = os.environ.get("PGVECTOR_POOL_TIMEOUT", 30) + +if PGVECTOR_POOL_TIMEOUT == "": + PGVECTOR_POOL_TIMEOUT = 30 +else: + try: + PGVECTOR_POOL_TIMEOUT = int(PGVECTOR_POOL_TIMEOUT) + except Exception: + PGVECTOR_POOL_TIMEOUT = 30 + +PGVECTOR_POOL_RECYCLE = os.environ.get("PGVECTOR_POOL_RECYCLE", 3600) + +if PGVECTOR_POOL_RECYCLE == "": + PGVECTOR_POOL_RECYCLE = 3600 +else: + try: + PGVECTOR_POOL_RECYCLE = int(PGVECTOR_POOL_RECYCLE) + except Exception: + PGVECTOR_POOL_RECYCLE = 3600 + # Pinecone PINECONE_API_KEY = os.environ.get("PINECONE_API_KEY", None) PINECONE_ENVIRONMENT = os.environ.get("PINECONE_ENVIRONMENT", None) @@ -1793,6 +2134,37 @@ class BannerModel(BaseModel): PINECONE_METRIC = os.getenv("PINECONE_METRIC", "cosine") PINECONE_CLOUD = os.getenv("PINECONE_CLOUD", "aws") # or "gcp" or "azure" +# ORACLE23AI (Oracle23ai Vector Search) + +ORACLE_DB_USE_WALLET = os.environ.get("ORACLE_DB_USE_WALLET", "false").lower() == "true" +ORACLE_DB_USER = os.environ.get("ORACLE_DB_USER", None) # +ORACLE_DB_PASSWORD = os.environ.get("ORACLE_DB_PASSWORD", None) # +ORACLE_DB_DSN = os.environ.get("ORACLE_DB_DSN", None) # +ORACLE_WALLET_DIR = os.environ.get("ORACLE_WALLET_DIR", None) +ORACLE_WALLET_PASSWORD = os.environ.get("ORACLE_WALLET_PASSWORD", None) +ORACLE_VECTOR_LENGTH = os.environ.get("ORACLE_VECTOR_LENGTH", 768) + +ORACLE_DB_POOL_MIN = int(os.environ.get("ORACLE_DB_POOL_MIN", 2)) +ORACLE_DB_POOL_MAX = int(os.environ.get("ORACLE_DB_POOL_MAX", 10)) +ORACLE_DB_POOL_INCREMENT = int(os.environ.get("ORACLE_DB_POOL_INCREMENT", 1)) + + +if VECTOR_DB == "oracle23ai": + if not ORACLE_DB_USER or not ORACLE_DB_PASSWORD or not ORACLE_DB_DSN: + raise ValueError( + "Oracle23ai requires setting ORACLE_DB_USER, ORACLE_DB_PASSWORD, and ORACLE_DB_DSN." + ) + if ORACLE_DB_USE_WALLET and (not ORACLE_WALLET_DIR or not ORACLE_WALLET_PASSWORD): + raise ValueError( + "Oracle23ai requires setting ORACLE_WALLET_DIR and ORACLE_WALLET_PASSWORD when using wallet authentication." + ) + +log.info(f"VECTOR_DB: {VECTOR_DB}") + +# S3 Vector +S3_VECTOR_BUCKET_NAME = os.environ.get("S3_VECTOR_BUCKET_NAME", None) +S3_VECTOR_REGION = os.environ.get("S3_VECTOR_REGION", None) + #################################### # Information Retrieval (RAG) #################################### @@ -1823,10 +2195,20 @@ class BannerModel(BaseModel): os.getenv("ENABLE_ONEDRIVE_INTEGRATION", "False").lower() == "true", ) -ONEDRIVE_CLIENT_ID = PersistentConfig( - "ONEDRIVE_CLIENT_ID", - "onedrive.client_id", - os.environ.get("ONEDRIVE_CLIENT_ID", ""), + +ENABLE_ONEDRIVE_PERSONAL = ( + os.environ.get("ENABLE_ONEDRIVE_PERSONAL", "True").lower() == "true" +) +ENABLE_ONEDRIVE_BUSINESS = ( + os.environ.get("ENABLE_ONEDRIVE_BUSINESS", "True").lower() == "true" +) + +ONEDRIVE_CLIENT_ID = os.environ.get("ONEDRIVE_CLIENT_ID", "") +ONEDRIVE_CLIENT_ID_PERSONAL = os.environ.get( + "ONEDRIVE_CLIENT_ID_PERSONAL", ONEDRIVE_CLIENT_ID +) +ONEDRIVE_CLIENT_ID_BUSINESS = os.environ.get( + "ONEDRIVE_CLIENT_ID_BUSINESS", ONEDRIVE_CLIENT_ID ) ONEDRIVE_SHAREPOINT_URL = PersistentConfig( @@ -1848,6 +2230,103 @@ class BannerModel(BaseModel): os.environ.get("CONTENT_EXTRACTION_ENGINE", "").lower(), ) +DATALAB_MARKER_API_KEY = PersistentConfig( + "DATALAB_MARKER_API_KEY", + "rag.datalab_marker_api_key", + os.environ.get("DATALAB_MARKER_API_KEY", ""), +) + +DATALAB_MARKER_API_BASE_URL = PersistentConfig( + "DATALAB_MARKER_API_BASE_URL", + "rag.datalab_marker_api_base_url", + os.environ.get("DATALAB_MARKER_API_BASE_URL", ""), +) + +DATALAB_MARKER_ADDITIONAL_CONFIG = PersistentConfig( + "DATALAB_MARKER_ADDITIONAL_CONFIG", + "rag.datalab_marker_additional_config", + os.environ.get("DATALAB_MARKER_ADDITIONAL_CONFIG", ""), +) + +DATALAB_MARKER_USE_LLM = PersistentConfig( + "DATALAB_MARKER_USE_LLM", + "rag.DATALAB_MARKER_USE_LLM", + os.environ.get("DATALAB_MARKER_USE_LLM", "false").lower() == "true", +) + +DATALAB_MARKER_SKIP_CACHE = PersistentConfig( + "DATALAB_MARKER_SKIP_CACHE", + "rag.datalab_marker_skip_cache", + os.environ.get("DATALAB_MARKER_SKIP_CACHE", "false").lower() == "true", +) + +DATALAB_MARKER_FORCE_OCR = PersistentConfig( + "DATALAB_MARKER_FORCE_OCR", + "rag.datalab_marker_force_ocr", + os.environ.get("DATALAB_MARKER_FORCE_OCR", "false").lower() == "true", +) + +DATALAB_MARKER_PAGINATE = PersistentConfig( + "DATALAB_MARKER_PAGINATE", + "rag.datalab_marker_paginate", + os.environ.get("DATALAB_MARKER_PAGINATE", "false").lower() == "true", +) + +DATALAB_MARKER_STRIP_EXISTING_OCR = PersistentConfig( + "DATALAB_MARKER_STRIP_EXISTING_OCR", + "rag.datalab_marker_strip_existing_ocr", + os.environ.get("DATALAB_MARKER_STRIP_EXISTING_OCR", "false").lower() == "true", +) + +DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = PersistentConfig( + "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", + "rag.datalab_marker_disable_image_extraction", + os.environ.get("DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", "false").lower() + == "true", +) + +DATALAB_MARKER_FORMAT_LINES = PersistentConfig( + "DATALAB_MARKER_FORMAT_LINES", + "rag.datalab_marker_format_lines", + os.environ.get("DATALAB_MARKER_FORMAT_LINES", "false").lower() == "true", +) + +DATALAB_MARKER_OUTPUT_FORMAT = PersistentConfig( + "DATALAB_MARKER_OUTPUT_FORMAT", + "rag.datalab_marker_output_format", + os.environ.get("DATALAB_MARKER_OUTPUT_FORMAT", "markdown"), +) + +MINERU_API_MODE = PersistentConfig( + "MINERU_API_MODE", + "rag.mineru_api_mode", + os.environ.get("MINERU_API_MODE", "local"), # "local" or "cloud" +) + +MINERU_API_URL = PersistentConfig( + "MINERU_API_URL", + "rag.mineru_api_url", + os.environ.get("MINERU_API_URL", "http://localhost:8000"), +) + +MINERU_API_KEY = PersistentConfig( + "MINERU_API_KEY", + "rag.mineru_api_key", + os.environ.get("MINERU_API_KEY", ""), +) + +mineru_params = os.getenv("MINERU_PARAMS", "") +try: + mineru_params = json.loads(mineru_params) +except json.JSONDecodeError: + mineru_params = {} + +MINERU_PARAMS = PersistentConfig( + "MINERU_PARAMS", + "rag.mineru_params", + mineru_params, +) + EXTERNAL_DOCUMENT_LOADER_URL = PersistentConfig( "EXTERNAL_DOCUMENT_LOADER_URL", "rag.external_document_loader_url", @@ -1872,6 +2351,30 @@ class BannerModel(BaseModel): os.getenv("DOCLING_SERVER_URL", "http://docling:5001"), ) +docling_params = os.getenv("DOCLING_PARAMS", "") +try: + docling_params = json.loads(docling_params) +except json.JSONDecodeError: + docling_params = {} + +DOCLING_PARAMS = PersistentConfig( + "DOCLING_PARAMS", + "rag.docling_params", + docling_params, +) + +DOCLING_DO_OCR = PersistentConfig( + "DOCLING_DO_OCR", + "rag.docling_do_ocr", + os.getenv("DOCLING_DO_OCR", "True").lower() == "true", +) + +DOCLING_FORCE_OCR = PersistentConfig( + "DOCLING_FORCE_OCR", + "rag.docling_force_ocr", + os.getenv("DOCLING_FORCE_OCR", "False").lower() == "true", +) + DOCLING_OCR_ENGINE = PersistentConfig( "DOCLING_OCR_ENGINE", "rag.docling_ocr_engine", @@ -1884,12 +2387,64 @@ class BannerModel(BaseModel): os.getenv("DOCLING_OCR_LANG", "eng,fra,deu,spa"), ) +DOCLING_PDF_BACKEND = PersistentConfig( + "DOCLING_PDF_BACKEND", + "rag.docling_pdf_backend", + os.getenv("DOCLING_PDF_BACKEND", "dlparse_v4"), +) + +DOCLING_TABLE_MODE = PersistentConfig( + "DOCLING_TABLE_MODE", + "rag.docling_table_mode", + os.getenv("DOCLING_TABLE_MODE", "accurate"), +) + +DOCLING_PIPELINE = PersistentConfig( + "DOCLING_PIPELINE", + "rag.docling_pipeline", + os.getenv("DOCLING_PIPELINE", "standard"), +) + DOCLING_DO_PICTURE_DESCRIPTION = PersistentConfig( "DOCLING_DO_PICTURE_DESCRIPTION", "rag.docling_do_picture_description", os.getenv("DOCLING_DO_PICTURE_DESCRIPTION", "False").lower() == "true", ) +DOCLING_PICTURE_DESCRIPTION_MODE = PersistentConfig( + "DOCLING_PICTURE_DESCRIPTION_MODE", + "rag.docling_picture_description_mode", + os.getenv("DOCLING_PICTURE_DESCRIPTION_MODE", ""), +) + + +docling_picture_description_local = os.getenv("DOCLING_PICTURE_DESCRIPTION_LOCAL", "") +try: + docling_picture_description_local = json.loads(docling_picture_description_local) +except json.JSONDecodeError: + docling_picture_description_local = {} + + +DOCLING_PICTURE_DESCRIPTION_LOCAL = PersistentConfig( + "DOCLING_PICTURE_DESCRIPTION_LOCAL", + "rag.docling_picture_description_local", + docling_picture_description_local, +) + +docling_picture_description_api = os.getenv("DOCLING_PICTURE_DESCRIPTION_API", "") +try: + docling_picture_description_api = json.loads(docling_picture_description_api) +except json.JSONDecodeError: + docling_picture_description_api = {} + + +DOCLING_PICTURE_DESCRIPTION_API = PersistentConfig( + "DOCLING_PICTURE_DESCRIPTION_API", + "rag.docling_picture_description_api", + docling_picture_description_api, +) + + DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig( "DOCUMENT_INTELLIGENCE_ENDPOINT", "rag.document_intelligence_endpoint", @@ -1928,6 +2483,11 @@ class BannerModel(BaseModel): "rag.relevance_threshold", float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")), ) +RAG_HYBRID_BM25_WEIGHT = PersistentConfig( + "RAG_HYBRID_BM25_WEIGHT", + "rag.hybrid_bm25_weight", + float(os.environ.get("RAG_HYBRID_BM25_WEIGHT", "0.5")), +) ENABLE_RAG_HYBRID_SEARCH = PersistentConfig( "ENABLE_RAG_HYBRID_SEARCH", @@ -1961,6 +2521,27 @@ class BannerModel(BaseModel): ), ) +FILE_IMAGE_COMPRESSION_WIDTH = PersistentConfig( + "FILE_IMAGE_COMPRESSION_WIDTH", + "file.image_compression_width", + ( + int(os.environ.get("FILE_IMAGE_COMPRESSION_WIDTH")) + if os.environ.get("FILE_IMAGE_COMPRESSION_WIDTH") + else None + ), +) + +FILE_IMAGE_COMPRESSION_HEIGHT = PersistentConfig( + "FILE_IMAGE_COMPRESSION_HEIGHT", + "file.image_compression_height", + ( + int(os.environ.get("FILE_IMAGE_COMPRESSION_HEIGHT")) + if os.environ.get("FILE_IMAGE_COMPRESSION_HEIGHT") + else None + ), +) + + RAG_ALLOWED_FILE_EXTENSIONS = PersistentConfig( "RAG_ALLOWED_FILE_EXTENSIONS", "rag.file.allowed_extensions", @@ -2124,6 +2705,22 @@ class BannerModel(BaseModel): os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY), ) +RAG_AZURE_OPENAI_BASE_URL = PersistentConfig( + "RAG_AZURE_OPENAI_BASE_URL", + "rag.azure_openai.base_url", + os.getenv("RAG_AZURE_OPENAI_BASE_URL", ""), +) +RAG_AZURE_OPENAI_API_KEY = PersistentConfig( + "RAG_AZURE_OPENAI_API_KEY", + "rag.azure_openai.api_key", + os.getenv("RAG_AZURE_OPENAI_API_KEY", ""), +) +RAG_AZURE_OPENAI_API_VERSION = PersistentConfig( + "RAG_AZURE_OPENAI_API_VERSION", + "rag.azure_openai.api_version", + os.getenv("RAG_AZURE_OPENAI_API_VERSION", ""), +) + RAG_OLLAMA_BASE_URL = PersistentConfig( "RAG_OLLAMA_BASE_URL", "rag.ollama.url", @@ -2177,6 +2774,12 @@ class BannerModel(BaseModel): ) +BYPASS_WEB_SEARCH_WEB_LOADER = PersistentConfig( + "BYPASS_WEB_SEARCH_WEB_LOADER", + "rag.web.search.bypass_web_loader", + os.getenv("BYPASS_WEB_SEARCH_WEB_LOADER", "False").lower() == "true", +) + WEB_SEARCH_RESULT_COUNT = PersistentConfig( "WEB_SEARCH_RESULT_COUNT", "rag.web.search.result_count", @@ -2202,12 +2805,21 @@ class BannerModel(BaseModel): int(os.getenv("WEB_SEARCH_CONCURRENT_REQUESTS", "10")), ) + WEB_LOADER_ENGINE = PersistentConfig( "WEB_LOADER_ENGINE", "rag.web.loader.engine", os.environ.get("WEB_LOADER_ENGINE", ""), ) + +WEB_LOADER_CONCURRENT_REQUESTS = PersistentConfig( + "WEB_LOADER_CONCURRENT_REQUESTS", + "rag.web.loader.concurrent_requests", + int(os.getenv("WEB_LOADER_CONCURRENT_REQUESTS", "10")), +) + + ENABLE_WEB_LOADER_SSL_VERIFICATION = PersistentConfig( "ENABLE_WEB_LOADER_SSL_VERIFICATION", "rag.web.loader.ssl_verification", @@ -2221,6 +2833,12 @@ class BannerModel(BaseModel): ) +OLLAMA_CLOUD_WEB_SEARCH_API_KEY = PersistentConfig( + "OLLAMA_CLOUD_WEB_SEARCH_API_KEY", + "rag.web.search.ollama_cloud_api_key", + os.getenv("OLLAMA_CLOUD_API_KEY", ""), +) + SEARXNG_QUERY_URL = PersistentConfig( "SEARXNG_QUERY_URL", "rag.web.search.searxng_query_url", @@ -2361,6 +2979,18 @@ class BannerModel(BaseModel): os.getenv("PERPLEXITY_API_KEY", ""), ) +PERPLEXITY_MODEL = PersistentConfig( + "PERPLEXITY_MODEL", + "rag.web.search.perplexity_model", + os.getenv("PERPLEXITY_MODEL", "sonar"), +) + +PERPLEXITY_SEARCH_CONTEXT_USAGE = PersistentConfig( + "PERPLEXITY_SEARCH_CONTEXT_USAGE", + "rag.web.search.perplexity_search_context_usage", + os.getenv("PERPLEXITY_SEARCH_CONTEXT_USAGE", "medium"), +) + SOUGOU_API_SID = PersistentConfig( "SOUGOU_API_SID", "rag.web.search.sougou_api_sid", @@ -2637,6 +3267,12 @@ class BannerModel(BaseModel): "image_generation.openai.api_base_url", os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL), ) +IMAGES_OPENAI_API_VERSION = PersistentConfig( + "IMAGES_OPENAI_API_VERSION", + "image_generation.openai.api_version", + os.getenv("IMAGES_OPENAI_API_VERSION", ""), +) + IMAGES_OPENAI_API_KEY = PersistentConfig( "IMAGES_OPENAI_API_KEY", "image_generation.openai.api_key", @@ -2725,6 +3361,18 @@ class BannerModel(BaseModel): os.getenv("AUDIO_STT_MODEL", ""), ) +AUDIO_STT_SUPPORTED_CONTENT_TYPES = PersistentConfig( + "AUDIO_STT_SUPPORTED_CONTENT_TYPES", + "audio.stt.supported_content_types", + [ + content_type.strip() + for content_type in os.environ.get( + "AUDIO_STT_SUPPORTED_CONTENT_TYPES", "" + ).split(",") + if content_type.strip() + ], +) + AUDIO_STT_AZURE_API_KEY = PersistentConfig( "AUDIO_STT_AZURE_API_KEY", "audio.stt.azure.api_key", @@ -2766,6 +3414,19 @@ class BannerModel(BaseModel): os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY), ) +audio_tts_openai_params = os.getenv("AUDIO_TTS_OPENAI_PARAMS", "") +try: + audio_tts_openai_params = json.loads(audio_tts_openai_params) +except json.JSONDecodeError: + audio_tts_openai_params = {} + +AUDIO_TTS_OPENAI_PARAMS = PersistentConfig( + "AUDIO_TTS_OPENAI_PARAMS", + "audio.tts.openai.params", + audio_tts_openai_params, +) + + AUDIO_TTS_API_KEY = PersistentConfig( "AUDIO_TTS_API_KEY", "audio.tts.api_key", @@ -2899,3 +3560,22 @@ class BannerModel(BaseModel): LDAP_CIPHERS = PersistentConfig( "LDAP_CIPHERS", "ldap.server.ciphers", os.environ.get("LDAP_CIPHERS", "ALL") ) + +# For LDAP Group Management +ENABLE_LDAP_GROUP_MANAGEMENT = PersistentConfig( + "ENABLE_LDAP_GROUP_MANAGEMENT", + "ldap.group.enable_management", + os.environ.get("ENABLE_LDAP_GROUP_MANAGEMENT", "False").lower() == "true", +) + +ENABLE_LDAP_GROUP_CREATION = PersistentConfig( + "ENABLE_LDAP_GROUP_CREATION", + "ldap.group.enable_creation", + os.environ.get("ENABLE_LDAP_GROUP_CREATION", "False").lower() == "true", +) + +LDAP_ATTRIBUTE_FOR_GROUPS = PersistentConfig( + "LDAP_ATTRIBUTE_FOR_GROUPS", + "ldap.server.attribute_for_groups", + os.environ.get("LDAP_ATTRIBUTE_FOR_GROUPS", "memberOf"), +) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index 95c54a0d270..6d63295ab8d 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -38,6 +38,7 @@ def __str__(self) -> str: ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string." MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string." NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string." + MODEL_ID_TOO_LONG = "The model id is too long. Please make sure your model id is less than 256 characters long." INVALID_TOKEN = ( "Your session has expired or the token is invalid. Please sign in again." @@ -111,6 +112,7 @@ def __str__(self) -> str: DEFAULT = lambda task="": f"{task if task else 'generation'}" TITLE_GENERATION = "title_generation" + FOLLOW_UP_GENERATION = "follow_up_generation" TAGS_GENERATION = "tags_generation" EMOJI_GENERATION = "emoji_generation" QUERY_GENERATION = "query_generation" diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 59557349e3a..8f9c1fbc445 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -5,7 +5,9 @@ import pkgutil import sys import shutil +from uuid import uuid4 from pathlib import Path +from cryptography.hazmat.primitives import serialization import markdown from bs4 import BeautifulSoup @@ -15,14 +17,17 @@ # Load .env file #################################### -OPEN_WEBUI_DIR = Path(__file__).parent # the path containing this file -print(OPEN_WEBUI_DIR) +# Use .resolve() to get the canonical path, removing any '..' or '.' components +ENV_FILE_PATH = Path(__file__).resolve() -BACKEND_DIR = OPEN_WEBUI_DIR.parent # the path containing this file -BASE_DIR = BACKEND_DIR.parent # the path containing the backend/ +# OPEN_WEBUI_DIR should be the directory where env.py resides (open_webui/) +OPEN_WEBUI_DIR = ENV_FILE_PATH.parent -print(BACKEND_DIR) -print(BASE_DIR) +# BACKEND_DIR is the parent of OPEN_WEBUI_DIR (backend/) +BACKEND_DIR = OPEN_WEBUI_DIR.parent + +# BASE_DIR is the parent of BACKEND_DIR (open-webui-dev/) +BASE_DIR = BACKEND_DIR.parent try: from dotenv import find_dotenv, load_dotenv @@ -130,6 +135,7 @@ PACKAGE_DATA = {"version": "0.0.0"} VERSION = PACKAGE_DATA["version"] +INSTANCE_ID = os.environ.get("INSTANCE_ID", str(uuid4())) # Function to parse each section @@ -197,6 +203,7 @@ def parse_section(section): SAFE_MODE = os.environ.get("SAFE_MODE", "false").lower() == "true" + #################################### # ENABLE_FORWARD_USER_INFO_HEADERS #################################### @@ -205,6 +212,11 @@ def parse_section(section): os.environ.get("ENABLE_FORWARD_USER_INFO_HEADERS", "False").lower() == "true" ) +# Experimental feature, may be removed in future +ENABLE_STAR_SESSIONS_MIDDLEWARE = ( + os.environ.get("ENABLE_STAR_SESSIONS_MIDDLEWARE", "False").lower() == "true" +) + #################################### # WEBUI_BUILD_HASH #################################### @@ -264,21 +276,43 @@ def parse_section(section): DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") +DATABASE_TYPE = os.environ.get("DATABASE_TYPE") +DATABASE_USER = os.environ.get("DATABASE_USER") +DATABASE_PASSWORD = os.environ.get("DATABASE_PASSWORD") + +DATABASE_CRED = "" +if DATABASE_USER: + DATABASE_CRED += f"{DATABASE_USER}" +if DATABASE_PASSWORD: + DATABASE_CRED += f":{DATABASE_PASSWORD}" + +DB_VARS = { + "db_type": DATABASE_TYPE, + "db_cred": DATABASE_CRED, + "db_host": os.environ.get("DATABASE_HOST"), + "db_port": os.environ.get("DATABASE_PORT"), + "db_name": os.environ.get("DATABASE_NAME"), +} + +if all(DB_VARS.values()): + DATABASE_URL = f"{DB_VARS['db_type']}://{DB_VARS['db_cred']}@{DB_VARS['db_host']}:{DB_VARS['db_port']}/{DB_VARS['db_name']}" +elif DATABASE_TYPE == "sqlite+sqlcipher" and not os.environ.get("DATABASE_URL"): + # Handle SQLCipher with local file when DATABASE_URL wasn't explicitly set + DATABASE_URL = f"sqlite+sqlcipher:///{DATA_DIR}/webui.db" + # Replace the postgres:// with postgresql:// if "postgres://" in DATABASE_URL: DATABASE_URL = DATABASE_URL.replace("postgres://", "postgresql://") DATABASE_SCHEMA = os.environ.get("DATABASE_SCHEMA", None) -DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", 0) +DATABASE_POOL_SIZE = os.environ.get("DATABASE_POOL_SIZE", None) -if DATABASE_POOL_SIZE == "": - DATABASE_POOL_SIZE = 0 -else: +if DATABASE_POOL_SIZE != None: try: DATABASE_POOL_SIZE = int(DATABASE_POOL_SIZE) except Exception: - DATABASE_POOL_SIZE = 0 + DATABASE_POOL_SIZE = None DATABASE_POOL_MAX_OVERFLOW = os.environ.get("DATABASE_POOL_MAX_OVERFLOW", 0) @@ -310,6 +344,21 @@ def parse_section(section): except Exception: DATABASE_POOL_RECYCLE = 3600 +DATABASE_ENABLE_SQLITE_WAL = ( + os.environ.get("DATABASE_ENABLE_SQLITE_WAL", "False").lower() == "true" +) + +DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = os.environ.get( + "DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL", None +) +if DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL is not None: + try: + DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = float( + DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL + ) + except Exception: + DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL = 0.0 + RESET_CONFIG_ON_START = ( os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true" ) @@ -318,14 +367,29 @@ def parse_section(section): os.environ.get("ENABLE_REALTIME_CHAT_SAVE", "False").lower() == "true" ) +ENABLE_QUERIES_CACHE = os.environ.get("ENABLE_QUERIES_CACHE", "False").lower() == "true" + #################################### # REDIS #################################### REDIS_URL = os.environ.get("REDIS_URL", "") +REDIS_CLUSTER = os.environ.get("REDIS_CLUSTER", "False").lower() == "true" + +REDIS_KEY_PREFIX = os.environ.get("REDIS_KEY_PREFIX", "open-webui") + REDIS_SENTINEL_HOSTS = os.environ.get("REDIS_SENTINEL_HOSTS", "") REDIS_SENTINEL_PORT = os.environ.get("REDIS_SENTINEL_PORT", "26379") +# Maximum number of retries for Redis operations when using Sentinel fail-over +REDIS_SENTINEL_MAX_RETRY_COUNT = os.environ.get("REDIS_SENTINEL_MAX_RETRY_COUNT", "2") +try: + REDIS_SENTINEL_MAX_RETRY_COUNT = int(REDIS_SENTINEL_MAX_RETRY_COUNT) + if REDIS_SENTINEL_MAX_RETRY_COUNT < 1: + REDIS_SENTINEL_MAX_RETRY_COUNT = 2 +except ValueError: + REDIS_SENTINEL_MAX_RETRY_COUNT = 2 + #################################### # UVICORN WORKERS #################################### @@ -345,10 +409,22 @@ def parse_section(section): #################################### WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" + +ENABLE_INITIAL_ADMIN_SIGNUP = ( + os.environ.get("ENABLE_INITIAL_ADMIN_SIGNUP", "False").lower() == "true" +) +ENABLE_SIGNUP_PASSWORD_CONFIRMATION = ( + os.environ.get("ENABLE_SIGNUP_PASSWORD_CONFIRMATION", "False").lower() == "true" +) + WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None ) WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None) +WEBUI_AUTH_TRUSTED_GROUPS_HEADER = os.environ.get( + "WEBUI_AUTH_TRUSTED_GROUPS_HEADER", None +) + BYPASS_MODEL_ACCESS_CONTROL = ( os.environ.get("BYPASS_MODEL_ACCESS_CONTROL", "False").lower() == "true" @@ -390,19 +466,136 @@ def parse_section(section): if WEBUI_AUTH and WEBUI_SECRET_KEY == "": raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) +ENABLE_COMPRESSION_MIDDLEWARE = ( + os.environ.get("ENABLE_COMPRESSION_MIDDLEWARE", "True").lower() == "true" +) + +#################################### +# OAUTH Configuration +#################################### +ENABLE_OAUTH_EMAIL_FALLBACK = ( + os.environ.get("ENABLE_OAUTH_EMAIL_FALLBACK", "False").lower() == "true" +) + +ENABLE_OAUTH_ID_TOKEN_COOKIE = ( + os.environ.get("ENABLE_OAUTH_ID_TOKEN_COOKIE", "True").lower() == "true" +) + +OAUTH_CLIENT_INFO_ENCRYPTION_KEY = os.environ.get( + "OAUTH_CLIENT_INFO_ENCRYPTION_KEY", WEBUI_SECRET_KEY +) + +OAUTH_SESSION_TOKEN_ENCRYPTION_KEY = os.environ.get( + "OAUTH_SESSION_TOKEN_ENCRYPTION_KEY", WEBUI_SECRET_KEY +) + +#################################### +# SCIM Configuration +#################################### + +SCIM_ENABLED = os.environ.get("SCIM_ENABLED", "False").lower() == "true" +SCIM_TOKEN = os.environ.get("SCIM_TOKEN", "") + +#################################### +# LICENSE_KEY +#################################### + +LICENSE_KEY = os.environ.get("LICENSE_KEY", "") + +LICENSE_BLOB = None +LICENSE_BLOB_PATH = os.environ.get("LICENSE_BLOB_PATH", DATA_DIR / "l.data") +if LICENSE_BLOB_PATH and os.path.exists(LICENSE_BLOB_PATH): + with open(LICENSE_BLOB_PATH, "rb") as f: + LICENSE_BLOB = f.read() + +LICENSE_PUBLIC_KEY = os.environ.get("LICENSE_PUBLIC_KEY", "") + +pk = None +if LICENSE_PUBLIC_KEY: + pk = serialization.load_pem_public_key( + f""" +-----BEGIN PUBLIC KEY----- +{LICENSE_PUBLIC_KEY} +-----END PUBLIC KEY----- +""".encode( + "utf-8" + ) + ) + + +#################################### +# MODELS +#################################### + +MODELS_CACHE_TTL = os.environ.get("MODELS_CACHE_TTL", "1") +if MODELS_CACHE_TTL == "": + MODELS_CACHE_TTL = None +else: + try: + MODELS_CACHE_TTL = int(MODELS_CACHE_TTL) + except Exception: + MODELS_CACHE_TTL = 1 + + +#################################### +# CHAT +#################################### + +CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = os.environ.get( + "CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE", "1" +) + +if CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE == "": + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1 +else: + try: + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = int( + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE + ) + except Exception: + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE = 1 + + +CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = os.environ.get( + "CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES", "30" +) + +if CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES == "": + CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 +else: + try: + CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = int(CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES) + except Exception: + CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES = 30 + + +#################################### +# WEBSOCKET SUPPORT +#################################### + ENABLE_WEBSOCKET_SUPPORT = ( os.environ.get("ENABLE_WEBSOCKET_SUPPORT", "True").lower() == "true" ) + WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) -WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60) +WEBSOCKET_REDIS_CLUSTER = ( + os.environ.get("WEBSOCKET_REDIS_CLUSTER", str(REDIS_CLUSTER)).lower() == "true" +) -WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") +websocket_redis_lock_timeout = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", "60") +try: + WEBSOCKET_REDIS_LOCK_TIMEOUT = int(websocket_redis_lock_timeout) +except ValueError: + WEBSOCKET_REDIS_LOCK_TIMEOUT = 60 + +WEBSOCKET_SENTINEL_HOSTS = os.environ.get("WEBSOCKET_SENTINEL_HOSTS", "") WEBSOCKET_SENTINEL_PORT = os.environ.get("WEBSOCKET_SENTINEL_PORT", "26379") + AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") if AIOHTTP_CLIENT_TIMEOUT == "": @@ -500,11 +693,14 @@ def parse_section(section): # OFFLINE_MODE #################################### +ENABLE_VERSION_UPDATE_CHECK = ( + os.environ.get("ENABLE_VERSION_UPDATE_CHECK", "true").lower() == "true" +) OFFLINE_MODE = os.environ.get("OFFLINE_MODE", "false").lower() == "true" if OFFLINE_MODE: os.environ["HF_HUB_OFFLINE"] = "1" - + ENABLE_VERSION_UPDATE_CHECK = False #################################### # AUDIT LOGGING @@ -513,6 +709,14 @@ def parse_section(section): AUDIT_LOGS_FILE_PATH = f"{DATA_DIR}/audit.log" # Maximum size of a file before rotating into a new log file AUDIT_LOG_FILE_ROTATION_SIZE = os.getenv("AUDIT_LOG_FILE_ROTATION_SIZE", "10MB") + +# Comma separated list of logger names to use for audit logging +# Default is "uvicorn.access" which is the access log for Uvicorn +# You can add more logger names to this list if you want to capture more logs +AUDIT_UVICORN_LOGGER_NAMES = os.getenv( + "AUDIT_UVICORN_LOGGER_NAMES", "uvicorn.access" +).split(",") + # METADATA | REQUEST | REQUEST_RESPONSE AUDIT_LOG_LEVEL = os.getenv("AUDIT_LOG_LEVEL", "NONE").upper() try: @@ -533,9 +737,34 @@ def parse_section(section): #################################### ENABLE_OTEL = os.environ.get("ENABLE_OTEL", "False").lower() == "true" +ENABLE_OTEL_TRACES = os.environ.get("ENABLE_OTEL_TRACES", "False").lower() == "true" +ENABLE_OTEL_METRICS = os.environ.get("ENABLE_OTEL_METRICS", "False").lower() == "true" +ENABLE_OTEL_LOGS = os.environ.get("ENABLE_OTEL_LOGS", "False").lower() == "true" + OTEL_EXPORTER_OTLP_ENDPOINT = os.environ.get( "OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317" ) +OTEL_METRICS_EXPORTER_OTLP_ENDPOINT = os.environ.get( + "OTEL_METRICS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT +) +OTEL_LOGS_EXPORTER_OTLP_ENDPOINT = os.environ.get( + "OTEL_LOGS_EXPORTER_OTLP_ENDPOINT", OTEL_EXPORTER_OTLP_ENDPOINT +) +OTEL_EXPORTER_OTLP_INSECURE = ( + os.environ.get("OTEL_EXPORTER_OTLP_INSECURE", "False").lower() == "true" +) +OTEL_METRICS_EXPORTER_OTLP_INSECURE = ( + os.environ.get( + "OTEL_METRICS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) + ).lower() + == "true" +) +OTEL_LOGS_EXPORTER_OTLP_INSECURE = ( + os.environ.get( + "OTEL_LOGS_EXPORTER_OTLP_INSECURE", str(OTEL_EXPORTER_OTLP_INSECURE) + ).lower() + == "true" +) OTEL_SERVICE_NAME = os.environ.get("OTEL_SERVICE_NAME", "open-webui") OTEL_RESOURCE_ATTRIBUTES = os.environ.get( "OTEL_RESOURCE_ATTRIBUTES", "" @@ -543,6 +772,33 @@ def parse_section(section): OTEL_TRACES_SAMPLER = os.environ.get( "OTEL_TRACES_SAMPLER", "parentbased_always_on" ).lower() +OTEL_BASIC_AUTH_USERNAME = os.environ.get("OTEL_BASIC_AUTH_USERNAME", "") +OTEL_BASIC_AUTH_PASSWORD = os.environ.get("OTEL_BASIC_AUTH_PASSWORD", "") + +OTEL_METRICS_BASIC_AUTH_USERNAME = os.environ.get( + "OTEL_METRICS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME +) +OTEL_METRICS_BASIC_AUTH_PASSWORD = os.environ.get( + "OTEL_METRICS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD +) +OTEL_LOGS_BASIC_AUTH_USERNAME = os.environ.get( + "OTEL_LOGS_BASIC_AUTH_USERNAME", OTEL_BASIC_AUTH_USERNAME +) +OTEL_LOGS_BASIC_AUTH_PASSWORD = os.environ.get( + "OTEL_LOGS_BASIC_AUTH_PASSWORD", OTEL_BASIC_AUTH_PASSWORD +) + +OTEL_OTLP_SPAN_EXPORTER = os.environ.get( + "OTEL_OTLP_SPAN_EXPORTER", "grpc" +).lower() # grpc or http + +OTEL_METRICS_OTLP_SPAN_EXPORTER = os.environ.get( + "OTEL_METRICS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER +).lower() # grpc or http + +OTEL_LOGS_OTLP_SPAN_EXPORTER = os.environ.get( + "OTEL_LOGS_OTLP_SPAN_EXPORTER", OTEL_OTLP_SPAN_EXPORTER +).lower() # grpc or http #################################### # TOOLS/FUNCTIONS PIP OPTIONS diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index 340b60ba47d..316efe18e7f 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -19,16 +19,21 @@ from starlette.responses import Response, StreamingResponse +from open_webui.constants import ERROR_MESSAGES from open_webui.socket.main import ( get_event_call, get_event_emitter, ) +from open_webui.models.users import UserModel from open_webui.models.functions import Functions from open_webui.models.models import Models -from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.plugin import ( + load_function_module_by_id, + get_function_module_from_cache, +) from open_webui.utils.tools import get_tools from open_webui.utils.access_control import has_access @@ -43,7 +48,7 @@ ) from open_webui.utils.payload import ( apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, + apply_system_prompt_to_body, ) @@ -53,16 +58,23 @@ def get_function_module_by_id(request: Request, pipe_id: str): - # Check if function is already loaded - if pipe_id not in request.app.state.FUNCTIONS: - function_module, _, _ = load_function_module_by_id(pipe_id) - request.app.state.FUNCTIONS[pipe_id] = function_module - else: - function_module = request.app.state.FUNCTIONS[pipe_id] + function_module, _, _ = get_function_module_from_cache(request, pipe_id) if hasattr(function_module, "valves") and hasattr(function_module, "Valves"): + Valves = function_module.Valves valves = Functions.get_function_valves_by_id(pipe_id) - function_module.valves = function_module.Valves(**(valves if valves else {})) + + if valves: + try: + function_module.valves = Valves( + **{k: v for k, v in valves.items() if v is not None} + ) + except Exception as e: + log.exception(f"Error loading valves for function {pipe_id}: {e}") + raise e + else: + function_module.valves = Valves() + return function_module @@ -71,65 +83,75 @@ async def get_function_models(request): pipe_models = [] for pipe in pipes: - function_module = get_function_module_by_id(request, pipe.id) + try: + function_module = get_function_module_by_id(request, pipe.id) - # Check if function is a manifold - if hasattr(function_module, "pipes"): - sub_pipes = [] + has_user_valves = False + if hasattr(function_module, "UserValves"): + has_user_valves = True - # Handle pipes being a list, sync function, or async function - try: - if callable(function_module.pipes): - if asyncio.iscoroutinefunction(function_module.pipes): - sub_pipes = await function_module.pipes() - else: - sub_pipes = function_module.pipes() - else: - sub_pipes = function_module.pipes - except Exception as e: - log.exception(e) + # Check if function is a manifold + if hasattr(function_module, "pipes"): sub_pipes = [] - log.debug( - f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" - ) - - for p in sub_pipes: - sub_pipe_id = f'{pipe.id}.{p["id"]}' - sub_pipe_name = p["name"] + # Handle pipes being a list, sync function, or async function + try: + if callable(function_module.pipes): + if asyncio.iscoroutinefunction(function_module.pipes): + sub_pipes = await function_module.pipes() + else: + sub_pipes = function_module.pipes() + else: + sub_pipes = function_module.pipes + except Exception as e: + log.exception(e) + sub_pipes = [] - if hasattr(function_module, "name"): - sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + log.debug( + f"get_function_models: function '{pipe.id}' is a manifold of {sub_pipes}" + ) - pipe_flag = {"type": pipe.type} + for p in sub_pipes: + sub_pipe_id = f'{pipe.id}.{p["id"]}' + sub_pipe_name = p["name"] + + if hasattr(function_module, "name"): + sub_pipe_name = f"{function_module.name}{sub_pipe_name}" + + pipe_flag = {"type": pipe.type} + + pipe_models.append( + { + "id": sub_pipe_id, + "name": sub_pipe_name, + "object": "model", + "created": pipe.created_at, + "owned_by": "openai", + "pipe": pipe_flag, + "has_user_valves": has_user_valves, + } + ) + else: + pipe_flag = {"type": "pipe"} + + log.debug( + f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" + ) pipe_models.append( { - "id": sub_pipe_id, - "name": sub_pipe_name, + "id": pipe.id, + "name": pipe.name, "object": "model", "created": pipe.created_at, "owned_by": "openai", "pipe": pipe_flag, + "has_user_valves": has_user_valves, } ) - else: - pipe_flag = {"type": "pipe"} - - log.debug( - f"get_function_models: function '{pipe.id}' is a single pipe {{ 'id': {pipe.id}, 'name': {pipe.name} }}" - ) - - pipe_models.append( - { - "id": pipe.id, - "name": pipe.name, - "object": "model", - "created": pipe.created_at, - "owned_by": "openai", - "pipe": pipe_flag, - } - ) + except Exception as e: + log.exception(e) + continue return pipe_models @@ -220,6 +242,16 @@ def get_function_params(function_module, form_data, user, extra_params=None): __task__ = metadata.get("task", None) __task_body__ = metadata.get("task_body", None) + oauth_token = None + try: + if request.cookies.get("oauth_session_id", None): + oauth_token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + extra_params = { "__event_emitter__": __event_emitter__, "__event_call__": __event_call__, @@ -229,16 +261,12 @@ def get_function_params(function_module, form_data, user, extra_params=None): "__task__": __task__, "__task_body__": __task_body__, "__files__": files, - "__user__": { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - }, + "__user__": user.model_dump() if isinstance(user, UserModel) else {}, "__metadata__": metadata, + "__oauth_token__": oauth_token, "__request__": request, } - extra_params["__tools__"] = get_tools( + extra_params["__tools__"] = await get_tools( request, tool_ids, user, @@ -255,8 +283,11 @@ def get_function_params(function_module, form_data, user, extra_params=None): form_data["model"] = model_info.base_model_id params = model_info.params.model_dump() - form_data = apply_model_params_to_body_openai(params, form_data) - form_data = apply_model_system_prompt_to_body(params, form_data, metadata, user) + + if params: + system = params.pop("system", None) + form_data = apply_model_params_to_body_openai(params, form_data) + form_data = apply_system_prompt_to_body(system, form_data, metadata, user) pipe_id = get_pipe_id(form_data) function_module = get_function_module_by_id(request, pipe_id) diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index 840f571cc91..b6913d87b09 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -1,3 +1,4 @@ +import os import json import logging from contextlib import contextmanager @@ -13,9 +14,10 @@ DATABASE_POOL_RECYCLE, DATABASE_POOL_SIZE, DATABASE_POOL_TIMEOUT, + DATABASE_ENABLE_SQLITE_WAL, ) from peewee_migrate import Router -from sqlalchemy import Dialect, create_engine, MetaData, types +from sqlalchemy import Dialect, create_engine, MetaData, event, types from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import QueuePool, NullPool @@ -62,6 +64,9 @@ def handle_peewee_migration(DATABASE_URL): except Exception as e: log.error(f"Failed to initialize the database connection: {e}") + log.warning( + "Hint: If your database password contains special characters, you may need to URL-encode it." + ) raise finally: # Properly closing the database connection @@ -76,25 +81,68 @@ def handle_peewee_migration(DATABASE_URL): SQLALCHEMY_DATABASE_URL = DATABASE_URL -if "sqlite" in SQLALCHEMY_DATABASE_URL: + +# Handle SQLCipher URLs +if SQLALCHEMY_DATABASE_URL.startswith("sqlite+sqlcipher://"): + database_password = os.environ.get("DATABASE_PASSWORD") + if not database_password or database_password.strip() == "": + raise ValueError( + "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" + ) + + # Extract database path from SQLCipher URL + db_path = SQLALCHEMY_DATABASE_URL.replace("sqlite+sqlcipher://", "") + if db_path.startswith("/"): + db_path = db_path[1:] # Remove leading slash for relative paths + + # Create a custom creator function that uses sqlcipher3 + def create_sqlcipher_connection(): + import sqlcipher3 + + conn = sqlcipher3.connect(db_path, check_same_thread=False) + conn.execute(f"PRAGMA key = '{database_password}'") + return conn + + engine = create_engine( + "sqlite://", # Dummy URL since we're using creator + creator=create_sqlcipher_connection, + echo=False, + ) + + log.info("Connected to encrypted SQLite database using SQLCipher") + +elif "sqlite" in SQLALCHEMY_DATABASE_URL: engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} ) + + def on_connect(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + if DATABASE_ENABLE_SQLITE_WAL: + cursor.execute("PRAGMA journal_mode=WAL") + else: + cursor.execute("PRAGMA journal_mode=DELETE") + cursor.close() + + event.listen(engine, "connect", on_connect) else: - if DATABASE_POOL_SIZE > 0: - engine = create_engine( - SQLALCHEMY_DATABASE_URL, - pool_size=DATABASE_POOL_SIZE, - max_overflow=DATABASE_POOL_MAX_OVERFLOW, - pool_timeout=DATABASE_POOL_TIMEOUT, - pool_recycle=DATABASE_POOL_RECYCLE, - pool_pre_ping=True, - poolclass=QueuePool, - ) + if isinstance(DATABASE_POOL_SIZE, int): + if DATABASE_POOL_SIZE > 0: + engine = create_engine( + SQLALCHEMY_DATABASE_URL, + pool_size=DATABASE_POOL_SIZE, + max_overflow=DATABASE_POOL_MAX_OVERFLOW, + pool_timeout=DATABASE_POOL_TIMEOUT, + pool_recycle=DATABASE_POOL_RECYCLE, + pool_pre_ping=True, + poolclass=QueuePool, + ) + else: + engine = create_engine( + SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool + ) else: - engine = create_engine( - SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, poolclass=NullPool - ) + engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) SessionLocal = sessionmaker( diff --git a/backend/open_webui/internal/wrappers.py b/backend/open_webui/internal/wrappers.py index ccc62b9a574..554a5effdd2 100644 --- a/backend/open_webui/internal/wrappers.py +++ b/backend/open_webui/internal/wrappers.py @@ -1,4 +1,5 @@ import logging +import os from contextvars import ContextVar from open_webui.env import SRC_LOG_LEVELS @@ -43,24 +44,47 @@ class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase): def register_connection(db_url): - db = connect(db_url, unquote_password=True) - if isinstance(db, PostgresqlDatabase): - # Enable autoconnect for SQLite databases, managed by Peewee - db.autoconnect = True - db.reuse_if_open = True - log.info("Connected to PostgreSQL database") + # Check if using SQLCipher protocol + if db_url.startswith("sqlite+sqlcipher://"): + database_password = os.environ.get("DATABASE_PASSWORD") + if not database_password or database_password.strip() == "": + raise ValueError( + "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" + ) + from playhouse.sqlcipher_ext import SqlCipherDatabase - # Get the connection details - connection = parse(db_url, unquote_password=True) + # Parse the database path from SQLCipher URL + # Convert sqlite+sqlcipher:///path/to/db.sqlite to /path/to/db.sqlite + db_path = db_url.replace("sqlite+sqlcipher://", "") + if db_path.startswith("/"): + db_path = db_path[1:] # Remove leading slash for relative paths - # Use our custom database class that supports reconnection - db = ReconnectingPostgresqlDatabase(**connection) - db.connect(reuse_if_open=True) - elif isinstance(db, SqliteDatabase): - # Enable autoconnect for SQLite databases, managed by Peewee + # Use Peewee's native SqlCipherDatabase with encryption + db = SqlCipherDatabase(db_path, passphrase=database_password) db.autoconnect = True db.reuse_if_open = True - log.info("Connected to SQLite database") + log.info("Connected to encrypted SQLite database using SQLCipher") + else: - raise ValueError("Unsupported database connection") + # Standard database connection (existing logic) + db = connect(db_url, unquote_user=True, unquote_password=True) + if isinstance(db, PostgresqlDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to PostgreSQL database") + + # Get the connection details + connection = parse(db_url, unquote_user=True, unquote_password=True) + + # Use our custom database class that supports reconnection + db = ReconnectingPostgresqlDatabase(**connection) + db.connect(reuse_if_open=True) + elif isinstance(db, SqliteDatabase): + # Enable autoconnect for SQLite databases, managed by Peewee + db.autoconnect = True + db.reuse_if_open = True + log.info("Connected to SQLite database") + else: + raise ValueError("Unsupported database connection") return db diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index a5aee4bb829..9998af0e73d 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -8,6 +8,9 @@ import sys import time import random +import re +from uuid import uuid4 + from contextlib import asynccontextmanager from urllib.parse import urlencode, parse_qs, urlparse @@ -19,6 +22,7 @@ import aiohttp import anyio.to_thread import requests +from redis import Redis from fastapi import ( @@ -33,18 +37,25 @@ applications, BackgroundTasks, ) - from fastapi.openapi.docs import get_swagger_ui_html from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.responses import FileResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles +from starlette_compress import CompressMiddleware + from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse +from starlette.datastructures import Headers +from starsessions import ( + SessionMiddleware as StarSessionsMiddleware, + SessionAutoloadMiddleware, +) +from starsessions.stores.redis import RedisStore from open_webui.utils import logger from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware @@ -52,6 +63,9 @@ from open_webui.socket.main import ( app as socket_app, periodic_usage_pool_cleanup, + get_event_emitter, + get_models_in_use, + get_active_user_ids, ) from open_webui.routers import ( audio, @@ -78,10 +92,12 @@ tools, users, utils, + scim, ) from open_webui.routers.retrieval import ( get_embedding_function, + get_reranking_function, get_ef, get_rf, ) @@ -94,21 +110,19 @@ from open_webui.models.chats import Chats from open_webui.config import ( - LICENSE_KEY, # Ollama ENABLE_OLLAMA_API, OLLAMA_BASE_URLS, OLLAMA_API_CONFIGS, # OpenAI ENABLE_OPENAI_API, - ONEDRIVE_CLIENT_ID, - ONEDRIVE_SHAREPOINT_URL, - ONEDRIVE_SHAREPOINT_TENANT_ID, OPENAI_API_BASE_URLS, OPENAI_API_KEYS, OPENAI_API_CONFIGS, # Direct Connections ENABLE_DIRECT_CONNECTIONS, + # Model list + ENABLE_BASE_MODELS_CACHE, # Thread pool size for FastAPI/AnyIO THREAD_POOL_SIZE, # Tool Server Configs @@ -146,12 +160,14 @@ IMAGE_SIZE, IMAGE_STEPS, IMAGES_OPENAI_API_BASE_URL, + IMAGES_OPENAI_API_VERSION, IMAGES_OPENAI_API_KEY, IMAGES_GEMINI_API_BASE_URL, IMAGES_GEMINI_API_KEY, # Audio AUDIO_STT_ENGINE, AUDIO_STT_MODEL, + AUDIO_STT_SUPPORTED_CONTENT_TYPES, AUDIO_STT_OPENAI_API_BASE_URL, AUDIO_STT_OPENAI_API_KEY, AUDIO_STT_AZURE_API_KEY, @@ -159,13 +175,14 @@ AUDIO_STT_AZURE_LOCALES, AUDIO_STT_AZURE_BASE_URL, AUDIO_STT_AZURE_MAX_SPEAKERS, - AUDIO_TTS_API_KEY, AUDIO_TTS_ENGINE, AUDIO_TTS_MODEL, + AUDIO_TTS_VOICE, AUDIO_TTS_OPENAI_API_BASE_URL, AUDIO_TTS_OPENAI_API_KEY, + AUDIO_TTS_OPENAI_PARAMS, + AUDIO_TTS_API_KEY, AUDIO_TTS_SPLIT_ON, - AUDIO_TTS_VOICE, AUDIO_TTS_AZURE_SPEECH_REGION, AUDIO_TTS_AZURE_SPEECH_BASE_URL, AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, @@ -174,6 +191,7 @@ FIRECRAWL_API_BASE_URL, FIRECRAWL_API_KEY, WEB_LOADER_ENGINE, + WEB_LOADER_CONCURRENT_REQUESTS, WHISPER_MODEL, WHISPER_VAD_FILTER, WHISPER_LANGUAGE, @@ -196,29 +214,59 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_BATCH_SIZE, + RAG_TOP_K, + RAG_TOP_K_RERANKER, RAG_RELEVANCE_THRESHOLD, + RAG_HYBRID_BM25_WEIGHT, RAG_ALLOWED_FILE_EXTENSIONS, RAG_FILE_MAX_COUNT, RAG_FILE_MAX_SIZE, + FILE_IMAGE_COMPRESSION_WIDTH, + FILE_IMAGE_COMPRESSION_HEIGHT, RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_KEY, + RAG_AZURE_OPENAI_BASE_URL, + RAG_AZURE_OPENAI_API_KEY, + RAG_AZURE_OPENAI_API_VERSION, RAG_OLLAMA_BASE_URL, RAG_OLLAMA_API_KEY, CHUNK_OVERLAP, CHUNK_SIZE, CONTENT_EXTRACTION_ENGINE, + DATALAB_MARKER_API_KEY, + DATALAB_MARKER_API_BASE_URL, + DATALAB_MARKER_ADDITIONAL_CONFIG, + DATALAB_MARKER_SKIP_CACHE, + DATALAB_MARKER_FORCE_OCR, + DATALAB_MARKER_PAGINATE, + DATALAB_MARKER_STRIP_EXISTING_OCR, + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + DATALAB_MARKER_FORMAT_LINES, + DATALAB_MARKER_OUTPUT_FORMAT, + MINERU_API_MODE, + MINERU_API_URL, + MINERU_API_KEY, + MINERU_PARAMS, + DATALAB_MARKER_USE_LLM, EXTERNAL_DOCUMENT_LOADER_URL, EXTERNAL_DOCUMENT_LOADER_API_KEY, TIKA_SERVER_URL, DOCLING_SERVER_URL, + DOCLING_PARAMS, + DOCLING_DO_OCR, + DOCLING_FORCE_OCR, DOCLING_OCR_ENGINE, DOCLING_OCR_LANG, + DOCLING_PDF_BACKEND, + DOCLING_TABLE_MODE, + DOCLING_PIPELINE, DOCLING_DO_PICTURE_DESCRIPTION, + DOCLING_PICTURE_DESCRIPTION_MODE, + DOCLING_PICTURE_DESCRIPTION_LOCAL, + DOCLING_PICTURE_DESCRIPTION_API, DOCUMENT_INTELLIGENCE_ENDPOINT, DOCUMENT_INTELLIGENCE_KEY, MISTRAL_OCR_API_KEY, - RAG_TOP_K, - RAG_TOP_K_RERANKER, RAG_TEXT_SPLITTER, TIKTOKEN_ENCODING_NAME, PDF_EXTRACT_IMAGES, @@ -228,10 +276,12 @@ ENABLE_WEB_SEARCH, WEB_SEARCH_ENGINE, BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + BYPASS_WEB_SEARCH_WEB_LOADER, WEB_SEARCH_RESULT_COUNT, WEB_SEARCH_CONCURRENT_REQUESTS, WEB_SEARCH_TRUST_ENV, WEB_SEARCH_DOMAIN_FILTER_LIST, + OLLAMA_CLOUD_WEB_SEARCH_API_KEY, JINA_API_KEY, SEARCHAPI_API_KEY, SEARCHAPI_ENGINE, @@ -252,6 +302,8 @@ BRAVE_SEARCH_API_KEY, EXA_API_KEY, PERPLEXITY_API_KEY, + PERPLEXITY_MODEL, + PERPLEXITY_SEARCH_CONTEXT_USAGE, SOUGOU_API_SID, SOUGOU_API_SK, KAGI_SEARCH_API_KEY, @@ -261,14 +313,17 @@ GOOGLE_PSE_ENGINE_ID, GOOGLE_DRIVE_CLIENT_ID, GOOGLE_DRIVE_API_KEY, - ONEDRIVE_CLIENT_ID, + ENABLE_ONEDRIVE_INTEGRATION, + ONEDRIVE_CLIENT_ID_PERSONAL, + ONEDRIVE_CLIENT_ID_BUSINESS, ONEDRIVE_SHAREPOINT_URL, ONEDRIVE_SHAREPOINT_TENANT_ID, + ENABLE_ONEDRIVE_PERSONAL, + ENABLE_ONEDRIVE_BUSINESS, ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_WEB_LOADER_SSL_VERIFICATION, ENABLE_GOOGLE_DRIVE_INTEGRATION, - ENABLE_ONEDRIVE_INTEGRATION, UPLOAD_DIR, EXTERNAL_WEB_SEARCH_URL, EXTERNAL_WEB_SEARCH_API_KEY, @@ -293,6 +348,7 @@ ENABLE_MESSAGE_RATING, ENABLE_USER_WEBHOOKS, ENABLE_EVALUATION_ARENA_MODELS, + BYPASS_ADMIN_ACCESS_CONTROL, USER_PERMISSIONS, DEFAULT_USER_ROLE, PENDING_USER_OVERLAY_CONTENT, @@ -325,6 +381,10 @@ LDAP_CA_CERT_FILE, LDAP_VALIDATE_CERT, LDAP_CIPHERS, + # LDAP Group Management + ENABLE_LDAP_GROUP_MANAGEMENT, + ENABLE_LDAP_GROUP_CREATION, + LDAP_ATTRIBUTE_FOR_GROUPS, # Misc ENV, CACHE_DIR, @@ -337,16 +397,19 @@ RESPONSE_WATERMARK, # Admin ENABLE_ADMIN_CHAT_ACCESS, + BYPASS_ADMIN_ACCESS_CONTROL, ENABLE_ADMIN_EXPORT, # Tasks TASK_MODEL, TASK_MODEL_EXTERNAL, ENABLE_TAGS_GENERATION, ENABLE_TITLE_GENERATION, + ENABLE_FOLLOW_UP_GENERATION, ENABLE_SEARCH_QUERY_GENERATION, ENABLE_RETRIEVAL_QUERY_GENERATION, ENABLE_AUTOCOMPLETE_GENERATION, TITLE_GENERATION_PROMPT_TEMPLATE, + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, TAGS_GENERATION_PROMPT_TEMPLATE, IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, @@ -357,10 +420,13 @@ reset_config, ) from open_webui.env import ( + LICENSE_KEY, AUDIT_EXCLUDED_PATHS, AUDIT_LOG_LEVEL, CHANGELOG, REDIS_URL, + REDIS_CLUSTER, + REDIS_KEY_PREFIX, REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT, GLOBAL_LOG_LEVEL, @@ -368,20 +434,27 @@ SAFE_MODE, SRC_LOG_LEVELS, VERSION, + INSTANCE_ID, WEBUI_BUILD_HASH, WEBUI_SECRET_KEY, WEBUI_SESSION_COOKIE_SAME_SITE, WEBUI_SESSION_COOKIE_SECURE, + ENABLE_SIGNUP_PASSWORD_CONFIRMATION, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, WEBUI_AUTH_SIGNOUT_REDIRECT_URL, + # SCIM + SCIM_ENABLED, + SCIM_TOKEN, + ENABLE_COMPRESSION_MIDDLEWARE, ENABLE_WEBSOCKET_SUPPORT, BYPASS_MODEL_ACCESS_CONTROL, RESET_CONFIG_ON_START, - OFFLINE_MODE, + ENABLE_VERSION_UPDATE_CHECK, ENABLE_OTEL, EXTERNAL_PWA_MANIFEST_URL, AIOHTTP_CLIENT_SESSION_SSL, + ENABLE_STAR_SESSIONS_MIDDLEWARE, ) @@ -389,12 +462,14 @@ get_all_models, get_all_base_models, check_model_access, + get_filtered_models, ) from open_webui.utils.chat import ( generate_chat_completion as chat_completion_handler, chat_completed as chat_completed_handler, chat_action as chat_action_handler, ) +from open_webui.utils.embeddings import generate_embeddings from open_webui.utils.middleware import process_chat_payload, process_chat_response from open_webui.utils.access_control import has_access @@ -406,11 +481,19 @@ get_verified_user, ) from open_webui.utils.plugin import install_tool_and_function_dependencies -from open_webui.utils.oauth import OAuthManager +from open_webui.utils.oauth import ( + OAuthManager, + OAuthClientManager, + decrypt_data, + OAuthClientInformationFull, +) from open_webui.utils.security_headers import SecurityHeadersMiddleware +from open_webui.utils.redis import get_redis_connection from open_webui.tasks import ( - list_task_ids_by_chat_id, + redis_task_command_listener, + list_task_ids_by_item_id, + create_task, stop_task, list_tasks, ) # Import from tasks.py @@ -418,6 +501,9 @@ from open_webui.utils.redis import get_sentinels_from_env +from open_webui.constants import ERROR_MESSAGES + + if SAFE_MODE: print("SAFE MODE ENABLED") Functions.deactivate_all_functions() @@ -461,7 +547,9 @@ async def get_response(self, path: str, scope): @asynccontextmanager async def lifespan(app: FastAPI): + app.state.instance_id = INSTANCE_ID start_logger() + if RESET_CONFIG_ON_START: reset_config() @@ -473,14 +561,52 @@ async def lifespan(app: FastAPI): log.info("Installing external dependencies of functions and tools...") install_tool_and_function_dependencies() + app.state.redis = get_redis_connection( + redis_url=REDIS_URL, + redis_sentinels=get_sentinels_from_env( + REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT + ), + redis_cluster=REDIS_CLUSTER, + async_mode=True, + ) + + if app.state.redis is not None: + app.state.redis_task_command_listener = asyncio.create_task( + redis_task_command_listener(app) + ) + if THREAD_POOL_SIZE and THREAD_POOL_SIZE > 0: limiter = anyio.to_thread.current_default_thread_limiter() limiter.total_tokens = THREAD_POOL_SIZE asyncio.create_task(periodic_usage_pool_cleanup()) + if app.state.config.ENABLE_BASE_MODELS_CACHE: + await get_all_models( + Request( + # Creating a mock request object to pass to get_all_models + { + "type": "http", + "asgi.version": "3.0", + "asgi.spec_version": "2.0", + "method": "GET", + "path": "/internal", + "query_string": b"", + "headers": Headers({}).raw, + "client": ("127.0.0.1", 12345), + "server": ("127.0.0.1", 80), + "scheme": "http", + "app": app, + } + ), + None, + ) + yield + if hasattr(app.state, "redis_task_command_listener"): + app.state.redis_task_command_listener.cancel() + app = FastAPI( title="Open WebUI", @@ -490,12 +616,22 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) +# For Open WebUI OIDC/OAuth2 oauth_manager = OAuthManager(app) +app.state.oauth_manager = oauth_manager +# For Integrations +oauth_client_manager = OAuthClientManager(app) +app.state.oauth_client_manager = oauth_client_manager + +app.state.instance_id = None app.state.config = AppConfig( redis_url=REDIS_URL, redis_sentinels=get_sentinels_from_env(REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_PORT), + redis_cluster=REDIS_CLUSTER, + redis_key_prefix=REDIS_KEY_PREFIX, ) +app.state.redis = None app.state.WEBUI_NAME = WEBUI_NAME app.state.LICENSE_METADATA = None @@ -556,6 +692,24 @@ async def lifespan(app: FastAPI): app.state.config.ENABLE_DIRECT_CONNECTIONS = ENABLE_DIRECT_CONNECTIONS +######################################## +# +# SCIM +# +######################################## + +app.state.SCIM_ENABLED = SCIM_ENABLED +app.state.SCIM_TOKEN = SCIM_TOKEN + +######################################## +# +# MODELS +# +######################################## + +app.state.config.ENABLE_BASE_MODELS_CACHE = ENABLE_BASE_MODELS_CACHE +app.state.BASE_MODELS = [] + ######################################## # # WEBUI @@ -626,6 +780,11 @@ async def lifespan(app: FastAPI): app.state.config.LDAP_VALIDATE_CERT = LDAP_VALIDATE_CERT app.state.config.LDAP_CIPHERS = LDAP_CIPHERS +# For LDAP Group Management +app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT = ENABLE_LDAP_GROUP_MANAGEMENT +app.state.config.ENABLE_LDAP_GROUP_CREATION = ENABLE_LDAP_GROUP_CREATION +app.state.config.LDAP_ATTRIBUTE_FOR_GROUPS = LDAP_ATTRIBUTE_FOR_GROUPS + app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER @@ -633,8 +792,12 @@ async def lifespan(app: FastAPI): app.state.EXTERNAL_PWA_MANIFEST_URL = EXTERNAL_PWA_MANIFEST_URL app.state.USER_COUNT = None + app.state.TOOLS = {} +app.state.TOOL_CONTENTS = {} + app.state.FUNCTIONS = {} +app.state.FUNCTION_CONTENTS = {} ######################################## # @@ -646,9 +809,14 @@ async def lifespan(app: FastAPI): app.state.config.TOP_K = RAG_TOP_K app.state.config.TOP_K_RERANKER = RAG_TOP_K_RERANKER app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD +app.state.config.HYBRID_BM25_WEIGHT = RAG_HYBRID_BM25_WEIGHT + + app.state.config.ALLOWED_FILE_EXTENSIONS = RAG_ALLOWED_FILE_EXTENSIONS app.state.config.FILE_MAX_SIZE = RAG_FILE_MAX_SIZE app.state.config.FILE_MAX_COUNT = RAG_FILE_MAX_COUNT +app.state.config.FILE_IMAGE_COMPRESSION_WIDTH = FILE_IMAGE_COMPRESSION_WIDTH +app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = FILE_IMAGE_COMPRESSION_HEIGHT app.state.config.RAG_FULL_CONTEXT = RAG_FULL_CONTEXT @@ -657,16 +825,42 @@ async def lifespan(app: FastAPI): app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION = ENABLE_WEB_LOADER_SSL_VERIFICATION app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE +app.state.config.DATALAB_MARKER_API_KEY = DATALAB_MARKER_API_KEY +app.state.config.DATALAB_MARKER_API_BASE_URL = DATALAB_MARKER_API_BASE_URL +app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG = DATALAB_MARKER_ADDITIONAL_CONFIG +app.state.config.DATALAB_MARKER_SKIP_CACHE = DATALAB_MARKER_SKIP_CACHE +app.state.config.DATALAB_MARKER_FORCE_OCR = DATALAB_MARKER_FORCE_OCR +app.state.config.DATALAB_MARKER_PAGINATE = DATALAB_MARKER_PAGINATE +app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = DATALAB_MARKER_STRIP_EXISTING_OCR +app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = ( + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION +) +app.state.config.DATALAB_MARKER_FORMAT_LINES = DATALAB_MARKER_FORMAT_LINES +app.state.config.DATALAB_MARKER_USE_LLM = DATALAB_MARKER_USE_LLM +app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = DATALAB_MARKER_OUTPUT_FORMAT app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = EXTERNAL_DOCUMENT_LOADER_URL app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY = EXTERNAL_DOCUMENT_LOADER_API_KEY app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL app.state.config.DOCLING_SERVER_URL = DOCLING_SERVER_URL +app.state.config.DOCLING_PARAMS = DOCLING_PARAMS +app.state.config.DOCLING_DO_OCR = DOCLING_DO_OCR +app.state.config.DOCLING_FORCE_OCR = DOCLING_FORCE_OCR app.state.config.DOCLING_OCR_ENGINE = DOCLING_OCR_ENGINE app.state.config.DOCLING_OCR_LANG = DOCLING_OCR_LANG +app.state.config.DOCLING_PDF_BACKEND = DOCLING_PDF_BACKEND +app.state.config.DOCLING_TABLE_MODE = DOCLING_TABLE_MODE +app.state.config.DOCLING_PIPELINE = DOCLING_PIPELINE app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = DOCLING_DO_PICTURE_DESCRIPTION +app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = DOCLING_PICTURE_DESCRIPTION_MODE +app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = DOCLING_PICTURE_DESCRIPTION_LOCAL +app.state.config.DOCLING_PICTURE_DESCRIPTION_API = DOCLING_PICTURE_DESCRIPTION_API app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY app.state.config.MISTRAL_OCR_API_KEY = MISTRAL_OCR_API_KEY +app.state.config.MINERU_API_MODE = MINERU_API_MODE +app.state.config.MINERU_API_URL = MINERU_API_URL +app.state.config.MINERU_API_KEY = MINERU_API_KEY +app.state.config.MINERU_PARAMS = MINERU_PARAMS app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME @@ -688,6 +882,10 @@ async def lifespan(app: FastAPI): app.state.config.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.config.RAG_AZURE_OPENAI_BASE_URL = RAG_AZURE_OPENAI_BASE_URL +app.state.config.RAG_AZURE_OPENAI_API_KEY = RAG_AZURE_OPENAI_API_KEY +app.state.config.RAG_AZURE_OPENAI_API_VERSION = RAG_AZURE_OPENAI_API_VERSION + app.state.config.RAG_OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL app.state.config.RAG_OLLAMA_API_KEY = RAG_OLLAMA_API_KEY @@ -702,14 +900,20 @@ async def lifespan(app: FastAPI): app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = WEB_SEARCH_DOMAIN_FILTER_LIST app.state.config.WEB_SEARCH_RESULT_COUNT = WEB_SEARCH_RESULT_COUNT app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = WEB_SEARCH_CONCURRENT_REQUESTS + app.state.config.WEB_LOADER_ENGINE = WEB_LOADER_ENGINE +app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = WEB_LOADER_CONCURRENT_REQUESTS + app.state.config.WEB_SEARCH_TRUST_ENV = WEB_SEARCH_TRUST_ENV app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL ) +app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = BYPASS_WEB_SEARCH_WEB_LOADER app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION = ENABLE_GOOGLE_DRIVE_INTEGRATION app.state.config.ENABLE_ONEDRIVE_INTEGRATION = ENABLE_ONEDRIVE_INTEGRATION + +app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = OLLAMA_CLOUD_WEB_SEARCH_API_KEY app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL app.state.config.YACY_QUERY_URL = YACY_QUERY_URL app.state.config.YACY_USERNAME = YACY_USERNAME @@ -734,6 +938,8 @@ async def lifespan(app: FastAPI): app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY = BING_SEARCH_V7_SUBSCRIPTION_KEY app.state.config.EXA_API_KEY = EXA_API_KEY app.state.config.PERPLEXITY_API_KEY = PERPLEXITY_API_KEY +app.state.config.PERPLEXITY_MODEL = PERPLEXITY_MODEL +app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = PERPLEXITY_SEARCH_CONTEXT_USAGE app.state.config.SOUGOU_API_SID = SOUGOU_API_SID app.state.config.SOUGOU_API_SK = SOUGOU_API_SK app.state.config.EXTERNAL_WEB_SEARCH_URL = EXTERNAL_WEB_SEARCH_URL @@ -749,6 +955,7 @@ async def lifespan(app: FastAPI): app.state.config.TAVILY_EXTRACT_DEPTH = TAVILY_EXTRACT_DEPTH app.state.EMBEDDING_FUNCTION = None +app.state.RERANKING_FUNCTION = None app.state.ef = None app.state.rf = None @@ -761,14 +968,19 @@ async def lifespan(app: FastAPI): app.state.config.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE, ) - - app.state.rf = get_rf( - app.state.config.RAG_RERANKING_ENGINE, - app.state.config.RAG_RERANKING_MODEL, - app.state.config.RAG_EXTERNAL_RERANKER_URL, - app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - RAG_RERANKING_MODEL_AUTO_UPDATE, - ) + if ( + app.state.config.ENABLE_RAG_HYBRID_SEARCH + and not app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ): + app.state.rf = get_rf( + app.state.config.RAG_RERANKING_ENGINE, + app.state.config.RAG_RERANKING_MODEL, + app.state.config.RAG_EXTERNAL_RERANKER_URL, + app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + RAG_RERANKING_MODEL_AUTO_UPDATE, + ) + else: + app.state.rf = None except Exception as e: log.error(f"Error updating models: {e}") pass @@ -777,18 +989,37 @@ async def lifespan(app: FastAPI): app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.config.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_MODEL, - app.state.ef, - ( + embedding_function=app.state.ef, + url=( app.state.config.RAG_OPENAI_API_BASE_URL if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_BASE_URL + else ( + app.state.config.RAG_OLLAMA_BASE_URL + if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else app.state.config.RAG_AZURE_OPENAI_BASE_URL + ) ), - ( + key=( app.state.config.RAG_OPENAI_API_KEY if app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else app.state.config.RAG_OLLAMA_API_KEY + else ( + app.state.config.RAG_OLLAMA_API_KEY + if app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else app.state.config.RAG_AZURE_OPENAI_API_KEY + ) + ), + embedding_batch_size=app.state.config.RAG_EMBEDDING_BATCH_SIZE, + azure_api_version=( + app.state.config.RAG_AZURE_OPENAI_API_VERSION + if app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None ), - app.state.config.RAG_EMBEDDING_BATCH_SIZE, +) + +app.state.RERANKING_FUNCTION = get_reranking_function( + app.state.config.RAG_RERANKING_ENGINE, + app.state.config.RAG_RERANKING_MODEL, + reranking_function=app.state.rf, ) ######################################## @@ -832,6 +1063,7 @@ async def lifespan(app: FastAPI): app.state.config.ENABLE_IMAGE_PROMPT_GENERATION = ENABLE_IMAGE_PROMPT_GENERATION app.state.config.IMAGES_OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL +app.state.config.IMAGES_OPENAI_API_VERSION = IMAGES_OPENAI_API_VERSION app.state.config.IMAGES_OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.config.IMAGES_GEMINI_API_BASE_URL = IMAGES_GEMINI_API_BASE_URL @@ -859,10 +1091,12 @@ async def lifespan(app: FastAPI): # ######################################## -app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL -app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY app.state.config.STT_ENGINE = AUDIO_STT_ENGINE app.state.config.STT_MODEL = AUDIO_STT_MODEL +app.state.config.STT_SUPPORTED_CONTENT_TYPES = AUDIO_STT_SUPPORTED_CONTENT_TYPES + +app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL +app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY app.state.config.WHISPER_MODEL = WHISPER_MODEL app.state.config.WHISPER_VAD_FILTER = WHISPER_VAD_FILTER @@ -874,11 +1108,15 @@ async def lifespan(app: FastAPI): app.state.config.AUDIO_STT_AZURE_BASE_URL = AUDIO_STT_AZURE_BASE_URL app.state.config.AUDIO_STT_AZURE_MAX_SPEAKERS = AUDIO_STT_AZURE_MAX_SPEAKERS -app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL -app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE + app.state.config.TTS_MODEL = AUDIO_TTS_MODEL app.state.config.TTS_VOICE = AUDIO_TTS_VOICE + +app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL +app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY +app.state.config.TTS_OPENAI_PARAMS = AUDIO_TTS_OPENAI_PARAMS + app.state.config.TTS_API_KEY = AUDIO_TTS_API_KEY app.state.config.TTS_SPLIT_ON = AUDIO_TTS_SPLIT_ON @@ -909,6 +1147,7 @@ async def lifespan(app: FastAPI): app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = ENABLE_AUTOCOMPLETE_GENERATION app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.ENABLE_TITLE_GENERATION = ENABLE_TITLE_GENERATION +app.state.config.ENABLE_FOLLOW_UP_GENERATION = ENABLE_FOLLOW_UP_GENERATION app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE @@ -916,6 +1155,9 @@ async def lifespan(app: FastAPI): app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE ) +app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE +) app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE @@ -945,12 +1187,32 @@ async def dispatch(self, request: Request, call_next): path = request.url.path query_params = dict(parse_qs(urlparse(str(request.url)).query)) + redirect_params = {} + # Check for the specific watch path and the presence of 'v' parameter if path.endswith("/watch") and "v" in query_params: # Extract the first 'v' parameter - video_id = query_params["v"][0] - encoded_video_id = urlencode({"youtube": video_id}) - redirect_url = f"/?{encoded_video_id}" + youtube_video_id = query_params["v"][0] + redirect_params["youtube"] = youtube_video_id + + if "shared" in query_params and len(query_params["shared"]) > 0: + # PWA share_target support + + text = query_params["shared"][0] + if text: + urls = re.match(r"https://\S+", text) + if urls: + from open_webui.retrieval.loaders.youtube import _parse_video_id + + if youtube_video_id := _parse_video_id(urls[0]): + redirect_params["youtube"] = youtube_video_id + else: + redirect_params["load-url"] = urls[0] + else: + redirect_params["q"] = text + + if redirect_params: + redirect_url = f"/?{urlencode(redirect_params)}" return RedirectResponse(url=redirect_url) # Proceed with the normal flow of other requests @@ -959,6 +1221,9 @@ async def dispatch(self, request: Request, call_next): # Add the middleware to the app +if ENABLE_COMPRESSION_MIDDLEWARE: + app.add_middleware(CompressMiddleware) + app.add_middleware(RedirectMiddleware) app.add_middleware(SecurityHeadersMiddleware) @@ -1052,6 +1317,10 @@ async def inspect_websocket(request: Request, call_next): ) app.include_router(utils.router, prefix="/api/v1/utils", tags=["utils"]) +# SCIM 2.0 API for identity management +if SCIM_ENABLED: + app.include_router(scim.router, prefix="/api/v1/scim/v2", tags=["scim"]) + try: audit_level = AuditLevel(AUDIT_LOG_LEVEL) @@ -1074,31 +1343,11 @@ async def inspect_websocket(request: Request, call_next): @app.get("/api/models") -async def get_models(request: Request, user=Depends(get_verified_user)): - def get_filtered_models(models, user): - filtered_models = [] - for model in models: - if model.get("arena"): - if has_access( - user.id, - type="read", - access_control=model.get("info", {}) - .get("meta", {}) - .get("access_control", {}), - ): - filtered_models.append(model) - continue - - model_info = Models.get_model_by_id(model["id"]) - if model_info: - if user.id == model_info.user_id or has_access( - user.id, type="read", access_control=model_info.access_control - ): - filtered_models.append(model) - - return filtered_models - - all_models = await get_all_models(request, user=user) +@app.get("/api/v1/models") # Experimental: Compatibility with OpenAI API +async def get_models( + request: Request, refresh: bool = False, user=Depends(get_verified_user) +): + all_models = await get_all_models(request, refresh=refresh, user=user) models = [] for model in all_models: @@ -1127,15 +1376,16 @@ def get_filtered_models(models, user): model_order_dict = {model_id: i for i, model_id in enumerate(model_order_list)} # Sort models by order list priority, with fallback for those not in the list models.sort( - key=lambda x: (model_order_dict.get(x["id"], float("inf")), x["name"]) + key=lambda model: ( + model_order_dict.get(model.get("id", ""), float("inf")), + (model.get("name", "") or ""), + ) ) - # Filter out models that the user does not have access to - if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: - models = get_filtered_models(models, user) + models = get_filtered_models(models, user) log.debug( - f"/api/models returned filtered models accessible to the user: {json.dumps([model['id'] for model in models])}" + f"/api/models returned filtered models accessible to the user: {json.dumps([model.get('id') for model in models])}" ) return {"data": models} @@ -1146,7 +1396,40 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)): return {"data": models} +################################## +# Embeddings +################################## + + +@app.post("/api/embeddings") +@app.post("/api/v1/embeddings") # Experimental: Compatibility with OpenAI API +async def embeddings( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + """ + OpenAI-compatible embeddings endpoint. + + This handler: + - Performs user/model checks and dispatches to the correct backend. + - Supports OpenAI, Ollama, arena models, pipelines, and any compatible provider. + + Args: + request (Request): Request context. + form_data (dict): OpenAI-like payload (e.g., {"model": "...", "input": [...]}) + user (UserModel): Authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. + """ + # Make sure models are loaded in app state + if not request.app.state.MODELS: + await get_all_models(request, user=user) + # Use generic dispatcher in utils.embeddings + return await generate_embeddings(request, form_data, user) + + @app.post("/api/chat/completions") +@app.post("/api/v1/chat/completions") # Experimental: Compatibility with OpenAI API async def chat_completion( request: Request, form_data: dict, @@ -1155,13 +1438,13 @@ async def chat_completion( if not request.app.state.MODELS: await get_all_models(request, user=user) + model_id = form_data.get("model", None) model_item = form_data.pop("model_item", {}) tasks = form_data.pop("background_tasks", None) metadata = {} try: if not model_item.get("direct", False): - model_id = form_data.get("model", None) if model_id not in request.app.state.MODELS: raise Exception("Model not found") @@ -1169,7 +1452,9 @@ async def chat_completion( model_info = Models.get_model_by_id(model_id) # Check if user has access to the model - if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user": + if not BYPASS_MODEL_ACCESS_CONTROL and ( + user.role != "admin" or not BYPASS_ADMIN_ACCESS_CONTROL + ): try: check_model_access(user, model) except Exception as e: @@ -1181,6 +1466,23 @@ async def chat_completion( request.state.direct = True request.state.model = model + model_info_params = ( + model_info.params.model_dump() if model_info and model_info.params else {} + ) + + # Chat Params + stream_delta_chunk_size = form_data.get("params", {}).get( + "stream_delta_chunk_size" + ) + reasoning_tags = form_data.get("params", {}).get("reasoning_tags") + + # Model Params + if model_info_params.get("stream_delta_chunk_size"): + stream_delta_chunk_size = model_info_params.get("stream_delta_chunk_size") + + if model_info_params.get("reasoning_tags") is not None: + reasoning_tags = model_info_params.get("reasoning_tags") + metadata = { "user_id": user.id, "chat_id": form_data.pop("chat_id", None), @@ -1194,53 +1496,121 @@ async def chat_completion( "variables": form_data.get("variables", {}), "model": model, "direct": model_item.get("direct", False), - **( - {"function_calling": "native"} - if form_data.get("params", {}).get("function_calling") == "native" - or ( - model_info - and model_info.params.model_dump().get("function_calling") - == "native" - ) - else {} - ), + "params": { + "stream_delta_chunk_size": stream_delta_chunk_size, + "reasoning_tags": reasoning_tags, + "function_calling": ( + "native" + if ( + form_data.get("params", {}).get("function_calling") == "native" + or model_info_params.get("function_calling") == "native" + ) + else "default" + ), + }, } + if metadata.get("chat_id") and (user and user.role != "admin"): + if not metadata["chat_id"].startswith("local:"): + chat = Chats.get_chat_by_id_and_user_id(metadata["chat_id"], user.id) + if chat is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.DEFAULT(), + ) + request.state.metadata = metadata form_data["metadata"] = metadata - form_data, metadata, events = await process_chat_payload( - request, form_data, user, metadata, model - ) - except Exception as e: - log.debug(f"Error processing chat payload: {e}") - if metadata.get("chat_id") and metadata.get("message_id"): - # Update the chat message with the error - Chats.upsert_message_to_chat_by_id_and_message_id( - metadata["chat_id"], - metadata["message_id"], - { - "error": {"content": str(e)}, - }, - ) - + log.debug(f"Error processing chat metadata: {e}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=str(e), ) - try: - response = await chat_completion_handler(request, form_data, user) + async def process_chat(request, form_data, user, metadata, model): + try: + form_data, metadata, events = await process_chat_payload( + request, form_data, user, metadata, model + ) - return await process_chat_response( - request, response, form_data, user, metadata, model, events, tasks - ) - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), + response = await chat_completion_handler(request, form_data, user) + if metadata.get("chat_id") and metadata.get("message_id"): + try: + if not metadata["chat_id"].startswith("local:"): + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "model": model_id, + }, + ) + except: + pass + + return await process_chat_response( + request, response, form_data, user, metadata, model, events, tasks + ) + except asyncio.CancelledError: + log.info("Chat processing was cancelled") + try: + event_emitter = get_event_emitter(metadata) + await event_emitter( + {"type": "chat:tasks:cancel"}, + ) + except Exception as e: + pass + except Exception as e: + log.debug(f"Error processing chat payload: {e}") + if metadata.get("chat_id") and metadata.get("message_id"): + # Update the chat message with the error + try: + if not metadata["chat_id"].startswith("local:"): + Chats.upsert_message_to_chat_by_id_and_message_id( + metadata["chat_id"], + metadata["message_id"], + { + "error": {"content": str(e)}, + }, + ) + + event_emitter = get_event_emitter(metadata) + await event_emitter( + { + "type": "chat:message:error", + "data": {"error": {"content": str(e)}}, + } + ) + await event_emitter( + {"type": "chat:tasks:cancel"}, + ) + + except: + pass + finally: + try: + if mcp_clients := metadata.get("mcp_clients"): + for client in mcp_clients.values(): + await client.disconnect() + except Exception as e: + log.debug(f"Error cleaning up: {e}") + pass + + if ( + metadata.get("session_id") + and metadata.get("chat_id") + and metadata.get("message_id") + ): + # Asynchronous Chat Processing + task_id, _ = await create_task( + request.app.state.redis, + process_chat(request, form_data, user, metadata, model), + id=metadata["chat_id"], ) + return {"status": True, "task_id": task_id} + else: + return await process_chat(request, form_data, user, metadata, model) # Alias for chat_completion (Legacy) @@ -1287,28 +1657,32 @@ async def chat_action( @app.post("/api/tasks/stop/{task_id}") -async def stop_task_endpoint(task_id: str, user=Depends(get_verified_user)): +async def stop_task_endpoint( + request: Request, task_id: str, user=Depends(get_verified_user) +): try: - result = await stop_task(task_id) + result = await stop_task(request.app.state.redis, task_id) return result except ValueError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e)) @app.get("/api/tasks") -async def list_tasks_endpoint(user=Depends(get_verified_user)): - return {"tasks": list_tasks()} +async def list_tasks_endpoint(request: Request, user=Depends(get_verified_user)): + return {"tasks": await list_tasks(request.app.state.redis)} @app.get("/api/tasks/chat/{chat_id}") -async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified_user)): +async def list_tasks_by_chat_id_endpoint( + request: Request, chat_id: str, user=Depends(get_verified_user) +): chat = Chats.get_chat_by_id(chat_id) if chat is None or chat.user_id != user.id: return {"task_ids": []} - task_ids = list_task_ids_by_chat_id(chat_id) + task_ids = await list_task_ids_by_item_id(request.app.state.redis, chat_id) - print(f"Task IDs for chat {chat_id}: {task_ids}") + log.debug(f"Task IDs for chat {chat_id}: {task_ids}") return {"task_ids": task_ids} @@ -1322,8 +1696,18 @@ async def list_tasks_by_chat_id_endpoint(chat_id: str, user=Depends(get_verified @app.get("/api/config") async def get_app_config(request: Request): user = None - if "token" in request.cookies: + token = None + + auth_header = request.headers.get("Authorization") + if auth_header: + cred = get_http_authorization_cred(auth_header) + if cred: + token = cred.credentials + + if not token and "token" in request.cookies: token = request.cookies.get("token") + + if token: try: data = decode_token(token) except Exception as e: @@ -1356,11 +1740,13 @@ async def get_app_config(request: Request): "features": { "auth": WEBUI_AUTH, "auth_trusted_header": bool(app.state.AUTH_TRUSTED_EMAIL_HEADER), + "enable_signup_password_confirmation": ENABLE_SIGNUP_PASSWORD_CONFIRMATION, "enable_ldap": app.state.config.ENABLE_LDAP, "enable_api_key": app.state.config.ENABLE_API_KEY, "enable_signup": app.state.config.ENABLE_SIGNUP, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM, "enable_websocket": ENABLE_WEBSOCKET_SUPPORT, + "enable_version_update_check": ENABLE_VERSION_UPDATE_CHECK, **( { "enable_direct_connections": app.state.config.ENABLE_DIRECT_CONNECTIONS, @@ -1378,6 +1764,14 @@ async def get_app_config(request: Request): "enable_admin_chat_access": ENABLE_ADMIN_CHAT_ACCESS, "enable_google_drive_integration": app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, "enable_onedrive_integration": app.state.config.ENABLE_ONEDRIVE_INTEGRATION, + **( + { + "enable_onedrive_personal": ENABLE_ONEDRIVE_PERSONAL, + "enable_onedrive_business": ENABLE_ONEDRIVE_BUSINESS, + } + if app.state.config.ENABLE_ONEDRIVE_INTEGRATION + else {} + ), } if user is not None else {} @@ -1404,6 +1798,10 @@ async def get_app_config(request: Request): "file": { "max_size": app.state.config.FILE_MAX_SIZE, "max_count": app.state.config.FILE_MAX_COUNT, + "image_compression": { + "width": app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + "height": app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, + }, }, "permissions": {**app.state.config.USER_PERMISSIONS}, "google_drive": { @@ -1411,7 +1809,8 @@ async def get_app_config(request: Request): "api_key": GOOGLE_DRIVE_API_KEY.value, }, "onedrive": { - "client_id": ONEDRIVE_CLIENT_ID.value, + "client_id_personal": ONEDRIVE_CLIENT_ID_PERSONAL, + "client_id_business": ONEDRIVE_CLIENT_ID_BUSINESS, "sharepoint_url": ONEDRIVE_SHAREPOINT_URL.value, "sharepoint_tenant_id": ONEDRIVE_SHAREPOINT_TENANT_ID.value, }, @@ -1429,8 +1828,33 @@ async def get_app_config(request: Request): else {} ), } - if user is not None - else {} + if user is not None and (user.role in ["admin", "user"]) + else { + **( + { + "ui": { + "pending_user_overlay_title": app.state.config.PENDING_USER_OVERLAY_TITLE, + "pending_user_overlay_content": app.state.config.PENDING_USER_OVERLAY_CONTENT, + } + } + if user and user.role == "pending" + else {} + ), + **( + { + "metadata": { + "login_footer": app.state.LICENSE_METADATA.get( + "login_footer", "" + ), + "auth_logo_position": app.state.LICENSE_METADATA.get( + "auth_logo_position", "" + ), + } + } + if app.state.LICENSE_METADATA + else {} + ), + } ), } @@ -1462,9 +1886,9 @@ async def get_app_version(): @app.get("/api/version/updates") async def get_app_latest_release_version(user=Depends(get_verified_user)): - if OFFLINE_MODE: + if not ENABLE_VERSION_UPDATE_CHECK: log.debug( - f"Offline mode is enabled, returning current version as latest version" + f"Version update check is disabled, returning current version as latest version" ) return {"current": VERSION, "latest": VERSION} try: @@ -1489,21 +1913,100 @@ async def get_app_changelog(): return {key: CHANGELOG[key] for idx, key in enumerate(CHANGELOG) if idx < 5} +@app.get("/api/usage") +async def get_current_usage(user=Depends(get_verified_user)): + """ + Get current usage statistics for Open WebUI. + This is an experimental endpoint and subject to change. + """ + try: + return {"model_ids": get_models_in_use(), "user_ids": get_active_user_ids()} + except Exception as e: + log.error(f"Error getting usage statistics: {e}") + raise HTTPException(status_code=500, detail="Internal Server Error") + + ############################ # OAuth Login & Callback ############################ -# SessionMiddleware is used by authlib for oauth -if len(OAUTH_PROVIDERS) > 0: + +# Initialize OAuth client manager with any MCP tool servers using OAuth 2.1 +if len(app.state.config.TOOL_SERVER_CONNECTIONS) > 0: + for tool_server_connection in app.state.config.TOOL_SERVER_CONNECTIONS: + if tool_server_connection.get("type", "openapi") == "mcp": + server_id = tool_server_connection.get("info", {}).get("id") + auth_type = tool_server_connection.get("auth_type", "none") + if server_id and auth_type == "oauth_2.1": + oauth_client_info = tool_server_connection.get("info", {}).get( + "oauth_client_info", "" + ) + + try: + oauth_client_info = decrypt_data(oauth_client_info) + app.state.oauth_client_manager.add_client( + f"mcp:{server_id}", + OAuthClientInformationFull(**oauth_client_info), + ) + except Exception as e: + log.error( + f"Error adding OAuth client for MCP tool server {server_id}: {e}" + ) + pass + +try: + if ENABLE_STAR_SESSIONS_MIDDLEWARE: + redis_session_store = RedisStore( + url=REDIS_URL, + prefix=(f"{REDIS_KEY_PREFIX}:session:" if REDIS_KEY_PREFIX else "session:"), + ) + + app.add_middleware(SessionAutoloadMiddleware) + app.add_middleware( + StarSessionsMiddleware, + store=redis_session_store, + cookie_name="owui-session", + cookie_same_site=WEBUI_SESSION_COOKIE_SAME_SITE, + cookie_https_only=WEBUI_SESSION_COOKIE_SECURE, + ) + log.info("Using Redis for session") + else: + raise ValueError("No Redis URL provided") +except Exception as e: app.add_middleware( SessionMiddleware, secret_key=WEBUI_SECRET_KEY, - session_cookie="oui-session", + session_cookie="owui-session", same_site=WEBUI_SESSION_COOKIE_SAME_SITE, https_only=WEBUI_SESSION_COOKIE_SECURE, ) +@app.get("/oauth/clients/{client_id}/authorize") +async def oauth_client_authorize( + client_id: str, + request: Request, + response: Response, + user=Depends(get_verified_user), +): + return await oauth_client_manager.handle_authorize(request, client_id=client_id) + + +@app.get("/oauth/clients/{client_id}/callback") +async def oauth_client_callback( + client_id: str, + request: Request, + response: Response, + user=Depends(get_verified_user), +): + return await oauth_client_manager.handle_callback( + request, + client_id=client_id, + user_id=user.id if user else None, + response=response, + ) + + @app.get("/oauth/{provider}/login") async def oauth_login(provider: str, request: Request): return await oauth_manager.handle_login(request, provider) @@ -1515,8 +2018,9 @@ async def oauth_login(provider: str, request: Request): # - This is considered insecure in general, as OAuth providers do not always verify email addresses # 3. If there is no user, and ENABLE_OAUTH_SIGNUP is true, create a user # - Email addresses are considered unique, so we fail registration if the email address is already taken -@app.get("/oauth/{provider}/callback") -async def oauth_callback(provider: str, request: Request, response: Response): +@app.get("/oauth/{provider}/login/callback") +@app.get("/oauth/{provider}/callback") # Legacy endpoint +async def oauth_login_callback(provider: str, request: Request, response: Response): return await oauth_manager.handle_callback(request, provider, response) @@ -1528,11 +2032,10 @@ async def get_manifest_json(): return { "name": app.state.WEBUI_NAME, "short_name": app.state.WEBUI_NAME, - "description": "Open WebUI is an open, extensible, user-friendly interface for AI that adapts to your workflow.", + "description": f"{app.state.WEBUI_NAME} is an open, extensible, user-friendly interface for AI that adapts to your workflow.", "start_url": "/", "display": "standalone", "background_color": "#343541", - "orientation": "any", "icons": [ { "src": "/static/logo.png", @@ -1547,6 +2050,11 @@ async def get_manifest_json(): "purpose": "maskable", }, ], + "share_target": { + "action": "/", + "method": "GET", + "params": {"text": "shared"}, + }, } @@ -1577,7 +2085,20 @@ async def healthcheck_with_db(): app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") -app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") + + +@app.get("/cache/{path:path}") +async def serve_cache_file( + path: str, + user=Depends(get_verified_user), +): + file_path = os.path.abspath(os.path.join(CACHE_DIR, path)) + # prevent path traversal + if not file_path.startswith(os.path.abspath(CACHE_DIR)): + raise HTTPException(status_code=404, detail="File not found") + if not os.path.isfile(file_path): + raise HTTPException(status_code=404, detail="File not found") + return FileResponse(file_path) def swagger_ui_html(*args, **kwargs): diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index 12888164717..7db92512820 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -2,8 +2,8 @@ from alembic import context from open_webui.models.auths import Auth -from open_webui.env import DATABASE_URL -from sqlalchemy import engine_from_config, pool +from open_webui.env import DATABASE_URL, DATABASE_PASSWORD +from sqlalchemy import engine_from_config, pool, create_engine # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -62,11 +62,38 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = engine_from_config( - config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) + # Handle SQLCipher URLs + if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"): + if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": + raise ValueError( + "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" + ) + + # Extract database path from SQLCipher URL + db_path = DB_URL.replace("sqlite+sqlcipher://", "") + if db_path.startswith("/"): + db_path = db_path[1:] # Remove leading slash for relative paths + + # Create a custom creator function that uses sqlcipher3 + def create_sqlcipher_connection(): + import sqlcipher3 + + conn = sqlcipher3.connect(db_path, check_same_thread=False) + conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'") + return conn + + connectable = create_engine( + "sqlite://", # Dummy URL since we're using creator + creator=create_sqlcipher_connection, + echo=False, + ) + else: + # Standard database connection (existing logic) + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) with connectable.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/backend/open_webui/migrations/versions/018012973d35_add_indexes.py b/backend/open_webui/migrations/versions/018012973d35_add_indexes.py new file mode 100644 index 00000000000..29af4271088 --- /dev/null +++ b/backend/open_webui/migrations/versions/018012973d35_add_indexes.py @@ -0,0 +1,46 @@ +"""Add indexes + +Revision ID: 018012973d35 +Revises: d31026856c01 +Create Date: 2025-08-13 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "018012973d35" +down_revision = "d31026856c01" +branch_labels = None +depends_on = None + + +def upgrade(): + # Chat table indexes + op.create_index("folder_id_idx", "chat", ["folder_id"]) + op.create_index("user_id_pinned_idx", "chat", ["user_id", "pinned"]) + op.create_index("user_id_archived_idx", "chat", ["user_id", "archived"]) + op.create_index("updated_at_user_id_idx", "chat", ["updated_at", "user_id"]) + op.create_index("folder_id_user_id_idx", "chat", ["folder_id", "user_id"]) + + # Tag table index + op.create_index("user_id_idx", "tag", ["user_id"]) + + # Function table index + op.create_index("is_global_idx", "function", ["is_global"]) + + +def downgrade(): + # Chat table indexes + op.drop_index("folder_id_idx", table_name="chat") + op.drop_index("user_id_pinned_idx", table_name="chat") + op.drop_index("user_id_archived_idx", table_name="chat") + op.drop_index("updated_at_user_id_idx", table_name="chat") + op.drop_index("folder_id_user_id_idx", table_name="chat") + + # Tag table index + op.drop_index("user_id_idx", table_name="tag") + + # Function table index + + op.drop_index("is_global_idx", table_name="function") diff --git a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py new file mode 100644 index 00000000000..8ead6db6d4a --- /dev/null +++ b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py @@ -0,0 +1,52 @@ +"""Add oauth_session table + +Revision ID: 38d63c18f30f +Revises: 3af16a1c9fb6 +Create Date: 2025-09-08 14:19:59.583921 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "38d63c18f30f" +down_revision: Union[str, None] = "3af16a1c9fb6" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Create oauth_session table + op.create_table( + "oauth_session", + sa.Column("id", sa.Text(), nullable=False), + sa.Column("user_id", sa.Text(), nullable=False), + sa.Column("provider", sa.Text(), nullable=False), + sa.Column("token", sa.Text(), nullable=False), + sa.Column("expires_at", sa.BigInteger(), nullable=False), + sa.Column("created_at", sa.BigInteger(), nullable=False), + sa.Column("updated_at", sa.BigInteger(), nullable=False), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), + ) + + # Create indexes for better performance + op.create_index("idx_oauth_session_user_id", "oauth_session", ["user_id"]) + op.create_index("idx_oauth_session_expires_at", "oauth_session", ["expires_at"]) + op.create_index( + "idx_oauth_session_user_provider", "oauth_session", ["user_id", "provider"] + ) + + +def downgrade() -> None: + # Drop indexes first + op.drop_index("idx_oauth_session_user_provider", table_name="oauth_session") + op.drop_index("idx_oauth_session_expires_at", table_name="oauth_session") + op.drop_index("idx_oauth_session_user_id", table_name="oauth_session") + + # Drop the table + op.drop_table("oauth_session") diff --git a/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py b/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py new file mode 100644 index 00000000000..ab980f27ce3 --- /dev/null +++ b/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py @@ -0,0 +1,32 @@ +"""update user table + +Revision ID: 3af16a1c9fb6 +Revises: 018012973d35 +Create Date: 2025-08-21 02:07:18.078283 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "3af16a1c9fb6" +down_revision: Union[str, None] = "018012973d35" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("user", sa.Column("username", sa.String(length=50), nullable=True)) + op.add_column("user", sa.Column("bio", sa.Text(), nullable=True)) + op.add_column("user", sa.Column("gender", sa.Text(), nullable=True)) + op.add_column("user", sa.Column("date_of_birth", sa.Date(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("user", "username") + op.drop_column("user", "bio") + op.drop_column("user", "gender") + op.drop_column("user", "date_of_birth") diff --git a/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py new file mode 100644 index 00000000000..dd2b7d1a680 --- /dev/null +++ b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py @@ -0,0 +1,34 @@ +"""Add reply_to_id column to message + +Revision ID: a5c220713937 +Revises: 38d63c18f30f +Create Date: 2025-09-27 02:24:18.058455 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision: str = "a5c220713937" +down_revision: Union[str, None] = "38d63c18f30f" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add 'reply_to_id' column to the 'message' table for replying to messages + op.add_column( + "message", + sa.Column("reply_to_id", sa.Text(), nullable=True), + ) + pass + + +def downgrade() -> None: + # Remove 'reply_to_id' column from the 'message' table + op.drop_column("message", "reply_to_id") + + pass diff --git a/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py new file mode 100644 index 00000000000..3c916964e92 --- /dev/null +++ b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py @@ -0,0 +1,23 @@ +"""Update folder table data + +Revision ID: d31026856c01 +Revises: 9f0c9cd09105 +Create Date: 2025-07-13 03:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa + +revision = "d31026856c01" +down_revision = "9f0c9cd09105" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column("folder", sa.Column("data", sa.JSON(), nullable=True)) + + +def downgrade(): + op.drop_column("folder", "data") diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index f07c36c734f..6517e21345a 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -73,11 +73,6 @@ class ProfileImageUrlForm(BaseModel): profile_image_url: str -class UpdateProfileForm(BaseModel): - profile_image_url: str - name: str - - class UpdatePasswordForm(BaseModel): password: str new_password: str @@ -129,12 +124,16 @@ def insert_new_auth( def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: log.info(f"authenticate_user: {email}") + + user = Users.get_user_by_email(email) + if not user: + return None + try: with get_db() as db: - auth = db.query(Auth).filter_by(email=email, active=True).first() + auth = db.query(Auth).filter_by(id=user.id, active=True).first() if auth: if verify_password(password, auth.password): - user = Users.get_user_by_id(auth.id) return user else: return None @@ -155,8 +154,8 @@ def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]: except Exception: return False - def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]: - log.info(f"authenticate_user_by_trusted_header: {email}") + def authenticate_user_by_email(self, email: str) -> Optional[UserModel]: + log.info(f"authenticate_user_by_email: {email}") try: with get_db() as db: auth = db.query(Auth).filter_by(email=email, active=True).first() diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 92f238c3a02..e75266be781 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -57,6 +57,10 @@ class ChannelModel(BaseModel): #################### +class ChannelResponse(ChannelModel): + write_access: bool = False + + class ChannelForm(BaseModel): name: str description: Optional[str] = None diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 4b4f3719765..cfcbc004b70 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -6,12 +6,14 @@ from open_webui.internal.db import Base, get_db from open_webui.models.tags import TagModel, Tag, Tags +from open_webui.models.folders import Folders from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON +from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON, Index from sqlalchemy import or_, func, select, and_, text from sqlalchemy.sql import exists +from sqlalchemy.sql.expression import bindparam #################### # Chat DB Schema @@ -39,6 +41,20 @@ class Chat(Base): meta = Column(JSON, server_default="{}") folder_id = Column(Text, nullable=True) + __table_args__ = ( + # Performance indexes for common queries + # WHERE folder_id = ... + Index("folder_id_idx", "folder_id"), + # WHERE user_id = ... AND pinned = ... + Index("user_id_pinned_idx", "user_id", "pinned"), + # WHERE user_id = ... AND archived = ... + Index("user_id_archived_idx", "user_id", "archived"), + # WHERE user_id = ... ORDER BY updated_at DESC + Index("updated_at_user_id_idx", "updated_at", "user_id"), + # WHERE folder_id = ... AND user_id = ... + Index("folder_id_user_id_idx", "folder_id", "user_id"), + ) + class ChatModel(BaseModel): model_config = ConfigDict(from_attributes=True) @@ -66,12 +82,14 @@ class ChatModel(BaseModel): class ChatForm(BaseModel): chat: dict + folder_id: Optional[str] = None class ChatImportForm(ChatForm): meta: Optional[dict] = {} pinned: Optional[bool] = False - folder_id: Optional[str] = None + created_at: Optional[int] = None + updated_at: Optional[int] = None class ChatTitleMessagesForm(BaseModel): @@ -118,6 +136,7 @@ def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatMod else "New Chat" ), "chat": form_data.chat, + "folder_id": form_data.folder_id, "created_at": int(time.time()), "updated_at": int(time.time()), } @@ -147,8 +166,16 @@ def import_chat( "meta": form_data.meta, "pinned": form_data.pinned, "folder_id": form_data.folder_id, - "created_at": int(time.time()), - "updated_at": int(time.time()), + "created_at": ( + form_data.created_at + if form_data.created_at + else int(time.time()) + ), + "updated_at": ( + form_data.updated_at + if form_data.updated_at + else int(time.time()) + ), } ) @@ -209,7 +236,7 @@ def get_chat_title_by_id(self, id: str) -> Optional[str]: return chat.chat.get("title", "New Chat") - def get_messages_by_chat_id(self, id: str) -> Optional[dict]: + def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]: chat = self.get_chat_by_id(id) if chat is None: return None @@ -232,6 +259,10 @@ def upsert_message_to_chat_by_id_and_message_id( if chat is None: return None + # Sanitize message content for null characters before upserting + if isinstance(message.get("content"), str): + message["content"] = message["content"].replace("\x00", "") + chat = chat.chat history = chat.get("history", {}) @@ -280,6 +311,9 @@ def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: "user_id": f"shared-{chat_id}", "title": chat.title, "chat": chat.chat, + "meta": chat.meta, + "pinned": chat.pinned, + "folder_id": chat.folder_id, "created_at": chat.created_at, "updated_at": int(time.time()), } @@ -311,7 +345,9 @@ def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]: shared_chat.title = chat.title shared_chat.chat = chat.chat - + shared_chat.meta = chat.meta + shared_chat.pinned = chat.pinned + shared_chat.folder_id = chat.folder_id shared_chat.updated_at = int(time.time()) db.commit() db.refresh(shared_chat) @@ -330,6 +366,15 @@ def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool: except Exception: return False + def unarchive_all_chats_by_user_id(self, user_id: str) -> bool: + try: + with get_db() as db: + db.query(Chat).filter_by(user_id=user_id).update({"archived": False}) + db.commit() + return True + except Exception: + return False + def update_chat_share_id_by_id( self, id: str, share_id: Optional[str] ) -> Optional[ChatModel]: @@ -377,22 +422,47 @@ def archive_all_chats_by_user_id(self, user_id: str) -> bool: return False def get_archived_chat_list_by_user_id( - self, user_id: str, skip: int = 0, limit: int = 50 + self, + user_id: str, + filter: Optional[dict] = None, + skip: int = 0, + limit: int = 50, ) -> list[ChatModel]: + with get_db() as db: - all_chats = ( - db.query(Chat) - .filter_by(user_id=user_id, archived=True) - .order_by(Chat.updated_at.desc()) - # .limit(limit).offset(skip) - .all() - ) + query = db.query(Chat).filter_by(user_id=user_id, archived=True) + + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter(Chat.title.ilike(f"%{query_key}%")) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by and direction and getattr(Chat, order_by): + if direction.lower() == "asc": + query = query.order_by(getattr(Chat, order_by).asc()) + elif direction.lower() == "desc": + query = query.order_by(getattr(Chat, order_by).desc()) + else: + raise ValueError("Invalid direction for ordering") + else: + query = query.order_by(Chat.updated_at.desc()) + + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + + all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] def get_chat_list_by_user_id( self, user_id: str, include_archived: bool = False, + filter: Optional[dict] = None, skip: int = 0, limit: int = 50, ) -> list[ChatModel]: @@ -401,7 +471,23 @@ def get_chat_list_by_user_id( if not include_archived: query = query.filter_by(archived=False) - query = query.order_by(Chat.updated_at.desc()) + if filter: + query_key = filter.get("query") + if query_key: + query = query.filter(Chat.title.ilike(f"%{query_key}%")) + + order_by = filter.get("order_by") + direction = filter.get("direction") + + if order_by and direction and getattr(Chat, order_by): + if direction.lower() == "asc": + query = query.order_by(getattr(Chat, order_by).asc()) + elif direction.lower() == "desc": + query = query.order_by(getattr(Chat, order_by).desc()) + else: + raise ValueError("Invalid direction for ordering") + else: + query = query.order_by(Chat.updated_at.desc()) if skip: query = query.offset(skip) @@ -415,12 +501,19 @@ def get_chat_title_id_list_by_user_id( self, user_id: str, include_archived: bool = False, + include_folders: bool = False, + include_pinned: bool = False, skip: Optional[int] = None, limit: Optional[int] = None, ) -> list[ChatTitleIdResponse]: with get_db() as db: - query = db.query(Chat).filter_by(user_id=user_id).filter_by(folder_id=None) - query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) + query = db.query(Chat).filter_by(user_id=user_id) + + if not include_folders: + query = query.filter_by(folder_id=None) + + if not include_pinned: + query = query.filter(or_(Chat.pinned == False, Chat.pinned == None)) if not include_archived: query = query.filter_by(archived=False) @@ -539,10 +632,12 @@ def get_chats_by_user_id_and_search_text( """ Filters chats based on a search query using Python, allowing pagination using skip and limit. """ - search_text = search_text.lower().strip() + search_text = search_text.replace("\u0000", "").lower().strip() if not search_text: - return self.get_chat_list_by_user_id(user_id, include_archived, skip, limit) + return self.get_chat_list_by_user_id( + user_id, include_archived, filter={}, skip=skip, limit=limit + ) search_text_words = search_text.split(" ") @@ -553,8 +648,45 @@ def get_chats_by_user_id_and_search_text( if word.startswith("tag:") ] + # Extract folder names - handle spaces and case insensitivity + folders = Folders.search_folders_by_names( + user_id, + [ + word.replace("folder:", "") + for word in search_text_words + if word.startswith("folder:") + ], + ) + folder_ids = [folder.id for folder in folders] + + is_pinned = None + if "pinned:true" in search_text_words: + is_pinned = True + elif "pinned:false" in search_text_words: + is_pinned = False + + is_archived = None + if "archived:true" in search_text_words: + is_archived = True + elif "archived:false" in search_text_words: + is_archived = False + + is_shared = None + if "shared:true" in search_text_words: + is_shared = True + elif "shared:false" in search_text_words: + is_shared = False + search_text_words = [ - word for word in search_text_words if not word.startswith("tag:") + word + for word in search_text_words + if ( + not word.startswith("tag:") + and not word.startswith("folder:") + and not word.startswith("pinned:") + and not word.startswith("archived:") + and not word.startswith("shared:") + ) ] search_text = " ".join(search_text_words) @@ -562,30 +694,41 @@ def get_chats_by_user_id_and_search_text( with get_db() as db: query = db.query(Chat).filter(Chat.user_id == user_id) - if not include_archived: + if is_archived is not None: + query = query.filter(Chat.archived == is_archived) + elif not include_archived: query = query.filter(Chat.archived == False) + if is_pinned is not None: + query = query.filter(Chat.pinned == is_pinned) + + if is_shared is not None: + if is_shared: + query = query.filter(Chat.share_id.isnot(None)) + else: + query = query.filter(Chat.share_id.is_(None)) + + if folder_ids: + query = query.filter(Chat.folder_id.in_(folder_ids)) + query = query.order_by(Chat.updated_at.desc()) # Check if the database dialect is either 'sqlite' or 'postgresql' dialect_name = db.bind.dialect.name if dialect_name == "sqlite": # SQLite case: using JSON1 extension for JSON searching + sqlite_content_sql = ( + "EXISTS (" + " SELECT 1 " + " FROM json_each(Chat.chat, '$.messages') AS message " + " WHERE LOWER(message.value->>'content') LIKE '%' || :content_key || '%'" + ")" + ) + sqlite_content_clause = text(sqlite_content_sql) query = query.filter( - ( - Chat.title.ilike( - f"%{search_text}%" - ) # Case-insensitive search in title - | text( - """ - EXISTS ( - SELECT 1 - FROM json_each(Chat.chat, '$.messages') AS message - WHERE LOWER(message.value->>'content') LIKE '%' || :search_text || '%' - ) - """ - ) - ).params(search_text=search_text) + or_( + Chat.title.ilike(bindparam("title_key")), sqlite_content_clause + ).params(title_key=f"%{search_text}%", content_key=search_text) ) # Check if there are any tags to filter, it should have all the tags @@ -620,21 +763,19 @@ def get_chats_by_user_id_and_search_text( elif dialect_name == "postgresql": # PostgreSQL relies on proper JSON query for search + postgres_content_sql = ( + "EXISTS (" + " SELECT 1 " + " FROM json_array_elements(Chat.chat->'messages') AS message " + " WHERE LOWER(message->>'content') LIKE '%' || :content_key || '%'" + ")" + ) + postgres_content_clause = text(postgres_content_sql) query = query.filter( - ( - Chat.title.ilike( - f"%{search_text}%" - ) # Case-insensitive search in title - | text( - """ - EXISTS ( - SELECT 1 - FROM json_array_elements(Chat.chat->'messages') AS message - WHERE LOWER(message->>'content') LIKE '%' || :search_text || '%' - ) - """ - ) - ).params(search_text=search_text) + or_( + Chat.title.ilike(bindparam("title_key")), + postgres_content_clause, + ).params(title_key=f"%{search_text}%", content_key=search_text) ) # Check if there are any tags to filter, it should have all the tags @@ -680,7 +821,7 @@ def get_chats_by_user_id_and_search_text( return [ChatModel.model_validate(chat) for chat in all_chats] def get_chats_by_folder_id_and_user_id( - self, folder_id: str, user_id: str + self, folder_id: str, user_id: str, skip: int = 0, limit: int = 60 ) -> list[ChatModel]: with get_db() as db: query = db.query(Chat).filter_by(folder_id=folder_id, user_id=user_id) @@ -689,6 +830,11 @@ def get_chats_by_folder_id_and_user_id( query = query.order_by(Chat.updated_at.desc()) + if skip: + query = query.offset(skip) + if limit: + query = query.limit(limit) + all_chats = query.all() return [ChatModel.model_validate(chat) for chat in all_chats] @@ -818,6 +964,16 @@ def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> in return count + def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str) -> int: + with get_db() as db: + query = db.query(Chat).filter_by(user_id=user_id) + + query = query.filter_by(folder_id=folder_id) + count = query.count() + + log.info(f"Count of chats for folder '{folder_id}': {count}") + return count + def delete_tag_by_id_and_user_id_and_tag_name( self, id: str, user_id: str, tag_name: str ) -> bool: diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 6f1511cd137..171810fde7b 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -82,6 +82,7 @@ class FileModelResponse(BaseModel): class FileMetadataResponse(BaseModel): id: str + hash: Optional[str] = None meta: dict created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -130,12 +131,24 @@ def get_file_by_id(self, id: str) -> Optional[FileModel]: except Exception: return None + def get_file_by_id_and_user_id(self, id: str, user_id: str) -> Optional[FileModel]: + with get_db() as db: + try: + file = db.query(File).filter_by(id=id, user_id=user_id).first() + if file: + return FileModel.model_validate(file) + else: + return None + except Exception: + return None + def get_file_metadata_by_id(self, id: str) -> Optional[FileMetadataResponse]: with get_db() as db: try: file = db.get(File, id) return FileMetadataResponse( id=file.id, + hash=file.hash, meta=file.meta, created_at=file.created_at, updated_at=file.updated_at, @@ -147,6 +160,15 @@ def get_files(self) -> list[FileModel]: with get_db() as db: return [FileModel.model_validate(file) for file in db.query(File).all()] + def check_access_by_user_id(self, id, user_id, permission="write") -> bool: + file = self.get_file_by_id(id) + if not file: + return False + if file.user_id == user_id: + return True + # Implement additional access control logic here as needed + return False + def get_files_by_ids(self, ids: list[str]) -> list[FileModel]: with get_db() as db: return [ @@ -162,11 +184,14 @@ def get_file_metadatas_by_ids(self, ids: list[str]) -> list[FileMetadataResponse return [ FileMetadataResponse( id=file.id, + hash=file.hash, meta=file.meta, created_at=file.created_at, updated_at=file.updated_at, ) - for file in db.query(File) + for file in db.query( + File.id, File.hash, File.meta, File.created_at, File.updated_at + ) .filter(File.id.in_(ids)) .order_by(File.updated_at.desc()) .all() diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 1c97de26c96..45f82470809 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -2,14 +2,14 @@ import time import uuid from typing import Optional +import re -from open_webui.internal.db import Base, get_db -from open_webui.models.chats import Chats -from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text, JSON, Boolean -from open_webui.utils.access_control import get_permissions +from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func + +from open_webui.internal.db import Base, get_db +from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -29,6 +29,7 @@ class Folder(Base): name = Column(Text) items = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) + data = Column(JSON, nullable=True) is_expanded = Column(Boolean, default=False) created_at = Column(BigInteger) updated_at = Column(BigInteger) @@ -41,6 +42,7 @@ class FolderModel(BaseModel): name: str items: Optional[dict] = None meta: Optional[dict] = None + data: Optional[dict] = None is_expanded: bool = False created_at: int updated_at: int @@ -48,6 +50,20 @@ class FolderModel(BaseModel): model_config = ConfigDict(from_attributes=True) +class FolderMetadataResponse(BaseModel): + icon: Optional[str] = None + + +class FolderNameIdResponse(BaseModel): + id: str + name: str + meta: Optional[FolderMetadataResponse] = None + parent_id: Optional[str] = None + is_expanded: bool = False + created_at: int + updated_at: int + + #################### # Forms #################### @@ -55,12 +71,21 @@ class FolderModel(BaseModel): class FolderForm(BaseModel): name: str + data: Optional[dict] = None + meta: Optional[dict] = None + model_config = ConfigDict(extra="allow") + + +class FolderUpdateForm(BaseModel): + name: Optional[str] = None + data: Optional[dict] = None + meta: Optional[dict] = None model_config = ConfigDict(extra="allow") class FolderTable: def insert_new_folder( - self, user_id: str, name: str, parent_id: Optional[str] = None + self, user_id: str, form_data: FolderForm, parent_id: Optional[str] = None ) -> Optional[FolderModel]: with get_db() as db: id = str(uuid.uuid4()) @@ -68,7 +93,7 @@ def insert_new_folder( **{ "id": id, "user_id": user_id, - "name": name, + **(form_data.model_dump(exclude_unset=True) or {}), "parent_id": parent_id, "created_at": int(time.time()), "updated_at": int(time.time()), @@ -103,7 +128,7 @@ def get_folder_by_id_and_user_id( def get_children_folders_by_id_and_user_id( self, id: str, user_id: str - ) -> Optional[FolderModel]: + ) -> Optional[list[FolderModel]]: try: with get_db() as db: folders = [] @@ -187,8 +212,8 @@ def update_folder_parent_id_by_id_and_user_id( log.error(f"update_folder: {e}") return - def update_folder_name_by_id_and_user_id( - self, id: str, user_id: str, name: str + def update_folder_by_id_and_user_id( + self, id: str, user_id: str, form_data: FolderUpdateForm ) -> Optional[FolderModel]: try: with get_db() as db: @@ -197,18 +222,35 @@ def update_folder_name_by_id_and_user_id( if not folder: return None + form_data = form_data.model_dump(exclude_unset=True) + existing_folder = ( db.query(Folder) - .filter_by(name=name, parent_id=folder.parent_id, user_id=user_id) + .filter_by( + name=form_data.get("name"), + parent_id=folder.parent_id, + user_id=user_id, + ) .first() ) - if existing_folder: + if existing_folder and existing_folder.id != id: return None - folder.name = name - folder.updated_at = int(time.time()) + folder.name = form_data.get("name", folder.name) + if "data" in form_data: + folder.data = { + **(folder.data or {}), + **form_data["data"], + } + if "meta" in form_data: + folder.meta = { + **(folder.meta or {}), + **form_data["meta"], + } + + folder.updated_at = int(time.time()) db.commit() return FolderModel.model_validate(folder) @@ -236,18 +278,15 @@ def update_folder_is_expanded_by_id_and_user_id( log.error(f"update_folder: {e}") return - def delete_folder_by_id_and_user_id( - self, id: str, user_id: str, delete_chats=True - ) -> bool: + def delete_folder_by_id_and_user_id(self, id: str, user_id: str) -> list[str]: try: + folder_ids = [] with get_db() as db: folder = db.query(Folder).filter_by(id=id, user_id=user_id).first() if not folder: - return False + return folder_ids - if delete_chats: - # Delete all chats in the folder - Chats.delete_chats_by_user_id_and_folder_id(user_id, folder.id) + folder_ids.append(folder.id) # Delete all children folders def delete_children(folder): @@ -255,12 +294,9 @@ def delete_children(folder): folder.id, user_id ) for folder_child in folder_children: - if delete_chats: - Chats.delete_chats_by_user_id_and_folder_id( - user_id, folder_child.id - ) delete_children(folder_child) + folder_ids.append(folder_child.id) folder = db.query(Folder).filter_by(id=folder_child.id).first() db.delete(folder) @@ -269,10 +305,62 @@ def delete_children(folder): delete_children(folder) db.delete(folder) db.commit() - return True + return folder_ids except Exception as e: log.error(f"delete_folder: {e}") - return False + return [] + + def normalize_folder_name(self, name: str) -> str: + # Replace _ and space with a single space, lower case, collapse multiple spaces + name = re.sub(r"[\s_]+", " ", name) + return name.strip().lower() + + def search_folders_by_names( + self, user_id: str, queries: list[str] + ) -> list[FolderModel]: + """ + Search for folders for a user where the name matches any of the queries, treating _ and space as equivalent, case-insensitive. + """ + normalized_queries = [self.normalize_folder_name(q) for q in queries] + if not normalized_queries: + return [] + + results = {} + with get_db() as db: + folders = db.query(Folder).filter_by(user_id=user_id).all() + for folder in folders: + if self.normalize_folder_name(folder.name) in normalized_queries: + results[folder.id] = FolderModel.model_validate(folder) + + # get children folders + children = self.get_children_folders_by_id_and_user_id( + folder.id, user_id + ) + for child in children: + results[child.id] = child + + # Return the results as a list + if not results: + return [] + else: + results = list(results.values()) + return results + + def search_folders_by_name_contains( + self, user_id: str, query: str + ) -> list[FolderModel]: + """ + Partial match: normalized name contains (as substring) the normalized query. + """ + normalized_query = self.normalize_folder_name(query) + results = [] + with get_db() as db: + folders = db.query(Folder).filter_by(user_id=user_id).all() + for folder in folders: + norm_name = self.normalize_folder_name(folder.name) + if normalized_query in norm_name: + results.append(FolderModel.model_validate(folder)) + return results Folders = FolderTable() diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 8cbfc5de7d2..2020a296335 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -3,10 +3,10 @@ from typing import Optional from open_webui.internal.db import Base, JSONField, get_db -from open_webui.models.users import Users +from open_webui.models.users import Users, UserModel from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text +from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -31,10 +31,13 @@ class Function(Base): updated_at = Column(BigInteger) created_at = Column(BigInteger) + __table_args__ = (Index("is_global_idx", "is_global"),) + class FunctionMeta(BaseModel): description: Optional[str] = None manifest: Optional[dict] = {} + model_config = ConfigDict(extra="allow") class FunctionModel(BaseModel): @@ -52,11 +55,31 @@ class FunctionModel(BaseModel): model_config = ConfigDict(from_attributes=True) +class FunctionWithValvesModel(BaseModel): + id: str + user_id: str + name: str + type: str + content: str + meta: FunctionMeta + valves: Optional[dict] = None + is_active: bool = False + is_global: bool = False + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + + model_config = ConfigDict(from_attributes=True) + + #################### # Forms #################### +class FunctionUserResponse(FunctionModel): + user: Optional[UserModel] = None + + class FunctionResponse(BaseModel): id: str user_id: str @@ -108,6 +131,54 @@ def insert_new_function( log.exception(f"Error creating a new function: {e}") return None + def sync_functions( + self, user_id: str, functions: list[FunctionWithValvesModel] + ) -> list[FunctionWithValvesModel]: + # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. + try: + with get_db() as db: + # Get existing functions + existing_functions = db.query(Function).all() + existing_ids = {func.id for func in existing_functions} + + # Prepare a set of new function IDs + new_function_ids = {func.id for func in functions} + + # Update or insert functions + for func in functions: + if func.id in existing_ids: + db.query(Function).filter_by(id=func.id).update( + { + **func.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + else: + new_func = Function( + **{ + **func.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + db.add(new_func) + + # Remove functions that are no longer present + for func in existing_functions: + if func.id not in new_function_ids: + db.delete(func) + + db.commit() + + return [ + FunctionModel.model_validate(func) + for func in db.query(Function).all() + ] + except Exception as e: + log.exception(f"Error syncing functions for user {user_id}: {e}") + return [] + def get_function_by_id(self, id: str) -> Optional[FunctionModel]: try: with get_db() as db: @@ -116,19 +187,48 @@ def get_function_by_id(self, id: str) -> Optional[FunctionModel]: except Exception: return None - def get_functions(self, active_only=False) -> list[FunctionModel]: + def get_functions( + self, active_only=False, include_valves=False + ) -> list[FunctionModel | FunctionWithValvesModel]: with get_db() as db: if active_only: + functions = db.query(Function).filter_by(is_active=True).all() + + else: + functions = db.query(Function).all() + + if include_valves: return [ - FunctionModel.model_validate(function) - for function in db.query(Function).filter_by(is_active=True).all() + FunctionWithValvesModel.model_validate(function) + for function in functions ] else: return [ - FunctionModel.model_validate(function) - for function in db.query(Function).all() + FunctionModel.model_validate(function) for function in functions ] + def get_function_list(self) -> list[FunctionUserResponse]: + with get_db() as db: + functions = db.query(Function).order_by(Function.updated_at.desc()).all() + user_ids = list(set(func.user_id for func in functions)) + + users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users_dict = {user.id: user for user in users} + + return [ + FunctionUserResponse.model_validate( + { + **FunctionModel.model_validate(func).model_dump(), + "user": ( + users_dict.get(func.user_id).model_dump() + if func.user_id in users_dict + else None + ), + } + ) + for func in functions + ] + def get_functions_by_type( self, type: str, active_only=False ) -> list[FunctionModel]: @@ -187,6 +287,29 @@ def update_function_valves_by_id( except Exception: return None + def update_function_metadata_by_id( + self, id: str, metadata: dict + ) -> Optional[FunctionModel]: + with get_db() as db: + try: + function = db.get(Function, id) + + if function: + if function.meta: + function.meta = {**function.meta, **metadata} + else: + function.meta = metadata + + function.updated_at = int(time.time()) + db.commit() + db.refresh(function) + return self.get_function_by_id(id) + else: + return None + except Exception as e: + log.exception(f"Error updating function metadata by id {id}: {e}") + return None + def get_user_valves_by_id_and_user_id( self, id: str, user_id: str ) -> Optional[dict]: @@ -202,9 +325,7 @@ def get_user_valves_by_id_and_user_id( return user_settings["functions"]["valves"].get(id, {}) except Exception as e: - log.exception( - f"Error getting user values by id {id} and user id {user_id}: {e}" - ) + log.exception(f"Error getting user values by id {id} and user id {user_id}") return None def update_user_valves_by_id_and_user_id( diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index 763340fbcb6..a09b2b73f96 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -83,10 +83,14 @@ class GroupForm(BaseModel): permissions: Optional[dict] = None -class GroupUpdateForm(GroupForm): +class UserIdsForm(BaseModel): user_ids: Optional[list[str]] = None +class GroupUpdateForm(GroupForm, UserIdsForm): + pass + + class GroupTable: def insert_new_group( self, user_id: str, form_data: GroupForm @@ -207,5 +211,131 @@ def remove_user_from_all_groups(self, user_id: str) -> bool: except Exception: return False + def create_groups_by_group_names( + self, user_id: str, group_names: list[str] + ) -> list[GroupModel]: + + # check for existing groups + existing_groups = self.get_groups() + existing_group_names = {group.name for group in existing_groups} + + new_groups = [] + + with get_db() as db: + for group_name in group_names: + if group_name not in existing_group_names: + new_group = GroupModel( + id=str(uuid.uuid4()), + user_id=user_id, + name=group_name, + description="", + created_at=int(time.time()), + updated_at=int(time.time()), + ) + try: + result = Group(**new_group.model_dump()) + db.add(result) + db.commit() + db.refresh(result) + new_groups.append(GroupModel.model_validate(result)) + except Exception as e: + log.exception(e) + continue + return new_groups + + def sync_groups_by_group_names(self, user_id: str, group_names: list[str]) -> bool: + with get_db() as db: + try: + groups = db.query(Group).filter(Group.name.in_(group_names)).all() + group_ids = [group.id for group in groups] + + # Remove user from groups not in the new list + existing_groups = self.get_groups_by_member_id(user_id) + + for group in existing_groups: + if group.id not in group_ids: + group.user_ids.remove(user_id) + db.query(Group).filter_by(id=group.id).update( + { + "user_ids": group.user_ids, + "updated_at": int(time.time()), + } + ) + + # Add user to new groups + for group in groups: + if user_id not in group.user_ids: + group.user_ids.append(user_id) + db.query(Group).filter_by(id=group.id).update( + { + "user_ids": group.user_ids, + "updated_at": int(time.time()), + } + ) + + db.commit() + return True + except Exception as e: + log.exception(e) + return False + + def add_users_to_group( + self, id: str, user_ids: Optional[list[str]] = None + ) -> Optional[GroupModel]: + try: + with get_db() as db: + group = db.query(Group).filter_by(id=id).first() + if not group: + return None + + group_user_ids = group.user_ids + if not group_user_ids or not isinstance(group_user_ids, list): + group_user_ids = [] + + group_user_ids = list(set(group_user_ids)) # Deduplicate + + for user_id in user_ids: + if user_id not in group_user_ids: + group_user_ids.append(user_id) + + group.user_ids = group_user_ids + group.updated_at = int(time.time()) + db.commit() + db.refresh(group) + return GroupModel.model_validate(group) + except Exception as e: + log.exception(e) + return None + + def remove_users_from_group( + self, id: str, user_ids: Optional[list[str]] = None + ) -> Optional[GroupModel]: + try: + with get_db() as db: + group = db.query(Group).filter_by(id=id).first() + if not group: + return None + + group_user_ids = group.user_ids + + if not group_user_ids or not isinstance(group_user_ids, list): + return GroupModel.model_validate(group) + + group_user_ids = list(set(group_user_ids)) # Deduplicate + + for user_id in user_ids: + if user_id in group_user_ids: + group_user_ids.remove(user_id) + + group.user_ids = group_user_ids + group.updated_at = int(time.time()) + + db.commit() + db.refresh(group) + return GroupModel.model_validate(group) + except Exception as e: + log.exception(e) + return None + Groups = GroupTable() diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index bed3d5542e7..cfef77e2375 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -8,6 +8,7 @@ from open_webui.env import SRC_LOG_LEVELS from open_webui.models.files import FileMetadataResponse +from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse @@ -128,11 +129,18 @@ def insert_new_knowledge( def get_knowledge_bases(self) -> list[KnowledgeUserModel]: with get_db() as db: - knowledge_bases = [] - for knowledge in ( + all_knowledge = ( db.query(Knowledge).order_by(Knowledge.updated_at.desc()).all() - ): - user = Users.get_user_by_id(knowledge.user_id) + ) + + user_ids = list(set(knowledge.user_id for knowledge in all_knowledge)) + + users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users_dict = {user.id: user for user in users} + + knowledge_bases = [] + for knowledge in all_knowledge: + user = users_dict.get(knowledge.user_id) knowledge_bases.append( KnowledgeUserModel.model_validate( { @@ -143,15 +151,27 @@ def get_knowledge_bases(self) -> list[KnowledgeUserModel]: ) return knowledge_bases + def check_access_by_user_id(self, id, user_id, permission="write") -> bool: + knowledge = self.get_knowledge_by_id(id) + if not knowledge: + return False + if knowledge.user_id == user_id: + return True + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} + return has_access(user_id, permission, knowledge.access_control, user_group_ids) + def get_knowledge_bases_by_user_id( self, user_id: str, permission: str = "write" ) -> list[KnowledgeUserModel]: knowledge_bases = self.get_knowledge_bases() + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} return [ knowledge_base for knowledge_base in knowledge_bases if knowledge_base.user_id == user_id - or has_access(user_id, permission, knowledge_base.access_control) + or has_access( + user_id, permission, knowledge_base.access_control, user_group_ids + ) ] def get_knowledge_by_id(self, id: str) -> Optional[KnowledgeModel]: diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index 8b10a77cf99..253371c6800 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -71,9 +71,13 @@ def update_memory_by_id_and_user_id( ) -> Optional[MemoryModel]: with get_db() as db: try: - db.query(Memory).filter_by(id=id, user_id=user_id).update( - {"content": content, "updated_at": int(time.time())} - ) + memory = db.get(Memory, id) + if not memory or memory.user_id != user_id: + return None + + memory.content = content + memory.updated_at = int(time.time()) + db.commit() return self.get_memory_by_id(id) except Exception: @@ -127,7 +131,12 @@ def delete_memories_by_user_id(self, user_id: str) -> bool: def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool: with get_db() as db: try: - db.query(Memory).filter_by(id=id, user_id=user_id).delete() + memory = db.get(Memory, id) + if not memory or memory.user_id != user_id: + return None + + # Delete the memory + db.delete(memory) db.commit() return True diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index a27ae525198..8b0027b8e78 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -5,6 +5,7 @@ from open_webui.internal.db import Base, get_db from open_webui.models.tags import TagModel, Tag, Tags +from open_webui.models.users import Users, UserNameResponse from pydantic import BaseModel, ConfigDict @@ -43,6 +44,7 @@ class Message(Base): user_id = Column(Text) channel_id = Column(Text, nullable=True) + reply_to_id = Column(Text, nullable=True) parent_id = Column(Text, nullable=True) content = Column(Text) @@ -60,6 +62,7 @@ class MessageModel(BaseModel): user_id: str channel_id: Optional[str] = None + reply_to_id: Optional[str] = None parent_id: Optional[str] = None content: str @@ -77,6 +80,7 @@ class MessageModel(BaseModel): class MessageForm(BaseModel): content: str + reply_to_id: Optional[str] = None parent_id: Optional[str] = None data: Optional[dict] = None meta: Optional[dict] = None @@ -88,7 +92,15 @@ class Reactions(BaseModel): count: int -class MessageResponse(MessageModel): +class MessageUserResponse(MessageModel): + user: Optional[UserNameResponse] = None + + +class MessageReplyToResponse(MessageUserResponse): + reply_to_message: Optional[MessageUserResponse] = None + + +class MessageResponse(MessageReplyToResponse): latest_reply_at: Optional[int] reply_count: int reactions: list[Reactions] @@ -107,6 +119,7 @@ def insert_new_message( "id": id, "user_id": user_id, "channel_id": channel_id, + "reply_to_id": form_data.reply_to_id, "parent_id": form_data.parent_id, "content": form_data.content, "data": form_data.data, @@ -128,19 +141,32 @@ def get_message_by_id(self, id: str) -> Optional[MessageResponse]: if not message: return None + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + reactions = self.get_reactions_by_message_id(id) - replies = self.get_replies_by_message_id(id) + thread_replies = self.get_thread_replies_by_message_id(id) - return MessageResponse( - **{ + user = Users.get_user_by_id(message.user_id) + return MessageResponse.model_validate( + { **MessageModel.model_validate(message).model_dump(), - "latest_reply_at": replies[0].created_at if replies else None, - "reply_count": len(replies), + "user": user.model_dump() if user else None, + "reply_to_message": ( + reply_to_message.model_dump() if reply_to_message else None + ), + "latest_reply_at": ( + thread_replies[0].created_at if thread_replies else None + ), + "reply_count": len(thread_replies), "reactions": reactions, } ) - def get_replies_by_message_id(self, id: str) -> list[MessageModel]: + def get_thread_replies_by_message_id(self, id: str) -> list[MessageReplyToResponse]: with get_db() as db: all_messages = ( db.query(Message) @@ -148,7 +174,27 @@ def get_replies_by_message_id(self, id: str) -> list[MessageModel]: .order_by(Message.created_at.desc()) .all() ) - return [MessageModel.model_validate(message) for message in all_messages] + + messages = [] + for message in all_messages: + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + messages.append( + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + reply_to_message.model_dump() + if reply_to_message + else None + ), + } + ) + ) + return messages def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: with get_db() as db: @@ -159,7 +205,7 @@ def get_reply_user_ids_by_message_id(self, id: str) -> list[str]: def get_messages_by_channel_id( self, channel_id: str, skip: int = 0, limit: int = 50 - ) -> list[MessageModel]: + ) -> list[MessageReplyToResponse]: with get_db() as db: all_messages = ( db.query(Message) @@ -169,11 +215,31 @@ def get_messages_by_channel_id( .limit(limit) .all() ) - return [MessageModel.model_validate(message) for message in all_messages] + + messages = [] + for message in all_messages: + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + messages.append( + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + reply_to_message.model_dump() + if reply_to_message + else None + ), + } + ) + ) + return messages def get_messages_by_parent_id( self, channel_id: str, parent_id: str, skip: int = 0, limit: int = 50 - ) -> list[MessageModel]: + ) -> list[MessageReplyToResponse]: with get_db() as db: message = db.get(Message, parent_id) @@ -193,7 +259,26 @@ def get_messages_by_parent_id( if len(all_messages) < limit: all_messages.append(message) - return [MessageModel.model_validate(message) for message in all_messages] + messages = [] + for message in all_messages: + reply_to_message = ( + self.get_message_by_id(message.reply_to_id) + if message.reply_to_id + else None + ) + messages.append( + MessageReplyToResponse.model_validate( + { + **MessageModel.model_validate(message).model_dump(), + "reply_to_message": ( + reply_to_message.model_dump() + if reply_to_message + else None + ), + } + ) + ) + return messages def update_message_by_id( self, id: str, form_data: MessageForm @@ -201,8 +286,14 @@ def update_message_by_id( with get_db() as db: message = db.get(Message, id) message.content = form_data.content - message.data = form_data.data - message.meta = form_data.meta + message.data = { + **(message.data if message.data else {}), + **(form_data.data if form_data.data else {}), + } + message.meta = { + **(message.meta if message.meta else {}), + **(form_data.meta if form_data.meta else {}), + } message.updated_at = int(time.time_ns()) db.commit() db.refresh(message) diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 7df8d8656b6..93dafe0f052 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -5,6 +5,7 @@ from open_webui.internal.db import Base, JSONField, get_db from open_webui.env import SRC_LOG_LEVELS +from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse @@ -175,9 +176,16 @@ def get_all_models(self) -> list[ModelModel]: def get_models(self) -> list[ModelUserResponse]: with get_db() as db: + all_models = db.query(Model).filter(Model.base_model_id != None).all() + + user_ids = list(set(model.user_id for model in all_models)) + + users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users_dict = {user.id: user for user in users} + models = [] - for model in db.query(Model).filter(Model.base_model_id != None).all(): - user = Users.get_user_by_id(model.user_id) + for model in all_models: + user = users_dict.get(model.user_id) models.append( ModelUserResponse.model_validate( { @@ -199,11 +207,12 @@ def get_models_by_user_id( self, user_id: str, permission: str = "write" ) -> list[ModelUserResponse]: models = self.get_models() + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} return [ model for model in models if model.user_id == user_id - or has_access(user_id, permission, model.access_control) + or has_access(user_id, permission, model.access_control, user_group_ids) ] def get_model_by_id(self, id: str) -> Optional[ModelModel]: @@ -269,5 +278,49 @@ def delete_all_models(self) -> bool: except Exception: return False + def sync_models(self, user_id: str, models: list[ModelModel]) -> list[ModelModel]: + try: + with get_db() as db: + # Get existing models + existing_models = db.query(Model).all() + existing_ids = {model.id for model in existing_models} + + # Prepare a set of new model IDs + new_model_ids = {model.id for model in models} + + # Update or insert models + for model in models: + if model.id in existing_ids: + db.query(Model).filter_by(id=model.id).update( + { + **model.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + else: + new_model = Model( + **{ + **model.model_dump(), + "user_id": user_id, + "updated_at": int(time.time()), + } + ) + db.add(new_model) + + # Remove models that are no longer present + for model in existing_models: + if model.id not in new_model_ids: + db.delete(model) + + db.commit() + + return [ + ModelModel.model_validate(model) for model in db.query(Model).all() + ] + except Exception as e: + log.exception(f"Error syncing models for user {user_id}: {e}") + return [] + Models = ModelsTable() diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index 114ccdc574d..f1b11f071e1 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -2,8 +2,10 @@ import time import uuid from typing import Optional +from functools import lru_cache from open_webui.internal.db import Base, get_db +from open_webui.models.groups import Groups from open_webui.utils.access_control import has_access from open_webui.models.users import Users, UserResponse @@ -62,6 +64,13 @@ class NoteForm(BaseModel): access_control: Optional[dict] = None +class NoteUpdateForm(BaseModel): + title: Optional[str] = None + data: Optional[dict] = None + meta: Optional[dict] = None + access_control: Optional[dict] = None + + class NoteUserResponse(NoteModel): user: Optional[UserResponse] = None @@ -89,37 +98,111 @@ def insert_new_note( db.commit() return note - def get_notes(self) -> list[NoteModel]: + def get_notes( + self, skip: Optional[int] = None, limit: Optional[int] = None + ) -> list[NoteModel]: with get_db() as db: - notes = db.query(Note).order_by(Note.updated_at.desc()).all() + query = db.query(Note).order_by(Note.updated_at.desc()) + if skip is not None: + query = query.offset(skip) + if limit is not None: + query = query.limit(limit) + notes = query.all() return [NoteModel.model_validate(note) for note in notes] def get_notes_by_user_id( - self, user_id: str, permission: str = "write" + self, + user_id: str, + skip: Optional[int] = None, + limit: Optional[int] = None, ) -> list[NoteModel]: - notes = self.get_notes() - return [ - note - for note in notes - if note.user_id == user_id - or has_access(user_id, permission, note.access_control) - ] + with get_db() as db: + query = db.query(Note).filter(Note.user_id == user_id) + query = query.order_by(Note.updated_at.desc()) + + if skip is not None: + query = query.offset(skip) + if limit is not None: + query = query.limit(limit) + + notes = query.all() + return [NoteModel.model_validate(note) for note in notes] + + def get_notes_by_permission( + self, + user_id: str, + permission: str = "write", + skip: Optional[int] = None, + limit: Optional[int] = None, + ) -> list[NoteModel]: + with get_db() as db: + user_groups = Groups.get_groups_by_member_id(user_id) + user_group_ids = {group.id for group in user_groups} + + # Order newest-first. We stream to keep memory usage low. + query = ( + db.query(Note) + .order_by(Note.updated_at.desc()) + .execution_options(stream_results=True) + .yield_per(256) + ) + + results: list[NoteModel] = [] + n_skipped = 0 + + for note in query: + # Fast-pass #1: owner + if note.user_id == user_id: + permitted = True + # Fast-pass #2: public/open + elif note.access_control is None: + # Technically this should mean public access for both read and write, but we'll only do read for now + # We might want to change this behavior later + permitted = permission == "read" + else: + permitted = has_access( + user_id, permission, note.access_control, user_group_ids + ) + + if not permitted: + continue + + # Apply skip AFTER permission filtering so it counts only accessible notes + if skip and n_skipped < skip: + n_skipped += 1 + continue + + results.append(NoteModel.model_validate(note)) + if limit is not None and len(results) >= limit: + break + + return results def get_note_by_id(self, id: str) -> Optional[NoteModel]: with get_db() as db: note = db.query(Note).filter(Note.id == id).first() return NoteModel.model_validate(note) if note else None - def update_note_by_id(self, id: str, form_data: NoteForm) -> Optional[NoteModel]: + def update_note_by_id( + self, id: str, form_data: NoteUpdateForm + ) -> Optional[NoteModel]: with get_db() as db: note = db.query(Note).filter(Note.id == id).first() if not note: return None - note.title = form_data.title - note.data = form_data.data - note.meta = form_data.meta - note.access_control = form_data.access_control + form_data = form_data.model_dump(exclude_unset=True) + + if "title" in form_data: + note.title = form_data["title"] + if "data" in form_data: + note.data = {**note.data, **form_data["data"]} + if "meta" in form_data: + note.meta = {**note.meta, **form_data["meta"]} + + if "access_control" in form_data: + note.access_control = form_data["access_control"] + note.updated_at = int(time.time_ns()) db.commit() diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py new file mode 100644 index 00000000000..81ce2203842 --- /dev/null +++ b/backend/open_webui/models/oauth_sessions.py @@ -0,0 +1,266 @@ +import time +import logging +import uuid +from typing import Optional, List +import base64 +import hashlib +import json + +from cryptography.fernet import Fernet + +from open_webui.internal.db import Base, get_db +from open_webui.env import SRC_LOG_LEVELS, OAUTH_SESSION_TOKEN_ENCRYPTION_KEY + +from pydantic import BaseModel, ConfigDict +from sqlalchemy import BigInteger, Column, String, Text, Index + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MODELS"]) + +#################### +# DB MODEL +#################### + + +class OAuthSession(Base): + __tablename__ = "oauth_session" + + id = Column(Text, primary_key=True) + user_id = Column(Text, nullable=False) + provider = Column(Text, nullable=False) + token = Column( + Text, nullable=False + ) # JSON with access_token, id_token, refresh_token + expires_at = Column(BigInteger, nullable=False) + created_at = Column(BigInteger, nullable=False) + updated_at = Column(BigInteger, nullable=False) + + # Add indexes for better performance + __table_args__ = ( + Index("idx_oauth_session_user_id", "user_id"), + Index("idx_oauth_session_expires_at", "expires_at"), + Index("idx_oauth_session_user_provider", "user_id", "provider"), + ) + + +class OAuthSessionModel(BaseModel): + id: str + user_id: str + provider: str + token: dict + expires_at: int # timestamp in epoch + created_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + + model_config = ConfigDict(from_attributes=True) + + +#################### +# Forms +#################### + + +class OAuthSessionResponse(BaseModel): + id: str + user_id: str + provider: str + expires_at: int + + +class OAuthSessionTable: + def __init__(self): + self.encryption_key = OAUTH_SESSION_TOKEN_ENCRYPTION_KEY + if not self.encryption_key: + raise Exception("OAUTH_SESSION_TOKEN_ENCRYPTION_KEY is not set") + + # check if encryption key is in the right format for Fernet (32 url-safe base64-encoded bytes) + if len(self.encryption_key) != 44: + key_bytes = hashlib.sha256(self.encryption_key.encode()).digest() + self.encryption_key = base64.urlsafe_b64encode(key_bytes) + else: + self.encryption_key = self.encryption_key.encode() + + try: + self.fernet = Fernet(self.encryption_key) + except Exception as e: + log.error(f"Error initializing Fernet with provided key: {e}") + raise + + def _encrypt_token(self, token) -> str: + """Encrypt OAuth tokens for storage""" + try: + token_json = json.dumps(token) + encrypted = self.fernet.encrypt(token_json.encode()).decode() + return encrypted + except Exception as e: + log.error(f"Error encrypting tokens: {e}") + raise + + def _decrypt_token(self, token: str): + """Decrypt OAuth tokens from storage""" + try: + decrypted = self.fernet.decrypt(token.encode()).decode() + return json.loads(decrypted) + except Exception as e: + log.error(f"Error decrypting tokens: {e}") + raise + + def create_session( + self, + user_id: str, + provider: str, + token: dict, + ) -> Optional[OAuthSessionModel]: + """Create a new OAuth session""" + try: + with get_db() as db: + current_time = int(time.time()) + id = str(uuid.uuid4()) + + result = OAuthSession( + **{ + "id": id, + "user_id": user_id, + "provider": provider, + "token": self._encrypt_token(token), + "expires_at": token.get("expires_at"), + "created_at": current_time, + "updated_at": current_time, + } + ) + + db.add(result) + db.commit() + db.refresh(result) + + if result: + result.token = token # Return decrypted token + return OAuthSessionModel.model_validate(result) + else: + return None + except Exception as e: + log.error(f"Error creating OAuth session: {e}") + return None + + def get_session_by_id(self, session_id: str) -> Optional[OAuthSessionModel]: + """Get OAuth session by ID""" + try: + with get_db() as db: + session = db.query(OAuthSession).filter_by(id=session_id).first() + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error getting OAuth session by ID: {e}") + return None + + def get_session_by_id_and_user_id( + self, session_id: str, user_id: str + ) -> Optional[OAuthSessionModel]: + """Get OAuth session by ID and user ID""" + try: + with get_db() as db: + session = ( + db.query(OAuthSession) + .filter_by(id=session_id, user_id=user_id) + .first() + ) + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error getting OAuth session by ID: {e}") + return None + + def get_session_by_provider_and_user_id( + self, provider: str, user_id: str + ) -> Optional[OAuthSessionModel]: + """Get OAuth session by provider and user ID""" + try: + with get_db() as db: + session = ( + db.query(OAuthSession) + .filter_by(provider=provider, user_id=user_id) + .first() + ) + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error getting OAuth session by provider and user ID: {e}") + return None + + def get_sessions_by_user_id(self, user_id: str) -> List[OAuthSessionModel]: + """Get all OAuth sessions for a user""" + try: + with get_db() as db: + sessions = db.query(OAuthSession).filter_by(user_id=user_id).all() + + results = [] + for session in sessions: + session.token = self._decrypt_token(session.token) + results.append(OAuthSessionModel.model_validate(session)) + + return results + + except Exception as e: + log.error(f"Error getting OAuth sessions by user ID: {e}") + return [] + + def update_session_by_id( + self, session_id: str, token: dict + ) -> Optional[OAuthSessionModel]: + """Update OAuth session tokens""" + try: + with get_db() as db: + current_time = int(time.time()) + + db.query(OAuthSession).filter_by(id=session_id).update( + { + "token": self._encrypt_token(token), + "expires_at": token.get("expires_at"), + "updated_at": current_time, + } + ) + db.commit() + session = db.query(OAuthSession).filter_by(id=session_id).first() + + if session: + session.token = self._decrypt_token(session.token) + return OAuthSessionModel.model_validate(session) + + return None + except Exception as e: + log.error(f"Error updating OAuth session tokens: {e}") + return None + + def delete_session_by_id(self, session_id: str) -> bool: + """Delete an OAuth session""" + try: + with get_db() as db: + result = db.query(OAuthSession).filter_by(id=session_id).delete() + db.commit() + return result > 0 + except Exception as e: + log.error(f"Error deleting OAuth session: {e}") + return False + + def delete_sessions_by_user_id(self, user_id: str) -> bool: + """Delete all OAuth sessions for a user""" + try: + with get_db() as db: + result = db.query(OAuthSession).filter_by(user_id=user_id).delete() + db.commit() + return True + except Exception as e: + log.error(f"Error deleting OAuth sessions by user ID: {e}") + return False + + +OAuthSessions = OAuthSessionTable() diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 8ef4cd2bec6..7502f34ccd7 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -2,6 +2,7 @@ from typing import Optional from open_webui.internal.db import Base, get_db +from open_webui.models.groups import Groups from open_webui.models.users import Users, UserResponse from pydantic import BaseModel, ConfigDict @@ -103,10 +104,16 @@ def get_prompt_by_command(self, command: str) -> Optional[PromptModel]: def get_prompts(self) -> list[PromptUserResponse]: with get_db() as db: - prompts = [] + all_prompts = db.query(Prompt).order_by(Prompt.timestamp.desc()).all() + + user_ids = list(set(prompt.user_id for prompt in all_prompts)) - for prompt in db.query(Prompt).order_by(Prompt.timestamp.desc()).all(): - user = Users.get_user_by_id(prompt.user_id) + users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users_dict = {user.id: user for user in users} + + prompts = [] + for prompt in all_prompts: + user = users_dict.get(prompt.user_id) prompts.append( PromptUserResponse.model_validate( { @@ -122,12 +129,13 @@ def get_prompts_by_user_id( self, user_id: str, permission: str = "write" ) -> list[PromptUserResponse]: prompts = self.get_prompts() + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} return [ prompt for prompt in prompts if prompt.user_id == user_id - or has_access(user_id, permission, prompt.access_control) + or has_access(user_id, permission, prompt.access_control, user_group_ids) ] def update_prompt_by_command( diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 279dc624d52..e1cbb68a0b3 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -8,7 +8,7 @@ from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint +from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -24,6 +24,11 @@ class Tag(Base): user_id = Column(String) meta = Column(JSON, nullable=True) + __table_args__ = ( + PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"), + Index("user_id_idx", "user_id"), + ) + # Unique constraint ensuring (id, user_id) is unique, not just the `id` column __table_args__ = (PrimaryKeyConstraint("id", "user_id", name="pk_id_user_id"),) diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index 68a83ea42c8..48f84b3ac4d 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -4,6 +4,8 @@ from open_webui.internal.db import Base, JSONField, get_db from open_webui.models.users import Users, UserResponse +from open_webui.models.groups import Groups + from open_webui.env import SRC_LOG_LEVELS from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, JSON @@ -93,6 +95,8 @@ class ToolResponse(BaseModel): class ToolUserResponse(ToolResponse): user: Optional[UserResponse] = None + model_config = ConfigDict(extra="allow") + class ToolForm(BaseModel): id: str @@ -144,9 +148,16 @@ def get_tool_by_id(self, id: str) -> Optional[ToolModel]: def get_tools(self) -> list[ToolUserModel]: with get_db() as db: + all_tools = db.query(Tool).order_by(Tool.updated_at.desc()).all() + + user_ids = list(set(tool.user_id for tool in all_tools)) + + users = Users.get_users_by_user_ids(user_ids) if user_ids else [] + users_dict = {user.id: user for user in users} + tools = [] - for tool in db.query(Tool).order_by(Tool.updated_at.desc()).all(): - user = Users.get_user_by_id(tool.user_id) + for tool in all_tools: + user = users_dict.get(tool.user_id) tools.append( ToolUserModel.model_validate( { @@ -161,12 +172,13 @@ def get_tools_by_user_id( self, user_id: str, permission: str = "write" ) -> list[ToolUserModel]: tools = self.get_tools() + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user_id)} return [ tool for tool in tools if tool.user_id == user_id - or has_access(user_id, permission, tool.access_control) + or has_access(user_id, permission, tool.access_control, user_group_ids) ] def get_tool_valves_by_id(self, id: str) -> Optional[dict]: @@ -175,7 +187,7 @@ def get_tool_valves_by_id(self, id: str) -> Optional[dict]: tool = db.get(Tool, id) return tool.valves if tool.valves else {} except Exception as e: - log.exception(f"Error getting tool valves by id {id}: {e}") + log.exception(f"Error getting tool valves by id {id}") return None def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 3222aa27a67..05000744dd4 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -4,14 +4,17 @@ from open_webui.internal.db import Base, JSONField, get_db +from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL from open_webui.models.chats import Chats from open_webui.models.groups import Groups +from open_webui.utils.misc import throttle from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, Date from sqlalchemy import or_ +import datetime #################### # User DB Schema @@ -23,20 +26,28 @@ class User(Base): id = Column(String, primary_key=True) name = Column(String) + email = Column(String) + username = Column(String(50), nullable=True) + role = Column(String) profile_image_url = Column(Text) - last_active_at = Column(BigInteger) - updated_at = Column(BigInteger) - created_at = Column(BigInteger) + bio = Column(Text, nullable=True) + gender = Column(Text, nullable=True) + date_of_birth = Column(Date, nullable=True) - api_key = Column(String, nullable=True, unique=True) - settings = Column(JSONField, nullable=True) info = Column(JSONField, nullable=True) + settings = Column(JSONField, nullable=True) + api_key = Column(String, nullable=True, unique=True) oauth_sub = Column(Text, unique=True) + last_active_at = Column(BigInteger) + + updated_at = Column(BigInteger) + created_at = Column(BigInteger) + class UserSettings(BaseModel): ui: Optional[dict] = {} @@ -47,20 +58,27 @@ class UserSettings(BaseModel): class UserModel(BaseModel): id: str name: str + email: str + username: Optional[str] = None + role: str = "pending" profile_image_url: str - last_active_at: int # timestamp in epoch - updated_at: int # timestamp in epoch - created_at: int # timestamp in epoch + bio: Optional[str] = None + gender: Optional[str] = None + date_of_birth: Optional[datetime.date] = None - api_key: Optional[str] = None - settings: Optional[UserSettings] = None info: Optional[dict] = None + settings: Optional[UserSettings] = None + api_key: Optional[str] = None oauth_sub: Optional[str] = None + last_active_at: int # timestamp in epoch + updated_at: int # timestamp in epoch + created_at: int # timestamp in epoch + model_config = ConfigDict(from_attributes=True) @@ -69,11 +87,41 @@ class UserModel(BaseModel): #################### +class UpdateProfileForm(BaseModel): + profile_image_url: str + name: str + bio: Optional[str] = None + gender: Optional[str] = None + date_of_birth: Optional[datetime.date] = None + + class UserListResponse(BaseModel): users: list[UserModel] total: int +class UserInfoResponse(BaseModel): + id: str + name: str + email: str + role: str + + +class UserIdNameResponse(BaseModel): + id: str + name: str + + +class UserInfoListResponse(BaseModel): + users: list[UserInfoResponse] + total: int + + +class UserIdNameListResponse(BaseModel): + users: list[UserIdNameResponse] + total: int + + class UserResponse(BaseModel): id: str name: str @@ -95,6 +143,7 @@ class UserRoleUpdateForm(BaseModel): class UserUpdateForm(BaseModel): + role: str name: str email: str profile_image_url: str @@ -171,7 +220,7 @@ def get_users( filter: Optional[dict] = None, skip: Optional[int] = None, limit: Optional[int] = None, - ) -> UserListResponse: + ) -> dict: with get_db() as db: query = db.query(User) @@ -245,6 +294,10 @@ def get_num_users(self) -> Optional[int]: with get_db() as db: return db.query(User).count() + def has_users(self) -> bool: + with get_db() as db: + return db.query(db.query(User).exists()).scalar() + def get_first_user(self) -> UserModel: try: with get_db() as db: @@ -294,6 +347,7 @@ def update_user_profile_image_url_by_id( except Exception: return None + @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]: try: with get_db() as db: @@ -329,7 +383,8 @@ def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: user = db.query(User).filter_by(id=id).first() return UserModel.model_validate(user) # return UserModel(**user.dict()) - except Exception: + except Exception as e: + print(e) return None def update_user_settings_by_id(self, id: str, updated: dict) -> Optional[UserModel]: @@ -369,7 +424,7 @@ def delete_user_by_id(self, id: str) -> bool: except Exception: return False - def update_user_api_key_by_id(self, id: str, api_key: str) -> str: + def update_user_api_key_by_id(self, id: str, api_key: str) -> bool: try: with get_db() as db: result = db.query(User).filter_by(id=id).update({"api_key": api_key}) diff --git a/backend/open_webui/retrieval/loaders/datalab_marker.py b/backend/open_webui/retrieval/loaders/datalab_marker.py new file mode 100644 index 00000000000..8d14be0a400 --- /dev/null +++ b/backend/open_webui/retrieval/loaders/datalab_marker.py @@ -0,0 +1,278 @@ +import os +import time +import requests +import logging +import json +from typing import List, Optional +from langchain_core.documents import Document +from fastapi import HTTPException, status + +log = logging.getLogger(__name__) + + +class DatalabMarkerLoader: + def __init__( + self, + file_path: str, + api_key: str, + api_base_url: str, + additional_config: Optional[str] = None, + use_llm: bool = False, + skip_cache: bool = False, + force_ocr: bool = False, + paginate: bool = False, + strip_existing_ocr: bool = False, + disable_image_extraction: bool = False, + format_lines: bool = False, + output_format: str = None, + ): + self.file_path = file_path + self.api_key = api_key + self.api_base_url = api_base_url + self.additional_config = additional_config + self.use_llm = use_llm + self.skip_cache = skip_cache + self.force_ocr = force_ocr + self.paginate = paginate + self.strip_existing_ocr = strip_existing_ocr + self.disable_image_extraction = disable_image_extraction + self.format_lines = format_lines + self.output_format = output_format + + def _get_mime_type(self, filename: str) -> str: + ext = filename.rsplit(".", 1)[-1].lower() + mime_map = { + "pdf": "application/pdf", + "xls": "application/vnd.ms-excel", + "xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "ods": "application/vnd.oasis.opendocument.spreadsheet", + "doc": "application/msword", + "docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "odt": "application/vnd.oasis.opendocument.text", + "ppt": "application/vnd.ms-powerpoint", + "pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + "odp": "application/vnd.oasis.opendocument.presentation", + "html": "text/html", + "epub": "application/epub+zip", + "png": "image/png", + "jpeg": "image/jpeg", + "jpg": "image/jpeg", + "webp": "image/webp", + "gif": "image/gif", + "tiff": "image/tiff", + } + return mime_map.get(ext, "application/octet-stream") + + def check_marker_request_status(self, request_id: str) -> dict: + url = f"{self.api_base_url}/{request_id}" + headers = {"X-Api-Key": self.api_key} + try: + response = requests.get(url, headers=headers) + response.raise_for_status() + result = response.json() + log.info(f"Marker API status check for request {request_id}: {result}") + return result + except requests.HTTPError as e: + log.error(f"Error checking Marker request status: {e}") + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=f"Failed to check Marker request: {e}", + ) + except ValueError as e: + log.error(f"Invalid JSON checking Marker request: {e}") + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON: {e}" + ) + + def load(self) -> List[Document]: + filename = os.path.basename(self.file_path) + mime_type = self._get_mime_type(filename) + headers = {"X-Api-Key": self.api_key} + + form_data = { + "use_llm": str(self.use_llm).lower(), + "skip_cache": str(self.skip_cache).lower(), + "force_ocr": str(self.force_ocr).lower(), + "paginate": str(self.paginate).lower(), + "strip_existing_ocr": str(self.strip_existing_ocr).lower(), + "disable_image_extraction": str(self.disable_image_extraction).lower(), + "format_lines": str(self.format_lines).lower(), + "output_format": self.output_format, + } + + if self.additional_config and self.additional_config.strip(): + form_data["additional_config"] = self.additional_config + + log.info( + f"Datalab Marker POST request parameters: {{'filename': '{filename}', 'mime_type': '{mime_type}', **{form_data}}}" + ) + + try: + with open(self.file_path, "rb") as f: + files = {"file": (filename, f, mime_type)} + response = requests.post( + f"{self.api_base_url}", + data=form_data, + files=files, + headers=headers, + ) + response.raise_for_status() + result = response.json() + except FileNotFoundError: + raise HTTPException( + status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" + ) + except requests.HTTPError as e: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"Datalab Marker request failed: {e}", + ) + except ValueError as e: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, detail=f"Invalid JSON response: {e}" + ) + except Exception as e: + raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) + + if not result.get("success"): + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"Datalab Marker request failed: {result.get('error', 'Unknown error')}", + ) + + check_url = result.get("request_check_url") + request_id = result.get("request_id") + + # Check if this is a direct response (self-hosted) or polling response (DataLab) + if check_url: + # DataLab polling pattern + for _ in range(300): # Up to 10 minutes + time.sleep(2) + try: + poll_response = requests.get(check_url, headers=headers) + poll_response.raise_for_status() + poll_result = poll_response.json() + except (requests.HTTPError, ValueError) as e: + raw_body = poll_response.text + log.error(f"Polling error: {e}, response body: {raw_body}") + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, detail=f"Polling failed: {e}" + ) + + status_val = poll_result.get("status") + success_val = poll_result.get("success") + + if status_val == "complete": + summary = { + k: poll_result.get(k) + for k in ( + "status", + "output_format", + "success", + "error", + "page_count", + "total_cost", + ) + } + log.info( + f"Marker processing completed successfully: {json.dumps(summary, indent=2)}" + ) + break + + if status_val == "failed" or success_val is False: + log.error( + f"Marker poll failed full response: {json.dumps(poll_result, indent=2)}" + ) + error_msg = ( + poll_result.get("error") + or "Marker returned failure without error message" + ) + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"Marker processing failed: {error_msg}", + ) + else: + raise HTTPException( + status.HTTP_504_GATEWAY_TIMEOUT, + detail="Marker processing timed out", + ) + + if not poll_result.get("success", False): + error_msg = poll_result.get("error") or "Unknown processing error" + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"Final processing failed: {error_msg}", + ) + + # DataLab format - content in format-specific fields + content_key = self.output_format.lower() + raw_content = poll_result.get(content_key) + final_result = poll_result + else: + # Self-hosted direct response - content in "output" field + if "output" in result: + log.info("Self-hosted Marker returned direct response without polling") + raw_content = result.get("output") + final_result = result + else: + available_fields = ( + list(result.keys()) + if isinstance(result, dict) + else "non-dict response" + ) + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=f"Custom Marker endpoint returned success but no 'output' field found. Available fields: {available_fields}. Expected either 'request_check_url' for polling or 'output' field for direct response.", + ) + + if self.output_format.lower() == "json": + full_text = json.dumps(raw_content, indent=2) + elif self.output_format.lower() in {"markdown", "html"}: + full_text = str(raw_content).strip() + else: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported output format: {self.output_format}", + ) + + if not full_text: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail="Marker returned empty content", + ) + + marker_output_dir = os.path.join("/app/backend/data/uploads", "marker_output") + os.makedirs(marker_output_dir, exist_ok=True) + + file_ext_map = {"markdown": "md", "json": "json", "html": "html"} + file_ext = file_ext_map.get(self.output_format.lower(), "txt") + output_filename = f"{os.path.splitext(filename)[0]}.{file_ext}" + output_path = os.path.join(marker_output_dir, output_filename) + + try: + with open(output_path, "w", encoding="utf-8") as f: + f.write(full_text) + log.info(f"Saved Marker output to: {output_path}") + except Exception as e: + log.warning(f"Failed to write marker output to disk: {e}") + + metadata = { + "source": filename, + "output_format": final_result.get("output_format", self.output_format), + "page_count": final_result.get("page_count", 0), + "processed_with_llm": self.use_llm, + "request_id": request_id or "", + } + + images = final_result.get("images", {}) + if images: + metadata["image_count"] = len(images) + metadata["images"] = json.dumps(list(images.keys())) + + for k, v in metadata.items(): + if isinstance(v, (dict, list)): + metadata[k] = json.dumps(v) + elif v is None: + metadata[k] = "" + + return [Document(page_content=full_text, metadata=metadata)] diff --git a/backend/open_webui/retrieval/loaders/external_document.py b/backend/open_webui/retrieval/loaders/external_document.py index 6119da3791b..1be2ca3f249 100644 --- a/backend/open_webui/retrieval/loaders/external_document.py +++ b/backend/open_webui/retrieval/loaders/external_document.py @@ -1,6 +1,7 @@ import requests -import logging +import logging, os from typing import Iterator, List, Union +from urllib.parse import quote from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document @@ -25,7 +26,7 @@ def __init__( self.file_path = file_path self.mime_type = mime_type - def load(self) -> list[Document]: + def load(self) -> List[Document]: with open(self.file_path, "rb") as f: data = f.read() @@ -36,23 +37,48 @@ def load(self) -> list[Document]: if self.api_key is not None: headers["Authorization"] = f"Bearer {self.api_key}" + try: + headers["X-Filename"] = quote(os.path.basename(self.file_path)) + except: + pass + url = self.url if url.endswith("/"): url = url[:-1] - r = requests.put(f"{url}/process", data=data, headers=headers) + try: + response = requests.put(f"{url}/process", data=data, headers=headers) + except Exception as e: + log.error(f"Error connecting to endpoint: {e}") + raise Exception(f"Error connecting to endpoint: {e}") + + if response.ok: - if r.ok: - res = r.json() + response_data = response.json() + if response_data: + if isinstance(response_data, dict): + return [ + Document( + page_content=response_data.get("page_content"), + metadata=response_data.get("metadata"), + ) + ] + elif isinstance(response_data, list): + documents = [] + for document in response_data: + documents.append( + Document( + page_content=document.get("page_content"), + metadata=document.get("metadata"), + ) + ) + return documents + else: + raise Exception("Error loading document: Unable to parse content") - if res: - return [ - Document( - page_content=res.get("page_content"), - metadata=res.get("metadata"), - ) - ] else: raise Exception("Error loading document: No content returned") else: - raise Exception(f"Error loading document: {r.status_code} {r.text}") + raise Exception( + f"Error loading document: {response.status_code} {response.text}" + ) diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index c5f0b4e5e5d..2ef1d75e026 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -2,7 +2,9 @@ import logging import ftfy import sys +import json +from azure.identity import DefaultAzureCredential from langchain_community.document_loaders import ( AzureAIDocumentIntelligenceLoader, BSHTMLLoader, @@ -13,7 +15,7 @@ TextLoader, UnstructuredEPubLoader, UnstructuredExcelLoader, - UnstructuredMarkdownLoader, + UnstructuredODTLoader, UnstructuredPowerPointLoader, UnstructuredRSTLoader, UnstructuredXMLLoader, @@ -21,9 +23,12 @@ ) from langchain_core.documents import Document - from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader + from open_webui.retrieval.loaders.mistral import MistralLoader +from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader +from open_webui.retrieval.loaders.mineru import MinerULoader + from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL @@ -74,7 +79,6 @@ "swift", "vue", "svelte", - "msg", "ex", "exs", "erl", @@ -145,18 +149,41 @@ def load(self) -> list[Document]: ) } - params = { - "image_export_mode": "placeholder", - "table_mode": "accurate", - } + params = {"image_export_mode": "placeholder"} if self.params: - if self.params.get("do_picture_classification"): - params["do_picture_classification"] = self.params.get( - "do_picture_classification" + if self.params.get("do_picture_description"): + params["do_picture_description"] = self.params.get( + "do_picture_description" ) - if self.params.get("ocr_engine") and self.params.get("ocr_lang"): + picture_description_mode = self.params.get( + "picture_description_mode", "" + ).lower() + + if picture_description_mode == "local" and self.params.get( + "picture_description_local", {} + ): + params["picture_description_local"] = json.dumps( + self.params.get("picture_description_local", {}) + ) + + elif picture_description_mode == "api" and self.params.get( + "picture_description_api", {} + ): + params["picture_description_api"] = json.dumps( + self.params.get("picture_description_api", {}) + ) + + params["do_ocr"] = self.params.get("do_ocr") + + params["force_ocr"] = self.params.get("force_ocr") + + if ( + self.params.get("do_ocr") + and self.params.get("ocr_engine") + and self.params.get("ocr_lang") + ): params["ocr_engine"] = self.params.get("ocr_engine") params["ocr_lang"] = [ lang.strip() @@ -164,7 +191,16 @@ def load(self) -> list[Document]: if lang.strip() ] - endpoint = f"{self.url}/v1alpha/convert/file" + if self.params.get("pdf_backend"): + params["pdf_backend"] = self.params.get("pdf_backend") + + if self.params.get("table_mode"): + params["table_mode"] = self.params.get("table_mode") + + if self.params.get("pipeline"): + params["pipeline"] = self.params.get("pipeline") + + endpoint = f"{self.url}/v1/convert/file" r = requests.post(endpoint, files=files, data=params) if r.ok: @@ -209,7 +245,10 @@ def load( def _is_text_file(self, file_ext: str, file_content_type: str) -> bool: return file_ext in known_source_ext or ( - file_content_type and file_content_type.find("text/") >= 0 + file_content_type + and file_content_type.find("text/") >= 0 + # Avoid text/html files being detected as text + and not file_content_type.find("html") >= 0 ) def _get_loader(self, filename: str, file_content_type: str, file_path: str): @@ -226,7 +265,7 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): api_key=self.kwargs.get("EXTERNAL_DOCUMENT_LOADER_API_KEY"), mime_type=file_content_type, ) - if self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): + elif self.engine == "tika" and self.kwargs.get("TIKA_SERVER_URL"): if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: @@ -236,42 +275,114 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): mime_type=file_content_type, extract_images=self.kwargs.get("PDF_EXTRACT_IMAGES"), ) + elif ( + self.engine == "datalab_marker" + and self.kwargs.get("DATALAB_MARKER_API_KEY") + and file_ext + in [ + "pdf", + "xls", + "xlsx", + "ods", + "doc", + "docx", + "odt", + "ppt", + "pptx", + "odp", + "html", + "epub", + "png", + "jpeg", + "jpg", + "webp", + "gif", + "tiff", + ] + ): + api_base_url = self.kwargs.get("DATALAB_MARKER_API_BASE_URL", "") + if not api_base_url or api_base_url.strip() == "": + api_base_url = "https://www.datalab.to/api/v1/marker" # https://github.com/open-webui/open-webui/pull/16867#issuecomment-3218424349 + + loader = DatalabMarkerLoader( + file_path=file_path, + api_key=self.kwargs["DATALAB_MARKER_API_KEY"], + api_base_url=api_base_url, + additional_config=self.kwargs.get("DATALAB_MARKER_ADDITIONAL_CONFIG"), + use_llm=self.kwargs.get("DATALAB_MARKER_USE_LLM", False), + skip_cache=self.kwargs.get("DATALAB_MARKER_SKIP_CACHE", False), + force_ocr=self.kwargs.get("DATALAB_MARKER_FORCE_OCR", False), + paginate=self.kwargs.get("DATALAB_MARKER_PAGINATE", False), + strip_existing_ocr=self.kwargs.get( + "DATALAB_MARKER_STRIP_EXISTING_OCR", False + ), + disable_image_extraction=self.kwargs.get( + "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION", False + ), + format_lines=self.kwargs.get("DATALAB_MARKER_FORMAT_LINES", False), + output_format=self.kwargs.get( + "DATALAB_MARKER_OUTPUT_FORMAT", "markdown" + ), + ) elif self.engine == "docling" and self.kwargs.get("DOCLING_SERVER_URL"): if self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: + # Build params for DoclingLoader + params = self.kwargs.get("DOCLING_PARAMS", {}) + if not isinstance(params, dict): + try: + params = json.loads(params) + except json.JSONDecodeError: + log.error("Invalid DOCLING_PARAMS format, expected JSON object") + params = {} + loader = DoclingLoader( url=self.kwargs.get("DOCLING_SERVER_URL"), file_path=file_path, mime_type=file_content_type, - params={ - "ocr_engine": self.kwargs.get("DOCLING_OCR_ENGINE"), - "ocr_lang": self.kwargs.get("DOCLING_OCR_LANG"), - "do_picture_classification": self.kwargs.get( - "DOCLING_DO_PICTURE_DESCRIPTION" - ), - }, + params=params, ) elif ( self.engine == "document_intelligence" and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != "" - and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "" and ( - file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"] + file_ext in ["pdf", "docx", "ppt", "pptx"] or file_content_type in [ - "application/vnd.ms-excel", - "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "application/vnd.ms-powerpoint", "application/vnd.openxmlformats-officedocument.presentationml.presentation", ] ) ): - loader = AzureAIDocumentIntelligenceLoader( + if self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != "": + loader = AzureAIDocumentIntelligenceLoader( + file_path=file_path, + api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), + api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), + ) + else: + loader = AzureAIDocumentIntelligenceLoader( + file_path=file_path, + api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), + azure_credential=DefaultAzureCredential(), + ) + elif self.engine == "mineru" and file_ext in [ + "pdf", + "doc", + "docx", + "ppt", + "pptx", + "xls", + "xlsx", + ]: + loader = MinerULoader( file_path=file_path, - api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"), - api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"), + api_mode=self.kwargs.get("MINERU_API_MODE", "local"), + api_url=self.kwargs.get("MINERU_API_URL", "http://localhost:8000"), + api_key=self.kwargs.get("MINERU_API_KEY", ""), + params=self.kwargs.get("MINERU_PARAMS", {}), ) elif ( self.engine == "mistral_ocr" @@ -326,6 +437,8 @@ def _get_loader(self, filename: str, file_content_type: str, file_path: str): loader = UnstructuredPowerPointLoader(file_path) elif file_ext == "msg": loader = OutlookMessageLoader(file_path) + elif file_ext == "odt": + loader = UnstructuredODTLoader(file_path) elif self._is_text_file(file_ext, file_content_type): loader = TextLoader(file_path, autodetect_encoding=True) else: diff --git a/backend/open_webui/retrieval/loaders/mineru.py b/backend/open_webui/retrieval/loaders/mineru.py new file mode 100644 index 00000000000..437f44ae6bc --- /dev/null +++ b/backend/open_webui/retrieval/loaders/mineru.py @@ -0,0 +1,541 @@ +import os +import time +import requests +import logging +import tempfile +import zipfile +from typing import List, Optional +from langchain_core.documents import Document +from fastapi import HTTPException, status + +log = logging.getLogger(__name__) + + +class MinerULoader: + """ + MinerU document parser loader supporting both Cloud API and Local API modes. + + Cloud API: Uses MinerU managed service with async task-based processing + Local API: Uses self-hosted MinerU API with synchronous processing + """ + + def __init__( + self, + file_path: str, + api_mode: str = "local", + api_url: str = "http://localhost:8000", + api_key: str = "", + params: dict = None, + ): + self.file_path = file_path + self.api_mode = api_mode.lower() + self.api_url = api_url.rstrip("/") + self.api_key = api_key + + # Parse params dict with defaults + params = params or {} + self.enable_ocr = params.get("enable_ocr", False) + self.enable_formula = params.get("enable_formula", True) + self.enable_table = params.get("enable_table", True) + self.language = params.get("language", "en") + self.model_version = params.get("model_version", "pipeline") + self.page_ranges = params.get("page_ranges", "") + + # Validate API mode + if self.api_mode not in ["local", "cloud"]: + raise ValueError( + f"Invalid API mode: {self.api_mode}. Must be 'local' or 'cloud'" + ) + + # Validate Cloud API requirements + if self.api_mode == "cloud" and not self.api_key: + raise ValueError("API key is required for Cloud API mode") + + def load(self) -> List[Document]: + """ + Main entry point for loading and parsing the document. + Routes to Cloud or Local API based on api_mode. + """ + try: + if self.api_mode == "cloud": + return self._load_cloud_api() + else: + return self._load_local_api() + except Exception as e: + log.error(f"Error loading document with MinerU: {e}") + raise + + def _load_local_api(self) -> List[Document]: + """ + Load document using Local API (synchronous). + Posts file to /file_parse endpoint and gets immediate response. + """ + log.info(f"Using MinerU Local API at {self.api_url}") + + filename = os.path.basename(self.file_path) + + # Build form data for Local API + form_data = { + "return_md": "true", + "formula_enable": str(self.enable_formula).lower(), + "table_enable": str(self.enable_table).lower(), + } + + # Parse method based on OCR setting + if self.enable_ocr: + form_data["parse_method"] = "ocr" + else: + form_data["parse_method"] = "auto" + + # Language configuration (Local API uses lang_list array) + if self.language: + form_data["lang_list"] = self.language + + # Backend/model version (Local API uses "backend" parameter) + if self.model_version == "vlm": + form_data["backend"] = "vlm-vllm-engine" + else: + form_data["backend"] = "pipeline" + + # Page ranges (Local API uses start_page_id and end_page_id) + if self.page_ranges: + # For simplicity, if page_ranges is specified, log a warning + # Full page range parsing would require parsing the string + log.warning( + f"Page ranges '{self.page_ranges}' specified but Local API uses different format. " + "Consider using start_page_id/end_page_id parameters if needed." + ) + + try: + with open(self.file_path, "rb") as f: + files = {"files": (filename, f, "application/octet-stream")} + + log.info(f"Sending file to MinerU Local API: {filename}") + log.debug(f"Local API parameters: {form_data}") + + response = requests.post( + f"{self.api_url}/file_parse", + data=form_data, + files=files, + timeout=300, # 5 minute timeout for large documents + ) + response.raise_for_status() + + except FileNotFoundError: + raise HTTPException( + status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" + ) + except requests.Timeout: + raise HTTPException( + status.HTTP_504_GATEWAY_TIMEOUT, + detail="MinerU Local API request timed out", + ) + except requests.HTTPError as e: + error_detail = f"MinerU Local API request failed: {e}" + if e.response is not None: + try: + error_data = e.response.json() + error_detail += f" - {error_data}" + except: + error_detail += f" - {e.response.text}" + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error calling MinerU Local API: {str(e)}", + ) + + # Parse response + try: + result = response.json() + except ValueError as e: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=f"Invalid JSON response from MinerU Local API: {e}", + ) + + # Extract markdown content from response + if "results" not in result: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail="MinerU Local API response missing 'results' field", + ) + + results = result["results"] + if not results: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail="MinerU returned empty results", + ) + + # Get the first (and typically only) result + file_result = list(results.values())[0] + markdown_content = file_result.get("md_content", "") + + if not markdown_content: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail="MinerU returned empty markdown content", + ) + + log.info(f"Successfully parsed document with MinerU Local API: {filename}") + + # Create metadata + metadata = { + "source": filename, + "api_mode": "local", + "backend": result.get("backend", "unknown"), + "version": result.get("version", "unknown"), + } + + return [Document(page_content=markdown_content, metadata=metadata)] + + def _load_cloud_api(self) -> List[Document]: + """ + Load document using Cloud API (asynchronous). + Uses batch upload endpoint to avoid need for public file URLs. + """ + log.info(f"Using MinerU Cloud API at {self.api_url}") + + filename = os.path.basename(self.file_path) + + # Step 1: Request presigned upload URL + batch_id, upload_url = self._request_upload_url(filename) + + # Step 2: Upload file to presigned URL + self._upload_to_presigned_url(upload_url) + + # Step 3: Poll for results + result = self._poll_batch_status(batch_id, filename) + + # Step 4: Download and extract markdown from ZIP + markdown_content = self._download_and_extract_zip( + result["full_zip_url"], filename + ) + + log.info(f"Successfully parsed document with MinerU Cloud API: {filename}") + + # Create metadata + metadata = { + "source": filename, + "api_mode": "cloud", + "batch_id": batch_id, + } + + return [Document(page_content=markdown_content, metadata=metadata)] + + def _request_upload_url(self, filename: str) -> tuple: + """ + Request presigned upload URL from Cloud API. + Returns (batch_id, upload_url). + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + # Build request body + request_body = { + "enable_formula": self.enable_formula, + "enable_table": self.enable_table, + "language": self.language, + "model_version": self.model_version, + "files": [ + { + "name": filename, + "is_ocr": self.enable_ocr, + } + ], + } + + # Add page ranges if specified + if self.page_ranges: + request_body["files"][0]["page_ranges"] = self.page_ranges + + log.info(f"Requesting upload URL for: {filename}") + log.debug(f"Cloud API request body: {request_body}") + + try: + response = requests.post( + f"{self.api_url}/file-urls/batch", + headers=headers, + json=request_body, + timeout=30, + ) + response.raise_for_status() + except requests.HTTPError as e: + error_detail = f"Failed to request upload URL: {e}" + if e.response is not None: + try: + error_data = e.response.json() + error_detail += f" - {error_data.get('msg', error_data)}" + except: + error_detail += f" - {e.response.text}" + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error requesting upload URL: {str(e)}", + ) + + try: + result = response.json() + except ValueError as e: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=f"Invalid JSON response: {e}", + ) + + # Check for API error response + if result.get("code") != 0: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", + ) + + data = result.get("data", {}) + batch_id = data.get("batch_id") + file_urls = data.get("file_urls", []) + + if not batch_id or not file_urls: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail="MinerU Cloud API response missing batch_id or file_urls", + ) + + upload_url = file_urls[0] + log.info(f"Received upload URL for batch: {batch_id}") + + return batch_id, upload_url + + def _upload_to_presigned_url(self, upload_url: str) -> None: + """ + Upload file to presigned URL (no authentication needed). + """ + log.info(f"Uploading file to presigned URL") + + try: + with open(self.file_path, "rb") as f: + response = requests.put( + upload_url, + data=f, + timeout=300, # 5 minute timeout for large files + ) + response.raise_for_status() + except FileNotFoundError: + raise HTTPException( + status.HTTP_404_NOT_FOUND, detail=f"File not found: {self.file_path}" + ) + except requests.Timeout: + raise HTTPException( + status.HTTP_504_GATEWAY_TIMEOUT, + detail="File upload to presigned URL timed out", + ) + except requests.HTTPError as e: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"Failed to upload file to presigned URL: {e}", + ) + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error uploading file: {str(e)}", + ) + + log.info("File uploaded successfully") + + def _poll_batch_status(self, batch_id: str, filename: str) -> dict: + """ + Poll batch status until completion. + Returns the result dict for the file. + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + } + + max_iterations = 300 # 10 minutes max (2 seconds per iteration) + poll_interval = 2 # seconds + + log.info(f"Polling batch status: {batch_id}") + + for iteration in range(max_iterations): + try: + response = requests.get( + f"{self.api_url}/extract-results/batch/{batch_id}", + headers=headers, + timeout=30, + ) + response.raise_for_status() + except requests.HTTPError as e: + error_detail = f"Failed to poll batch status: {e}" + if e.response is not None: + try: + error_data = e.response.json() + error_detail += f" - {error_data.get('msg', error_data)}" + except: + error_detail += f" - {e.response.text}" + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=error_detail) + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error polling batch status: {str(e)}", + ) + + try: + result = response.json() + except ValueError as e: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=f"Invalid JSON response while polling: {e}", + ) + + # Check for API error response + if result.get("code") != 0: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"MinerU Cloud API error: {result.get('msg', 'Unknown error')}", + ) + + data = result.get("data", {}) + extract_result = data.get("extract_result", []) + + # Find our file in the batch results + file_result = None + for item in extract_result: + if item.get("file_name") == filename: + file_result = item + break + + if not file_result: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=f"File {filename} not found in batch results", + ) + + state = file_result.get("state") + + if state == "done": + log.info(f"Processing complete for {filename}") + return file_result + elif state == "failed": + error_msg = file_result.get("err_msg", "Unknown error") + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"MinerU processing failed: {error_msg}", + ) + elif state in ["waiting-file", "pending", "running", "converting"]: + # Still processing + if iteration % 10 == 0: # Log every 20 seconds + log.info( + f"Processing status: {state} (iteration {iteration + 1}/{max_iterations})" + ) + time.sleep(poll_interval) + else: + log.warning(f"Unknown state: {state}") + time.sleep(poll_interval) + + # Timeout + raise HTTPException( + status.HTTP_504_GATEWAY_TIMEOUT, + detail="MinerU processing timed out after 10 minutes", + ) + + def _download_and_extract_zip(self, zip_url: str, filename: str) -> str: + """ + Download ZIP file from CDN and extract markdown content. + Returns the markdown content as a string. + """ + log.info(f"Downloading results from: {zip_url}") + + try: + response = requests.get(zip_url, timeout=60) + response.raise_for_status() + except requests.HTTPError as e: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"Failed to download results ZIP: {e}", + ) + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error downloading results: {str(e)}", + ) + + # Save ZIP to temporary file and extract + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: + tmp_zip.write(response.content) + tmp_zip_path = tmp_zip.name + + with tempfile.TemporaryDirectory() as tmp_dir: + # Extract ZIP + with zipfile.ZipFile(tmp_zip_path, "r") as zip_ref: + zip_ref.extractall(tmp_dir) + + # Find markdown file - search recursively for any .md file + markdown_content = None + found_md_path = None + + # First, list all files in the ZIP for debugging + all_files = [] + for root, dirs, files in os.walk(tmp_dir): + for file in files: + full_path = os.path.join(root, file) + all_files.append(full_path) + # Look for any .md file + if file.endswith(".md"): + found_md_path = full_path + log.info(f"Found markdown file at: {full_path}") + try: + with open(full_path, "r", encoding="utf-8") as f: + markdown_content = f.read() + if ( + markdown_content + ): # Use the first non-empty markdown file + break + except Exception as e: + log.warning(f"Failed to read {full_path}: {e}") + if markdown_content: + break + + if markdown_content is None: + log.error(f"Available files in ZIP: {all_files}") + # Try to provide more helpful error message + md_files = [f for f in all_files if f.endswith(".md")] + if md_files: + error_msg = ( + f"Found .md files but couldn't read them: {md_files}" + ) + else: + error_msg = ( + f"No .md files found in ZIP. Available files: {all_files}" + ) + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=error_msg, + ) + + # Clean up temporary ZIP file + os.unlink(tmp_zip_path) + + except zipfile.BadZipFile as e: + raise HTTPException( + status.HTTP_502_BAD_GATEWAY, + detail=f"Invalid ZIP file received: {e}", + ) + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error extracting ZIP: {str(e)}", + ) + + if not markdown_content: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail="Extracted markdown content is empty", + ) + + log.info( + f"Successfully extracted markdown content ({len(markdown_content)} characters)" + ) + return markdown_content diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py index 8f3a960a283..b7f2622f5e0 100644 --- a/backend/open_webui/retrieval/loaders/mistral.py +++ b/backend/open_webui/retrieval/loaders/mistral.py @@ -1,8 +1,12 @@ import requests +import aiohttp +import asyncio import logging import os import sys +import time from typing import List, Dict, Any +from contextlib import asynccontextmanager from langchain_core.documents import Document from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL @@ -14,18 +18,37 @@ class MistralLoader: """ + Enhanced Mistral OCR loader with both sync and async support. Loads documents by processing them through the Mistral OCR API. + + Performance Optimizations: + - Differentiated timeouts for different operations + - Intelligent retry logic with exponential backoff + - Memory-efficient file streaming for large files + - Connection pooling and keepalive optimization + - Semaphore-based concurrency control for batch processing + - Enhanced error handling with retryable error classification """ BASE_API_URL = "https://api.mistral.ai/v1" - def __init__(self, api_key: str, file_path: str): + def __init__( + self, + api_key: str, + file_path: str, + timeout: int = 300, # 5 minutes default + max_retries: int = 3, + enable_debug_logging: bool = False, + ): """ - Initializes the loader. + Initializes the loader with enhanced features. Args: api_key: Your Mistral API key. file_path: The local path to the PDF file to process. + timeout: Request timeout in seconds. + max_retries: Maximum number of retry attempts. + enable_debug_logging: Enable detailed debug logs. """ if not api_key: raise ValueError("API key cannot be empty.") @@ -34,7 +57,46 @@ def __init__(self, api_key: str, file_path: str): self.api_key = api_key self.file_path = file_path - self.headers = {"Authorization": f"Bearer {self.api_key}"} + self.timeout = timeout + self.max_retries = max_retries + self.debug = enable_debug_logging + + # PERFORMANCE OPTIMIZATION: Differentiated timeouts for different operations + # This prevents long-running OCR operations from affecting quick operations + # and improves user experience by failing fast on operations that should be quick + self.upload_timeout = min( + timeout, 120 + ) # Cap upload at 2 minutes - prevents hanging on large files + self.url_timeout = ( + 30 # URL requests should be fast - fail quickly if API is slow + ) + self.ocr_timeout = ( + timeout # OCR can take the full timeout - this is the heavy operation + ) + self.cleanup_timeout = ( + 30 # Cleanup should be quick - don't hang on file deletion + ) + + # PERFORMANCE OPTIMIZATION: Pre-compute file info to avoid repeated filesystem calls + # This avoids multiple os.path.basename() and os.path.getsize() calls during processing + self.file_name = os.path.basename(file_path) + self.file_size = os.path.getsize(file_path) + + # ENHANCEMENT: Added User-Agent for better API tracking and debugging + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "User-Agent": "OpenWebUI-MistralLoader/2.0", # Helps API provider track usage + } + + def _debug_log(self, message: str, *args) -> None: + """ + PERFORMANCE OPTIMIZATION: Conditional debug logging for performance. + + Only processes debug messages when debug mode is enabled, avoiding + string formatting overhead in production environments. + """ + if self.debug: + log.debug(message, *args) def _handle_response(self, response: requests.Response) -> Dict[str, Any]: """Checks response status and returns JSON content.""" @@ -54,24 +116,154 @@ def _handle_response(self, response: requests.Response) -> Dict[str, Any]: log.error(f"JSON decode error: {json_err} - Response: {response.text}") raise # Re-raise after logging + async def _handle_response_async( + self, response: aiohttp.ClientResponse + ) -> Dict[str, Any]: + """Async version of response handling with better error info.""" + try: + response.raise_for_status() + + # Check content type + content_type = response.headers.get("content-type", "") + if "application/json" not in content_type: + if response.status == 204: + return {} + text = await response.text() + raise ValueError( + f"Unexpected content type: {content_type}, body: {text[:200]}..." + ) + + return await response.json() + + except aiohttp.ClientResponseError as e: + error_text = await response.text() if response else "No response" + log.error(f"HTTP {e.status}: {e.message} - Response: {error_text[:500]}") + raise + except aiohttp.ClientError as e: + log.error(f"Client error: {e}") + raise + except Exception as e: + log.error(f"Unexpected error processing response: {e}") + raise + + def _is_retryable_error(self, error: Exception) -> bool: + """ + ENHANCEMENT: Intelligent error classification for retry logic. + + Determines if an error is retryable based on its type and status code. + This prevents wasting time retrying errors that will never succeed + (like authentication errors) while ensuring transient errors are retried. + + Retryable errors: + - Network connection errors (temporary network issues) + - Timeouts (server might be temporarily overloaded) + - Server errors (5xx status codes - server-side issues) + - Rate limiting (429 status - temporary throttling) + + Non-retryable errors: + - Authentication errors (401, 403 - won't fix with retry) + - Bad request errors (400 - malformed request) + - Not found errors (404 - resource doesn't exist) + """ + if isinstance(error, requests.exceptions.ConnectionError): + return True # Network issues are usually temporary + if isinstance(error, requests.exceptions.Timeout): + return True # Timeouts might resolve on retry + if isinstance(error, requests.exceptions.HTTPError): + # Only retry on server errors (5xx) or rate limits (429) + if hasattr(error, "response") and error.response is not None: + status_code = error.response.status_code + return status_code >= 500 or status_code == 429 + return False + if isinstance( + error, (aiohttp.ClientConnectionError, aiohttp.ServerTimeoutError) + ): + return True # Async network/timeout errors are retryable + if isinstance(error, aiohttp.ClientResponseError): + return error.status >= 500 or error.status == 429 + return False # All other errors are non-retryable + + def _retry_request_sync(self, request_func, *args, **kwargs): + """ + ENHANCEMENT: Synchronous retry logic with intelligent error classification. + + Uses exponential backoff with jitter to avoid thundering herd problems. + The wait time increases exponentially but is capped at 30 seconds to + prevent excessive delays. Only retries errors that are likely to succeed + on subsequent attempts. + """ + for attempt in range(self.max_retries): + try: + return request_func(*args, **kwargs) + except Exception as e: + if attempt == self.max_retries - 1 or not self._is_retryable_error(e): + raise + + # PERFORMANCE OPTIMIZATION: Exponential backoff with cap + # Prevents overwhelming the server while ensuring reasonable retry delays + wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds + log.warning( + f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " + f"Retrying in {wait_time}s..." + ) + time.sleep(wait_time) + + async def _retry_request_async(self, request_func, *args, **kwargs): + """ + ENHANCEMENT: Async retry logic with intelligent error classification. + + Async version of retry logic that doesn't block the event loop during + wait periods. Uses the same exponential backoff strategy as sync version. + """ + for attempt in range(self.max_retries): + try: + return await request_func(*args, **kwargs) + except Exception as e: + if attempt == self.max_retries - 1 or not self._is_retryable_error(e): + raise + + # PERFORMANCE OPTIMIZATION: Non-blocking exponential backoff + wait_time = min((2**attempt) + 0.5, 30) # Cap at 30 seconds + log.warning( + f"Retryable error (attempt {attempt + 1}/{self.max_retries}): {e}. " + f"Retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) # Non-blocking wait + def _upload_file(self) -> str: - """Uploads the file to Mistral for OCR processing.""" + """ + PERFORMANCE OPTIMIZATION: Enhanced file upload with streaming consideration. + + Uploads the file to Mistral for OCR processing (sync version). + Uses context manager for file handling to ensure proper resource cleanup. + Although streaming is not enabled for this endpoint, the file is opened + in a context manager to minimize memory usage duration. + """ log.info("Uploading file to Mistral API") url = f"{self.BASE_API_URL}/files" - file_name = os.path.basename(self.file_path) - try: + def upload_request(): + # MEMORY OPTIMIZATION: Use context manager to minimize file handle lifetime + # This ensures the file is closed immediately after reading, reducing memory usage with open(self.file_path, "rb") as f: - files = {"file": (file_name, f, "application/pdf")} + files = {"file": (self.file_name, f, "application/pdf")} data = {"purpose": "ocr"} - upload_headers = self.headers.copy() # Avoid modifying self.headers - + # NOTE: stream=False is required for this endpoint + # The Mistral API doesn't support chunked uploads for this endpoint response = requests.post( - url, headers=upload_headers, files=files, data=data + url, + headers=self.headers, + files=files, + data=data, + timeout=self.upload_timeout, # Use specialized upload timeout + stream=False, # Keep as False for this endpoint ) - response_data = self._handle_response(response) + return self._handle_response(response) + + try: + response_data = self._retry_request_sync(upload_request) file_id = response_data.get("id") if not file_id: raise ValueError("File ID not found in upload response.") @@ -81,16 +273,66 @@ def _upload_file(self) -> str: log.error(f"Failed to upload file: {e}") raise + async def _upload_file_async(self, session: aiohttp.ClientSession) -> str: + """Async file upload with streaming for better memory efficiency.""" + url = f"{self.BASE_API_URL}/files" + + async def upload_request(): + # Create multipart writer for streaming upload + writer = aiohttp.MultipartWriter("form-data") + + # Add purpose field + purpose_part = writer.append("ocr") + purpose_part.set_content_disposition("form-data", name="purpose") + + # Add file part with streaming + file_part = writer.append_payload( + aiohttp.streams.FilePayload( + self.file_path, + filename=self.file_name, + content_type="application/pdf", + ) + ) + file_part.set_content_disposition( + "form-data", name="file", filename=self.file_name + ) + + self._debug_log( + f"Uploading file: {self.file_name} ({self.file_size:,} bytes)" + ) + + async with session.post( + url, + data=writer, + headers=self.headers, + timeout=aiohttp.ClientTimeout(total=self.upload_timeout), + ) as response: + return await self._handle_response_async(response) + + response_data = await self._retry_request_async(upload_request) + + file_id = response_data.get("id") + if not file_id: + raise ValueError("File ID not found in upload response.") + + log.info(f"File uploaded successfully. File ID: {file_id}") + return file_id + def _get_signed_url(self, file_id: str) -> str: - """Retrieves a temporary signed URL for the uploaded file.""" + """Retrieves a temporary signed URL for the uploaded file (sync version).""" log.info(f"Getting signed URL for file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}/url" params = {"expiry": 1} signed_url_headers = {**self.headers, "Accept": "application/json"} + def url_request(): + response = requests.get( + url, headers=signed_url_headers, params=params, timeout=self.url_timeout + ) + return self._handle_response(response) + try: - response = requests.get(url, headers=signed_url_headers, params=params) - response_data = self._handle_response(response) + response_data = self._retry_request_sync(url_request) signed_url = response_data.get("url") if not signed_url: raise ValueError("Signed URL not found in response.") @@ -100,8 +342,36 @@ def _get_signed_url(self, file_id: str) -> str: log.error(f"Failed to get signed URL: {e}") raise + async def _get_signed_url_async( + self, session: aiohttp.ClientSession, file_id: str + ) -> str: + """Async signed URL retrieval.""" + url = f"{self.BASE_API_URL}/files/{file_id}/url" + params = {"expiry": 1} + + headers = {**self.headers, "Accept": "application/json"} + + async def url_request(): + self._debug_log(f"Getting signed URL for file ID: {file_id}") + async with session.get( + url, + headers=headers, + params=params, + timeout=aiohttp.ClientTimeout(total=self.url_timeout), + ) as response: + return await self._handle_response_async(response) + + response_data = await self._retry_request_async(url_request) + + signed_url = response_data.get("url") + if not signed_url: + raise ValueError("Signed URL not found in response.") + + self._debug_log("Signed URL received successfully") + return signed_url + def _process_ocr(self, signed_url: str) -> Dict[str, Any]: - """Sends the signed URL to the OCR endpoint for processing.""" + """Sends the signed URL to the OCR endpoint for processing (sync version).""" log.info("Processing OCR via Mistral API") url = f"{self.BASE_API_URL}/ocr" ocr_headers = { @@ -118,43 +388,218 @@ def _process_ocr(self, signed_url: str) -> Dict[str, Any]: "include_image_base64": False, } + def ocr_request(): + response = requests.post( + url, headers=ocr_headers, json=payload, timeout=self.ocr_timeout + ) + return self._handle_response(response) + try: - response = requests.post(url, headers=ocr_headers, json=payload) - ocr_response = self._handle_response(response) + ocr_response = self._retry_request_sync(ocr_request) log.info("OCR processing done.") - log.debug("OCR response: %s", ocr_response) + self._debug_log("OCR response: %s", ocr_response) return ocr_response except Exception as e: log.error(f"Failed during OCR processing: {e}") raise + async def _process_ocr_async( + self, session: aiohttp.ClientSession, signed_url: str + ) -> Dict[str, Any]: + """Async OCR processing with timing metrics.""" + url = f"{self.BASE_API_URL}/ocr" + + headers = { + **self.headers, + "Content-Type": "application/json", + "Accept": "application/json", + } + + payload = { + "model": "mistral-ocr-latest", + "document": { + "type": "document_url", + "document_url": signed_url, + }, + "include_image_base64": False, + } + + async def ocr_request(): + log.info("Starting OCR processing via Mistral API") + start_time = time.time() + + async with session.post( + url, + json=payload, + headers=headers, + timeout=aiohttp.ClientTimeout(total=self.ocr_timeout), + ) as response: + ocr_response = await self._handle_response_async(response) + + processing_time = time.time() - start_time + log.info(f"OCR processing completed in {processing_time:.2f}s") + + return ocr_response + + return await self._retry_request_async(ocr_request) + def _delete_file(self, file_id: str) -> None: - """Deletes the file from Mistral storage.""" + """Deletes the file from Mistral storage (sync version).""" log.info(f"Deleting uploaded file ID: {file_id}") url = f"{self.BASE_API_URL}/files/{file_id}" - # No specific Accept header needed, default or Authorization is usually sufficient try: - response = requests.delete(url, headers=self.headers) - delete_response = self._handle_response( - response - ) # Check status, ignore response body unless needed - log.info( - f"File deleted successfully: {delete_response}" - ) # Log the response if available + response = requests.delete( + url, headers=self.headers, timeout=self.cleanup_timeout + ) + delete_response = self._handle_response(response) + log.info(f"File deleted successfully: {delete_response}") except Exception as e: # Log error but don't necessarily halt execution if deletion fails log.error(f"Failed to delete file ID {file_id}: {e}") - # Depending on requirements, you might choose to raise the error here + + async def _delete_file_async( + self, session: aiohttp.ClientSession, file_id: str + ) -> None: + """Async file deletion with error tolerance.""" + try: + + async def delete_request(): + self._debug_log(f"Deleting file ID: {file_id}") + async with session.delete( + url=f"{self.BASE_API_URL}/files/{file_id}", + headers=self.headers, + timeout=aiohttp.ClientTimeout( + total=self.cleanup_timeout + ), # Shorter timeout for cleanup + ) as response: + return await self._handle_response_async(response) + + await self._retry_request_async(delete_request) + self._debug_log(f"File {file_id} deleted successfully") + + except Exception as e: + # Don't fail the entire process if cleanup fails + log.warning(f"Failed to delete file ID {file_id}: {e}") + + @asynccontextmanager + async def _get_session(self): + """Context manager for HTTP session with optimized settings.""" + connector = aiohttp.TCPConnector( + limit=20, # Increased total connection limit for better throughput + limit_per_host=10, # Increased per-host limit for API endpoints + ttl_dns_cache=600, # Longer DNS cache TTL (10 minutes) + use_dns_cache=True, + keepalive_timeout=60, # Increased keepalive for connection reuse + enable_cleanup_closed=True, + force_close=False, # Allow connection reuse + resolver=aiohttp.AsyncResolver(), # Use async DNS resolver + ) + + timeout = aiohttp.ClientTimeout( + total=self.timeout, + connect=30, # Connection timeout + sock_read=60, # Socket read timeout + ) + + async with aiohttp.ClientSession( + connector=connector, + timeout=timeout, + headers={"User-Agent": "OpenWebUI-MistralLoader/2.0"}, + raise_for_status=False, # We handle status codes manually + trust_env=True, + ) as session: + yield session + + def _process_results(self, ocr_response: Dict[str, Any]) -> List[Document]: + """Process OCR results into Document objects with enhanced metadata and memory efficiency.""" + pages_data = ocr_response.get("pages") + if not pages_data: + log.warning("No pages found in OCR response.") + return [ + Document( + page_content="No text content found", + metadata={"error": "no_pages", "file_name": self.file_name}, + ) + ] + + documents = [] + total_pages = len(pages_data) + skipped_pages = 0 + + # Process pages in a memory-efficient way + for page_data in pages_data: + page_content = page_data.get("markdown") + page_index = page_data.get("index") # API uses 0-based index + + if page_content is None or page_index is None: + skipped_pages += 1 + self._debug_log( + f"Skipping page due to missing 'markdown' or 'index'. Data keys: {list(page_data.keys())}" + ) + continue + + # Clean up content efficiently with early exit for empty content + if isinstance(page_content, str): + cleaned_content = page_content.strip() + else: + cleaned_content = str(page_content).strip() + + if not cleaned_content: + skipped_pages += 1 + self._debug_log(f"Skipping empty page {page_index}") + continue + + # Create document with optimized metadata + documents.append( + Document( + page_content=cleaned_content, + metadata={ + "page": page_index, # 0-based index from API + "page_label": page_index + 1, # 1-based label for convenience + "total_pages": total_pages, + "file_name": self.file_name, + "file_size": self.file_size, + "processing_engine": "mistral-ocr", + "content_length": len(cleaned_content), + }, + ) + ) + + if skipped_pages > 0: + log.info( + f"Processed {len(documents)} pages, skipped {skipped_pages} empty/invalid pages" + ) + + if not documents: + # Case where pages existed but none had valid markdown/index + log.warning( + "OCR response contained pages, but none had valid content/index." + ) + return [ + Document( + page_content="No valid text content found in document", + metadata={ + "error": "no_valid_pages", + "total_pages": total_pages, + "file_name": self.file_name, + }, + ) + ] + + return documents def load(self) -> List[Document]: """ Executes the full OCR workflow: upload, get URL, process OCR, delete file. + Synchronous version for backward compatibility. Returns: A list of Document objects, one for each page processed. """ file_id = None + start_time = time.time() + try: # 1. Upload file file_id = self._upload_file() @@ -166,53 +611,30 @@ def load(self) -> List[Document]: ocr_response = self._process_ocr(signed_url) # 4. Process results - pages_data = ocr_response.get("pages") - if not pages_data: - log.warning("No pages found in OCR response.") - return [Document(page_content="No text content found", metadata={})] - - documents = [] - total_pages = len(pages_data) - for page_data in pages_data: - page_content = page_data.get("markdown") - page_index = page_data.get("index") # API uses 0-based index - - if page_content is not None and page_index is not None: - documents.append( - Document( - page_content=page_content, - metadata={ - "page": page_index, # 0-based index from API - "page_label": page_index - + 1, # 1-based label for convenience - "total_pages": total_pages, - # Add other relevant metadata from page_data if available/needed - # e.g., page_data.get('width'), page_data.get('height') - }, - ) - ) - else: - log.warning( - f"Skipping page due to missing 'markdown' or 'index'. Data: {page_data}" - ) + documents = self._process_results(ocr_response) - if not documents: - # Case where pages existed but none had valid markdown/index - log.warning( - "OCR response contained pages, but none had valid content/index." - ) - return [ - Document( - page_content="No text content found in valid pages", metadata={} - ) - ] + total_time = time.time() - start_time + log.info( + f"Sync OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" + ) return documents except Exception as e: - log.error(f"An error occurred during the loading process: {e}") - # Return an empty list or a specific error document on failure - return [Document(page_content=f"Error during processing: {e}", metadata={})] + total_time = time.time() - start_time + log.error( + f"An error occurred during the loading process after {total_time:.2f}s: {e}" + ) + # Return an error document on failure + return [ + Document( + page_content=f"Error during processing: {e}", + metadata={ + "error": "processing_failed", + "file_name": self.file_name, + }, + ) + ] finally: # 5. Delete file (attempt even if prior steps failed after upload) if file_id: @@ -223,3 +645,124 @@ def load(self) -> List[Document]: log.error( f"Cleanup error: Could not delete file ID {file_id}. Reason: {del_e}" ) + + async def load_async(self) -> List[Document]: + """ + Asynchronous OCR workflow execution with optimized performance. + + Returns: + A list of Document objects, one for each page processed. + """ + file_id = None + start_time = time.time() + + try: + async with self._get_session() as session: + # 1. Upload file with streaming + file_id = await self._upload_file_async(session) + + # 2. Get signed URL + signed_url = await self._get_signed_url_async(session, file_id) + + # 3. Process OCR + ocr_response = await self._process_ocr_async(session, signed_url) + + # 4. Process results + documents = self._process_results(ocr_response) + + total_time = time.time() - start_time + log.info( + f"Async OCR workflow completed in {total_time:.2f}s, produced {len(documents)} documents" + ) + + return documents + + except Exception as e: + total_time = time.time() - start_time + log.error(f"Async OCR workflow failed after {total_time:.2f}s: {e}") + return [ + Document( + page_content=f"Error during OCR processing: {e}", + metadata={ + "error": "processing_failed", + "file_name": self.file_name, + }, + ) + ] + finally: + # 5. Cleanup - always attempt file deletion + if file_id: + try: + async with self._get_session() as session: + await self._delete_file_async(session, file_id) + except Exception as cleanup_error: + log.error(f"Cleanup failed for file ID {file_id}: {cleanup_error}") + + @staticmethod + async def load_multiple_async( + loaders: List["MistralLoader"], + max_concurrent: int = 5, # Limit concurrent requests + ) -> List[List[Document]]: + """ + Process multiple files concurrently with controlled concurrency. + + Args: + loaders: List of MistralLoader instances + max_concurrent: Maximum number of concurrent requests + + Returns: + List of document lists, one for each loader + """ + if not loaders: + return [] + + log.info( + f"Starting concurrent processing of {len(loaders)} files with max {max_concurrent} concurrent" + ) + start_time = time.time() + + # Use semaphore to control concurrency + semaphore = asyncio.Semaphore(max_concurrent) + + async def process_with_semaphore(loader: "MistralLoader") -> List[Document]: + async with semaphore: + return await loader.load_async() + + # Process all files with controlled concurrency + tasks = [process_with_semaphore(loader) for loader in loaders] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Handle any exceptions in results + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + log.error(f"File {i} failed: {result}") + processed_results.append( + [ + Document( + page_content=f"Error processing file: {result}", + metadata={ + "error": "batch_processing_failed", + "file_index": i, + }, + ) + ] + ) + else: + processed_results.append(result) + + # MONITORING: Log comprehensive batch processing statistics + total_time = time.time() - start_time + total_docs = sum(len(docs) for docs in processed_results) + success_count = sum( + 1 for result in results if not isinstance(result, Exception) + ) + failure_count = len(results) - success_count + + log.info( + f"Batch processing completed in {total_time:.2f}s: " + f"{success_count} files succeeded, {failure_count} files failed, " + f"produced {total_docs} total documents" + ) + + return processed_results diff --git a/backend/open_webui/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py index d908cc8cb50..da17eaef651 100644 --- a/backend/open_webui/retrieval/loaders/youtube.py +++ b/backend/open_webui/retrieval/loaders/youtube.py @@ -1,4 +1,5 @@ import logging +from xml.etree.ElementTree import ParseError from typing import Any, Dict, Generator, List, Optional, Sequence, Union from urllib.parse import parse_qs, urlparse @@ -93,15 +94,13 @@ def load(self) -> List[Document]: "http": self.proxy_url, "https": self.proxy_url, } - # Don't log complete URL because it might contain secrets log.debug(f"Using proxy URL: {self.proxy_url[:14]}...") else: youtube_proxies = None + transcript_api = YouTubeTranscriptApi(proxy_config=youtube_proxies) try: - transcript_list = YouTubeTranscriptApi.list_transcripts( - self.video_id, proxies=youtube_proxies - ) + transcript_list = transcript_api.list(self.video_id) except Exception as e: log.exception("Loading YouTube transcript failed") return [] @@ -110,11 +109,37 @@ def load(self) -> List[Document]: for lang in self.language: try: transcript = transcript_list.find_transcript([lang]) + if transcript.is_generated: + log.debug(f"Found generated transcript for language '{lang}'") + try: + transcript = transcript_list.find_manually_created_transcript( + [lang] + ) + log.debug(f"Found manual transcript for language '{lang}'") + except NoTranscriptFound: + log.debug( + f"No manual transcript found for language '{lang}', using generated" + ) + pass + log.debug(f"Found transcript for language '{lang}'") - transcript_pieces: List[Dict[str, Any]] = transcript.fetch() + try: + transcript_pieces: List[Dict[str, Any]] = transcript.fetch() + except ParseError: + log.debug(f"Empty or invalid transcript for language '{lang}'") + continue + + if not transcript_pieces: + log.debug(f"Empty transcript for language '{lang}'") + continue + transcript_text = " ".join( map( - lambda transcript_piece: transcript_piece.text.strip(" "), + lambda transcript_piece: ( + transcript_piece.text.strip(" ") + if hasattr(transcript_piece, "text") + else "" + ), transcript_pieces, ) ) @@ -131,6 +156,11 @@ def load(self) -> List[Document]: log.warning( f"No transcript found for any of the specified languages: {languages_tried}. Verify if the video has transcripts, add more languages if needed." ) - raise NoTranscriptFound( - f"No transcript found for any supported language. Verify if the video has transcripts, add more languages if needed." - ) + raise NoTranscriptFound(self.video_id, self.language, list(transcript_list)) + + async def aload(self) -> Generator[Document, None, None]: + """Asynchronously load YouTube transcripts into `Document` objects.""" + import asyncio + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.load) diff --git a/backend/open_webui/retrieval/models/base_reranker.py b/backend/open_webui/retrieval/models/base_reranker.py new file mode 100644 index 00000000000..6be7a5649b8 --- /dev/null +++ b/backend/open_webui/retrieval/models/base_reranker.py @@ -0,0 +1,8 @@ +from abc import ABC, abstractmethod +from typing import Optional, List, Tuple + + +class BaseReranker(ABC): + @abstractmethod + def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: + pass diff --git a/backend/open_webui/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py index 5b7499fd18b..7ec888437a0 100644 --- a/backend/open_webui/retrieval/models/colbert.py +++ b/backend/open_webui/retrieval/models/colbert.py @@ -7,11 +7,13 @@ from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.models.base_reranker import BaseReranker + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ColBERT: +class ColBERT(BaseReranker): def __init__(self, name, **kwargs) -> None: log.info("ColBERT: Loading model", name) self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index 187d66e384e..a9be526b6d1 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -1,14 +1,18 @@ import logging import requests from typing import Optional, List, Tuple +from urllib.parse import quote + + +from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS +from open_webui.retrieval.models.base_reranker import BaseReranker -from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) -class ExternalReranker: +class ExternalReranker(BaseReranker): def __init__( self, api_key: str, @@ -19,7 +23,9 @@ def __init__( self.url = url self.model = model - def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: + def predict( + self, sentences: List[Tuple[str, str]], user=None + ) -> Optional[List[float]]: query = sentences[0][0] docs = [i[1] for i in sentences] @@ -39,6 +45,16 @@ def predict(self, sentences: List[Tuple[str, str]]) -> Optional[List[float]]: headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), }, json=payload, ) diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index a132d720133..69aee29ac2f 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -5,7 +5,10 @@ import requests import hashlib from concurrent.futures import ThreadPoolExecutor +import time +import re +from urllib.parse import quote from huggingface_hub import snapshot_download from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever from langchain_community.retrievers import BM25Retriever @@ -14,10 +17,20 @@ from open_webui.config import VECTOR_DB from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT + from open_webui.models.users import UserModel from open_webui.models.files import Files +from open_webui.models.knowledge import Knowledges + +from open_webui.models.chats import Chats +from open_webui.models.notes import Notes from open_webui.retrieval.vector.main import GetResult +from open_webui.utils.access_control import has_access +from open_webui.utils.misc import get_message_list + +from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.loaders.youtube import YoutubeLoader from open_webui.env import ( @@ -41,6 +54,33 @@ from langchain_core.retrievers import BaseRetriever +def is_youtube_url(url: str) -> bool: + youtube_regex = r"^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$" + return re.match(youtube_regex, url) is not None + + +def get_loader(request, url: str): + if is_youtube_url(url): + return YoutubeLoader( + url, + language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, + proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, + ) + else: + return get_web_loader( + url, + verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, + ) + + +def get_content_from_url(request, url: str) -> str: + loader = get_loader(request, url) + docs = loader.load() + content = " ".join([doc.page_content for doc in docs]) + return content, docs + + class VectorSearchRetriever(BaseRetriever): collection_name: Any embedding_function: Any @@ -116,9 +156,21 @@ def query_doc_with_hybrid_search( reranking_function, k_reranker: int, r: float, + hybrid_bm25_weight: float, ) -> dict: try: + if ( + not collection_result + or not hasattr(collection_result, "documents") + or not collection_result.documents + or len(collection_result.documents) == 0 + or not collection_result.documents[0] + ): + log.warning(f"query_doc_with_hybrid_search:no_docs {collection_name}") + return {"documents": [], "metadatas": [], "distances": []} + log.debug(f"query_doc_with_hybrid_search:doc {collection_name}") + bm25_retriever = BM25Retriever.from_texts( texts=collection_result.documents[0], metadatas=collection_result.metadatas[0], @@ -131,9 +183,20 @@ def query_doc_with_hybrid_search( top_k=k, ) - ensemble_retriever = EnsembleRetriever( - retrievers=[bm25_retriever, vector_search_retriever], weights=[0.5, 0.5] - ) + if hybrid_bm25_weight <= 0: + ensemble_retriever = EnsembleRetriever( + retrievers=[vector_search_retriever], weights=[1.0] + ) + elif hybrid_bm25_weight >= 1: + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever], weights=[1.0] + ) + else: + ensemble_retriever = EnsembleRetriever( + retrievers=[bm25_retriever, vector_search_retriever], + weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight], + ) + compressor = RerankCompressor( embedding_function=embedding_function, top_n=k_reranker, @@ -157,7 +220,11 @@ def query_doc_with_hybrid_search( zip(distances, metadatas, documents), key=lambda x: x[0], reverse=True ) sorted_items = sorted_items[:k] - distances, documents, metadatas = map(list, zip(*sorted_items)) + + if sorted_items: + distances, documents, metadatas = map(list, zip(*sorted_items)) + else: + distances, documents, metadatas = [], [], [] result = { "distances": [distances], @@ -201,6 +268,13 @@ def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict: combined = dict() # To store documents with unique document hashes for data in query_results: + if ( + len(data.get("distances", [])) == 0 + or len(data.get("documents", [])) == 0 + or len(data.get("metadatas", [])) == 0 + ): + continue + distances = data["distances"][0] documents = data["documents"][0] metadatas = data["metadatas"][0] @@ -313,6 +387,7 @@ def query_collection_with_hybrid_search( reranking_function, k_reranker: int, r: float, + hybrid_bm25_weight: float, ) -> dict: results = [] error = False @@ -346,6 +421,7 @@ def process_query(collection_name, query): reranking_function=reranking_function, k_reranker=k_reranker, r=r, + hybrid_bm25_weight=hybrid_bm25_weight, ) return result, None except Exception as e: @@ -386,12 +462,13 @@ def get_embedding_function( url, key, embedding_batch_size, + azure_api_version=None, ): if embedding_engine == "": return lambda query, prefix=None, user=None: embedding_function.encode( query, **({"prompt": prefix} if prefix else {}) ).tolist() - elif embedding_engine in ["ollama", "openai"]: + elif embedding_engine in ["ollama", "openai", "azure_openai"]: func = lambda query, prefix=None, user=None: generate_embeddings( engine=embedding_engine, model=embedding_model, @@ -400,19 +477,21 @@ def get_embedding_function( url=url, key=key, user=user, + azure_api_version=azure_api_version, ) def generate_multiple(query, prefix, user, func): if isinstance(query, list): embeddings = [] for i in range(0, len(query), embedding_batch_size): - embeddings.extend( - func( - query[i : i + embedding_batch_size], - prefix=prefix, - user=user, - ) + batch_embeddings = func( + query[i : i + embedding_batch_size], + prefix=prefix, + user=user, ) + + if isinstance(batch_embeddings, list): + embeddings.extend(batch_embeddings) return embeddings else: return func(query, prefix, user) @@ -424,174 +503,289 @@ def generate_multiple(query, prefix, user, func): raise ValueError(f"Unknown embedding engine: {embedding_engine}") -def get_sources_from_files( +def get_reranking_function(reranking_engine, reranking_model, reranking_function): + if reranking_function is None: + return None + if reranking_engine == "external": + return lambda sentences, user=None: reranking_function.predict( + sentences, user=user + ) + else: + return lambda sentences, user=None: reranking_function.predict(sentences) + + +def get_sources_from_items( request, - files, + items, queries, embedding_function, k, reranking_function, k_reranker, r, + hybrid_bm25_weight, hybrid_search, full_context=False, + user: Optional[UserModel] = None, ): log.debug( - f"files: {files} {queries} {embedding_function} {reranking_function} {full_context}" + f"items: {items} {queries} {embedding_function} {reranking_function} {full_context}" ) extracted_collections = [] - relevant_contexts = [] + query_results = [] + + for item in items: + query_result = None + collection_names = [] + + if item.get("type") == "text": + # Raw Text + # Used during temporary chat file uploads or web page & youtube attachements + + if item.get("context") == "full": + if item.get("file"): + # if item has file data, use it + query_result = { + "documents": [ + [item.get("file", {}).get("data", {}).get("content")] + ], + "metadatas": [[item.get("file", {}).get("meta", {})]], + } - for file in files: + if query_result is None: + # Fallback + if item.get("collection_name"): + # If item has a collection name, use it + collection_names.append(item.get("collection_name")) + elif item.get("file"): + # If item has file data, use it + query_result = { + "documents": [ + [item.get("file", {}).get("data", {}).get("content")] + ], + "metadatas": [[item.get("file", {}).get("meta", {})]], + } + else: + # Fallback to item content + query_result = { + "documents": [[item.get("content")]], + "metadatas": [ + [{"file_id": item.get("id"), "name": item.get("name")}] + ], + } - context = None - if file.get("docs"): - # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL - context = { - "documents": [[doc.get("content") for doc in file.get("docs")]], - "metadatas": [[doc.get("metadata") for doc in file.get("docs")]], - } - elif file.get("context") == "full": - # Manual Full Mode Toggle - context = { - "documents": [[file.get("file").get("data", {}).get("content")]], - "metadatas": [[{"file_id": file.get("id"), "name": file.get("name")}]], - } - elif ( - file.get("type") != "web_search" - and request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL - ): - # BYPASS_EMBEDDING_AND_RETRIEVAL - if file.get("type") == "collection": - file_ids = file.get("data", {}).get("file_ids", []) + elif item.get("type") == "note": + # Note Attached + note = Notes.get_note_by_id(item.get("id")) + + if note and ( + user.role == "admin" + or note.user_id == user.id + or has_access(user.id, "read", note.access_control) + ): + # User has access to the note + query_result = { + "documents": [[note.data.get("content", {}).get("md", "")]], + "metadatas": [[{"file_id": note.id, "name": note.title}]], + } - documents = [] - metadatas = [] - for file_id in file_ids: - file_object = Files.get_file_by_id(file_id) + elif item.get("type") == "chat": + # Chat Attached + chat = Chats.get_chat_by_id(item.get("id")) + + if chat and (user.role == "admin" or chat.user_id == user.id): + messages_map = chat.chat.get("history", {}).get("messages", {}) + message_id = chat.chat.get("history", {}).get("currentId") + + if messages_map and message_id: + # Reconstruct the message list in order + message_list = get_message_list(messages_map, message_id) + message_history = "\n".join( + [ + f"#### {m.get('role', 'user').capitalize()}\n{m.get('content')}\n" + for m in message_list + ] + ) - if file_object: - documents.append(file_object.data.get("content", "")) - metadatas.append( - { - "file_id": file_id, - "name": file_object.filename, - "source": file_object.filename, - } - ) + # User has access to the chat + query_result = { + "documents": [[message_history]], + "metadatas": [[{"file_id": chat.id, "name": chat.title}]], + } - context = { - "documents": [documents], - "metadatas": [metadatas], + elif item.get("type") == "url": + content, docs = get_content_from_url(request, item.get("url")) + if docs: + query_result = { + "documents": [[content]], + "metadatas": [[{"url": item.get("url"), "name": item.get("url")}]], } - - elif file.get("id"): - file_object = Files.get_file_by_id(file.get("id")) - if file_object: - context = { - "documents": [[file_object.data.get("content", "")]], + elif item.get("type") == "file": + if ( + item.get("context") == "full" + or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ): + if item.get("file", {}).get("data", {}).get("content", ""): + # Manual Full Mode Toggle + # Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content") + query_result = { + "documents": [ + [item.get("file", {}).get("data", {}).get("content", "")] + ], "metadatas": [ [ { - "file_id": file.get("id"), - "name": file_object.filename, - "source": file_object.filename, + "file_id": item.get("id"), + "name": item.get("name"), + **item.get("file") + .get("data", {}) + .get("metadata", {}), } ] ], } - elif file.get("file").get("data"): - context = { - "documents": [[file.get("file").get("data", {}).get("content")]], - "metadatas": [ - [file.get("file").get("data", {}).get("metadata", {})] - ], - } - else: - collection_names = [] - if file.get("type") == "collection": - if file.get("legacy"): - collection_names = file.get("collection_names", []) + elif item.get("id"): + file_object = Files.get_file_by_id(item.get("id")) + if file_object: + query_result = { + "documents": [[file_object.data.get("content", "")]], + "metadatas": [ + [ + { + "file_id": item.get("id"), + "name": file_object.filename, + "source": file_object.filename, + } + ] + ], + } + else: + # Fallback to collection names + if item.get("legacy"): + collection_names.append(f"{item['id']}") else: - collection_names.append(file["id"]) - elif file.get("collection_name"): - collection_names.append(file["collection_name"]) - elif file.get("id"): - if file.get("legacy"): - collection_names.append(f"{file['id']}") + collection_names.append(f"file-{item['id']}") + + elif item.get("type") == "collection": + if ( + item.get("context") == "full" + or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ): + # Manual Full Mode Toggle for Collection + knowledge_base = Knowledges.get_knowledge_by_id(item.get("id")) + + if knowledge_base and ( + user.role == "admin" + or knowledge_base.user_id == user.id + or has_access(user.id, "read", knowledge_base.access_control) + ): + + file_ids = knowledge_base.data.get("file_ids", []) + + documents = [] + metadatas = [] + for file_id in file_ids: + file_object = Files.get_file_by_id(file_id) + + if file_object: + documents.append(file_object.data.get("content", "")) + metadatas.append( + { + "file_id": file_id, + "name": file_object.filename, + "source": file_object.filename, + } + ) + + query_result = { + "documents": [documents], + "metadatas": [metadatas], + } + else: + # Fallback to collection names + if item.get("legacy"): + collection_names = item.get("collection_names", []) else: - collection_names.append(f"file-{file['id']}") + collection_names.append(item["id"]) + elif item.get("docs"): + # BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL + query_result = { + "documents": [[doc.get("content") for doc in item.get("docs")]], + "metadatas": [[doc.get("metadata") for doc in item.get("docs")]], + } + elif item.get("collection_name"): + # Direct Collection Name + collection_names.append(item["collection_name"]) + elif item.get("collection_names"): + # Collection Names List + collection_names.extend(item["collection_names"]) + + # If query_result is None + # Fallback to collection names and vector search the collections + if query_result is None and collection_names: collection_names = set(collection_names).difference(extracted_collections) if not collection_names: - log.debug(f"skipping {file} as it has already been extracted") + log.debug(f"skipping {item} as it has already been extracted") continue - if full_context: - try: - context = get_all_items_from_collections(collection_names) - except Exception as e: - log.exception(e) - - else: - try: - context = None - if file.get("type") == "text": - context = file["content"] - else: - if hybrid_search: - try: - context = query_collection_with_hybrid_search( - collection_names=collection_names, - queries=queries, - embedding_function=embedding_function, - k=k, - reranking_function=reranking_function, - k_reranker=k_reranker, - r=r, - ) - except Exception as e: - log.debug( - "Error when using hybrid search, using" - " non hybrid search as fallback." - ) - - if (not hybrid_search) or (context is None): - context = query_collection( + try: + if full_context: + query_result = get_all_items_from_collections(collection_names) + else: + query_result = None # Initialize to None + if hybrid_search: + try: + query_result = query_collection_with_hybrid_search( collection_names=collection_names, queries=queries, embedding_function=embedding_function, k=k, + reranking_function=reranking_function, + k_reranker=k_reranker, + r=r, + hybrid_bm25_weight=hybrid_bm25_weight, + ) + except Exception as e: + log.debug( + "Error when using hybrid search, using non hybrid search as fallback." ) - except Exception as e: - log.exception(e) - extracted_collections.extend(collection_names) + # fallback to non-hybrid search + if not hybrid_search and query_result is None: + query_result = query_collection( + collection_names=collection_names, + queries=queries, + embedding_function=embedding_function, + k=k, + ) + except Exception as e: + log.exception(e) - if context: - if "data" in file: - del file["data"] + extracted_collections.extend(collection_names) - relevant_contexts.append({**context, "file": file}) + if query_result: + if "data" in item: + del item["data"] + query_results.append({**query_result, "file": item}) sources = [] - for context in relevant_contexts: + for query_result in query_results: try: - if "documents" in context: - if "metadatas" in context: + if "documents" in query_result: + if "metadatas" in query_result: source = { - "source": context["file"], - "document": context["documents"][0], - "metadata": context["metadatas"][0], + "source": query_result["file"], + "document": query_result["documents"][0], + "metadata": query_result["metadatas"][0], } - if "distances" in context and context["distances"]: - source["distances"] = context["distances"][0] + if "distances" in query_result and query_result["distances"]: + source["distances"] = query_result["distances"][0] sources.append(source) except Exception as e: log.exception(e) - return sources @@ -659,7 +853,7 @@ def generate_openai_batch_embeddings( "Authorization": f"Bearer {key}", **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -681,6 +875,60 @@ def generate_openai_batch_embeddings( return None +def generate_azure_openai_batch_embeddings( + model: str, + texts: list[str], + url: str, + key: str = "", + version: str = "", + prefix: str = None, + user: UserModel = None, +) -> Optional[list[list[float]]]: + try: + log.debug( + f"generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}" + ) + json_data = {"input": texts} + if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str): + json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix + + url = f"{url}/openai/deployments/{model}/embeddings?api-version={version}" + + for _ in range(5): + r = requests.post( + url, + headers={ + "Content-Type": "application/json", + "api-key": key, + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + } + if ENABLE_FORWARD_USER_INFO_HEADERS and user + else {} + ), + }, + json=json_data, + ) + if r.status_code == 429: + retry = float(r.headers.get("Retry-After", "1")) + time.sleep(retry) + continue + r.raise_for_status() + data = r.json() + if "data" in data: + return [elem["embedding"] for elem in data["data"]] + else: + raise Exception("Something went wrong :/") + return None + except Exception as e: + log.exception(f"Error generating azure openai batch embeddings: {e}") + return None + + def generate_ollama_batch_embeddings( model: str, texts: list[str], @@ -704,7 +952,7 @@ def generate_ollama_batch_embeddings( "Authorization": f"Bearer {key}", **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -745,38 +993,33 @@ def generate_embeddings( text = f"{prefix}{text}" if engine == "ollama": - if isinstance(text, list): - embeddings = generate_ollama_batch_embeddings( - **{ - "model": model, - "texts": text, - "url": url, - "key": key, - "prefix": prefix, - "user": user, - } - ) - else: - embeddings = generate_ollama_batch_embeddings( - **{ - "model": model, - "texts": [text], - "url": url, - "key": key, - "prefix": prefix, - "user": user, - } - ) + embeddings = generate_ollama_batch_embeddings( + **{ + "model": model, + "texts": text if isinstance(text, list) else [text], + "url": url, + "key": key, + "prefix": prefix, + "user": user, + } + ) return embeddings[0] if isinstance(text, str) else embeddings elif engine == "openai": - if isinstance(text, list): - embeddings = generate_openai_batch_embeddings( - model, text, url, key, prefix, user - ) - else: - embeddings = generate_openai_batch_embeddings( - model, [text], url, key, prefix, user - ) + embeddings = generate_openai_batch_embeddings( + model, text if isinstance(text, list) else [text], url, key, prefix, user + ) + return embeddings[0] if isinstance(text, str) else embeddings + elif engine == "azure_openai": + azure_api_version = kwargs.get("azure_api_version", "") + embeddings = generate_azure_openai_batch_embeddings( + model, + text if isinstance(text, list) else [text], + url, + key, + azure_api_version, + prefix, + user, + ) return embeddings[0] if isinstance(text, str) else embeddings @@ -805,8 +1048,9 @@ def compress_documents( ) -> Sequence[Document]: reranking = self.reranking_function is not None + scores = None if reranking: - scores = self.reranking_function.predict( + scores = self.reranking_function( [(query, doc.page_content) for doc in documents] ) else: @@ -818,22 +1062,31 @@ def compress_documents( ) scores = util.cos_sim(query_embedding, document_embedding)[0] - docs_with_scores = list( - zip(documents, scores.tolist() if not isinstance(scores, list) else scores) - ) - if self.r_score: - docs_with_scores = [ - (d, s) for d, s in docs_with_scores if s >= self.r_score - ] - - result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) - final_results = [] - for doc, doc_score in result[: self.top_n]: - metadata = doc.metadata - metadata["score"] = doc_score - doc = Document( - page_content=doc.page_content, - metadata=metadata, + if scores is not None: + docs_with_scores = list( + zip( + documents, + scores.tolist() if not isinstance(scores, list) else scores, + ) + ) + if self.r_score: + docs_with_scores = [ + (d, s) for d, s in docs_with_scores if s >= self.r_score + ] + + result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) + final_results = [] + for doc, doc_score in result[: self.top_n]: + metadata = doc.metadata + metadata["score"] = doc_score + doc = Document( + page_content=doc.page_content, + metadata=metadata, + ) + final_results.append(doc) + return final_results + else: + log.warning( + "No valid scores found, check your reranking function. Returning original documents." ) - final_results.append(doc) - return final_results + return documents diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index f9adc9c95f3..1fdb064c51f 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -11,6 +11,8 @@ SearchResult, GetResult, ) +from open_webui.retrieval.vector.utils import process_metadata + from open_webui.config import ( CHROMA_DATA_PATH, CHROMA_HTTP_HOST, @@ -144,7 +146,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): ids = [item["id"] for item in items] documents = [item["text"] for item in items] embeddings = [item["vector"] for item in items] - metadatas = [item["metadata"] for item in items] + metadatas = [process_metadata(item["metadata"]) for item in items] for batch in create_batches( api=self.client, @@ -164,7 +166,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): ids = [item["id"] for item in items] documents = [item["text"] for item in items] embeddings = [item["vector"] for item in items] - metadatas = [item["metadata"] for item in items] + metadatas = [process_metadata(item["metadata"]) for item in items] collection.upsert( ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index 18a915e381f..6de0d859f8a 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -2,6 +2,8 @@ from typing import Optional import ssl from elasticsearch.helpers import bulk, scan + +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -243,7 +245,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): "collection": collection_name, "vector": item["vector"], "text": item["text"], - "metadata": item["metadata"], + "metadata": process_metadata(item["metadata"]), }, } for item in batch @@ -264,7 +266,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): "collection": collection_name, "vector": item["vector"], "text": item["text"], - "metadata": item["metadata"], + "metadata": process_metadata(item["metadata"]), }, "doc_as_upsert": True, } diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index a4bad13d00d..98f8e335f21 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -1,8 +1,12 @@ from pymilvus import MilvusClient as Client from pymilvus import FieldSchema, DataType +from pymilvus import connections, Collection + import json import logging from typing import Optional + +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -18,6 +22,8 @@ MILVUS_HNSW_M, MILVUS_HNSW_EFCONSTRUCTION, MILVUS_IVF_FLAT_NLIST, + MILVUS_DISKANN_MAX_DEGREE, + MILVUS_DISKANN_SEARCH_LIST_SIZE, ) from open_webui.env import SRC_LOG_LEVELS @@ -127,12 +133,18 @@ def _create_collection(self, collection_name: str, dimension: int): elif index_type == "IVF_FLAT": index_creation_params = {"nlist": MILVUS_IVF_FLAT_NLIST} log.info(f"IVF_FLAT params: {index_creation_params}") + elif index_type == "DISKANN": + index_creation_params = { + "max_degree": MILVUS_DISKANN_MAX_DEGREE, + "search_list_size": MILVUS_DISKANN_SEARCH_LIST_SIZE, + } + log.info(f"DISKANN params: {index_creation_params}") elif index_type in ["FLAT", "AUTOINDEX"]: log.info(f"Using {index_type} index with no specific build-time params.") else: log.warning( f"Unsupported MILVUS_INDEX_TYPE: '{index_type}'. " - f"Supported types: HNSW, IVF_FLAT, FLAT, AUTOINDEX. " + f"Supported types: HNSW, IVF_FLAT, DISKANN, FLAT, AUTOINDEX. " f"Milvus will use its default for the collection if this type is not directly supported for index creation." ) # For unsupported types, pass the type directly to Milvus; it might handle it or use a default. @@ -185,7 +197,9 @@ def search( ) return self._result_to_search_result(result) - def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + def query(self, collection_name: str, filter: dict, limit: int = -1): + connections.connect(uri=MILVUS_URI, token=MILVUS_TOKEN, db_name=MILVUS_DB) + # Construct the filter string for querying collection_name = collection_name.replace("-", "_") if not self.has_collection(collection_name): @@ -199,72 +213,36 @@ def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) for key, value in filter.items() ] ) - max_limit = 16383 # The maximum number of records per request - all_results = [] - if limit is None: - # Milvus default limit for query if not specified is 16384, but docs mention iteration. - # Let's set a practical high number if "all" is intended, or handle true pagination. - # For now, if limit is None, we'll fetch in batches up to a very large number. - # This part could be refined based on expected use cases for "get all". - # For this function signature, None implies "as many as possible" up to Milvus limits. - limit = ( - 16384 * 10 - ) # A large number to signify fetching many, will be capped by actual data or max_limit per call. - log.info( - f"Limit not specified for query, fetching up to {limit} results in batches." - ) - # Initialize offset and remaining to handle pagination - offset = 0 - remaining = limit + collection = Collection(f"{self.collection_prefix}_{collection_name}") + collection.load() + all_results = [] try: log.info( f"Querying collection {self.collection_prefix}_{collection_name} with filter: '{filter_string}', limit: {limit}" ) - # Loop until there are no more items to fetch or the desired limit is reached - while remaining > 0: - current_fetch = min( - max_limit, remaining if isinstance(remaining, int) else max_limit - ) - log.debug( - f"Querying with offset: {offset}, current_fetch: {current_fetch}" - ) - - results = self.client.query( - collection_name=f"{self.collection_prefix}_{collection_name}", - filter=filter_string, - output_fields=[ - "id", - "data", - "metadata", - ], # Explicitly list needed fields. Vector not usually needed in query. - limit=current_fetch, - offset=offset, - ) - - if not results: - log.debug("No more results from query.") - break - all_results.extend(results) - results_count = len(results) - log.debug(f"Fetched {results_count} results in this batch.") - - if isinstance(remaining, int): - remaining -= results_count - - offset += results_count + iterator = collection.query_iterator( + filter=filter_string, + output_fields=[ + "id", + "data", + "metadata", + ], + limit=limit, # Pass the limit directly; -1 means no limit. + ) - # Break the loop if the results returned are less than the requested fetch count (means end of data) - if results_count < current_fetch: - log.debug( - "Fetched less than requested, assuming end of results for this query." - ) + while True: + result = iterator.next() + if not result: + iterator.close() break + all_results += result log.info(f"Total results from query: {len(all_results)}") return self._result_to_get_result([all_results]) + except Exception as e: log.exception( f"Error querying collection {self.collection_prefix}_{collection_name} with filter '{filter_string}' and limit {limit}: {e}" @@ -279,7 +257,7 @@ def get(self, collection_name: str) -> Optional[GetResult]: ) # Using query with a trivial filter to get all items. # This will use the paginated query logic. - return self.query(collection_name=collection_name, filter={}, limit=None) + return self.query(collection_name=collection_name, filter={}, limit=-1) def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. @@ -311,7 +289,7 @@ def insert(self, collection_name: str, items: list[VectorItem]): "id": item["id"], "vector": item["vector"], "data": {"text": item["text"]}, - "metadata": item["metadata"], + "metadata": process_metadata(item["metadata"]), } for item in items ], @@ -347,7 +325,7 @@ def upsert(self, collection_name: str, items: list[VectorItem]): "id": item["id"], "vector": item["vector"], "data": {"text": item["text"]}, - "metadata": item["metadata"], + "metadata": process_metadata(item["metadata"]), } for item in items ], diff --git a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py new file mode 100644 index 00000000000..5c80d155d35 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py @@ -0,0 +1,282 @@ +import logging +from typing import Optional, Tuple, List, Dict, Any + +from open_webui.config import ( + MILVUS_URI, + MILVUS_TOKEN, + MILVUS_DB, + MILVUS_COLLECTION_PREFIX, + MILVUS_INDEX_TYPE, + MILVUS_METRIC_TYPE, + MILVUS_HNSW_M, + MILVUS_HNSW_EFCONSTRUCTION, + MILVUS_IVF_FLAT_NLIST, +) +from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from pymilvus import ( + connections, + utility, + Collection, + CollectionSchema, + FieldSchema, + DataType, +) + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + +RESOURCE_ID_FIELD = "resource_id" + + +class MilvusClient(VectorDBBase): + def __init__(self): + # Milvus collection names can only contain numbers, letters, and underscores. + self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_") + connections.connect( + alias="default", + uri=MILVUS_URI, + token=MILVUS_TOKEN, + db_name=MILVUS_DB, + ) + + # Main collection types for multi-tenancy + self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" + self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" + self.FILE_COLLECTION = f"{self.collection_prefix}_files" + self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search" + self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based" + self.shared_collections = [ + self.MEMORY_COLLECTION, + self.KNOWLEDGE_COLLECTION, + self.FILE_COLLECTION, + self.WEB_SEARCH_COLLECTION, + self.HASH_BASED_COLLECTION, + ] + + def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]: + """ + Maps the traditional collection name to multi-tenant collection and resource ID. + + WARNING: This mapping relies on current Open WebUI naming conventions for + collection names. If Open WebUI changes how it generates collection names + (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash + formats), this mapping will break and route data to incorrect collections. + POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT + DATA MAPPING INSIDE THE DATABASE. + """ + resource_id = collection_name + + if collection_name.startswith("user-memory-"): + return self.MEMORY_COLLECTION, resource_id + elif collection_name.startswith("file-"): + return self.FILE_COLLECTION, resource_id + elif collection_name.startswith("web-search-"): + return self.WEB_SEARCH_COLLECTION, resource_id + elif len(collection_name) == 63 and all( + c in "0123456789abcdef" for c in collection_name + ): + return self.HASH_BASED_COLLECTION, resource_id + else: + return self.KNOWLEDGE_COLLECTION, resource_id + + def _create_shared_collection(self, mt_collection_name: str, dimension: int): + fields = [ + FieldSchema( + name="id", + dtype=DataType.VARCHAR, + is_primary=True, + auto_id=False, + max_length=36, + ), + FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), + FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535), + FieldSchema(name="metadata", dtype=DataType.JSON), + FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255), + ] + schema = CollectionSchema(fields, "Shared collection for multi-tenancy") + collection = Collection(mt_collection_name, schema) + + index_params = { + "metric_type": MILVUS_METRIC_TYPE, + "index_type": MILVUS_INDEX_TYPE, + "params": {}, + } + if MILVUS_INDEX_TYPE == "HNSW": + index_params["params"] = { + "M": MILVUS_HNSW_M, + "efConstruction": MILVUS_HNSW_EFCONSTRUCTION, + } + elif MILVUS_INDEX_TYPE == "IVF_FLAT": + index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST} + + collection.create_index("vector", index_params) + collection.create_index(RESOURCE_ID_FIELD) + log.info(f"Created shared collection: {mt_collection_name}") + return collection + + def _ensure_collection(self, mt_collection_name: str, dimension: int): + if not utility.has_collection(mt_collection_name): + self._create_shared_collection(mt_collection_name, dimension) + + def has_collection(self, collection_name: str) -> bool: + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return False + + collection = Collection(mt_collection) + collection.load() + res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1) + return len(res) > 0 + + def upsert(self, collection_name: str, items: List[VectorItem]): + if not items: + return + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + dimension = len(items[0]["vector"]) + self._ensure_collection(mt_collection, dimension) + collection = Collection(mt_collection) + + entities = [ + { + "id": item["id"], + "vector": item["vector"], + "text": item["text"], + "metadata": item["metadata"], + RESOURCE_ID_FIELD: resource_id, + } + for item in items + ] + collection.insert(entities) + collection.flush() + + def search( + self, collection_name: str, vectors: List[List[float]], limit: int + ) -> Optional[SearchResult]: + if not vectors: + return None + + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return None + + collection = Collection(mt_collection) + collection.load() + + search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}} + results = collection.search( + data=vectors, + anns_field="vector", + param=search_params, + limit=limit, + expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", + output_fields=["id", "text", "metadata"], + ) + + ids, documents, metadatas, distances = [], [], [], [] + for hits in results: + batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], [] + for hit in hits: + batch_ids.append(hit.entity.get("id")) + batch_docs.append(hit.entity.get("text")) + batch_metadatas.append(hit.entity.get("metadata")) + batch_dists.append(hit.distance) + ids.append(batch_ids) + documents.append(batch_docs) + metadatas.append(batch_metadatas) + distances.append(batch_dists) + + return SearchResult( + ids=ids, documents=documents, metadatas=metadatas, distances=distances + ) + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ): + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return + + collection = Collection(mt_collection) + + # Build expression + expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] + if ids: + # Milvus expects a string list for 'in' operator + id_list_str = ", ".join([f"'{id_val}'" for id_val in ids]) + expr.append(f"id in [{id_list_str}]") + + if filter: + for key, value in filter.items(): + expr.append(f"metadata['{key}'] == '{value}'") + + collection.delete(" and ".join(expr)) + + def reset(self): + for collection_name in self.shared_collections: + if utility.has_collection(collection_name): + utility.drop_collection(collection_name) + + def delete_collection(self, collection_name: str): + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return + + collection = Collection(mt_collection) + collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'") + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ) -> Optional[GetResult]: + mt_collection, resource_id = self._get_collection_and_resource_id( + collection_name + ) + if not utility.has_collection(mt_collection): + return None + + collection = Collection(mt_collection) + collection.load() + + expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"] + if filter: + for key, value in filter.items(): + if isinstance(value, str): + expr.append(f"metadata['{key}'] == '{value}'") + else: + expr.append(f"metadata['{key}'] == {value}") + + results = collection.query( + expr=" and ".join(expr), + output_fields=["id", "text", "metadata"], + limit=limit, + ) + + ids = [res["id"] for res in results] + documents = [res["text"] for res in results] + metadatas = [res["metadata"] for res in results] + + return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) + + def get(self, collection_name: str) -> Optional[GetResult]: + return self.query(collection_name, filter={}, limit=None) + + def insert(self, collection_name: str, items: List[VectorItem]): + return self.upsert(collection_name, items) diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index 60ef2d906cf..2e946710e24 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -2,6 +2,7 @@ from opensearchpy.helpers import bulk from typing import Optional +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, @@ -157,10 +158,10 @@ def query( for field, value in filter.items(): query_body["query"]["bool"]["filter"].append( - {"match": {"metadata." + str(field): value}} + {"term": {"metadata." + str(field) + ".keyword": value}} ) - size = limit if limit else 10 + size = limit if limit else 10000 try: result = self.client.search( @@ -200,12 +201,13 @@ def insert(self, collection_name: str, items: list[VectorItem]): "_source": { "vector": item["vector"], "text": item["text"], - "metadata": item["metadata"], + "metadata": process_metadata(item["metadata"]), }, } for item in batch ] bulk(self.client, actions) + self.client.indices.refresh(self._get_index_name(collection_name)) def upsert(self, collection_name: str, items: list[VectorItem]): self._create_index_if_not_exists( @@ -221,13 +223,14 @@ def upsert(self, collection_name: str, items: list[VectorItem]): "doc": { "vector": item["vector"], "text": item["text"], - "metadata": item["metadata"], + "metadata": process_metadata(item["metadata"]), }, "doc_as_upsert": True, } for item in batch ] bulk(self.client, actions) + self.client.indices.refresh(self._get_index_name(collection_name)) def delete( self, @@ -251,11 +254,12 @@ def delete( } for field, value in filter.items(): query_body["query"]["bool"]["filter"].append( - {"match": {"metadata." + str(field): value}} + {"term": {"metadata." + str(field) + ".keyword": value}} ) self.client.delete_by_query( index=self._get_index_name(collection_name), body=query_body ) + self.client.indices.refresh(self._get_index_name(collection_name)) def reset(self): indices = self.client.indices.get(index=f"{self.index_prefix}_*") diff --git a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py new file mode 100644 index 00000000000..b714588bdc2 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py @@ -0,0 +1,943 @@ +""" +Oracle 23ai Vector Database Client - Fixed Version + +# .env +VECTOR_DB = "oracle23ai" + +## DBCS or oracle 23ai free +ORACLE_DB_USE_WALLET = false +ORACLE_DB_USER = "DEMOUSER" +ORACLE_DB_PASSWORD = "Welcome123456" +ORACLE_DB_DSN = "localhost:1521/FREEPDB1" + +## ADW or ATP +# ORACLE_DB_USE_WALLET = true +# ORACLE_DB_USER = "DEMOUSER" +# ORACLE_DB_PASSWORD = "Welcome123456" +# ORACLE_DB_DSN = "medium" +# ORACLE_DB_DSN = "(description= (retry_count=3)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=xx.oraclecloud.com))(connect_data=(service_name=yy.adb.oraclecloud.com))(security=(ssl_server_dn_match=no)))" +# ORACLE_WALLET_DIR = "/home/opc/adb_wallet" +# ORACLE_WALLET_PASSWORD = "Welcome1" + +ORACLE_VECTOR_LENGTH = 768 + +ORACLE_DB_POOL_MIN = 2 +ORACLE_DB_POOL_MAX = 10 +ORACLE_DB_POOL_INCREMENT = 1 +""" + +from typing import Optional, List, Dict, Any, Union +from decimal import Decimal +import logging +import os +import threading +import time +import json +import array +import oracledb + +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + SearchResult, + GetResult, +) + +from open_webui.config import ( + ORACLE_DB_USE_WALLET, + ORACLE_DB_USER, + ORACLE_DB_PASSWORD, + ORACLE_DB_DSN, + ORACLE_WALLET_DIR, + ORACLE_WALLET_PASSWORD, + ORACLE_VECTOR_LENGTH, + ORACLE_DB_POOL_MIN, + ORACLE_DB_POOL_MAX, + ORACLE_DB_POOL_INCREMENT, +) +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +class Oracle23aiClient(VectorDBBase): + """ + Oracle Vector Database Client for vector similarity search using Oracle Database 23ai. + + This client provides an interface to store, retrieve, and search vector embeddings + in an Oracle database. It uses connection pooling for efficient database access + and supports vector similarity search operations. + + Attributes: + pool: Connection pool for Oracle database connections + """ + + def __init__(self) -> None: + """ + Initialize the Oracle23aiClient with a connection pool. + + Creates a connection pool with configurable min/max connections, initializes + the database schema if needed, and sets up necessary tables and indexes. + + Raises: + ValueError: If required configuration parameters are missing + Exception: If database initialization fails + """ + self.pool = None + + try: + # Create the appropriate connection pool based on DB type + if ORACLE_DB_USE_WALLET: + self._create_adb_pool() + else: # DBCS + self._create_dbcs_pool() + + dsn = ORACLE_DB_DSN + log.info(f"Creating Connection Pool [{ORACLE_DB_USER}:**@{dsn}]") + + with self.get_connection() as connection: + log.info(f"Connection version: {connection.version}") + self._initialize_database(connection) + + log.info("Oracle Vector Search initialization complete.") + except Exception as e: + log.exception(f"Error during Oracle Vector Search initialization: {e}") + raise + + def _create_adb_pool(self) -> None: + """ + Create connection pool for Oracle Autonomous Database. + + Uses wallet-based authentication. + """ + self.pool = oracledb.create_pool( + user=ORACLE_DB_USER, + password=ORACLE_DB_PASSWORD, + dsn=ORACLE_DB_DSN, + min=ORACLE_DB_POOL_MIN, + max=ORACLE_DB_POOL_MAX, + increment=ORACLE_DB_POOL_INCREMENT, + config_dir=ORACLE_WALLET_DIR, + wallet_location=ORACLE_WALLET_DIR, + wallet_password=ORACLE_WALLET_PASSWORD, + ) + log.info("Created ADB connection pool with wallet authentication.") + + def _create_dbcs_pool(self) -> None: + """ + Create connection pool for Oracle Database Cloud Service. + + Uses basic authentication without wallet. + """ + self.pool = oracledb.create_pool( + user=ORACLE_DB_USER, + password=ORACLE_DB_PASSWORD, + dsn=ORACLE_DB_DSN, + min=ORACLE_DB_POOL_MIN, + max=ORACLE_DB_POOL_MAX, + increment=ORACLE_DB_POOL_INCREMENT, + ) + log.info("Created DB connection pool with basic authentication.") + + def get_connection(self): + """ + Acquire a connection from the connection pool with retry logic. + + Returns: + connection: A database connection with output type handler configured + """ + max_retries = 3 + for attempt in range(max_retries): + try: + connection = self.pool.acquire() + connection.outputtypehandler = self._output_type_handler + return connection + except oracledb.DatabaseError as e: + (error_obj,) = e.args + log.exception( + f"Connection attempt {attempt + 1} failed: {error_obj.message}" + ) + + if attempt < max_retries - 1: + wait_time = 2**attempt + log.info(f"Retrying in {wait_time} seconds...") + time.sleep(wait_time) + else: + raise + + def start_health_monitor(self, interval_seconds: int = 60): + """ + Start a background thread to periodically check the health of the connection pool. + + Args: + interval_seconds (int): Number of seconds between health checks + """ + + def _monitor(): + while True: + try: + log.info("[HealthCheck] Running periodic DB health check...") + self.ensure_connection() + log.info("[HealthCheck] Connection is healthy.") + except Exception as e: + log.exception(f"[HealthCheck] Connection health check failed: {e}") + time.sleep(interval_seconds) + + thread = threading.Thread(target=_monitor, daemon=True) + thread.start() + log.info(f"Started DB health monitor every {interval_seconds} seconds.") + + def _reconnect_pool(self): + """ + Attempt to reinitialize the connection pool if it's been closed or broken. + """ + try: + log.info("Attempting to reinitialize the Oracle connection pool...") + + # Close existing pool if it exists + if self.pool: + try: + self.pool.close() + except Exception as close_error: + log.warning(f"Error closing existing pool: {close_error}") + + # Re-create the appropriate connection pool based on DB type + if ORACLE_DB_USE_WALLET: + self._create_adb_pool() + else: # DBCS + self._create_dbcs_pool() + + log.info("Connection pool reinitialized.") + except Exception as e: + log.exception(f"Failed to reinitialize the connection pool: {e}") + raise + + def ensure_connection(self): + """ + Ensure the database connection is alive, reconnecting pool if needed. + """ + try: + with self.get_connection() as connection: + with connection.cursor() as cursor: + cursor.execute("SELECT 1 FROM dual") + except Exception as e: + log.exception( + f"Connection check failed: {e}, attempting to reconnect pool..." + ) + self._reconnect_pool() + + def _output_type_handler(self, cursor, metadata): + """ + Handle Oracle vector type conversion. + + Args: + cursor: Oracle database cursor + metadata: Metadata for the column + + Returns: + A variable with appropriate conversion for vector types + """ + if metadata.type_code is oracledb.DB_TYPE_VECTOR: + return cursor.var( + metadata.type_code, arraysize=cursor.arraysize, outconverter=list + ) + + def _initialize_database(self, connection) -> None: + """ + Initialize database schema, tables and indexes. + + Creates the document_chunk table and necessary indexes if they don't exist. + + Args: + connection: Oracle database connection + + Raises: + Exception: If schema initialization fails + """ + with connection.cursor() as cursor: + try: + log.info("Creating Table document_chunk") + cursor.execute( + """ + BEGIN + EXECUTE IMMEDIATE ' + CREATE TABLE IF NOT EXISTS document_chunk ( + id VARCHAR2(255) PRIMARY KEY, + collection_name VARCHAR2(255) NOT NULL, + text CLOB, + vmetadata JSON, + vector vector(*, float32) + ) + '; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + ) + + log.info("Creating Index document_chunk_collection_name_idx") + cursor.execute( + """ + BEGIN + EXECUTE IMMEDIATE ' + CREATE INDEX IF NOT EXISTS document_chunk_collection_name_idx + ON document_chunk (collection_name) + '; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + ) + + log.info("Creating VECTOR INDEX document_chunk_vector_ivf_idx") + cursor.execute( + """ + BEGIN + EXECUTE IMMEDIATE ' + CREATE VECTOR INDEX IF NOT EXISTS document_chunk_vector_ivf_idx + ON document_chunk(vector) + ORGANIZATION NEIGHBOR PARTITIONS + DISTANCE COSINE + WITH TARGET ACCURACY 95 + PARAMETERS (TYPE IVF, NEIGHBOR PARTITIONS 100) + '; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + ) + + connection.commit() + log.info("Database initialization completed successfully.") + + except Exception as e: + connection.rollback() + log.exception(f"Error during database initialization: {e}") + raise + + def check_vector_length(self) -> None: + """ + Check vector length compatibility (placeholder). + + This method would check if the configured vector length matches the database schema. + Currently implemented as a placeholder. + """ + pass + + def _vector_to_blob(self, vector: List[float]) -> bytes: + """ + Convert a vector to Oracle BLOB format. + + Args: + vector (List[float]): The vector to convert + + Returns: + bytes: The vector in Oracle BLOB format + """ + return array.array("f", vector) + + def adjust_vector_length(self, vector: List[float]) -> List[float]: + """ + Adjust vector to the expected length if needed. + + Args: + vector (List[float]): The vector to adjust + + Returns: + List[float]: The adjusted vector + """ + return vector + + def _decimal_handler(self, obj): + """ + Handle Decimal objects for JSON serialization. + + Args: + obj: Object to serialize + + Returns: + float: Converted decimal value + + Raises: + TypeError: If object is not JSON serializable + """ + if isinstance(obj, Decimal): + return float(obj) + raise TypeError(f"{obj} is not JSON serializable") + + def _metadata_to_json(self, metadata: Dict) -> str: + """ + Convert metadata dictionary to JSON string. + + Args: + metadata (Dict): Metadata dictionary + + Returns: + str: JSON representation of metadata + """ + return json.dumps(metadata, default=self._decimal_handler) if metadata else "{}" + + def _json_to_metadata(self, json_str: str) -> Dict: + """ + Convert JSON string to metadata dictionary. + + Args: + json_str (str): JSON string + + Returns: + Dict: Metadata dictionary + """ + return json.loads(json_str) if json_str else {} + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + """ + Insert vector items into the database. + + Args: + collection_name (str): Name of the collection + items (List[VectorItem]): List of vector items to insert + + Raises: + Exception: If insertion fails + + Example: + >>> client = Oracle23aiClient() + >>> items = [ + ... {"id": "1", "text": "Sample text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}}, + ... {"id": "2", "text": "Another text", "vector": [0.3, 0.4, ...], "metadata": {"source": "doc2"}} + ... ] + >>> client.insert("my_collection", items) + """ + log.info(f"Inserting {len(items)} items into collection '{collection_name}'.") + + with self.get_connection() as connection: + try: + with connection.cursor() as cursor: + for item in items: + vector_blob = self._vector_to_blob(item["vector"]) + metadata_json = self._metadata_to_json(item["metadata"]) + + cursor.execute( + """ + INSERT INTO document_chunk + (id, collection_name, text, vmetadata, vector) + VALUES (:id, :collection_name, :text, :metadata, :vector) + """, + { + "id": item["id"], + "collection_name": collection_name, + "text": item["text"], + "metadata": metadata_json, + "vector": vector_blob, + }, + ) + + connection.commit() + log.info( + f"Successfully inserted {len(items)} items into collection '{collection_name}'." + ) + + except Exception as e: + connection.rollback() + log.exception(f"Error during insert: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + """ + Update or insert vector items into the database. + + If an item with the same ID exists, it will be updated; + otherwise, it will be inserted. + + Args: + collection_name (str): Name of the collection + items (List[VectorItem]): List of vector items to upsert + + Raises: + Exception: If upsert operation fails + + Example: + >>> client = Oracle23aiClient() + >>> items = [ + ... {"id": "1", "text": "Updated text", "vector": [0.1, 0.2, ...], "metadata": {"source": "doc1"}}, + ... {"id": "3", "text": "New item", "vector": [0.5, 0.6, ...], "metadata": {"source": "doc3"}} + ... ] + >>> client.upsert("my_collection", items) + """ + log.info(f"Upserting {len(items)} items into collection '{collection_name}'.") + + with self.get_connection() as connection: + try: + with connection.cursor() as cursor: + for item in items: + vector_blob = self._vector_to_blob(item["vector"]) + metadata_json = self._metadata_to_json(item["metadata"]) + + cursor.execute( + """ + MERGE INTO document_chunk d + USING (SELECT :merge_id as id FROM dual) s + ON (d.id = s.id) + WHEN MATCHED THEN + UPDATE SET + collection_name = :upd_collection_name, + text = :upd_text, + vmetadata = :upd_metadata, + vector = :upd_vector + WHEN NOT MATCHED THEN + INSERT (id, collection_name, text, vmetadata, vector) + VALUES (:ins_id, :ins_collection_name, :ins_text, :ins_metadata, :ins_vector) + """, + { + "merge_id": item["id"], + "upd_collection_name": collection_name, + "upd_text": item["text"], + "upd_metadata": metadata_json, + "upd_vector": vector_blob, + "ins_id": item["id"], + "ins_collection_name": collection_name, + "ins_text": item["text"], + "ins_metadata": metadata_json, + "ins_vector": vector_blob, + }, + ) + + connection.commit() + log.info( + f"Successfully upserted {len(items)} items into collection '{collection_name}'." + ) + + except Exception as e: + connection.rollback() + log.exception(f"Error during upsert: {e}") + raise + + def search( + self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + ) -> Optional[SearchResult]: + """ + Search for similar vectors in the database. + + Performs vector similarity search using cosine distance. + + Args: + collection_name (str): Name of the collection to search + vectors (List[List[Union[float, int]]]): Query vectors to find similar items for + limit (int): Maximum number of results to return per query + + Returns: + Optional[SearchResult]: Search results containing ids, distances, documents, and metadata + + Example: + >>> client = Oracle23aiClient() + >>> query_vector = [0.1, 0.2, 0.3, ...] # Must match VECTOR_LENGTH + >>> results = client.search("my_collection", [query_vector], limit=5) + >>> if results: + ... log.info(f"Found {len(results.ids[0])} matches") + ... for i, (id, dist) in enumerate(zip(results.ids[0], results.distances[0])): + ... log.info(f"Match {i+1}: id={id}, distance={dist}") + """ + log.info( + f"Searching items from collection '{collection_name}' with limit {limit}." + ) + + try: + if not vectors: + log.warning("No vectors provided for search.") + return None + + num_queries = len(vectors) + + ids = [[] for _ in range(num_queries)] + distances = [[] for _ in range(num_queries)] + documents = [[] for _ in range(num_queries)] + metadatas = [[] for _ in range(num_queries)] + + with self.get_connection() as connection: + with connection.cursor() as cursor: + for qid, vector in enumerate(vectors): + vector_blob = self._vector_to_blob(vector) + + cursor.execute( + """ + SELECT dc.id, dc.text, + JSON_SERIALIZE(dc.vmetadata RETURNING VARCHAR2(4096)) as vmetadata, + VECTOR_DISTANCE(dc.vector, :query_vector, COSINE) as distance + FROM document_chunk dc + WHERE dc.collection_name = :collection_name + ORDER BY VECTOR_DISTANCE(dc.vector, :query_vector, COSINE) + FETCH APPROX FIRST :limit ROWS ONLY + """, + { + "query_vector": vector_blob, + "collection_name": collection_name, + "limit": limit, + }, + ) + + results = cursor.fetchall() + + for row in results: + ids[qid].append(row[0]) + documents[qid].append( + row[1].read() + if isinstance(row[1], oracledb.LOB) + else str(row[1]) + ) + # 🔧 FIXED: Parse JSON metadata properly + metadata_str = ( + row[2].read() + if isinstance(row[2], oracledb.LOB) + else row[2] + ) + metadatas[qid].append(self._json_to_metadata(metadata_str)) + distances[qid].append(float(row[3])) + + log.info( + f"Search completed. Found {sum(len(ids[i]) for i in range(num_queries))} total results." + ) + + return SearchResult( + ids=ids, distances=distances, documents=documents, metadatas=metadatas + ) + + except Exception as e: + log.exception(f"Error during search: {e}") + return None + + def query( + self, collection_name: str, filter: Dict, limit: Optional[int] = None + ) -> Optional[GetResult]: + """ + Query items based on metadata filters. + + Retrieves items that match specified metadata criteria. + + Args: + collection_name (str): Name of the collection to query + filter (Dict[str, Any]): Metadata filters to apply + limit (Optional[int]): Maximum number of results to return + + Returns: + Optional[GetResult]: Query results containing ids, documents, and metadata + + Example: + >>> client = Oracle23aiClient() + >>> filter = {"source": "doc1", "category": "finance"} + >>> results = client.query("my_collection", filter, limit=20) + >>> if results: + ... print(f"Found {len(results.ids[0])} matching documents") + """ + log.info(f"Querying items from collection '{collection_name}' with filters.") + + try: + limit = limit or 100 + + query = """ + SELECT id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata + FROM document_chunk + WHERE collection_name = :collection_name + """ + + params = {"collection_name": collection_name} + + for i, (key, value) in enumerate(filter.items()): + param_name = f"value_{i}" + query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}" + params[param_name] = str(value) + + query += " FETCH FIRST :limit ROWS ONLY" + params["limit"] = limit + + with self.get_connection() as connection: + with connection.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + + if not results: + log.info("No results found for query.") + return None + + ids = [[row[0] for row in results]] + documents = [ + [ + row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) + for row in results + ] + ] + # 🔧 FIXED: Parse JSON metadata properly + metadatas = [ + [ + self._json_to_metadata( + row[2].read() if isinstance(row[2], oracledb.LOB) else row[2] + ) + for row in results + ] + ] + + log.info(f"Query completed. Found {len(results)} results.") + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + + except Exception as e: + log.exception(f"Error during query: {e}") + return None + + def get(self, collection_name: str) -> Optional[GetResult]: + """ + Get all items in a collection. + + Retrieves items from a specified collection up to the limit. + + Args: + collection_name (str): Name of the collection to retrieve + limit (Optional[int]): Maximum number of items to retrieve + + Returns: + Optional[GetResult]: Result containing ids, documents, and metadata + + Example: + >>> client = Oracle23aiClient() + >>> results = client.get("my_collection", limit=50) + >>> if results: + ... print(f"Retrieved {len(results.ids[0])} documents from collection") + """ + log.info( + f"Getting items from collection '{collection_name}' with limit {limit}." + ) + + try: + limit = 1000 # Hardcoded limit for get operation + + with self.get_connection() as connection: + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT /*+ MONITOR */ id, text, JSON_SERIALIZE(vmetadata RETURNING VARCHAR2(4096)) as vmetadata + FROM document_chunk + WHERE collection_name = :collection_name + FETCH FIRST :limit ROWS ONLY + """, + {"collection_name": collection_name, "limit": limit}, + ) + + results = cursor.fetchall() + + if not results: + log.info("No results found.") + return None + + ids = [[row[0] for row in results]] + documents = [ + [ + row[1].read() if isinstance(row[1], oracledb.LOB) else str(row[1]) + for row in results + ] + ] + # 🔧 FIXED: Parse JSON metadata properly + metadatas = [ + [ + self._json_to_metadata( + row[2].read() if isinstance(row[2], oracledb.LOB) else row[2] + ) + for row in results + ] + ] + + return GetResult(ids=ids, documents=documents, metadatas=metadatas) + + except Exception as e: + log.exception(f"Error during get: {e}") + return None + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Delete items from the database. + + Deletes items from a collection based on IDs or metadata filters. + + Args: + collection_name (str): Name of the collection to delete from + ids (Optional[List[str]]): Specific item IDs to delete + filter (Optional[Dict[str, Any]]): Metadata filters for deletion + + Raises: + Exception: If deletion fails + + Example: + >>> client = Oracle23aiClient() + >>> # Delete specific items by ID + >>> client.delete("my_collection", ids=["1", "3", "5"]) + >>> # Or delete by metadata filter + >>> client.delete("my_collection", filter={"source": "deprecated_source"}) + """ + log.info(f"Deleting items from collection '{collection_name}'.") + + try: + query = ( + "DELETE FROM document_chunk WHERE collection_name = :collection_name" + ) + params = {"collection_name": collection_name} + + if ids: + # 🔧 FIXED: Use proper parameterized query to prevent SQL injection + placeholders = ",".join([f":id_{i}" for i in range(len(ids))]) + query += f" AND id IN ({placeholders})" + for i, id_val in enumerate(ids): + params[f"id_{i}"] = id_val + + if filter: + for i, (key, value) in enumerate(filter.items()): + param_name = f"value_{i}" + query += f" AND JSON_VALUE(vmetadata, '$.{key}' RETURNING VARCHAR2(4096)) = :{param_name}" + params[param_name] = str(value) + + with self.get_connection() as connection: + with connection.cursor() as cursor: + cursor.execute(query, params) + deleted = cursor.rowcount + connection.commit() + + log.info(f"Deleted {deleted} items from collection '{collection_name}'.") + + except Exception as e: + log.exception(f"Error during delete: {e}") + raise + + def reset(self) -> None: + """ + Reset the database by deleting all items. + + Deletes all items from the document_chunk table. + + Raises: + Exception: If reset fails + + Example: + >>> client = Oracle23aiClient() + >>> client.reset() # Warning: Removes all data! + """ + log.info("Resetting database - deleting all items.") + + try: + with self.get_connection() as connection: + with connection.cursor() as cursor: + cursor.execute("DELETE FROM document_chunk") + deleted = cursor.rowcount + connection.commit() + + log.info( + f"Reset complete. Deleted {deleted} items from 'document_chunk' table." + ) + + except Exception as e: + log.exception(f"Error during reset: {e}") + raise + + def close(self) -> None: + """ + Close the database connection pool. + + Properly closes the connection pool and releases all resources. + + Example: + >>> client = Oracle23aiClient() + >>> # After finishing all operations + >>> client.close() + """ + try: + if hasattr(self, "pool") and self.pool: + self.pool.close() + log.info("Oracle Vector Search connection pool closed.") + except Exception as e: + log.exception(f"Error closing connection pool: {e}") + + def has_collection(self, collection_name: str) -> bool: + """ + Check if a collection exists. + + Args: + collection_name (str): Name of the collection to check + + Returns: + bool: True if the collection exists, False otherwise + + Example: + >>> client = Oracle23aiClient() + >>> if client.has_collection("my_collection"): + ... print("Collection exists!") + ... else: + ... print("Collection does not exist.") + """ + try: + with self.get_connection() as connection: + with connection.cursor() as cursor: + cursor.execute( + """ + SELECT COUNT(*) + FROM document_chunk + WHERE collection_name = :collection_name + FETCH FIRST 1 ROWS ONLY + """, + {"collection_name": collection_name}, + ) + + count = cursor.fetchone()[0] + + return count > 0 + + except Exception as e: + log.exception(f"Error checking collection existence: {e}") + return False + + def delete_collection(self, collection_name: str) -> None: + """ + Delete an entire collection. + + Removes all items belonging to the specified collection. + + Args: + collection_name (str): Name of the collection to delete + + Example: + >>> client = Oracle23aiClient() + >>> client.delete_collection("obsolete_collection") + """ + log.info(f"Deleting collection '{collection_name}'.") + + try: + with self.get_connection() as connection: + with connection.cursor() as cursor: + cursor.execute( + """ + DELETE FROM document_chunk + WHERE collection_name = :collection_name + """, + {"collection_name": collection_name}, + ) + + deleted = cursor.rowcount + connection.commit() + + log.info( + f"Collection '{collection_name}' deleted. Removed {deleted} items." + ) + + except Exception as e: + log.exception(f"Error deleting collection '{collection_name}': {e}") + raise diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index b6cb2a4e25e..312b48944c9 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -1,12 +1,16 @@ from typing import Optional, List, Dict, Any import logging +import json from sqlalchemy import ( + func, + literal, cast, column, create_engine, Column, Integer, MetaData, + LargeBinary, select, text, Text, @@ -14,7 +18,7 @@ values, ) from sqlalchemy.sql import true -from sqlalchemy.pool import NullPool +from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.dialects.postgresql import JSONB, array @@ -22,13 +26,25 @@ from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.exc import NoSuchTableError + +from open_webui.retrieval.vector.utils import process_metadata from open_webui.retrieval.vector.main import ( VectorDBBase, VectorItem, SearchResult, GetResult, ) -from open_webui.config import PGVECTOR_DB_URL, PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH +from open_webui.config import ( + PGVECTOR_DB_URL, + PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH, + PGVECTOR_CREATE_EXTENSION, + PGVECTOR_PGCRYPTO, + PGVECTOR_PGCRYPTO_KEY, + PGVECTOR_POOL_SIZE, + PGVECTOR_POOL_MAX_OVERFLOW, + PGVECTOR_POOL_TIMEOUT, + PGVECTOR_POOL_RECYCLE, +) from open_webui.env import SRC_LOG_LEVELS @@ -39,14 +55,27 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) +def pgcrypto_encrypt(val, key): + return func.pgp_sym_encrypt(val, literal(key)) + + +def pgcrypto_decrypt(col, key, outtype="text"): + return func.cast(func.pgp_sym_decrypt(col, literal(key)), outtype) + + class DocumentChunk(Base): __tablename__ = "document_chunk" id = Column(Text, primary_key=True) vector = Column(Vector(dim=VECTOR_LENGTH), nullable=True) collection_name = Column(Text, nullable=False) - text = Column(Text, nullable=True) - vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) + + if PGVECTOR_PGCRYPTO: + text = Column(LargeBinary, nullable=True) + vmetadata = Column(LargeBinary, nullable=True) + else: + text = Column(Text, nullable=True) + vmetadata = Column(MutableDict.as_mutable(JSONB), nullable=True) class PgvectorClient(VectorDBBase): @@ -58,9 +87,24 @@ def __init__(self) -> None: self.session = Session else: - engine = create_engine( - PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool - ) + if isinstance(PGVECTOR_POOL_SIZE, int): + if PGVECTOR_POOL_SIZE > 0: + engine = create_engine( + PGVECTOR_DB_URL, + pool_size=PGVECTOR_POOL_SIZE, + max_overflow=PGVECTOR_POOL_MAX_OVERFLOW, + pool_timeout=PGVECTOR_POOL_TIMEOUT, + pool_recycle=PGVECTOR_POOL_RECYCLE, + pool_pre_ping=True, + poolclass=QueuePool, + ) + else: + engine = create_engine( + PGVECTOR_DB_URL, pool_pre_ping=True, poolclass=NullPool + ) + else: + engine = create_engine(PGVECTOR_DB_URL, pool_pre_ping=True) + SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) @@ -68,7 +112,41 @@ def __init__(self) -> None: try: # Ensure the pgvector extension is available - self.session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + # Use a conditional check to avoid permission issues on Azure PostgreSQL + if PGVECTOR_CREATE_EXTENSION: + self.session.execute( + text( + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'vector') THEN + CREATE EXTENSION IF NOT EXISTS vector; + END IF; + END $$; + """ + ) + ) + + if PGVECTOR_PGCRYPTO: + # Ensure the pgcrypto extension is available for encryption + # Use a conditional check to avoid permission issues on Azure PostgreSQL + self.session.execute( + text( + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgcrypto') THEN + CREATE EXTENSION IF NOT EXISTS pgcrypto; + END IF; + END $$; + """ + ) + ) + + if not PGVECTOR_PGCRYPTO_KEY: + raise ValueError( + "PGVECTOR_PGCRYPTO_KEY must be set when PGVECTOR_PGCRYPTO is enabled." + ) # Check vector length consistency self.check_vector_length() @@ -147,22 +225,54 @@ def adjust_vector_length(self, vector: List[float]) -> List[float]: def insert(self, collection_name: str, items: List[VectorItem]) -> None: try: - new_items = [] - for item in items: - vector = self.adjust_vector_length(item["vector"]) - new_chunk = DocumentChunk( - id=item["id"], - vector=vector, - collection_name=collection_name, - text=item["text"], - vmetadata=item["metadata"], + if PGVECTOR_PGCRYPTO: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + # Use raw SQL for BYTEA/pgcrypto + # Ensure metadata is converted to its JSON text representation + json_metadata = json.dumps(item["metadata"]) + self.session.execute( + text( + """ + INSERT INTO document_chunk + (id, vector, collection_name, text, vmetadata) + VALUES ( + :id, :vector, :collection_name, + pgp_sym_encrypt(:text, :key), + pgp_sym_encrypt(:metadata_text, :key) + ) + ON CONFLICT (id) DO NOTHING + """ + ), + { + "id": item["id"], + "vector": vector, + "collection_name": collection_name, + "text": item["text"], + "metadata_text": json_metadata, + "key": PGVECTOR_PGCRYPTO_KEY, + }, + ) + self.session.commit() + log.info(f"Encrypted & inserted {len(items)} into '{collection_name}'") + + else: + new_items = [] + for item in items: + vector = self.adjust_vector_length(item["vector"]) + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=process_metadata(item["metadata"]), + ) + new_items.append(new_chunk) + self.session.bulk_save_objects(new_items) + self.session.commit() + log.info( + f"Inserted {len(new_items)} items into collection '{collection_name}'." ) - new_items.append(new_chunk) - self.session.bulk_save_objects(new_items) - self.session.commit() - log.info( - f"Inserted {len(new_items)} items into collection '{collection_name}'." - ) except Exception as e: self.session.rollback() log.exception(f"Error during insert: {e}") @@ -170,33 +280,66 @@ def insert(self, collection_name: str, items: List[VectorItem]) -> None: def upsert(self, collection_name: str, items: List[VectorItem]) -> None: try: - for item in items: - vector = self.adjust_vector_length(item["vector"]) - existing = ( - self.session.query(DocumentChunk) - .filter(DocumentChunk.id == item["id"]) - .first() - ) - if existing: - existing.vector = vector - existing.text = item["text"] - existing.vmetadata = item["metadata"] - existing.collection_name = ( - collection_name # Update collection_name if necessary + if PGVECTOR_PGCRYPTO: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + json_metadata = json.dumps(item["metadata"]) + self.session.execute( + text( + """ + INSERT INTO document_chunk + (id, vector, collection_name, text, vmetadata) + VALUES ( + :id, :vector, :collection_name, + pgp_sym_encrypt(:text, :key), + pgp_sym_encrypt(:metadata_text, :key) + ) + ON CONFLICT (id) DO UPDATE SET + vector = EXCLUDED.vector, + collection_name = EXCLUDED.collection_name, + text = EXCLUDED.text, + vmetadata = EXCLUDED.vmetadata + """ + ), + { + "id": item["id"], + "vector": vector, + "collection_name": collection_name, + "text": item["text"], + "metadata_text": json_metadata, + "key": PGVECTOR_PGCRYPTO_KEY, + }, ) - else: - new_chunk = DocumentChunk( - id=item["id"], - vector=vector, - collection_name=collection_name, - text=item["text"], - vmetadata=item["metadata"], + self.session.commit() + log.info(f"Encrypted & upserted {len(items)} into '{collection_name}'") + else: + for item in items: + vector = self.adjust_vector_length(item["vector"]) + existing = ( + self.session.query(DocumentChunk) + .filter(DocumentChunk.id == item["id"]) + .first() ) - self.session.add(new_chunk) - self.session.commit() - log.info( - f"Upserted {len(items)} items into collection '{collection_name}'." - ) + if existing: + existing.vector = vector + existing.text = item["text"] + existing.vmetadata = process_metadata(item["metadata"]) + existing.collection_name = ( + collection_name # Update collection_name if necessary + ) + else: + new_chunk = DocumentChunk( + id=item["id"], + vector=vector, + collection_name=collection_name, + text=item["text"], + vmetadata=process_metadata(item["metadata"]), + ) + self.session.add(new_chunk) + self.session.commit() + log.info( + f"Upserted {len(items)} items into collection '{collection_name}'." + ) except Exception as e: self.session.rollback() log.exception(f"Error during upsert: {e}") @@ -230,16 +373,32 @@ def vector_expr(vector): .alias("query_vectors") ) + result_fields = [ + DocumentChunk.id, + ] + if PGVECTOR_PGCRYPTO: + result_fields.append( + pgcrypto_decrypt( + DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text + ).label("text") + ) + result_fields.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + ).label("vmetadata") + ) + else: + result_fields.append(DocumentChunk.text) + result_fields.append(DocumentChunk.vmetadata) + result_fields.append( + (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)).label( + "distance" + ) + ) + # Build the lateral subquery for each query vector subq = ( - select( - DocumentChunk.id, - DocumentChunk.text, - DocumentChunk.vmetadata, - ( - DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector) - ).label("distance"), - ) + select(*result_fields) .where(DocumentChunk.collection_name == collection_name) .order_by( (DocumentChunk.vector.cosine_distance(query_vectors.c.q_vector)) @@ -288,10 +447,12 @@ def vector_expr(vector): documents[qid].append(row.text) metadatas[qid].append(row.vmetadata) + self.session.rollback() # read-only transaction return SearchResult( ids=ids, distances=distances, documents=documents, metadatas=metadatas ) except Exception as e: + self.session.rollback() log.exception(f"Error during search: {e}") return None @@ -299,17 +460,43 @@ def query( self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None ) -> Optional[GetResult]: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) + if PGVECTOR_PGCRYPTO: + # Build where clause for vmetadata filter + where_clauses = [DocumentChunk.collection_name == collection_name] + for key, value in filter.items(): + # decrypt then check key: JSON filter after decryption + where_clauses.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + )[key].astext + == str(value) + ) + stmt = select( + DocumentChunk.id, + pgcrypto_decrypt( + DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text + ).label("text"), + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + ).label("vmetadata"), + ).where(*where_clauses) + if limit is not None: + stmt = stmt.limit(limit) + results = self.session.execute(stmt).all() + else: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) - for key, value in filter.items(): - query = query.filter(DocumentChunk.vmetadata[key].astext == str(value)) + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) - if limit is not None: - query = query.limit(limit) + if limit is not None: + query = query.limit(limit) - results = query.all() + results = query.all() if not results: return None @@ -318,12 +505,14 @@ def query( documents = [[result.text for result in results]] metadatas = [[result.vmetadata for result in results]] + self.session.rollback() # read-only transaction return GetResult( ids=ids, documents=documents, metadatas=metadatas, ) except Exception as e: + self.session.rollback() log.exception(f"Error during query: {e}") return None @@ -331,23 +520,43 @@ def get( self, collection_name: str, limit: Optional[int] = None ) -> Optional[GetResult]: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) - if limit is not None: - query = query.limit(limit) + if PGVECTOR_PGCRYPTO: + stmt = select( + DocumentChunk.id, + pgcrypto_decrypt( + DocumentChunk.text, PGVECTOR_PGCRYPTO_KEY, Text + ).label("text"), + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + ).label("vmetadata"), + ).where(DocumentChunk.collection_name == collection_name) + if limit is not None: + stmt = stmt.limit(limit) + results = self.session.execute(stmt).all() + ids = [[row.id for row in results]] + documents = [[row.text for row in results]] + metadatas = [[row.vmetadata for row in results]] + else: - results = query.all() + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if limit is not None: + query = query.limit(limit) - if not results: - return None + results = query.all() - ids = [[result.id for result in results]] - documents = [[result.text for result in results]] - metadatas = [[result.vmetadata for result in results]] + if not results: + return None + ids = [[result.id for result in results]] + documents = [[result.text for result in results]] + metadatas = [[result.vmetadata for result in results]] + + self.session.rollback() # read-only transaction return GetResult(ids=ids, documents=documents, metadatas=metadatas) except Exception as e: + self.session.rollback() log.exception(f"Error during get: {e}") return None @@ -358,17 +567,33 @@ def delete( filter: Optional[Dict[str, Any]] = None, ) -> None: try: - query = self.session.query(DocumentChunk).filter( - DocumentChunk.collection_name == collection_name - ) - if ids: - query = query.filter(DocumentChunk.id.in_(ids)) - if filter: - for key, value in filter.items(): - query = query.filter( - DocumentChunk.vmetadata[key].astext == str(value) - ) - deleted = query.delete(synchronize_session=False) + if PGVECTOR_PGCRYPTO: + wheres = [DocumentChunk.collection_name == collection_name] + if ids: + wheres.append(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + wheres.append( + pgcrypto_decrypt( + DocumentChunk.vmetadata, PGVECTOR_PGCRYPTO_KEY, JSONB + )[key].astext + == str(value) + ) + stmt = DocumentChunk.__table__.delete().where(*wheres) + result = self.session.execute(stmt) + deleted = result.rowcount + else: + query = self.session.query(DocumentChunk).filter( + DocumentChunk.collection_name == collection_name + ) + if ids: + query = query.filter(DocumentChunk.id.in_(ids)) + if filter: + for key, value in filter.items(): + query = query.filter( + DocumentChunk.vmetadata[key].astext == str(value) + ) + deleted = query.delete(synchronize_session=False) self.session.commit() log.info(f"Deleted {deleted} items from collection '{collection_name}'.") except Exception as e: @@ -399,8 +624,10 @@ def has_collection(self, collection_name: str) -> bool: .first() is not None ) + self.session.rollback() # read-only transaction return exists except Exception as e: + self.session.rollback() log.exception(f"Error checking collection existence: {e}") return False diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index c921089b6da..5bef0d9ea7d 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -1,13 +1,21 @@ from typing import Optional, List, Dict, Any, Union import logging import time # for measuring elapsed time -from pinecone import ServerlessSpec +from pinecone import Pinecone, ServerlessSpec + +# Add gRPC support for better performance (Pinecone best practice) +try: + from pinecone.grpc import PineconeGRPC + + GRPC_AVAILABLE = True +except ImportError: + GRPC_AVAILABLE = False import asyncio # for async upserts import functools # for partial binding in async tasks import concurrent.futures # for parallel batch upserts -from pinecone.grpc import PineconeGRPC # use gRPC client for faster upserts +import random # for jitter in retry backoff from open_webui.retrieval.vector.main import ( VectorDBBase, @@ -24,6 +32,8 @@ PINECONE_CLOUD, ) from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.vector.utils import process_metadata + NO_LIMIT = 10000 # Reasonable limit to avoid overwhelming the system BATCH_SIZE = 100 # Recommended batch size for Pinecone operations @@ -47,10 +57,25 @@ def __init__(self): self.metric = PINECONE_METRIC self.cloud = PINECONE_CLOUD - # Initialize Pinecone gRPC client for improved performance - self.client = PineconeGRPC( - api_key=self.api_key, environment=self.environment, cloud=self.cloud - ) + # Initialize Pinecone client for improved performance + if GRPC_AVAILABLE: + # Use gRPC client for better performance (Pinecone recommendation) + self.client = PineconeGRPC( + api_key=self.api_key, + pool_threads=20, # Improved connection pool size + timeout=30, # Reasonable timeout for operations + ) + self.using_grpc = True + log.info("Using Pinecone gRPC client for optimal performance") + else: + # Fallback to HTTP client with enhanced connection pooling + self.client = Pinecone( + api_key=self.api_key, + pool_threads=20, # Improved connection pool size + timeout=30, # Reasonable timeout for operations + ) + self.using_grpc = False + log.info("Using Pinecone HTTP client (gRPC not available)") # Persistent executor for batch operations self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=5) @@ -94,12 +119,53 @@ def _initialize_index(self) -> None: log.info(f"Using existing Pinecone index '{self.index_name}'") # Connect to the index - self.index = self.client.Index(self.index_name) + self.index = self.client.Index( + self.index_name, + pool_threads=20, # Enhanced connection pool for index operations + ) except Exception as e: log.error(f"Failed to initialize Pinecone index: {e}") raise RuntimeError(f"Failed to initialize Pinecone index: {e}") + def _retry_pinecone_operation(self, operation_func, max_retries=3): + """Retry Pinecone operations with exponential backoff for rate limits and network issues.""" + for attempt in range(max_retries): + try: + return operation_func() + except Exception as e: + error_str = str(e).lower() + # Check if it's a retryable error (rate limits, network issues, timeouts) + is_retryable = any( + keyword in error_str + for keyword in [ + "rate limit", + "quota", + "timeout", + "network", + "connection", + "unavailable", + "internal error", + "429", + "500", + "502", + "503", + "504", + ] + ) + + if not is_retryable or attempt == max_retries - 1: + # Don't retry for non-retryable errors or on final attempt + raise + + # Exponential backoff with jitter + delay = (2**attempt) + random.uniform(0, 1) + log.warning( + f"Pinecone operation failed (attempt {attempt + 1}/{max_retries}), " + f"retrying in {delay:.2f}s: {e}" + ) + time.sleep(delay) + def _create_points( self, items: List[VectorItem], collection_name_with_prefix: str ) -> List[Dict[str, Any]]: @@ -119,7 +185,7 @@ def _create_points( point = { "id": item["id"], "values": item["vector"], - "metadata": metadata, + "metadata": process_metadata(metadata), } points.append(point) return points @@ -147,8 +213,8 @@ def _result_to_get_result(self, matches: list) -> GetResult: metadatas = [] for match in matches: - metadata = match.get("metadata", {}) - ids.append(match["id"]) + metadata = getattr(match, "metadata", {}) or {} + ids.append(match.id if hasattr(match, "id") else match["id"]) documents.append(metadata.get("text", "")) metadatas.append(metadata) @@ -174,7 +240,8 @@ def has_collection(self, collection_name: str) -> bool: filter={"collection_name": collection_name_with_prefix}, include_metadata=False, ) - return len(response.matches) > 0 + matches = getattr(response, "matches", []) or [] + return len(matches) > 0 except Exception as e: log.exception( f"Error checking collection '{collection_name_with_prefix}': {e}" @@ -225,7 +292,8 @@ def insert(self, collection_name: str, items: List[VectorItem]) -> None: elapsed = time.time() - start_time log.debug(f"Insert of {len(points)} vectors took {elapsed:.2f} seconds") log.info( - f"Successfully inserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" + f"Successfully inserted {len(points)} vectors in parallel batches " + f"into '{collection_name_with_prefix}'" ) def upsert(self, collection_name: str, items: List[VectorItem]) -> None: @@ -256,7 +324,8 @@ def upsert(self, collection_name: str, items: List[VectorItem]) -> None: elapsed = time.time() - start_time log.debug(f"Upsert of {len(points)} vectors took {elapsed:.2f} seconds") log.info( - f"Successfully upserted {len(points)} vectors in parallel batches into '{collection_name_with_prefix}'" + f"Successfully upserted {len(points)} vectors in parallel batches " + f"into '{collection_name_with_prefix}'" ) async def insert_async(self, collection_name: str, items: List[VectorItem]) -> None: @@ -287,7 +356,8 @@ async def insert_async(self, collection_name: str, items: List[VectorItem]) -> N log.error(f"Error in async insert batch: {result}") raise result log.info( - f"Successfully async inserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" + f"Successfully async inserted {len(points)} vectors in batches " + f"into '{collection_name_with_prefix}'" ) async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> None: @@ -318,35 +388,10 @@ async def upsert_async(self, collection_name: str, items: List[VectorItem]) -> N log.error(f"Error in async upsert batch: {result}") raise result log.info( - f"Successfully async upserted {len(points)} vectors in batches into '{collection_name_with_prefix}'" + f"Successfully async upserted {len(points)} vectors in batches " + f"into '{collection_name_with_prefix}'" ) - def streaming_upsert(self, collection_name: str, items: List[VectorItem]) -> None: - """Perform a streaming upsert over gRPC for performance testing.""" - if not items: - log.warning("No items to upsert via streaming") - return - - collection_name_with_prefix = self._get_collection_name_with_prefix( - collection_name - ) - points = self._create_points(items, collection_name_with_prefix) - - # Open a streaming upsert channel - stream = self.index.streaming_upsert() - try: - for point in points: - # send each point over the stream - stream.send(point) - # close the stream to finalize - stream.close() - log.info( - f"Successfully streamed upsert of {len(points)} vectors into '{collection_name_with_prefix}'" - ) - except Exception as e: - log.error(f"Error during streaming upsert: {e}") - raise - def search( self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int ) -> Optional[SearchResult]: @@ -374,7 +419,8 @@ def search( filter={"collection_name": collection_name_with_prefix}, ) - if not query_response.matches: + matches = getattr(query_response, "matches", []) or [] + if not matches: # Return empty result if no matches return SearchResult( ids=[[]], @@ -384,13 +430,13 @@ def search( ) # Convert to GetResult format - get_result = self._result_to_get_result(query_response.matches) + get_result = self._result_to_get_result(matches) # Calculate normalized distances based on metric distances = [ [ - self._normalize_distance(match.score) - for match in query_response.matches + self._normalize_distance(getattr(match, "score", 0.0)) + for match in matches ] ] @@ -432,7 +478,8 @@ def query( include_metadata=True, ) - return self._result_to_get_result(query_response.matches) + matches = getattr(query_response, "matches", []) or [] + return self._result_to_get_result(matches) except Exception as e: log.error(f"Error querying collection '{collection_name}': {e}") @@ -456,7 +503,8 @@ def get(self, collection_name: str) -> Optional[GetResult]: filter={"collection_name": collection_name_with_prefix}, ) - return self._result_to_get_result(query_response.matches) + matches = getattr(query_response, "matches", []) or [] + return self._result_to_get_result(matches) except Exception as e: log.error(f"Error getting collection '{collection_name}': {e}") @@ -482,10 +530,12 @@ def delete( # This is a limitation of Pinecone - be careful with ID uniqueness self.index.delete(ids=batch_ids) log.debug( - f"Deleted batch of {len(batch_ids)} vectors by ID from '{collection_name_with_prefix}'" + f"Deleted batch of {len(batch_ids)} vectors by ID " + f"from '{collection_name_with_prefix}'" ) log.info( - f"Successfully deleted {len(ids)} vectors by ID from '{collection_name_with_prefix}'" + f"Successfully deleted {len(ids)} vectors by ID " + f"from '{collection_name_with_prefix}'" ) elif filter: @@ -516,12 +566,12 @@ def reset(self) -> None: raise def close(self): - """Shut down the gRPC channel and thread pool.""" + """Shut down resources.""" try: - self.client.close() - log.info("Pinecone gRPC channel closed.") + # The new Pinecone client doesn't need explicit closing + pass except Exception as e: - log.warning(f"Failed to close Pinecone gRPC channel: {e}") + log.warning(f"Failed to clean up Pinecone resources: {e}") self._executor.shutdown(wait=True) def __enter__(self): diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index dfe2979076f..ea432974993 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -18,6 +18,9 @@ QDRANT_ON_DISK, QDRANT_GRPC_PORT, QDRANT_PREFER_GRPC, + QDRANT_COLLECTION_PREFIX, + QDRANT_TIMEOUT, + QDRANT_HNSW_M, ) from open_webui.env import SRC_LOG_LEVELS @@ -29,12 +32,14 @@ class QdrantClient(VectorDBBase): def __init__(self): - self.collection_prefix = "open-webui" + self.collection_prefix = QDRANT_COLLECTION_PREFIX self.QDRANT_URI = QDRANT_URI self.QDRANT_API_KEY = QDRANT_API_KEY self.QDRANT_ON_DISK = QDRANT_ON_DISK self.PREFER_GRPC = QDRANT_PREFER_GRPC self.GRPC_PORT = QDRANT_GRPC_PORT + self.QDRANT_TIMEOUT = QDRANT_TIMEOUT + self.QDRANT_HNSW_M = QDRANT_HNSW_M if not self.QDRANT_URI: self.client = None @@ -52,9 +57,14 @@ def __init__(self): grpc_port=self.GRPC_PORT, prefer_grpc=self.PREFER_GRPC, api_key=self.QDRANT_API_KEY, + timeout=self.QDRANT_TIMEOUT, ) else: - self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) + self.client = Qclient( + url=self.QDRANT_URI, + api_key=self.QDRANT_API_KEY, + timeout=QDRANT_TIMEOUT, + ) def _result_to_get_result(self, points) -> GetResult: ids = [] @@ -84,8 +94,30 @@ def _create_collection(self, collection_name: str, dimension: int): distance=models.Distance.COSINE, on_disk=self.QDRANT_ON_DISK, ), + hnsw_config=models.HnswConfigDiff( + m=self.QDRANT_HNSW_M, + ), ) + # Create payload indexes for efficient filtering + self.client.create_payload_index( + collection_name=collection_name_with_prefix, + field_name="metadata.hash", + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD, + is_tenant=False, + on_disk=self.QDRANT_ON_DISK, + ), + ) + self.client.create_payload_index( + collection_name=collection_name_with_prefix, + field_name="metadata.file_id", + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD, + is_tenant=False, + on_disk=self.QDRANT_ON_DISK, + ), + ) log.info(f"collection {collection_name_with_prefix} successfully created!") def _create_collection_if_not_exists(self, collection_name, dimension): @@ -151,23 +183,23 @@ def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) ) ) - points = self.client.query_points( + points = self.client.scroll( collection_name=f"{self.collection_prefix}_{collection_name}", - query_filter=models.Filter(should=field_conditions), + scroll_filter=models.Filter(should=field_conditions), limit=limit, ) - return self._result_to_get_result(points.points) + return self._result_to_get_result(points[0]) except Exception as e: log.exception(f"Error querying a collection '{collection_name}': {e}") return None def get(self, collection_name: str) -> Optional[GetResult]: # Get all the items in the collection. - points = self.client.query_points( + points = self.client.scroll( collection_name=f"{self.collection_prefix}_{collection_name}", limit=NO_LIMIT, # otherwise qdrant would set limit to 10! ) - return self._result_to_get_result(points.points) + return self._result_to_get_result(points[0]) def insert(self, collection_name: str, items: list[VectorItem]): # Insert the items into the collection, if the collection does not exist, it will be created. diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index e83c437ef77..e9fa03d4591 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Dict, Any from urllib.parse import urlparse import grpc @@ -9,6 +9,9 @@ QDRANT_ON_DISK, QDRANT_PREFER_GRPC, QDRANT_URI, + QDRANT_COLLECTION_PREFIX, + QDRANT_TIMEOUT, + QDRANT_HNSW_M, ) from open_webui.env import SRC_LOG_LEVELS from open_webui.retrieval.vector.main import ( @@ -23,39 +26,62 @@ from qdrant_client.models import models NO_LIMIT = 999999999 +TENANT_ID_FIELD = "tenant_id" +DEFAULT_DIMENSION = 384 log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) +def _tenant_filter(tenant_id: str) -> models.FieldCondition: + return models.FieldCondition( + key=TENANT_ID_FIELD, match=models.MatchValue(value=tenant_id) + ) + + +def _metadata_filter(key: str, value: Any) -> models.FieldCondition: + return models.FieldCondition( + key=f"metadata.{key}", match=models.MatchValue(value=value) + ) + + class QdrantClient(VectorDBBase): def __init__(self): - self.collection_prefix = "open-webui" + self.collection_prefix = QDRANT_COLLECTION_PREFIX self.QDRANT_URI = QDRANT_URI self.QDRANT_API_KEY = QDRANT_API_KEY self.QDRANT_ON_DISK = QDRANT_ON_DISK self.PREFER_GRPC = QDRANT_PREFER_GRPC self.GRPC_PORT = QDRANT_GRPC_PORT + self.QDRANT_TIMEOUT = QDRANT_TIMEOUT + self.QDRANT_HNSW_M = QDRANT_HNSW_M if not self.QDRANT_URI: - self.client = None - return + raise ValueError( + "QDRANT_URI is not set. Please configure it in the environment variables." + ) # Unified handling for either scheme parsed = urlparse(self.QDRANT_URI) host = parsed.hostname or self.QDRANT_URI http_port = parsed.port or 6333 # default REST port - if self.PREFER_GRPC: - self.client = Qclient( + self.client = ( + Qclient( host=host, port=http_port, grpc_port=self.GRPC_PORT, prefer_grpc=self.PREFER_GRPC, api_key=self.QDRANT_API_KEY, + timeout=self.QDRANT_TIMEOUT, ) - else: - self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) + if self.PREFER_GRPC + else Qclient( + url=self.QDRANT_URI, + api_key=self.QDRANT_API_KEY, + timeout=self.QDRANT_TIMEOUT, + ) + ) # Main collection types for multi-tenancy self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" @@ -65,23 +91,13 @@ def __init__(self): self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" def _result_to_get_result(self, points) -> GetResult: - ids = [] - documents = [] - metadatas = [] - + ids, documents, metadatas = [], [], [] for point in points: payload = point.payload ids.append(point.id) documents.append(payload["text"]) metadatas.append(payload["metadata"]) - - return GetResult( - **{ - "ids": [ids], - "documents": [documents], - "metadatas": [metadatas], - } - ) + return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas]) def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: """ @@ -89,6 +105,13 @@ def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str] Returns: tuple: (collection_name, tenant_id) + + WARNING: This mapping relies on current Open WebUI naming conventions for + collection names. If Open WebUI changes how it generates collection names + (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash + formats), this mapping will break and route data to incorrect collections. + POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT + DATA MAPPING INSIDE THE DATABASE. """ # Check for user memory collections tenant_id = collection_name @@ -113,143 +136,53 @@ def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str] else: return self.KNOWLEDGE_COLLECTION, tenant_id - def _extract_error_message(self, exception): - """ - Extract error message from either HTTP or gRPC exceptions - - Returns: - tuple: (status_code, error_message) - """ - # Check if it's an HTTP exception - if isinstance(exception, UnexpectedResponse): - try: - error_data = exception.structured() - error_msg = error_data.get("status", {}).get("error", "") - return exception.status_code, error_msg - except Exception as inner_e: - log.error(f"Failed to parse HTTP error: {inner_e}") - return exception.status_code, str(exception) - - # Check if it's a gRPC exception - elif isinstance(exception, grpc.RpcError): - # Extract status code from gRPC error - status_code = None - if hasattr(exception, "code") and callable(exception.code): - status_code = exception.code().value[0] - - # Extract error message - error_msg = str(exception) - if "details =" in error_msg: - # Parse the details line which contains the actual error message - try: - details_line = [ - line.strip() - for line in error_msg.split("\n") - if "details =" in line - ][0] - error_msg = details_line.split("details =")[1].strip(' "') - except (IndexError, AttributeError): - # Fall back to full message if parsing fails - pass - - return status_code, error_msg - - # For any other type of exception - return None, str(exception) - - def _is_collection_not_found_error(self, exception): - """ - Check if the exception is due to collection not found, supporting both HTTP and gRPC - """ - status_code, error_msg = self._extract_error_message(exception) - - # HTTP error (404) - if ( - status_code == 404 - and "Collection" in error_msg - and "doesn't exist" in error_msg - ): - return True - - # gRPC error (NOT_FOUND status) - if ( - isinstance(exception, grpc.RpcError) - and exception.code() == grpc.StatusCode.NOT_FOUND - ): - return True - - return False - - def _is_dimension_mismatch_error(self, exception): + def _create_multi_tenant_collection( + self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION + ): """ - Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC + Creates a collection with multi-tenancy configuration and payload indexes for tenant_id and metadata fields. """ - status_code, error_msg = self._extract_error_message(exception) - - # Common patterns in both HTTP and gRPC - return ( - "Vector dimension error" in error_msg - or "dimensions mismatch" in error_msg - or "invalid vector size" in error_msg + self.client.create_collection( + collection_name=mt_collection_name, + vectors_config=models.VectorParams( + size=dimension, + distance=models.Distance.COSINE, + on_disk=self.QDRANT_ON_DISK, + ), + # Disable global index building due to multitenancy + # For more details https://qdrant.tech/documentation/guides/multiple-partitions/#calibrate-performance + hnsw_config=models.HnswConfigDiff( + payload_m=self.QDRANT_HNSW_M, + m=0, + ), + ) + log.info( + f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!" ) - def _create_multi_tenant_collection_if_not_exists( - self, mt_collection_name: str, dimension: int = 384 - ): - """ - Creates a collection with multi-tenancy configuration if it doesn't exist. - Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'. - When creating collections dynamically (insert/upsert), the actual vector dimensions will be used. - """ - try: - # Try to create the collection directly - will fail if it already exists - self.client.create_collection( - collection_name=mt_collection_name, - vectors_config=models.VectorParams( - size=dimension, - distance=models.Distance.COSINE, - on_disk=self.QDRANT_ON_DISK, - ), - hnsw_config=models.HnswConfigDiff( - payload_m=16, # Enable per-tenant indexing - m=0, - on_disk=self.QDRANT_ON_DISK, - ), - ) + self.client.create_payload_index( + collection_name=mt_collection_name, + field_name=TENANT_ID_FIELD, + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD, + is_tenant=True, + on_disk=self.QDRANT_ON_DISK, + ), + ) - # Create tenant ID payload index + for field in ("metadata.hash", "metadata.file_id"): self.client.create_payload_index( collection_name=mt_collection_name, - field_name="tenant_id", + field_name=field, field_schema=models.KeywordIndexParams( type=models.KeywordIndexType.KEYWORD, - is_tenant=True, on_disk=self.QDRANT_ON_DISK, ), - wait=True, ) - log.info( - f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!" - ) - except (UnexpectedResponse, grpc.RpcError) as e: - # Check for the specific error indicating collection already exists - status_code, error_msg = self._extract_error_message(e) - - # HTTP status code 409 or gRPC ALREADY_EXISTS - if (isinstance(e, UnexpectedResponse) and status_code == 409) or ( - isinstance(e, grpc.RpcError) - and e.code() == grpc.StatusCode.ALREADY_EXISTS - ): - if "already exists" in error_msg: - log.debug(f"Collection {mt_collection_name} already exists") - return - # If it's not an already exists error, re-raise - raise e - except Exception as e: - raise e - - def _create_points(self, items: list[VectorItem], tenant_id: str): + def _create_points( + self, items: List[VectorItem], tenant_id: str + ) -> List[PointStruct]: """ Create point structs from vector items with tenant ID. """ @@ -260,56 +193,42 @@ def _create_points(self, items: list[VectorItem], tenant_id: str): payload={ "text": item["text"], "metadata": item["metadata"], - "tenant_id": tenant_id, + TENANT_ID_FIELD: tenant_id, }, ) for item in items ] + def _ensure_collection( + self, mt_collection_name: str, dimension: int = DEFAULT_DIMENSION + ): + """ + Ensure the collection exists and payload indexes are created for tenant_id and metadata fields. + """ + if not self.client.collection_exists(collection_name=mt_collection_name): + self._create_multi_tenant_collection(mt_collection_name, dimension) + def has_collection(self, collection_name: str) -> bool: """ Check if a logical collection exists by checking for any points with the tenant ID. """ if not self.client: return False - - # Map to multi-tenant collection and tenant ID mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - - # Create tenant filter - tenant_filter = models.FieldCondition( - key="tenant_id", match=models.MatchValue(value=tenant_id) - ) - - try: - # Try directly querying - most of the time collection should exist - response = self.client.query_points( - collection_name=mt_collection, - query_filter=models.Filter(must=[tenant_filter]), - limit=1, - ) - - # Collection exists with this tenant ID if there are points - return len(response.points) > 0 - except (UnexpectedResponse, grpc.RpcError) as e: - if self._is_collection_not_found_error(e): - log.debug(f"Collection {mt_collection} doesn't exist") - return False - else: - # For other API errors, log and return False - _, error_msg = self._extract_error_message(e) - log.warning(f"Unexpected Qdrant error: {error_msg}") - return False - except Exception as e: - # For any other errors, log and return False - log.debug(f"Error checking collection {mt_collection}: {e}") + if not self.client.collection_exists(collection_name=mt_collection): return False + tenant_filter = _tenant_filter(tenant_id) + count_result = self.client.count( + collection_name=mt_collection, + count_filter=models.Filter(must=[tenant_filter]), + ) + return count_result.count > 0 def delete( self, collection_name: str, - ids: Optional[list[str]] = None, - filter: Optional[dict] = None, + ids: Optional[List[str]] = None, + filter: Optional[Dict[str, Any]] = None, ): """ Delete vectors by ID or filter from a collection with tenant isolation. @@ -317,189 +236,76 @@ def delete( if not self.client: return None - # Map to multi-tenant collection and tenant ID mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) + if not self.client.collection_exists(collection_name=mt_collection): + log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") + return None - # Create tenant filter - tenant_filter = models.FieldCondition( - key="tenant_id", match=models.MatchValue(value=tenant_id) - ) - - must_conditions = [tenant_filter] + must_conditions = [_tenant_filter(tenant_id)] should_conditions = [] - if ids: - for id_value in ids: - should_conditions.append( - models.FieldCondition( - key="metadata.id", - match=models.MatchValue(value=id_value), - ), - ) + should_conditions = [_metadata_filter("id", id_value) for id_value in ids] elif filter: - for key, value in filter.items(): - must_conditions.append( - models.FieldCondition( - key=f"metadata.{key}", - match=models.MatchValue(value=value), - ), - ) - - try: - # Try to delete directly - most of the time collection should exist - update_result = self.client.delete( - collection_name=mt_collection, - points_selector=models.FilterSelector( - filter=models.Filter(must=must_conditions, should=should_conditions) - ), - ) + must_conditions += [_metadata_filter(k, v) for k, v in filter.items()] - return update_result - except (UnexpectedResponse, grpc.RpcError) as e: - if self._is_collection_not_found_error(e): - log.debug( - f"Collection {mt_collection} doesn't exist, nothing to delete" - ) - return None - else: - # For other API errors, log and re-raise - _, error_msg = self._extract_error_message(e) - log.warning(f"Unexpected Qdrant error: {error_msg}") - raise - except Exception as e: - # For non-Qdrant exceptions, re-raise - raise + return self.client.delete( + collection_name=mt_collection, + points_selector=models.FilterSelector( + filter=models.Filter(must=must_conditions, should=should_conditions) + ), + ) def search( - self, collection_name: str, vectors: list[list[float | int]], limit: int + self, collection_name: str, vectors: List[List[float | int]], limit: int ) -> Optional[SearchResult]: """ Search for the nearest neighbor items based on the vectors with tenant isolation. """ - if not self.client: + if not self.client or not vectors: return None - - # Map to multi-tenant collection and tenant ID mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - - # Get the vector dimension from the query vector - dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None - - try: - # Try the search operation directly - most of the time collection should exist - - # Create tenant filter - tenant_filter = models.FieldCondition( - key="tenant_id", match=models.MatchValue(value=tenant_id) - ) - - # Ensure vector dimensions match the collection - collection_dim = self.client.get_collection( - mt_collection - ).config.params.vectors.size - - if collection_dim != dimension: - if collection_dim < dimension: - vectors = [vector[:collection_dim] for vector in vectors] - else: - vectors = [ - vector + [0] * (collection_dim - dimension) - for vector in vectors - ] - - # Search with tenant filter - prefetch_query = models.Prefetch( - filter=models.Filter(must=[tenant_filter]), - limit=NO_LIMIT, - ) - query_response = self.client.query_points( - collection_name=mt_collection, - query=vectors[0], - prefetch=prefetch_query, - limit=limit, - ) - - get_result = self._result_to_get_result(query_response.points) - return SearchResult( - ids=get_result.ids, - documents=get_result.documents, - metadatas=get_result.metadatas, - # qdrant distance is [-1, 1], normalize to [0, 1] - distances=[ - [(point.score + 1.0) / 2.0 for point in query_response.points] - ], - ) - except (UnexpectedResponse, grpc.RpcError) as e: - if self._is_collection_not_found_error(e): - log.debug( - f"Collection {mt_collection} doesn't exist, search returns None" - ) - return None - else: - # For other API errors, log and re-raise - _, error_msg = self._extract_error_message(e) - log.warning(f"Unexpected Qdrant error during search: {error_msg}") - raise - except Exception as e: - # For non-Qdrant exceptions, log and return None - log.exception(f"Error searching collection '{collection_name}': {e}") + if not self.client.collection_exists(collection_name=mt_collection): + log.debug(f"Collection {mt_collection} doesn't exist, search returns None") return None - def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): + tenant_filter = _tenant_filter(tenant_id) + query_response = self.client.query_points( + collection_name=mt_collection, + query=vectors[0], + limit=limit, + query_filter=models.Filter(must=[tenant_filter]), + ) + get_result = self._result_to_get_result(query_response.points) + return SearchResult( + ids=get_result.ids, + documents=get_result.documents, + metadatas=get_result.metadatas, + distances=[[(point.score + 1.0) / 2.0 for point in query_response.points]], + ) + + def query( + self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None + ): """ Query points with filters and tenant isolation. """ if not self.client: return None - - # Map to multi-tenant collection and tenant ID mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - - # Set default limit if not provided + if not self.client.collection_exists(collection_name=mt_collection): + log.debug(f"Collection {mt_collection} doesn't exist, query returns None") + return None if limit is None: limit = NO_LIMIT - - # Create tenant filter - tenant_filter = models.FieldCondition( - key="tenant_id", match=models.MatchValue(value=tenant_id) - ) - - # Create metadata filters - field_conditions = [] - for key, value in filter.items(): - field_conditions.append( - models.FieldCondition( - key=f"metadata.{key}", match=models.MatchValue(value=value) - ) - ) - - # Combine tenant filter with metadata filters + tenant_filter = _tenant_filter(tenant_id) + field_conditions = [_metadata_filter(k, v) for k, v in filter.items()] combined_filter = models.Filter(must=[tenant_filter, *field_conditions]) - - try: - # Try the query directly - most of the time collection should exist - points = self.client.query_points( - collection_name=mt_collection, - query_filter=combined_filter, - limit=limit, - ) - - return self._result_to_get_result(points.points) - except (UnexpectedResponse, grpc.RpcError) as e: - if self._is_collection_not_found_error(e): - log.debug( - f"Collection {mt_collection} doesn't exist, query returns None" - ) - return None - else: - # For other API errors, log and re-raise - _, error_msg = self._extract_error_message(e) - log.warning(f"Unexpected Qdrant error during query: {error_msg}") - raise - except Exception as e: - # For non-Qdrant exceptions, log and re-raise - log.exception(f"Error querying collection '{collection_name}': {e}") - return None + points = self.client.scroll( + collection_name=mt_collection, + scroll_filter=combined_filter, + limit=limit, + ) + return self._result_to_get_result(points[0]) def get(self, collection_name: str) -> Optional[GetResult]: """ @@ -507,169 +313,36 @@ def get(self, collection_name: str) -> Optional[GetResult]: """ if not self.client: return None - - # Map to multi-tenant collection and tenant ID mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - - # Create tenant filter - tenant_filter = models.FieldCondition( - key="tenant_id", match=models.MatchValue(value=tenant_id) - ) - - try: - # Try to get points directly - most of the time collection should exist - points = self.client.query_points( - collection_name=mt_collection, - query_filter=models.Filter(must=[tenant_filter]), - limit=NO_LIMIT, - ) - - return self._result_to_get_result(points.points) - except (UnexpectedResponse, grpc.RpcError) as e: - if self._is_collection_not_found_error(e): - log.debug(f"Collection {mt_collection} doesn't exist, get returns None") - return None - else: - # For other API errors, log and re-raise - _, error_msg = self._extract_error_message(e) - log.warning(f"Unexpected Qdrant error during get: {error_msg}") - raise - except Exception as e: - # For non-Qdrant exceptions, log and return None - log.exception(f"Error getting collection '{collection_name}': {e}") + if not self.client.collection_exists(collection_name=mt_collection): + log.debug(f"Collection {mt_collection} doesn't exist, get returns None") return None - - def _handle_operation_with_error_retry( - self, operation_name, mt_collection, points, dimension - ): - """ - Private helper to handle common error cases for insert and upsert operations. - - Args: - operation_name: 'insert' or 'upsert' - mt_collection: The multi-tenant collection name - points: The vector points to insert/upsert - dimension: The dimension of the vectors - - Returns: - The operation result (for upsert) or None (for insert) - """ - try: - if operation_name == "insert": - self.client.upload_points(mt_collection, points) - return None - else: # upsert - return self.client.upsert(mt_collection, points) - except (UnexpectedResponse, grpc.RpcError) as e: - # Handle collection not found - if self._is_collection_not_found_error(e): - log.info( - f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}." - ) - # Create collection with correct dimensions from our vectors - self._create_multi_tenant_collection_if_not_exists( - mt_collection_name=mt_collection, dimension=dimension - ) - # Try operation again - no need for dimension adjustment since we just created with correct dimensions - if operation_name == "insert": - self.client.upload_points(mt_collection, points) - return None - else: # upsert - return self.client.upsert(mt_collection, points) - - # Handle dimension mismatch - elif self._is_dimension_mismatch_error(e): - # For dimension errors, the collection must exist, so get its configuration - mt_collection_info = self.client.get_collection(mt_collection) - existing_size = mt_collection_info.config.params.vectors.size - - log.info( - f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}" - ) - - if existing_size < dimension: - # Truncate vectors to fit - log.info( - f"Truncating vectors from {dimension} to {existing_size} dimensions" - ) - points = [ - PointStruct( - id=point.id, - vector=point.vector[:existing_size], - payload=point.payload, - ) - for point in points - ] - elif existing_size > dimension: - # Pad vectors with zeros - log.info( - f"Padding vectors from {dimension} to {existing_size} dimensions with zeros" - ) - points = [ - PointStruct( - id=point.id, - vector=point.vector - + [0] * (existing_size - len(point.vector)), - payload=point.payload, - ) - for point in points - ] - # Try operation again with adjusted dimensions - if operation_name == "insert": - self.client.upload_points(mt_collection, points) - return None - else: # upsert - return self.client.upsert(mt_collection, points) - else: - # Not a known error we can handle, log and re-raise - _, error_msg = self._extract_error_message(e) - log.warning(f"Unhandled Qdrant error: {error_msg}") - raise - except Exception as e: - # For non-Qdrant exceptions, re-raise - raise - - def insert(self, collection_name: str, items: list[VectorItem]): - """ - Insert items with tenant ID. - """ - if not self.client or not items: - return None - - # Map to multi-tenant collection and tenant ID - mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - - # Get dimensions from the actual vectors - dimension = len(items[0]["vector"]) if items else None - - # Create points with tenant ID - points = self._create_points(items, tenant_id) - - # Handle the operation with error retry - return self._handle_operation_with_error_retry( - "insert", mt_collection, points, dimension + tenant_filter = _tenant_filter(tenant_id) + points = self.client.scroll( + collection_name=mt_collection, + scroll_filter=models.Filter(must=[tenant_filter]), + limit=NO_LIMIT, ) + return self._result_to_get_result(points[0]) - def upsert(self, collection_name: str, items: list[VectorItem]): + def upsert(self, collection_name: str, items: List[VectorItem]): """ Upsert items with tenant ID. """ if not self.client or not items: return None - - # Map to multi-tenant collection and tenant ID mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - - # Get dimensions from the actual vectors - dimension = len(items[0]["vector"]) if items else None - - # Create points with tenant ID + dimension = len(items[0]["vector"]) + self._ensure_collection(mt_collection, dimension) points = self._create_points(items, tenant_id) + self.client.upload_points(mt_collection, points) + return None - # Handle the operation with error retry - return self._handle_operation_with_error_retry( - "upsert", mt_collection, points, dimension - ) + def insert(self, collection_name: str, items: List[VectorItem]): + """ + Insert items with tenant ID. + """ + return self.upsert(collection_name, items) def reset(self): """ @@ -677,11 +350,9 @@ def reset(self): """ if not self.client: return None - - collection_names = self.client.get_collections().collections - for collection_name in collection_names: - if collection_name.name.startswith(self.collection_prefix): - self.client.delete_collection(collection_name=collection_name.name) + for collection in self.client.get_collections().collections: + if collection.name.startswith(self.collection_prefix): + self.client.delete_collection(collection_name=collection.name) def delete_collection(self, collection_name: str): """ @@ -689,24 +360,13 @@ def delete_collection(self, collection_name: str): """ if not self.client: return None - - # Map to multi-tenant collection and tenant ID mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) - - tenant_filter = models.FieldCondition( - key="tenant_id", match=models.MatchValue(value=tenant_id) - ) - - field_conditions = [tenant_filter] - - update_result = self.client.delete( + if not self.client.collection_exists(collection_name=mt_collection): + log.debug(f"Collection {mt_collection} doesn't exist, nothing to delete") + return None + self.client.delete( collection_name=mt_collection, points_selector=models.FilterSelector( - filter=models.Filter(must=field_conditions) + filter=models.Filter(must=[_tenant_filter(tenant_id)]) ), ) - - if self.client.get_collection(mt_collection).points_count == 0: - self.client.delete_collection(mt_collection) - - return update_result diff --git a/backend/open_webui/retrieval/vector/dbs/s3vector.py b/backend/open_webui/retrieval/vector/dbs/s3vector.py new file mode 100644 index 00000000000..519ee5abad3 --- /dev/null +++ b/backend/open_webui/retrieval/vector/dbs/s3vector.py @@ -0,0 +1,775 @@ +from open_webui.retrieval.vector.utils import process_metadata +from open_webui.retrieval.vector.main import ( + VectorDBBase, + VectorItem, + GetResult, + SearchResult, +) +from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION +from open_webui.env import SRC_LOG_LEVELS +from typing import List, Optional, Dict, Any, Union +import logging +import boto3 + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +class S3VectorClient(VectorDBBase): + """ + AWS S3 Vector integration for Open WebUI Knowledge. + """ + + def __init__(self): + self.bucket_name = S3_VECTOR_BUCKET_NAME + self.region = S3_VECTOR_REGION + + # Simple validation - log warnings instead of raising exceptions + if not self.bucket_name: + log.warning("S3_VECTOR_BUCKET_NAME not set - S3Vector will not work") + if not self.region: + log.warning("S3_VECTOR_REGION not set - S3Vector will not work") + + if self.bucket_name and self.region: + try: + self.client = boto3.client("s3vectors", region_name=self.region) + log.info( + f"S3Vector client initialized for bucket '{self.bucket_name}' in region '{self.region}'" + ) + except Exception as e: + log.error(f"Failed to initialize S3Vector client: {e}") + self.client = None + else: + self.client = None + + def _create_index( + self, + index_name: str, + dimension: int, + data_type: str = "float32", + distance_metric: str = "cosine", + ) -> None: + """ + Create a new index in the S3 vector bucket for the given collection if it does not exist. + """ + if self.has_collection(index_name): + log.debug(f"Index '{index_name}' already exists, skipping creation") + return + + try: + self.client.create_index( + vectorBucketName=self.bucket_name, + indexName=index_name, + dataType=data_type, + dimension=dimension, + distanceMetric=distance_metric, + ) + log.info( + f"Created S3 index: {index_name} (dim={dimension}, type={data_type}, metric={distance_metric})" + ) + except Exception as e: + log.error(f"Error creating S3 index '{index_name}': {e}") + raise + + def _filter_metadata( + self, metadata: Dict[str, Any], item_id: str + ) -> Dict[str, Any]: + """ + Filter vector metadata keys to comply with S3 Vector API limit of 10 keys maximum. + """ + if not isinstance(metadata, dict) or len(metadata) <= 10: + return metadata + + # Keep only the first 10 keys, prioritizing important ones based on actual Open WebUI metadata + important_keys = [ + "text", # The actual document content + "file_id", # File ID + "source", # Document source file + "title", # Document title + "page", # Page number + "total_pages", # Total pages in document + "embedding_config", # Embedding configuration + "created_by", # User who created it + "name", # Document name + "hash", # Content hash + ] + filtered_metadata = {} + + # First, add important keys if they exist + for key in important_keys: + if key in metadata: + filtered_metadata[key] = metadata[key] + if len(filtered_metadata) >= 10: + break + + # If we still have room, add other keys + if len(filtered_metadata) < 10: + for key, value in metadata.items(): + if key not in filtered_metadata: + filtered_metadata[key] = value + if len(filtered_metadata) >= 10: + break + + log.warning( + f"Metadata for key '{item_id}' had {len(metadata)} keys, limited to 10 keys" + ) + return filtered_metadata + + def has_collection(self, collection_name: str) -> bool: + """ + Check if a vector index (collection) exists in the S3 vector bucket. + """ + + try: + response = self.client.list_indexes(vectorBucketName=self.bucket_name) + indexes = response.get("indexes", []) + return any(idx.get("indexName") == collection_name for idx in indexes) + except Exception as e: + log.error(f"Error listing indexes: {e}") + return False + + def delete_collection(self, collection_name: str) -> None: + """ + Delete an entire S3 Vector index/collection. + """ + + if not self.has_collection(collection_name): + log.warning( + f"Collection '{collection_name}' does not exist, nothing to delete" + ) + return + + try: + log.info(f"Deleting collection '{collection_name}'") + self.client.delete_index( + vectorBucketName=self.bucket_name, indexName=collection_name + ) + log.info(f"Successfully deleted collection '{collection_name}'") + except Exception as e: + log.error(f"Error deleting collection '{collection_name}': {e}") + raise + + def insert(self, collection_name: str, items: List[VectorItem]) -> None: + """ + Insert vector items into the S3 Vector index. Create index if it does not exist. + """ + if not items: + log.warning("No items to insert") + return + + dimension = len(items[0]["vector"]) + + try: + if not self.has_collection(collection_name): + log.info(f"Index '{collection_name}' does not exist. Creating index.") + self._create_index( + index_name=collection_name, + dimension=dimension, + data_type="float32", + distance_metric="cosine", + ) + + # Prepare vectors for insertion + vectors = [] + for item in items: + # Ensure vector data is in the correct format for S3 Vector API + vector_data = item["vector"] + if isinstance(vector_data, list): + # Convert list to float32 values as required by S3 Vector API + vector_data = [float(x) for x in vector_data] + + # Prepare metadata, ensuring the text field is preserved + metadata = item.get("metadata", {}).copy() + + # Add the text field to metadata so it's available for retrieval + metadata["text"] = item["text"] + + # Convert metadata to string format for consistency + metadata = process_metadata(metadata) + + # Filter metadata to comply with S3 Vector API limit of 10 keys + metadata = self._filter_metadata(metadata, item["id"]) + + vectors.append( + { + "key": item["id"], + "data": {"float32": vector_data}, + "metadata": metadata, + } + ) + + # Insert vectors in batches of 500 (S3 Vector API limit) + batch_size = 500 + for i in range(0, len(vectors), batch_size): + batch = vectors[i : i + batch_size] + self.client.put_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + vectors=batch, + ) + log.info( + f"Inserted batch {i//batch_size + 1}: {len(batch)} vectors into index '{collection_name}'." + ) + + log.info( + f"Completed insertion of {len(vectors)} vectors into index '{collection_name}'." + ) + except Exception as e: + log.error(f"Error inserting vectors: {e}") + raise + + def upsert(self, collection_name: str, items: List[VectorItem]) -> None: + """ + Insert or update vector items in the S3 Vector index. Create index if it does not exist. + """ + if not items: + log.warning("No items to upsert") + return + + dimension = len(items[0]["vector"]) + log.info(f"Upsert dimension: {dimension}") + + try: + if not self.has_collection(collection_name): + log.info( + f"Index '{collection_name}' does not exist. Creating index for upsert." + ) + self._create_index( + index_name=collection_name, + dimension=dimension, + data_type="float32", + distance_metric="cosine", + ) + + # Prepare vectors for upsert + vectors = [] + for item in items: + # Ensure vector data is in the correct format for S3 Vector API + vector_data = item["vector"] + if isinstance(vector_data, list): + # Convert list to float32 values as required by S3 Vector API + vector_data = [float(x) for x in vector_data] + + # Prepare metadata, ensuring the text field is preserved + metadata = item.get("metadata", {}).copy() + # Add the text field to metadata so it's available for retrieval + metadata["text"] = item["text"] + + # Convert metadata to string format for consistency + metadata = process_metadata(metadata) + + # Filter metadata to comply with S3 Vector API limit of 10 keys + metadata = self._filter_metadata(metadata, item["id"]) + + vectors.append( + { + "key": item["id"], + "data": {"float32": vector_data}, + "metadata": metadata, + } + ) + + # Upsert vectors in batches of 500 (S3 Vector API limit) + batch_size = 500 + for i in range(0, len(vectors), batch_size): + batch = vectors[i : i + batch_size] + if i == 0: # Log sample info for first batch only + log.info( + f"Upserting batch 1: {len(batch)} vectors. First vector sample: key={batch[0]['key']}, data_type={type(batch[0]['data']['float32'])}, data_len={len(batch[0]['data']['float32'])}" + ) + else: + log.info( + f"Upserting batch {i//batch_size + 1}: {len(batch)} vectors." + ) + + self.client.put_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + vectors=batch, + ) + + log.info( + f"Completed upsert of {len(vectors)} vectors into index '{collection_name}'." + ) + except Exception as e: + log.error(f"Error upserting vectors: {e}") + raise + + def search( + self, collection_name: str, vectors: List[List[Union[float, int]]], limit: int + ) -> Optional[SearchResult]: + """ + Search for similar vectors in a collection using multiple query vectors. + """ + + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist") + return None + + if not vectors: + log.warning("No query vectors provided") + return None + + try: + log.info( + f"Searching collection '{collection_name}' with {len(vectors)} query vectors, limit={limit}" + ) + + # Initialize result lists + all_ids = [] + all_documents = [] + all_metadatas = [] + all_distances = [] + + # Process each query vector + for i, query_vector in enumerate(vectors): + log.debug(f"Processing query vector {i+1}/{len(vectors)}") + + # Prepare the query vector in S3 Vector format + query_vector_dict = {"float32": [float(x) for x in query_vector]} + + # Call S3 Vector query API + response = self.client.query_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + topK=limit, + queryVector=query_vector_dict, + returnMetadata=True, + returnDistance=True, + ) + + # Process results for this query + query_ids = [] + query_documents = [] + query_metadatas = [] + query_distances = [] + + result_vectors = response.get("vectors", []) + + for vector in result_vectors: + vector_id = vector.get("key") + vector_metadata = vector.get("metadata", {}) + vector_distance = vector.get("distance", 0.0) + + # Extract document text from metadata + document_text = "" + if isinstance(vector_metadata, dict): + # Get the text field first (highest priority) + document_text = vector_metadata.get("text") + if not document_text: + # Fallback to other possible text fields + document_text = ( + vector_metadata.get("content") + or vector_metadata.get("document") + or vector_id + ) + else: + document_text = vector_id + + query_ids.append(vector_id) + query_documents.append(document_text) + query_metadatas.append(vector_metadata) + query_distances.append(vector_distance) + + # Add this query's results to the overall results + all_ids.append(query_ids) + all_documents.append(query_documents) + all_metadatas.append(query_metadatas) + all_distances.append(query_distances) + + log.info(f"Search completed. Found results for {len(all_ids)} queries") + + # Return SearchResult format + return SearchResult( + ids=all_ids if all_ids else None, + documents=all_documents if all_documents else None, + metadatas=all_metadatas if all_metadatas else None, + distances=all_distances if all_distances else None, + ) + + except Exception as e: + log.error(f"Error searching collection '{collection_name}': {str(e)}") + # Handle specific AWS exceptions + if hasattr(e, "response") and "Error" in e.response: + error_code = e.response["Error"]["Code"] + if error_code == "NotFoundException": + log.warning(f"Collection '{collection_name}' not found") + return None + elif error_code == "ValidationException": + log.error(f"Invalid query vector dimensions or parameters") + return None + elif error_code == "AccessDeniedException": + log.error( + f"Access denied for collection '{collection_name}'. Check permissions." + ) + return None + raise + + def query( + self, collection_name: str, filter: Dict, limit: Optional[int] = None + ) -> Optional[GetResult]: + """ + Query vectors from a collection using metadata filter. + """ + + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + if not filter: + log.warning("No filter provided, returning all vectors") + return self.get(collection_name) + + try: + log.info(f"Querying collection '{collection_name}' with filter: {filter}") + + # For S3 Vector, we need to use list_vectors and then filter results + # Since S3 Vector may not support complex server-side filtering, + # we'll retrieve all vectors and filter client-side + + # Get all vectors first + all_vectors_result = self.get(collection_name) + + if not all_vectors_result or not all_vectors_result.ids: + log.warning("No vectors found in collection") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + # Extract the lists from the result + all_ids = all_vectors_result.ids[0] if all_vectors_result.ids else [] + all_documents = ( + all_vectors_result.documents[0] if all_vectors_result.documents else [] + ) + all_metadatas = ( + all_vectors_result.metadatas[0] if all_vectors_result.metadatas else [] + ) + + # Apply client-side filtering + filtered_ids = [] + filtered_documents = [] + filtered_metadatas = [] + + for i, metadata in enumerate(all_metadatas): + if self._matches_filter(metadata, filter): + if i < len(all_ids): + filtered_ids.append(all_ids[i]) + if i < len(all_documents): + filtered_documents.append(all_documents[i]) + filtered_metadatas.append(metadata) + + # Apply limit if specified + if limit and len(filtered_ids) >= limit: + break + + log.info( + f"Filter applied: {len(filtered_ids)} vectors match out of {len(all_ids)} total" + ) + + # Return GetResult format + if filtered_ids: + return GetResult( + ids=[filtered_ids], + documents=[filtered_documents], + metadatas=[filtered_metadatas], + ) + else: + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + except Exception as e: + log.error(f"Error querying collection '{collection_name}': {str(e)}") + # Handle specific AWS exceptions + if hasattr(e, "response") and "Error" in e.response: + error_code = e.response["Error"]["Code"] + if error_code == "NotFoundException": + log.warning(f"Collection '{collection_name}' not found") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + elif error_code == "AccessDeniedException": + log.error( + f"Access denied for collection '{collection_name}'. Check permissions." + ) + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + raise + + def get(self, collection_name: str) -> Optional[GetResult]: + """ + Retrieve all vectors from a collection. + """ + + if not self.has_collection(collection_name): + log.warning(f"Collection '{collection_name}' does not exist") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + try: + log.info(f"Retrieving all vectors from collection '{collection_name}'") + + # Initialize result lists + all_ids = [] + all_documents = [] + all_metadatas = [] + + # Handle pagination + next_token = None + + while True: + # Prepare request parameters + request_params = { + "vectorBucketName": self.bucket_name, + "indexName": collection_name, + "returnData": False, # Don't include vector data (not needed for get) + "returnMetadata": True, # Include metadata + "maxResults": 500, # Use reasonable page size + } + + if next_token: + request_params["nextToken"] = next_token + + # Call S3 Vector API + response = self.client.list_vectors(**request_params) + + # Process vectors in this page + vectors = response.get("vectors", []) + + for vector in vectors: + vector_id = vector.get("key") + vector_data = vector.get("data", {}) + vector_metadata = vector.get("metadata", {}) + + # Extract the actual vector array + vector_array = vector_data.get("float32", []) + + # For documents, we try to extract text from metadata or use the vector ID + document_text = "" + if isinstance(vector_metadata, dict): + # Get the text field first (highest priority) + document_text = vector_metadata.get("text") + if not document_text: + # Fallback to other possible text fields + document_text = ( + vector_metadata.get("content") + or vector_metadata.get("document") + or vector_id + ) + + # Log the actual content for debugging + log.debug( + f"Document text preview (first 200 chars): {str(document_text)[:200]}" + ) + else: + document_text = vector_id + + all_ids.append(vector_id) + all_documents.append(document_text) + all_metadatas.append(vector_metadata) + + # Check if there are more pages + next_token = response.get("nextToken") + if not next_token: + break + + log.info( + f"Retrieved {len(all_ids)} vectors from collection '{collection_name}'" + ) + + # Return in GetResult format + # The Open WebUI GetResult expects lists of lists, so we wrap each list + if all_ids: + return GetResult( + ids=[all_ids], documents=[all_documents], metadatas=[all_metadatas] + ) + else: + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + + except Exception as e: + log.error( + f"Error retrieving vectors from collection '{collection_name}': {str(e)}" + ) + # Handle specific AWS exceptions + if hasattr(e, "response") and "Error" in e.response: + error_code = e.response["Error"]["Code"] + if error_code == "NotFoundException": + log.warning(f"Collection '{collection_name}' not found") + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + elif error_code == "AccessDeniedException": + log.error( + f"Access denied for collection '{collection_name}'. Check permissions." + ) + return GetResult(ids=[[]], documents=[[]], metadatas=[[]]) + raise + + def delete( + self, + collection_name: str, + ids: Optional[List[str]] = None, + filter: Optional[Dict] = None, + ) -> None: + """ + Delete vectors by ID or filter from a collection. + """ + + if not self.has_collection(collection_name): + log.warning( + f"Collection '{collection_name}' does not exist, nothing to delete" + ) + return + + # Check if this is a knowledge collection (not file-specific) + is_knowledge_collection = not collection_name.startswith("file-") + + try: + if ids: + # Delete by specific vector IDs/keys + log.info( + f"Deleting {len(ids)} vectors by IDs from collection '{collection_name}'" + ) + self.client.delete_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + keys=ids, + ) + log.info(f"Deleted {len(ids)} vectors from index '{collection_name}'") + + elif filter: + # Handle filter-based deletion + log.info( + f"Deleting vectors by filter from collection '{collection_name}': {filter}" + ) + + # If this is a knowledge collection and we have a file_id filter, + # also clean up the corresponding file-specific collection + if is_knowledge_collection and "file_id" in filter: + file_id = filter["file_id"] + file_collection_name = f"file-{file_id}" + if self.has_collection(file_collection_name): + log.info( + f"Found related file-specific collection '{file_collection_name}', deleting it to prevent duplicates" + ) + self.delete_collection(file_collection_name) + + # For the main collection, implement query-then-delete + # First, query to get IDs matching the filter + query_result = self.query(collection_name, filter) + if query_result and query_result.ids and query_result.ids[0]: + matching_ids = query_result.ids[0] + log.info( + f"Found {len(matching_ids)} vectors matching filter, deleting them" + ) + + # Delete the matching vectors by ID + self.client.delete_vectors( + vectorBucketName=self.bucket_name, + indexName=collection_name, + keys=matching_ids, + ) + log.info( + f"Deleted {len(matching_ids)} vectors from index '{collection_name}' using filter" + ) + else: + log.warning("No vectors found matching the filter criteria") + else: + log.warning("No IDs or filter provided for deletion") + except Exception as e: + log.error( + f"Error deleting vectors from collection '{collection_name}': {e}" + ) + raise + + def reset(self) -> None: + """ + Reset/clear all vector data. For S3 Vector, this deletes all indexes. + """ + + try: + log.warning( + "Reset called - this will delete all vector indexes in the S3 bucket" + ) + + # List all indexes + response = self.client.list_indexes(vectorBucketName=self.bucket_name) + indexes = response.get("indexes", []) + + if not indexes: + log.warning("No indexes found to delete") + return + + # Delete all indexes + deleted_count = 0 + for index in indexes: + index_name = index.get("indexName") + if index_name: + try: + self.client.delete_index( + vectorBucketName=self.bucket_name, indexName=index_name + ) + deleted_count += 1 + log.info(f"Deleted index: {index_name}") + except Exception as e: + log.error(f"Error deleting index '{index_name}': {e}") + + log.info(f"Reset completed: deleted {deleted_count} indexes") + + except Exception as e: + log.error(f"Error during reset: {e}") + raise + + def _matches_filter(self, metadata: Dict[str, Any], filter: Dict[str, Any]) -> bool: + """ + Check if metadata matches the given filter conditions. + """ + if not isinstance(metadata, dict) or not isinstance(filter, dict): + return False + + # Check each filter condition + for key, expected_value in filter.items(): + # Handle special operators + if key.startswith("$"): + if key == "$and": + # All conditions must match + if not isinstance(expected_value, list): + continue + for condition in expected_value: + if not self._matches_filter(metadata, condition): + return False + elif key == "$or": + # At least one condition must match + if not isinstance(expected_value, list): + continue + any_match = False + for condition in expected_value: + if self._matches_filter(metadata, condition): + any_match = True + break + if not any_match: + return False + continue + + # Get the actual value from metadata + actual_value = metadata.get(key) + + # Handle different types of expected values + if isinstance(expected_value, dict): + # Handle comparison operators + for op, op_value in expected_value.items(): + if op == "$eq": + if actual_value != op_value: + return False + elif op == "$ne": + if actual_value == op_value: + return False + elif op == "$in": + if ( + not isinstance(op_value, list) + or actual_value not in op_value + ): + return False + elif op == "$nin": + if isinstance(op_value, list) and actual_value in op_value: + return False + elif op == "$exists": + if bool(op_value) != (key in metadata): + return False + # Add more operators as needed + else: + # Simple equality check + if actual_value != expected_value: + return False + + return True diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index 72a3f6cebe2..7888c22be88 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -1,6 +1,10 @@ from open_webui.retrieval.vector.main import VectorDBBase from open_webui.retrieval.vector.type import VectorType -from open_webui.config import VECTOR_DB, ENABLE_QDRANT_MULTITENANCY_MODE +from open_webui.config import ( + VECTOR_DB, + ENABLE_QDRANT_MULTITENANCY_MODE, + ENABLE_MILVUS_MULTITENANCY_MODE, +) class Vector: @@ -12,9 +16,16 @@ def get_vector(vector_type: str) -> VectorDBBase: """ match vector_type: case VectorType.MILVUS: - from open_webui.retrieval.vector.dbs.milvus import MilvusClient + if ENABLE_MILVUS_MULTITENANCY_MODE: + from open_webui.retrieval.vector.dbs.milvus_multitenancy import ( + MilvusClient, + ) + + return MilvusClient() + else: + from open_webui.retrieval.vector.dbs.milvus import MilvusClient - return MilvusClient() + return MilvusClient() case VectorType.QDRANT: if ENABLE_QDRANT_MULTITENANCY_MODE: from open_webui.retrieval.vector.dbs.qdrant_multitenancy import ( @@ -30,6 +41,10 @@ def get_vector(vector_type: str) -> VectorDBBase: from open_webui.retrieval.vector.dbs.pinecone import PineconeClient return PineconeClient() + case VectorType.S3VECTOR: + from open_webui.retrieval.vector.dbs.s3vector import S3VectorClient + + return S3VectorClient() case VectorType.OPENSEARCH: from open_webui.retrieval.vector.dbs.opensearch import OpenSearchClient @@ -48,6 +63,10 @@ def get_vector(vector_type: str) -> VectorDBBase: from open_webui.retrieval.vector.dbs.chroma import ChromaClient return ChromaClient() + case VectorType.ORACLE23AI: + from open_webui.retrieval.vector.dbs.oracle23ai import Oracle23aiClient + + return Oracle23aiClient() case _: raise ValueError(f"Unsupported vector type: {vector_type}") diff --git a/backend/open_webui/retrieval/vector/type.py b/backend/open_webui/retrieval/vector/type.py index b03bcb48289..7e517c169cd 100644 --- a/backend/open_webui/retrieval/vector/type.py +++ b/backend/open_webui/retrieval/vector/type.py @@ -9,3 +9,5 @@ class VectorType(StrEnum): ELASTICSEARCH = "elasticsearch" OPENSEARCH = "opensearch" PGVECTOR = "pgvector" + ORACLE23AI = "oracle23ai" + S3VECTOR = "s3vector" diff --git a/backend/open_webui/retrieval/vector/utils.py b/backend/open_webui/retrieval/vector/utils.py new file mode 100644 index 00000000000..a597390b920 --- /dev/null +++ b/backend/open_webui/retrieval/vector/utils.py @@ -0,0 +1,28 @@ +from datetime import datetime + +KEYS_TO_EXCLUDE = ["content", "pages", "tables", "paragraphs", "sections", "figures"] + + +def filter_metadata(metadata: dict[str, any]) -> dict[str, any]: + metadata = { + key: value for key, value in metadata.items() if key not in KEYS_TO_EXCLUDE + } + return metadata + + +def process_metadata( + metadata: dict[str, any], +) -> dict[str, any]: + for key, value in metadata.items(): + # Remove large fields + if key in KEYS_TO_EXCLUDE: + del metadata[key] + + # Convert non-serializable fields to strings + if ( + isinstance(value, datetime) + or isinstance(value, list) + or isinstance(value, dict) + ): + metadata[key] = str(value) + return metadata diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py index 3075db990f5..7bea5756203 100644 --- a/backend/open_webui/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -36,7 +36,9 @@ def search_brave( return [ SearchResult( - link=result["url"], title=result.get("title"), snippet=result.get("snippet") + link=result["url"], + title=result.get("title"), + snippet=result.get("description"), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py index bf8ae6880bd..e4cf9d00ec7 100644 --- a/backend/open_webui/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -2,8 +2,8 @@ from typing import Optional from open_webui.retrieval.web.main import SearchResult, get_filtered_results -from duckduckgo_search import DDGS -from duckduckgo_search.exceptions import RatelimitException +from ddgs import DDGS +from ddgs.exceptions import RatelimitException from open_webui.env import SRC_LOG_LEVELS log = logging.getLogger(__name__) @@ -11,7 +11,10 @@ def search_duckduckgo( - query: str, count: int, filter_list: Optional[list[str]] = None + query: str, + count: int, + filter_list: Optional[list[str]] = None, + concurrent_requests: Optional[int] = None, ) -> list[SearchResult]: """ Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects. @@ -25,6 +28,9 @@ def search_duckduckgo( # Use the DDGS context manager to create a DDGS object search_results = [] with DDGS() as ddgs: + if concurrent_requests: + ddgs.threads = concurrent_requests + # Use the ddgs.text() method to perform the search try: search_results = ddgs.text( diff --git a/backend/open_webui/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py index 28a749e7d2e..dc1eafb3317 100644 --- a/backend/open_webui/retrieval/web/main.py +++ b/backend/open_webui/retrieval/web/main.py @@ -11,7 +11,7 @@ def get_filtered_results(results, filter_list): return results filtered_results = [] for result in results: - url = result.get("url") or result.get("link", "") + url = result.get("url") or result.get("link", "") or result.get("href", "") if not validators.url(url): continue domain = urlparse(url).netloc diff --git a/backend/open_webui/retrieval/web/ollama.py b/backend/open_webui/retrieval/web/ollama.py new file mode 100644 index 00000000000..a199a14389b --- /dev/null +++ b/backend/open_webui/retrieval/web/ollama.py @@ -0,0 +1,51 @@ +import logging +from dataclasses import dataclass +from typing import Optional + +import requests +from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.web.main import SearchResult + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_ollama_cloud( + url: str, + api_key: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, +) -> list[SearchResult]: + """Search using Ollama Search API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Ollama Search API key + query (str): The query to search for + count (int): Number of results to return + filter_list (Optional[list[str]]): List of domains to filter results by + """ + log.info(f"Searching with Ollama for query: {query}") + + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + payload = {"query": query, "max_results": count} + + try: + response = requests.post(f"{url}/api/web_search", headers=headers, json=payload) + response.raise_for_status() + data = response.json() + + results = data.get("results", []) + log.info(f"Found {len(results)} results") + + return [ + SearchResult( + link=result.get("url", ""), + title=result.get("title", ""), + snippet=result.get("content", ""), + ) + for result in results + ] + except Exception as e: + log.error(f"Error searching Ollama: {e}") + return [] diff --git a/backend/open_webui/retrieval/web/perplexity.py b/backend/open_webui/retrieval/web/perplexity.py index e5314eb1f73..4e046668fa0 100644 --- a/backend/open_webui/retrieval/web/perplexity.py +++ b/backend/open_webui/retrieval/web/perplexity.py @@ -1,10 +1,20 @@ import logging -from typing import Optional, List +from typing import Optional, Literal import requests from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.env import SRC_LOG_LEVELS +MODELS = Literal[ + "sonar", + "sonar-pro", + "sonar-reasoning", + "sonar-reasoning-pro", + "sonar-deep-research", +] +SEARCH_CONTEXT_USAGE_LEVELS = Literal["low", "medium", "high"] + + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["RAG"]) @@ -14,6 +24,8 @@ def search_perplexity( query: str, count: int, filter_list: Optional[list[str]] = None, + model: MODELS = "sonar", + search_context_usage: SEARCH_CONTEXT_USAGE_LEVELS = "medium", ) -> list[SearchResult]: """Search using Perplexity API and return the results as a list of SearchResult objects. @@ -21,6 +33,9 @@ def search_perplexity( api_key (str): A Perplexity API key query (str): The query to search for count (int): Maximum number of results to return + filter_list (Optional[list[str]]): List of domains to filter results + model (str): The Perplexity model to use (sonar, sonar-pro) + search_context_usage (str): Search context usage level (low, medium, high) """ @@ -33,7 +48,7 @@ def search_perplexity( # Create payload for the API call payload = { - "model": "sonar", + "model": model, "messages": [ { "role": "system", @@ -43,6 +58,9 @@ def search_perplexity( ], "temperature": 0.2, # Lower temperature for more factual responses "stream": False, + "web_search_options": { + "search_context_usage": search_context_usage, + }, } headers = { diff --git a/backend/open_webui/retrieval/web/perplexity_search.py b/backend/open_webui/retrieval/web/perplexity_search.py new file mode 100644 index 00000000000..e3e0caa2b39 --- /dev/null +++ b/backend/open_webui/retrieval/web/perplexity_search.py @@ -0,0 +1,64 @@ +import logging +from typing import Optional, Literal +import requests + +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.env import SRC_LOG_LEVELS + + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["RAG"]) + + +def search_perplexity_search( + api_key: str, + query: str, + count: int, + filter_list: Optional[list[str]] = None, +) -> list[SearchResult]: + """Search using Perplexity API and return the results as a list of SearchResult objects. + + Args: + api_key (str): A Perplexity API key + query (str): The query to search for + count (int): Maximum number of results to return + filter_list (Optional[list[str]]): List of domains to filter results + + """ + + # Handle PersistentConfig object + if hasattr(api_key, "__str__"): + api_key = str(api_key) + + try: + url = "https://api.perplexity.ai/search" + + # Create payload for the API call + payload = { + "query": query, + "max_results": count, + } + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + # Make the API request + response = requests.request("POST", url, json=payload, headers=headers) + # Parse the JSON response + json_response = response.json() + + # Extract citations from the response + results = json_response.get("results", []) + + return [ + SearchResult( + link=result["url"], title=result["title"], snippet=result["snippet"] + ) + for result in results + ] + + except Exception as e: + log.error(f"Error searching with Perplexity Search API: {e}") + return [] diff --git a/backend/open_webui/retrieval/web/searchapi.py b/backend/open_webui/retrieval/web/searchapi.py index 38bc0b5742e..d7704638c2b 100644 --- a/backend/open_webui/retrieval/web/searchapi.py +++ b/backend/open_webui/retrieval/web/searchapi.py @@ -42,7 +42,9 @@ def search_searchapi( results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], title=result["title"], snippet=result["snippet"] + link=result["link"], + title=result.get("title"), + snippet=result.get("snippet"), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/serpapi.py b/backend/open_webui/retrieval/web/serpapi.py index 028b6bcfe1f..8762210bfd2 100644 --- a/backend/open_webui/retrieval/web/serpapi.py +++ b/backend/open_webui/retrieval/web/serpapi.py @@ -42,7 +42,9 @@ def search_serpapi( results = get_filtered_results(results, filter_list) return [ SearchResult( - link=result["link"], title=result["title"], snippet=result["snippet"] + link=result["link"], + title=result.get("title"), + snippet=result.get("snippet"), ) for result in results[:count] ] diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index b8ec538d3b5..61356adb569 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -75,7 +75,8 @@ def safe_validate_urls(url: Sequence[str]) -> Sequence[str]: try: if validate_url(u): valid_urls.append(u) - except ValueError: + except Exception as e: + log.debug(f"Invalid URL {u}: {str(e)}") continue return valid_urls @@ -517,7 +518,7 @@ async def _fetch( async with session.get( url, **(self.requests_kwargs | kwargs), - ssl=AIOHTTP_CLIENT_SESSION_SSL, + allow_redirects=False, ) as response: if self.raise_for_status: response.raise_for_status() @@ -615,7 +616,7 @@ def get_web_loader( WebLoaderClass = SafeWebBaseLoader if WEB_LOADER_ENGINE.value == "playwright": WebLoaderClass = SafePlaywrightURLLoader - web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value * 1000 + web_loader_args["playwright_timeout"] = PLAYWRIGHT_TIMEOUT.value if PLAYWRIGHT_WS_URL.value: web_loader_args["playwright_ws_url"] = PLAYWRIGHT_WS_URL.value diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index a0f5af4fc4f..cb7a57b5b7e 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -3,21 +3,25 @@ import logging import os import uuid +import html from functools import lru_cache -from pathlib import Path from pydub import AudioSegment from pydub.silence import split_on_silence from concurrent.futures import ThreadPoolExecutor +from typing import Optional +from fnmatch import fnmatch import aiohttp import aiofiles import requests import mimetypes +from urllib.parse import urljoin, quote from fastapi import ( Depends, FastAPI, File, + Form, HTTPException, Request, UploadFile, @@ -93,12 +97,9 @@ def is_audio_conversion_required(file_path): # File is AAC/mp4a audio, recommend mp3 conversion return True - # If the codec name or file extension is in the supported formats - if ( - codec_name in SUPPORTED_FORMATS - or os.path.splitext(file_path)[1][1:].lower() in SUPPORTED_FORMATS - ): - return False # Already supported + # If the codec name is in the supported formats + if codec_name in SUPPORTED_FORMATS: + return False return True except Exception as e: @@ -153,6 +154,7 @@ def set_faster_whisper_model(model: str, auto_update: bool = False): class TTSConfigForm(BaseModel): OPENAI_API_BASE_URL: str OPENAI_API_KEY: str + OPENAI_PARAMS: Optional[dict] = None API_KEY: str ENGINE: str MODEL: str @@ -168,6 +170,7 @@ class STTConfigForm(BaseModel): OPENAI_API_KEY: str ENGINE: str MODEL: str + SUPPORTED_CONTENT_TYPES: list[str] = [] WHISPER_MODEL: str DEEPGRAM_API_KEY: str AZURE_API_KEY: str @@ -188,6 +191,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)): "tts": { "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS, "API_KEY": request.app.state.config.TTS_API_KEY, "ENGINE": request.app.state.config.TTS_ENGINE, "MODEL": request.app.state.config.TTS_MODEL, @@ -202,6 +206,7 @@ async def get_audio_config(request: Request, user=Depends(get_admin_user)): "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, "ENGINE": request.app.state.config.STT_ENGINE, "MODEL": request.app.state.config.STT_MODEL, + "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY, "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY, @@ -219,6 +224,7 @@ async def update_audio_config( ): request.app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL request.app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY + request.app.state.config.TTS_OPENAI_PARAMS = form_data.tts.OPENAI_PARAMS request.app.state.config.TTS_API_KEY = form_data.tts.API_KEY request.app.state.config.TTS_ENGINE = form_data.tts.ENGINE request.app.state.config.TTS_MODEL = form_data.tts.MODEL @@ -236,6 +242,10 @@ async def update_audio_config( request.app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY request.app.state.config.STT_ENGINE = form_data.stt.ENGINE request.app.state.config.STT_MODEL = form_data.stt.MODEL + request.app.state.config.STT_SUPPORTED_CONTENT_TYPES = ( + form_data.stt.SUPPORTED_CONTENT_TYPES + ) + request.app.state.config.WHISPER_MODEL = form_data.stt.WHISPER_MODEL request.app.state.config.DEEPGRAM_API_KEY = form_data.stt.DEEPGRAM_API_KEY request.app.state.config.AUDIO_STT_AZURE_API_KEY = form_data.stt.AZURE_API_KEY @@ -250,15 +260,18 @@ async def update_audio_config( request.app.state.faster_whisper_model = set_faster_whisper_model( form_data.stt.WHISPER_MODEL, WHISPER_MODEL_AUTO_UPDATE ) + else: + request.app.state.faster_whisper_model = None return { "tts": { - "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, - "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, - "API_KEY": request.app.state.config.TTS_API_KEY, "ENGINE": request.app.state.config.TTS_ENGINE, "MODEL": request.app.state.config.TTS_MODEL, "VOICE": request.app.state.config.TTS_VOICE, + "OPENAI_API_BASE_URL": request.app.state.config.TTS_OPENAI_API_BASE_URL, + "OPENAI_API_KEY": request.app.state.config.TTS_OPENAI_API_KEY, + "OPENAI_PARAMS": request.app.state.config.TTS_OPENAI_PARAMS, + "API_KEY": request.app.state.config.TTS_API_KEY, "SPLIT_ON": request.app.state.config.TTS_SPLIT_ON, "AZURE_SPEECH_REGION": request.app.state.config.TTS_AZURE_SPEECH_REGION, "AZURE_SPEECH_BASE_URL": request.app.state.config.TTS_AZURE_SPEECH_BASE_URL, @@ -269,6 +282,7 @@ async def update_audio_config( "OPENAI_API_KEY": request.app.state.config.STT_OPENAI_API_KEY, "ENGINE": request.app.state.config.STT_ENGINE, "MODEL": request.app.state.config.STT_MODEL, + "SUPPORTED_CONTENT_TYPES": request.app.state.config.STT_SUPPORTED_CONTENT_TYPES, "WHISPER_MODEL": request.app.state.config.WHISPER_MODEL, "DEEPGRAM_API_KEY": request.app.state.config.DEEPGRAM_API_KEY, "AZURE_API_KEY": request.app.state.config.AUDIO_STT_AZURE_API_KEY, @@ -318,6 +332,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): log.exception(e) raise HTTPException(status_code=400, detail="Invalid JSON payload") + r = None if request.app.state.config.TTS_ENGINE == "openai": payload["model"] = request.app.state.config.TTS_MODEL @@ -326,7 +341,12 @@ async def speech(request: Request, user=Depends(get_verified_user)): async with aiohttp.ClientSession( timeout=timeout, trust_env=True ) as session: - async with session.post( + payload = { + **payload, + **(request.app.state.config.TTS_OPENAI_PARAMS or {}), + } + + r = await session.post( url=f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech", json=payload, headers={ @@ -334,7 +354,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): "Authorization": f"Bearer {request.app.state.config.TTS_OPENAI_API_KEY}", **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -344,14 +364,15 @@ async def speech(request: Request, user=Depends(get_verified_user)): ), }, ssl=AIOHTTP_CLIENT_SESSION_SSL, - ) as r: - r.raise_for_status() + ) - async with aiofiles.open(file_path, "wb") as f: - await f.write(await r.read()) + r.raise_for_status() - async with aiofiles.open(file_body_path, "w") as f: - await f.write(json.dumps(payload)) + async with aiofiles.open(file_path, "wb") as f: + await f.write(await r.read()) + + async with aiofiles.open(file_body_path, "w") as f: + await f.write(json.dumps(payload)) return FileResponse(file_path) @@ -359,18 +380,22 @@ async def speech(request: Request, user=Depends(get_verified_user)): log.exception(e) detail = None - try: - if r.status != 200: - res = await r.json() + status_code = 500 + detail = f"Open WebUI: Server Connection Error" + + if r is not None: + status_code = r.status + try: + res = await r.json() if "error" in res: - detail = f"External: {res['error'].get('message', '')}" - except Exception: - detail = f"External: {e}" + detail = f"External: {res['error']}" + except Exception: + detail = f"External: {e}" raise HTTPException( - status_code=getattr(r, "status", 500) if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + status_code=status_code, + detail=detail, ) elif request.app.state.config.TTS_ENGINE == "elevenlabs": @@ -443,7 +468,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): try: data = f""" - {payload["input"]} + {html.escape(payload["input"])} """ timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) async with aiohttp.ClientSession( @@ -527,11 +552,18 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) -def transcription_handler(request, file_path): +def transcription_handler(request, file_path, metadata): filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) id = filename.split(".")[0] + metadata = metadata or {} + + languages = [ + metadata.get("language", None) if not WHISPER_LANGUAGE else WHISPER_LANGUAGE, + None, # Always fallback to None in case transcription fails + ] + if request.app.state.config.STT_ENGINE == "": if request.app.state.faster_whisper_model is None: request.app.state.faster_whisper_model = set_faster_whisper_model( @@ -543,7 +575,7 @@ def transcription_handler(request, file_path): file_path, beam_size=5, vad_filter=request.app.state.config.WHISPER_VAD_FILTER, - language=WHISPER_LANGUAGE, + language=languages[0], ) log.info( "Detected language '%s' with probability %f" @@ -563,14 +595,26 @@ def transcription_handler(request, file_path): elif request.app.state.config.STT_ENGINE == "openai": r = None try: - r = requests.post( - url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", - headers={ - "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" - }, - files={"file": (filename, open(file_path, "rb"))}, - data={"model": request.app.state.config.STT_MODEL}, - ) + for language in languages: + payload = { + "model": request.app.state.config.STT_MODEL, + } + + if language: + payload["language"] = language + + r = requests.post( + url=f"{request.app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions", + headers={ + "Authorization": f"Bearer {request.app.state.config.STT_OPENAI_API_KEY}" + }, + files={"file": (filename, open(file_path, "rb"))}, + data=payload, + ) + + if r.status_code == 200: + # Successful transcription + break r.raise_for_status() data = r.json() @@ -612,18 +656,26 @@ def transcription_handler(request, file_path): "Content-Type": mime, } - # Add model if specified - params = {} - if request.app.state.config.STT_MODEL: - params["model"] = request.app.state.config.STT_MODEL - - # Make request to Deepgram API - r = requests.post( - "https://api.deepgram.com/v1/listen", - headers=headers, - params=params, - data=file_data, - ) + for language in languages: + params = {} + if request.app.state.config.STT_MODEL: + params["model"] = request.app.state.config.STT_MODEL + + if language: + params["language"] = language + + # Make request to Deepgram API + r = requests.post( + "https://api.deepgram.com/v1/listen?smart_format=true", + headers=headers, + params=params, + data=file_data, + ) + + if r.status_code == 200: + # Successful transcription + break + r.raise_for_status() response_data = r.json() @@ -777,8 +829,8 @@ def transcription_handler(request, file_path): ) -def transcribe(request: Request, file_path): - log.info(f"transcribe: {file_path}") +def transcribe(request: Request, file_path: str, metadata: Optional[dict] = None): + log.info(f"transcribe: {file_path} {metadata}") if is_audio_conversion_required(file_path): file_path = convert_audio_to_mp3(file_path) @@ -804,7 +856,7 @@ def transcribe(request: Request, file_path): with ThreadPoolExecutor() as executor: # Submit tasks for each chunk_path futures = [ - executor.submit(transcription_handler, request, chunk_path) + executor.submit(transcription_handler, request, chunk_path, metadata) for chunk_path in chunk_paths ] # Gather results as they complete @@ -812,10 +864,9 @@ def transcribe(request: Request, file_path): try: results.append(future.result()) except Exception as transcribe_exc: - log.exception(f"Error transcribing chunk: {transcribe_exc}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="Error during transcription.", + detail=f"Error transcribing chunk: {transcribe_exc}", ) finally: # Clean up only the temporary chunks, never the original file @@ -897,14 +948,23 @@ def split_audio(file_path, max_bytes, format="mp3", bitrate="32k"): def transcription( request: Request, file: UploadFile = File(...), + language: Optional[str] = Form(None), user=Depends(get_verified_user), ): log.info(f"file.content_type: {file.content_type}") - SUPPORTED_CONTENT_TYPES = {"video/webm"} # Extend if you add more video types! - if not ( - file.content_type.startswith("audio/") - or file.content_type in SUPPORTED_CONTENT_TYPES + stt_supported_content_types = getattr( + request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] + ) + + if not any( + fnmatch(file.content_type, content_type) + for content_type in ( + stt_supported_content_types + if stt_supported_content_types + and any(t.strip() for t in stt_supported_content_types) + else ["audio/*", "video/webm"] + ) ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -926,7 +986,12 @@ def transcription( f.write(contents) try: - result = transcribe(request, file_path) + metadata = None + + if language: + metadata = {"language": language} + + result = transcribe(request, file_path, metadata) return { **result, diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 793bdfd30a2..e3271250c14 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -15,19 +15,22 @@ SigninResponse, SignupForm, UpdatePasswordForm, - UpdateProfileForm, UserResponse, ) -from open_webui.models.users import Users +from open_webui.models.users import Users, UpdateProfileForm +from open_webui.models.groups import Groups +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_AUTH_TRUSTED_GROUPS_HEADER, WEBUI_AUTH_COOKIE_SAME_SITE, WEBUI_AUTH_COOKIE_SECURE, WEBUI_AUTH_SIGNOUT_REDIRECT_URL, + ENABLE_INITIAL_ADMIN_SIGNUP, SRC_LOG_LEVELS, ) from fastapi import APIRouter, Depends, HTTPException, Request, status @@ -53,9 +56,8 @@ from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS -if ENABLE_LDAP.value: - from ldap3 import Server, Connection, NONE, Tls - from ldap3.utils.conv import escape_filter_chars +from ldap3 import Server, Connection, NONE, Tls +from ldap3.utils.conv import escape_filter_chars router = APIRouter() @@ -72,7 +74,13 @@ class SessionUserResponse(Token, UserResponse): permissions: Optional[dict] = None -@router.get("/", response_model=SessionUserResponse) +class SessionUserInfoResponse(SessionUserResponse): + bio: Optional[str] = None + gender: Optional[str] = None + date_of_birth: Optional[datetime.date] = None + + +@router.get("/", response_model=SessionUserInfoResponse) async def get_session_user( request: Request, response: Response, user=Depends(get_current_user) ): @@ -120,6 +128,9 @@ async def get_session_user( "name": user.name, "role": user.role, "profile_image_url": user.profile_image_url, + "bio": user.bio, + "gender": user.gender, + "date_of_birth": user.date_of_birth, "permissions": user_permissions, } @@ -136,7 +147,7 @@ async def update_profile( if session_user: user = Users.update_user_by_id( session_user.id, - {"profile_image_url": form_data.profile_image_url, "name": form_data.name}, + form_data.model_dump(), ) if user: return user @@ -227,14 +238,30 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): if not connection_app.bind(): raise HTTPException(400, detail="Application account bind failed") + ENABLE_LDAP_GROUP_MANAGEMENT = ( + request.app.state.config.ENABLE_LDAP_GROUP_MANAGEMENT + ) + ENABLE_LDAP_GROUP_CREATION = request.app.state.config.ENABLE_LDAP_GROUP_CREATION + LDAP_ATTRIBUTE_FOR_GROUPS = request.app.state.config.LDAP_ATTRIBUTE_FOR_GROUPS + + search_attributes = [ + f"{LDAP_ATTRIBUTE_FOR_USERNAME}", + f"{LDAP_ATTRIBUTE_FOR_MAIL}", + "cn", + ] + + if ENABLE_LDAP_GROUP_MANAGEMENT: + search_attributes.append(f"{LDAP_ATTRIBUTE_FOR_GROUPS}") + log.info( + f"LDAP Group Management enabled. Adding {LDAP_ATTRIBUTE_FOR_GROUPS} to search attributes" + ) + + log.info(f"LDAP search attributes: {search_attributes}") + search_success = connection_app.search( search_base=LDAP_SEARCH_BASE, search_filter=f"(&({LDAP_ATTRIBUTE_FOR_USERNAME}={escape_filter_chars(form_data.user.lower())}){LDAP_SEARCH_FILTERS})", - attributes=[ - f"{LDAP_ATTRIBUTE_FOR_USERNAME}", - f"{LDAP_ATTRIBUTE_FOR_MAIL}", - "cn", - ], + attributes=search_attributes, ) if not search_success or not connection_app.entries: @@ -257,6 +284,69 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): cn = str(entry["cn"]) user_dn = entry.entry_dn + user_groups = [] + if ENABLE_LDAP_GROUP_MANAGEMENT and LDAP_ATTRIBUTE_FOR_GROUPS in entry: + group_dns = entry[LDAP_ATTRIBUTE_FOR_GROUPS] + log.info(f"LDAP raw group DNs for user {username}: {group_dns}") + + if group_dns: + log.info(f"LDAP group_dns original: {group_dns}") + log.info(f"LDAP group_dns type: {type(group_dns)}") + log.info(f"LDAP group_dns length: {len(group_dns)}") + + if hasattr(group_dns, "value"): + group_dns = group_dns.value + log.info(f"Extracted .value property: {group_dns}") + elif hasattr(group_dns, "__iter__") and not isinstance( + group_dns, (str, bytes) + ): + group_dns = list(group_dns) + log.info(f"Converted to list: {group_dns}") + + if isinstance(group_dns, list): + group_dns = [str(item) for item in group_dns] + else: + group_dns = [str(group_dns)] + + log.info( + f"LDAP group_dns after processing - type: {type(group_dns)}, length: {len(group_dns)}" + ) + + for group_idx, group_dn in enumerate(group_dns): + group_dn = str(group_dn) + log.info(f"Processing group DN #{group_idx + 1}: {group_dn}") + + try: + group_cn = None + + for item in group_dn.split(","): + item = item.strip() + if item.upper().startswith("CN="): + group_cn = item[3:] + break + + if group_cn: + user_groups.append(group_cn) + + else: + log.warning( + f"Could not extract CN from group DN: {group_dn}" + ) + except Exception as e: + log.warning( + f"Failed to extract group name from DN {group_dn}: {e}" + ) + + log.info( + f"LDAP groups for user {username}: {user_groups} (total: {len(user_groups)})" + ) + else: + log.info(f"No groups found for user {username}") + elif ENABLE_LDAP_GROUP_MANAGEMENT: + log.warning( + f"LDAP Group Management enabled but {LDAP_ATTRIBUTE_FOR_GROUPS} attribute not found in user entry" + ) + if username == form_data.user.lower(): connection_user = Connection( server, @@ -271,11 +361,9 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): user = Users.get_user_by_email(email) if not user: try: - user_count = Users.get_num_users() - role = ( "admin" - if user_count == 0 + if not Users.has_users() else request.app.state.config.DEFAULT_USER_ROLE ) @@ -299,7 +387,7 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): 500, detail="Internal error occurred during LDAP user creation." ) - user = Auths.authenticate_user_by_trusted_header(email) + user = Auths.authenticate_user_by_email(email) if user: expires_delta = parse_duration(request.app.state.config.JWT_EXPIRES_IN) @@ -332,6 +420,22 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm): user.id, request.app.state.config.USER_PERMISSIONS ) + if ( + user.role != "admin" + and ENABLE_LDAP_GROUP_MANAGEMENT + and user_groups + ): + if ENABLE_LDAP_GROUP_CREATION: + Groups.create_groups_by_group_names(user.id, user_groups) + + try: + Groups.sync_groups_by_group_names(user.id, user_groups) + log.info( + f"Successfully synced groups for user {user.id}: {user_groups}" + ) + except Exception as e: + log.error(f"Failed to sync groups for user {user.id}: {e}") + return { "token": token, "token_type": "Bearer", @@ -363,21 +467,29 @@ async def signin(request: Request, response: Response, form_data: SigninForm): if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) - trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() - trusted_name = trusted_email + email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() + name = email + if WEBUI_AUTH_TRUSTED_NAME_HEADER: - trusted_name = request.headers.get( - WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email - ) - if not Users.get_user_by_email(trusted_email.lower()): + name = request.headers.get(WEBUI_AUTH_TRUSTED_NAME_HEADER, email) + + if not Users.get_user_by_email(email.lower()): await signup( request, response, - SignupForm( - email=trusted_email, password=str(uuid.uuid4()), name=trusted_name - ), + SignupForm(email=email, password=str(uuid.uuid4()), name=name), ) - user = Auths.authenticate_user_by_trusted_header(trusted_email) + + user = Auths.authenticate_user_by_email(email) + if WEBUI_AUTH_TRUSTED_GROUPS_HEADER and user and user.role != "admin": + group_names = request.headers.get( + WEBUI_AUTH_TRUSTED_GROUPS_HEADER, "" + ).split(",") + group_names = [name.strip() for name in group_names if name.strip()] + + if group_names: + Groups.sync_groups_by_group_names(user.id, group_names) + elif WEBUI_AUTH == False: admin_email = "admin@localhost" admin_password = "admin" @@ -385,7 +497,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm): if Users.get_user_by_email(admin_email.lower()): user = Auths.authenticate_user(admin_email.lower(), admin_password) else: - if Users.get_num_users() != 0: + if Users.has_users(): raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) await signup( @@ -452,22 +564,23 @@ async def signin(request: Request, response: Response, form_data: SigninForm): @router.post("/signup", response_model=SessionUserResponse) async def signup(request: Request, response: Response, form_data: SignupForm): + has_users = Users.has_users() if WEBUI_AUTH: if ( not request.app.state.config.ENABLE_SIGNUP or not request.app.state.config.ENABLE_LOGIN_FORM ): - raise HTTPException( - status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED - ) + if has_users or not ENABLE_INITIAL_ADMIN_SIGNUP: + raise HTTPException( + status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED + ) else: - if Users.get_num_users() != 0: + if has_users: raise HTTPException( status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED ) - user_count = Users.get_num_users() if not validate_email_format(form_data.email.lower()): raise HTTPException( status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT @@ -477,9 +590,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) try: - role = ( - "admin" if user_count == 0 else request.app.state.config.DEFAULT_USER_ROLE - ) + role = "admin" if not has_users else request.app.state.config.DEFAULT_USER_ROLE # The password passed to bcrypt must be 72 bytes or fewer. If it is longer, it will be truncated before hashing. if len(form_data.password.encode("utf-8")) > 72: @@ -525,7 +636,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): ) if request.app.state.config.WEBHOOK_URL: - post_webhook( + await post_webhook( request.app.state.WEBUI_NAME, request.app.state.config.WEBHOOK_URL, WEBHOOK_MESSAGES.USER_SIGNUP(user.name), @@ -540,7 +651,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm): user.id, request.app.state.config.USER_PERMISSIONS ) - if user_count == 0: + if not has_users: # Disable signup after the first user is created request.app.state.config.ENABLE_SIGNUP = False @@ -565,37 +676,52 @@ async def signup(request: Request, response: Response, form_data: SignupForm): @router.get("/signout") async def signout(request: Request, response: Response): response.delete_cookie("token") + response.delete_cookie("oui-session") + response.delete_cookie("oauth_id_token") + + oauth_session_id = request.cookies.get("oauth_session_id") + if oauth_session_id: + response.delete_cookie("oauth_session_id") + + session = OAuthSessions.get_session_by_id(oauth_session_id) + oauth_server_metadata_url = ( + request.app.state.oauth_manager.get_server_metadata_url(session.provider) + if session + else None + ) or OPENID_PROVIDER_URL.value - if ENABLE_OAUTH_SIGNUP.value: - oauth_id_token = request.cookies.get("oauth_id_token") - if oauth_id_token: + if session and oauth_server_metadata_url: + oauth_id_token = session.token.get("id_token") try: - async with ClientSession() as session: - async with session.get(OPENID_PROVIDER_URL.value) as resp: - if resp.status == 200: - openid_data = await resp.json() + async with ClientSession(trust_env=True) as session: + async with session.get(oauth_server_metadata_url) as r: + if r.status == 200: + openid_data = await r.json() logout_url = openid_data.get("end_session_endpoint") - if logout_url: - response.delete_cookie("oauth_id_token") + if logout_url: return JSONResponse( status_code=200, content={ "status": True, - "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}", + "redirect_url": f"{logout_url}?id_token_hint={oauth_id_token}" + + ( + f"&post_logout_redirect_uri={WEBUI_AUTH_SIGNOUT_REDIRECT_URL}" + if WEBUI_AUTH_SIGNOUT_REDIRECT_URL + else "" + ), }, headers=response.headers, ) else: - raise HTTPException( - status_code=resp.status, - detail="Failed to fetch OpenID configuration", - ) + raise Exception("Failed to fetch OpenID configuration") + except Exception as e: log.error(f"OpenID signout error: {str(e)}") raise HTTPException( status_code=500, detail="Failed to sign out from the OpenID provider.", + headers=response.headers, ) if WEBUI_AUTH_SIGNOUT_REDIRECT_URL: diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 6da3f04cee6..fda0879594b 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -10,7 +10,13 @@ from open_webui.socket.main import sio, get_user_ids_from_room from open_webui.models.users import Users, UserNameResponse -from open_webui.models.channels import Channels, ChannelModel, ChannelForm +from open_webui.models.groups import Groups +from open_webui.models.channels import ( + Channels, + ChannelModel, + ChannelForm, + ChannelResponse, +) from open_webui.models.messages import ( Messages, MessageModel, @@ -24,9 +30,17 @@ from open_webui.env import SRC_LOG_LEVELS +from open_webui.utils.models import ( + get_all_models, + get_filtered_models, +) +from open_webui.utils.chat import generate_chat_completion + + from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, get_users_with_access from open_webui.utils.webhook import post_webhook +from open_webui.utils.channels import extract_mentions, replace_mentions log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) @@ -40,10 +54,14 @@ @router.get("/", response_model=list[ChannelModel]) async def get_channels(user=Depends(get_verified_user)): + return Channels.get_channels_by_user_id(user.id) + + +@router.get("/list", response_model=list[ChannelModel]) +async def get_all_channels(user=Depends(get_verified_user)): if user.role == "admin": return Channels.get_channels() - else: - return Channels.get_channels_by_user_id(user.id) + return Channels.get_channels_by_user_id(user.id) ############################ @@ -68,7 +86,7 @@ async def create_new_channel(form_data: ChannelForm, user=Depends(get_admin_user ############################ -@router.get("/{id}", response_model=Optional[ChannelModel]) +@router.get("/{id}", response_model=Optional[ChannelResponse]) async def get_channel_by_id(id: str, user=Depends(get_verified_user)): channel = Channels.get_channel_by_id(id) if not channel: @@ -83,7 +101,16 @@ async def get_channel_by_id(id: str, user=Depends(get_verified_user)): status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) - return ChannelModel(**channel.model_dump()) + write_access = has_access( + user.id, type="write", access_control=channel.access_control, strict=False + ) + + return ChannelResponse( + **{ + **channel.model_dump(), + "write_access": write_access or user.role == "admin", + } + ) ############################ @@ -140,7 +167,7 @@ async def delete_channel_by_id(id: str, user=Depends(get_admin_user)): class MessageUserResponse(MessageResponse): - user: UserNameResponse + pass @router.get("/{id}/messages", response_model=list[MessageUserResponse]) @@ -169,15 +196,17 @@ async def get_channel_messages( user = Users.get_user_by_id(message.user_id) users[message.user_id] = user - replies = Messages.get_replies_by_message_id(message.id) - latest_reply_at = replies[0].created_at if replies else None + thread_replies = Messages.get_thread_replies_by_message_id(message.id) + latest_thread_reply_at = ( + thread_replies[0].created_at if thread_replies else None + ) messages.append( MessageUserResponse( **{ **message.model_dump(), - "reply_count": len(replies), - "latest_reply_at": latest_reply_at, + "reply_count": len(thread_replies), + "latest_reply_at": latest_thread_reply_at, "reactions": Messages.get_reactions_by_message_id(message.id), "user": UserNameResponse(**users[message.user_id].model_dump()), } @@ -196,16 +225,13 @@ async def send_notification(name, webui_url, channel, message, active_user_ids): users = get_users_with_access("read", channel.access_control) for user in users: - if user.id in active_user_ids: - continue - else: + if user.id not in active_user_ids: if user.settings: webhook_url = user.settings.ui.get("notifications", {}).get( "webhook_url", None ) - if webhook_url: - post_webhook( + await post_webhook( name, webhook_url, f"#{channel.name} - {webui_url}/channels/{channel.id}\n\n{message.content}", @@ -217,14 +243,185 @@ async def send_notification(name, webui_url, channel, message, active_user_ids): }, ) + return True -@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) -async def post_new_message( - request: Request, - id: str, - form_data: MessageForm, - background_tasks: BackgroundTasks, - user=Depends(get_verified_user), + +async def model_response_handler(request, channel, message, user): + MODELS = { + model["id"]: model + for model in get_filtered_models(await get_all_models(request, user=user), user) + } + + mentions = extract_mentions(message.content) + message_content = replace_mentions(message.content) + + model_mentions = {} + + # check if the message is a reply to a message sent by a model + if ( + message.reply_to_message + and message.reply_to_message.meta + and message.reply_to_message.meta.get("model_id", None) + ): + model_id = message.reply_to_message.meta.get("model_id", None) + model_mentions[model_id] = {"id": model_id, "id_type": "M"} + + # check if any of the mentions are models + for mention in mentions: + if mention["id_type"] == "M" and mention["id"] not in model_mentions: + model_mentions[mention["id"]] = mention + + if not model_mentions: + return False + + for mention in model_mentions.values(): + model_id = mention["id"] + model = MODELS.get(model_id, None) + + if model: + try: + # reverse to get in chronological order + thread_messages = Messages.get_messages_by_parent_id( + channel.id, + message.parent_id if message.parent_id else message.id, + )[::-1] + + response_message, channel = await new_message_handler( + request, + channel.id, + MessageForm( + **{ + "parent_id": ( + message.parent_id if message.parent_id else message.id + ), + "content": f"", + "data": {}, + "meta": { + "model_id": model_id, + "model_name": model.get("name", model_id), + }, + } + ), + user, + ) + + thread_history = [] + images = [] + message_users = {} + + for thread_message in thread_messages: + message_user = None + if thread_message.user_id not in message_users: + message_user = Users.get_user_by_id(thread_message.user_id) + message_users[thread_message.user_id] = message_user + else: + message_user = message_users[thread_message.user_id] + + if thread_message.meta and thread_message.meta.get( + "model_id", None + ): + # If the message was sent by a model, use the model name + message_model_id = thread_message.meta.get("model_id", None) + message_model = MODELS.get(message_model_id, None) + username = ( + message_model.get("name", message_model_id) + if message_model + else message_model_id + ) + else: + username = message_user.name if message_user else "Unknown" + + thread_history.append( + f"{username}: {replace_mentions(thread_message.content)}" + ) + + thread_message_files = thread_message.data.get("files", []) + for file in thread_message_files: + if file.get("type", "") == "image": + images.append(file.get("url", "")) + + thread_history_string = "\n\n".join(thread_history) + system_message = { + "role": "system", + "content": f"You are {model.get('name', model_id)}, participating in a threaded conversation. Be concise and conversational." + + ( + f"Here's the thread history:\n\n\n{thread_history_string}\n\n\nContinue the conversation naturally as {model.get('name', model_id)}, addressing the most recent message while being aware of the full context." + if thread_history + else "" + ), + } + + content = f"{user.name if user else 'User'}: {message_content}" + if images: + content = [ + { + "type": "text", + "text": content, + }, + *[ + { + "type": "image_url", + "image_url": { + "url": image, + }, + } + for image in images + ], + ] + + form_data = { + "model": model_id, + "messages": [ + system_message, + {"role": "user", "content": content}, + ], + "stream": False, + } + + res = await generate_chat_completion( + request, + form_data=form_data, + user=user, + ) + + if res: + if res.get("choices", []) and len(res["choices"]) > 0: + await update_message_by_id( + channel.id, + response_message.id, + MessageForm( + **{ + "content": res["choices"][0]["message"]["content"], + "meta": { + "done": True, + }, + } + ), + user, + ) + elif res.get("error", None): + await update_message_by_id( + channel.id, + response_message.id, + MessageForm( + **{ + "content": f"Error: {res['error']}", + "meta": { + "done": True, + }, + } + ), + user, + ) + except Exception as e: + log.info(e) + pass + + return True + + +async def new_message_handler( + request: Request, id: str, form_data: MessageForm, user=Depends(get_verified_user) ): channel = Channels.get_channel_by_id(id) if not channel: @@ -233,7 +430,7 @@ async def post_new_message( ) if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="write", access_control=channel.access_control, strict=False ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -241,31 +438,21 @@ async def post_new_message( try: message = Messages.insert_new_message(form_data, channel.id, user.id) - if message: + message = Messages.get_message_by_id(message.id) event_data = { "channel_id": channel.id, "message_id": message.id, "data": { "type": "message", - "data": MessageUserResponse( - **{ - **message.model_dump(), - "reply_count": 0, - "latest_reply_at": None, - "reactions": Messages.get_reactions_by_message_id( - message.id - ), - "user": UserNameResponse(**user.model_dump()), - } - ).model_dump(), + "data": message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), } await sio.emit( - "channel-events", + "events:channel", event_data, to=f"channel:{channel.id}", ) @@ -276,33 +463,45 @@ async def post_new_message( if parent_message: await sio.emit( - "channel-events", + "events:channel", { "channel_id": channel.id, "message_id": parent_message.id, "data": { "type": "message:reply", - "data": MessageUserResponse( - **{ - **parent_message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id - ).model_dump() - ), - } - ).model_dump(), + "data": parent_message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), }, to=f"channel:{channel.id}", ) + return message, channel + else: + raise Exception("Error creating message") + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) - active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") - background_tasks.add_task( - send_notification, +@router.post("/{id}/messages/post", response_model=Optional[MessageModel]) +async def post_new_message( + request: Request, + id: str, + form_data: MessageForm, + background_tasks: BackgroundTasks, + user=Depends(get_verified_user), +): + + try: + message, channel = await new_message_handler(request, id, form_data, user) + active_user_ids = get_user_ids_from_room(f"channel:{channel.id}") + + async def background_handler(): + await model_response_handler(request, channel, message, user) + await send_notification( request.app.state.WEBUI_NAME, request.app.state.config.WEBUI_URL, channel, @@ -310,7 +509,12 @@ async def post_new_message( active_user_ids, ) - return MessageModel(**message.model_dump()) + background_tasks.add_task(background_handler) + + return message + + except HTTPException as e: + raise e except Exception as e: log.exception(e) raise HTTPException( @@ -430,13 +634,6 @@ async def update_message_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) - message = Messages.get_message_by_id(message_id) if not message: raise HTTPException( @@ -448,26 +645,28 @@ async def update_message_by_id( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() ) + if ( + user.role != "admin" + and message.user_id != user.id + and not has_access(user.id, type="read", access_control=channel.access_control) + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + try: message = Messages.update_message_by_id(message_id, form_data) message = Messages.get_message_by_id(message_id) if message: await sio.emit( - "channel-events", + "events:channel", { "channel_id": channel.id, "message_id": message.id, "data": { "type": "message:update", - "data": MessageUserResponse( - **{ - **message.model_dump(), - "user": UserNameResponse( - **user.model_dump() - ).model_dump(), - } - ).model_dump(), + "data": message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), @@ -503,7 +702,7 @@ async def add_reaction_to_message( ) if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="write", access_control=channel.access_control, strict=False ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -525,7 +724,7 @@ async def add_reaction_to_message( message = Messages.get_message_by_id(message_id) await sio.emit( - "channel-events", + "events:channel", { "channel_id": channel.id, "message_id": message.id, @@ -533,9 +732,6 @@ async def add_reaction_to_message( "type": "message:reaction:add", "data": { **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() - ).model_dump(), "name": form_data.name, }, }, @@ -569,7 +765,7 @@ async def remove_reaction_by_id_and_user_id_and_name( ) if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control + user.id, type="write", access_control=channel.access_control, strict=False ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -594,7 +790,7 @@ async def remove_reaction_by_id_and_user_id_and_name( message = Messages.get_message_by_id(message_id) await sio.emit( - "channel-events", + "events:channel", { "channel_id": channel.id, "message_id": message.id, @@ -602,9 +798,6 @@ async def remove_reaction_by_id_and_user_id_and_name( "type": "message:reaction:remove", "data": { **message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id(message.user_id).model_dump() - ).model_dump(), "name": form_data.name, }, }, @@ -637,13 +830,6 @@ async def delete_message_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if user.role != "admin" and not has_access( - user.id, type="read", access_control=channel.access_control - ): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() - ) - message = Messages.get_message_by_id(message_id) if not message: raise HTTPException( @@ -655,10 +841,21 @@ async def delete_message_by_id( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() ) + if ( + user.role != "admin" + and message.user_id != user.id + and not has_access( + user.id, type="write", access_control=channel.access_control, strict=False + ) + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() + ) + try: Messages.delete_message_by_id(message_id) await sio.emit( - "channel-events", + "events:channel", { "channel_id": channel.id, "message_id": message.id, @@ -681,22 +878,13 @@ async def delete_message_by_id( if parent_message: await sio.emit( - "channel-events", + "events:channel", { "channel_id": channel.id, "message_id": parent_message.id, "data": { "type": "message:reply", - "data": MessageUserResponse( - **{ - **parent_message.model_dump(), - "user": UserNameResponse( - **Users.get_user_by_id( - parent_message.user_id - ).model_dump() - ), - } - ).model_dump(), + "data": parent_message.model_dump(), }, "user": UserNameResponse(**user.model_dump()).model_dump(), "channel": channel.model_dump(), diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 6f00dd4d7ca..2587c5ff8e5 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -36,16 +36,33 @@ @router.get("/", response_model=list[ChatTitleIdResponse]) @router.get("/list", response_model=list[ChatTitleIdResponse]) -async def get_session_user_chat_list( - user=Depends(get_verified_user), page: Optional[int] = None +def get_session_user_chat_list( + user=Depends(get_verified_user), + page: Optional[int] = None, + include_pinned: Optional[bool] = False, + include_folders: Optional[bool] = False, ): - if page is not None: - limit = 60 - skip = (page - 1) * limit - - return Chats.get_chat_title_id_list_by_user_id(user.id, skip=skip, limit=limit) - else: - return Chats.get_chat_title_id_list_by_user_id(user.id) + try: + if page is not None: + limit = 60 + skip = (page - 1) * limit + + return Chats.get_chat_title_id_list_by_user_id( + user.id, + include_folders=include_folders, + include_pinned=include_pinned, + skip=skip, + limit=limit, + ) + else: + return Chats.get_chat_title_id_list_by_user_id( + user.id, include_folders=include_folders, include_pinned=include_pinned + ) + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) ############################ @@ -76,17 +93,34 @@ async def delete_all_user_chats(request: Request, user=Depends(get_verified_user @router.get("/list/user/{user_id}", response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_user_id( user_id: str, + page: Optional[int] = None, + query: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, user=Depends(get_admin_user), - skip: int = 0, - limit: int = 50, ): if not ENABLE_ADMIN_CHAT_ACCESS: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + return Chats.get_chat_list_by_user_id( - user_id, include_archived=True, skip=skip, limit=limit + user_id, include_archived=True, filter=filter, skip=skip, limit=limit ) @@ -141,7 +175,7 @@ async def import_chat(form_data: ChatImportForm, user=Depends(get_verified_user) @router.get("/search", response_model=list[ChatTitleIdResponse]) -async def search_user_chats( +def search_user_chats( text: str, page: Optional[int] = None, user=Depends(get_verified_user) ): if page is None: @@ -189,15 +223,37 @@ async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user) ] +@router.get("/folder/{folder_id}/list") +async def get_chat_list_by_folder_id( + folder_id: str, page: Optional[int] = 1, user=Depends(get_verified_user) +): + try: + limit = 60 + skip = (page - 1) * limit + + return [ + {"title": chat.title, "id": chat.id, "updated_at": chat.updated_at} + for chat in Chats.get_chats_by_folder_id_and_user_id( + folder_id, user.id, skip=skip, limit=limit + ) + ] + + except Exception as e: + log.exception(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + ############################ # GetPinnedChats ############################ -@router.get("/pinned", response_model=list[ChatResponse]) +@router.get("/pinned", response_model=list[ChatTitleIdResponse]) async def get_user_pinned_chats(user=Depends(get_verified_user)): return [ - ChatResponse(**chat.model_dump()) + ChatTitleIdResponse(**chat.model_dump()) for chat in Chats.get_pinned_chats_by_user_id(user.id) ] @@ -267,9 +323,37 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): @router.get("/archived", response_model=list[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( - user=Depends(get_verified_user), skip: int = 0, limit: int = 50 + page: Optional[int] = None, + query: Optional[str] = None, + order_by: Optional[str] = None, + direction: Optional[str] = None, + user=Depends(get_verified_user), ): - return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) + if page is None: + page = 1 + + limit = 60 + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + if order_by: + filter["order_by"] = order_by + if direction: + filter["direction"] = direction + + chat_list = [ + ChatTitleIdResponse(**chat.model_dump()) + for chat in Chats.get_archived_chat_list_by_user_id( + user.id, + filter=filter, + skip=skip, + limit=limit, + ) + ] + + return chat_list ############################ @@ -282,6 +366,16 @@ async def archive_all_chats(user=Depends(get_verified_user)): return Chats.archive_all_chats_by_user_id(user.id) +############################ +# UnarchiveAllChats +############################ + + +@router.post("/unarchive/all", response_model=bool) +async def unarchive_all_chats(user=Depends(get_verified_user)): + return Chats.unarchive_all_chats_by_user_id(user.id) + + ############################ # GetSharedChatById ############################ @@ -564,7 +658,18 @@ async def clone_chat_by_id( "title": form_data.title if form_data.title else f"Clone of {chat.title}", } - chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) + chat = Chats.import_chat( + user.id, + ChatImportForm( + **{ + "chat": updated_chat, + "meta": chat.meta, + "pinned": chat.pinned, + "folder_id": chat.folder_id, + } + ), + ) + return ChatResponse(**chat.model_dump()) else: raise HTTPException( @@ -593,7 +698,17 @@ async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user)): "title": f"Clone of {chat.title}", } - chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat})) + chat = Chats.import_chat( + user.id, + ChatImportForm( + **{ + "chat": updated_chat, + "meta": chat.meta, + "pinned": chat.pinned, + "folder_id": chat.folder_id, + } + ), + ) return ChatResponse(**chat.model_dump()) else: raise HTTPException( @@ -639,8 +754,10 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/share", response_model=Optional[ChatResponse]) async def share_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)): - if not has_permission( - user.id, "chat.share", request.app.state.config.USER_PERMISSIONS + if (user.role != "admin") and ( + not has_permission( + user.id, "chat.share", request.app.state.config.USER_PERMISSIONS + ) ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 44b2ef40cfb..e7fa13d1ff2 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,5 +1,7 @@ +import logging from fastapi import APIRouter, Depends, Request, HTTPException from pydantic import BaseModel, ConfigDict +import aiohttp from typing import Optional @@ -7,11 +9,29 @@ from open_webui.config import get_config, save_config from open_webui.config import BannerModel -from open_webui.utils.tools import get_tool_server_data, get_tool_servers_data +from open_webui.utils.tools import ( + get_tool_server_data, + get_tool_server_url, + set_tool_servers, +) +from open_webui.utils.mcp.client import MCPClient +from open_webui.env import SRC_LOG_LEVELS + +from open_webui.utils.oauth import ( + get_discovery_urls, + get_oauth_client_info_with_dynamic_client_registration, + encrypt_data, + decrypt_data, + OAuthClientInformationFull, +) +from mcp.shared.auth import OAuthMetadata router = APIRouter() +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + ############################ # ImportConfig @@ -39,35 +59,79 @@ async def export_config(user=Depends(get_admin_user)): ############################ -# Direct Connections Config +# Connections Config ############################ -class DirectConnectionsConfigForm(BaseModel): +class ConnectionsConfigForm(BaseModel): ENABLE_DIRECT_CONNECTIONS: bool + ENABLE_BASE_MODELS_CACHE: bool -@router.get("/direct_connections", response_model=DirectConnectionsConfigForm) -async def get_direct_connections_config(request: Request, user=Depends(get_admin_user)): +@router.get("/connections", response_model=ConnectionsConfigForm) +async def get_connections_config(request: Request, user=Depends(get_admin_user)): return { "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE, } -@router.post("/direct_connections", response_model=DirectConnectionsConfigForm) -async def set_direct_connections_config( +@router.post("/connections", response_model=ConnectionsConfigForm) +async def set_connections_config( request: Request, - form_data: DirectConnectionsConfigForm, + form_data: ConnectionsConfigForm, user=Depends(get_admin_user), ): request.app.state.config.ENABLE_DIRECT_CONNECTIONS = ( form_data.ENABLE_DIRECT_CONNECTIONS ) + request.app.state.config.ENABLE_BASE_MODELS_CACHE = ( + form_data.ENABLE_BASE_MODELS_CACHE + ) + return { "ENABLE_DIRECT_CONNECTIONS": request.app.state.config.ENABLE_DIRECT_CONNECTIONS, + "ENABLE_BASE_MODELS_CACHE": request.app.state.config.ENABLE_BASE_MODELS_CACHE, } +class OAuthClientRegistrationForm(BaseModel): + url: str + client_id: str + client_name: Optional[str] = None + + +@router.post("/oauth/clients/register") +async def register_oauth_client( + request: Request, + form_data: OAuthClientRegistrationForm, + type: Optional[str] = None, + user=Depends(get_admin_user), +): + try: + oauth_client_id = form_data.client_id + if type: + oauth_client_id = f"{type}:{form_data.client_id}" + + oauth_client_info = ( + await get_oauth_client_info_with_dynamic_client_registration( + request, oauth_client_id, form_data.url + ) + ) + return { + "status": True, + "oauth_client_info": encrypt_data( + oauth_client_info.model_dump(mode="json") + ), + } + except Exception as e: + log.debug(f"Failed to register OAuth client: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to register OAuth client", + ) + + ############################ # ToolServers Config ############################ @@ -76,6 +140,7 @@ async def set_direct_connections_config( class ToolServerConnection(BaseModel): url: str path: str + type: Optional[str] = "openapi" # openapi, mcp auth_type: Optional[str] key: Optional[str] config: Optional[dict] @@ -104,9 +169,27 @@ async def set_tool_servers_config( connection.model_dump() for connection in form_data.TOOL_SERVER_CONNECTIONS ] - request.app.state.TOOL_SERVERS = await get_tool_servers_data( - request.app.state.config.TOOL_SERVER_CONNECTIONS - ) + await set_tool_servers(request) + + for connection in request.app.state.config.TOOL_SERVER_CONNECTIONS: + server_type = connection.get("type", "openapi") + if server_type == "mcp": + server_id = connection.get("info", {}).get("id") + auth_type = connection.get("auth_type", "none") + if auth_type == "oauth_2.1" and server_id: + try: + oauth_client_info = connection.get("info", {}).get( + "oauth_client_info", "" + ) + oauth_client_info = decrypt_data(oauth_client_info) + + request.app.state.oauth_client_manager.add_client( + f"{server_type}:{server_id}", + OAuthClientInformationFull(**oauth_client_info), + ) + except Exception as e: + log.debug(f"Failed to add OAuth client for MCP tool server: {e}") + continue return { "TOOL_SERVER_CONNECTIONS": request.app.state.config.TOOL_SERVER_CONNECTIONS, @@ -121,19 +204,106 @@ async def verify_tool_servers_config( Verify the connection to the tool server. """ try: - - token = None - if form_data.auth_type == "bearer": - token = form_data.key - elif form_data.auth_type == "session": - token = request.state.token.credentials - - url = f"{form_data.url}/{form_data.path}" - return await get_tool_server_data(token, url) + if form_data.type == "mcp": + if form_data.auth_type == "oauth_2.1": + discovery_urls = get_discovery_urls(form_data.url) + for discovery_url in discovery_urls: + log.debug( + f"Trying to fetch OAuth 2.1 discovery document from {discovery_url}" + ) + async with aiohttp.ClientSession() as session: + async with session.get( + discovery_url + ) as oauth_server_metadata_response: + if oauth_server_metadata_response.status == 200: + try: + oauth_server_metadata = ( + OAuthMetadata.model_validate( + await oauth_server_metadata_response.json() + ) + ) + return { + "status": True, + "oauth_server_metadata": oauth_server_metadata.model_dump( + mode="json" + ), + } + except Exception as e: + log.info( + f"Failed to parse OAuth 2.1 discovery document: {e}" + ) + raise HTTPException( + status_code=400, + detail=f"Failed to parse OAuth 2.1 discovery document from {discovery_url}", + ) + + raise HTTPException( + status_code=400, + detail=f"Failed to fetch OAuth 2.1 discovery document from {discovery_urls}", + ) + else: + try: + client = MCPClient() + headers = None + + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + pass + + if token: + headers = {"Authorization": f"Bearer {token}"} + + await client.connect(form_data.url, headers=headers) + specs = await client.list_tool_specs() + return { + "status": True, + "specs": specs, + } + except Exception as e: + log.debug(f"Failed to create MCP client: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to create MCP client", + ) + finally: + if client: + await client.disconnect() + else: # openapi + token = None + if form_data.auth_type == "bearer": + token = form_data.key + elif form_data.auth_type == "session": + token = request.state.token.credentials + elif form_data.auth_type == "system_oauth": + try: + if request.cookies.get("oauth_session_id", None): + token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + pass + + url = get_tool_server_url(form_data.url, form_data.path) + return await get_tool_server_data(token, url) + except HTTPException as e: + raise e except Exception as e: + log.debug(f"Failed to connect to the tool server: {e}") raise HTTPException( status_code=400, - detail=f"Failed to connect to the tool server: {str(e)}", + detail=f"Failed to connect to the tool server", ) diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index 164f3c40b48..c76a1f6915d 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -129,7 +129,10 @@ async def create_feedback( @router.get("/feedback/{id}", response_model=FeedbackModel) async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): - feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) + if user.role == "admin": + feedback = Feedbacks.get_feedback_by_id(id=id) + else: + feedback = Feedbacks.get_feedback_by_id_and_user_id(id=id, user_id=user.id) if not feedback: raise HTTPException( @@ -143,9 +146,12 @@ async def get_feedback_by_id(id: str, user=Depends(get_verified_user)): async def update_feedback_by_id( id: str, form_data: FeedbackForm, user=Depends(get_verified_user) ): - feedback = Feedbacks.update_feedback_by_id_and_user_id( - id=id, user_id=user.id, form_data=form_data - ) + if user.role == "admin": + feedback = Feedbacks.update_feedback_by_id(id=id, form_data=form_data) + else: + feedback = Feedbacks.update_feedback_by_id_and_user_id( + id=id, user_id=user.id, form_data=form_data + ) if not feedback: raise HTTPException( diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index ad556d3272f..84d8f841cfc 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -1,24 +1,30 @@ import logging import os import uuid +import json from fnmatch import fnmatch from pathlib import Path from typing import Optional from urllib.parse import quote +import asyncio from fastapi import ( + BackgroundTasks, APIRouter, Depends, File, + Form, HTTPException, Request, UploadFile, status, Query, ) + from fastapi.responses import FileResponse, StreamingResponse from open_webui.constants import ERROR_MESSAGES from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT from open_webui.models.users import Users from open_webui.models.files import ( @@ -39,7 +45,6 @@ log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MODELS"]) - router = APIRouter() @@ -80,23 +85,103 @@ def has_access_to_file( ############################ +def process_uploaded_file(request, file, file_path, file_item, file_metadata, user): + try: + if file.content_type: + stt_supported_content_types = getattr( + request.app.state.config, "STT_SUPPORTED_CONTENT_TYPES", [] + ) + + if any( + fnmatch(file.content_type, content_type) + for content_type in ( + stt_supported_content_types + if stt_supported_content_types + and any(t.strip() for t in stt_supported_content_types) + else ["audio/*", "video/webm"] + ) + ): + file_path = Storage.get_file(file_path) + result = transcribe(request, file_path, file_metadata) + + process_file( + request, + ProcessFileForm( + file_id=file_item.id, content=result.get("text", "") + ), + user=user, + ) + elif (not file.content_type.startswith(("image/", "video/"))) or ( + request.app.state.config.CONTENT_EXTRACTION_ENGINE == "external" + ): + process_file(request, ProcessFileForm(file_id=file_item.id), user=user) + else: + log.info( + f"File type {file.content_type} is not provided, but trying to process anyway" + ) + process_file(request, ProcessFileForm(file_id=file_item.id), user=user) + except Exception as e: + log.error(f"Error processing file: {file_item.id}") + Files.update_file_data_by_id( + file_item.id, + { + "status": "failed", + "error": str(e.detail) if hasattr(e, "detail") else str(e), + }, + ) + + @router.post("/", response_model=FileModelResponse) def upload_file( request: Request, + background_tasks: BackgroundTasks, file: UploadFile = File(...), + metadata: Optional[dict | str] = Form(None), + process: bool = Query(True), + process_in_background: bool = Query(True), user=Depends(get_verified_user), - file_metadata: dict = None, +): + return upload_file_handler( + request, + file=file, + metadata=metadata, + process=process, + process_in_background=process_in_background, + user=user, + background_tasks=background_tasks, + ) + + +def upload_file_handler( + request: Request, + file: UploadFile = File(...), + metadata: Optional[dict | str] = Form(None), process: bool = Query(True), + process_in_background: bool = Query(True), + user=Depends(get_verified_user), + background_tasks: Optional[BackgroundTasks] = None, ): log.info(f"file.content_type: {file.content_type}") - file_metadata = file_metadata if file_metadata else {} + if isinstance(metadata, str): + try: + metadata = json.loads(metadata) + except json.JSONDecodeError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Invalid metadata format"), + ) + file_metadata = metadata if metadata else {} + try: unsanitized_filename = file.filename filename = os.path.basename(unsanitized_filename) file_extension = os.path.splitext(filename)[1] - if request.app.state.config.ALLOWED_FILE_EXTENSIONS: + # Remove the leading dot from the file extension + file_extension = file_extension[1:] if file_extension else "" + + if process and request.app.state.config.ALLOWED_FILE_EXTENSIONS: request.app.state.config.ALLOWED_FILE_EXTENSIONS = [ ext for ext in request.app.state.config.ALLOWED_FILE_EXTENSIONS if ext ] @@ -113,13 +198,16 @@ def upload_file( id = str(uuid.uuid4()) name = filename filename = f"{id}_{filename}" - tags = { - "OpenWebUI-User-Email": user.email, - "OpenWebUI-User-Id": user.id, - "OpenWebUI-User-Name": user.name, - "OpenWebUI-File-Id": id, - } - contents, file_path = Storage.upload_file(file.file, filename, tags) + contents, file_path = Storage.upload_file( + file.file, + filename, + { + "OpenWebUI-User-Email": user.email, + "OpenWebUI-User-Id": user.id, + "OpenWebUI-User-Name": user.name, + "OpenWebUI-File-Id": id, + }, + ) file_item = Files.insert_new_file( user.id, @@ -128,6 +216,9 @@ def upload_file( "id": id, "filename": name, "path": file_path, + "data": { + **({"status": "pending"} if process else {}), + }, "meta": { "name": name, "content_type": file.content_type, @@ -137,59 +228,43 @@ def upload_file( } ), ) - if process: - try: - if file.content_type: - if file.content_type.startswith("audio/") or file.content_type in { - "video/webm" - }: - file_path = Storage.get_file(file_path) - result = transcribe(request, file_path) - - process_file( - request, - ProcessFileForm(file_id=id, content=result.get("text", "")), - user=user, - ) - elif file.content_type not in [ - "image/png", - "image/jpeg", - "image/gif", - "video/mp4", - "video/ogg", - "video/quicktime", - ]: - process_file(request, ProcessFileForm(file_id=id), user=user) - else: - log.info( - f"File type {file.content_type} is not provided, but trying to process anyway" - ) - process_file(request, ProcessFileForm(file_id=id), user=user) - file_item = Files.get_file_by_id(id=id) - except Exception as e: - log.exception(e) - log.error(f"Error processing file: {file_item.id}") - file_item = FileModelResponse( - **{ - **file_item.model_dump(), - "error": str(e.detail) if hasattr(e, "detail") else str(e), - } + if process: + if background_tasks and process_in_background: + background_tasks.add_task( + process_uploaded_file, + request, + file, + file_path, + file_item, + file_metadata, + user, ) - - if file_item: - return file_item + return {"status": True, **file_item.model_dump()} + else: + process_uploaded_file( + request, + file, + file_path, + file_item, + file_metadata, + user, + ) + return {"status": True, **file_item.model_dump()} else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), - ) + if file_item: + return file_item + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), + ) except Exception as e: log.exception(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), + detail=ERROR_MESSAGES.DEFAULT("Error uploading file"), ) @@ -266,6 +341,7 @@ async def delete_all_files(user=Depends(get_admin_user)): if result: try: Storage.delete_all_files() + VECTOR_DB_CLIENT.reset() except Exception as e: log.exception(e) log.error("Error deleting files") @@ -309,6 +385,63 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)): ) +@router.get("/{id}/process/status") +async def get_file_process_status( + id: str, stream: bool = Query(False), user=Depends(get_verified_user) +): + file = Files.get_file_by_id(id) + + if not file: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + if ( + file.user_id == user.id + or user.role == "admin" + or has_access_to_file(id, "read", user) + ): + if stream: + MAX_FILE_PROCESSING_DURATION = 3600 * 2 + + async def event_stream(file_item): + if file_item: + for _ in range(MAX_FILE_PROCESSING_DURATION): + file_item = Files.get_file_by_id(file_item.id) + if file_item: + data = file_item.model_dump().get("data", {}) + status = data.get("status") + + if status: + event = {"status": status} + if status == "failed": + event["error"] = data.get("error") + + yield f"data: {json.dumps(event)}\n\n" + if status in ("completed", "failed"): + break + else: + # Legacy + break + + await asyncio.sleep(0.5) + else: + yield f"data: {json.dumps({'status': 'not_found'})}\n\n" + + return StreamingResponse( + event_stream(file), + media_type="text/event-stream", + ) + else: + return {"status": file.data.get("status", "pending")} + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + ############################ # Get File Data Content By Id ############################ @@ -583,12 +716,12 @@ async def delete_file_by_id(id: str, user=Depends(get_verified_user)): or user.role == "admin" or has_access_to_file(id, "write", user) ): - # We should add Chroma cleanup here result = Files.delete_file_by_id(id) if result: try: Storage.delete_file(file.path) + VECTOR_DB_CLIENT.delete(collection_name=f"file-{id}") except Exception as e: log.exception(e) log.error("Error deleting files") diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 2c41c92854b..b242b08e3a3 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -10,10 +10,15 @@ from open_webui.models.folders import ( FolderForm, + FolderUpdateForm, FolderModel, + FolderNameIdResponse, Folders, ) from open_webui.models.chats import Chats +from open_webui.models.files import Files +from open_webui.models.knowledge import Knowledges + from open_webui.config import UPLOAD_DIR from open_webui.env import SRC_LOG_LEVELS @@ -40,24 +45,46 @@ ############################ -@router.get("/", response_model=list[FolderModel]) +@router.get("/", response_model=list[FolderNameIdResponse]) async def get_folders(user=Depends(get_verified_user)): folders = Folders.get_folders_by_user_id(user.id) - return [ - { - **folder.model_dump(), - "items": { - "chats": [ - {"title": chat.title, "id": chat.id} - for chat in Chats.get_chats_by_folder_id_and_user_id( - folder.id, user.id - ) - ] - }, - } - for folder in folders - ] + # Verify folder data integrity + folder_list = [] + for folder in folders: + if folder.parent_id and not Folders.get_folder_by_id_and_user_id( + folder.parent_id, user.id + ): + folder = Folders.update_folder_parent_id_by_id_and_user_id( + folder.id, user.id, None + ) + + if folder.data: + if "files" in folder.data: + valid_files = [] + for file in folder.data["files"]: + + if file.get("type") == "file": + if Files.check_access_by_user_id( + file.get("id"), user.id, "read" + ): + valid_files.append(file) + elif file.get("type") == "collection": + if Knowledges.check_access_by_user_id( + file.get("id"), user.id, "read" + ): + valid_files.append(file) + else: + valid_files.append(file) + + folder.data["files"] = valid_files + Folders.update_folder_by_id_and_user_id( + folder.id, user.id, FolderUpdateForm(data=folder.data) + ) + + folder_list.append(FolderNameIdResponse(**folder.model_dump())) + + return folder_list ############################ @@ -78,7 +105,7 @@ def create_folder(form_data: FolderForm, user=Depends(get_verified_user)): ) try: - folder = Folders.insert_new_folder(user.id, form_data.name) + folder = Folders.insert_new_folder(user.id, form_data) return folder except Exception as e: log.exception(e) @@ -113,24 +140,24 @@ async def get_folder_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update") async def update_folder_name_by_id( - id: str, form_data: FolderForm, user=Depends(get_verified_user) + id: str, form_data: FolderUpdateForm, user=Depends(get_verified_user) ): folder = Folders.get_folder_by_id_and_user_id(id, user.id) if folder: - existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( - folder.parent_id, user.id, form_data.name - ) - if existing_folder: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), - ) - try: - folder = Folders.update_folder_name_by_id_and_user_id( - id, user.id, form_data.name + if form_data.name is not None: + # Check if folder with same name exists + existing_folder = Folders.get_folder_by_parent_id_and_user_id_and_name( + folder.parent_id, user.id, form_data.name ) + if existing_folder and existing_folder.id != id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Folder already exists"), + ) + try: + folder = Folders.update_folder_by_id_and_user_id(id, user.id, form_data) return folder except Exception as e: log.exception(e) @@ -233,31 +260,41 @@ async def update_folder_is_expanded_by_id( async def delete_folder_by_id( request: Request, id: str, user=Depends(get_verified_user) ): - chat_delete_permission = has_permission( - user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS - ) - - if user.role != "admin" and not chat_delete_permission: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACCESS_PROHIBITED, + if Chats.count_chats_by_folder_id_and_user_id(id, user.id): + chat_delete_permission = has_permission( + user.id, "chat.delete", request.app.state.config.USER_PERMISSIONS ) - - folder = Folders.get_folder_by_id_and_user_id(id, user.id) - if folder: - try: - result = Folders.delete_folder_by_id_and_user_id(id, user.id) - if result: - return result - else: - raise Exception("Error deleting folder") - except Exception as e: - log.exception(e) - log.error(f"Error deleting folder: {id}") + if user.role != "admin" and not chat_delete_permission: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"), + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + + folders = [] + folders.append(Folders.get_folder_by_id_and_user_id(id, user.id)) + while folders: + folder = folders.pop() + if folder: + try: + folder_ids = Folders.delete_folder_by_id_and_user_id(id, user.id) + for folder_id in folder_ids: + Chats.delete_chats_by_user_id_and_folder_id(user.id, folder_id) + + return True + except Exception as e: + log.exception(e) + log.error(f"Error deleting folder: {id}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error deleting folder"), + ) + finally: + # Get all subfolders + subfolders = Folders.get_folders_by_parent_id_and_user_id( + folder.id, user.id + ) + folders.extend(subfolders) + else: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 206610138e0..c8f131553c3 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -1,5 +1,8 @@ import os +import re + import logging +import aiohttp from pathlib import Path from typing import Optional @@ -7,14 +10,22 @@ FunctionForm, FunctionModel, FunctionResponse, + FunctionUserResponse, + FunctionWithValvesModel, Functions, ) -from open_webui.utils.plugin import load_function_module_by_id, replace_imports +from open_webui.utils.plugin import ( + load_function_module_by_id, + replace_imports, + get_function_module_from_cache, +) from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.env import SRC_LOG_LEVELS +from pydantic import BaseModel, HttpUrl + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -32,14 +43,136 @@ async def get_functions(user=Depends(get_verified_user)): return Functions.get_functions() +@router.get("/list", response_model=list[FunctionUserResponse]) +async def get_function_list(user=Depends(get_admin_user)): + return Functions.get_function_list() + + ############################ # ExportFunctions ############################ -@router.get("/export", response_model=list[FunctionModel]) -async def get_functions(user=Depends(get_admin_user)): - return Functions.get_functions() +@router.get("/export", response_model=list[FunctionModel | FunctionWithValvesModel]) +async def get_functions(include_valves: bool = False, user=Depends(get_admin_user)): + return Functions.get_functions(include_valves=include_valves) + + +############################ +# LoadFunctionFromLink +############################ + + +class LoadUrlForm(BaseModel): + url: HttpUrl + + +def github_url_to_raw_url(url: str) -> str: + # Handle 'tree' (folder) URLs (add main.py at the end) + m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url) + if m1: + org, repo, branch, path = m1.groups() + return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py" + + # Handle 'blob' (file) URLs + m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url) + if m2: + org, repo, branch, path = m2.groups() + return ( + f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}" + ) + + # No match; return as-is + return url + + +@router.post("/load/url", response_model=Optional[dict]) +async def load_function_from_url( + request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user) +): + # NOTE: This is NOT a SSRF vulnerability: + # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, + # and does NOT accept untrusted user input. Access is enforced by authentication. + + url = str(form_data.url) + if not url: + raise HTTPException(status_code=400, detail="Please enter a valid URL") + + url = github_url_to_raw_url(url) + url_parts = url.rstrip("/").split("/") + + file_name = url_parts[-1] + function_name = ( + file_name[:-3] + if ( + file_name.endswith(".py") + and (not file_name.startswith(("main.py", "index.py", "__init__.py"))) + ) + else url_parts[-2] if len(url_parts) > 1 else "function" + ) + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + url, headers={"Content-Type": "application/json"} + ) as resp: + if resp.status != 200: + raise HTTPException( + status_code=resp.status, detail="Failed to fetch the function" + ) + data = await resp.text() + if not data: + raise HTTPException( + status_code=400, detail="No data received from the URL" + ) + return { + "name": function_name, + "content": data, + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error importing function: {e}") + + +############################ +# SyncFunctions +############################ + + +class SyncFunctionsForm(BaseModel): + functions: list[FunctionWithValvesModel] = [] + + +@router.post("/sync", response_model=list[FunctionWithValvesModel]) +async def sync_functions( + request: Request, form_data: SyncFunctionsForm, user=Depends(get_admin_user) +): + try: + for function in form_data.functions: + function.content = replace_imports(function.content) + function_module, function_type, frontmatter = load_function_module_by_id( + function.id, + content=function.content, + ) + + if hasattr(function_module, "Valves") and function.valves: + Valves = function_module.Valves + try: + Valves( + **{k: v for k, v in function.valves.items() if v is not None} + ) + except Exception as e: + log.exception( + f"Error validating valves for function {function.id}: {e}" + ) + raise e + + return Functions.sync_functions(user.id, form_data.functions) + except Exception as e: + log.exception(f"Failed to load a function: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) ############################ @@ -77,6 +210,9 @@ async def create_new_function( function_cache_dir = CACHE_DIR / "functions" / form_data.id function_cache_dir.mkdir(parents=True, exist_ok=True) + if function_type == "filter" and getattr(function_module, "toggle", None): + Functions.update_function_metadata_by_id(id, {"toggle": True}) + if function: return function else: @@ -193,6 +329,9 @@ async def update_function_by_id( function = Functions.update_function_by_id(id, updated) + if function_type == "filter" and getattr(function_module, "toggle", None): + Functions.update_function_metadata_by_id(id, {"toggle": True}) + if function: return function else: @@ -262,11 +401,9 @@ async def get_function_valves_spec_by_id( ): function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache( + request, id + ) if hasattr(function_module, "Valves"): Valves = function_module.Valves @@ -290,11 +427,9 @@ async def update_function_valves_by_id( ): function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache( + request, id + ) if hasattr(function_module, "Valves"): Valves = function_module.Valves @@ -302,8 +437,10 @@ async def update_function_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) - Functions.update_function_valves_by_id(id, valves.model_dump()) - return valves.model_dump() + + valves_dict = valves.model_dump(exclude_unset=True) + Functions.update_function_valves_by_id(id, valves_dict) + return valves_dict except Exception as e: log.exception(f"Error updating function values by id {id}: {e}") raise HTTPException( @@ -353,11 +490,9 @@ async def get_function_user_valves_spec_by_id( ): function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache( + request, id + ) if hasattr(function_module, "UserValves"): UserValves = function_module.UserValves @@ -377,11 +512,9 @@ async def update_function_user_valves_by_id( function = Functions.get_function_by_id(id) if function: - if id in request.app.state.FUNCTIONS: - function_module = request.app.state.FUNCTIONS[id] - else: - function_module, function_type, frontmatter = load_function_module_by_id(id) - request.app.state.FUNCTIONS[id] = function_module + function_module, function_type, frontmatter = get_function_module_from_cache( + request, id + ) if hasattr(function_module, "UserValves"): UserValves = function_module.UserValves @@ -389,10 +522,11 @@ async def update_function_user_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) + user_valves_dict = user_valves.model_dump(exclude_unset=True) Functions.update_user_valves_by_id_and_user_id( - id, user.id, user_valves.model_dump() + id, user.id, user_valves_dict ) - return user_valves.model_dump() + return user_valves_dict except Exception as e: log.exception(f"Error updating function user valves by id {id}: {e}") raise HTTPException( diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index ae822c0d006..bf286fe001c 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -9,6 +9,7 @@ GroupForm, GroupUpdateForm, GroupResponse, + UserIdsForm, ) from open_webui.config import CACHE_DIR @@ -107,6 +108,56 @@ async def update_group_by_id( ) +############################ +# AddUserToGroupByUserIdAndGroupId +############################ + + +@router.post("/id/{id}/users/add", response_model=Optional[GroupResponse]) +async def add_user_to_group( + id: str, form_data: UserIdsForm, user=Depends(get_admin_user) +): + try: + if form_data.user_ids: + form_data.user_ids = Users.get_valid_user_ids(form_data.user_ids) + + group = Groups.add_users_to_group(id, form_data.user_ids) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error adding users to group"), + ) + except Exception as e: + log.exception(f"Error adding users to group {id}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + +@router.post("/id/{id}/users/remove", response_model=Optional[GroupResponse]) +async def remove_users_from_group( + id: str, form_data: UserIdsForm, user=Depends(get_admin_user) +): + try: + group = Groups.remove_users_from_group(id, form_data.user_ids) + if group: + return group + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT("Error removing users from group"), + ) + except Exception as e: + log.exception(f"Error removing users from group {id}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.DEFAULT(e), + ) + + ############################ # DeleteGroupById ############################ diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index b8bb110f51d..059b3a23d72 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -8,12 +8,20 @@ from pathlib import Path from typing import Optional +from urllib.parse import quote import requests -from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + UploadFile, +) + from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, SRC_LOG_LEVELS -from open_webui.routers.files import upload_file +from open_webui.routers.files import upload_file_handler from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.images.comfyui import ( ComfyUIGenerateImageForm, @@ -40,6 +48,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { @@ -64,6 +73,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): class OpenAIConfigForm(BaseModel): OPENAI_API_BASE_URL: str + OPENAI_API_VERSION: str OPENAI_API_KEY: str @@ -111,6 +121,9 @@ async def update_config( request.app.state.config.IMAGES_OPENAI_API_BASE_URL = ( form_data.openai.OPENAI_API_BASE_URL ) + request.app.state.config.IMAGES_OPENAI_API_VERSION = ( + form_data.openai.OPENAI_API_VERSION + ) request.app.state.config.IMAGES_OPENAI_API_KEY = form_data.openai.OPENAI_API_KEY request.app.state.config.IMAGES_GEMINI_API_BASE_URL = ( @@ -157,6 +170,7 @@ async def update_config( "prompt_generation": request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION, "openai": { "OPENAI_API_BASE_URL": request.app.state.config.IMAGES_OPENAI_API_BASE_URL, + "OPENAI_API_VERSION": request.app.state.config.IMAGES_OPENAI_API_VERSION, "OPENAI_API_KEY": request.app.state.config.IMAGES_OPENAI_API_KEY, }, "automatic1111": { @@ -302,8 +316,16 @@ async def update_image_config( ): set_image_model(request, form_data.MODEL) + if form_data.IMAGE_SIZE == "auto" and form_data.MODEL != "gpt-image-1": + raise HTTPException( + status_code=400, + detail=ERROR_MESSAGES.INCORRECT_FORMAT( + " (auto is only allowed with gpt-image-1)." + ), + ) + pattern = r"^\d+x\d+$" - if re.match(pattern, form_data.IMAGE_SIZE): + if form_data.IMAGE_SIZE == "auto" or re.match(pattern, form_data.IMAGE_SIZE): request.app.state.config.IMAGE_SIZE = form_data.IMAGE_SIZE else: raise HTTPException( @@ -333,10 +355,11 @@ def get_models(request: Request, user=Depends(get_verified_user)): return [ {"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-3", "name": "DALL·E 3"}, + {"id": "gpt-image-1", "name": "GPT-IMAGE 1"}, ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "gemini": return [ - {"id": "imagen-3-0-generate-002", "name": "imagen-3.0 generate-002"}, + {"id": "imagen-3.0-generate-002", "name": "imagen-3.0 generate-002"}, ] elif request.app.state.config.IMAGE_GENERATION_ENGINE == "comfyui": # TODO - get models from comfyui @@ -419,7 +442,7 @@ def load_b64_image_data(b64_str): try: if "," in b64_str: header, encoded = b64_str.split(",", 1) - mime_type = header.split(";")[0] + mime_type = header.split(";")[0].lstrip("data:") img_data = base64.b64decode(encoded) else: mime_type = "image/png" @@ -427,7 +450,7 @@ def load_b64_image_data(b64_str): return img_data, mime_type except Exception as e: log.exception(f"Error loading image data: {e}") - return None + return None, None def load_url_image_data(url, headers=None): @@ -450,7 +473,7 @@ def load_url_image_data(url, headers=None): return None -def upload_image(request, image_metadata, image_data, content_type, user): +def upload_image(request, image_data, content_type, metadata, user): image_format = mimetypes.guess_extension(content_type) file = UploadFile( file=io.BytesIO(image_data), @@ -459,7 +482,13 @@ def upload_image(request, image_metadata, image_data, content_type, user): "content-type": content_type, }, ) - file_item = upload_file(request, file, user, file_metadata=image_metadata) + file_item = upload_file_handler( + request, + file=file, + metadata=metadata, + process=False, + user=user, + ) url = request.app.url_path_for("get_file_content_by_id", id=file_item.id) return url @@ -470,7 +499,22 @@ async def image_generations( form_data: GenerateImageForm, user=Depends(get_verified_user), ): - width, height = tuple(map(int, request.app.state.config.IMAGE_SIZE.split("x"))) + # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default + # This is only relevant when the user has set IMAGE_SIZE to 'auto' with an + # image model other than gpt-image-1, which is warned about on settings save + + size = "512x512" + if ( + request.app.state.config.IMAGE_SIZE + and "x" in request.app.state.config.IMAGE_SIZE + ): + size = request.app.state.config.IMAGE_SIZE + + if form_data.size and "x" in form_data.size: + size = form_data.size + + width, height = tuple(map(int, size.split("x"))) + model = get_image_model(request) r = None try: @@ -482,17 +526,13 @@ async def image_generations( headers["Content-Type"] = "application/json" if ENABLE_FORWARD_USER_INFO_HEADERS: - headers["X-OpenWebUI-User-Name"] = user.name + headers["X-OpenWebUI-User-Name"] = quote(user.name, safe=" ") headers["X-OpenWebUI-User-Id"] = user.id headers["X-OpenWebUI-User-Email"] = user.email headers["X-OpenWebUI-User-Role"] = user.role data = { - "model": ( - request.app.state.config.IMAGE_GENERATION_MODEL - if request.app.state.config.IMAGE_GENERATION_MODEL != "" - else "dall-e-2" - ), + "model": model, "prompt": form_data.prompt, "n": form_data.n, "size": ( @@ -507,10 +547,16 @@ async def image_generations( ), } + api_version_query_param = "" + if request.app.state.config.IMAGES_OPENAI_API_VERSION: + api_version_query_param = ( + f"?api-version={request.app.state.config.IMAGES_OPENAI_API_VERSION}" + ) + # Use asyncio.to_thread for the requests.post call r = await asyncio.to_thread( requests.post, - url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations", + url=f"{request.app.state.config.IMAGES_OPENAI_API_BASE_URL}/images/generations{api_version_query_param}", json=data, headers=headers, ) @@ -526,7 +572,7 @@ async def image_generations( else: image_data, content_type = load_b64_image_data(image["b64_json"]) - url = upload_image(request, data, image_data, content_type, user) + url = upload_image(request, image_data, content_type, data, user) images.append({"url": url}) return images @@ -535,7 +581,6 @@ async def image_generations( headers["Content-Type"] = "application/json" headers["x-goog-api-key"] = request.app.state.config.IMAGES_GEMINI_API_KEY - model = get_image_model(request) data = { "instances": {"prompt": form_data.prompt}, "parameters": { @@ -560,7 +605,7 @@ async def image_generations( image_data, content_type = load_b64_image_data( image["bytesBase64Encoded"] ) - url = upload_image(request, data, image_data, content_type, user) + url = upload_image(request, image_data, content_type, data, user) images.append({"url": url}) return images @@ -591,7 +636,7 @@ async def image_generations( } ) res = await comfyui_generate_image( - request.app.state.config.IMAGE_GENERATION_MODEL, + model, form_data, user.id, request.app.state.config.COMFYUI_BASE_URL, @@ -611,9 +656,9 @@ async def image_generations( image_data, content_type = load_url_image_data(image["url"], headers) url = upload_image( request, - form_data.model_dump(exclude_none=True), image_data, content_type, + form_data.model_dump(exclude_none=True), user, ) images.append({"url": url}) @@ -664,9 +709,9 @@ async def image_generations( image_data, content_type = load_b64_image_data(image) url = upload_image( request, - {**data, "info": res["info"]}, image_data, content_type, + {**data, "info": res["info"]}, user, ) images.append({"url": url}) diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index e6e55f4d388..71722d706e5 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -1,6 +1,6 @@ from typing import List, Optional from pydantic import BaseModel -from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi import APIRouter, Depends, HTTPException, status, Request, Query import logging from open_webui.models.knowledge import ( @@ -25,6 +25,7 @@ from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.models.models import Models, ModelForm @@ -42,7 +43,7 @@ async def get_knowledge(user=Depends(get_verified_user)): knowledge_bases = [] - if user.role == "admin": + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: knowledge_bases = Knowledges.get_knowledge_bases() else: knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "read") @@ -90,7 +91,7 @@ async def get_knowledge(user=Depends(get_verified_user)): async def get_knowledge_list(user=Depends(get_verified_user)): knowledge_bases = [] - if user.role == "admin": + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: knowledge_bases = Knowledges.get_knowledge_bases() else: knowledge_bases = Knowledges.get_knowledge_bases_by_user_id(user.id, "write") @@ -150,6 +151,18 @@ async def create_new_knowledge( detail=ERROR_MESSAGES.UNAUTHORIZED, ) + # Check if user can share publicly + if ( + user.role != "admin" + and form_data.access_control == None + and not has_permission( + user.id, + "sharing.public_knowledge", + request.app.state.config.USER_PERMISSIONS, + ) + ): + form_data.access_control = {} + knowledge = Knowledges.insert_new_knowledge(user.id, form_data) if knowledge: @@ -284,6 +297,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user)): @router.post("/{id}/update", response_model=Optional[KnowledgeFilesResponse]) async def update_knowledge_by_id( + request: Request, id: str, form_data: KnowledgeForm, user=Depends(get_verified_user), @@ -305,10 +319,22 @@ async def update_knowledge_by_id( detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) + # Check if user can share publicly + if ( + user.role != "admin" + and form_data.access_control == None + and not has_permission( + user.id, + "sharing.public_knowledge", + request.app.state.config.USER_PERMISSIONS, + ) + ): + form_data.access_control = {} + knowledge = Knowledges.update_knowledge_by_id(id=id, form_data=form_data) if knowledge: file_ids = knowledge.data.get("file_ids", []) if knowledge.data else [] - files = Files.get_files_by_ids(file_ids) + files = Files.get_file_metadatas_by_ids(file_ids) return KnowledgeFilesResponse( **knowledge.model_dump(), @@ -491,6 +517,7 @@ def update_file_from_knowledge_by_id( def remove_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, + delete_file: bool = Query(True), user=Depends(get_verified_user), ): knowledge = Knowledges.get_knowledge_by_id(id=id) @@ -527,18 +554,19 @@ def remove_file_from_knowledge_by_id( log.debug(e) pass - try: - # Remove the file's collection from vector database - file_collection = f"file-{form_data.file_id}" - if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection): - VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection) - except Exception as e: - log.debug("This was most likely caused by bypassing embedding processing") - log.debug(e) - pass + if delete_file: + try: + # Remove the file's collection from vector database + file_collection = f"file-{form_data.file_id}" + if VECTOR_DB_CLIENT.has_collection(collection_name=file_collection): + VECTOR_DB_CLIENT.delete_collection(collection_name=file_collection) + except Exception as e: + log.debug("This was most likely caused by bypassing embedding processing") + log.debug(e) + pass - # Delete file from database - Files.delete_file_by_id(form_data.file_id) + # Delete file from database + Files.delete_file_by_id(form_data.file_id) if knowledge: data = knowledge.data or {} diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 333e9ecc6af..11b3d0c96c5 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -82,6 +82,10 @@ class QueryMemoryForm(BaseModel): async def query_memory( request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) ): + memories = Memories.get_memories_by_user_id(user.id) + if not memories: + raise HTTPException(status_code=404, detail="No memories found for user") + results = VECTOR_DB_CLIENT.search( collection_name=f"user-memory-{user.id}", vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user=user)], diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 0cf3308f194..215cd8426c2 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,4 +1,9 @@ from typing import Optional +import io +import base64 +import json +import asyncio +import logging from open_webui.models.models import ( ModelForm, @@ -7,17 +12,33 @@ ModelUserResponse, Models, ) + +from pydantic import BaseModel from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import ( + APIRouter, + Depends, + HTTPException, + Request, + status, + Response, +) +from fastapi.responses import FileResponse, StreamingResponse from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STATIC_DIR +log = logging.getLogger(__name__) router = APIRouter() +def validate_model_id(model_id: str) -> bool: + return model_id and len(model_id) <= 256 + + ########################### # GetModels ########################### @@ -25,7 +46,7 @@ @router.get("/", response_model=list[ModelUserResponse]) async def get_models(id: Optional[str] = None, user=Depends(get_verified_user)): - if user.role == "admin": + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: return Models.get_models() else: return Models.get_models_by_user_id(user.id) @@ -67,6 +88,12 @@ async def create_new_model( detail=ERROR_MESSAGES.MODEL_ID_TAKEN, ) + if not validate_model_id(form_data.id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.MODEL_ID_TOO_LONG, + ) + else: model = Models.insert_new_model(form_data, user.id) if model: @@ -78,6 +105,77 @@ async def create_new_model( ) +############################ +# ExportModels +############################ + + +@router.get("/export", response_model=list[ModelModel]) +async def export_models(user=Depends(get_admin_user)): + return Models.get_models() + + +############################ +# ImportModels +############################ + + +class ModelsImportForm(BaseModel): + models: list[dict] + + +@router.post("/import", response_model=bool) +async def import_models( + user: str = Depends(get_admin_user), form_data: ModelsImportForm = (...) +): + try: + data = form_data.models + if isinstance(data, list): + for model_data in data: + # Here, you can add logic to validate model_data if needed + model_id = model_data.get("id") + + if model_id and validate_model_id(model_id): + existing_model = Models.get_model_by_id(model_id) + if existing_model: + # Update existing model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + + updated_model = ModelForm( + **{**existing_model.model_dump(), **model_data} + ) + Models.update_model_by_id(model_id, updated_model) + else: + # Insert new model + model_data["meta"] = model_data.get("meta", {}) + model_data["params"] = model_data.get("params", {}) + new_model = ModelForm(**model_data) + Models.insert_new_model(user_id=user.id, form_data=new_model) + return True + else: + raise HTTPException(status_code=400, detail="Invalid JSON format") + except Exception as e: + log.exception(e) + raise HTTPException(status_code=500, detail=str(e)) + + +############################ +# SyncModels +############################ + + +class SyncModelsForm(BaseModel): + models: list[ModelModel] = [] + + +@router.post("/sync", response_model=list[ModelModel]) +async def sync_models( + request: Request, form_data: SyncModelsForm, user=Depends(get_admin_user) +): + return Models.sync_models(user.id, form_data.models) + + ########################### # GetModelById ########################### @@ -89,7 +187,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): model = Models.get_model_by_id(id) if model: if ( - user.role == "admin" + (user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL) or model.user_id == user.id or has_access(user.id, "read", model.access_control) ): @@ -101,8 +199,41 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)): ) +########################### +# GetModelById +########################### + + +@router.get("/model/profile/image") +async def get_model_profile_image(id: str, user=Depends(get_verified_user)): + model = Models.get_model_by_id(id) + if model: + if model.meta.profile_image_url: + if model.meta.profile_image_url.startswith("http"): + return Response( + status_code=status.HTTP_302_FOUND, + headers={"Location": model.meta.profile_image_url}, + ) + elif model.meta.profile_image_url.startswith("data:image"): + try: + header, base64_data = model.meta.profile_image_url.split(",", 1) + image_data = base64.b64decode(base64_data) + image_buffer = io.BytesIO(image_data) + + return StreamingResponse( + image_buffer, + media_type="image/png", + headers={"Content-Disposition": "inline; filename=image.png"}, + ) + except Exception as e: + pass + return FileResponse(f"{STATIC_DIR}/favicon.png") + else: + return FileResponse(f"{STATIC_DIR}/favicon.png") + + ############################ -# ToggelModelById +# ToggleModelById ############################ diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 5ad5ff051e3..3858c4670f2 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -6,6 +6,9 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks from pydantic import BaseModel +from open_webui.socket.main import sio + + from open_webui.models.users import Users, UserResponse from open_webui.models.notes import Notes, NoteModel, NoteForm, NoteUserResponse @@ -45,15 +48,23 @@ async def get_notes(request: Request, user=Depends(get_verified_user)): "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), } ) - for note in Notes.get_notes_by_user_id(user.id, "write") + for note in Notes.get_notes_by_permission(user.id, "write") ] return notes -@router.get("/list", response_model=list[NoteUserResponse]) -async def get_note_list(request: Request, user=Depends(get_verified_user)): +class NoteTitleIdResponse(BaseModel): + id: str + title: str + updated_at: int + created_at: int + +@router.get("/list", response_model=list[NoteTitleIdResponse]) +async def get_note_list( + request: Request, page: Optional[int] = None, user=Depends(get_verified_user) +): if user.role != "admin" and not has_permission( user.id, "features.notes", request.app.state.config.USER_PERMISSIONS ): @@ -62,14 +73,17 @@ async def get_note_list(request: Request, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.UNAUTHORIZED, ) + limit = None + skip = None + if page is not None: + limit = 60 + skip = (page - 1) * limit + notes = [ - NoteUserResponse( - **{ - **note.model_dump(), - "user": UserResponse(**Users.get_user_by_id(note.user_id).model_dump()), - } + NoteTitleIdResponse(**note.model_dump()) + for note in Notes.get_notes_by_permission( + user.id, "write", skip=skip, limit=limit ) - for note in Notes.get_notes_by_user_id(user.id, "read") ] return notes @@ -124,10 +138,9 @@ async def get_note_by_id(request: Request, id: str, user=Depends(get_verified_us status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if ( - user.role != "admin" - and user.id != note.user_id - and not has_access(user.id, type="read", access_control=note.access_control) + if user.role != "admin" and ( + user.id != note.user_id + and (not has_access(user.id, type="read", access_control=note.access_control)) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() @@ -159,17 +172,34 @@ async def update_note_by_id( status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if ( - user.role != "admin" - and user.id != note.user_id + if user.role != "admin" and ( + user.id != note.user_id and not has_access(user.id, type="write", access_control=note.access_control) ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.DEFAULT() ) + # Check if user can share publicly + if ( + user.role != "admin" + and form_data.access_control == None + and not has_permission( + user.id, + "sharing.public_notes", + request.app.state.config.USER_PERMISSIONS, + ) + ): + form_data.access_control = {} + try: note = Notes.update_note_by_id(id, form_data) + await sio.emit( + "note-events", + note.model_dump(), + to=f"note:{note.id}", + ) + return note except Exception as e: log.exception(e) @@ -199,9 +229,8 @@ async def delete_note_by_id(request: Request, id: str, user=Depends(get_verified status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND ) - if ( - user.role != "admin" - and user.id != note.user_id + if user.role != "admin" and ( + user.id != note.user_id and not has_access(user.id, type="write", access_control=note.access_control) ): raise HTTPException( diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 7c313ea9700..64b0687afa0 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -9,11 +9,16 @@ import random import re import time +from datetime import datetime + from typing import Optional, Union from urllib.parse import urlparse import aiohttp from aiocache import cached import requests +from urllib.parse import quote + +from open_webui.models.chats import Chats from open_webui.models.users import UserModel from open_webui.env import ( @@ -42,7 +47,7 @@ from open_webui.utils.payload import ( apply_model_params_to_body_ollama, apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, + apply_system_prompt_to_body, ) from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access @@ -54,6 +59,7 @@ from open_webui.env import ( ENV, SRC_LOG_LEVELS, + MODELS_CACHE_TTL, AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, @@ -83,7 +89,7 @@ async def send_get_request(url, key=None, user: UserModel = None): **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -118,6 +124,7 @@ async def send_post_request( key: Optional[str] = None, content_type: Optional[str] = None, user: UserModel = None, + metadata: Optional[dict] = None, ): r = None @@ -134,10 +141,15 @@ async def send_post_request( **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, + **( + {"X-OpenWebUI-Chat-Id": metadata.get("chat_id")} + if metadata and metadata.get("chat_id") + else {} + ), } if ENABLE_FORWARD_USER_INFO_HEADERS and user else {} @@ -145,8 +157,23 @@ async def send_post_request( }, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) - r.raise_for_status() + if r.ok is False: + try: + res = await r.json() + await cleanup_response(r, session) + if "error" in res: + raise HTTPException(status_code=r.status, detail=res["error"]) + except HTTPException as e: + raise e # Re-raise HTTPException to be handled by FastAPI + except Exception as e: + log.error(f"Failed to parse error response: {e}") + raise HTTPException( + status_code=r.status, + detail=f"Open WebUI: Server Connection Error", + ) + + r.raise_for_status() # Raises an error for bad responses (4xx, 5xx) if stream: response_headers = dict(r.headers) @@ -163,24 +190,20 @@ async def send_post_request( ) else: res = await r.json() - await cleanup_response(r, session) return res + except HTTPException as e: + raise e # Re-raise HTTPException to be handled by FastAPI except Exception as e: - detail = None - - if r is not None: - try: - res = await r.json() - if "error" in res: - detail = f"Ollama: {res.get('error', 'Unknown error')}" - except Exception: - detail = f"Ollama: {e}" + detail = f"Ollama: {e}" raise HTTPException( status_code=r.status if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail=detail if e else "Open WebUI: Server Connection Error", ) + finally: + if not stream: + await cleanup_response(r, session) def get_api_key(idx, url, configs): @@ -229,7 +252,7 @@ async def verify_connection( **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -300,7 +323,27 @@ async def update_config( } -@cached(ttl=1) +def merge_ollama_models_lists(model_lists): + merged_models = {} + + for idx, model_list in enumerate(model_lists): + if model_list is not None: + for model in model_list: + id = model.get("model") + if id is not None: + if id not in merged_models: + model["urls"] = [idx] + merged_models[id] = model + else: + merged_models[id]["urls"].append(idx) + + return list(merged_models.values()) + + +@cached( + ttl=MODELS_CACHE_TTL, + key=lambda _, user: f"ollama_all_models_{user.id}" if user else "ollama_all_models", +) async def get_all_models(request: Request, user: UserModel = None): log.info("get_all_models()") if request.app.state.config.ENABLE_OLLAMA_API: @@ -364,23 +407,8 @@ async def get_all_models(request: Request, user: UserModel = None): if connection_type: model["connection_type"] = connection_type - def merge_models_lists(model_lists): - merged_models = {} - - for idx, model_list in enumerate(model_lists): - if model_list is not None: - for model in model_list: - id = model["model"] - if id not in merged_models: - model["urls"] = [idx] - merged_models[id] = model - else: - merged_models[id]["urls"].append(idx) - - return list(merged_models.values()) - models = { - "models": merge_models_lists( + "models": merge_ollama_models_lists( map( lambda response: response.get("models", []) if response else None, responses, @@ -388,6 +416,22 @@ def merge_models_lists(model_lists): ) } + try: + loaded_models = await get_ollama_loaded_models(request, user=user) + expires_map = { + m["model"]: m["expires_at"] + for m in loaded_models["models"] + if "expires_at" in m + } + + for m in models["models"]: + if m["model"] in expires_map: + # Parse ISO8601 datetime with offset, get unix timestamp as int + dt = datetime.fromisoformat(expires_map[m["model"]]) + m["expires_at"] = int(dt.timestamp()) + except Exception as e: + log.debug(f"Failed to get loaded models: {e}") + else: models = {"models": []} @@ -432,7 +476,7 @@ async def get_ollama_tags( **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -468,6 +512,68 @@ async def get_ollama_tags( return models +@router.get("/api/ps") +async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user)): + """ + List models that are currently loaded into Ollama memory, and which node they are loaded on. + """ + if request.app.state.config.ENABLE_OLLAMA_API: + request_tasks = [] + for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): + if (str(idx) not in request.app.state.config.OLLAMA_API_CONFIGS) and ( + url not in request.app.state.config.OLLAMA_API_CONFIGS # Legacy support + ): + request_tasks.append(send_get_request(f"{url}/api/ps", user=user)) + else: + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + + enable = api_config.get("enable", True) + key = api_config.get("key", None) + + if enable: + request_tasks.append( + send_get_request(f"{url}/api/ps", key, user=user) + ) + else: + request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) + + responses = await asyncio.gather(*request_tasks) + + for idx, response in enumerate(responses): + if response: + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), + request.app.state.config.OLLAMA_API_CONFIGS.get( + url, {} + ), # Legacy support + ) + + prefix_id = api_config.get("prefix_id", None) + + for model in response.get("models", []): + if prefix_id: + model["model"] = f"{prefix_id}.{model['model']}" + + models = { + "models": merge_ollama_models_lists( + map( + lambda response: response.get("models", []) if response else None, + responses, + ) + ) + } + else: + models = {"models": []} + + return models + + @router.get("/api/version") @router.get("/api/version/{url_idx}") async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): @@ -541,34 +647,77 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): return {"version": False} -@router.get("/api/ps") -async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): - """ - List models that are currently loaded into Ollama memory, and which node they are loaded on. - """ - if request.app.state.config.ENABLE_OLLAMA_API: - request_tasks = [ - send_get_request( - f"{url}/api/ps", - request.app.state.config.OLLAMA_API_CONFIGS.get( - str(idx), - request.app.state.config.OLLAMA_API_CONFIGS.get( - url, {} - ), # Legacy support - ).get("key", None), +class ModelNameForm(BaseModel): + model: Optional[str] = None + model_config = ConfigDict( + extra="allow", + ) + + +@router.post("/api/unload") +async def unload_model( + request: Request, + form_data: ModelNameForm, + user=Depends(get_admin_user), +): + form_data = form_data.model_dump(exclude_none=True) + model_name = form_data.get("model", form_data.get("name")) + + if not model_name: + raise HTTPException( + status_code=400, detail="Missing name of the model to unload." + ) + + # Refresh/load models if needed, get mapping from name to URLs + await get_all_models(request, user=user) + models = request.app.state.OLLAMA_MODELS + + # Canonicalize model name (if not supplied with version) + if ":" not in model_name: + model_name = f"{model_name}:latest" + + if model_name not in models: + raise HTTPException( + status_code=400, detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model_name) + ) + url_indices = models[model_name]["urls"] + + # Send unload to ALL url_indices + results = [] + errors = [] + for idx in url_indices: + url = request.app.state.config.OLLAMA_BASE_URLS[idx] + api_config = request.app.state.config.OLLAMA_API_CONFIGS.get( + str(idx), request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) + ) + key = get_api_key(idx, url, request.app.state.config.OLLAMA_API_CONFIGS) + + prefix_id = api_config.get("prefix_id", None) + if prefix_id and model_name.startswith(f"{prefix_id}."): + model_name = model_name[len(f"{prefix_id}.") :] + + payload = {"model": model_name, "keep_alive": 0, "prompt": ""} + + try: + res = await send_post_request( + url=f"{url}/api/generate", + payload=json.dumps(payload), + stream=False, + key=key, user=user, ) - for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS) - ] - responses = await asyncio.gather(*request_tasks) - - return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) - else: - return {} + results.append({"url_idx": idx, "success": True, "response": res}) + except Exception as e: + log.exception(f"Failed to unload model on node {idx}: {e}") + errors.append({"url_idx": idx, "success": False, "error": str(e)}) + if len(errors) > 0: + raise HTTPException( + status_code=500, + detail=f"Failed to unload model on {len(errors)} nodes: {errors}", + ) -class ModelNameForm(BaseModel): - name: str + return {"status": True} @router.post("/api/pull") @@ -579,11 +728,14 @@ async def pull_model( url_idx: int = 0, user=Depends(get_admin_user), ): + form_data = form_data.model_dump(exclude_none=True) + form_data["model"] = form_data.get("model", form_data.get("name")) + url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] log.info(f"url: {url}") # Admin should be able to pull models from any source - payload = {**form_data.model_dump(exclude_none=True), "insecure": True} + payload = {**form_data, "insecure": True} return await send_post_request( url=f"{url}/api/pull", @@ -594,7 +746,7 @@ async def pull_model( class PushModelForm(BaseModel): - name: str + model: str insecure: Optional[bool] = None stream: Optional[bool] = None @@ -611,12 +763,12 @@ async def push_model( await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS - if form_data.name in models: - url_idx = models[form_data.name]["urls"][0] + if form_data.model in models: + url_idx = models[form_data.model]["urls"][0] else: raise HTTPException( status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -694,7 +846,7 @@ async def copy_model( **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -735,16 +887,21 @@ async def delete_model( url_idx: Optional[int] = None, user=Depends(get_admin_user), ): + form_data = form_data.model_dump(exclude_none=True) + form_data["model"] = form_data.get("model", form_data.get("name")) + + model = form_data.get("model") + if url_idx is None: await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS - if form_data.name in models: - url_idx = models[form_data.name]["urls"][0] + if model in models: + url_idx = models[model]["urls"][0] else: raise HTTPException( status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] @@ -754,13 +911,13 @@ async def delete_model( r = requests.request( method="DELETE", url=f"{url}/api/delete", - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(form_data).encode(), headers={ "Content-Type": "application/json", **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -796,16 +953,21 @@ async def delete_model( async def show_model_info( request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) ): + form_data = form_data.model_dump(exclude_none=True) + form_data["model"] = form_data.get("model", form_data.get("name")) + await get_all_models(request, user=user) models = request.app.state.OLLAMA_MODELS - if form_data.name not in models: + model = form_data.get("model") + + if model not in models: raise HTTPException( status_code=400, - detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), + detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), ) - url_idx = random.choice(models[form_data.name]["urls"]) + url_idx = random.choice(models[model]["urls"]) url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) @@ -819,7 +981,7 @@ async def show_model_info( **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -828,7 +990,7 @@ async def show_model_info( else {} ), }, - data=form_data.model_dump_json(exclude_none=True).encode(), + data=json.dumps(form_data).encode(), ) r.raise_for_status() @@ -858,6 +1020,10 @@ class GenerateEmbedForm(BaseModel): options: Optional[dict] = None keep_alive: Optional[Union[int, str]] = None + model_config = ConfigDict( + extra="allow", + ) + @router.post("/api/embed") @router.post("/api/embed/{url_idx}") @@ -906,7 +1072,7 @@ async def embed( **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -993,7 +1159,7 @@ async def embeddings( **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -1113,6 +1279,9 @@ class GenerateChatCompletionForm(BaseModel): stream: Optional[bool] = True keep_alive: Optional[Union[int, str]] = None tools: Optional[list[dict]] = None + model_config = ConfigDict( + extra="allow", + ) async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): @@ -1150,7 +1319,9 @@ async def generate_chat_completion( detail=str(e), ) - payload = {**form_data.model_dump(exclude_none=True)} + if isinstance(form_data, BaseModel): + payload = {**form_data.model_dump(exclude_none=True)} + if "metadata" in payload: del payload["metadata"] @@ -1164,13 +1335,10 @@ async def generate_chat_completion( params = model_info.params.model_dump() if params: - if payload.get("options") is None: - payload["options"] = {} + system = params.pop("system", None) - payload["options"] = apply_model_params_to_body_ollama( - params, payload["options"] - ) - payload = apply_model_system_prompt_to_body(params, payload, metadata, user) + payload = apply_model_params_to_body_ollama(params, payload) + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if not bypass_filter and user.role == "user": @@ -1203,7 +1371,7 @@ async def generate_chat_completion( prefix_id = api_config.get("prefix_id", None) if prefix_id: payload["model"] = payload["model"].replace(f"{prefix_id}.", "") - # payload["keep_alive"] = -1 # keep alive forever + return await send_post_request( url=f"{url}/api/chat", payload=json.dumps(payload), @@ -1211,6 +1379,7 @@ async def generate_chat_completion( key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), content_type="application/x-ndjson", user=user, + metadata=metadata, ) @@ -1249,6 +1418,8 @@ async def generate_openai_completion( url_idx: Optional[int] = None, user=Depends(get_verified_user), ): + metadata = form_data.pop("metadata", None) + try: form_data = OpenAICompletionForm(**form_data) except Exception as e: @@ -1314,6 +1485,7 @@ async def generate_openai_completion( stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, + metadata=metadata, ) @@ -1352,8 +1524,10 @@ async def generate_openai_chat_completion( params = model_info.params.model_dump() if params: + system = params.pop("system", None) + payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(params, payload, metadata, user) + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if user.role == "user": @@ -1393,6 +1567,7 @@ async def generate_openai_chat_completion( stream=payload.get("stream", False), key=get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS), user=user, + metadata=metadata, ) @@ -1523,25 +1698,27 @@ async def download_file_stream( yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' if done: - file.seek(0) - chunk_size = 1024 * 1024 * 2 - hashed = calculate_sha256(file, chunk_size) - file.seek(0) - - url = f"{ollama_url}/api/blobs/sha256:{hashed}" - response = requests.post(url, data=file) - - if response.ok: - res = { - "done": done, - "blob": f"sha256:{hashed}", - "name": file_name, - } - os.remove(file_path) - - yield f"data: {json.dumps(res)}\n\n" - else: - raise "Ollama: Could not create blob, Please try again." + file.close() + + with open(file_path, "rb") as file: + chunk_size = 1024 * 1024 * 2 + hashed = calculate_sha256(file, chunk_size) + + url = f"{ollama_url}/api/blobs/sha256:{hashed}" + with requests.Session() as session: + response = session.post(url, data=file, timeout=30) + + if response.ok: + res = { + "done": done, + "blob": f"sha256:{hashed}", + "name": file_name, + } + os.remove(file_path) + + yield f"data: {json.dumps(res)}\n\n" + else: + raise "Ollama: Could not create blob, Please try again." # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 96c21f9c03a..8c5e3da7364 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -2,17 +2,22 @@ import hashlib import json import logging -from pathlib import Path -from typing import Literal, Optional, overload +from typing import Optional import aiohttp from aiocache import cached import requests +from urllib.parse import quote +from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from fastapi import Depends, FastAPI, HTTPException, Request, APIRouter -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import FileResponse, StreamingResponse +from fastapi import Depends, HTTPException, Request, APIRouter +from fastapi.responses import ( + FileResponse, + StreamingResponse, + JSONResponse, + PlainTextResponse, +) from pydantic import BaseModel from starlette.background import BackgroundTask @@ -21,6 +26,7 @@ CACHE_DIR, ) from open_webui.env import ( + MODELS_CACHE_TTL, AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, @@ -30,12 +36,12 @@ from open_webui.models.users import UserModel from open_webui.constants import ERROR_MESSAGES -from open_webui.env import ENV, SRC_LOG_LEVELS +from open_webui.env import SRC_LOG_LEVELS from open_webui.utils.payload import ( apply_model_params_to_body_openai, - apply_model_system_prompt_to_body, + apply_system_prompt_to_body, ) from open_webui.utils.misc import ( convert_logit_bias_input_to_json, @@ -66,7 +72,7 @@ async def send_get_request(url, key=None, user: UserModel = None): **({"Authorization": f"Bearer {key}"} if key else {}), **( { - "X-OpenWebUI-User-Name": user.name, + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), "X-OpenWebUI-User-Id": user.id, "X-OpenWebUI-User-Email": user.email, "X-OpenWebUI-User-Role": user.role, @@ -94,12 +100,12 @@ async def cleanup_response( await session.close() -def openai_o_series_handler(payload): +def openai_reasoning_model_handler(payload): """ - Handle "o" series specific parameters + Handle reasoning model specific parameters """ if "max_tokens" in payload: - # Convert "max_tokens" to "max_completion_tokens" for all o-series models + # Convert "max_tokens" to "max_completion_tokens" for all reasoning models payload["max_completion_tokens"] = payload["max_tokens"] del payload["max_tokens"] @@ -115,6 +121,96 @@ def openai_o_series_handler(payload): return payload +async def get_headers_and_cookies( + request: Request, + url, + key=None, + config=None, + metadata: Optional[dict] = None, + user: UserModel = None, +): + cookies = {} + headers = { + "Content-Type": "application/json", + **( + { + "HTTP-Referer": "https://openwebui.com/", + "X-Title": "Open WebUI", + } + if "openrouter.ai" in url + else {} + ), + **( + { + "X-OpenWebUI-User-Name": quote(user.name, safe=" "), + "X-OpenWebUI-User-Id": user.id, + "X-OpenWebUI-User-Email": user.email, + "X-OpenWebUI-User-Role": user.role, + **( + {"X-OpenWebUI-Chat-Id": metadata.get("chat_id")} + if metadata and metadata.get("chat_id") + else {} + ), + } + if ENABLE_FORWARD_USER_INFO_HEADERS + else {} + ), + } + + token = None + auth_type = config.get("auth_type") + + if auth_type == "bearer" or auth_type is None: + # Default to bearer if not specified + token = f"{key}" + elif auth_type == "none": + token = None + elif auth_type == "session": + cookies = request.cookies + token = request.state.token.credentials + elif auth_type == "system_oauth": + cookies = request.cookies + + oauth_token = None + try: + if request.cookies.get("oauth_session_id", None): + oauth_token = await request.app.state.oauth_manager.get_oauth_token( + user.id, + request.cookies.get("oauth_session_id", None), + ) + except Exception as e: + log.error(f"Error getting OAuth token: {e}") + + if oauth_token: + token = f"{oauth_token.get('access_token', '')}" + + elif auth_type in ("azure_ad", "microsoft_entra_id"): + token = get_microsoft_entra_id_access_token() + + if token: + headers["Authorization"] = f"Bearer {token}" + + if config.get("headers") and isinstance(config.get("headers"), dict): + headers = {**headers, **config.get("headers")} + + return headers, cookies + + +def get_microsoft_entra_id_access_token(): + """ + Get Microsoft Entra ID access token using DefaultAzureCredential for Azure OpenAI. + Returns the token string or None if authentication fails. + """ + try: + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) + return token_provider() + except Exception as e: + log.error(f"Error getting Microsoft Entra ID access token: {e}") + return None + + ########################################## # # API routes @@ -206,34 +302,23 @@ async def speech(request: Request, user=Depends(get_verified_user)): return FileResponse(file_path) url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support + ) + + headers, cookies = await get_headers_and_cookies( + request, url, key, api_config, user=user + ) r = None try: r = requests.post( url=f"{url}/audio/speech", data=body, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {request.app.state.config.OPENAI_API_KEYS[idx]}", - **( - { - "HTTP-Referer": "https://openwebui.com/", - "X-Title": "Open WebUI", - } - if "openrouter.ai" in url - else {} - ), - **( - { - "X-OpenWebUI-User-Name": user.name, - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - }, + headers=headers, + cookies=cookies, stream=True, ) @@ -357,11 +442,22 @@ async def get_all_models_responses(request: Request, user: UserModel) -> list: prefix_id = api_config.get("prefix_id", None) tags = api_config.get("tags", []) - for model in ( + model_list = ( response if isinstance(response, list) else response.get("data", []) - ): + ) + if not isinstance(model_list, list): + # Catch non-list responses + model_list = [] + + for model in model_list: + # Remove name key if its value is None #16689 + if "name" in model and model["name"] is None: + del model["name"] + if prefix_id: - model["id"] = f"{prefix_id}.{model['id']}" + model["id"] = ( + f"{prefix_id}.{model.get('id', model.get('name', ''))}" + ) if tags: model["tags"] = tags @@ -386,7 +482,10 @@ async def get_filtered_models(models, user): return filtered_models -@cached(ttl=1) +@cached( + ttl=MODELS_CACHE_TTL, + key=lambda _, user: f"openai_all_models_{user.id}" if user else "openai_all_models", +) async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: log.info("get_all_models()") @@ -474,19 +573,9 @@ async def get_models( timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - headers = { - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": user.name, - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = await get_headers_and_cookies( + request, url, key, api_config, user=user + ) if api_config.get("azure", False): models = { @@ -494,11 +583,10 @@ async def get_models( "object": "list", } else: - headers["Authorization"] = f"Bearer {key}" - async with session.get( f"{url}/models", headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: if r.status != 200: @@ -557,7 +645,9 @@ class ConnectionVerificationForm(BaseModel): @router.post("/verify") async def verify_connection( - form_data: ConnectionVerificationForm, user=Depends(get_admin_user) + request: Request, + form_data: ConnectionVerificationForm, + user=Depends(get_admin_user), ): url = form_data.url key = form_data.key @@ -569,56 +659,61 @@ async def verify_connection( timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST), ) as session: try: - headers = { - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": user.name, - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = await get_headers_and_cookies( + request, url, key, api_config, user=user + ) if api_config.get("azure", False): - headers["api-key"] = key - api_version = api_config.get("api_version", "") or "2023-03-15-preview" + # Only set api-key header if not using Azure Entra ID authentication + auth_type = api_config.get("auth_type", "bearer") + if auth_type not in ("azure_ad", "microsoft_entra_id"): + headers["api-key"] = key + api_version = api_config.get("api_version", "") or "2023-03-15-preview" async with session.get( url=f"{url}/openai/models?api-version={api_version}", headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: + try: + response_data = await r.json() + except Exception: + response_data = await r.text() + if r.status != 200: - # Extract response error details if available - error_detail = f"HTTP Error: {r.status}" - res = await r.json() - if "error" in res: - error_detail = f"External Error: {res['error']}" - raise Exception(error_detail) - - response_data = await r.json() + if isinstance(response_data, (dict, list)): + return JSONResponse( + status_code=r.status, content=response_data + ) + else: + return PlainTextResponse( + status_code=r.status, content=response_data + ) + return response_data else: - headers["Authorization"] = f"Bearer {key}" - async with session.get( f"{url}/models", headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) as r: + try: + response_data = await r.json() + except Exception: + response_data = await r.text() + if r.status != 200: - # Extract response error details if available - error_detail = f"HTTP Error: {r.status}" - res = await r.json() - if "error" in res: - error_detail = f"External Error: {res['error']}" - raise Exception(error_detail) - - response_data = await r.json() + if isinstance(response_data, (dict, list)): + return JSONResponse( + status_code=r.status, content=response_data + ) + else: + return PlainTextResponse( + status_code=r.status, content=response_data + ) + return response_data except aiohttp.ClientError as e: @@ -629,17 +724,12 @@ async def verify_connection( ) except Exception as e: log.exception(f"Unexpected error: {e}") - error_detail = f"Unexpected error: {str(e)}" - raise HTTPException(status_code=500, detail=error_detail) - + raise HTTPException( + status_code=500, detail="Open WebUI: Server Connection Error" + ) -def convert_to_azure_payload( - url, - payload: dict, -): - model = payload.get("model", "") - # Filter allowed parameters based on Azure OpenAI API +def get_azure_allowed_params(api_version: str) -> set[str]: allowed_params = { "messages", "temperature", @@ -669,8 +759,29 @@ def convert_to_azure_payload( "max_completion_tokens", } + try: + if api_version >= "2024-09-01-preview": + allowed_params.add("stream_options") + except ValueError: + log.debug( + f"Invalid API version {api_version} for Azure OpenAI. Defaulting to allowed parameters." + ) + + return allowed_params + + +def is_openai_reasoning_model(model: str) -> bool: + return model.lower().startswith(("o1", "o3", "o4", "gpt-5")) + + +def convert_to_azure_payload(url, payload: dict, api_version: str): + model = payload.get("model", "") + + # Filter allowed parameters based on Azure OpenAI API + allowed_params = get_azure_allowed_params(api_version) + # Special handling for o-series models - if model.startswith("o") and model.endswith("-mini"): + if is_openai_reasoning_model(model): # Convert max_tokens to max_completion_tokens for o-series models if "max_tokens" in payload: payload["max_completion_tokens"] = payload["max_tokens"] @@ -715,8 +826,12 @@ async def generate_chat_completion( model_id = model_info.base_model_id params = model_info.params.model_dump() - payload = apply_model_params_to_body_openai(params, payload) - payload = apply_model_system_prompt_to_body(params, payload, metadata, user) + + if params: + system = params.pop("system", None) + + payload = apply_model_params_to_body_openai(params, payload) + payload = apply_system_prompt_to_body(system, payload, metadata, user) # Check if user has access to the model if not bypass_filter and user.role == "user": @@ -771,10 +886,9 @@ async def generate_chat_completion( url = request.app.state.config.OPENAI_API_BASE_URLS[idx] key = request.app.state.config.OPENAI_API_KEYS[idx] - # Check if model is from "o" series - is_o_series = payload["model"].lower().startswith(("o1", "o3", "o4")) - if is_o_series: - payload = openai_o_series_handler(payload) + # Check if model is a reasoning model that needs special handling + if is_openai_reasoning_model(payload["model"]): + payload = openai_reasoning_model_handler(payload) elif "api.openai.com" not in url: # Remove "max_completion_tokens" from the payload for backward compatibility if "max_completion_tokens" in payload: @@ -790,37 +904,23 @@ async def generate_chat_completion( convert_logit_bias_input_to_json(payload["logit_bias"]) ) - headers = { - "Content-Type": "application/json", - **( - { - "HTTP-Referer": "https://openwebui.com/", - "X-Title": "Open WebUI", - } - if "openrouter.ai" in url - else {} - ), - **( - { - "X-OpenWebUI-User-Name": user.name, - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = await get_headers_and_cookies( + request, url, key, api_config, metadata, user=user + ) if api_config.get("azure", False): - request_url, payload = convert_to_azure_payload(url, payload) - api_version = api_config.get("api_version", "") or "2023-03-15-preview" - headers["api-key"] = key + api_version = api_config.get("api_version", "2023-03-15-preview") + request_url, payload = convert_to_azure_payload(url, payload, api_version) + + # Only set api-key header if not using Azure Entra ID authentication + auth_type = api_config.get("auth_type", "bearer") + if auth_type not in ("azure_ad", "microsoft_entra_id"): + headers["api-key"] = key + headers["api-version"] = api_version request_url = f"{request_url}/chat/completions?api-version={api_version}" else: request_url = f"{url}/chat/completions" - headers["Authorization"] = f"Bearer {key}" payload = json.dumps(payload) @@ -839,6 +939,7 @@ async def generate_chat_completion( url=request_url, data=payload, headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) @@ -860,27 +961,105 @@ async def generate_chat_completion( log.error(e) response = await r.text() - r.raise_for_status() + if r.status >= 400: + if isinstance(response, (dict, list)): + return JSONResponse(status_code=r.status, content=response) + else: + return PlainTextResponse(status_code=r.status, content=response) + return response except Exception as e: log.exception(e) - detail = None - if isinstance(response, dict): - if "error" in response: - detail = f"{response['error']['message'] if 'message' in response['error'] else response['error']}" - elif isinstance(response, str): - detail = response + raise HTTPException( + status_code=r.status if r else 500, + detail="Open WebUI: Server Connection Error", + ) + finally: + if not streaming: + await cleanup_response(r, session) + + +async def embeddings(request: Request, form_data: dict, user): + """ + Calls the embeddings endpoint for OpenAI-compatible providers. + + Args: + request (Request): The FastAPI request context. + form_data (dict): OpenAI-compatible embeddings payload. + user (UserModel): The authenticated user. + + Returns: + dict: OpenAI-compatible embeddings response. + """ + idx = 0 + # Prepare payload/body + body = json.dumps(form_data) + # Find correct backend url/key based on model + await get_all_models(request, user=user) + model_id = form_data.get("model") + models = request.app.state.OPENAI_MODELS + if model_id in models: + idx = models[model_id]["urlIdx"] + + url = request.app.state.config.OPENAI_API_BASE_URLS[idx] + key = request.app.state.config.OPENAI_API_KEYS[idx] + api_config = request.app.state.config.OPENAI_API_CONFIGS.get( + str(idx), + request.app.state.config.OPENAI_API_CONFIGS.get(url, {}), # Legacy support + ) + + r = None + session = None + streaming = False + + headers, cookies = await get_headers_and_cookies( + request, url, key, api_config, user=user + ) + try: + session = aiohttp.ClientSession(trust_env=True) + r = await session.request( + method="POST", + url=f"{url}/embeddings", + data=body, + headers=headers, + cookies=cookies, + ) + if "text/event-stream" in r.headers.get("Content-Type", ""): + streaming = True + return StreamingResponse( + r.content, + status_code=r.status, + headers=dict(r.headers), + background=BackgroundTask( + cleanup_response, response=r, session=session + ), + ) + else: + try: + response_data = await r.json() + except Exception: + response_data = await r.text() + + if r.status >= 400: + if isinstance(response_data, (dict, list)): + return JSONResponse(status_code=r.status, content=response_data) + else: + return PlainTextResponse( + status_code=r.status, content=response_data + ) + + return response_data + except Exception as e: + log.exception(e) raise HTTPException( status_code=r.status if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail="Open WebUI: Server Connection Error", ) finally: - if not streaming and session: - if r: - r.close() - await session.close() + if not streaming: + await cleanup_response(r, session) @router.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @@ -906,33 +1085,26 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): streaming = False try: - headers = { - "Content-Type": "application/json", - **( - { - "X-OpenWebUI-User-Name": user.name, - "X-OpenWebUI-User-Id": user.id, - "X-OpenWebUI-User-Email": user.email, - "X-OpenWebUI-User-Role": user.role, - } - if ENABLE_FORWARD_USER_INFO_HEADERS - else {} - ), - } + headers, cookies = await get_headers_and_cookies( + request, url, key, api_config, user=user + ) if api_config.get("azure", False): - headers["api-key"] = key - headers["api-version"] = ( - api_config.get("api_version", "") or "2023-03-15-preview" - ) + api_version = api_config.get("api_version", "2023-03-15-preview") + + # Only set api-key header if not using Azure Entra ID authentication + auth_type = api_config.get("auth_type", "bearer") + if auth_type not in ("azure_ad", "microsoft_entra_id"): + headers["api-key"] = key + + headers["api-version"] = api_version payload = json.loads(body) - url, payload = convert_to_azure_payload(url, payload) + url, payload = convert_to_azure_payload(url, payload, api_version) body = json.dumps(payload).encode() - request_url = f"{url}/{path}?api-version={api_config.get('api_version', '2023-03-15-preview')}" + request_url = f"{url}/{path}?api-version={api_version}" else: - headers["Authorization"] = f"Bearer {key}" request_url = f"{url}/{path}" session = aiohttp.ClientSession(trust_env=True) @@ -941,9 +1113,9 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): url=request_url, data=body, headers=headers, + cookies=cookies, ssl=AIOHTTP_CLIENT_SESSION_SSL, ) - r.raise_for_status() # Check if response is SSE if "text/event-stream" in r.headers.get("Content-Type", ""): @@ -957,27 +1129,27 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ), ) else: - response_data = await r.json() + try: + response_data = await r.json() + except Exception: + response_data = await r.text() + + if r.status >= 400: + if isinstance(response_data, (dict, list)): + return JSONResponse(status_code=r.status, content=response_data) + else: + return PlainTextResponse( + status_code=r.status, content=response_data + ) + return response_data except Exception as e: log.exception(e) - - detail = None - if r is not None: - try: - res = await r.json() - log.error(res) - if "error" in res: - detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}" - except Exception: - detail = f"External: {e}" raise HTTPException( status_code=r.status if r else 500, - detail=detail if detail else "Open WebUI: Server Connection Error", + detail="Open WebUI: Server Connection Error", ) finally: - if not streaming and session: - if r: - r.close() - await session.close() + if not streaming: + await cleanup_response(r, session) diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 9fb946c6e72..5981f99f697 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -1,4 +1,5 @@ from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, status, Request from open_webui.models.prompts import ( PromptForm, @@ -7,9 +8,9 @@ Prompts, ) from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, status, Request from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL router = APIRouter() @@ -20,7 +21,7 @@ @router.get("/", response_model=list[PromptModel]) async def get_prompts(user=Depends(get_verified_user)): - if user.role == "admin": + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: prompts = Prompts.get_prompts() else: prompts = Prompts.get_prompts_by_user_id(user.id, "read") @@ -30,7 +31,7 @@ async def get_prompts(user=Depends(get_verified_user)): @router.get("/list", response_model=list[PromptUserResponse]) async def get_prompt_list(user=Depends(get_verified_user)): - if user.role == "admin": + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: prompts = Prompts.get_prompts() else: prompts = Prompts.get_prompts_by_user_id(user.id, "write") diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 5cb47373f32..cb66e8926ec 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -5,7 +5,7 @@ import shutil import asyncio - +import re import uuid from datetime import datetime from pathlib import Path @@ -29,6 +29,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter +from langchain_text_splitters import MarkdownHeaderTextSplitter from langchain_core.documents import Document from open_webui.models.files import FileModel, Files @@ -45,6 +46,8 @@ # Web search engines from open_webui.retrieval.web.main import SearchResult from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.web.ollama import search_ollama_cloud +from open_webui.retrieval.web.perplexity_search import search_perplexity_search from open_webui.retrieval.web.brave import search_brave from open_webui.retrieval.web.kagi import search_kagi from open_webui.retrieval.web.mojeek import search_mojeek @@ -68,13 +71,16 @@ from open_webui.retrieval.web.external import search_external from open_webui.retrieval.utils import ( + get_content_from_url, get_embedding_function, + get_reranking_function, get_model_path, query_collection, query_collection_with_hybrid_search, query_doc, query_doc_with_hybrid_search, ) +from open_webui.retrieval.vector.utils import filter_metadata from open_webui.utils.misc import ( calculate_sha256_string, ) @@ -185,6 +191,26 @@ def get_rf( log.error(f"CrossEncoder: {e}") raise Exception(ERROR_MESSAGES.DEFAULT("CrossEncoder error")) + # Safely adjust pad_token_id if missing as some models do not have this in config + try: + model_cfg = getattr(rf, "model", None) + if model_cfg and hasattr(model_cfg, "config"): + cfg = model_cfg.config + if getattr(cfg, "pad_token_id", None) is None: + # Fallback to eos_token_id when available + eos = getattr(cfg, "eos_token_id", None) + if eos is not None: + cfg.pad_token_id = eos + log.debug( + f"Missing pad_token_id detected; set to eos_token_id={eos}" + ) + else: + log.warning( + "Neither pad_token_id nor eos_token_id present in model config" + ) + except Exception as e2: + log.warning(f"Failed to adjust pad_token_id on CrossEncoder: {e2}") + return rf @@ -239,6 +265,11 @@ async def get_embedding_config(request: Request, user=Depends(get_admin_user)): "url": request.app.state.config.RAG_OLLAMA_BASE_URL, "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, + "azure_openai_config": { + "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, + "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, + "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, + }, } @@ -252,9 +283,16 @@ class OllamaConfigForm(BaseModel): key: str +class AzureOpenAIConfigForm(BaseModel): + url: str + key: str + version: str + + class EmbeddingModelUpdateForm(BaseModel): openai_config: Optional[OpenAIConfigForm] = None ollama_config: Optional[OllamaConfigForm] = None + azure_openai_config: Optional[AzureOpenAIConfigForm] = None embedding_engine: str embedding_model: str embedding_batch_size: Optional[int] = 1 @@ -267,11 +305,27 @@ async def update_embedding_config( log.info( f"Updating embedding model: {request.app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" ) + if request.app.state.config.RAG_EMBEDDING_ENGINE == "": + # unloads current internal embedding model and clears VRAM cache + request.app.state.ef = None + request.app.state.EMBEDDING_FUNCTION = None + import gc + + gc.collect() + if DEVICE_TYPE == "cuda": + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() try: request.app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine request.app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model - if request.app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: + if request.app.state.config.RAG_EMBEDDING_ENGINE in [ + "ollama", + "openai", + "azure_openai", + ]: if form_data.openai_config is not None: request.app.state.config.RAG_OPENAI_API_BASE_URL = ( form_data.openai_config.url @@ -288,6 +342,17 @@ async def update_embedding_config( form_data.ollama_config.key ) + if form_data.azure_openai_config is not None: + request.app.state.config.RAG_AZURE_OPENAI_BASE_URL = ( + form_data.azure_openai_config.url + ) + request.app.state.config.RAG_AZURE_OPENAI_API_KEY = ( + form_data.azure_openai_config.key + ) + request.app.state.config.RAG_AZURE_OPENAI_API_VERSION = ( + form_data.azure_openai_config.version + ) + request.app.state.config.RAG_EMBEDDING_BATCH_SIZE = ( form_data.embedding_batch_size ) @@ -304,14 +369,27 @@ async def update_embedding_config( ( request.app.state.config.RAG_OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_BASE_URL + else ( + request.app.state.config.RAG_OLLAMA_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL + ) ), ( request.app.state.config.RAG_OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY + else ( + request.app.state.config.RAG_OLLAMA_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_API_KEY + ) ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + azure_api_version=( + request.app.state.config.RAG_AZURE_OPENAI_API_VERSION + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), ) return { @@ -327,6 +405,11 @@ async def update_embedding_config( "url": request.app.state.config.RAG_OLLAMA_BASE_URL, "key": request.app.state.config.RAG_OLLAMA_API_KEY, }, + "azure_openai_config": { + "url": request.app.state.config.RAG_AZURE_OPENAI_BASE_URL, + "key": request.app.state.config.RAG_AZURE_OPENAI_API_KEY, + "version": request.app.state.config.RAG_AZURE_OPENAI_API_VERSION, + }, } except Exception as e: log.exception(f"Problem updating embedding model: {e}") @@ -349,19 +432,45 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, # Content extraction settings "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, + "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, + "DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL, + "DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, + "DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE, + "DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR, + "DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE, + "DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, + "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + "DATALAB_MARKER_FORMAT_LINES": request.app.state.config.DATALAB_MARKER_FORMAT_LINES, + "DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM, + "DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, + "DOCLING_PARAMS": request.app.state.config.DOCLING_PARAMS, + "DOCLING_DO_OCR": request.app.state.config.DOCLING_DO_OCR, + "DOCLING_FORCE_OCR": request.app.state.config.DOCLING_FORCE_OCR, "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, + "DOCLING_PDF_BACKEND": request.app.state.config.DOCLING_PDF_BACKEND, + "DOCLING_TABLE_MODE": request.app.state.config.DOCLING_TABLE_MODE, + "DOCLING_PIPELINE": request.app.state.config.DOCLING_PIPELINE, "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, + "DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, + "DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, + "DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + # MinerU settings + "MINERU_API_MODE": request.app.state.config.MINERU_API_MODE, + "MINERU_API_URL": request.app.state.config.MINERU_API_URL, + "MINERU_API_KEY": request.app.state.config.MINERU_API_KEY, + "MINERU_PARAMS": request.app.state.config.MINERU_PARAMS, # Reranking settings "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, @@ -374,6 +483,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): # File upload settings "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, + "FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + "FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, # Integration settings "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, @@ -385,8 +496,11 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + "WEB_LOADER_CONCURRENT_REQUESTS": request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, + "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, "YACY_USERNAME": request.app.state.config.YACY_USERNAME, @@ -411,6 +525,8 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "EXA_API_KEY": request.app.state.config.EXA_API_KEY, "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, + "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, + "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, @@ -437,8 +553,11 @@ class WebConfig(BaseModel): WEB_SEARCH_TRUST_ENV: Optional[bool] = None WEB_SEARCH_RESULT_COUNT: Optional[int] = None WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None + WEB_LOADER_CONCURRENT_REQUESTS: Optional[int] = None WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = [] BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None + BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None + OLLAMA_CLOUD_WEB_SEARCH_API_KEY: Optional[str] = None SEARXNG_QUERY_URL: Optional[str] = None YACY_QUERY_URL: Optional[str] = None YACY_USERNAME: Optional[str] = None @@ -463,6 +582,8 @@ class WebConfig(BaseModel): BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None EXA_API_KEY: Optional[str] = None PERPLEXITY_API_KEY: Optional[str] = None + PERPLEXITY_MODEL: Optional[str] = None + PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None SOUGOU_API_SID: Optional[str] = None SOUGOU_API_SK: Optional[str] = None WEB_LOADER_ENGINE: Optional[str] = None @@ -492,22 +613,51 @@ class ConfigForm(BaseModel): ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None TOP_K_RERANKER: Optional[int] = None RELEVANCE_THRESHOLD: Optional[float] = None + HYBRID_BM25_WEIGHT: Optional[float] = None # Content extraction settings CONTENT_EXTRACTION_ENGINE: Optional[str] = None PDF_EXTRACT_IMAGES: Optional[bool] = None + + DATALAB_MARKER_API_KEY: Optional[str] = None + DATALAB_MARKER_API_BASE_URL: Optional[str] = None + DATALAB_MARKER_ADDITIONAL_CONFIG: Optional[str] = None + DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None + DATALAB_MARKER_FORCE_OCR: Optional[bool] = None + DATALAB_MARKER_PAGINATE: Optional[bool] = None + DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None + DATALAB_MARKER_FORMAT_LINES: Optional[bool] = None + DATALAB_MARKER_USE_LLM: Optional[bool] = None + DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None + EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None TIKA_SERVER_URL: Optional[str] = None DOCLING_SERVER_URL: Optional[str] = None + DOCLING_PARAMS: Optional[dict] = None + DOCLING_DO_OCR: Optional[bool] = None + DOCLING_FORCE_OCR: Optional[bool] = None DOCLING_OCR_ENGINE: Optional[str] = None DOCLING_OCR_LANG: Optional[str] = None + DOCLING_PDF_BACKEND: Optional[str] = None + DOCLING_TABLE_MODE: Optional[str] = None + DOCLING_PIPELINE: Optional[str] = None DOCLING_DO_PICTURE_DESCRIPTION: Optional[bool] = None + DOCLING_PICTURE_DESCRIPTION_MODE: Optional[str] = None + DOCLING_PICTURE_DESCRIPTION_LOCAL: Optional[dict] = None + DOCLING_PICTURE_DESCRIPTION_API: Optional[dict] = None DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None MISTRAL_OCR_API_KEY: Optional[str] = None + # MinerU settings + MINERU_API_MODE: Optional[str] = None + MINERU_API_URL: Optional[str] = None + MINERU_API_KEY: Optional[str] = None + MINERU_PARAMS: Optional[dict] = None + # Reranking settings RAG_RERANKING_MODEL: Optional[str] = None RAG_RERANKING_ENGINE: Optional[str] = None @@ -522,6 +672,8 @@ class ConfigForm(BaseModel): # File upload settings FILE_MAX_SIZE: Optional[int] = None FILE_MAX_COUNT: Optional[int] = None + FILE_IMAGE_COMPRESSION_WIDTH: Optional[int] = None + FILE_IMAGE_COMPRESSION_HEIGHT: Optional[int] = None ALLOWED_FILE_EXTENSIONS: Optional[List[str]] = None # Integration settings @@ -564,9 +716,6 @@ async def update_rag_config( if form_data.ENABLE_RAG_HYBRID_SEARCH is not None else request.app.state.config.ENABLE_RAG_HYBRID_SEARCH ) - # Free up memory if hybrid search is disabled - if not request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: - request.app.state.rf = None request.app.state.config.TOP_K_RERANKER = ( form_data.TOP_K_RERANKER @@ -578,6 +727,11 @@ async def update_rag_config( if form_data.RELEVANCE_THRESHOLD is not None else request.app.state.config.RELEVANCE_THRESHOLD ) + request.app.state.config.HYBRID_BM25_WEIGHT = ( + form_data.HYBRID_BM25_WEIGHT + if form_data.HYBRID_BM25_WEIGHT is not None + else request.app.state.config.HYBRID_BM25_WEIGHT + ) # Content extraction settings request.app.state.config.CONTENT_EXTRACTION_ENGINE = ( @@ -590,6 +744,61 @@ async def update_rag_config( if form_data.PDF_EXTRACT_IMAGES is not None else request.app.state.config.PDF_EXTRACT_IMAGES ) + request.app.state.config.DATALAB_MARKER_API_KEY = ( + form_data.DATALAB_MARKER_API_KEY + if form_data.DATALAB_MARKER_API_KEY is not None + else request.app.state.config.DATALAB_MARKER_API_KEY + ) + request.app.state.config.DATALAB_MARKER_API_BASE_URL = ( + form_data.DATALAB_MARKER_API_BASE_URL + if form_data.DATALAB_MARKER_API_BASE_URL is not None + else request.app.state.config.DATALAB_MARKER_API_BASE_URL + ) + request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG = ( + form_data.DATALAB_MARKER_ADDITIONAL_CONFIG + if form_data.DATALAB_MARKER_ADDITIONAL_CONFIG is not None + else request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG + ) + request.app.state.config.DATALAB_MARKER_SKIP_CACHE = ( + form_data.DATALAB_MARKER_SKIP_CACHE + if form_data.DATALAB_MARKER_SKIP_CACHE is not None + else request.app.state.config.DATALAB_MARKER_SKIP_CACHE + ) + request.app.state.config.DATALAB_MARKER_FORCE_OCR = ( + form_data.DATALAB_MARKER_FORCE_OCR + if form_data.DATALAB_MARKER_FORCE_OCR is not None + else request.app.state.config.DATALAB_MARKER_FORCE_OCR + ) + request.app.state.config.DATALAB_MARKER_PAGINATE = ( + form_data.DATALAB_MARKER_PAGINATE + if form_data.DATALAB_MARKER_PAGINATE is not None + else request.app.state.config.DATALAB_MARKER_PAGINATE + ) + request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR = ( + form_data.DATALAB_MARKER_STRIP_EXISTING_OCR + if form_data.DATALAB_MARKER_STRIP_EXISTING_OCR is not None + else request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR + ) + request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION = ( + form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION + if form_data.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION is not None + else request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION + ) + request.app.state.config.DATALAB_MARKER_FORMAT_LINES = ( + form_data.DATALAB_MARKER_FORMAT_LINES + if form_data.DATALAB_MARKER_FORMAT_LINES is not None + else request.app.state.config.DATALAB_MARKER_FORMAT_LINES + ) + request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT = ( + form_data.DATALAB_MARKER_OUTPUT_FORMAT + if form_data.DATALAB_MARKER_OUTPUT_FORMAT is not None + else request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT + ) + request.app.state.config.DATALAB_MARKER_USE_LLM = ( + form_data.DATALAB_MARKER_USE_LLM + if form_data.DATALAB_MARKER_USE_LLM is not None + else request.app.state.config.DATALAB_MARKER_USE_LLM + ) request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL = ( form_data.EXTERNAL_DOCUMENT_LOADER_URL if form_data.EXTERNAL_DOCUMENT_LOADER_URL is not None @@ -610,6 +819,21 @@ async def update_rag_config( if form_data.DOCLING_SERVER_URL is not None else request.app.state.config.DOCLING_SERVER_URL ) + request.app.state.config.DOCLING_PARAMS = ( + form_data.DOCLING_PARAMS + if form_data.DOCLING_PARAMS is not None + else request.app.state.config.DOCLING_PARAMS + ) + request.app.state.config.DOCLING_DO_OCR = ( + form_data.DOCLING_DO_OCR + if form_data.DOCLING_DO_OCR is not None + else request.app.state.config.DOCLING_DO_OCR + ) + request.app.state.config.DOCLING_FORCE_OCR = ( + form_data.DOCLING_FORCE_OCR + if form_data.DOCLING_FORCE_OCR is not None + else request.app.state.config.DOCLING_FORCE_OCR + ) request.app.state.config.DOCLING_OCR_ENGINE = ( form_data.DOCLING_OCR_ENGINE if form_data.DOCLING_OCR_ENGINE is not None @@ -620,13 +844,43 @@ async def update_rag_config( if form_data.DOCLING_OCR_LANG is not None else request.app.state.config.DOCLING_OCR_LANG ) - + request.app.state.config.DOCLING_PDF_BACKEND = ( + form_data.DOCLING_PDF_BACKEND + if form_data.DOCLING_PDF_BACKEND is not None + else request.app.state.config.DOCLING_PDF_BACKEND + ) + request.app.state.config.DOCLING_TABLE_MODE = ( + form_data.DOCLING_TABLE_MODE + if form_data.DOCLING_TABLE_MODE is not None + else request.app.state.config.DOCLING_TABLE_MODE + ) + request.app.state.config.DOCLING_PIPELINE = ( + form_data.DOCLING_PIPELINE + if form_data.DOCLING_PIPELINE is not None + else request.app.state.config.DOCLING_PIPELINE + ) request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION = ( form_data.DOCLING_DO_PICTURE_DESCRIPTION if form_data.DOCLING_DO_PICTURE_DESCRIPTION is not None else request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION ) + request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE = ( + form_data.DOCLING_PICTURE_DESCRIPTION_MODE + if form_data.DOCLING_PICTURE_DESCRIPTION_MODE is not None + else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE + ) + request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL = ( + form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL + if form_data.DOCLING_PICTURE_DESCRIPTION_LOCAL is not None + else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL + ) + request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API = ( + form_data.DOCLING_PICTURE_DESCRIPTION_API + if form_data.DOCLING_PICTURE_DESCRIPTION_API is not None + else request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API + ) + request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = ( form_data.DOCUMENT_INTELLIGENCE_ENDPOINT if form_data.DOCUMENT_INTELLIGENCE_ENDPOINT is not None @@ -643,7 +897,41 @@ async def update_rag_config( else request.app.state.config.MISTRAL_OCR_API_KEY ) + # MinerU settings + request.app.state.config.MINERU_API_MODE = ( + form_data.MINERU_API_MODE + if form_data.MINERU_API_MODE is not None + else request.app.state.config.MINERU_API_MODE + ) + request.app.state.config.MINERU_API_URL = ( + form_data.MINERU_API_URL + if form_data.MINERU_API_URL is not None + else request.app.state.config.MINERU_API_URL + ) + request.app.state.config.MINERU_API_KEY = ( + form_data.MINERU_API_KEY + if form_data.MINERU_API_KEY is not None + else request.app.state.config.MINERU_API_KEY + ) + request.app.state.config.MINERU_PARAMS = ( + form_data.MINERU_PARAMS + if form_data.MINERU_PARAMS is not None + else request.app.state.config.MINERU_PARAMS + ) + # Reranking settings + if request.app.state.config.RAG_RERANKING_ENGINE == "": + # Unloading the internal reranker and clear VRAM memory + request.app.state.rf = None + request.app.state.RERANKING_FUNCTION = None + import gc + + gc.collect() + if DEVICE_TYPE == "cuda": + import torch + + if torch.cuda.is_available(): + torch.cuda.empty_cache() request.app.state.config.RAG_RERANKING_ENGINE = ( form_data.RAG_RERANKING_ENGINE if form_data.RAG_RERANKING_ENGINE is not None @@ -666,16 +954,30 @@ async def update_rag_config( f"Updating reranking model: {request.app.state.config.RAG_RERANKING_MODEL} to {form_data.RAG_RERANKING_MODEL}" ) try: - request.app.state.config.RAG_RERANKING_MODEL = form_data.RAG_RERANKING_MODEL + request.app.state.config.RAG_RERANKING_MODEL = ( + form_data.RAG_RERANKING_MODEL + if form_data.RAG_RERANKING_MODEL is not None + else request.app.state.config.RAG_RERANKING_MODEL + ) try: - request.app.state.rf = get_rf( - request.app.state.config.RAG_RERANKING_ENGINE, - request.app.state.config.RAG_RERANKING_MODEL, - request.app.state.config.RAG_EXTERNAL_RERANKER_URL, - request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, - True, - ) + if ( + request.app.state.config.ENABLE_RAG_HYBRID_SEARCH + and not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL + ): + request.app.state.rf = get_rf( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.config.RAG_EXTERNAL_RERANKER_URL, + request.app.state.config.RAG_EXTERNAL_RERANKER_API_KEY, + True, + ) + + request.app.state.RERANKING_FUNCTION = get_reranking_function( + request.app.state.config.RAG_RERANKING_ENGINE, + request.app.state.config.RAG_RERANKING_MODEL, + request.app.state.rf, + ) except Exception as e: log.error(f"Error loading reranking model: {e}") request.app.state.config.ENABLE_RAG_HYBRID_SEARCH = False @@ -704,15 +1006,13 @@ async def update_rag_config( ) # File upload settings - request.app.state.config.FILE_MAX_SIZE = ( - form_data.FILE_MAX_SIZE - if form_data.FILE_MAX_SIZE is not None - else request.app.state.config.FILE_MAX_SIZE + request.app.state.config.FILE_MAX_SIZE = form_data.FILE_MAX_SIZE + request.app.state.config.FILE_MAX_COUNT = form_data.FILE_MAX_COUNT + request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH = ( + form_data.FILE_IMAGE_COMPRESSION_WIDTH ) - request.app.state.config.FILE_MAX_COUNT = ( - form_data.FILE_MAX_COUNT - if form_data.FILE_MAX_COUNT is not None - else request.app.state.config.FILE_MAX_COUNT + request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT = ( + form_data.FILE_IMAGE_COMPRESSION_HEIGHT ) request.app.state.config.ALLOWED_FILE_EXTENSIONS = ( form_data.ALLOWED_FILE_EXTENSIONS @@ -745,12 +1045,21 @@ async def update_rag_config( request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS = ( form_data.web.WEB_SEARCH_CONCURRENT_REQUESTS ) + request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS = ( + form_data.web.WEB_LOADER_CONCURRENT_REQUESTS + ) request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST = ( form_data.web.WEB_SEARCH_DOMAIN_FILTER_LIST ) request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL = ( form_data.web.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL ) + request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER = ( + form_data.web.BYPASS_WEB_SEARCH_WEB_LOADER + ) + request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY = ( + form_data.web.OLLAMA_CLOUD_WEB_SEARCH_API_KEY + ) request.app.state.config.SEARXNG_QUERY_URL = form_data.web.SEARXNG_QUERY_URL request.app.state.config.YACY_QUERY_URL = form_data.web.YACY_QUERY_URL request.app.state.config.YACY_USERNAME = form_data.web.YACY_USERNAME @@ -787,6 +1096,10 @@ async def update_rag_config( ) request.app.state.config.EXA_API_KEY = form_data.web.EXA_API_KEY request.app.state.config.PERPLEXITY_API_KEY = form_data.web.PERPLEXITY_API_KEY + request.app.state.config.PERPLEXITY_MODEL = form_data.web.PERPLEXITY_MODEL + request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE = ( + form_data.web.PERPLEXITY_SEARCH_CONTEXT_USAGE + ) request.app.state.config.SOUGOU_API_SID = form_data.web.SOUGOU_API_SID request.app.state.config.SOUGOU_API_SK = form_data.web.SOUGOU_API_SK @@ -837,19 +1150,44 @@ async def update_rag_config( "ENABLE_RAG_HYBRID_SEARCH": request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, "TOP_K_RERANKER": request.app.state.config.TOP_K_RERANKER, "RELEVANCE_THRESHOLD": request.app.state.config.RELEVANCE_THRESHOLD, + "HYBRID_BM25_WEIGHT": request.app.state.config.HYBRID_BM25_WEIGHT, # Content extraction settings "CONTENT_EXTRACTION_ENGINE": request.app.state.config.CONTENT_EXTRACTION_ENGINE, "PDF_EXTRACT_IMAGES": request.app.state.config.PDF_EXTRACT_IMAGES, + "DATALAB_MARKER_API_KEY": request.app.state.config.DATALAB_MARKER_API_KEY, + "DATALAB_MARKER_API_BASE_URL": request.app.state.config.DATALAB_MARKER_API_BASE_URL, + "DATALAB_MARKER_ADDITIONAL_CONFIG": request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, + "DATALAB_MARKER_SKIP_CACHE": request.app.state.config.DATALAB_MARKER_SKIP_CACHE, + "DATALAB_MARKER_FORCE_OCR": request.app.state.config.DATALAB_MARKER_FORCE_OCR, + "DATALAB_MARKER_PAGINATE": request.app.state.config.DATALAB_MARKER_PAGINATE, + "DATALAB_MARKER_STRIP_EXISTING_OCR": request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, + "DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION": request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + "DATALAB_MARKER_USE_LLM": request.app.state.config.DATALAB_MARKER_USE_LLM, + "DATALAB_MARKER_OUTPUT_FORMAT": request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, "EXTERNAL_DOCUMENT_LOADER_URL": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, "EXTERNAL_DOCUMENT_LOADER_API_KEY": request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, "TIKA_SERVER_URL": request.app.state.config.TIKA_SERVER_URL, "DOCLING_SERVER_URL": request.app.state.config.DOCLING_SERVER_URL, + "DOCLING_PARAMS": request.app.state.config.DOCLING_PARAMS, + "DOCLING_DO_OCR": request.app.state.config.DOCLING_DO_OCR, + "DOCLING_FORCE_OCR": request.app.state.config.DOCLING_FORCE_OCR, "DOCLING_OCR_ENGINE": request.app.state.config.DOCLING_OCR_ENGINE, "DOCLING_OCR_LANG": request.app.state.config.DOCLING_OCR_LANG, + "DOCLING_PDF_BACKEND": request.app.state.config.DOCLING_PDF_BACKEND, + "DOCLING_TABLE_MODE": request.app.state.config.DOCLING_TABLE_MODE, + "DOCLING_PIPELINE": request.app.state.config.DOCLING_PIPELINE, "DOCLING_DO_PICTURE_DESCRIPTION": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, + "DOCLING_PICTURE_DESCRIPTION_MODE": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, + "DOCLING_PICTURE_DESCRIPTION_LOCAL": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, + "DOCLING_PICTURE_DESCRIPTION_API": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, "DOCUMENT_INTELLIGENCE_ENDPOINT": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, "DOCUMENT_INTELLIGENCE_KEY": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, "MISTRAL_OCR_API_KEY": request.app.state.config.MISTRAL_OCR_API_KEY, + # MinerU settings + "MINERU_API_MODE": request.app.state.config.MINERU_API_MODE, + "MINERU_API_URL": request.app.state.config.MINERU_API_URL, + "MINERU_API_KEY": request.app.state.config.MINERU_API_KEY, + "MINERU_PARAMS": request.app.state.config.MINERU_PARAMS, # Reranking settings "RAG_RERANKING_MODEL": request.app.state.config.RAG_RERANKING_MODEL, "RAG_RERANKING_ENGINE": request.app.state.config.RAG_RERANKING_ENGINE, @@ -862,6 +1200,8 @@ async def update_rag_config( # File upload settings "FILE_MAX_SIZE": request.app.state.config.FILE_MAX_SIZE, "FILE_MAX_COUNT": request.app.state.config.FILE_MAX_COUNT, + "FILE_IMAGE_COMPRESSION_WIDTH": request.app.state.config.FILE_IMAGE_COMPRESSION_WIDTH, + "FILE_IMAGE_COMPRESSION_HEIGHT": request.app.state.config.FILE_IMAGE_COMPRESSION_HEIGHT, "ALLOWED_FILE_EXTENSIONS": request.app.state.config.ALLOWED_FILE_EXTENSIONS, # Integration settings "ENABLE_GOOGLE_DRIVE_INTEGRATION": request.app.state.config.ENABLE_GOOGLE_DRIVE_INTEGRATION, @@ -873,8 +1213,11 @@ async def update_rag_config( "WEB_SEARCH_TRUST_ENV": request.app.state.config.WEB_SEARCH_TRUST_ENV, "WEB_SEARCH_RESULT_COUNT": request.app.state.config.WEB_SEARCH_RESULT_COUNT, "WEB_SEARCH_CONCURRENT_REQUESTS": request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, + "WEB_LOADER_CONCURRENT_REQUESTS": request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, "WEB_SEARCH_DOMAIN_FILTER_LIST": request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, "BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL": request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + "BYPASS_WEB_SEARCH_WEB_LOADER": request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER, + "OLLAMA_CLOUD_WEB_SEARCH_API_KEY": request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, "SEARXNG_QUERY_URL": request.app.state.config.SEARXNG_QUERY_URL, "YACY_QUERY_URL": request.app.state.config.YACY_QUERY_URL, "YACY_USERNAME": request.app.state.config.YACY_USERNAME, @@ -899,6 +1242,8 @@ async def update_rag_config( "BING_SEARCH_V7_SUBSCRIPTION_KEY": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY, "EXA_API_KEY": request.app.state.config.EXA_API_KEY, "PERPLEXITY_API_KEY": request.app.state.config.PERPLEXITY_API_KEY, + "PERPLEXITY_MODEL": request.app.state.config.PERPLEXITY_MODEL, + "PERPLEXITY_SEARCH_CONTEXT_USAGE": request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, "SOUGOU_API_SID": request.app.state.config.SOUGOU_API_SID, "SOUGOU_API_SK": request.app.state.config.SOUGOU_API_SK, "WEB_LOADER_ENGINE": request.app.state.config.WEB_LOADER_ENGINE, @@ -976,6 +1321,7 @@ def _get_docs_info(docs: list[Document]) -> str: chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) + docs = text_splitter.split_documents(docs) elif request.app.state.config.TEXT_SPLITTER == "token": log.info( f"Using token text splitter: {request.app.state.config.TIKTOKEN_ENCODING_NAME}" @@ -988,11 +1334,56 @@ def _get_docs_info(docs: list[Document]) -> str: chunk_overlap=request.app.state.config.CHUNK_OVERLAP, add_start_index=True, ) + docs = text_splitter.split_documents(docs) + elif request.app.state.config.TEXT_SPLITTER == "markdown_header": + log.info("Using markdown header text splitter") + + # Define headers to split on - covering most common markdown header levels + headers_to_split_on = [ + ("#", "Header 1"), + ("##", "Header 2"), + ("###", "Header 3"), + ("####", "Header 4"), + ("#####", "Header 5"), + ("######", "Header 6"), + ] + + markdown_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on, + strip_headers=False, # Keep headers in content for context + ) + + md_split_docs = [] + for doc in docs: + md_header_splits = markdown_splitter.split_text(doc.page_content) + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=request.app.state.config.CHUNK_SIZE, + chunk_overlap=request.app.state.config.CHUNK_OVERLAP, + add_start_index=True, + ) + md_header_splits = text_splitter.split_documents(md_header_splits) + + # Convert back to Document objects, preserving original metadata + for split_chunk in md_header_splits: + headings_list = [] + # Extract header values in order based on headers_to_split_on + for _, header_meta_key_name in headers_to_split_on: + if header_meta_key_name in split_chunk.metadata: + headings_list.append( + split_chunk.metadata[header_meta_key_name] + ) + + md_split_docs.append( + Document( + page_content=split_chunk.page_content, + metadata={**doc.metadata, "headings": headings_list}, + ) + ) + + docs = md_split_docs else: raise ValueError(ERROR_MESSAGES.DEFAULT("Invalid text splitter")) - docs = text_splitter.split_documents(docs) - if len(docs) == 0: raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT) @@ -1001,27 +1392,14 @@ def _get_docs_info(docs: list[Document]) -> str: { **doc.metadata, **(metadata if metadata else {}), - "embedding_config": json.dumps( - { - "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, - "model": request.app.state.config.RAG_EMBEDDING_MODEL, - } - ), + "embedding_config": { + "engine": request.app.state.config.RAG_EMBEDDING_ENGINE, + "model": request.app.state.config.RAG_EMBEDDING_MODEL, + }, } for doc in docs ] - # ChromaDB does not like datetime formats - # for meta-data so convert them to string. - for metadata in metadatas: - for key, value in metadata.items(): - if ( - isinstance(value, datetime) - or isinstance(value, list) - or isinstance(value, dict) - ): - metadata[key] = str(value) - try: if VECTOR_DB_CLIENT.has_collection(collection_name=collection_name): log.info(f"collection {collection_name} already exists") @@ -1035,7 +1413,7 @@ def _get_docs_info(docs: list[Document]) -> str: ) return True - log.info(f"adding to collection {collection_name}") + log.info(f"generating embeddings for {collection_name}") embedding_function = get_embedding_function( request.app.state.config.RAG_EMBEDDING_ENGINE, request.app.state.config.RAG_EMBEDDING_MODEL, @@ -1043,14 +1421,27 @@ def _get_docs_info(docs: list[Document]) -> str: ( request.app.state.config.RAG_OPENAI_API_BASE_URL if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_BASE_URL + else ( + request.app.state.config.RAG_OLLAMA_BASE_URL + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_BASE_URL + ) ), ( request.app.state.config.RAG_OPENAI_API_KEY if request.app.state.config.RAG_EMBEDDING_ENGINE == "openai" - else request.app.state.config.RAG_OLLAMA_API_KEY + else ( + request.app.state.config.RAG_OLLAMA_API_KEY + if request.app.state.config.RAG_EMBEDDING_ENGINE == "ollama" + else request.app.state.config.RAG_AZURE_OPENAI_API_KEY + ) ), request.app.state.config.RAG_EMBEDDING_BATCH_SIZE, + azure_api_version=( + request.app.state.config.RAG_AZURE_OPENAI_API_VERSION + if request.app.state.config.RAG_EMBEDDING_ENGINE == "azure_openai" + else None + ), ) embeddings = embedding_function( @@ -1058,6 +1449,7 @@ def _get_docs_info(docs: list[Document]) -> str: prefix=RAG_EMBEDDING_CONTENT_PREFIX, user=user, ) + log.info(f"embeddings generated {len(embeddings)} for {len(texts)} items") items = [ { @@ -1069,11 +1461,13 @@ def _get_docs_info(docs: list[Document]) -> str: for idx, text in enumerate(texts) ] + log.info(f"adding to collection {collection_name}") VECTOR_DB_CLIENT.insert( collection_name=collection_name, items=items, ) + log.info(f"added {len(items)} items to collection {collection_name}") return True except Exception as e: log.exception(e) @@ -1092,59 +1486,35 @@ def process_file( form_data: ProcessFileForm, user=Depends(get_verified_user), ): - try: + if user.role == "admin": file = Files.get_file_by_id(form_data.file_id) + else: + file = Files.get_file_by_id_and_user_id(form_data.file_id, user.id) - collection_name = form_data.collection_name - - if collection_name is None: - collection_name = f"file-{file.id}" - - if form_data.content: - # Update the content in the file - # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline) - - try: - # /files/{file_id}/data/content/update - VECTOR_DB_CLIENT.delete_collection(collection_name=f"file-{file.id}") - except: - # Audio file upload pipeline - pass + if file: + try: - docs = [ - Document( - page_content=form_data.content.replace("
", "\n"), - metadata={ - **file.meta, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, - }, - ) - ] + collection_name = form_data.collection_name - text_content = form_data.content - elif form_data.collection_name: - # Check if the file has already been processed and save the content - # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update + if collection_name is None: + collection_name = f"file-{file.id}" - result = VECTOR_DB_CLIENT.query( - collection_name=f"file-{file.id}", filter={"file_id": file.id} - ) + if form_data.content: + # Update the content in the file + # Usage: /files/{file_id}/data/content/update, /files/ (audio file upload pipeline) - if result is not None and len(result.ids[0]) > 0: - docs = [ - Document( - page_content=result.documents[0][idx], - metadata=result.metadatas[0][idx], + try: + # /files/{file_id}/data/content/update + VECTOR_DB_CLIENT.delete_collection( + collection_name=f"file-{file.id}" ) - for idx, id in enumerate(result.ids[0]) - ] - else: + except: + # Audio file upload pipeline + pass + docs = [ Document( - page_content=file.data.get("content", ""), + page_content=form_data.content.replace("
", "\n"), metadata={ **file.meta, "name": file.filename, @@ -1155,119 +1525,194 @@ def process_file( ) ] - text_content = file.data.get("content", "") - else: - # Process the file and save the content - # Usage: /files/ - file_path = file.path - if file_path: - file_path = Storage.get_file(file_path) - loader = Loader( - engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, - EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, - EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, - TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, - DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL, - DOCLING_OCR_ENGINE=request.app.state.config.DOCLING_OCR_ENGINE, - DOCLING_OCR_LANG=request.app.state.config.DOCLING_OCR_LANG, - DOCLING_DO_PICTURE_DESCRIPTION=request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, - PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, - DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, - DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, - MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY, - ) - docs = loader.load( - file.filename, file.meta.get("content_type"), file_path + text_content = form_data.content + elif form_data.collection_name: + # Check if the file has already been processed and save the content + # Usage: /knowledge/{id}/file/add, /knowledge/{id}/file/update + + result = VECTOR_DB_CLIENT.query( + collection_name=f"file-{file.id}", filter={"file_id": file.id} ) - docs = [ - Document( - page_content=doc.page_content, - metadata={ - **doc.metadata, - "name": file.filename, - "created_by": file.user_id, - "file_id": file.id, - "source": file.filename, + if result is not None and len(result.ids[0]) > 0: + docs = [ + Document( + page_content=result.documents[0][idx], + metadata=result.metadatas[0][idx], + ) + for idx, id in enumerate(result.ids[0]) + ] + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] + + text_content = file.data.get("content", "") + else: + # Process the file and save the content + # Usage: /files/ + file_path = file.path + if file_path: + file_path = Storage.get_file(file_path) + loader = Loader( + engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE, + DATALAB_MARKER_API_KEY=request.app.state.config.DATALAB_MARKER_API_KEY, + DATALAB_MARKER_API_BASE_URL=request.app.state.config.DATALAB_MARKER_API_BASE_URL, + DATALAB_MARKER_ADDITIONAL_CONFIG=request.app.state.config.DATALAB_MARKER_ADDITIONAL_CONFIG, + DATALAB_MARKER_SKIP_CACHE=request.app.state.config.DATALAB_MARKER_SKIP_CACHE, + DATALAB_MARKER_FORCE_OCR=request.app.state.config.DATALAB_MARKER_FORCE_OCR, + DATALAB_MARKER_PAGINATE=request.app.state.config.DATALAB_MARKER_PAGINATE, + DATALAB_MARKER_STRIP_EXISTING_OCR=request.app.state.config.DATALAB_MARKER_STRIP_EXISTING_OCR, + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=request.app.state.config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + DATALAB_MARKER_FORMAT_LINES=request.app.state.config.DATALAB_MARKER_FORMAT_LINES, + DATALAB_MARKER_USE_LLM=request.app.state.config.DATALAB_MARKER_USE_LLM, + DATALAB_MARKER_OUTPUT_FORMAT=request.app.state.config.DATALAB_MARKER_OUTPUT_FORMAT, + EXTERNAL_DOCUMENT_LOADER_URL=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_URL, + EXTERNAL_DOCUMENT_LOADER_API_KEY=request.app.state.config.EXTERNAL_DOCUMENT_LOADER_API_KEY, + TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL, + DOCLING_SERVER_URL=request.app.state.config.DOCLING_SERVER_URL, + DOCLING_PARAMS={ + "do_ocr": request.app.state.config.DOCLING_DO_OCR, + "force_ocr": request.app.state.config.DOCLING_FORCE_OCR, + "ocr_engine": request.app.state.config.DOCLING_OCR_ENGINE, + "ocr_lang": request.app.state.config.DOCLING_OCR_LANG, + "pdf_backend": request.app.state.config.DOCLING_PDF_BACKEND, + "table_mode": request.app.state.config.DOCLING_TABLE_MODE, + "pipeline": request.app.state.config.DOCLING_PIPELINE, + "do_picture_description": request.app.state.config.DOCLING_DO_PICTURE_DESCRIPTION, + "picture_description_mode": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_MODE, + "picture_description_local": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_LOCAL, + "picture_description_api": request.app.state.config.DOCLING_PICTURE_DESCRIPTION_API, + **request.app.state.config.DOCLING_PARAMS, }, + PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES, + DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT, + DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY, + MISTRAL_OCR_API_KEY=request.app.state.config.MISTRAL_OCR_API_KEY, + MINERU_API_MODE=request.app.state.config.MINERU_API_MODE, + MINERU_API_URL=request.app.state.config.MINERU_API_URL, + MINERU_API_KEY=request.app.state.config.MINERU_API_KEY, + MINERU_PARAMS=request.app.state.config.MINERU_PARAMS, ) - for doc in docs - ] + docs = loader.load( + file.filename, file.meta.get("content_type"), file_path + ) + + docs = [ + Document( + page_content=doc.page_content, + metadata={ + **filter_metadata(doc.metadata), + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + for doc in docs + ] + else: + docs = [ + Document( + page_content=file.data.get("content", ""), + metadata={ + **file.meta, + "name": file.filename, + "created_by": file.user_id, + "file_id": file.id, + "source": file.filename, + }, + ) + ] + text_content = " ".join([doc.page_content for doc in docs]) + + log.debug(f"text_content: {text_content}") + Files.update_file_data_by_id( + file.id, + {"content": text_content}, + ) + hash = calculate_sha256_string(text_content) + Files.update_file_hash_by_id(file.id, hash) + + if request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: + Files.update_file_data_by_id(file.id, {"status": "completed"}) + return { + "status": True, + "collection_name": None, + "filename": file.filename, + "content": text_content, + } else: - docs = [ - Document( - page_content=file.data.get("content", ""), + try: + result = save_docs_to_vector_db( + request, + docs=docs, + collection_name=collection_name, metadata={ - **file.meta, - "name": file.filename, - "created_by": file.user_id, "file_id": file.id, - "source": file.filename, + "name": file.filename, + "hash": hash, }, + add=(True if form_data.collection_name else False), + user=user, ) - ] - text_content = " ".join([doc.page_content for doc in docs]) - - log.debug(f"text_content: {text_content}") - Files.update_file_data_by_id( - file.id, - {"content": text_content}, - ) + log.info(f"added {len(docs)} items to collection {collection_name}") + + if result: + Files.update_file_metadata_by_id( + file.id, + { + "collection_name": collection_name, + }, + ) + + Files.update_file_data_by_id( + file.id, + {"status": "completed"}, + ) + + return { + "status": True, + "collection_name": collection_name, + "filename": file.filename, + "content": text_content, + } + else: + raise Exception("Error saving document to vector database") + except Exception as e: + raise e - hash = calculate_sha256_string(text_content) - Files.update_file_hash_by_id(file.id, hash) + except Exception as e: + log.exception(e) + Files.update_file_data_by_id( + file.id, + {"status": "failed"}, + ) - if not request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL: - try: - result = save_docs_to_vector_db( - request, - docs=docs, - collection_name=collection_name, - metadata={ - "file_id": file.id, - "name": file.filename, - "hash": hash, - }, - add=(True if form_data.collection_name else False), - user=user, + if "No pandoc was found" in str(e): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), ) - if result: - Files.update_file_metadata_by_id( - file.id, - { - "collection_name": collection_name, - }, - ) - - return { - "status": True, - "collection_name": collection_name, - "filename": file.filename, - "content": text_content, - } - except Exception as e: - raise e - else: - return { - "status": True, - "collection_name": None, - "filename": file.filename, - "content": text_content, - } - - except Exception as e: - log.exception(e) - if "No pandoc was found" in str(e): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED, - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=str(e), - ) + else: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND + ) class ProcessTextForm(BaseModel): @@ -1310,49 +1755,6 @@ def process_text( @router.post("/process/youtube") -def process_youtube_video( - request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) -): - try: - collection_name = form_data.collection_name - if not collection_name: - collection_name = calculate_sha256_string(form_data.url)[:63] - - loader = YoutubeLoader( - form_data.url, - language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE, - proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL, - ) - - docs = loader.load() - content = " ".join([doc.page_content for doc in docs]) - log.debug(f"text_content: {content}") - - save_docs_to_vector_db( - request, docs, collection_name, overwrite=True, user=user - ) - - return { - "status": True, - "collection_name": collection_name, - "filename": form_data.url, - "file": { - "data": { - "content": content, - }, - "meta": { - "name": form_data.url, - }, - }, - } - except Exception as e: - log.exception(e) - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail=ERROR_MESSAGES.DEFAULT(e), - ) - - @router.post("/process/web") def process_web( request: Request, form_data: ProcessUrlForm, user=Depends(get_verified_user) @@ -1362,19 +1764,16 @@ def process_web( if not collection_name: collection_name = calculate_sha256_string(form_data.url)[:63] - loader = get_web_loader( - form_data.url, - verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - ) - docs = loader.load() - content = " ".join([doc.page_content for doc in docs]) - + content, docs = get_content_from_url(request, form_data.url) log.debug(f"text_content: {content}") if not request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: save_docs_to_vector_db( - request, docs, collection_name, overwrite=True, user=user + request, + docs, + collection_name, + overwrite=True, + user=user, ) else: collection_name = None @@ -1425,7 +1824,25 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: """ # TODO: add playwright to search the web - if engine == "searxng": + if engine == "ollama_cloud": + return search_ollama_cloud( + "https://ollama.com", + request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY, + query, + request.app.state.config.WEB_SEARCH_RESULT_COUNT, + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + elif engine == "perplexity_search": + if request.app.state.config.PERPLEXITY_API_KEY: + return search_perplexity_search( + request.app.state.config.PERPLEXITY_API_KEY, + query, + request.app.state.config.WEB_SEARCH_RESULT_COUNT, + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No PERPLEXITY_API_KEY found in environment variables") + elif engine == "searxng": if request.app.state.config.SEARXNG_QUERY_URL: return search_searxng( request.app.state.config.SEARXNG_QUERY_URL, @@ -1530,7 +1947,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: request.app.state.config.SERPLY_API_KEY, query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, - request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + filter_list=request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, ) else: raise Exception("No SERPLY_API_KEY found in environment variables") @@ -1539,6 +1956,7 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + concurrent_requests=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, ) elif engine == "tavily": if request.app.state.config.TAVILY_API_KEY: @@ -1550,6 +1968,16 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: ) else: raise Exception("No TAVILY_API_KEY found in environment variables") + elif engine == "exa": + if request.app.state.config.EXA_API_KEY: + return search_exa( + request.app.state.config.EXA_API_KEY, + query, + request.app.state.config.WEB_SEARCH_RESULT_COUNT, + request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + ) + else: + raise Exception("No EXA_API_KEY found in environment variables") elif engine == "searchapi": if request.app.state.config.SEARCHAPI_API_KEY: return search_searchapi( @@ -1600,6 +2028,8 @@ def search_web(request: Request, engine: str, query: str) -> list[SearchResult]: query, request.app.state.config.WEB_SEARCH_RESULT_COUNT, request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST, + model=request.app.state.config.PERPLEXITY_MODEL, + search_context_usage=request.app.state.config.PERPLEXITY_SEARCH_CONTEXT_USAGE, ) elif engine == "sougou": if ( @@ -1643,8 +2073,10 @@ async def process_web_search( ): urls = [] + result_items = [] + try: - logging.info( + logging.debug( f"trying to web search with {request.app.state.config.WEB_SEARCH_ENGINE, form_data.queries}" ) @@ -1664,6 +2096,7 @@ async def process_web_search( if result: for item in result: if item and item.link: + result_items.append(item) urls.append(item.link) urls = list(dict.fromkeys(urls)) @@ -1677,23 +2110,53 @@ async def process_web_search( detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e), ) - try: - loader = get_web_loader( - urls, - verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, - requests_per_second=request.app.state.config.WEB_SEARCH_CONCURRENT_REQUESTS, - trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV, + if len(urls) == 0: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.DEFAULT("No results found from web search"), ) - docs = await loader.aload() + + try: + if request.app.state.config.BYPASS_WEB_SEARCH_WEB_LOADER: + search_results = [ + item for result in search_results for item in result if result + ] + + docs = [ + Document( + page_content=result.snippet, + metadata={ + "source": result.link, + "title": result.title, + "snippet": result.snippet, + "link": result.link, + }, + ) + for result in search_results + if hasattr(result, "snippet") and result.snippet is not None + ] + else: + loader = get_web_loader( + urls, + verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION, + requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS, + trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV, + ) + docs = await loader.aload() + urls = [ doc.metadata.get("source") for doc in docs if doc.metadata.get("source") ] # only keep the urls returned by the loader + result_items = [ + dict(item) for item in result_items if item.link in urls + ] # only keep the search results that have been loaded if request.app.state.config.BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: return { "status": True, "collection_name": None, "filenames": urls, + "items": result_items, "docs": [ { "content": doc.page_content, @@ -1726,6 +2189,7 @@ async def process_web_search( return { "status": True, "collection_names": [collection_name], + "items": result_items, "filenames": urls, "loaded_count": len(docs), } @@ -1753,7 +2217,9 @@ def query_doc_handler( user=Depends(get_verified_user), ): try: - if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( + form_data.hybrid is None or form_data.hybrid + ): collection_results = {} collection_results[form_data.collection_name] = VECTOR_DB_CLIENT.get( collection_name=form_data.collection_name @@ -1766,7 +2232,15 @@ def query_doc_handler( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=( + ( + lambda sentences: request.app.state.RERANKING_FUNCTION( + sentences, user=user + ) + ) + if request.app.state.RERANKING_FUNCTION + else None + ), k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( @@ -1774,6 +2248,11 @@ def query_doc_handler( if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), + hybrid_bm25_weight=( + form_data.hybrid_bm25_weight + if form_data.hybrid_bm25_weight + else request.app.state.config.HYBRID_BM25_WEIGHT + ), user=user, ) else: @@ -1800,6 +2279,7 @@ class QueryCollectionsForm(BaseModel): k_reranker: Optional[int] = None r: Optional[float] = None hybrid: Optional[bool] = None + hybrid_bm25_weight: Optional[float] = None @router.post("/query/collection") @@ -1809,7 +2289,9 @@ def query_collection_handler( user=Depends(get_verified_user), ): try: - if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: + if request.app.state.config.ENABLE_RAG_HYBRID_SEARCH and ( + form_data.hybrid is None or form_data.hybrid + ): return query_collection_with_hybrid_search( collection_names=form_data.collection_names, queries=[form_data.query], @@ -1817,7 +2299,15 @@ def query_collection_handler( query, prefix=prefix, user=user ), k=form_data.k if form_data.k else request.app.state.config.TOP_K, - reranking_function=request.app.state.rf, + reranking_function=( + ( + lambda sentences: request.app.state.RERANKING_FUNCTION( + sentences, user=user + ) + ) + if request.app.state.RERANKING_FUNCTION + else None + ), k_reranker=form_data.k_reranker or request.app.state.config.TOP_K_RERANKER, r=( @@ -1825,6 +2315,11 @@ def query_collection_handler( if form_data.r else request.app.state.config.RELEVANCE_THRESHOLD ), + hybrid_bm25_weight=( + form_data.hybrid_bm25_weight + if form_data.hybrid_bm25_weight + else request.app.state.config.HYBRID_BM25_WEIGHT + ), ) else: return query_collection( diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py new file mode 100644 index 00000000000..de1b979c867 --- /dev/null +++ b/backend/open_webui/routers/scim.py @@ -0,0 +1,926 @@ +""" +Experimental SCIM 2.0 Implementation for Open WebUI +Provides System for Cross-domain Identity Management endpoints for users and groups + +NOTE: This is an experimental implementation and may not fully comply with SCIM 2.0 standards, and is subject to change. +""" + +import logging +import uuid +import time +from typing import Optional, List, Dict, Any +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, Request, Query, Header, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field, ConfigDict + +from open_webui.models.users import Users, UserModel +from open_webui.models.groups import Groups, GroupModel +from open_webui.utils.auth import ( + get_admin_user, + get_current_user, + decode_token, + get_verified_user, +) +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import SRC_LOG_LEVELS + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["MAIN"]) + +router = APIRouter() + +# SCIM 2.0 Schema URIs +SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" +SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group" +SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse" +SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error" + +# SCIM Resource Types +SCIM_RESOURCE_TYPE_USER = "User" +SCIM_RESOURCE_TYPE_GROUP = "Group" + + +def scim_error(status_code: int, detail: str, scim_type: Optional[str] = None): + """Create a SCIM-compliant error response""" + error_body = { + "schemas": [SCIM_ERROR_SCHEMA], + "status": str(status_code), + "detail": detail, + } + + if scim_type: + error_body["scimType"] = scim_type + elif status_code == 404: + error_body["scimType"] = "invalidValue" + elif status_code == 409: + error_body["scimType"] = "uniqueness" + elif status_code == 400: + error_body["scimType"] = "invalidSyntax" + + return JSONResponse(status_code=status_code, content=error_body) + + +class SCIMError(BaseModel): + """SCIM Error Response""" + + schemas: List[str] = [SCIM_ERROR_SCHEMA] + status: str + scimType: Optional[str] = None + detail: Optional[str] = None + + +class SCIMMeta(BaseModel): + """SCIM Resource Metadata""" + + resourceType: str + created: str + lastModified: str + location: Optional[str] = None + version: Optional[str] = None + + +class SCIMName(BaseModel): + """SCIM User Name""" + + formatted: Optional[str] = None + familyName: Optional[str] = None + givenName: Optional[str] = None + middleName: Optional[str] = None + honorificPrefix: Optional[str] = None + honorificSuffix: Optional[str] = None + + +class SCIMEmail(BaseModel): + """SCIM Email""" + + value: str + type: Optional[str] = "work" + primary: bool = True + display: Optional[str] = None + + +class SCIMPhoto(BaseModel): + """SCIM Photo""" + + value: str + type: Optional[str] = "photo" + primary: bool = True + display: Optional[str] = None + + +class SCIMGroupMember(BaseModel): + """SCIM Group Member""" + + value: str # User ID + ref: Optional[str] = Field(None, alias="$ref") + type: Optional[str] = "User" + display: Optional[str] = None + + +class SCIMUser(BaseModel): + """SCIM User Resource""" + + model_config = ConfigDict(populate_by_name=True) + + schemas: List[str] = [SCIM_USER_SCHEMA] + id: str + externalId: Optional[str] = None + userName: str + name: Optional[SCIMName] = None + displayName: str + emails: List[SCIMEmail] + active: bool = True + photos: Optional[List[SCIMPhoto]] = None + groups: Optional[List[Dict[str, str]]] = None + meta: SCIMMeta + + +class SCIMUserCreateRequest(BaseModel): + """SCIM User Create Request""" + + model_config = ConfigDict(populate_by_name=True) + + schemas: List[str] = [SCIM_USER_SCHEMA] + externalId: Optional[str] = None + userName: str + name: Optional[SCIMName] = None + displayName: str + emails: List[SCIMEmail] + active: bool = True + password: Optional[str] = None + photos: Optional[List[SCIMPhoto]] = None + + +class SCIMUserUpdateRequest(BaseModel): + """SCIM User Update Request""" + + model_config = ConfigDict(populate_by_name=True) + + schemas: List[str] = [SCIM_USER_SCHEMA] + id: Optional[str] = None + externalId: Optional[str] = None + userName: Optional[str] = None + name: Optional[SCIMName] = None + displayName: Optional[str] = None + emails: Optional[List[SCIMEmail]] = None + active: Optional[bool] = None + photos: Optional[List[SCIMPhoto]] = None + + +class SCIMGroup(BaseModel): + """SCIM Group Resource""" + + model_config = ConfigDict(populate_by_name=True) + + schemas: List[str] = [SCIM_GROUP_SCHEMA] + id: str + displayName: str + members: Optional[List[SCIMGroupMember]] = [] + meta: SCIMMeta + + +class SCIMGroupCreateRequest(BaseModel): + """SCIM Group Create Request""" + + model_config = ConfigDict(populate_by_name=True) + + schemas: List[str] = [SCIM_GROUP_SCHEMA] + displayName: str + members: Optional[List[SCIMGroupMember]] = [] + + +class SCIMGroupUpdateRequest(BaseModel): + """SCIM Group Update Request""" + + model_config = ConfigDict(populate_by_name=True) + + schemas: List[str] = [SCIM_GROUP_SCHEMA] + displayName: Optional[str] = None + members: Optional[List[SCIMGroupMember]] = None + + +class SCIMListResponse(BaseModel): + """SCIM List Response""" + + schemas: List[str] = [SCIM_LIST_RESPONSE_SCHEMA] + totalResults: int + itemsPerPage: int + startIndex: int + Resources: List[Any] + + +class SCIMPatchOperation(BaseModel): + """SCIM Patch Operation""" + + op: str # "add", "replace", "remove" + path: Optional[str] = None + value: Optional[Any] = None + + +class SCIMPatchRequest(BaseModel): + """SCIM Patch Request""" + + schemas: List[str] = ["urn:ietf:params:scim:api:messages:2.0:PatchOp"] + Operations: List[SCIMPatchOperation] + + +def get_scim_auth( + request: Request, authorization: Optional[str] = Header(None) +) -> bool: + """ + Verify SCIM authentication + Checks for SCIM-specific bearer token configured in the system + """ + if not authorization: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authorization header required", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + parts = authorization.split() + if len(parts) != 2: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authorization format. Expected: Bearer ", + ) + + scheme, token = parts + if scheme.lower() != "bearer": + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid authentication scheme", + ) + + # Check if SCIM is enabled + scim_enabled = getattr(request.app.state, "SCIM_ENABLED", False) + log.info( + f"SCIM auth check - raw SCIM_ENABLED: {scim_enabled}, type: {type(scim_enabled)}" + ) + # Handle both PersistentConfig and direct value + if hasattr(scim_enabled, "value"): + scim_enabled = scim_enabled.value + log.info(f"SCIM enabled status after conversion: {scim_enabled}") + if not scim_enabled: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="SCIM is not enabled", + ) + + # Verify the SCIM token + scim_token = getattr(request.app.state, "SCIM_TOKEN", None) + # Handle both PersistentConfig and direct value + if hasattr(scim_token, "value"): + scim_token = scim_token.value + log.debug(f"SCIM token configured: {bool(scim_token)}") + if not scim_token or token != scim_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid SCIM token", + ) + + return True + except HTTPException: + # Re-raise HTTP exceptions as-is + raise + except Exception as e: + log.error(f"SCIM authentication error: {e}") + import traceback + + log.error(f"Traceback: {traceback.format_exc()}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication failed", + ) + + +def user_to_scim(user: UserModel, request: Request) -> SCIMUser: + """Convert internal User model to SCIM User""" + # Parse display name into name components + name_parts = user.name.split(" ", 1) if user.name else ["", ""] + given_name = name_parts[0] if name_parts else "" + family_name = name_parts[1] if len(name_parts) > 1 else "" + + # Get user's groups + user_groups = Groups.get_groups_by_member_id(user.id) + groups = [ + { + "value": group.id, + "display": group.name, + "$ref": f"{request.base_url}api/v1/scim/v2/Groups/{group.id}", + "type": "direct", + } + for group in user_groups + ] + + return SCIMUser( + id=user.id, + userName=user.email, + name=SCIMName( + formatted=user.name, + givenName=given_name, + familyName=family_name, + ), + displayName=user.name, + emails=[SCIMEmail(value=user.email)], + active=user.role != "pending", + photos=( + [SCIMPhoto(value=user.profile_image_url)] + if user.profile_image_url + else None + ), + groups=groups if groups else None, + meta=SCIMMeta( + resourceType=SCIM_RESOURCE_TYPE_USER, + created=datetime.fromtimestamp( + user.created_at, tz=timezone.utc + ).isoformat(), + lastModified=datetime.fromtimestamp( + user.updated_at, tz=timezone.utc + ).isoformat(), + location=f"{request.base_url}api/v1/scim/v2/Users/{user.id}", + ), + ) + + +def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup: + """Convert internal Group model to SCIM Group""" + members = [] + for user_id in group.user_ids: + user = Users.get_user_by_id(user_id) + if user: + members.append( + SCIMGroupMember( + value=user.id, + ref=f"{request.base_url}api/v1/scim/v2/Users/{user.id}", + display=user.name, + ) + ) + + return SCIMGroup( + id=group.id, + displayName=group.name, + members=members, + meta=SCIMMeta( + resourceType=SCIM_RESOURCE_TYPE_GROUP, + created=datetime.fromtimestamp( + group.created_at, tz=timezone.utc + ).isoformat(), + lastModified=datetime.fromtimestamp( + group.updated_at, tz=timezone.utc + ).isoformat(), + location=f"{request.base_url}api/v1/scim/v2/Groups/{group.id}", + ), + ) + + +# SCIM Service Provider Config +@router.get("/ServiceProviderConfig") +async def get_service_provider_config(): + """Get SCIM Service Provider Configuration""" + return { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"], + "patch": {"supported": True}, + "bulk": {"supported": False, "maxOperations": 1000, "maxPayloadSize": 1048576}, + "filter": {"supported": True, "maxResults": 200}, + "changePassword": {"supported": False}, + "sort": {"supported": False}, + "etag": {"supported": False}, + "authenticationSchemes": [ + { + "type": "oauthbearertoken", + "name": "OAuth Bearer Token", + "description": "Authentication using OAuth 2.0 Bearer Token", + } + ], + } + + +# SCIM Resource Types +@router.get("/ResourceTypes") +async def get_resource_types(request: Request): + """Get SCIM Resource Types""" + return [ + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"], + "id": "User", + "name": "User", + "endpoint": "/Users", + "schema": SCIM_USER_SCHEMA, + "meta": { + "location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/User", + "resourceType": "ResourceType", + }, + }, + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"], + "id": "Group", + "name": "Group", + "endpoint": "/Groups", + "schema": SCIM_GROUP_SCHEMA, + "meta": { + "location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/Group", + "resourceType": "ResourceType", + }, + }, + ] + + +# SCIM Schemas +@router.get("/Schemas") +async def get_schemas(): + """Get SCIM Schemas""" + return [ + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], + "id": SCIM_USER_SCHEMA, + "name": "User", + "description": "User Account", + "attributes": [ + { + "name": "userName", + "type": "string", + "required": True, + "uniqueness": "server", + }, + {"name": "displayName", "type": "string", "required": True}, + { + "name": "emails", + "type": "complex", + "multiValued": True, + "required": True, + }, + {"name": "active", "type": "boolean", "required": False}, + ], + }, + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"], + "id": SCIM_GROUP_SCHEMA, + "name": "Group", + "description": "Group", + "attributes": [ + {"name": "displayName", "type": "string", "required": True}, + { + "name": "members", + "type": "complex", + "multiValued": True, + "required": False, + }, + ], + }, + ] + + +# Users endpoints +@router.get("/Users", response_model=SCIMListResponse) +async def get_users( + request: Request, + startIndex: int = Query(1, ge=1), + count: int = Query(20, ge=1, le=100), + filter: Optional[str] = None, + _: bool = Depends(get_scim_auth), +): + """List SCIM Users""" + skip = startIndex - 1 + limit = count + + # Get users from database + if filter: + # Simple filter parsing - supports userName eq "email" + # In production, you'd want a more robust filter parser + if "userName eq" in filter: + email = filter.split('"')[1] + user = Users.get_user_by_email(email) + users_list = [user] if user else [] + total = 1 if user else 0 + else: + response = Users.get_users(skip=skip, limit=limit) + users_list = response["users"] + total = response["total"] + else: + response = Users.get_users(skip=skip, limit=limit) + users_list = response["users"] + total = response["total"] + + # Convert to SCIM format + scim_users = [user_to_scim(user, request) for user in users_list] + + return SCIMListResponse( + totalResults=total, + itemsPerPage=len(scim_users), + startIndex=startIndex, + Resources=scim_users, + ) + + +@router.get("/Users/{user_id}", response_model=SCIMUser) +async def get_user( + user_id: str, + request: Request, + _: bool = Depends(get_scim_auth), +): + """Get SCIM User by ID""" + user = Users.get_user_by_id(user_id) + if not user: + return scim_error( + status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found" + ) + + return user_to_scim(user, request) + + +@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED) +async def create_user( + request: Request, + user_data: SCIMUserCreateRequest, + _: bool = Depends(get_scim_auth), +): + """Create SCIM User""" + # Check if user already exists + existing_user = Users.get_user_by_email(user_data.userName) + if existing_user: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"User with email {user_data.userName} already exists", + ) + + # Create user + user_id = str(uuid.uuid4()) + email = user_data.emails[0].value if user_data.emails else user_data.userName + + # Parse name if provided + name = user_data.displayName + if user_data.name: + if user_data.name.formatted: + name = user_data.name.formatted + elif user_data.name.givenName or user_data.name.familyName: + name = f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip() + + # Get profile image if provided + profile_image = "/user.png" + if user_data.photos and len(user_data.photos) > 0: + profile_image = user_data.photos[0].value + + # Create user + new_user = Users.insert_new_user( + id=user_id, + name=name, + email=email, + profile_image_url=profile_image, + role="user" if user_data.active else "pending", + ) + + if not new_user: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create user", + ) + + return user_to_scim(new_user, request) + + +@router.put("/Users/{user_id}", response_model=SCIMUser) +async def update_user( + user_id: str, + request: Request, + user_data: SCIMUserUpdateRequest, + _: bool = Depends(get_scim_auth), +): + """Update SCIM User (full update)""" + user = Users.get_user_by_id(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User {user_id} not found", + ) + + # Build update dict + update_data = {} + + if user_data.userName: + update_data["email"] = user_data.userName + + if user_data.displayName: + update_data["name"] = user_data.displayName + elif user_data.name: + if user_data.name.formatted: + update_data["name"] = user_data.name.formatted + elif user_data.name.givenName or user_data.name.familyName: + update_data["name"] = ( + f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip() + ) + + if user_data.emails and len(user_data.emails) > 0: + update_data["email"] = user_data.emails[0].value + + if user_data.active is not None: + update_data["role"] = "user" if user_data.active else "pending" + + if user_data.photos and len(user_data.photos) > 0: + update_data["profile_image_url"] = user_data.photos[0].value + + # Update user + updated_user = Users.update_user_by_id(user_id, update_data) + if not updated_user: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update user", + ) + + return user_to_scim(updated_user, request) + + +@router.patch("/Users/{user_id}", response_model=SCIMUser) +async def patch_user( + user_id: str, + request: Request, + patch_data: SCIMPatchRequest, + _: bool = Depends(get_scim_auth), +): + """Update SCIM User (partial update)""" + user = Users.get_user_by_id(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User {user_id} not found", + ) + + update_data = {} + + for operation in patch_data.Operations: + op = operation.op.lower() + path = operation.path + value = operation.value + + if op == "replace": + if path == "active": + update_data["role"] = "user" if value else "pending" + elif path == "userName": + update_data["email"] = value + elif path == "displayName": + update_data["name"] = value + elif path == "emails[primary eq true].value": + update_data["email"] = value + elif path == "name.formatted": + update_data["name"] = value + + # Update user + if update_data: + updated_user = Users.update_user_by_id(user_id, update_data) + if not updated_user: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update user", + ) + else: + updated_user = user + + return user_to_scim(updated_user, request) + + +@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_user( + user_id: str, + request: Request, + _: bool = Depends(get_scim_auth), +): + """Delete SCIM User""" + user = Users.get_user_by_id(user_id) + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"User {user_id} not found", + ) + + success = Users.delete_user_by_id(user_id) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete user", + ) + + return None + + +# Groups endpoints +@router.get("/Groups", response_model=SCIMListResponse) +async def get_groups( + request: Request, + startIndex: int = Query(1, ge=1), + count: int = Query(20, ge=1, le=100), + filter: Optional[str] = None, + _: bool = Depends(get_scim_auth), +): + """List SCIM Groups""" + # Get all groups + groups_list = Groups.get_groups() + + # Apply pagination + total = len(groups_list) + start = startIndex - 1 + end = start + count + paginated_groups = groups_list[start:end] + + # Convert to SCIM format + scim_groups = [group_to_scim(group, request) for group in paginated_groups] + + return SCIMListResponse( + totalResults=total, + itemsPerPage=len(scim_groups), + startIndex=startIndex, + Resources=scim_groups, + ) + + +@router.get("/Groups/{group_id}", response_model=SCIMGroup) +async def get_group( + group_id: str, + request: Request, + _: bool = Depends(get_scim_auth), +): + """Get SCIM Group by ID""" + group = Groups.get_group_by_id(group_id) + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Group {group_id} not found", + ) + + return group_to_scim(group, request) + + +@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED) +async def create_group( + request: Request, + group_data: SCIMGroupCreateRequest, + _: bool = Depends(get_scim_auth), +): + """Create SCIM Group""" + # Extract member IDs + member_ids = [] + if group_data.members: + for member in group_data.members: + member_ids.append(member.value) + + # Create group + from open_webui.models.groups import GroupForm + + form = GroupForm( + name=group_data.displayName, + description="", + ) + + # Need to get the creating user's ID - we'll use the first admin + admin_user = Users.get_super_admin_user() + if not admin_user: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="No admin user found", + ) + + new_group = Groups.insert_new_group(admin_user.id, form) + if not new_group: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create group", + ) + + # Add members if provided + if member_ids: + from open_webui.models.groups import GroupUpdateForm + + update_form = GroupUpdateForm( + name=new_group.name, + description=new_group.description, + user_ids=member_ids, + ) + Groups.update_group_by_id(new_group.id, update_form) + new_group = Groups.get_group_by_id(new_group.id) + + return group_to_scim(new_group, request) + + +@router.put("/Groups/{group_id}", response_model=SCIMGroup) +async def update_group( + group_id: str, + request: Request, + group_data: SCIMGroupUpdateRequest, + _: bool = Depends(get_scim_auth), +): + """Update SCIM Group (full update)""" + group = Groups.get_group_by_id(group_id) + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Group {group_id} not found", + ) + + # Build update form + from open_webui.models.groups import GroupUpdateForm + + update_form = GroupUpdateForm( + name=group_data.displayName if group_data.displayName else group.name, + description=group.description, + ) + + # Handle members if provided + if group_data.members is not None: + member_ids = [member.value for member in group_data.members] + update_form.user_ids = member_ids + + # Update group + updated_group = Groups.update_group_by_id(group_id, update_form) + if not updated_group: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update group", + ) + + return group_to_scim(updated_group, request) + + +@router.patch("/Groups/{group_id}", response_model=SCIMGroup) +async def patch_group( + group_id: str, + request: Request, + patch_data: SCIMPatchRequest, + _: bool = Depends(get_scim_auth), +): + """Update SCIM Group (partial update)""" + group = Groups.get_group_by_id(group_id) + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Group {group_id} not found", + ) + + from open_webui.models.groups import GroupUpdateForm + + update_form = GroupUpdateForm( + name=group.name, + description=group.description, + user_ids=group.user_ids.copy() if group.user_ids else [], + ) + + for operation in patch_data.Operations: + op = operation.op.lower() + path = operation.path + value = operation.value + + if op == "replace": + if path == "displayName": + update_form.name = value + elif path == "members": + # Replace all members + update_form.user_ids = [member["value"] for member in value] + elif op == "add": + if path == "members": + # Add members + if isinstance(value, list): + for member in value: + if isinstance(member, dict) and "value" in member: + if member["value"] not in update_form.user_ids: + update_form.user_ids.append(member["value"]) + elif op == "remove": + if path and path.startswith("members[value eq"): + # Remove specific member + member_id = path.split('"')[1] + if member_id in update_form.user_ids: + update_form.user_ids.remove(member_id) + + # Update group + updated_group = Groups.update_group_by_id(group_id, update_form) + if not updated_group: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update group", + ) + + return group_to_scim(updated_group, request) + + +@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_group( + group_id: str, + request: Request, + _: bool = Depends(get_scim_auth), +): + """Delete SCIM Group""" + group = Groups.get_group_by_id(group_id) + if not group: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Group {group_id} not found", + ) + + success = Groups.delete_group_by_id(group_id) + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete group", + ) + + return None diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index f94346099ef..7585466f69c 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -9,6 +9,7 @@ from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( title_generation_template, + follow_up_generation_template, query_generation_template, image_prompt_generation_template, autocomplete_generation_template, @@ -25,6 +26,7 @@ from open_webui.config import ( DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, + DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, @@ -58,6 +60,8 @@ async def get_task_config(request: Request, user=Depends(get_verified_user)): "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION, "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, + "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, @@ -76,6 +80,8 @@ class TaskConfigForm(BaseModel): ENABLE_AUTOCOMPLETE_GENERATION: bool AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int TAGS_GENERATION_PROMPT_TEMPLATE: str + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str + ENABLE_FOLLOW_UP_GENERATION: bool ENABLE_TAGS_GENERATION: bool ENABLE_SEARCH_QUERY_GENERATION: bool ENABLE_RETRIEVAL_QUERY_GENERATION: bool @@ -94,6 +100,13 @@ async def update_task_config( form_data.TITLE_GENERATION_PROMPT_TEMPLATE ) + request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = ( + form_data.ENABLE_FOLLOW_UP_GENERATION + ) + request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = ( + form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + ) + request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = ( form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE ) @@ -133,6 +146,8 @@ async def update_task_config( "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE, "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION, + "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION, + "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION, "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION, "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE, @@ -183,14 +198,7 @@ async def generate_title( else: template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE - content = title_generation_template( - template, - form_data["messages"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) + content = title_generation_template(template, form_data["messages"], user) max_tokens = ( models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000) @@ -231,6 +239,79 @@ async def generate_title( ) +@router.post("/follow_up/completions") +async def generate_follow_ups( + request: Request, form_data: dict, user=Depends(get_verified_user) +): + + if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"detail": "Follow-up generation is disabled"}, + ) + + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): + models = { + request.state.model["id"]: request.state.model, + } + else: + models = request.app.state.MODELS + + model_id = form_data["model"] + if model_id not in models: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Model not found", + ) + + # Check if the user has a custom task model + # If the user has a custom task model, use that model + task_model_id = get_task_model_id( + model_id, + request.app.state.config.TASK_MODEL, + request.app.state.config.TASK_MODEL_EXTERNAL, + models, + ) + + log.debug( + f"generating chat title using model {task_model_id} for user {user.email} " + ) + + if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "": + template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + else: + template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE + + content = follow_up_generation_template(template, form_data["messages"], user) + + payload = { + "model": task_model_id, + "messages": [{"role": "user", "content": content}], + "stream": False, + "metadata": { + **(request.state.metadata if hasattr(request.state, "metadata") else {}), + "task": str(TASKS.FOLLOW_UP_GENERATION), + "task_body": form_data, + "chat_id": form_data.get("chat_id", None), + }, + } + + # Process the payload through the pipeline + try: + payload = await process_pipeline_inlet_filter(request, payload, user, models) + except Exception as e: + raise e + + try: + return await generate_chat_completion(request, form_data=payload, user=user) + except Exception as e: + log.error("Exception occurred", exc_info=True) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": "An internal error has occurred."}, + ) + + @router.post("/tags/completions") async def generate_chat_tags( request: Request, form_data: dict, user=Depends(get_verified_user) @@ -274,9 +355,7 @@ async def generate_chat_tags( else: template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE - content = tags_generation_template( - template, form_data["messages"], {"name": user.name} - ) + content = tags_generation_template(template, form_data["messages"], user) payload = { "model": task_model_id, @@ -342,13 +421,7 @@ async def generate_image_prompt( else: template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE - content = image_prompt_generation_template( - template, - form_data["messages"], - user={ - "name": user.name, - }, - ) + content = image_prompt_generation_template(template, form_data["messages"], user) payload = { "model": task_model_id, @@ -397,6 +470,10 @@ async def generate_queries( detail=f"Query generation is disabled", ) + if getattr(request.state, "cached_queries", None): + log.info(f"Reusing cached queries: {request.state.cached_queries}") + return request.state.cached_queries + if getattr(request.state, "direct", False) and hasattr(request.state, "model"): models = { request.state.model["id"]: request.state.model, @@ -429,9 +506,7 @@ async def generate_queries( else: template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE - content = query_generation_template( - template, form_data["messages"], {"name": user.name} - ) + content = query_generation_template(template, form_data["messages"], user) payload = { "model": task_model_id, @@ -516,9 +591,7 @@ async def generate_autocompletion( else: template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE - content = autocomplete_generation_template( - template, prompt, messages, type, {"name": user.name} - ) + content = autocomplete_generation_template(template, prompt, messages, type, user) payload = { "model": task_model_id, @@ -580,14 +653,7 @@ async def generate_emoji( template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE - content = emoji_generation_template( - template, - form_data["prompt"], - { - "name": user.name, - "location": user.info.get("location") if user.info else None, - }, - ) + content = emoji_generation_template(template, form_data["prompt"], user) payload = { "model": task_model_id, @@ -600,11 +666,11 @@ async def generate_emoji( "max_completion_tokens": 4, } ), - "chat_id": form_data.get("chat_id", None), "metadata": { **(request.state.metadata if hasattr(request.state, "metadata") else {}), "task": str(TASKS.EMOJI_GENERATION), "task_body": form_data, + "chat_id": form_data.get("chat_id", None), }, } diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 318f613983b..2fa3f6abf61 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -2,7 +2,14 @@ from pathlib import Path from typing import Optional import time +import re +import aiohttp +from open_webui.models.groups import Groups +from pydantic import BaseModel, HttpUrl +from fastapi import APIRouter, Depends, HTTPException, Request, status + +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.tools import ( ToolForm, ToolModel, @@ -10,16 +17,20 @@ ToolUserResponse, Tools, ) -from open_webui.utils.plugin import load_tool_module_by_id, replace_imports -from open_webui.config import CACHE_DIR -from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.utils.plugin import ( + load_tool_module_by_id, + replace_imports, + get_tool_module_from_cache, +) from open_webui.utils.tools import get_tool_specs from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.access_control import has_access, has_permission +from open_webui.utils.tools import get_tool_servers + from open_webui.env import SRC_LOG_LEVELS +from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.tools import get_tool_servers_data log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["MAIN"]) @@ -27,6 +38,15 @@ router = APIRouter() + +def get_tool_module(request, tool_id, load_from_db=True): + """ + Get the tool module by its ID. + """ + tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db) + return tool_module + + ############################ # GetTools ############################ @@ -34,33 +54,37 @@ @router.get("/", response_model=list[ToolUserResponse]) async def get_tools(request: Request, user=Depends(get_verified_user)): + tools = [] - if not request.app.state.TOOL_SERVERS: - # If the tool servers are not set, we need to set them - # This is done only once when the server starts - # This is done to avoid loading the tool servers every time - - request.app.state.TOOL_SERVERS = await get_tool_servers_data( - request.app.state.config.TOOL_SERVER_CONNECTIONS + # Local Tools + for tool in Tools.get_tools(): + tool_module = get_tool_module(request, tool.id) + tools.append( + ToolUserResponse( + **{ + **tool.model_dump(), + "has_user_valves": hasattr(tool_module, "UserValves"), + } + ) ) - tools = Tools.get_tools() - for server in request.app.state.TOOL_SERVERS: + # OpenAPI Tool Servers + for server in await get_tool_servers(request): tools.append( ToolUserResponse( **{ - "id": f"server:{server['idx']}", - "user_id": f"server:{server['idx']}", - "name": server["openapi"] + "id": f"server:{server.get('id')}", + "user_id": f"server:{server.get('id')}", + "name": server.get("openapi", {}) .get("info", {}) .get("title", "Tool Server"), "meta": { - "description": server["openapi"] + "description": server.get("openapi", {}) .get("info", {}) .get("description", ""), }, "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[ - server["idx"] + server.get("idx", 0) ] .get("config", {}) .get("access_control", None), @@ -70,15 +94,62 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): ) ) - if user.role != "admin": + # MCP Tool Servers + for server in request.app.state.config.TOOL_SERVER_CONNECTIONS: + if server.get("type", "openapi") == "mcp": + server_id = server.get("info", {}).get("id") + auth_type = server.get("auth_type", "none") + + session_token = None + if auth_type == "oauth_2.1": + splits = server_id.split(":") + server_id = splits[-1] if len(splits) > 1 else server_id + + session_token = ( + await request.app.state.oauth_client_manager.get_oauth_token( + user.id, f"mcp:{server_id}" + ) + ) + + tools.append( + ToolUserResponse( + **{ + "id": f"server:mcp:{server.get('info', {}).get('id')}", + "user_id": f"server:mcp:{server.get('info', {}).get('id')}", + "name": server.get("info", {}).get("name", "MCP Tool Server"), + "meta": { + "description": server.get("info", {}).get( + "description", "" + ), + }, + "access_control": server.get("config", {}).get( + "access_control", None + ), + "updated_at": int(time.time()), + "created_at": int(time.time()), + **( + { + "authenticated": session_token is not None, + } + if auth_type == "oauth_2.1" + else {} + ), + } + ) + ) + + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: + # Admin can see all tools + return tools + else: + user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)} tools = [ tool for tool in tools if tool.user_id == user.id - or has_access(user.id, "read", tool.access_control) + or has_access(user.id, "read", tool.access_control, user_group_ids) ] - - return tools + return tools ############################ @@ -88,13 +159,88 @@ async def get_tools(request: Request, user=Depends(get_verified_user)): @router.get("/list", response_model=list[ToolUserResponse]) async def get_tool_list(user=Depends(get_verified_user)): - if user.role == "admin": + if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL: tools = Tools.get_tools() else: tools = Tools.get_tools_by_user_id(user.id, "write") return tools +############################ +# LoadFunctionFromLink +############################ + + +class LoadUrlForm(BaseModel): + url: HttpUrl + + +def github_url_to_raw_url(url: str) -> str: + # Handle 'tree' (folder) URLs (add main.py at the end) + m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url) + if m1: + org, repo, branch, path = m1.groups() + return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py" + + # Handle 'blob' (file) URLs + m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url) + if m2: + org, repo, branch, path = m2.groups() + return ( + f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}" + ) + + # No match; return as-is + return url + + +@router.post("/load/url", response_model=Optional[dict]) +async def load_tool_from_url( + request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user) +): + # NOTE: This is NOT a SSRF vulnerability: + # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, + # and does NOT accept untrusted user input. Access is enforced by authentication. + + url = str(form_data.url) + if not url: + raise HTTPException(status_code=400, detail="Please enter a valid URL") + + url = github_url_to_raw_url(url) + url_parts = url.rstrip("/").split("/") + + file_name = url_parts[-1] + tool_name = ( + file_name[:-3] + if ( + file_name.endswith(".py") + and (not file_name.startswith(("main.py", "index.py", "__init__.py"))) + ) + else url_parts[-2] if len(url_parts) > 1 else "function" + ) + + try: + async with aiohttp.ClientSession(trust_env=True) as session: + async with session.get( + url, headers={"Content-Type": "application/json"} + ) as resp: + if resp.status != 200: + raise HTTPException( + status_code=resp.status, detail="Failed to fetch the tool" + ) + data = await resp.text() + if not data: + raise HTTPException( + status_code=400, detail="No data received from the URL" + ) + return { + "name": tool_name, + "content": data, + } + except Exception as e: + raise HTTPException(status_code=500, detail=f"Error importing tool: {e}") + + ############################ # ExportTools ############################ @@ -386,8 +532,9 @@ async def update_tools_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} valves = Valves(**form_data) - Tools.update_tool_valves_by_id(id, valves.model_dump()) - return valves.model_dump() + valves_dict = valves.model_dump(exclude_unset=True) + Tools.update_tool_valves_by_id(id, valves_dict) + return valves_dict except Exception as e: log.exception(f"Failed to update tool valves by id {id}: {e}") raise HTTPException( @@ -462,10 +609,11 @@ async def update_tools_user_valves_by_id( try: form_data = {k: v for k, v in form_data.items() if v is not None} user_valves = UserValves(**form_data) + user_valves_dict = user_valves.model_dump(exclude_unset=True) Tools.update_user_valves_by_id_and_user_id( - id, user.id, user_valves.model_dump() + id, user.id, user_valves_dict ) - return user_valves.model_dump() + return user_valves_dict except Exception as e: log.exception(f"Failed to update user valves by id {id}: {e}") raise HTTPException( diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 8702ae50bae..2dd229eeb77 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -1,12 +1,24 @@ import logging from typing import Optional +import base64 +import io + + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import Response, StreamingResponse, FileResponse +from pydantic import BaseModel + from open_webui.models.auths import Auths +from open_webui.models.oauth_sessions import OAuthSessions + from open_webui.models.groups import Groups from open_webui.models.chats import Chats from open_webui.models.users import ( UserModel, UserListResponse, + UserInfoListResponse, + UserIdNameListResponse, UserRoleUpdateForm, Users, UserSettings, @@ -14,11 +26,14 @@ ) -from open_webui.socket.main import get_active_status_by_user_id +from open_webui.socket.main import ( + get_active_status_by_user_id, + get_active_user_ids, + get_user_active_status, +) from open_webui.constants import ERROR_MESSAGES -from open_webui.env import SRC_LOG_LEVELS -from fastapi import APIRouter, Depends, HTTPException, Request, status -from pydantic import BaseModel +from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR + from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user from open_webui.utils.access_control import get_permissions, has_permission @@ -29,6 +44,24 @@ router = APIRouter() + +############################ +# GetActiveUsers +############################ + + +@router.get("/active") +async def get_active_users( + user=Depends(get_verified_user), +): + """ + Get a list of active users. + """ + return { + "user_ids": get_active_user_ids(), + } + + ############################ # GetUsers ############################ @@ -61,13 +94,30 @@ async def get_users( return Users.get_users(filter=filter, skip=skip, limit=limit) -@router.get("/all", response_model=UserListResponse) +@router.get("/all", response_model=UserInfoListResponse) async def get_all_users( user=Depends(get_admin_user), ): return Users.get_users() +@router.get("/search", response_model=UserIdNameListResponse) +async def search_users( + query: Optional[str] = None, + user=Depends(get_verified_user), +): + limit = PAGE_ITEM_COUNT + + page = 1 # Always return the first page for search + skip = (page - 1) * limit + + filter = {} + if query: + filter["query"] = query + + return Users.get_users(filter=filter, skip=skip, limit=limit) + + ############################ # User Groups ############################ @@ -107,12 +157,20 @@ class SharingPermissions(BaseModel): public_knowledge: bool = True public_prompts: bool = True public_tools: bool = True + public_notes: bool = True class ChatPermissions(BaseModel): controls: bool = True + valves: bool = True + system_prompt: bool = True + params: bool = True file_upload: bool = True delete: bool = True + delete_message: bool = True + continue_response: bool = True + regenerate_response: bool = True + rate_response: bool = True edit: bool = True share: bool = True export: bool = True @@ -165,22 +223,6 @@ async def update_default_user_permissions( return request.app.state.config.USER_PERMISSIONS -############################ -# UpdateUserRole -############################ - - -@router.post("/update/role", response_model=Optional[UserModel]) -async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)): - if user.id != form_data.id and form_data.id != Users.get_first_user().id: - return Users.update_user_role_by_id(form_data.id, form_data.role) - - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACTION_PROHIBITED, - ) - - ############################ # GetUserSettingsBySessionUser ############################ @@ -319,6 +361,67 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)): ) +@router.get("/{user_id}/oauth/sessions", response_model=Optional[dict]) +async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)): + sessions = OAuthSessions.get_sessions_by_user_id(user_id) + if sessions and len(sessions) > 0: + return sessions + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + +############################ +# GetUserProfileImageById +############################ + + +@router.get("/{user_id}/profile/image") +async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)): + user = Users.get_user_by_id(user_id) + if user: + if user.profile_image_url: + # check if it's url or base64 + if user.profile_image_url.startswith("http"): + return Response( + status_code=status.HTTP_302_FOUND, + headers={"Location": user.profile_image_url}, + ) + elif user.profile_image_url.startswith("data:image"): + try: + header, base64_data = user.profile_image_url.split(",", 1) + image_data = base64.b64decode(base64_data) + image_buffer = io.BytesIO(image_data) + + return StreamingResponse( + image_buffer, + media_type="image/png", + headers={"Content-Disposition": "inline; filename=image.png"}, + ) + except Exception as e: + pass + return FileResponse(f"{STATIC_DIR}/user.png") + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=ERROR_MESSAGES.USER_NOT_FOUND, + ) + + +############################ +# GetUserActiveStatusById +############################ + + +@router.get("/{user_id}/active", response_model=dict) +async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)): + return { + "active": get_user_active_status(user_id), + } + + ############################ # UpdateUserById ############################ @@ -333,11 +436,22 @@ async def update_user_by_id( # Prevent modification of the primary admin user by other admins try: first_user = Users.get_first_user() - if first_user and user_id == first_user.id and session_user.id != user_id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail=ERROR_MESSAGES.ACTION_PROHIBITED, - ) + if first_user: + if user_id == first_user.id: + if session_user.id != user_id: + # If the user trying to update is the primary admin, and they are not the primary admin themselves + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + + if form_data.role != "admin": + # If the primary admin is trying to change their own role, prevent it + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.ACTION_PROHIBITED, + ) + except Exception as e: log.error(f"Error checking primary admin status: {e}") raise HTTPException( @@ -365,6 +479,7 @@ async def update_user_by_id( updated_user = Users.update_user_by_id( user_id, { + "role": form_data.role, "name": form_data.name, "email": form_data.email.lower(), "profile_image_url": form_data.profile_image_url, @@ -423,3 +538,13 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)): status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACTION_PROHIBITED, ) + + +############################ +# GetUserGroupsById +############################ + + +@router.get("/{user_id}/groups") +async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)): + return Groups.get_groups_by_member_id(user_id) diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index b64adafb442..0e6768a6716 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -33,7 +33,7 @@ class CodeForm(BaseModel): @router.post("/code/format") -async def format_code(form_data: CodeForm, user=Depends(get_verified_user)): +async def format_code(form_data: CodeForm, user=Depends(get_admin_user)): try: formatted_code = black.format_str(form_data.code, mode=black.Mode()) return {"code": formatted_code} diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 09eccd82675..47b2c579616 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -1,13 +1,18 @@ import asyncio +import random + import socketio import logging import sys import time +from typing import Dict, Set from redis import asyncio as aioredis +import pycrdt as Y from open_webui.models.users import Users, UserNameResponse from open_webui.models.channels import Channels from open_webui.models.chats import Chats +from open_webui.models.notes import Notes, NoteUpdateForm from open_webui.utils.redis import ( get_sentinels_from_env, get_sentinel_url_from_env, @@ -17,12 +22,18 @@ ENABLE_WEBSOCKET_SUPPORT, WEBSOCKET_MANAGER, WEBSOCKET_REDIS_URL, + WEBSOCKET_REDIS_CLUSTER, WEBSOCKET_REDIS_LOCK_TIMEOUT, WEBSOCKET_SENTINEL_PORT, WEBSOCKET_SENTINEL_HOSTS, + REDIS_KEY_PREFIX, ) from open_webui.utils.auth import decode_token -from open_webui.socket.utils import RedisDict, RedisLock +from open_webui.socket.utils import RedisDict, RedisLock, YdocManager +from open_webui.tasks import create_task, stop_item_tasks +from open_webui.utils.redis import get_redis_connection +from open_webui.utils.access_control import has_access, get_users_with_access + from open_webui.env import ( GLOBAL_LOG_LEVEL, @@ -35,6 +46,8 @@ log.setLevel(SRC_LOG_LEVELS["SOCKET"]) +REDIS = None + if WEBSOCKET_MANAGER == "redis": if WEBSOCKET_SENTINEL_HOSTS: mgr = socketio.AsyncRedisManager( @@ -69,30 +82,43 @@ if WEBSOCKET_MANAGER == "redis": log.debug("Using Redis to manage websockets.") + REDIS = get_redis_connection( + redis_url=WEBSOCKET_REDIS_URL, + redis_sentinels=get_sentinels_from_env( + WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT + ), + redis_cluster=WEBSOCKET_REDIS_CLUSTER, + async_mode=True, + ) + redis_sentinels = get_sentinels_from_env( WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT ) SESSION_POOL = RedisDict( - "open-webui:session_pool", + f"{REDIS_KEY_PREFIX}:session_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) USER_POOL = RedisDict( - "open-webui:user_pool", + f"{REDIS_KEY_PREFIX}:user_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) USAGE_POOL = RedisDict( - "open-webui:usage_pool", + f"{REDIS_KEY_PREFIX}:usage_pool", redis_url=WEBSOCKET_REDIS_URL, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) clean_up_lock = RedisLock( redis_url=WEBSOCKET_REDIS_URL, - lock_name="usage_cleanup_lock", + lock_name=f"{REDIS_KEY_PREFIX}:usage_cleanup_lock", timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, redis_sentinels=redis_sentinels, + redis_cluster=WEBSOCKET_REDIS_CLUSTER, ) aquire_func = clean_up_lock.aquire_lock renew_func = clean_up_lock.renew_lock @@ -101,14 +127,37 @@ SESSION_POOL = {} USER_POOL = {} USAGE_POOL = {} + aquire_func = release_func = renew_func = lambda: True +YDOC_MANAGER = YdocManager( + redis=REDIS, + redis_key_prefix=f"{REDIS_KEY_PREFIX}:ydoc:documents", +) + + async def periodic_usage_pool_cleanup(): - if not aquire_func(): - log.debug("Usage pool cleanup lock already exists. Not running it.") - return - log.debug("Running periodic_usage_pool_cleanup") + max_retries = 2 + retry_delay = random.uniform( + WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT + ) + for attempt in range(max_retries + 1): + if aquire_func(): + break + else: + if attempt < max_retries: + log.debug( + f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..." + ) + await asyncio.sleep(retry_delay) + else: + log.warning( + "Failed to acquire cleanup lock after retries. Skipping cleanup." + ) + return + + log.debug("Running periodic_cleanup") try: while True: if not renew_func(): @@ -135,11 +184,6 @@ async def periodic_usage_pool_cleanup(): USAGE_POOL[model_id] = connections send_usage = True - - if send_usage: - # Emit updated usage information after cleaning - await sio.emit("usage", {"models": get_models_in_use()}) - await asyncio.sleep(TIMEOUT_DURATION) finally: release_func() @@ -157,6 +201,47 @@ def get_models_in_use(): return models_in_use +def get_active_user_ids(): + """Get the list of active user IDs.""" + return list(USER_POOL.keys()) + + +def get_user_active_status(user_id): + """Check if a user is currently active.""" + return user_id in USER_POOL + + +def get_user_id_from_session_pool(sid): + user = SESSION_POOL.get(sid) + if user: + return user["id"] + return None + + +def get_session_ids_from_room(room): + """Get all session IDs from a specific room.""" + active_session_ids = sio.manager.get_participants( + namespace="/", + room=room, + ) + return [session_id[0] for session_id in active_session_ids] + + +def get_user_ids_from_room(room): + active_session_ids = get_session_ids_from_room(room) + + active_user_ids = list( + set([SESSION_POOL.get(session_id)["id"] for session_id in active_session_ids]) + ) + return active_user_ids + + +def get_active_status_by_user_id(user_id): + if user_id in USER_POOL: + return True + return False + + @sio.on("usage") async def usage(sid, data): if sid in SESSION_POOL: @@ -170,9 +255,6 @@ async def usage(sid, data): sid: {"updated_at": current_time}, } - # Broadcast the usage data to all clients - await sio.emit("usage", {"models": get_models_in_use()}) - @sio.event async def connect(sid, environ, auth): @@ -184,16 +266,14 @@ async def connect(sid, environ, auth): user = Users.get_user_by_id(data["id"]) if user: - SESSION_POOL[sid] = user.model_dump() + SESSION_POOL[sid] = user.model_dump( + exclude=["date_of_birth", "bio", "gender"] + ) if user.id in USER_POOL: USER_POOL[user.id] = USER_POOL[user.id] + [sid] else: USER_POOL[user.id] = [sid] - # print(f"user {user.name}({user.id}) connected with session ID {sid}") - await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())}) - await sio.emit("usage", {"models": get_models_in_use()}) - @sio.on("user-join") async def user_join(sid, data): @@ -210,7 +290,7 @@ async def user_join(sid, data): if not user: return - SESSION_POOL[sid] = user.model_dump() + SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"]) if user.id in USER_POOL: USER_POOL[user.id] = USER_POOL[user.id] + [sid] else: @@ -221,10 +301,6 @@ async def user_join(sid, data): log.debug(f"{channels=}") for channel in channels: await sio.enter_room(sid, f"channel:{channel.id}") - - # print(f"user {user.name}({user.id}) connected with session ID {sid}") - - await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())}) return {"id": user.id, "name": user.name} @@ -249,7 +325,38 @@ async def join_channel(sid, data): await sio.enter_room(sid, f"channel:{channel.id}") -@sio.on("channel-events") +@sio.on("join-note") +async def join_note(sid, data): + auth = data["auth"] if "auth" in data else None + if not auth or "token" not in auth: + return + + token_data = decode_token(auth["token"]) + if token_data is None or "id" not in token_data: + return + + user = Users.get_user_by_id(token_data["id"]) + if not user: + return + + note = Notes.get_note_by_id(data["note_id"]) + if not note: + log.error(f"Note {data['note_id']} not found for user {user.id}") + return + + if ( + user.role != "admin" + and user.id != note.user_id + and not has_access(user.id, type="read", access_control=note.access_control) + ): + log.error(f"User {user.id} does not have access to note {data['note_id']}") + return + + log.debug(f"Joining note {note.id} for user {user.id}") + await sio.enter_room(sid, f"note:{note.id}") + + +@sio.on("events:channel") async def channel_events(sid, data): room = f"channel:{data['channel_id']}" participants = sio.manager.get_participants( @@ -266,7 +373,7 @@ async def channel_events(sid, data): if event_type == "typing": await sio.emit( - "channel-events", + "events:channel", { "channel_id": data["channel_id"], "message_id": data.get("message_id", None), @@ -277,10 +384,240 @@ async def channel_events(sid, data): ) -@sio.on("user-list") -async def user_list(sid): - if sid in SESSION_POOL: - await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())}) +@sio.on("ydoc:document:join") +async def ydoc_document_join(sid, data): + """Handle user joining a document""" + user = SESSION_POOL.get(sid) + + try: + document_id = data["document_id"] + + if document_id.startswith("note:"): + note_id = document_id.split(":")[1] + note = Notes.get_note_by_id(note_id) + if not note: + log.error(f"Note {note_id} not found") + return + + if ( + user.get("role") != "admin" + and user.get("id") != note.user_id + and not has_access( + user.get("id"), type="read", access_control=note.access_control + ) + ): + log.error( + f"User {user.get('id')} does not have access to note {note_id}" + ) + return + + user_id = data.get("user_id", sid) + user_name = data.get("user_name", "Anonymous") + user_color = data.get("user_color", "#000000") + + log.info(f"User {user_id} joining document {document_id}") + await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid) + + # Join Socket.IO room + await sio.enter_room(sid, f"doc_{document_id}") + + active_session_ids = get_session_ids_from_room(f"doc_{document_id}") + + # Get the Yjs document state + ydoc = Y.Doc() + updates = await YDOC_MANAGER.get_updates(document_id) + for update in updates: + ydoc.apply_update(bytes(update)) + + # Encode the entire document state as an update + state_update = ydoc.get_update() + await sio.emit( + "ydoc:document:state", + { + "document_id": document_id, + "state": list(state_update), # Convert bytes to list for JSON + "sessions": active_session_ids, + }, + room=sid, + ) + + # Notify other users about the new user + await sio.emit( + "ydoc:user:joined", + { + "document_id": document_id, + "user_id": user_id, + "user_name": user_name, + "user_color": user_color, + }, + room=f"doc_{document_id}", + skip_sid=sid, + ) + + log.info(f"User {user_id} successfully joined document {document_id}") + + except Exception as e: + log.error(f"Error in yjs_document_join: {e}") + await sio.emit("error", {"message": "Failed to join document"}, room=sid) + + +async def document_save_handler(document_id, data, user): + if document_id.startswith("note:"): + note_id = document_id.split(":")[1] + note = Notes.get_note_by_id(note_id) + if not note: + log.error(f"Note {note_id} not found") + return + + if ( + user.get("role") != "admin" + and user.get("id") != note.user_id + and not has_access( + user.get("id"), type="read", access_control=note.access_control + ) + ): + log.error(f"User {user.get('id')} does not have access to note {note_id}") + return + + Notes.update_note_by_id(note_id, NoteUpdateForm(data=data)) + + +@sio.on("ydoc:document:state") +async def yjs_document_state(sid, data): + """Send the current state of the Yjs document to the user""" + try: + document_id = data["document_id"] + room = f"doc_{document_id}" + + active_session_ids = get_session_ids_from_room(room) + + if sid not in active_session_ids: + log.warning(f"Session {sid} not in room {room}. Cannot send state.") + return + + if not await YDOC_MANAGER.document_exists(document_id): + log.warning(f"Document {document_id} not found") + return + + # Get the Yjs document state + ydoc = Y.Doc() + updates = await YDOC_MANAGER.get_updates(document_id) + for update in updates: + ydoc.apply_update(bytes(update)) + + # Encode the entire document state as an update + state_update = ydoc.get_update() + + await sio.emit( + "ydoc:document:state", + { + "document_id": document_id, + "state": list(state_update), # Convert bytes to list for JSON + "sessions": active_session_ids, + }, + room=sid, + ) + except Exception as e: + log.error(f"Error in yjs_document_state: {e}") + + +@sio.on("ydoc:document:update") +async def yjs_document_update(sid, data): + """Handle Yjs document updates""" + try: + document_id = data["document_id"] + + try: + await stop_item_tasks(REDIS, document_id) + except: + pass + + user_id = data.get("user_id", sid) + + update = data["update"] # List of bytes from frontend + + await YDOC_MANAGER.append_to_updates( + document_id=document_id, + update=update, # Convert list of bytes to bytes + ) + + # Broadcast update to all other users in the document + await sio.emit( + "ydoc:document:update", + { + "document_id": document_id, + "user_id": user_id, + "update": update, + "socket_id": sid, # Add socket_id to match frontend filtering + }, + room=f"doc_{document_id}", + skip_sid=sid, + ) + + async def debounced_save(): + await asyncio.sleep(0.5) + await document_save_handler( + document_id, data.get("data", {}), SESSION_POOL.get(sid) + ) + + if data.get("data"): + await create_task(REDIS, debounced_save(), document_id) + + except Exception as e: + log.error(f"Error in yjs_document_update: {e}") + + +@sio.on("ydoc:document:leave") +async def yjs_document_leave(sid, data): + """Handle user leaving a document""" + try: + document_id = data["document_id"] + user_id = data.get("user_id", sid) + + log.info(f"User {user_id} leaving document {document_id}") + + # Remove user from the document + await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid) + + # Leave Socket.IO room + await sio.leave_room(sid, f"doc_{document_id}") + + # Notify other users + await sio.emit( + "ydoc:user:left", + {"document_id": document_id, "user_id": user_id}, + room=f"doc_{document_id}", + ) + + if ( + await YDOC_MANAGER.document_exists(document_id) + and len(await YDOC_MANAGER.get_users(document_id)) == 0 + ): + log.info(f"Cleaning up document {document_id} as no users are left") + await YDOC_MANAGER.clear_document(document_id) + + except Exception as e: + log.error(f"Error in yjs_document_leave: {e}") + + +@sio.on("ydoc:awareness:update") +async def yjs_awareness_update(sid, data): + """Handle awareness updates (cursors, selections, etc.)""" + try: + document_id = data["document_id"] + user_id = data.get("user_id", sid) + update = data["update"] + + # Broadcast awareness update to all other users in the document + await sio.emit( + "ydoc:awareness:update", + {"document_id": document_id, "user_id": user_id, "update": update}, + room=f"doc_{document_id}", + skip_sid=sid, + ) + + except Exception as e: + log.error(f"Error in yjs_awareness_update: {e}") @sio.event @@ -295,7 +632,7 @@ async def disconnect(sid): if len(USER_POOL[user_id]) == 0: del USER_POOL[user_id] - await sio.emit("user-list", {"user_ids": list(USER_POOL.keys())}) + await YDOC_MANAGER.remove_user_from_all_documents(sid) else: pass # print(f"Unknown session ID {sid} disconnected") @@ -316,12 +653,15 @@ async def __event_emitter__(event_data): ) ) + chat_id = request_info.get("chat_id", None) + message_id = request_info.get("message_id", None) + emit_tasks = [ sio.emit( - "chat-events", + "events", { - "chat_id": request_info.get("chat_id", None), - "message_id": request_info.get("message_id", None), + "chat_id": chat_id, + "message_id": message_id, "data": event_data, }, to=session_id, @@ -330,8 +670,11 @@ async def __event_emitter__(event_data): ] await asyncio.gather(*emit_tasks) - - if update_db: + if ( + update_db + and message_id + and not request_info.get("chat_id", "").startswith("local:") + ): if "type" in event_data and event_data["type"] == "status": Chats.add_message_status_to_chat_by_id_and_message_id( request_info["chat_id"], @@ -368,13 +711,66 @@ async def __event_emitter__(event_data): }, ) + if "type" in event_data and event_data["type"] == "embeds": + message = Chats.get_message_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + ) + + embeds = event_data.get("data", {}).get("embeds", []) + embeds.extend(message.get("embeds", [])) + + Chats.upsert_message_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + { + "embeds": embeds, + }, + ) + + if "type" in event_data and event_data["type"] == "files": + message = Chats.get_message_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + ) + + files = event_data.get("data", {}).get("files", []) + files.extend(message.get("files", [])) + + Chats.upsert_message_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + { + "files": files, + }, + ) + + if event_data.get("type") in ["source", "citation"]: + data = event_data.get("data", {}) + if data.get("type") == None: + message = Chats.get_message_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + ) + + sources = message.get("sources", []) + sources.append(data) + + Chats.upsert_message_to_chat_by_id_and_message_id( + request_info["chat_id"], + request_info["message_id"], + { + "sources": sources, + }, + ) + return __event_emitter__ def get_event_call(request_info): async def __event_caller__(event_data): response = await sio.call( - "chat-events", + "events", { "chat_id": request_info.get("chat_id", None), "message_id": request_info.get("message_id", None), @@ -388,30 +784,3 @@ async def __event_caller__(event_data): get_event_caller = get_event_call - - -def get_user_id_from_session_pool(sid): - user = SESSION_POOL.get(sid) - if user: - return user["id"] - return None - - -def get_user_ids_from_room(room): - active_session_ids = sio.manager.get_participants( - namespace="/", - room=room, - ) - - active_user_ids = list( - set( - [SESSION_POOL.get(session_id[0])["id"] for session_id in active_session_ids] - ) - ) - return active_user_ids - - -def get_active_status_by_user_id(user_id): - if user_id in USER_POOL: - return True - return False diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 85a8bb7909b..168d2fd88ef 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -1,16 +1,30 @@ import json import uuid from open_webui.utils.redis import get_redis_connection +from open_webui.env import REDIS_KEY_PREFIX +from typing import Optional, List, Tuple +import pycrdt as Y class RedisLock: - def __init__(self, redis_url, lock_name, timeout_secs, redis_sentinels=[]): + def __init__( + self, + redis_url, + lock_name, + timeout_secs, + redis_sentinels=[], + redis_cluster=False, + ): + self.lock_name = lock_name self.lock_id = str(uuid.uuid4()) self.timeout_secs = timeout_secs self.lock_obtained = False self.redis = get_redis_connection( - redis_url, redis_sentinels, decode_responses=True + redis_url, + redis_sentinels, + redis_cluster=redis_cluster, + decode_responses=True, ) def aquire_lock(self): @@ -33,10 +47,13 @@ def release_lock(self): class RedisDict: - def __init__(self, name, redis_url, redis_sentinels=[]): + def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False): self.name = name self.redis = get_redis_connection( - redis_url, redis_sentinels, decode_responses=True + redis_url, + redis_sentinels, + redis_cluster=redis_cluster, + decode_responses=True, ) def __setitem__(self, key, value): @@ -89,3 +106,109 @@ def setdefault(self, key, default=None): if key not in self: self[key] = default return self[key] + + +class YdocManager: + def __init__( + self, + redis=None, + redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents", + ): + self._updates = {} + self._users = {} + self._redis = redis + self._redis_key_prefix = redis_key_prefix + + async def append_to_updates(self, document_id: str, update: bytes): + document_id = document_id.replace(":", "_") + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + await self._redis.rpush(redis_key, json.dumps(list(update))) + else: + if document_id not in self._updates: + self._updates[document_id] = [] + self._updates[document_id].append(update) + + async def get_updates(self, document_id: str) -> List[bytes]: + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + updates = await self._redis.lrange(redis_key, 0, -1) + return [bytes(json.loads(update)) for update in updates] + else: + return self._updates.get(document_id, []) + + async def document_exists(self, document_id: str) -> bool: + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + return await self._redis.exists(redis_key) > 0 + else: + return document_id in self._updates + + async def get_users(self, document_id: str) -> List[str]: + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:users" + users = await self._redis.smembers(redis_key) + return list(users) + else: + return self._users.get(document_id, []) + + async def add_user(self, document_id: str, user_id: str): + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:users" + await self._redis.sadd(redis_key, user_id) + else: + if document_id not in self._users: + self._users[document_id] = set() + self._users[document_id].add(user_id) + + async def remove_user(self, document_id: str, user_id: str): + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:users" + await self._redis.srem(redis_key, user_id) + else: + if document_id in self._users and user_id in self._users[document_id]: + self._users[document_id].remove(user_id) + + async def remove_user_from_all_documents(self, user_id: str): + if self._redis: + keys = await self._redis.keys(f"{self._redis_key_prefix}:*") + for key in keys: + if key.endswith(":users"): + await self._redis.srem(key, user_id) + + document_id = key.split(":")[-2] + if len(await self.get_users(document_id)) == 0: + await self.clear_document(document_id) + + else: + for document_id in list(self._users.keys()): + if user_id in self._users[document_id]: + self._users[document_id].remove(user_id) + if not self._users[document_id]: + del self._users[document_id] + + await self.clear_document(document_id) + + async def clear_document(self, document_id: str): + document_id = document_id.replace(":", "_") + + if self._redis: + redis_key = f"{self._redis_key_prefix}:{document_id}:updates" + await self._redis.delete(redis_key) + redis_users_key = f"{self._redis_key_prefix}:{document_id}:users" + await self._redis.delete(redis_users_key) + else: + if document_id in self._updates: + del self._updates[document_id] + if document_id in self._users: + del self._users[document_id] diff --git a/backend/open_webui/static/apple-touch-icon.png b/backend/open_webui/static/apple-touch-icon.png index ece4b85dbc8..98073734365 100644 Binary files a/backend/open_webui/static/apple-touch-icon.png and b/backend/open_webui/static/apple-touch-icon.png differ diff --git a/backend/open_webui/static/assets/pdf-style.css b/backend/open_webui/static/assets/pdf-style.css index 7cb5b0cd24a..8b4e8d23705 100644 --- a/backend/open_webui/static/assets/pdf-style.css +++ b/backend/open_webui/static/assets/pdf-style.css @@ -269,11 +269,6 @@ tbody + tbody { margin-bottom: 0; } -/* Add a rule to reset margin-bottom for

not followed by