From eaca4a033ea1c3827681d4ec5e4879da1afe29ae Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 17 Sep 2024 16:14:17 +0200 Subject: [PATCH 01/63] remove sync/sync active simplify --- .../modules/brain/entity/brain_entity.py | 2 - .../knowledge/service/knowledge_service.py | 3 +- .../knowledge/tests/test_knowledge_entity.py | 2 + .../sync/controller/azure_sync_routes.py | 19 +- .../sync/controller/dropbox_sync_routes.py | 16 +- .../sync/controller/github_sync_routes.py | 19 +- .../sync/controller/google_sync_routes.py | 16 +- .../sync/controller/notion_sync_routes.py | 16 +- .../modules/sync/controller/sync_routes.py | 15 +- .../quivr_api/modules/sync/dto/__init__.py | 2 +- .../api/quivr_api/modules/sync/dto/inputs.py | 83 +--- .../api/quivr_api/modules/sync/dto/outputs.py | 10 +- .../modules/sync/entity/sync_models.py | 54 +-- .../sync/repository/notion_repository.py | 130 ++++++ .../sync/repository/sync_interfaces.py | 123 ------ .../sync/repository/sync_repository.py | 416 +++++++----------- .../modules/sync/repository/sync_user.py | 255 ----------- .../modules/sync/service/sync_notion.py | 2 +- .../modules/sync/service/sync_service.py | 143 +----- .../quivr_api/modules/sync/tests/conftest.py | 8 +- .../modules/sync/tests/test_notion_service.py | 2 +- .../modules/sync/utils/sync_exceptions.py | 31 ++ .../quivr_api/modules/sync/utils/syncutils.py | 6 +- backend/worker/quivr_worker/celery_worker.py | 6 +- .../syncs/process_active_syncs.py | 19 +- .../worker/quivr_worker/syncs/store_notion.py | 2 +- backend/worker/quivr_worker/syncs/utils.py | 8 +- 27 files changed, 452 insertions(+), 956 deletions(-) create mode 100644 backend/api/quivr_api/modules/sync/repository/notion_repository.py delete mode 100644 backend/api/quivr_api/modules/sync/repository/sync_interfaces.py delete mode 100644 backend/api/quivr_api/modules/sync/repository/sync_user.py create mode 100644 backend/api/quivr_api/modules/sync/utils/sync_exceptions.py diff --git a/backend/api/quivr_api/modules/brain/entity/brain_entity.py b/backend/api/quivr_api/modules/brain/entity/brain_entity.py index 0b8e3460c396..a9a618733636 100644 --- a/backend/api/quivr_api/modules/brain/entity/brain_entity.py +++ b/backend/api/quivr_api/modules/brain/entity/brain_entity.py @@ -68,8 +68,6 @@ class Brain(AsyncAttrs, SQLModel, table=True): knowledges: List[KnowledgeDB] = Relationship( back_populates="brains", link_model=KnowledgeBrain ) - - # TODO : add # "meaning" "public"."vector", # "tags" "public"."tags"[] diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index cfc88884b98f..310a68e25211 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -209,8 +209,7 @@ async def insert_knowledge_brain( async def get_all_knowledge_in_brain(self, brain_id: UUID) -> List[Knowledge]: brain = await self.repository.get_brain_by_id(brain_id) - - all_knowledges = await brain.awaitable_attrs.knowledges + all_knowledges: List[KnowledgeDB] = await brain.awaitable_attrs.knowledges knowledges = [await knowledge.to_dto() for knowledge in all_knowledges] return knowledges diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index 7376559ebc39..b9732aa1fba2 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -209,6 +209,8 @@ async def test_knowledge_dto(session, user, brain): km_dto = await km.to_dto() + breakpoint() + assert km_dto.file_name == km.file_name assert km_dto.url == km.url assert km_dto.extension == km.extension diff --git a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py index c905fb5ba88c..4ec6ce4f137a 100644 --- a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py @@ -4,10 +4,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import HTMLResponse from msal import ConfidentialClientApplication + from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user -from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -16,8 +17,8 @@ logger = get_logger(__name__) # Initialize sync service -sync_service = SyncService() -sync_user_service = SyncUserService() +sync_service = SyncsService() +sync_user_service = SyncsService() # Initialize API router azure_sync_router = APIRouter() @@ -62,7 +63,7 @@ def authorize_azure( scopes=SCOPE, redirect_uri=REDIRECT_URI, state=state, prompt="select_account" ) - sync_user_input = SyncsUserInput( + sync_user_input = SyncCreateInput( user_id=str(current_user.id), name=name, provider="Azure", @@ -96,7 +97,7 @@ def oauth2callback_azure(request: Request): logger.debug( f"Handling OAuth2 callback for user: {current_user} with state: {state}" ) - sync_user_state = sync_user_service.get_sync_user_by_state(state_dict) + sync_user_state = sync_user_service.get_sync_by_state(state_dict) logger.info(f"Retrieved sync user state: {sync_user_state}") if not sync_user_state or state_dict != sync_user_state.state: @@ -136,10 +137,8 @@ def oauth2callback_azure(request: Request): user_email = user_info.get("mail") or user_info.get("userPrincipalName") logger.info(f"Retrieved email for user: {current_user} - {user_email}") - sync_user_input = SyncUserUpdateInput( - credentials=result, state={}, email=user_email - ) + sync_user_input = SyncUpdateInput(credentials=result, state={}, email=user_email) - sync_user_service.update_sync_user(current_user, state_dict, sync_user_input) + sync_user_service.update_sync(current_user, state_dict, sync_user_input) logger.info(f"Azure sync created successfully for user: {current_user}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py index 83fa52fee132..5bd480864005 100644 --- a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py @@ -7,8 +7,8 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user -from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -20,8 +20,8 @@ SCOPE = ["files.metadata.read", "account_info.read", "files.content.read"] # Initialize sync service -sync_service = SyncService() -sync_user_service = SyncUserService() +sync_service = SyncsService() +sync_user_service = SyncsService() logger = get_logger(__name__) @@ -65,7 +65,7 @@ def authorize_dropbox( logger.info( f"Generated authorization URL: {authorize_url} for user: {current_user.id}" ) - sync_user_input = SyncsUserInput( + sync_user_input = SyncCreateInput( name=name, user_id=str(current_user.id), provider="DropBox", @@ -104,7 +104,7 @@ def oauth2callback_dropbox(request: Request): logger.debug( f"Handling OAuth2 callback for user: {current_user} with state: {state} and state_dict: {state_dict}" ) - sync_user_state = sync_user_service.get_sync_user_by_state(state_dict) + sync_user_state = sync_user_service.get_sync_by_state(state_dict) if not sync_user_state or state_dict != sync_user_state.state: logger.error("Invalid state parameter") @@ -145,12 +145,12 @@ def oauth2callback_dropbox(request: Request): "expires_in": str(oauth_result.expires_at), } - sync_user_input = SyncUserUpdateInput( + sync_user_input = SyncUpdateInput( credentials=result, state={}, email=user_email, ) - sync_user_service.update_sync_user(current_user, state_dict, sync_user_input) + sync_user_service.update_sync(current_user, state_dict, sync_user_input) logger.info(f"DropBox sync created successfully for user: {current_user}") return HTMLResponse(successfullConnectionPage) except Exception as e: diff --git a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py index ecc88a5b3aa6..04ede0394a40 100644 --- a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py @@ -3,10 +3,11 @@ import requests from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import HTMLResponse + from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user -from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -15,8 +16,8 @@ logger = get_logger(__name__) # Initialize sync service -sync_service = SyncService() -sync_user_service = SyncUserService() +sync_service = SyncsService() +sync_user_service = SyncsService() # Initialize API router github_sync_router = APIRouter() @@ -54,7 +55,7 @@ def authorize_github( f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state}" ) - sync_user_input = SyncsUserInput( + sync_user_input = SyncCreateInput( user_id=str(current_user.id), name=name, provider="GitHub", @@ -84,7 +85,7 @@ def oauth2callback_github(request: Request): logger.debug( f"Handling OAuth2 callback for user: {current_user} with state: {state}" ) - sync_user_state = sync_user_service.get_sync_user_by_state(state_dict) + sync_user_state = sync_user_service.get_sync_by_state(state_dict) logger.info(f"Retrieved sync user state: {sync_user_state}") if state_dict != sync_user_state["state"]: @@ -146,10 +147,8 @@ def oauth2callback_github(request: Request): logger.info(f"Retrieved email for user: {current_user} - {user_email}") - sync_user_input = SyncUserUpdateInput( - credentials=result, state={}, email=user_email - ) + sync_user_input = SyncUpdateInput(credentials=result, state={}, email=user_email) - sync_user_service.update_sync_user(current_user, state_dict, sync_user_input) + sync_user_service.update_sync(current_user, state_dict, sync_user_input) logger.info(f"GitHub sync created successfully for user: {current_user}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py index c9b5b3bf478c..7a1d437842c6 100644 --- a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py @@ -9,8 +9,8 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user -from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -22,8 +22,8 @@ logger = get_logger(__name__) # Initialize sync service -sync_service = SyncService() -sync_user_service = SyncUserService() +sync_service = SyncsService() +sync_user_service = SyncsService() # Initialize API router google_sync_router = APIRouter() @@ -94,7 +94,7 @@ def authorize_google( logger.info( f"Generated authorization URL: {authorization_url} for user: {current_user.id}" ) - sync_user_input = SyncsUserInput( + sync_user_input = SyncCreateInput( name=name, user_id=str(current_user.id), provider="Google", @@ -126,7 +126,7 @@ def oauth2callback_google(request: Request): logger.debug( f"Handling OAuth2 callback for user: {current_user} with state: {state}" ) - sync_user_state = sync_user_service.get_sync_user_by_state(state_dict) + sync_user_state = sync_user_service.get_sync_by_state(state_dict) logger.info(f"Retrieved sync user state: {sync_user_state}") if not sync_user_state or state_dict != sync_user_state.state: @@ -154,11 +154,11 @@ def oauth2callback_google(request: Request): user_email = user_info.get("email") logger.info(f"Retrieved email for user: {current_user} - {user_email}") - sync_user_input = SyncUserUpdateInput( + sync_user_input = SyncUpdateInput( credentials=json.loads(creds.to_json()), state={}, email=user_email, ) - sync_user_service.update_sync_user(current_user, state_dict, sync_user_input) + sync_user_service.update_sync(current_user, state_dict, sync_user_input) logger.info(f"Google Drive sync created successfully for user: {current_user}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py index 1cc8b2b9fe78..c4fd1a034664 100644 --- a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py @@ -10,8 +10,8 @@ from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user -from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -25,8 +25,8 @@ # Initialize sync service -sync_service = SyncService() -sync_user_service = SyncUserService() +sync_service = SyncsService() +sync_user_service = SyncsService() logger = get_logger(__name__) @@ -59,7 +59,7 @@ def authorize_notion( logger.info( f"Generated authorization URL: {authorize_url} for user: {current_user.id}" ) - sync_user_input = SyncsUserInput( + sync_user_input = SyncCreateInput( name=name, user_id=str(current_user.id), provider="Notion", @@ -93,7 +93,7 @@ def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks): logger.debug( f"Handling OAuth2 callback for user: {current_user} with state: {state} and state_dict: {state_dict}" ) - sync_user_state = sync_user_service.get_sync_user_by_state(state_dict) + sync_user_state = sync_user_service.get_sync_by_state(state_dict) if not sync_user_state or state_dict != sync_user_state.state: logger.error(f"Invalid state parameter for {sync_user_state}") @@ -143,12 +143,12 @@ def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks): "expires_in": oauth_result.get("expires_in", ""), } - sync_user_input = SyncUserUpdateInput( + sync_user_input = SyncUpdateInput( credentials=result, state={}, email=user_email, ) - sync_user_service.update_sync_user(current_user, state_dict, sync_user_input) + sync_user_service.update_sync(current_user, state_dict, sync_user_input) logger.info(f"Notion sync created successfully for user: {current_user}") # launch celery task to sync notion data celery.send_task( diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index 3adcbe41b4bb..76b18e94303e 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -1,6 +1,6 @@ import os import uuid -from typing import Annotated, List +from typing import List from fastapi import APIRouter, Depends, HTTPException, status @@ -23,7 +23,7 @@ from quivr_api.modules.sync.dto.outputs import AuthMethodEnum from quivr_api.modules.sync.entity.sync_models import SyncsActive from quivr_api.modules.sync.service.sync_notion import SyncNotionService -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity notification_service = NotificationService() @@ -35,9 +35,8 @@ logger = get_logger(__name__) # Initialize sync service -sync_service = SyncService() -sync_user_service = SyncUserService() -NotionServiceDep = Annotated[SyncNotionService, Depends(get_service(SyncNotionService))] +sync_service = SyncsService() +sync_user_service = SyncsService() # Initialize API router @@ -119,7 +118,7 @@ async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user)) List: A list of syncs for the user. """ logger.debug(f"Fetching user syncs for user: {current_user.id}") - return sync_user_service.get_syncs_user(current_user.id) + return sync_user_service.get_syncs(current_user.id) @sync_router.delete( @@ -144,7 +143,7 @@ async def delete_user_sync( logger.debug( f"Deleting user sync for user: {current_user.id} with sync ID: {sync_id}" ) - sync_user_service.delete_sync_user(sync_id, str(current_user.id)) # type: ignore + sync_user_service.delete_sync(sync_id, str(current_user.id)) # type: ignore return None @@ -372,8 +371,8 @@ async def get_active_syncs_for_user( ) async def get_files_folder_user_sync( user_sync_id: int, - notion_service: NotionServiceDep, folder_id: str | None = None, + notion_service: SyncNotionService = Depends(get_service(SyncNotionService)), current_user: UserIdentity = Depends(get_current_user), ): """ diff --git a/backend/api/quivr_api/modules/sync/dto/__init__.py b/backend/api/quivr_api/modules/sync/dto/__init__.py index 986765df6302..b40f07618217 100644 --- a/backend/api/quivr_api/modules/sync/dto/__init__.py +++ b/backend/api/quivr_api/modules/sync/dto/__init__.py @@ -1 +1 @@ -from .outputs import SyncsDescription, SyncsUserOutput +from .outputs import SyncsDescription, SyncsOutput diff --git a/backend/api/quivr_api/modules/sync/dto/inputs.py b/backend/api/quivr_api/modules/sync/dto/inputs.py index 25847e2105b4..a267304dbaa9 100644 --- a/backend/api/quivr_api/modules/sync/dto/inputs.py +++ b/backend/api/quivr_api/modules/sync/dto/inputs.py @@ -1,9 +1,9 @@ -from typing import List, Optional +from uuid import UUID from pydantic import BaseModel -class SyncsUserInput(BaseModel): +class SyncCreateInput(BaseModel): """ Input model for creating a new sync user. @@ -15,7 +15,7 @@ class SyncsUserInput(BaseModel): state (dict): The state information for the sync user. """ - user_id: str + user_id: UUID name: str provider: str credentials: dict @@ -23,7 +23,7 @@ class SyncsUserInput(BaseModel): additional_data: dict = {} -class SyncUserUpdateInput(BaseModel): +class SyncUpdateInput(BaseModel): """ Input model for updating an existing sync user. @@ -35,78 +35,3 @@ class SyncUserUpdateInput(BaseModel): credentials: dict state: dict email: str - - -class SyncActiveSettings(BaseModel): - """ - Sync active settings. - - Attributes: - folders (List[str] | None): A list of folder paths to be synced, or None if not applicable. - files (List[str] | None): A list of file paths to be synced, or None if not applicable. - """ - - folders: Optional[List[str]] = None - files: Optional[List[str]] = None - - -class SyncsActiveInput(BaseModel): - """ - Input model for creating a new active sync. - - Attributes: - name (str): The name of the sync. - syncs_user_id (int): The ID of the sync user associated with this sync. - settings (SyncActiveSettings): The settings for the active sync. - """ - - name: str - syncs_user_id: int - settings: SyncActiveSettings - brain_id: str - notification_id: Optional[str] = None - - -class SyncsActiveUpdateInput(BaseModel): - """ - Input model for updating an existing active sync. - - Attributes: - name (str): The updated name of the sync. - sync_interval_minutes (int): The updated sync interval in minutes. - settings (dict): The updated settings for the active sync. - """ - - name: Optional[str] = None - settings: Optional[SyncActiveSettings] = None - last_synced: Optional[str] = None - force_sync: Optional[bool] = False - notification_id: Optional[str] = None - - -class SyncFileInput(BaseModel): - """ - Input model for creating a new sync file. - - Attributes: - path (str): The path of the file. - syncs_active_id (int): The ID of the active sync associated with this file. - """ - - path: str - syncs_active_id: int - last_modified: str - brain_id: str - supported: Optional[bool] = True - - -class SyncFileUpdateInput(BaseModel): - """ - Input model for updating an existing sync file. - - Attributes: - last_modified (datetime.datetime): The updated last modified date and time. - """ - - last_modified: Optional[str] = None - supported: Optional[bool] = None diff --git a/backend/api/quivr_api/modules/sync/dto/outputs.py b/backend/api/quivr_api/modules/sync/dto/outputs.py index 498f66177b01..4702ae541753 100644 --- a/backend/api/quivr_api/modules/sync/dto/outputs.py +++ b/backend/api/quivr_api/modules/sync/dto/outputs.py @@ -7,13 +7,21 @@ class AuthMethodEnum(str, Enum): URI_WITH_CALLBACK = "uri_with_callback" +class SyncProvider(str, Enum): + GOOGLE = "google" + AZURE = "azure" + DROPBOX = "dropbox" + NOTION = "notion" + GITHUB = "github" + + class SyncsDescription(BaseModel): name: str description: str auth_method: AuthMethodEnum -class SyncsUserOutput(BaseModel): +class SyncsOutput(BaseModel): user_id: str provider: str state: dict diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index 8cc2297377bf..f1f7cbfd8200 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -2,13 +2,22 @@ import io from dataclasses import dataclass from datetime import datetime -from typing import Optional +from typing import Dict, Optional from uuid import UUID from pydantic import BaseModel -from sqlmodel import TIMESTAMP, Column, Field, Relationship, SQLModel, text +from sqlmodel import ( # noqa: F811 + JSON, + TIMESTAMP, + Column, + Field, + Relationship, + SQLModel, + text, +) from sqlmodel import UUID as PGUUID +from quivr_api.modules.sync.dto.outputs import SyncsOutput from quivr_api.modules.user.entity.user_identity import User @@ -27,15 +36,6 @@ def file_sha1(self) -> str: return m.hexdigest() -class DBSyncFile(BaseModel): - id: int - path: str - syncs_active_id: int - last_modified: str - brain_id: str - supported: bool - - class SyncFile(BaseModel): id: str name: str @@ -50,32 +50,20 @@ class SyncFile(BaseModel): type: Optional[str] = None -class SyncsUser(BaseModel): - id: int - user_id: UUID +class Syncs(SQLModel, table=True): + __tablename__ = "syns_user" # type: ignore + id: UUID | None = Field(default=None, primary_key=True) + user_id: UUID = Field(foreign_key="users.id", nullable=False) name: str provider: str - credentials: dict - state: dict + credentials: Dict[str, str] | None = Field( + default=None, sa_column=Column("state", JSON) + ) + state: Dict[str, str] | None = Field(default=None, sa_column=Column("state", JSON)) additional_data: dict - -class SyncsActive(BaseModel): - id: int - name: str - syncs_user_id: int - user_id: UUID - settings: dict - last_synced: str - sync_interval_minutes: int - brain_id: UUID - syncs_user: Optional[SyncsUser] = None - notification_id: Optional[str] = None - - -# TODO: all of this should be rewritten -class SyncsActiveDetails(BaseModel): - pass + def to_dto(self) -> SyncsOutput: + return SyncsOutput(user_id=self.user_id, provider=self.provider) class NotionSyncFile(SQLModel, table=True): diff --git a/backend/api/quivr_api/modules/sync/repository/notion_repository.py b/backend/api/quivr_api/modules/sync/repository/notion_repository.py new file mode 100644 index 000000000000..f87be94c029e --- /dev/null +++ b/backend/api/quivr_api/modules/sync/repository/notion_repository.py @@ -0,0 +1,130 @@ +from typing import List, Sequence +from uuid import UUID + +from sqlalchemy import or_ +from sqlalchemy.exc import IntegrityError +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_api.logger import get_logger +from quivr_api.modules.dependencies import BaseRepository, get_supabase_client +from quivr_api.modules.notification.service.notification_service import ( + NotificationService, +) +from quivr_api.modules.sync.entity.sync_models import NotionSyncFile + +notification_service = NotificationService() + +logger = get_logger(__name__) + + +class NotionRepository(BaseRepository): + def __init__(self, session: AsyncSession): + super().__init__(session) + self.session = session + self.db = get_supabase_client() + + async def get_user_notion_files(self, user_id: UUID) -> Sequence[NotionSyncFile]: + query = select(NotionSyncFile).where(NotionSyncFile.user_id == user_id) + response = await self.session.exec(query) + return response.all() + + async def create_notion_files( + self, notion_files: List[NotionSyncFile] + ) -> List[NotionSyncFile]: + try: + self.session.add_all(notion_files) + await self.session.commit() + except IntegrityError: + await self.session.rollback() + raise Exception("Integrity error while creating notion files.") + except Exception as e: + await self.session.rollback() + raise e + + return notion_files + + async def update_notion_file(self, updated_notion_file: NotionSyncFile) -> bool: + try: + is_update = False + query = select(NotionSyncFile).where( + NotionSyncFile.notion_id == updated_notion_file.notion_id + ) + result = await self.session.exec(query) + existing_page = result.one_or_none() + + if existing_page: + # Update existing page + existing_page.name = updated_notion_file.name + existing_page.last_modified = updated_notion_file.last_modified + self.session.add(existing_page) + is_update = True + else: + # Add new page + self.session.add(updated_notion_file) + + await self.session.commit() + + # Refresh the object that's actually in the session + refreshed_file = existing_page if existing_page else updated_notion_file + await self.session.refresh(refreshed_file) + + logger.info(f"Updated notion file in notion repo: {refreshed_file}") + return is_update + + except IntegrityError as ie: + logger.error(f"IntegrityError occurred: {ie}") + await self.session.rollback() + raise Exception("Integrity error while updating notion file.") + except Exception as e: + logger.error(f"Exception occurred: {e}") + await self.session.rollback() + raise + + async def get_notion_files_by_ids(self, ids: List[str]) -> Sequence[NotionSyncFile]: + query = select(NotionSyncFile).where(NotionSyncFile.notion_id.in_(ids)) # type: ignore + response = await self.session.exec(query) + return response.all() + + async def get_notion_files_by_parent_id( + self, parent_id: str | None + ) -> Sequence[NotionSyncFile]: + query = select(NotionSyncFile).where(NotionSyncFile.parent_id == parent_id) + response = await self.session.exec(query) + return response.all() + + async def get_all_notion_files(self) -> Sequence[NotionSyncFile]: + query = select(NotionSyncFile) + response = await self.session.exec(query) + return response.all() + + async def is_folder_page(self, page_id: str) -> bool: + query = select(NotionSyncFile).where(NotionSyncFile.parent_id == page_id) + response = await self.session.exec(query) + return response.first() is not None + + async def delete_notion_page(self, notion_id: UUID): + query = select(NotionSyncFile).where(NotionSyncFile.notion_id == notion_id) + response = await self.session.exec(query) + notion_file = response.first() + if notion_file: + await self.session.delete(notion_file) + await self.session.commit() + return notion_file + return None + + async def delete_notion_pages(self, notion_ids: List[UUID]): + query = select(NotionSyncFile).where( + or_( + col(NotionSyncFile.notion_id).in_(notion_ids), + col(NotionSyncFile.parent_id).in_(notion_ids), + ) + ) + response = await self.session.exec(query) + notion_files = response.all() + if notion_files: + for notion_file in notion_files: + await self.session.delete(notion_file) + await self.session.commit() + return notion_files + return None diff --git a/backend/api/quivr_api/modules/sync/repository/sync_interfaces.py b/backend/api/quivr_api/modules/sync/repository/sync_interfaces.py deleted file mode 100644 index dc6a9add6eb9..000000000000 --- a/backend/api/quivr_api/modules/sync/repository/sync_interfaces.py +++ /dev/null @@ -1,123 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, List, Literal -from uuid import UUID - -from quivr_api.modules.sync.dto.inputs import ( - SyncFileInput, - SyncFileUpdateInput, - SyncsActiveInput, - SyncsActiveUpdateInput, - SyncsUserInput, - SyncUserUpdateInput, -) -from quivr_api.modules.sync.entity.sync_models import ( - DBSyncFile, - SyncFile, - SyncsActive, -) - - -class SyncUserInterface(ABC): - @abstractmethod - def create_sync_user( - self, - sync_user_input: SyncsUserInput, - ): - pass - - @abstractmethod - def get_syncs_user(self, user_id: str, sync_user_id: int | None = None): - pass - - @abstractmethod - def get_sync_user_by_id(self, sync_id: int): - pass - - @abstractmethod - def delete_sync_user(self, sync_user_id: int, user_id: UUID | str): - pass - - @abstractmethod - def get_sync_user_by_state(self, state: dict): - pass - - @abstractmethod - def update_sync_user( - self, sync_user_id: int, state: dict, sync_user_input: SyncUserUpdateInput - ): - pass - - @abstractmethod - async def get_files_folder_user_sync( - self, - sync_active_id: int, - user_id: str, - notion_service: Any = None, - folder_id: int | str | None = None, - recursive: bool = False, - ) -> None | dict[str, List[SyncFile]] | Literal["No sync found"]: - pass - - @abstractmethod - def get_all_notion_user_syncs(self): - pass - - -class SyncInterface(ABC): - @abstractmethod - def create_sync_active( - self, - sync_active_input: SyncsActiveInput, - user_id: str, - ) -> SyncsActive | None: - pass - - @abstractmethod - def get_syncs_active(self, user_id: UUID | str) -> List[SyncsActive]: - pass - - @abstractmethod - def update_sync_active( - self, sync_id: UUID | int, sync_active_input: SyncsActiveUpdateInput - ): - pass - - @abstractmethod - def delete_sync_active(self, sync_active_id: int, user_id: str): - pass - - @abstractmethod - def get_details_sync_active(self, sync_active_id: int): - pass - - @abstractmethod - async def get_syncs_active_in_interval(self) -> List[SyncsActive]: - pass - - -class SyncFileInterface(ABC): - @abstractmethod - def create_sync_file(self, sync_file_input: SyncFileInput) -> DBSyncFile: - pass - - @abstractmethod - def get_sync_files(self, sync_active_id: int) -> list[DBSyncFile]: - pass - - @abstractmethod - def update_sync_file(self, sync_file_id: int, sync_file_input: SyncFileUpdateInput): - pass - - @abstractmethod - def delete_sync_file(self, sync_file_id: int): - pass - - @abstractmethod - def update_or_create_sync_file( - self, - file: SyncFile, - sync_active: SyncsActive, - previous_file: DBSyncFile | None, - supported: bool, - ) -> DBSyncFile | None: - pass diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index 59e013b22925..6282f0ffcbe7 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -1,318 +1,218 @@ -from datetime import datetime, timedelta -from typing import List, Sequence +import json +from sqlite3 import IntegrityError +from typing import List from uuid import UUID -from quivr_api.logger import get_logger -from quivr_api.modules.dependencies import (BaseRepository, get_supabase_client) -from quivr_api.modules.notification.service.notification_service import \ - NotificationService -from quivr_api.modules.sync.dto.inputs import (SyncsActiveInput, - SyncsActiveUpdateInput) -from quivr_api.modules.sync.entity.sync_models import (NotionSyncFile, - SyncsActive) -from quivr_api.modules.sync.repository.sync_interfaces import SyncInterface -from sqlalchemy import or_ -from sqlalchemy.exc import IntegrityError -from sqlmodel import col, select +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -notification_service = NotificationService() +from quivr_api.logger import get_logger +from quivr_api.modules.dependencies import BaseRepository, get_supabase_client +from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import SyncFile, Syncs +from quivr_api.modules.sync.repository.notion_repository import NotionRepository +from quivr_api.modules.sync.service.sync_notion import SyncNotionService +from quivr_api.modules.sync.utils.exceptions import SyncNotFoundException +from quivr_api.modules.sync.utils.sync import ( + AzureDriveSync, + BaseSync, + DropboxSync, + GitHubSync, + GoogleDriveSync, + NotionSync, +) +from quivr_api.modules.sync.utils.sync_exceptions import SyncEmptyCredentials logger = get_logger(__name__) -class Sync(SyncInterface): - def __init__(self): - """ - Initialize the Sync class with a Supabase client. - """ - supabase_client = get_supabase_client() - self.db = supabase_client # type: ignore - logger.debug("Supabase client initialized") +class SyncsRepository(BaseRepository): + def __init__(self, session: AsyncSession): + self.session = session + self.db = get_supabase_client() - def create_sync_active( - self, sync_active_input: SyncsActiveInput, user_id: str - ) -> SyncsActive | None: + self.sync_provider_mapping: dict[SyncProvider, BaseSync] = { + SyncProvider.GOOGLE: GoogleDriveSync(), + SyncProvider.DROPBOX: DropboxSync(), + SyncProvider.AZURE: AzureDriveSync(), + SyncProvider.NOTION: NotionSync( + notion_service=SyncNotionService(NotionRepository(self.session)) + ), + SyncProvider.GITHUB: GitHubSync(), + } + + async def create_sync( + self, + sync_user_input: SyncCreateInput, + ) -> Syncs: """ - Create a new active sync in the database. + Create a new sync user in the database. Args: - sync_active_input (SyncsActiveInput): The input data for creating an active sync. - user_id (str): The user ID associated with the active sync. + sync_user_input (SyncsUserInput): The input data for creating a sync user. Returns: - SyncsActive or None: The created active sync data or None if creation failed. """ - logger.info( - "Creating active sync for user_id: %s with input: %s", - user_id, - sync_active_input, - ) - sync_active_input_dict = sync_active_input.model_dump() - sync_active_input_dict["user_id"] = user_id - response = ( - self.db.from_("syncs_active").insert(sync_active_input_dict).execute() - ) - if response.data: - logger.info("Active sync created successfully: %s", response.data[0]) - return SyncsActive(**response.data[0]) + logger.info("Creating sync user with input: %s", sync_user_input) + try: + sync = Syncs.model_validate(sync_user_input.model_dump()) + self.session.add(sync) + await self.session.commit() + await self.session.refresh(sync) + return sync + except IntegrityError: + await self.session.rollback() + raise + except Exception: + await self.session.rollback() + raise - logger.error("Failed to create active sync for user_id: %s", user_id) + async def get_sync_id(self, sync_id: int) -> Syncs | None: + """ + Retrieve sync users from the database. + """ + query = select(Syncs).where(Syncs.id == sync_id) + result = await self.session.exec(query) + sync = result.first() + if not sync: + logger.error( + f"No sync user found for sync_id: {sync_id}", + ) + raise SyncNotFoundException() + return sync - def get_syncs_active(self, user_id: UUID | str) -> List[SyncsActive]: + async def get_syncs(self, user_id: UUID, sync_id: int | None = None): """ - Retrieve active syncs from the database. + Retrieve sync users from the database. Args: - user_id (str): The user ID to filter active syncs. + user_id (str): The user ID to filter sync users. + sync_user_id (int, optional): The sync user ID to filter sync users. Defaults to None. Returns: - List[SyncsActive]: A list of active syncs matching the criteria. + list: A list of sync users matching the criteria. """ - logger.info("Retrieving active syncs for user_id: %s", user_id) - response = ( - self.db.from_("syncs_active") - .select("*, syncs_user(*)") - .eq("user_id", user_id) - .execute() + logger.info( + "Retrieving sync users for user_id: %s, sync_user_id: %s", + user_id, + sync_id, ) - if response.data: - logger.info("Active syncs retrieved successfully: %s", response.data) - return [SyncsActive(**sync) for sync in response.data] - logger.warning("No active syncs found for user_id: %s", user_id) - return [] + query = select(Syncs).where(Syncs.id == sync_id).where(Syncs.user_id == user_id) + result = await self.session.exec(query) + sync = result.first() + if not sync: + logger.error( + f"No sync user found for sync_id: {sync_id}", + ) + raise SyncNotFoundException() + return sync - def update_sync_active( - self, sync_id: int | str, sync_active_input: SyncsActiveUpdateInput - ) -> SyncsActive | None: + async def get_sync_user_by_state(self, state: dict) -> Syncs | None: """ - Update an active sync in the database. + Retrieve a sync user by their state. Args: - sync_id (int): The ID of the active sync. - sync_active_input (SyncsActiveUpdateInput): The input data for updating the active sync. + state (dict): The state to filter sync users. Returns: - dict or None: The updated active sync data or None if update failed. + dict or None: The sync user data matching the state or None if not found. """ - logger.info( - "Updating active sync with sync_id: %s, input: %s", - sync_id, - sync_active_input, - ) - - response = ( - self.db.from_("syncs_active") - .update(sync_active_input.model_dump(exclude_unset=True)) - .eq("id", sync_id) - .execute() - ) + logger.info("Getting sync user by state: %s", state) - if response.data: - logger.info("Active sync updated successfully: %s", response.data[0]) - return SyncsActive.model_validate(response.data[0]) + query = select(Syncs).where(Syncs.state == state) + result = await self.session.exec(query) + sync = result.first() + if not sync: + raise SyncNotFoundException() + return sync - logger.error("Failed to update active sync with sync_id: %s", sync_id) + return None - def delete_sync_active(self, sync_active_id: int, user_id: UUID): + def delete_sync(self, sync_id: int, user_id: UUID | str): """ - Delete an active sync from the database. + Delete a sync user from the database. Args: - sync_active_id (int): The ID of the active sync. - user_id (str): The user ID associated with the active sync. - - Returns: - dict or None: The deleted active sync data or None if deletion failed. + provider (str): The provider of the sync user. + user_id (str): The user ID of the sync user. """ logger.info( - "Deleting active sync with sync_active_id: %s, user_id: %s", - sync_active_id, - user_id, - ) - response = ( - self.db.from_("syncs_active") - .delete() - .eq("id", sync_active_id) - .eq("user_id", str(user_id)) - .execute() + "Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id ) - if response.data: - logger.info("Active sync deleted successfully: %s", response.data[0]) - return response.data[0] - logger.warning( - "Failed to delete active sync with sync_active_id: %s, user_id: %s", - sync_active_id, - user_id, - ) - return None - - def get_details_sync_active(self, sync_active_id: int): + self.db.from_("syncs_user").delete().eq("id", sync_id).eq( + "user_id", user_id + ).execute() + logger.info("Sync user deleted successfully") + + def update_sync_user( + self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput + ): """ - Retrieve details of an active sync, including associated sync user data. + Update a sync user in the database. Args: - sync_active_id (int): The ID of the active sync. - - Returns: - dict or None: The detailed active sync data or None if not found. + sync_user_id (str): The user ID of the sync user. + state (dict): The state to filter sync users. + sync_user_input (SyncUserUpdateInput): The input data for updating the sync user. """ logger.info( - "Retrieving details for active sync with sync_active_id: %s", sync_active_id - ) - response = ( - self.db.table("syncs_active") - .select("*, syncs_user(provider, credentials)") - .eq("id", sync_active_id) - .execute() - ) - if response.data: - logger.info( - "Details for active sync retrieved successfully: %s", response.data[0] - ) - return response.data[0] - logger.warning( - "No details found for active sync with sync_active_id: %s", sync_active_id + "Updating sync user with user_id: %s, state: %s, input: %s", + sync_user_id, + state, + sync_user_input, ) - return None - async def get_syncs_active_in_interval(self) -> List[SyncsActive]: + state_str = json.dumps(state) + self.db.from_("syncs_user").update(sync_user_input.model_dump()).eq( + "user_id", str(sync_user_id) + ).eq("state", state_str).execute() + logger.info("Sync user updated successfully") + + def get_all_notion_user_syncs(self): """ - Retrieve active syncs that are due for synchronization based on their interval. + Retrieve all Notion sync users from the database. Returns: - list: A list of active syncs that are due for synchronization. + list: A list of Notion sync users. """ - logger.info("Retrieving active syncs due for synchronization") - - current_time = datetime.now() - - # The Query filters the active syncs based on the sync_interval_minutes field and last_synced timestamp + logger.info("Retrieving all Notion sync users") response = ( - self.db.table("syncs_active") - .select("*") - .lt("last_synced", (current_time - timedelta(minutes=360)).isoformat()) - .execute() - ) - - force_sync = ( - self.db.table("syncs_active").select("*").eq("force_sync", True).execute() + self.db.from_("syncs_user").select("*").eq("provider", "Notion").execute() ) - merge_data = response.data + force_sync.data - if merge_data: - logger.info("Active syncs retrieved successfully: %s", merge_data) - return [SyncsActive(**sync) for sync in merge_data] - logger.info("No active syncs found due for synchronization") + if response.data: + logger.info("Notion sync users retrieved successfully") + return response.data return [] - -class NotionRepository(BaseRepository): - def __init__(self, session: AsyncSession): - super().__init__(session) - self.session = session - self.db = get_supabase_client() - - async def get_user_notion_files(self, user_id: UUID) -> Sequence[NotionSyncFile]: - query = select(NotionSyncFile).where(NotionSyncFile.user_id == user_id) - response = await self.session.exec(query) - return response.all() - - async def create_notion_files( - self, notion_files: List[NotionSyncFile] - ) -> List[NotionSyncFile]: - try: - self.session.add_all(notion_files) - await self.session.commit() - except IntegrityError: - await self.session.rollback() - raise Exception("Integrity error while creating notion files.") - except Exception as e: - await self.session.rollback() - raise e - - return notion_files - - async def update_notion_file(self, updated_notion_file: NotionSyncFile) -> bool: - try: - is_update = False - query = select(NotionSyncFile).where( - NotionSyncFile.notion_id == updated_notion_file.notion_id + async def get_files_folder_user_sync( + self, + sync_active_id: int, + user_id: UUID, + folder_id: str | None = None, + recursive: bool = False, + ) -> List[SyncFile] | None: + logger.info( + "Retrieving files for user sync with sync_active_id: %s, user_id: %s, folder_id: %s", + sync_active_id, + user_id, + folder_id, + ) + sync_user = await self.get_syncs(user_id=user_id, sync_id=sync_active_id) + if not sync_user: + logger.error( + "No sync user found for sync_active_id: %s, user_id: %s", + sync_active_id, + user_id, ) - result = await self.session.exec(query) - existing_page = result.one_or_none() - - if existing_page: - # Update existing page - existing_page.name = updated_notion_file.name - existing_page.last_modified = updated_notion_file.last_modified - self.session.add(existing_page) - is_update = True - else: - # Add new page - self.session.add(updated_notion_file) - - await self.session.commit() - - # Refresh the object that's actually in the session - refreshed_file = existing_page if existing_page else updated_notion_file - await self.session.refresh(refreshed_file) - - logger.info(f"Updated notion file in notion repo: {refreshed_file}") - return is_update + return None - except IntegrityError as ie: - logger.error(f"IntegrityError occurred: {ie}") - await self.session.rollback() - raise Exception("Integrity error while updating notion file.") - except Exception as e: - logger.error(f"Exception occurred: {e}") - await self.session.rollback() - raise - - async def get_notion_files_by_ids(self, ids: List[str]) -> Sequence[NotionSyncFile]: - query = select(NotionSyncFile).where(NotionSyncFile.notion_id.in_(ids)) # type: ignore - response = await self.session.exec(query) - return response.all() - - async def get_notion_files_by_parent_id( - self, parent_id: str | None - ) -> Sequence[NotionSyncFile]: - query = select(NotionSyncFile).where(NotionSyncFile.parent_id == parent_id) - response = await self.session.exec(query) - return response.all() - - async def get_all_notion_files(self) -> Sequence[NotionSyncFile]: - query = select(NotionSyncFile) - response = await self.session.exec(query) - return response.all() + provider = sync_user.provider.lower() + sync_provider = self.sync_provider_mapping[SyncProvider(provider)] - async def is_folder_page(self, page_id: str) -> bool: - query = select(NotionSyncFile).where(NotionSyncFile.parent_id == page_id) - response = await self.session.exec(query) - return response.first() is not None + if sync_user.credentials is None: + raise SyncEmptyCredentials - async def delete_notion_page(self, notion_id: UUID): - query = select(NotionSyncFile).where(NotionSyncFile.notion_id == notion_id) - response = await self.session.exec(query) - notion_file = response.first() - if notion_file: - await self.session.delete(notion_file) - await self.session.commit() - return notion_file - return None - - async def delete_notion_pages(self, notion_ids: List[UUID]): - query = select(NotionSyncFile).where( - or_( - col(NotionSyncFile.notion_id).in_(notion_ids), - col(NotionSyncFile.parent_id).in_(notion_ids), - ) + return await sync_provider.aget_files( + sync_user.credentials, folder_id if folder_id else "", recursive ) - response = await self.session.exec(query) - notion_files = response.all() - if notion_files: - for notion_file in notion_files: - await self.session.delete(notion_file) - await self.session.commit() - return notion_files - return None diff --git a/backend/api/quivr_api/modules/sync/repository/sync_user.py b/backend/api/quivr_api/modules/sync/repository/sync_user.py deleted file mode 100644 index efb3e9c892c6..000000000000 --- a/backend/api/quivr_api/modules/sync/repository/sync_user.py +++ /dev/null @@ -1,255 +0,0 @@ -import json -from typing import List, Literal -from uuid import UUID - -from quivr_api.logger import get_logger -from quivr_api.modules.dependencies import get_supabase_client -from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput -from quivr_api.modules.sync.entity.sync_models import SyncFile, SyncsUser -from quivr_api.modules.sync.service.sync_notion import SyncNotionService -from quivr_api.modules.sync.utils.sync import ( - AzureDriveSync, - BaseSync, - DropboxSync, - GitHubSync, - GoogleDriveSync, - NotionSync, -) - -logger = get_logger(__name__) - - -class SyncUserRepository: - def __init__(self): - """ - Initialize the Sync class with a Supabase client. - """ - supabase_client = get_supabase_client() - self.db = supabase_client # type: ignore - logger.debug("Supabase client initialized") - - def create_sync_user( - self, - sync_user_input: SyncsUserInput, - ): - """ - Create a new sync user in the database. - - Args: - sync_user_input (SyncsUserInput): The input data for creating a sync user. - - Returns: - dict or None: The created sync user data or None if creation failed. - """ - logger.info("Creating sync user with input: %s", sync_user_input) - response = ( - self.db.from_("syncs_user") - .insert(sync_user_input.model_dump(exclude_none=True, exclude_unset=True)) - .execute() - ) - if response.data: - logger.info("Sync user created successfully: %s", response.data[0]) - return response.data[0] - logger.warning("Failed to create sync user") - - def get_sync_user_by_id(self, sync_id: int) -> SyncsUser | None: - """ - Retrieve sync users from the database. - """ - response = self.db.from_("syncs_user").select("*").eq("id", sync_id).execute() - if response.data: - logger.info("Sync user found: %s", response.data[0]) - return SyncsUser.model_validate(response.data[0]) - logger.error("No sync user found for sync_id: %s", sync_id) - - def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None): - """ - Retrieve sync users from the database. - - Args: - user_id (str): The user ID to filter sync users. - sync_user_id (int, optional): The sync user ID to filter sync users. Defaults to None. - - Returns: - list: A list of sync users matching the criteria. - """ - logger.info( - "Retrieving sync users for user_id: %s, sync_user_id: %s", - user_id, - sync_user_id, - ) - query = self.db.from_("syncs_user").select("*").eq("user_id", user_id) - if sync_user_id: - query = query.eq("id", str(sync_user_id)) - response = query.execute() - if response.data: - # logger.info("Sync users retrieved successfully: %s", response.data) - return response.data - logger.warning( - "No sync users found for user_id: %s, sync_user_id: %s", - user_id, - sync_user_id, - ) - return [] - - def get_sync_user_by_state(self, state: dict) -> SyncsUser | None: - """ - Retrieve a sync user by their state. - - Args: - state (dict): The state to filter sync users. - - Returns: - dict or None: The sync user data matching the state or None if not found. - """ - logger.info("Getting sync user by state: %s", state) - - state_str = json.dumps(state) - response = ( - self.db.from_("syncs_user").select("*").eq("state", state_str).execute() - ) - if response.data and len(response.data) > 0: - logger.info("Sync user found by state: %s", response.data[0]) - sync_user = SyncsUser.model_validate(response.data[0]) - return sync_user - logger.error("No sync user found for state: %s", state) - return None - - def delete_sync_user(self, sync_id: int, user_id: UUID | str): - """ - Delete a sync user from the database. - - Args: - provider (str): The provider of the sync user. - user_id (str): The user ID of the sync user. - """ - logger.info( - "Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id - ) - self.db.from_("syncs_user").delete().eq("id", sync_id).eq( - "user_id", user_id - ).execute() - logger.info("Sync user deleted successfully") - - def update_sync_user( - self, sync_user_id: UUID, state: dict, sync_user_input: SyncUserUpdateInput - ): - """ - Update a sync user in the database. - - Args: - sync_user_id (str): The user ID of the sync user. - state (dict): The state to filter sync users. - sync_user_input (SyncUserUpdateInput): The input data for updating the sync user. - """ - logger.info( - "Updating sync user with user_id: %s, state: %s, input: %s", - sync_user_id, - state, - sync_user_input, - ) - - state_str = json.dumps(state) - self.db.from_("syncs_user").update(sync_user_input.model_dump()).eq( - "user_id", str(sync_user_id) - ).eq("state", state_str).execute() - logger.info("Sync user updated successfully") - - def get_all_notion_user_syncs(self): - """ - Retrieve all Notion sync users from the database. - - Returns: - list: A list of Notion sync users. - """ - logger.info("Retrieving all Notion sync users") - response = ( - self.db.from_("syncs_user").select("*").eq("provider", "Notion").execute() - ) - if response.data: - logger.info("Notion sync users retrieved successfully") - return response.data - logger.warning("No Notion sync users found") - return [] - - async def get_files_folder_user_sync( - self, - sync_active_id: int, - user_id: UUID, - notion_service: SyncNotionService | None, - folder_id: str | None = None, - recursive: bool = False, - ) -> None | dict[str, List[SyncFile]] | Literal["No sync found"]: - """ - Retrieve files from a user's sync folder, either from Google Drive or Azure. - - Args: - sync_active_id (int): The ID of the active sync. - user_id (str): The user ID associated with the active sync. - folder_id (str, optional): The folder ID to filter files. Defaults to None. - - Returns: - dict or str: A dictionary containing the list of files or a string indicating the sync provider. - """ - logger.info( - "Retrieving files for user sync with sync_active_id: %s, user_id: %s, folder_id: %s", - sync_active_id, - user_id, - folder_id, - ) - # Check whether the sync is Google or Azure - sync_user = self.get_syncs_user(user_id=user_id, sync_user_id=sync_active_id) - if not sync_user: - logger.warning( - "No sync user found for sync_active_id: %s, user_id: %s", - sync_active_id, - user_id, - ) - return None - - sync_user = sync_user[0] - sync: BaseSync - - provider = sync_user["provider"].lower() - if provider == "google": - logger.info("Getting files for Google sync") - sync = GoogleDriveSync() - return {"files": sync.get_files(sync_user["credentials"], folder_id)} - elif provider == "azure": - logger.info("Getting files for Azure sync") - sync = AzureDriveSync() - return { - "files": sync.get_files(sync_user["credentials"], folder_id, recursive) - } - elif provider == "dropbox": - logger.info("Getting files for Drop Box sync") - sync = DropboxSync() - return { - "files": sync.get_files( - sync_user["credentials"], folder_id if folder_id else "", recursive - ) - } - elif provider == "notion": - if notion_service is None: - raise ValueError("provider notion but notion_service is None") - logger.info("Getting files for Notion sync") - sync = NotionSync(notion_service=notion_service) - return { - "files": await sync.aget_files( - sync_user["credentials"], folder_id if folder_id else "", recursive - ) - } - elif provider == "github": - logger.info("Getting files for GitHub sync") - sync = GitHubSync() - return { - "files": sync.get_files( - sync_user["credentials"], folder_id if folder_id else "", recursive - ) - } - - else: - logger.warning( - "No sync found for provider: %s", sync_user["provider"], recursive - ) - return "No sync found" diff --git a/backend/api/quivr_api/modules/sync/service/sync_notion.py b/backend/api/quivr_api/modules/sync/service/sync_notion.py index 09ede38cda9b..fa613ed1f03e 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_notion.py +++ b/backend/api/quivr_api/modules/sync/service/sync_notion.py @@ -8,7 +8,7 @@ from quivr_api.modules.dependencies import BaseService from quivr_api.modules.sync.entity.notion_page import NotionPage, NotionSearchResult from quivr_api.modules.sync.entity.sync_models import NotionSyncFile -from quivr_api.modules.sync.repository.sync_repository import NotionRepository +from quivr_api.modules.sync.repository.notion_repository import NotionRepository logger = get_logger(__name__) diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index 242eb4c36aca..6d5ce968ac09 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -1,86 +1,42 @@ -from abc import ABC, abstractmethod -from typing import Dict, List, Union from uuid import UUID from quivr_api.logger import get_logger +from quivr_api.modules.dependencies import BaseService from quivr_api.modules.sync.dto.inputs import ( - SyncsActiveInput, - SyncsActiveUpdateInput, - SyncsUserInput, - SyncUserUpdateInput, + SyncCreateInput, + SyncUpdateInput, ) -from quivr_api.modules.sync.entity.sync_models import SyncsActive, SyncsUser -from quivr_api.modules.sync.repository.sync_repository import Sync -from quivr_api.modules.sync.repository.sync_user import SyncUserRepository +from quivr_api.modules.sync.dto.outputs import SyncsOutput +from quivr_api.modules.sync.repository.sync_repository import SyncsRepository from quivr_api.modules.sync.service.sync_notion import SyncNotionService logger = get_logger(__name__) -class ISyncUserService(ABC): - @abstractmethod - def get_syncs_user(self, user_id: UUID, sync_user_id: Union[int, None] = None): - pass +class SyncsService(BaseService[SyncsRepository]): + repository_cls = SyncsRepository - @abstractmethod - def create_sync_user(self, sync_user_input: SyncsUserInput): - pass + def __init__(self, repository: SyncsRepository): + self.repository = repository - @abstractmethod - def delete_sync_user(self, sync_id: int, user_id: str): - pass + async def create_sync_user(self, sync_user_input: SyncCreateInput) -> SyncsOutput: + sync = await self.repository.create_sync(sync_user_input) + return sync.to_dto() - @abstractmethod - def get_sync_user_by_state(self, state: Dict) -> Union["SyncsUser", None]: - pass + def get_syncs(self, user_id: UUID, sync_id: int | None = None): + return self.repository.get_syncs(user_id, sync_id) - @abstractmethod - def get_sync_user_by_id(self, sync_id: int): - pass + def delete_sync(self, sync_id: int, user_id: str): + return self.repository.delete_sync(sync_id, user_id) - @abstractmethod - def update_sync_user( - self, sync_user_id: UUID, state: Dict, sync_user_input: SyncUserUpdateInput - ): - pass - - @abstractmethod - def get_all_notion_user_syncs(self): - pass - - @abstractmethod - async def get_files_folder_user_sync( - self, - sync_active_id: int, - user_id: UUID, - folder_id: Union[str, None] = None, - recursive: bool = False, - notion_service: Union["SyncNotionService", None] = None, - ): - pass - - -class SyncUserService(ISyncUserService): - def __init__(self): - self.repository = SyncUserRepository() - - def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None): - return self.repository.get_syncs_user(user_id, sync_user_id) - - def create_sync_user(self, sync_user_input: SyncsUserInput): - return self.repository.create_sync_user(sync_user_input) - - def delete_sync_user(self, sync_id: int, user_id: str): - return self.repository.delete_sync_user(sync_id, user_id) - - def get_sync_user_by_state(self, state: dict) -> SyncsUser | None: + def get_sync_by_state(self, state: dict) -> SyncsOutput | None: return self.repository.get_sync_user_by_state(state) - def get_sync_user_by_id(self, sync_id: int): - return self.repository.get_sync_user_by_id(sync_id) + def get_sync_by_id(self, sync_id: int): + return self.repository.get_sync_id(sync_id) - def update_sync_user( - self, sync_user_id: UUID, state: dict, sync_user_input: SyncUserUpdateInput + def update_sync( + self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput ): return self.repository.update_sync_user(sync_user_id, state, sync_user_input) @@ -102,60 +58,3 @@ async def get_files_folder_user_sync( recursive=recursive, notion_service=notion_service, ) - - -class ISyncService(ABC): - @abstractmethod - def create_sync_active( - self, sync_active_input: SyncsActiveInput, user_id: str - ) -> Union["SyncsActive", None]: - pass - - @abstractmethod - def get_syncs_active(self, user_id: str) -> List[SyncsActive]: - pass - - @abstractmethod - def update_sync_active( - self, sync_id: int, sync_active_input: SyncsActiveUpdateInput - ): - pass - - @abstractmethod - def delete_sync_active(self, sync_active_id: int, user_id: UUID): - pass - - @abstractmethod - async def get_syncs_active_in_interval(self) -> List[SyncsActive]: - pass - - @abstractmethod - def get_details_sync_active(self, sync_active_id: int): - pass - - -class SyncService(ISyncService): - def __init__(self): - self.repository = Sync() - - def create_sync_active( - self, sync_active_input: SyncsActiveInput, user_id: str - ) -> SyncsActive | None: - return self.repository.create_sync_active(sync_active_input, user_id) - - def get_syncs_active(self, user_id: str) -> List[SyncsActive]: - return self.repository.get_syncs_active(user_id) - - def update_sync_active( - self, sync_id: int, sync_active_input: SyncsActiveUpdateInput - ): - return self.repository.update_sync_active(sync_id, sync_active_input) - - def delete_sync_active(self, sync_active_id: int, user_id: UUID): - return self.repository.delete_sync_active(sync_active_id, user_id) - - async def get_syncs_active_in_interval(self) -> List[SyncsActive]: - return await self.repository.get_syncs_active_in_interval() - - def get_details_sync_active(self, sync_active_id: int): - return self.repository.get_details_sync_active(sync_active_id) diff --git a/backend/api/quivr_api/modules/sync/tests/conftest.py b/backend/api/quivr_api/modules/sync/tests/conftest.py index 0955e2edbd8b..bd52fbc4d99b 100644 --- a/backend/api/quivr_api/modules/sync/tests/conftest.py +++ b/backend/api/quivr_api/modules/sync/tests/conftest.py @@ -30,12 +30,12 @@ NotificationService, ) from quivr_api.modules.sync.dto.inputs import ( + SyncCreateInput, SyncFileInput, SyncFileUpdateInput, SyncsActiveInput, SyncsActiveUpdateInput, - SyncsUserInput, - SyncUserUpdateInput, + SyncUpdateInput, ) from quivr_api.modules.sync.entity.notion_page import ( BlockParent, @@ -487,7 +487,7 @@ def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None): def get_sync_user_by_id(self, sync_id: int): return self.map_id[sync_id] - def create_sync_user(self, sync_user_input: SyncsUserInput): + def create_sync_user(self, sync_user_input: SyncCreateInput): id = len(self.map_userid) + 1 self.map_userid[sync_user_input.user_id] = SyncsUser( id=id, **sync_user_input.model_dump() @@ -503,7 +503,7 @@ def get_sync_user_by_state(self, state: dict) -> SyncsUser | None: return list(self.map_userid.values())[-1] def update_sync_user( - self, sync_user_id: UUID, state: dict, sync_user_input: SyncUserUpdateInput + self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput ): return diff --git a/backend/api/quivr_api/modules/sync/tests/test_notion_service.py b/backend/api/quivr_api/modules/sync/tests/test_notion_service.py index 7f0a429eb9de..df1d82109c9d 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_notion_service.py +++ b/backend/api/quivr_api/modules/sync/tests/test_notion_service.py @@ -7,7 +7,7 @@ from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionPage from quivr_api.modules.sync.entity.notion_page import NotionSearchResult -from quivr_api.modules.sync.repository.sync_repository import NotionRepository +from quivr_api.modules.sync.repository.notion_repository import NotionRepository from quivr_api.modules.sync.service.sync_notion import ( SyncNotionService, fetch_limit_notion_pages, diff --git a/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py b/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py new file mode 100644 index 000000000000..e7b8cf63e5f5 --- /dev/null +++ b/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py @@ -0,0 +1,31 @@ +class SyncException(Exception): + def __init__(self, message="A sync-related error occurred"): + self.message = message + super().__init__(self.message) + + +class SyncCreationError(SyncException): + def __init__(self, message="An error occurred while creating"): + super().__init__(message) + + +class SyncUpdateError(SyncException): + def __init__(self, message="An error occurred while updating"): + super().__init__(message) + + +class SyncDeleteError(SyncException): + def __init__(self, message="An error occurred while deleting"): + super().__init__(message) + + +class SyncEmptyCredentials(SyncException): + def __init__( + self, message="You do not have credentials to access files from this sync." + ): + super().__init__(message) + + +class SyncNotFoundException(SyncException): + def __init__(self, message="The requested sync was not found"): + super().__init__(message) diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py index 5f6a63628d4a..2a2e5563ef94 100644 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ b/backend/api/quivr_api/modules/sync/utils/syncutils.py @@ -346,16 +346,16 @@ async def get_syncfiles_from_ids( async def direct_sync( self, sync_active: SyncsActive, - user_sync: SyncsUser, + sync_user: SyncsUser, files_ids: list[str], folder_ids: list[str], ): files = await self.get_syncfiles_from_ids( - user_sync.credentials, files_ids, folder_ids + sync_user.credentials, files_ids, folder_ids ) processed_files = await self.process_sync_files( files=files, - current_user=user_sync, + current_user=sync_user, sync_active=sync_active, ) diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index ceb1632c8ce4..d1bd6fb6e67a 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -20,7 +20,7 @@ ) from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository from quivr_api.modules.sync.service.sync_notion import SyncNotionService -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.vector.repository.vectors_repository import VectorRepository from quivr_api.modules.vector.service.vector_service import VectorService from quivr_api.utils.telemetry import maybe_send_telemetry @@ -53,8 +53,8 @@ supabase_client = get_supabase_client() # document_vector_store = get_documents_vector_store() notification_service = NotificationService() -sync_active_service = SyncService() -sync_user_service = SyncUserService() +sync_active_service = SyncsService() +sync_user_service = SyncsService() sync_files_repo_service = SyncFilesRepository() brain_service = BrainService() brain_vectors = BrainsVectors() diff --git a/backend/worker/quivr_worker/syncs/process_active_syncs.py b/backend/worker/quivr_worker/syncs/process_active_syncs.py index d190c219166b..92733e466d63 100644 --- a/backend/worker/quivr_worker/syncs/process_active_syncs.py +++ b/backend/worker/quivr_worker/syncs/process_active_syncs.py @@ -7,14 +7,13 @@ from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) -from quivr_api.modules.sync.entity.sync_models import SyncsActive -from quivr_api.modules.sync.repository.sync_repository import NotionRepository +from quivr_api.modules.sync.repository.notion_repository import NotionRepository from quivr_api.modules.sync.service.sync_notion import ( SyncNotionService, fetch_limit_notion_pages, update_notion_pages, ) -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.syncutils import SyncUtils from sqlalchemy.ext.asyncio import AsyncEngine from sqlmodel import text @@ -35,9 +34,7 @@ async def process_sync( ): async with build_syncs_utils(services) as mapping_syncs_utils: try: - user_sync = services.sync_user_service.get_sync_user_by_id( - sync.syncs_user_id - ) + user_sync = services.sync_user_service.get_sync_by_id(sync.syncs_user_id) services.notification_service.remove_notification_by_id( sync.notification_id ) @@ -45,7 +42,7 @@ async def process_sync( sync_util = mapping_syncs_utils[user_sync.provider.lower()] await sync_util.direct_sync( sync_active=sync, - user_sync=user_sync, + sync_user=user_sync, files_ids=files_ids, folder_ids=folder_ids, ) @@ -69,8 +66,8 @@ async def process_all_active_syncs(sync_services: SyncServices): async def _process_all_active_syncs( - sync_active_service: SyncService, - sync_user_service: SyncUserService, + sync_active_service: SyncsService, + sync_user_service: SyncsService, mapping_syncs_utils: dict[str, SyncUtils], notification_service: NotificationService, ): @@ -78,7 +75,7 @@ async def _process_all_active_syncs( logger.debug(f"Found active syncs: {active_syncs}") for sync in active_syncs: try: - user_sync = sync_user_service.get_sync_user_by_id(sync.syncs_user_id) + user_sync = sync_user_service.get_sync_by_id(sync.syncs_user_id) # TODO: this should be global # NOTE: Remove the global notification notification_service.remove_notification_by_id(sync.notification_id) @@ -104,7 +101,7 @@ async def process_notion_sync( await session.execute( text("SET SESSION idle_in_transaction_session_timeout = '5min';") ) - sync_user_service = SyncUserService() + sync_user_service = SyncsService() notion_repository = NotionRepository(session) notion_service = SyncNotionService(notion_repository) diff --git a/backend/worker/quivr_worker/syncs/store_notion.py b/backend/worker/quivr_worker/syncs/store_notion.py index 821de8874096..7349eef3cfd3 100644 --- a/backend/worker/quivr_worker/syncs/store_notion.py +++ b/backend/worker/quivr_worker/syncs/store_notion.py @@ -3,7 +3,7 @@ from notion_client import Client from quivr_api.logger import get_logger -from quivr_api.modules.sync.repository.sync_repository import NotionRepository +from quivr_api.modules.sync.repository.notion_repository import NotionRepository from quivr_api.modules.sync.service.sync_notion import ( SyncNotionService, fetch_limit_notion_pages, diff --git a/backend/worker/quivr_worker/syncs/utils.py b/backend/worker/quivr_worker/syncs/utils.py index bbc3c75f8588..da8a6af83b3b 100644 --- a/backend/worker/quivr_worker/syncs/utils.py +++ b/backend/worker/quivr_worker/syncs/utils.py @@ -11,12 +11,12 @@ from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) +from quivr_api.modules.sync.repository.notion_repository import NotionRepository from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository -from quivr_api.modules.sync.repository.sync_repository import NotionRepository from quivr_api.modules.sync.service.sync_notion import ( SyncNotionService, ) -from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.sync import ( AzureDriveSync, DropboxSync, @@ -37,8 +37,8 @@ @dataclass class SyncServices: async_engine: AsyncEngine - sync_active_service: SyncService - sync_user_service: SyncUserService + sync_active_service: SyncsService + sync_user_service: SyncsService sync_files_repo_service: SyncFilesRepository notification_service: NotificationService brain_vectors: BrainsVectors From 3daef917262b1a89b02a8d38c87eef7619f62464 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 17 Sep 2024 18:34:04 +0200 Subject: [PATCH 02/63] sync repository --- .../sync/repository/sync_repository.py | 51 +++++++------------ .../modules/sync/service/sync_service.py | 2 +- .../modules/sync/utils/sync_exceptions.py | 5 ++ 3 files changed, 23 insertions(+), 35 deletions(-) diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index 6282f0ffcbe7..058c9761b591 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -1,9 +1,8 @@ -import json from sqlite3 import IntegrityError -from typing import List +from typing import Any, List from uuid import UUID -from sqlmodel import select +from sqlmodel import delete, select from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.logger import get_logger @@ -22,7 +21,10 @@ GoogleDriveSync, NotionSync, ) -from quivr_api.modules.sync.utils.sync_exceptions import SyncEmptyCredentials +from quivr_api.modules.sync.utils.sync_exceptions import ( + SyncEmptyCredentials, + SyncProviderError, +) logger = get_logger(__name__) @@ -129,44 +131,22 @@ async def get_sync_user_by_state(self, state: dict) -> Syncs | None: return None - def delete_sync(self, sync_id: int, user_id: UUID | str): - """ - Delete a sync user from the database. - - Args: - provider (str): The provider of the sync user. - user_id (str): The user ID of the sync user. - """ + async def delete_sync(self, sync_id: int, user_id: UUID): logger.info( "Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id ) - self.db.from_("syncs_user").delete().eq("id", sync_id).eq( - "user_id", user_id - ).execute() + await self.session.execute( + delete(Syncs).where(Syncs.id == sync_id).where(Syncs.user_id == user_id) + ) logger.info("Sync user deleted successfully") - def update_sync_user( - self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput - ): - """ - Update a sync user in the database. - - Args: - sync_user_id (str): The user ID of the sync user. - state (dict): The state to filter sync users. - sync_user_input (SyncUserUpdateInput): The input data for updating the sync user. - """ + def update_sync(self, sync_id: UUID, sync_input: SyncUpdateInput | dict[str, Any]): logger.info( "Updating sync user with user_id: %s, state: %s, input: %s", - sync_user_id, - state, - sync_user_input, + sync_id, + sync_input, ) - state_str = json.dumps(state) - self.db.from_("syncs_user").update(sync_user_input.model_dump()).eq( - "user_id", str(sync_user_id) - ).eq("state", state_str).execute() logger.info("Sync user updated successfully") def get_all_notion_user_syncs(self): @@ -208,7 +188,10 @@ async def get_files_folder_user_sync( return None provider = sync_user.provider.lower() - sync_provider = self.sync_provider_mapping[SyncProvider(provider)] + try: + sync_provider = self.sync_provider_mapping[SyncProvider(provider)] + except KeyError: + raise SyncProviderError if sync_user.credentials is None: raise SyncEmptyCredentials diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index 6d5ce968ac09..dcc3e3aeb24c 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -38,7 +38,7 @@ def get_sync_by_id(self, sync_id: int): def update_sync( self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput ): - return self.repository.update_sync_user(sync_user_id, state, sync_user_input) + return self.repository.update_sync(sync_user_id, state, sync_user_input) def get_all_notion_user_syncs(self): return self.repository.get_all_notion_user_syncs() diff --git a/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py b/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py index e7b8cf63e5f5..5d2ad17fcf31 100644 --- a/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py +++ b/backend/api/quivr_api/modules/sync/utils/sync_exceptions.py @@ -29,3 +29,8 @@ def __init__( class SyncNotFoundException(SyncException): def __init__(self, message="The requested sync was not found"): super().__init__(message) + + +class SyncProviderError(SyncException): + def __init__(self, message="Unknown provider"): + super().__init__(message) From 5803fb9ecf43d2ff1c9cc8a6a4d3eac767fe12b4 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 17 Sep 2024 19:17:21 +0200 Subject: [PATCH 03/63] sync service --- .../api/quivr_api/modules/sync/dto/outputs.py | 9 +++--- .../modules/sync/entity/sync_models.py | 9 ++++-- .../sync/repository/sync_repository.py | 31 ++++++++++++++----- .../modules/sync/service/sync_service.py | 28 ++++++++--------- .../modules/sync/tests/test_sync_service.py | 0 5 files changed, 50 insertions(+), 27 deletions(-) create mode 100644 backend/api/quivr_api/modules/sync/tests/test_sync_service.py diff --git a/backend/api/quivr_api/modules/sync/dto/outputs.py b/backend/api/quivr_api/modules/sync/dto/outputs.py index 4702ae541753..6330abacba2d 100644 --- a/backend/api/quivr_api/modules/sync/dto/outputs.py +++ b/backend/api/quivr_api/modules/sync/dto/outputs.py @@ -1,4 +1,5 @@ from enum import Enum +from uuid import UUID from pydantic import BaseModel @@ -22,7 +23,7 @@ class SyncsDescription(BaseModel): class SyncsOutput(BaseModel): - user_id: str - provider: str - state: dict - credentials: dict + user_id: UUID + provider: SyncProvider + state: dict | None + credentials: dict | None diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index f1f7cbfd8200..e8a01eaf37ba 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -17,7 +17,7 @@ ) from sqlmodel import UUID as PGUUID -from quivr_api.modules.sync.dto.outputs import SyncsOutput +from quivr_api.modules.sync.dto.outputs import SyncProvider, SyncsOutput from quivr_api.modules.user.entity.user_identity import User @@ -63,7 +63,12 @@ class Syncs(SQLModel, table=True): additional_data: dict def to_dto(self) -> SyncsOutput: - return SyncsOutput(user_id=self.user_id, provider=self.provider) + return SyncsOutput( + user_id=self.user_id, + provider=SyncProvider(self.provider), + credentials=self.credentials, + state=self.state, + ) class NotionSyncFile(SQLModel, table=True): diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index 058c9761b591..d81894bb2a67 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -12,7 +12,6 @@ from quivr_api.modules.sync.entity.sync_models import SyncFile, Syncs from quivr_api.modules.sync.repository.notion_repository import NotionRepository from quivr_api.modules.sync.service.sync_notion import SyncNotionService -from quivr_api.modules.sync.utils.exceptions import SyncNotFoundException from quivr_api.modules.sync.utils.sync import ( AzureDriveSync, BaseSync, @@ -23,7 +22,9 @@ ) from quivr_api.modules.sync.utils.sync_exceptions import ( SyncEmptyCredentials, + SyncNotFoundException, SyncProviderError, + SyncUpdateError, ) logger = get_logger(__name__) @@ -70,7 +71,7 @@ async def create_sync( await self.session.rollback() raise - async def get_sync_id(self, sync_id: int) -> Syncs | None: + async def get_sync_id(self, sync_id: int) -> Syncs: """ Retrieve sync users from the database. """ @@ -110,7 +111,7 @@ async def get_syncs(self, user_id: UUID, sync_id: int | None = None): raise SyncNotFoundException() return sync - async def get_sync_user_by_state(self, state: dict) -> Syncs | None: + async def get_sync_user_by_state(self, state: dict) -> Syncs: """ Retrieve a sync user by their state. @@ -140,14 +141,30 @@ async def delete_sync(self, sync_id: int, user_id: UUID): ) logger.info("Sync user deleted successfully") - def update_sync(self, sync_id: UUID, sync_input: SyncUpdateInput | dict[str, Any]): - logger.info( + async def update_sync( + self, sync: Syncs, sync_input: SyncUpdateInput | dict[str, Any] + ): + logger.debug( "Updating sync user with user_id: %s, state: %s, input: %s", - sync_id, + sync.id, sync_input, ) + try: + if isinstance(sync_input, dict): + update_data = sync_input + else: + update_data = sync_input.model_dump(exclude_unset=True) + for field in update_data: + setattr(sync, field, update_data[field]) - logger.info("Sync user updated successfully") + self.session.add(sync) + await self.session.commit() + await self.session.refresh(sync) + return sync + except IntegrityError as e: + await self.session.rollback() + logger.error(f"Error updating knowledge {e}") + raise SyncUpdateError def get_all_notion_user_syncs(self): """ diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index dcc3e3aeb24c..62324d652979 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -8,7 +8,6 @@ ) from quivr_api.modules.sync.dto.outputs import SyncsOutput from quivr_api.modules.sync.repository.sync_repository import SyncsRepository -from quivr_api.modules.sync.service.sync_notion import SyncNotionService logger = get_logger(__name__) @@ -23,24 +22,27 @@ async def create_sync_user(self, sync_user_input: SyncCreateInput) -> SyncsOutpu sync = await self.repository.create_sync(sync_user_input) return sync.to_dto() - def get_syncs(self, user_id: UUID, sync_id: int | None = None): + async def get_syncs(self, user_id: UUID, sync_id: int | None = None): return self.repository.get_syncs(user_id, sync_id) - def delete_sync(self, sync_id: int, user_id: str): - return self.repository.delete_sync(sync_id, user_id) + async def delete_sync(self, sync_id: int, user_id: UUID): + await self.repository.delete_sync(sync_id, user_id) - def get_sync_by_state(self, state: dict) -> SyncsOutput | None: - return self.repository.get_sync_user_by_state(state) + async def get_sync_by_state(self, state: dict) -> SyncsOutput: + sync = await self.repository.get_sync_user_by_state(state) + return sync.to_dto() - def get_sync_by_id(self, sync_id: int): + async def get_sync_by_id(self, sync_id: int): return self.repository.get_sync_id(sync_id) - def update_sync( - self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput - ): - return self.repository.update_sync(sync_user_id, state, sync_user_input) + async def update_sync( + self, sync_id: int, sync_user_input: SyncUpdateInput + ) -> SyncsOutput: + sync = await self.repository.get_sync_id(sync_id) + sync = await self.repository.update_sync(sync, sync_user_input) + return sync.to_dto() - def get_all_notion_user_syncs(self): + async def get_all_notion_user_syncs(self): return self.repository.get_all_notion_user_syncs() async def get_files_folder_user_sync( @@ -49,12 +51,10 @@ async def get_files_folder_user_sync( user_id: UUID, folder_id: str | None = None, recursive: bool = False, - notion_service: SyncNotionService | None = None, ): return await self.repository.get_files_folder_user_sync( sync_active_id=sync_active_id, user_id=user_id, folder_id=folder_id, recursive=recursive, - notion_service=notion_service, ) diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_service.py b/backend/api/quivr_api/modules/sync/tests/test_sync_service.py new file mode 100644 index 000000000000..e69de29bb2d1 From 13b288fa955a0f8873aa8be910a1e1f5d713630b Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 18 Sep 2024 13:48:30 +0200 Subject: [PATCH 04/63] oauth2 state --- .../knowledge/service/knowledge_service.py | 83 +-- .../knowledge/tests/test_knowledge_entity.py | 2 - .../sync/controller/azure_sync_routes.py | 99 ++- .../sync/controller/dropbox_sync_routes.py | 80 ++- .../sync/controller/github_sync_routes.py | 81 ++- .../sync/controller/google_sync_routes.py | 88 ++- .../sync/controller/notion_sync_routes.py | 82 ++- .../sync/controller/successfull_connection.py | 2 +- .../modules/sync/controller/sync_routes.py | 266 +------ .../api/quivr_api/modules/sync/dto/inputs.py | 7 +- .../api/quivr_api/modules/sync/dto/outputs.py | 1 + .../modules/sync/entity/notion_page.py | 1 + .../modules/sync/entity/sync_models.py | 13 +- .../modules/sync/service/sync_service.py | 8 +- .../quivr_api/modules/sync/tests/conftest.py | 667 +++++++++-------- .../modules/sync/tests/test_sync_service.py | 40 ++ .../modules/sync/tests/test_syncutils.py | 671 +++++++++--------- .../quivr_api/modules/sync/utils/oauth2.py | 9 + .../quivr_api/modules/sync/utils/syncutils.py | 9 - backend/worker/tests/test_sync.py | 0 20 files changed, 1046 insertions(+), 1163 deletions(-) create mode 100644 backend/api/quivr_api/modules/sync/utils/oauth2.py create mode 100644 backend/worker/tests/test_sync.py diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 310a68e25211..4c7b48571026 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -28,11 +28,6 @@ KnowledgeForbiddenAccess, UploadError, ) -from quivr_api.modules.sync.entity.sync_models import ( - DBSyncFile, - DownloadedSyncFile, - SyncFile, -) from quivr_api.modules.upload.service.upload_file import check_file_exists logger = get_logger(__name__) @@ -294,43 +289,43 @@ async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None: # TODO: REDO THIS MESS !!!! # REMOVE ALL SYNC TABLES and start from scratch - async def update_or_create_knowledge_sync( - self, - brain_id: UUID, - user_id: UUID, - file: SyncFile, - new_sync_file: DBSyncFile | None, - prev_sync_file: DBSyncFile | None, - downloaded_file: DownloadedSyncFile, - source: str, - source_link: str, - ) -> Knowledge: - sync_id = None - # TODO: THIS IS A HACK!! Remove all of this - if prev_sync_file: - prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id) - if len(prev_knowledge.brains) > 1: - await self.repository.remove_knowledge_from_brain( - prev_knowledge.id, brain_id - ) - else: - await self.repository.remove_knowledge_by_id(prev_knowledge.id) - sync_id = prev_sync_file.id + # async def update_or_create_knowledge_sync( + # self, + # brain_id: UUID, + # user_id: UUID, + # file: SyncFile, + # new_sync_file: DBSyncFile | None, + # prev_sync_file: DBSyncFile | None, + # downloaded_file: DownloadedSyncFile, + # source: str, + # source_link: str, + # ) -> Knowledge: + # sync_id = None + # # TODO: THIS IS A HACK!! Remove all of this + # if prev_sync_file: + # prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id) + # if len(prev_knowledge.brains) > 1: + # await self.repository.remove_knowledge_from_brain( + # prev_knowledge.id, brain_id + # ) + # else: + # await self.repository.remove_knowledge_by_id(prev_knowledge.id) + # sync_id = prev_sync_file.id - sync_id = new_sync_file.id if new_sync_file else sync_id - knowledge_to_add = CreateKnowledgeProperties( - brain_id=brain_id, - file_name=file.name, - extension=downloaded_file.extension, - source=source, - status=KnowledgeStatus.PROCESSING, - source_link=source_link, - file_size=file.size if file.size else 0, - # FIXME (@aminediro): This is a temporary fix, redo in KMS - file_sha1=None, - metadata={"sync_file_id": str(sync_id)}, - ) - added_knowledge = await self.insert_knowledge_brain( - knowledge_to_add=knowledge_to_add, user_id=user_id - ) - return added_knowledge + # sync_id = new_sync_file.id if new_sync_file else sync_id + # knowledge_to_add = CreateKnowledgeProperties( + # brain_id=brain_id, + # file_name=file.name, + # extension=downloaded_file.extension, + # source=source, + # status=KnowledgeStatus.PROCESSING, + # source_link=source_link, + # file_size=file.size if file.size else 0, + # # FIXME (@aminediro): This is a temporary fix, redo in KMS + # file_sha1=None, + # metadata={"sync_file_id": str(sync_id)}, + # ) + # added_knowledge = await self.insert_knowledge_brain( + # knowledge_to_add=knowledge_to_add, user_id=user_id + # ) + # return added_knowledge diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index b9732aa1fba2..7376559ebc39 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -209,8 +209,6 @@ async def test_knowledge_dto(session, user, brain): km_dto = await km.to_dto() - breakpoint() - assert km_dto.file_name == km.file_name assert km_dto.url == km.url assert km_dto.extension == km.extension diff --git a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py index 4ec6ce4f137a..9a4585c45ce6 100644 --- a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py @@ -1,14 +1,17 @@ import os import requests -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import HTMLResponse from msal import ConfidentialClientApplication from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user +from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.oauth2 import Oauth2State +from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -16,9 +19,8 @@ # Initialize logger logger = get_logger(__name__) -# Initialize sync service -sync_service = SyncsService() -sync_user_service = SyncsService() + +syncs_service_dep = get_service(SyncsService) # Initialize API router azure_sync_router = APIRouter() @@ -41,8 +43,11 @@ dependencies=[Depends(AuthBearer())], tags=["Sync"], ) -def authorize_azure( - request: Request, name: str, current_user: UserIdentity = Depends(get_current_user) +async def authorize_azure( + request: Request, + name: str, + syncs_service: SyncsService = Depends(syncs_service_dep), + current_user: UserIdentity = Depends(get_current_user), ): """ Authorize Azure sync for the current user. @@ -58,25 +63,38 @@ def authorize_azure( CLIENT_ID, client_credential=CLIENT_SECRET, authority=AUTHORITY ) logger.debug(f"Authorizing Azure sync for user: {current_user.id}") - state = f"user_id={current_user.id}, name={name}" - flow = client.initiate_auth_code_flow( - scopes=SCOPE, redirect_uri=REDIRECT_URI, state=state, prompt="select_account" - ) + state_struct = Oauth2State(name=name, user_id=current_user.id) + state = state_struct.model_dump_json() sync_user_input = SyncCreateInput( - user_id=str(current_user.id), + user_id=current_user.id, name=name, provider="Azure", credentials={}, state={"state": state}, - additional_data={"flow": flow}, ) - sync_user_service.create_sync_user(sync_user_input) + sync = await syncs_service.create_sync_user(sync_user_input) + state_struct.sync_id = sync.id + state = state_struct.model_dump_json() + + flow = client.initiate_auth_code_flow( + scopes=SCOPE, redirect_uri=REDIRECT_URI, state=state, prompt="select_account" + ) + + sync = await syncs_service.update_sync( + sync_id=sync.id, + sync_user_input=SyncUpdateInput( + **{**sync.model_dump(), "additional_data": {"flow": flow}} + ), + ) return {"authorization_url": flow["auth_uri"]} @azure_sync_router.get("/sync/azure/oauth2callback", tags=["Sync"]) -def oauth2callback_azure(request: Request): +async def oauth2callback_azure( + request: Request, + syncs_service: SyncsService = Depends(syncs_service_dep), +): """ Handle OAuth2 callback from Azure. @@ -90,28 +108,45 @@ def oauth2callback_azure(request: Request): CLIENT_ID, client_credential=CLIENT_SECRET, authority=AUTHORITY ) state = request.query_params.get("state") - state_split = state.split(",") - current_user = state_split[0].split("=")[1] # Extract user_id from state - name = state_split[1].split("=")[1] if state else None - state_dict = {"state": state} + + if not state: + raise HTTPException(status_code=400, detail="Invalid state parameter") + + state = Oauth2State.model_validate_json(state) + + if state.sync_id is None: + raise HTTPException( + status_code=400, detail="Invalid state parameter. Unknown sync" + ) + logger.debug( - f"Handling OAuth2 callback for user: {current_user} with state: {state}" + f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" ) - sync_user_state = sync_user_service.get_sync_by_state(state_dict) - logger.info(f"Retrieved sync user state: {sync_user_state}") - if not sync_user_state or state_dict != sync_user_state.state: + try: + sync = await syncs_service.get_sync_by_id(state.sync_id) + except SyncNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) + if ( + not sync + or not sync.state + or state.model_dump(exclude={"sync_id"}) != sync.state["state"] + ): logger.error("Invalid state parameter") raise HTTPException(status_code=400, detail="Invalid state parameter") - if str(sync_user_state.user_id) != current_user: - logger.info(f"Sync user state: {sync_user_state}") - logger.info(f"Current user: {current_user}") - logger.info(f"Sync user state user_id: {sync_user_state.user_id}") - logger.error("Invalid user") + + if sync.user_id != state.user_id: raise HTTPException(status_code=400, detail="Invalid user") + if sync.additional_data is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid sync data" + ) + result = client.acquire_token_by_auth_code_flow( - sync_user_state.additional_data["flow"], dict(request.query_params) + sync.additional_data["flow"], dict(request.query_params) ) if "access_token" not in result: logger.error(f"Failed to acquire token: {result}") @@ -123,7 +158,7 @@ def oauth2callback_azure(request: Request): access_token = result["access_token"] creds = result - logger.info(f"Fetched OAuth2 token for user: {current_user}") + logger.info(f"Fetched OAuth2 token for user: {state.user_id}") # Fetch user email from Microsoft Graph API graph_url = "https://graph.microsoft.com/v1.0/me" @@ -135,10 +170,10 @@ def oauth2callback_azure(request: Request): user_info = response.json() user_email = user_info.get("mail") or user_info.get("userPrincipalName") - logger.info(f"Retrieved email for user: {current_user} - {user_email}") + logger.info(f"Retrieved email for user: {state.user_id} - {user_email}") sync_user_input = SyncUpdateInput(credentials=result, state={}, email=user_email) - sync_user_service.update_sync(current_user, state_dict, sync_user_input) - logger.info(f"Azure sync created successfully for user: {current_user}") + await syncs_service.update_sync(state.sync_id, sync_user_input) + logger.info(f"Azure sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py index 5bd480864005..f461dde461eb 100644 --- a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py @@ -1,14 +1,16 @@ import os -from uuid import UUID from dropbox import Dropbox, DropboxOAuth2Flow -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import HTMLResponse from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user +from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.oauth2 import Oauth2State +from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -20,8 +22,7 @@ SCOPE = ["files.metadata.read", "account_info.read", "files.content.read"] # Initialize sync service -sync_service = SyncsService() -sync_user_service = SyncsService() +syncs_service_dep = get_service(SyncsService) logger = get_logger(__name__) @@ -34,8 +35,11 @@ dependencies=[Depends(AuthBearer())], tags=["Sync"], ) -def authorize_dropbox( - request: Request, name: str, current_user: UserIdentity = Depends(get_current_user) +async def authorize_dropbox( + request: Request, + name: str, + current_user: UserIdentity = Depends(get_current_user), + syncs_service: SyncsService = Depends(syncs_service_dep), ): """ Authorize DropBox sync for the current user. @@ -59,26 +63,30 @@ def authorize_dropbox( token_access_type="offline", scope=SCOPE, ) - state: str = f"user_id={current_user.id}, name={name}" - authorize_url = auth_flow.start(state) - - logger.info( - f"Generated authorization URL: {authorize_url} for user: {current_user.id}" - ) + state_struct = Oauth2State(name=name, user_id=current_user.id) sync_user_input = SyncCreateInput( name=name, - user_id=str(current_user.id), + user_id=current_user.id, provider="DropBox", credentials={}, - state={"state": state}, + state={"state": state_struct.model_dump_json()}, additional_data={}, ) - sync_user_service.create_sync_user(sync_user_input) + sync = await syncs_service.create_sync_user(sync_user_input) + state_struct.sync_id = sync.id + state = state_struct.model_dump_json() + authorize_url = auth_flow.start(state) + logger.info( + f"Generated authorization URL: {authorize_url} for user: {current_user.id}" + ) return {"authorization_url": authorize_url} @dropbox_sync_router.get("/sync/dropbox/oauth2callback", tags=["Sync"]) -def oauth2callback_dropbox(request: Request): +async def oauth2callback_dropbox( + request: Request, + syncs_service: SyncsService = Depends(syncs_service_dep), +): """ Handle OAuth2 callback from DropBox. @@ -97,24 +105,32 @@ def oauth2callback_dropbox(request: Request): logger.debug("Keys in session : %s", session.keys()) logger.debug("Value in session : %s", session.values()) - state = state.split("|")[1] if "|" in state else state # type: ignore - state_dict = {"state": state} - state_split = state.split(",") # type: ignore - current_user = UUID(state_split[0].split("=")[1]) if state else None + state = Oauth2State.model_validate_json(state) + + if state.sync_id is None: + raise HTTPException( + status_code=400, detail="Invalid state parameter. Unknown sync" + ) + logger.debug( - f"Handling OAuth2 callback for user: {current_user} with state: {state} and state_dict: {state_dict}" + f"Handling OAuth2 callback for user: {state.user_id} with state: {state} " ) - sync_user_state = sync_user_service.get_sync_by_state(state_dict) + try: + sync = await syncs_service.get_sync_by_id(state.sync_id) + except SyncNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) - if not sync_user_state or state_dict != sync_user_state.state: + if ( + not sync + or not sync.state + or state.model_dump(exclude={"sync_id"}) != sync.state["state"] + ): logger.error("Invalid state parameter") raise HTTPException(status_code=400, detail="Invalid state parameter") - else: - logger.info( - f"CURRENT USER: {current_user}, SYNC USER STATE USER: {sync_user_state.user_id}" - ) - if sync_user_state.user_id != current_user: + if sync.user_id != state.user_id: raise HTTPException(status_code=400, detail="Invalid user") auth_flow = DropboxOAuth2Flow( @@ -138,7 +154,7 @@ def oauth2callback_dropbox(request: Request): user_email = account_info.email # type: ignore account_id = account_info.account_id # type: ignore - result: dict[str, str] = { + credentials: dict[str, str] = { "access_token": oauth_result.access_token, "refresh_token": oauth_result.refresh_token, "account_id": account_id, @@ -146,12 +162,12 @@ def oauth2callback_dropbox(request: Request): } sync_user_input = SyncUpdateInput( - credentials=result, + credentials=credentials, state={}, email=user_email, ) - sync_user_service.update_sync(current_user, state_dict, sync_user_input) - logger.info(f"DropBox sync created successfully for user: {current_user}") + await syncs_service.update_sync(state.sync_id, sync_user_input) + logger.info(f"DropBox sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) except Exception as e: logger.error(f"Error: {e}") diff --git a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py index 04ede0394a40..048fa5ad67d8 100644 --- a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py @@ -1,13 +1,16 @@ import os import requests -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import HTMLResponse from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user +from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.oauth2 import Oauth2State +from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -16,8 +19,7 @@ logger = get_logger(__name__) # Initialize sync service -sync_service = SyncsService() -sync_user_service = SyncsService() +syncs_service_dep = get_service(SyncsService) # Initialize API router github_sync_router = APIRouter() @@ -35,8 +37,11 @@ dependencies=[Depends(AuthBearer())], tags=["Sync"], ) -def authorize_github( - request: Request, name: str, current_user: UserIdentity = Depends(get_current_user) +async def authorize_github( + request: Request, + name: str, + syncs_service: SyncsService = Depends(syncs_service_dep), + current_user: UserIdentity = Depends(get_current_user), ): """ Authorize GitHub sync for the current user. @@ -49,25 +54,32 @@ def authorize_github( dict: A dictionary containing the authorization URL. """ logger.debug(f"Authorizing GitHub sync for user: {current_user.id}") - state = f"user_id={current_user.id},name={name}" - authorization_url = ( - f"https://github.com/login/oauth/authorize?client_id={CLIENT_ID}" - f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state}" - ) + + state_struct = Oauth2State(name=name, user_id=current_user.id) + state = state_struct.model_dump_json() sync_user_input = SyncCreateInput( - user_id=str(current_user.id), + user_id=current_user.id, name=name, provider="GitHub", credentials={}, state={"state": state}, ) - sync_user_service.create_sync_user(sync_user_input) + sync = await syncs_service.create_sync_user(sync_user_input) + state_struct.sync_id = sync.id + state = state_struct.model_dump_json() + + authorization_url = ( + f"https://github.com/login/oauth/authorize?client_id={CLIENT_ID}" + f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state}" + ) return {"authorization_url": authorization_url} @github_sync_router.get("/sync/github/oauth2callback", tags=["Sync"]) -def oauth2callback_github(request: Request): +async def oauth2callback_github( + request: Request, syncs_service: SyncsService = Depends(syncs_service_dep) +): """ Handle OAuth2 callback from GitHub. @@ -78,21 +90,36 @@ def oauth2callback_github(request: Request): dict: A dictionary containing a success message. """ state = request.query_params.get("state") - state_split = state.split(",") - current_user = state_split[0].split("=")[1] # Extract user_id from state - name = state_split[1].split("=")[1] if state else None - state_dict = {"state": state} + + if not state: + raise HTTPException(status_code=400, detail="Invalid state parameter") + + state = Oauth2State.model_validate_json(state) + + if state.sync_id is None: + raise HTTPException( + status_code=400, detail="Invalid state parameter. Unknown sync" + ) + logger.debug( - f"Handling OAuth2 callback for user: {current_user} with state: {state}" + f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" ) - sync_user_state = sync_user_service.get_sync_by_state(state_dict) - logger.info(f"Retrieved sync user state: {sync_user_state}") - if state_dict != sync_user_state["state"]: + try: + sync = await syncs_service.get_sync_by_id(state.sync_id) + except SyncNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) + if ( + not sync + or not sync.state + or state.model_dump(exclude={"sync_id"}) != sync.state["state"] + ): logger.error("Invalid state parameter") raise HTTPException(status_code=400, detail="Invalid state parameter") - if sync_user_state.get("user_id") != current_user: - logger.error("Invalid user") + + if sync.user_id != state.user_id: raise HTTPException(status_code=400, detail="Invalid user") token_url = "https://github.com/login/oauth/access_token" @@ -122,7 +149,7 @@ def oauth2callback_github(request: Request): ) creds = result - logger.info(f"Fetched OAuth2 token for user: {current_user}") + logger.info(f"Fetched OAuth2 token for user: {state.user_id}") # Fetch user email from GitHub API github_api_url = "https://api.github.com/user" @@ -145,10 +172,10 @@ def oauth2callback_github(request: Request): logger.error("Failed to fetch user email from GitHub API") raise HTTPException(status_code=400, detail="Failed to fetch user email") - logger.info(f"Retrieved email for user: {current_user} - {user_email}") + logger.info(f"Retrieved email for user: {state.user_id} - {user_email}") sync_user_input = SyncUpdateInput(credentials=result, state={}, email=user_email) - sync_user_service.update_sync(current_user, state_dict, sync_user_input) - logger.info(f"GitHub sync created successfully for user: {current_user}") + await syncs_service.update_sync(state.sync_id, sync_user_input) + logger.info(f"GitHub sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py index 7a1d437842c6..75e9c2a3b1df 100644 --- a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py @@ -1,16 +1,18 @@ import json import os -from uuid import UUID -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import HTMLResponse from google_auth_oauthlib.flow import Flow from googleapiclient.discovery import build from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user +from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.oauth2 import Oauth2State +from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -22,8 +24,7 @@ logger = get_logger(__name__) # Initialize sync service -sync_service = SyncsService() -sync_user_service = SyncsService() +syncs_service_dep = get_service(SyncsService) # Initialize API router google_sync_router = APIRouter() @@ -62,8 +63,11 @@ dependencies=[Depends(AuthBearer())], tags=["Sync"], ) -def authorize_google( - request: Request, name: str, current_user: UserIdentity = Depends(get_current_user) +async def authorize_google( + request: Request, + name: str, + current_user: UserIdentity = Depends(get_current_user), + syncs_service: SyncsService = Depends(syncs_service_dep), ): """ Authorize Google Drive sync for the current user. @@ -84,7 +88,19 @@ def authorize_google( scopes=SCOPES, redirect_uri=redirect_uri, ) - state = f"user_id={current_user.id}, name={name}" + state_struct = Oauth2State(name=name, user_id=current_user.id) + state = state_struct.model_dump_json() + sync_user_input = SyncCreateInput( + name=name, + user_id=current_user.id, + provider="Google", + credentials={}, + state={"state": state}, + additional_data={}, + ) + sync = await syncs_service.create_sync_user(sync_user_input) + state_struct.sync_id = sync.id + state = state_struct.model_dump_json() authorization_url, state = flow.authorization_url( access_type="offline", include_granted_scopes="true", @@ -94,20 +110,14 @@ def authorize_google( logger.info( f"Generated authorization URL: {authorization_url} for user: {current_user.id}" ) - sync_user_input = SyncCreateInput( - name=name, - user_id=str(current_user.id), - provider="Google", - credentials={}, - state={"state": state}, - additional_data={}, - ) - sync_user_service.create_sync_user(sync_user_input) return {"authorization_url": authorization_url} @google_sync_router.get("/sync/google/oauth2callback", tags=["Sync"]) -def oauth2callback_google(request: Request): +async def oauth2callback_google( + request: Request, + syncs_service: SyncsService = Depends(syncs_service_dep), +): """ Handle OAuth2 callback from Google. @@ -118,23 +128,35 @@ def oauth2callback_google(request: Request): dict: A dictionary containing a success message. """ state = request.query_params.get("state") - state_dict = {"state": state} - logger.info(f"State: {state}") - state_split = state.split(",") - current_user = UUID(state_split[0].split("=")[1]) if state else None - assert current_user, f"oauth2callback_googl empty current_user in {request}" + logger.debug(f"request state: {state}") + if not state: + raise HTTPException(status_code=400, detail="Invalid state parameter") + + state = Oauth2State.model_validate_json(state) + if state.sync_id is None: + raise HTTPException( + status_code=400, detail="Invalid state parameter. Unknown sync" + ) + logger.debug( - f"Handling OAuth2 callback for user: {current_user} with state: {state}" + f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" ) - sync_user_state = sync_user_service.get_sync_by_state(state_dict) - logger.info(f"Retrieved sync user state: {sync_user_state}") - if not sync_user_state or state_dict != sync_user_state.state: + try: + sync = await syncs_service.get_sync_by_id(state.sync_id) + except SyncNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" + ) + if ( + not sync + or not sync.state + or state.model_dump(exclude={"sync_id"}) != sync.state["state"] + ): logger.error("Invalid state parameter") raise HTTPException(status_code=400, detail="Invalid state parameter") - if sync_user_state.user_id != current_user: - logger.error("Invalid user") - logger.info(f"Invalid user: {current_user}") + + if sync.user_id != state.user_id: raise HTTPException(status_code=400, detail="Invalid user") redirect_uri = f"{BASE_REDIRECT_URI}" @@ -146,19 +168,19 @@ def oauth2callback_google(request: Request): ) flow.fetch_token(authorization_response=str(request.url)) creds = flow.credentials - logger.info(f"Fetched OAuth2 token for user: {current_user}") + logger.info(f"Fetched OAuth2 token for user: {state.user_id}") # Use the credentials to get the user's email service = build("oauth2", "v2", credentials=creds) user_info = service.userinfo().get().execute() user_email = user_info.get("email") - logger.info(f"Retrieved email for user: {current_user} - {user_email}") + logger.info(f"Retrieved email for user: {state.user_id} - {user_email}") sync_user_input = SyncUpdateInput( credentials=json.loads(creds.to_json()), state={}, email=user_email, ) - sync_user_service.update_sync(current_user, state_dict, sync_user_input) - logger.info(f"Google Drive sync created successfully for user: {current_user}") + sync = await syncs_service.update_sync(state.sync_id, sync_user_input) + logger.info(f"Google Drive sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py index c4fd1a034664..9aec1a3e32b9 100644 --- a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py @@ -1,17 +1,19 @@ import base64 import os -from uuid import UUID import requests -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import HTMLResponse from notion_client import Client from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user +from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.oauth2 import Oauth2State +from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -25,8 +27,7 @@ # Initialize sync service -sync_service = SyncsService() -sync_user_service = SyncsService() +syncs_service_dep = get_service(SyncsService) logger = get_logger(__name__) @@ -39,8 +40,11 @@ dependencies=[Depends(AuthBearer())], tags=["Sync"], ) -def authorize_notion( - request: Request, name: str, current_user: UserIdentity = Depends(get_current_user) +async def authorize_notion( + request: Request, + name: str, + current_user: UserIdentity = Depends(get_current_user), + syncs_service: SyncsService = Depends(syncs_service_dep), ): """ Authorize Notion sync for the current user. @@ -53,25 +57,31 @@ def authorize_notion( dict: A dictionary containing the authorization URL. """ logger.debug(f"Authorizing Notion sync for user: {current_user.id}, name : {name}") - state: str = f"user_id={current_user.id}, name={name}" - authorize_url = str(NOTION_AUTH_URL) + f"&state={state}" - - logger.info( - f"Generated authorization URL: {authorize_url} for user: {current_user.id}" - ) + state_struct = Oauth2State(name=name, user_id=current_user.id) + state = state_struct.model_dump_json() sync_user_input = SyncCreateInput( name=name, - user_id=str(current_user.id), + user_id=current_user.id, provider="Notion", credentials={}, state={"state": state}, ) - sync_user_service.create_sync_user(sync_user_input) + sync = await syncs_service.create_sync_user(sync_user_input) + state_struct.sync_id = sync.id + state = state_struct.model_dump_json() + # Finalize the state + authorize_url = str(NOTION_AUTH_URL) + f"&state={state}" + logger.debug( + f"Generated authorization URL: {authorize_url} for user: {current_user.id}" + ) return {"authorization_url": authorize_url} @notion_sync_router.get("/sync/notion/oauth2callback", tags=["Sync"]) -def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks): +async def oauth2callback_notion( + request: Request, + syncs_service: SyncsService = Depends(syncs_service_dep), +): """ Handle OAuth2 callback from Notion. @@ -83,27 +93,36 @@ def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks): """ code = request.query_params.get("code") state = request.query_params.get("state") + if not state: raise HTTPException(status_code=400, detail="Invalid state parameter") - state_dict = {"state": state} - state_split = state.split(",") # type: ignore - current_user = UUID(state_split[0].split("=")[1]) if state else None - assert current_user, "Oauth callback user is None" + state = Oauth2State.model_validate_json(state) + + if state.sync_id is None: + raise HTTPException( + status_code=400, detail="Invalid state parameter. Unknown sync" + ) + logger.debug( - f"Handling OAuth2 callback for user: {current_user} with state: {state} and state_dict: {state_dict}" + f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" ) - sync_user_state = sync_user_service.get_sync_by_state(state_dict) - if not sync_user_state or state_dict != sync_user_state.state: - logger.error(f"Invalid state parameter for {sync_user_state}") - raise HTTPException(status_code=400, detail="Invalid state parameter") - else: - logger.info( - f"Current user: {current_user}, sync user state: {sync_user_state.state}" + try: + sync = await syncs_service.get_sync_by_id(state.sync_id) + except SyncNotFoundException as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" ) + if ( + not sync + or not sync.state + or state.model_dump(exclude={"sync_id"}) != sync.state["state"] + ): + logger.error("Invalid state parameter") + raise HTTPException(status_code=400, detail="Invalid state parameter") - if sync_user_state.user_id != current_user: + if sync.user_id != state.user_id: raise HTTPException(status_code=400, detail="Invalid user") try: @@ -148,12 +167,13 @@ def oauth2callback_notion(request: Request, background_tasks: BackgroundTasks): state={}, email=user_email, ) - sync_user_service.update_sync(current_user, state_dict, sync_user_input) - logger.info(f"Notion sync created successfully for user: {current_user}") + await syncs_service.update_sync(state.sync_id, sync_user_input) + + logger.info(f"Notion sync created successfully for user: {state.user_id}") # launch celery task to sync notion data celery.send_task( "fetch_and_store_notion_files_task", - kwargs={"access_token": access_token, "user_id": current_user}, + kwargs={"access_token": access_token, "user_id": state.user_id}, ) return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/successfull_connection.py b/backend/api/quivr_api/modules/sync/controller/successfull_connection.py index ffdb877e8081..0e9f00852ebd 100644 --- a/backend/api/quivr_api/modules/sync/controller/successfull_connection.py +++ b/backend/api/quivr_api/modules/sync/controller/successfull_connection.py @@ -50,4 +50,4 @@ -""" \ No newline at end of file +""" diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index 76b18e94303e..ccea41a4dc5f 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -1,15 +1,11 @@ import os -import uuid from typing import List -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, status -from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.notification.dto.inputs import CreateNotification -from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) @@ -19,10 +15,7 @@ from quivr_api.modules.sync.controller.google_sync_routes import google_sync_router from quivr_api.modules.sync.controller.notion_sync_routes import notion_sync_router from quivr_api.modules.sync.dto import SyncsDescription -from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput from quivr_api.modules.sync.dto.outputs import AuthMethodEnum -from quivr_api.modules.sync.entity.sync_models import SyncsActive -from quivr_api.modules.sync.service.sync_notion import SyncNotionService from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -35,8 +28,7 @@ logger = get_logger(__name__) # Initialize sync service -sync_service = SyncsService() -sync_user_service = SyncsService() +syncs_service_dep = get_service(SyncsService) # Initialize API router @@ -107,7 +99,10 @@ async def get_syncs(current_user: UserIdentity = Depends(get_current_user)): dependencies=[Depends(AuthBearer())], tags=["Sync"], ) -async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user)): +async def get_user_syncs( + current_user: UserIdentity = Depends(get_current_user), + syncs_service: SyncsService = Depends(syncs_service_dep), +): """ Get syncs for the current user. @@ -118,7 +113,7 @@ async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user)) List: A list of syncs for the user. """ logger.debug(f"Fetching user syncs for user: {current_user.id}") - return sync_user_service.get_syncs(current_user.id) + return await syncs_service.get_syncs(current_user.id) @sync_router.delete( @@ -128,7 +123,9 @@ async def get_user_syncs(current_user: UserIdentity = Depends(get_current_user)) tags=["Sync"], ) async def delete_user_sync( - sync_id: int, current_user: UserIdentity = Depends(get_current_user) + sync_id: int, + current_user: UserIdentity = Depends(get_current_user), + syncs_service: SyncsService = Depends(syncs_service_dep), ): """ Delete a sync for the current user. @@ -143,227 +140,10 @@ async def delete_user_sync( logger.debug( f"Deleting user sync for user: {current_user.id} with sync ID: {sync_id}" ) - sync_user_service.delete_sync(sync_id, str(current_user.id)) # type: ignore + await syncs_service.delete_sync(sync_id, current_user.id) return None -@sync_router.post( - "/sync/active", - response_model=SyncsActive, - dependencies=[Depends(AuthBearer())], - tags=["Sync"], -) -async def create_sync_active( - sync_active_input: SyncsActiveInput, - current_user: UserIdentity = Depends(get_current_user), -): - """ - Create a new active sync for the current user. - - Args: - sync_active_input (SyncsActiveInput): The sync active input data. - current_user (UserIdentity): The current authenticated user. - - Returns: - SyncsActive: The created sync active data. - """ - logger.debug( - f"Creating active sync for user: {current_user.id} with data: {sync_active_input}" - ) - bulk_id = uuid.uuid4() - notification = notification_service.add_notification( - CreateNotification( - user_id=current_user.id, - status=NotificationsStatusEnum.INFO, - title="Synchronization created! ", - description="Your brain is preparing to sync files. This may take a few minutes before proceeding.", - category="generic", - bulk_id=bulk_id, - brain_id=sync_active_input.brain_id, - ) - ) - sync_active_input.notification_id = str(notification.id) - sync_active = sync_service.create_sync_active( - sync_active_input, str(current_user.id) - ) - if not sync_active: - raise HTTPException( - status_code=500, detail=f"Error creating sync active for {current_user}" - ) - - celery.send_task( - "process_sync_task", - kwargs={ - "sync_id": sync_active.id, - "user_id": sync_active.user_id, - "files_ids": sync_active_input.settings.files, - "folder_ids": sync_active_input.settings.folders, - }, - ) - - return sync_active - - -@sync_router.put( - "/sync/active/{sync_id}", - response_model=SyncsActive | None, - dependencies=[Depends(AuthBearer())], - tags=["Sync"], -) -async def update_sync_active( - sync_id: int, - sync_active_input: SyncsActiveUpdateInput, - current_user: UserIdentity = Depends(get_current_user), -): - """ - Update an existing active sync for the current user. - - Args: - sync_id (str): The ID of the active sync to update. - sync_active_input (SyncsActiveUpdateInput): The updated sync active input data. - current_user (UserIdentity): The current authenticated user. - - Returns: - SyncsActive: The updated sync active data. - """ - logger.info( - f"Updating active sync for user: {current_user.id} with data: {sync_active_input}" - ) - - details_sync_active = sync_service.get_details_sync_active(sync_id) - - if details_sync_active is None: - raise HTTPException( - status_code=500, - detail="Error updating sync", - ) - - if sync_active_input.settings is None: - return {"message": "No modification to sync active"} - - input_file_ids = ( - sync_active_input.settings.files if sync_active_input.settings.files else [] - ) - input_folder_ids = ( - sync_active_input.settings.folders if sync_active_input.settings.folders else [] - ) - - if (input_file_ids == details_sync_active["settings"]["files"]) and ( - input_folder_ids == details_sync_active["settings"]["folders"] - ): - logger.info({"message": "No modification to sync active"}) - return None - - logger.debug( - f"Updating sync_id {details_sync_active['id']}. Sync prev_settings={details_sync_active['settings'] }, Sync active input={sync_active_input.settings}" - ) - - bulk_id = uuid.uuid4() - sync_active_input.force_sync = True - notification = notification_service.add_notification( - CreateNotification( - user_id=current_user.id, - status=NotificationsStatusEnum.INFO, - title="Sync updated! Synchronization takes a few minutes to complete", - description="Your brain is syncing files. This may take a few minutes before proceeding.", - category="generic", - bulk_id=bulk_id, - brain_id=details_sync_active["brain_id"], # type: ignore - ) - ) - sync_active_input.notification_id = str(notification.id) - sync_active = sync_service.update_sync_active(sync_id, sync_active_input) - if not sync_active: - raise HTTPException( - status_code=500, - detail=f"Error updating sync active for {current_user.id}", - ) - logger.debug( - f"Sending task process_sync_task for sync_id={sync_id}, user_id={current_user.id}" - ) - - added_files_ids = set(input_file_ids).difference( - set(details_sync_active["settings"]["files"]) - ) - added_folder_ids = set(input_folder_ids).difference( - set(details_sync_active["settings"]["folders"]) - ) - if len(added_files_ids) + len(added_folder_ids) > 0: - celery.send_task( - "process_sync_task", - kwargs={ - "sync_id": sync_active.id, - "user_id": sync_active.user_id, - "files_ids": list(added_files_ids), - "folder_ids": list(added_folder_ids), - }, - ) - - else: - return None - - -@sync_router.delete( - "/sync/active/{sync_id}", - status_code=status.HTTP_204_NO_CONTENT, - dependencies=[Depends(AuthBearer())], - tags=["Sync"], -) -async def delete_sync_active( - sync_id: int, current_user: UserIdentity = Depends(get_current_user) -): - """ - Delete an existing active sync for the current user. - - Args: - sync_id (str): The ID of the active sync to delete. - current_user (UserIdentity): The current authenticated user. - - Returns: - None - """ - logger.debug( - f"Deleting active sync for user: {current_user.id} with sync ID: {sync_id}" - ) - - details_sync_active = sync_service.get_details_sync_active(sync_id) - notification_service.add_notification( - CreateNotification( - user_id=current_user.id, - status=NotificationsStatusEnum.SUCCESS, - title="Sync deleted!", - description="Sync deleted!", - category="generic", - bulk_id=uuid.uuid4(), - brain_id=details_sync_active["brain_id"], # type: ignore - ) - ) - sync_service.delete_sync_active(sync_id, str(current_user.id)) # type: ignore - return None - - -@sync_router.get( - "/sync/active", - response_model=List[SyncsActive], - dependencies=[Depends(AuthBearer())], - tags=["Sync"], -) -async def get_active_syncs_for_user( - current_user: UserIdentity = Depends(get_current_user), -): - """ - Get all active syncs for the current user. - - Args: - current_user (UserIdentity): The current authenticated user. - - Returns: - List[SyncsActive]: A list of active syncs for the current user. - """ - logger.debug(f"Fetching active syncs for user: {current_user.id}") - return sync_service.get_syncs_active(str(current_user.id)) - - @sync_router.get( "/sync/{sync_id}/files", dependencies=[Depends(AuthBearer())], @@ -372,8 +152,8 @@ async def get_active_syncs_for_user( async def get_files_folder_user_sync( user_sync_id: int, folder_id: str | None = None, - notion_service: SyncNotionService = Depends(get_service(SyncNotionService)), current_user: UserIdentity = Depends(get_current_user), + syncs_service: SyncsService = Depends(syncs_service_dep), ): """ Get files for an active sync. @@ -389,22 +169,8 @@ async def get_files_folder_user_sync( logger.debug( f"Fetching files for user sync: {user_sync_id} for user: {current_user.id}" ) - return await sync_user_service.get_files_folder_user_sync( - user_sync_id, current_user.id, folder_id, notion_service=notion_service + return await syncs_service.get_files_folder_user_sync( + user_sync_id, + current_user.id, + folder_id, ) - - -@sync_router.get( - "/sync/active/interval", - dependencies=[Depends(AuthBearer())], - tags=["Sync"], -) -async def get_syncs_active_in_interval() -> List[SyncsActive]: - """ - Get all active syncs that need to be synced. - - Returns: - List: A list of active syncs that need to be synced. - """ - logger.debug("Fetching active syncs in interval") - return await sync_service.get_syncs_active_in_interval() diff --git a/backend/api/quivr_api/modules/sync/dto/inputs.py b/backend/api/quivr_api/modules/sync/dto/inputs.py index a267304dbaa9..41ec8e71f72e 100644 --- a/backend/api/quivr_api/modules/sync/dto/inputs.py +++ b/backend/api/quivr_api/modules/sync/dto/inputs.py @@ -32,6 +32,7 @@ class SyncUpdateInput(BaseModel): state (dict): The updated state information for the sync user. """ - credentials: dict - state: dict - email: str + additional_data: dict | None = None + credentials: dict | None = None + state: dict | None = None + email: str | None = None diff --git a/backend/api/quivr_api/modules/sync/dto/outputs.py b/backend/api/quivr_api/modules/sync/dto/outputs.py index 6330abacba2d..5750e8f60137 100644 --- a/backend/api/quivr_api/modules/sync/dto/outputs.py +++ b/backend/api/quivr_api/modules/sync/dto/outputs.py @@ -23,6 +23,7 @@ class SyncsDescription(BaseModel): class SyncsOutput(BaseModel): + id: int user_id: UUID provider: SyncProvider state: dict | None diff --git a/backend/api/quivr_api/modules/sync/entity/notion_page.py b/backend/api/quivr_api/modules/sync/entity/notion_page.py index f84f89fd26ec..7a42f190250f 100644 --- a/backend/api/quivr_api/modules/sync/entity/notion_page.py +++ b/backend/api/quivr_api/modules/sync/entity/notion_page.py @@ -3,6 +3,7 @@ from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator + from quivr_api.modules.sync.entity.sync_models import NotionSyncFile diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index e8a01eaf37ba..8612dba21c9f 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -51,19 +51,24 @@ class SyncFile(BaseModel): class Syncs(SQLModel, table=True): - __tablename__ = "syns_user" # type: ignore - id: UUID | None = Field(default=None, primary_key=True) + __tablename__ = "syncs_user" # type: ignore + + id: int | None = Field(default=None, primary_key=True) user_id: UUID = Field(foreign_key="users.id", nullable=False) name: str provider: str credentials: Dict[str, str] | None = Field( - default=None, sa_column=Column("state", JSON) + default=None, sa_column=Column("credentials", JSON) ) state: Dict[str, str] | None = Field(default=None, sa_column=Column("state", JSON)) - additional_data: dict + additional_data: dict | None = Field( + default=None, sa_column=Column("additional_data", JSON) + ) def to_dto(self) -> SyncsOutput: + assert self.id, "can't create create output if sync isn't inserted" return SyncsOutput( + id=self.id, user_id=self.user_id, provider=SyncProvider(self.provider), credentials=self.credentials, diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index 62324d652979..4ec7259ac67a 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -23,17 +23,13 @@ async def create_sync_user(self, sync_user_input: SyncCreateInput) -> SyncsOutpu return sync.to_dto() async def get_syncs(self, user_id: UUID, sync_id: int | None = None): - return self.repository.get_syncs(user_id, sync_id) + return await self.repository.get_syncs(user_id, sync_id) async def delete_sync(self, sync_id: int, user_id: UUID): await self.repository.delete_sync(sync_id, user_id) - async def get_sync_by_state(self, state: dict) -> SyncsOutput: - sync = await self.repository.get_sync_user_by_state(state) - return sync.to_dto() - async def get_sync_by_id(self, sync_id: int): - return self.repository.get_sync_id(sync_id) + return await self.repository.get_sync_id(sync_id) async def update_sync( self, sync_id: int, sync_user_input: SyncUpdateInput diff --git a/backend/api/quivr_api/modules/sync/tests/conftest.py b/backend/api/quivr_api/modules/sync/tests/conftest.py index bd52fbc4d99b..4ace95a7abf1 100644 --- a/backend/api/quivr_api/modules/sync/tests/conftest.py +++ b/backend/api/quivr_api/modules/sync/tests/conftest.py @@ -1,20 +1,14 @@ import json import os -from collections import defaultdict -from datetime import datetime, timedelta +from datetime import datetime from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Union from uuid import UUID, uuid4 import pytest -import pytest_asyncio from sqlmodel import select -from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType -from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors -from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository -from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.dto.inputs import ( CreateNotification, NotificationUpdatableProperties, @@ -26,17 +20,6 @@ from quivr_api.modules.notification.repository.notifications_interface import ( NotificationInterface, ) -from quivr_api.modules.notification.service.notification_service import ( - NotificationService, -) -from quivr_api.modules.sync.dto.inputs import ( - SyncCreateInput, - SyncFileInput, - SyncFileUpdateInput, - SyncsActiveInput, - SyncsActiveUpdateInput, - SyncUpdateInput, -) from quivr_api.modules.sync.entity.notion_page import ( BlockParent, DatabaseParent, @@ -50,23 +33,11 @@ WorkspaceParent, ) from quivr_api.modules.sync.entity.sync_models import ( - DBSyncFile, SyncFile, - SyncsActive, - SyncsUser, -) -from quivr_api.modules.sync.repository.sync_interfaces import SyncFileInterface -from quivr_api.modules.sync.service.sync_notion import SyncNotionService -from quivr_api.modules.sync.service.sync_service import ( - ISyncService, - ISyncUserService, ) from quivr_api.modules.sync.utils.sync import ( BaseSync, ) -from quivr_api.modules.sync.utils.syncutils import ( - SyncUtils, -) from quivr_api.modules.user.entity.user_identity import User pg_database_base_url = "postgres:postgres@localhost:54322/postgres" @@ -432,320 +403,320 @@ def remove_notification_by_id(self, notification_id: UUID): del self.received[notification_id] -class MockSyncService(ISyncService): - def __init__(self, sync_active: SyncsActive): - self.syncs_active_user = {} - self.syncs_active_id = {} - self.syncs_active_user[sync_active.user_id] = sync_active - self.syncs_active_id[sync_active.id] = sync_active - - def create_sync_active( - self, - sync_active_input: SyncsActiveInput, - user_id: str, - ) -> SyncsActive | None: - sactive = SyncsActive( - id=len(self.syncs_active_user) + 1, - user_id=UUID(user_id), - **sync_active_input.model_dump(), - ) - self.syncs_active_user[user_id] = sactive - return sactive - - def get_syncs_active(self, user_id: str) -> List[SyncsActive]: - return self.syncs_active_user[user_id] - - def update_sync_active( - self, sync_id: int, sync_active_input: SyncsActiveUpdateInput - ): - sync = self.syncs_active_id[sync_id] - sync = SyncsActive(**sync.model_dump(), **sync_active_input.model_dump()) - self.syncs_active_id[sync_id] = sync - return sync - - def delete_sync_active(self, sync_active_id: int, user_id: UUID): - del self.syncs_active_id[sync_active_id] - del self.syncs_active_user[user_id] - - async def get_syncs_active_in_interval(self) -> List[SyncsActive]: - return list(self.syncs_active_id.values()) - - def get_details_sync_active(self, sync_active_id: int): - return - - -class MockSyncUserService(ISyncUserService): - def __init__(self, sync_user: SyncsUser): - self.map_id = {} - self.map_userid = {} - self.map_id[sync_user.id] = sync_user - self.map_userid[sync_user.id] = sync_user - - def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None): - return self.map_userid[user_id] - - def get_sync_user_by_id(self, sync_id: int): - return self.map_id[sync_id] - - def create_sync_user(self, sync_user_input: SyncCreateInput): - id = len(self.map_userid) + 1 - self.map_userid[sync_user_input.user_id] = SyncsUser( - id=id, **sync_user_input.model_dump() - ) - self.map_id[id] = self.map_userid[sync_user_input.user_id] - return self.map_id[id] - - def delete_sync_user(self, sync_id: int, user_id: str): - del self.map_userid[user_id] - del self.map_userid[sync_id] - - def get_sync_user_by_state(self, state: dict) -> SyncsUser | None: - return list(self.map_userid.values())[-1] - - def update_sync_user( - self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput - ): - return - - def get_all_notion_user_syncs(self): - return - - async def get_files_folder_user_sync( - self, - sync_active_id: int, - user_id: UUID, - folder_id: str | None = None, - recursive: bool = False, - notion_service: SyncNotionService | None = None, - ): - return - - -class MockSyncFilesRepository(SyncFileInterface): - def __init__(self): - self.files_store = defaultdict(list) - self.next_id = 1 - - def create_sync_file(self, sync_file_input: SyncFileInput) -> Optional[DBSyncFile]: - supported = sync_file_input.supported if sync_file_input.supported else True - new_file = DBSyncFile( - id=self.next_id, - path=sync_file_input.path, - syncs_active_id=sync_file_input.syncs_active_id, - last_modified=sync_file_input.last_modified, - brain_id=sync_file_input.brain_id, - supported=supported, - ) - self.files_store[sync_file_input.syncs_active_id].append(new_file) - self.next_id += 1 - return new_file - - def get_sync_files(self, sync_active_id: int) -> List[DBSyncFile]: - """ - Retrieve sync files from the mock database. - - Args: - sync_active_id (int): The ID of the active sync. - - Returns: - List[DBSyncFile]: A list of sync files matching the criteria. - """ - return self.files_store[sync_active_id] - - def update_sync_file( - self, sync_file_id: int, sync_file_input: SyncFileUpdateInput - ) -> None: - for sync_files in self.files_store.values(): - for file in sync_files: - if file.id == sync_file_id: - update_data = sync_file_input.model_dump(exclude_unset=True) - if "last_modified" in update_data: - file.last_modified = update_data["last_modified"] - if "supported" in update_data: - file.supported = update_data["supported"] - return - - def update_or_create_sync_file( - self, - file: SyncFile, - sync_active: SyncsActive, - previous_file: Optional[DBSyncFile], - supported: bool, - ) -> Optional[DBSyncFile]: - if previous_file: - self.update_sync_file( - previous_file.id, - SyncFileUpdateInput( - last_modified=file.last_modified, - supported=previous_file.supported or supported, - ), - ) - return previous_file - else: - return self.create_sync_file( - SyncFileInput( - path=file.name, - syncs_active_id=sync_active.id, - last_modified=file.last_modified, - brain_id=str(sync_active.brain_id), - supported=supported, - ) - ) - - def delete_sync_file(self, sync_file_id: int) -> None: - for sync_active_id, sync_files in self.files_store.items(): - self.files_store[sync_active_id] = [ - file for file in sync_files if file.id != sync_file_id - ] - - -@pytest.fixture -def sync_file(): - file = SyncFile( - id=str(uuid4()), - name="test_file.txt", - is_folder=False, - last_modified=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - mime_type=".txt", - web_view_link="", - notification_id=uuid4(), # - ) - return file - - -@pytest.fixture -def prev_file(): - file = SyncFile( - id=str(uuid4()), - name="test_file.txt", - is_folder=False, - last_modified=(datetime.now() - timedelta(hours=1)).strftime( - "%Y-%m-%d %H:%M:%S" - ), - mime_type="txt", - web_view_link="", - notification_id=uuid4(), # - ) - return file - - -@pytest_asyncio.fixture(scope="function") -async def brain_user_setup( - session, -) -> Tuple[Brain, User]: - user_1 = ( - await session.exec(select(User).where(User.email == "admin@quivr.app")) - ).one() - # Brain data - brain_1 = Brain( - name="test_brain", - description="this is a test brain", - brain_type=BrainType.integration, - ) - - session.add(brain_1) - await session.refresh(user_1) - await session.commit() - assert user_1 - assert brain_1.brain_id - return brain_1, user_1 - - -@pytest_asyncio.fixture(scope="function") -async def setup_syncs_data( - brain_user_setup, -) -> Tuple[SyncsUser, SyncsActive]: - brain_1, user_1 = brain_user_setup - - sync_user = SyncsUser( - id=0, - user_id=user_1.id, - name="c8xfz3g566b8xa1ajiesdh", - provider="mock", - credentials={}, - state={}, - additional_data={}, - ) - sync_active = SyncsActive( - id=0, - name="test", - syncs_user_id=sync_user.id, - user_id=sync_user.user_id, - settings={}, - last_synced=str(datetime.now() - timedelta(hours=5)), - sync_interval_minutes=1, - brain_id=brain_1.brain_id, - ) - - return (sync_user, sync_active) - - -@pytest.fixture -def syncutils( - sync_file: SyncFile, - prev_file: SyncFile, - setup_syncs_data: Tuple[SyncsUser, SyncsActive], - session, -) -> SyncUtils: - (sync_user, sync_active) = setup_syncs_data - assert sync_file.notification_id - sync_active_service = MockSyncService(sync_active) - sync_user_service = MockSyncUserService(sync_user) - sync_files_repo_service = MockSyncFilesRepository() - knowledge_service = KnowledgeService(KnowledgeRepository(session)) - notification_service = NotificationService( - repository=MockNotification( - [sync_file.notification_id, prev_file.notification_id], # type: ignore - sync_user.user_id, - sync_active.brain_id, - ) - ) - brain_vectors = BrainsVectors() - sync_cloud = MockSyncCloud() - - sync_util = SyncUtils( - sync_user_service=sync_user_service, - sync_active_service=sync_active_service, - sync_files_repo=sync_files_repo_service, - sync_cloud=sync_cloud, - notification_service=notification_service, - brain_vectors=brain_vectors, - knowledge_service=knowledge_service, - ) - - return sync_util - - -@pytest.fixture -def syncutils_notion( - sync_file: SyncFile, - prev_file: SyncFile, - setup_syncs_data: Tuple[SyncsUser, SyncsActive], - session, -) -> SyncUtils: - (sync_user, sync_active) = setup_syncs_data - assert sync_file.notification_id - sync_active_service = MockSyncService(sync_active) - sync_user_service = MockSyncUserService(sync_user) - sync_files_repo_service = MockSyncFilesRepository() - knowledge_service = KnowledgeService(KnowledgeRepository(session)) - notification_service = NotificationService( - repository=MockNotification( - [sync_file.notification_id, prev_file.notification_id], # type: ignore - sync_user.user_id, - sync_active.brain_id, - ) - ) - brain_vectors = BrainsVectors() - sync_cloud = MockSyncCloudNotion() - sync_util = SyncUtils( - sync_user_service=sync_user_service, - sync_active_service=sync_active_service, - sync_files_repo=sync_files_repo_service, - sync_cloud=sync_cloud, - notification_service=notification_service, - brain_vectors=brain_vectors, - knowledge_service=knowledge_service, - ) - - return sync_util +# class MockSyncService(ISyncService): +# def __init__(self, sync_active: SyncsActive): +# self.syncs_active_user = {} +# self.syncs_active_id = {} +# self.syncs_active_user[sync_active.user_id] = sync_active +# self.syncs_active_id[sync_active.id] = sync_active + +# def create_sync_active( +# self, +# sync_active_input: SyncsActiveInput, +# user_id: str, +# ) -> SyncsActive | None: +# sactive = SyncsActive( +# id=len(self.syncs_active_user) + 1, +# user_id=UUID(user_id), +# **sync_active_input.model_dump(), +# ) +# self.syncs_active_user[user_id] = sactive +# return sactive + +# def get_syncs_active(self, user_id: str) -> List[SyncsActive]: +# return self.syncs_active_user[user_id] + +# def update_sync_active( +# self, sync_id: int, sync_active_input: SyncsActiveUpdateInput +# ): +# sync = self.syncs_active_id[sync_id] +# sync = SyncsActive(**sync.model_dump(), **sync_active_input.model_dump()) +# self.syncs_active_id[sync_id] = sync +# return sync + +# def delete_sync_active(self, sync_active_id: int, user_id: UUID): +# del self.syncs_active_id[sync_active_id] +# del self.syncs_active_user[user_id] + +# async def get_syncs_active_in_interval(self) -> List[SyncsActive]: +# return list(self.syncs_active_id.values()) + +# def get_details_sync_active(self, sync_active_id: int): +# return + + +# class MockSyncUserService(ISyncUserService): +# def __init__(self, sync_user: SyncsUser): +# self.map_id = {} +# self.map_userid = {} +# self.map_id[sync_user.id] = sync_user +# self.map_userid[sync_user.id] = sync_user + +# def get_syncs_user(self, user_id: UUID, sync_user_id: int | None = None): +# return self.map_userid[user_id] + +# def get_sync_user_by_id(self, sync_id: int): +# return self.map_id[sync_id] + +# def create_sync_user(self, sync_user_input: SyncCreateInput): +# id = len(self.map_userid) + 1 +# self.map_userid[sync_user_input.user_id] = SyncsUser( +# id=id, **sync_user_input.model_dump() +# ) +# self.map_id[id] = self.map_userid[sync_user_input.user_id] +# return self.map_id[id] + +# def delete_sync_user(self, sync_id: int, user_id: str): +# del self.map_userid[user_id] +# del self.map_userid[sync_id] + +# def get_sync_user_by_state(self, state: dict) -> SyncsUser | None: +# return list(self.map_userid.values())[-1] + +# def update_sync_user( +# self, sync_user_id: UUID, state: dict, sync_user_input: SyncUpdateInput +# ): +# return + +# def get_all_notion_user_syncs(self): +# return + +# async def get_files_folder_user_sync( +# self, +# sync_active_id: int, +# user_id: UUID, +# folder_id: str | None = None, +# recursive: bool = False, +# notion_service: SyncNotionService | None = None, +# ): +# return + + +# class MockSyncFilesRepository(SyncFileInterface): +# def __init__(self): +# self.files_store = defaultdict(list) +# self.next_id = 1 + +# def create_sync_file(self, sync_file_input: SyncFileInput) -> Optional[DBSyncFile]: +# supported = sync_file_input.supported if sync_file_input.supported else True +# new_file = DBSyncFile( +# id=self.next_id, +# path=sync_file_input.path, +# syncs_active_id=sync_file_input.syncs_active_id, +# last_modified=sync_file_input.last_modified, +# brain_id=sync_file_input.brain_id, +# supported=supported, +# ) +# self.files_store[sync_file_input.syncs_active_id].append(new_file) +# self.next_id += 1 +# return new_file + +# def get_sync_files(self, sync_active_id: int) -> List[DBSyncFile]: +# """ +# Retrieve sync files from the mock database. + +# Args: +# sync_active_id (int): The ID of the active sync. + +# Returns: +# List[DBSyncFile]: A list of sync files matching the criteria. +# """ +# return self.files_store[sync_active_id] + +# def update_sync_file( +# self, sync_file_id: int, sync_file_input: SyncFileUpdateInput +# ) -> None: +# for sync_files in self.files_store.values(): +# for file in sync_files: +# if file.id == sync_file_id: +# update_data = sync_file_input.model_dump(exclude_unset=True) +# if "last_modified" in update_data: +# file.last_modified = update_data["last_modified"] +# if "supported" in update_data: +# file.supported = update_data["supported"] +# return + +# def update_or_create_sync_file( +# self, +# file: SyncFile, +# sync_active: SyncsActive, +# previous_file: Optional[DBSyncFile], +# supported: bool, +# ) -> Optional[DBSyncFile]: +# if previous_file: +# self.update_sync_file( +# previous_file.id, +# SyncFileUpdateInput( +# last_modified=file.last_modified, +# supported=previous_file.supported or supported, +# ), +# ) +# return previous_file +# else: +# return self.create_sync_file( +# SyncFileInput( +# path=file.name, +# syncs_active_id=sync_active.id, +# last_modified=file.last_modified, +# brain_id=str(sync_active.brain_id), +# supported=supported, +# ) +# ) + +# def delete_sync_file(self, sync_file_id: int) -> None: +# for sync_active_id, sync_files in self.files_store.items(): +# self.files_store[sync_active_id] = [ +# file for file in sync_files if file.id != sync_file_id +# ] + + +# @pytest.fixture +# def sync_file(): +# file = SyncFile( +# id=str(uuid4()), +# name="test_file.txt", +# is_folder=False, +# last_modified=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), +# mime_type=".txt", +# web_view_link="", +# notification_id=uuid4(), # +# ) +# return file + + +# @pytest.fixture +# def prev_file(): +# file = SyncFile( +# id=str(uuid4()), +# name="test_file.txt", +# is_folder=False, +# last_modified=(datetime.now() - timedelta(hours=1)).strftime( +# "%Y-%m-%d %H:%M:%S" +# ), +# mime_type="txt", +# web_view_link="", +# notification_id=uuid4(), # +# ) +# return file + + +# @pytest_asyncio.fixture(scope="function") +# async def brain_user_setup( +# session, +# ) -> Tuple[Brain, User]: +# user_1 = ( +# await session.exec(select(User).where(User.email == "admin@quivr.app")) +# ).one() +# # Brain data +# brain_1 = Brain( +# name="test_brain", +# description="this is a test brain", +# brain_type=BrainType.integration, +# ) + +# session.add(brain_1) +# await session.refresh(user_1) +# await session.commit() +# assert user_1 +# assert brain_1.brain_id +# return brain_1, user_1 + + +# @pytest_asyncio.fixture(scope="function") +# async def setup_syncs_data( +# brain_user_setup, +# ) -> Tuple[SyncsUser, SyncsActive]: +# brain_1, user_1 = brain_user_setup + +# sync_user = SyncsUser( +# id=0, +# user_id=user_1.id, +# name="c8xfz3g566b8xa1ajiesdh", +# provider="mock", +# credentials={}, +# state={}, +# additional_data={}, +# ) +# sync_active = SyncsActive( +# id=0, +# name="test", +# syncs_user_id=sync_user.id, +# user_id=sync_user.user_id, +# settings={}, +# last_synced=str(datetime.now() - timedelta(hours=5)), +# sync_interval_minutes=1, +# brain_id=brain_1.brain_id, +# ) + +# return (sync_user, sync_active) + + +# @pytest.fixture +# def syncutils( +# sync_file: SyncFile, +# prev_file: SyncFile, +# setup_syncs_data: Tuple[SyncsUser, SyncsActive], +# session, +# ) -> SyncUtils: +# (sync_user, sync_active) = setup_syncs_data +# assert sync_file.notification_id +# sync_active_service = MockSyncService(sync_active) +# sync_user_service = MockSyncUserService(sync_user) +# sync_files_repo_service = MockSyncFilesRepository() +# knowledge_service = KnowledgeService(KnowledgeRepository(session)) +# notification_service = NotificationService( +# repository=MockNotification( +# [sync_file.notification_id, prev_file.notification_id], # type: ignore +# sync_user.user_id, +# sync_active.brain_id, +# ) +# ) +# brain_vectors = BrainsVectors() +# sync_cloud = MockSyncCloud() + +# sync_util = SyncUtils( +# sync_user_service=sync_user_service, +# sync_active_service=sync_active_service, +# sync_files_repo=sync_files_repo_service, +# sync_cloud=sync_cloud, +# notification_service=notification_service, +# brain_vectors=brain_vectors, +# knowledge_service=knowledge_service, +# ) + +# return sync_util + + +# @pytest.fixture +# def syncutils_notion( +# sync_file: SyncFile, +# prev_file: SyncFile, +# setup_syncs_data: Tuple[SyncsUser, SyncsActive], +# session, +# ) -> SyncUtils: +# (sync_user, sync_active) = setup_syncs_data +# assert sync_file.notification_id +# sync_active_service = MockSyncService(sync_active) +# sync_user_service = MockSyncUserService(sync_user) +# sync_files_repo_service = MockSyncFilesRepository() +# knowledge_service = KnowledgeService(KnowledgeRepository(session)) +# notification_service = NotificationService( +# repository=MockNotification( +# [sync_file.notification_id, prev_file.notification_id], # type: ignore +# sync_user.user_id, +# sync_active.brain_id, +# ) +# ) +# brain_vectors = BrainsVectors() +# sync_cloud = MockSyncCloudNotion() +# sync_util = SyncUtils( +# sync_user_service=sync_user_service, +# sync_active_service=sync_active_service, +# sync_files_repo=sync_files_repo_service, +# sync_cloud=sync_cloud, +# notification_service=notification_service, +# brain_vectors=brain_vectors, +# knowledge_service=knowledge_service, +# ) + +# return sync_util diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_service.py b/backend/api/quivr_api/modules/sync/tests/test_sync_service.py index e69de29bb2d1..cf398efb3d82 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_sync_service.py +++ b/backend/api/quivr_api/modules/sync/tests/test_sync_service.py @@ -0,0 +1,40 @@ +import pytest +import pytest_asyncio +from sqlmodel import select + +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Syncs +from quivr_api.modules.sync.repository.sync_repository import SyncsRepository +from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.user.entity.user_identity import User + + +@pytest_asyncio.fixture(scope="function") +async def user(session): + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + return user_1 + + +@pytest_asyncio.fixture(scope="function") +async def test_sync(session, user): + assert user.id + + sync = Syncs( + user_id=user.id, + name="test_sync", + provider=SyncProvider.GOOGLE, + ) + + session.add(sync) + await session.commit() + await session.refresh(sync) + return sync + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sync_delete_sync(session, test_sync): + assert test_sync.id + service = SyncsService(SyncsRepository(session)) + await service.delete_sync(test_sync.id, test_sync.user_id) diff --git a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py index 63b212128e11..95d1717a2c65 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py +++ b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py @@ -1,71 +1,60 @@ -from datetime import datetime, timedelta, timezone -from typing import Tuple -from uuid import uuid4 - -import pytest - -from quivr_api.modules.brain.entity.brain_entity import Brain -from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum -from quivr_api.modules.sync.entity.sync_models import ( - DBSyncFile, - SyncFile, - SyncsActive, - SyncsUser, -) -from quivr_api.modules.sync.utils.syncutils import ( - SyncUtils, - filter_on_supported_files, - should_download_file, -) -from quivr_api.modules.upload.service.upload_file import check_file_exists -from quivr_api.modules.user.entity.user_identity import User - - -def test_filter_on_supported_files_empty_existing(): - files = [ - SyncFile( - id="1", - name="file_name", - is_folder=True, - last_modified=str(datetime.now()), - mime_type="txt", - web_view_link="link", - ) - ] - existing_file = {} - - assert [(files[0], None)] == filter_on_supported_files(files, existing_file) - - -def test_filter_on_supported_files_prev_not_supported(): - files = [ - SyncFile( - id=f"{idx}", - name=f"file_name_{idx}", - is_folder=False, - last_modified=str(datetime.now()), - mime_type="txt", - web_view_link="link", - ) - for idx in range(3) - ] - existing_files = { - file.name: DBSyncFile( - id=idx, - path=file.name, - syncs_active_id=1, - last_modified=str(datetime.now()), - brain_id=str(uuid4()), - supported=idx % 2 == 0, - ) - for idx, file in enumerate(files) - } - - assert [ - (files[idx], existing_files[f"file_name_{idx}"]) - for idx in range(3) - if idx % 2 == 0 - ] == filter_on_supported_files(files, existing_files) +# from datetime import datetime, timedelta, timezone + + +# from quivr_api.modules.sync.entity.sync_models import ( +# SyncFile, +# ) +# from quivr_api.modules.sync.utils.syncutils import ( +# filter_on_supported_files, +# should_download_file, +# ) + + +# def test_filter_on_supported_files_empty_existing(): +# files = [ +# SyncFile( +# id="1", +# name="file_name", +# is_folder=True, +# last_modified=str(datetime.now()), +# mime_type="txt", +# web_view_link="link", +# ) +# ] +# existing_file = {} + +# assert [(files[0], None)] == filter_on_supported_files(files, existing_file) + + +# def test_filter_on_supported_files_prev_not_supported(): +# files = [ +# SyncFile( +# id=f"{idx}", +# name=f"file_name_{idx}", +# is_folder=False, +# last_modified=str(datetime.now()), +# mime_type="txt", +# web_view_link="link", +# ) +# for idx in range(3) +# ] +# existing_files = { +# file.name: DBSyncFile( +# id=idx, +# path=file.name, +# syncs_active_id=1, +# last_modified=str(datetime.now()), +# brain_id=str(uuid4()), +# supported=idx % 2 == 0, +# ) +# for idx, file in enumerate(files) +# } + +# assert [ +# (files[idx], existing_files[f"file_name_{idx}"]) +# for idx in range(3) +# if idx % 2 == 0 +# ] == filter_on_supported_files(files, existing_files) def test_should_download_file_no_sync_time_not_folder(): @@ -164,276 +153,276 @@ def test_should_download_file_lastsynctime_before(): ) -def test_should_download_file_lastsynctime_after(): - datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" - file_not_folder = SyncFile( - id="1", - name="file_name", - is_folder=False, - last_modified=(datetime.now() - timedelta(hours=5)).strftime(datetime_format), - mime_type="txt", - web_view_link="link", - ) - last_sync_time = datetime.now().astimezone(timezone.utc) - - assert not should_download_file( - file=file_not_folder, - last_updated_sync_active=last_sync_time, - provider_name="google", - datetime_format=datetime_format, - ) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_get_syncfiles_from_ids_nofolder(syncutils: SyncUtils): - files = await syncutils.get_syncfiles_from_ids( - credentials={}, files_ids=[str(uuid4())], folder_ids=[] - ) - assert len(files) == 1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_get_syncfiles_from_ids_folder(syncutils: SyncUtils): - files = await syncutils.get_syncfiles_from_ids( - credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())] - ) - assert len(files) == 2 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_get_syncfiles_from_ids_notion(syncutils_notion: SyncUtils): - files = await syncutils_notion.get_syncfiles_from_ids( - credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())] - ) - assert len(files) == 3 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_download_file(syncutils: SyncUtils): - file = SyncFile( - id=str(uuid4()), - name="test_file.txt", - is_folder=False, - last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format), - mime_type="txt", - web_view_link="", - ) - dfile = await syncutils.download_file(file, {}) - assert dfile.extension == ".txt" - assert dfile.file_name == file.name - assert len(dfile.file_data.read()) > 0 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_process_sync_file_not_supported(syncutils: SyncUtils): - file = SyncFile( - id=str(uuid4()), - name="test_file.asldkjfalsdkjf", - is_folder=False, - last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format), - mime_type="txt", - web_view_link="", - notification_id=uuid4(), # - ) - brain_id = uuid4() - sync_user = SyncsUser( - id=1, - user_id=uuid4(), - name="c8xfz3g566b8xa1ajiesdh", - provider="mock", - credentials={}, - state={}, - additional_data={}, - ) - sync_active = SyncsActive( - id=1, - name="test", - syncs_user_id=1, - user_id=sync_user.user_id, - settings={}, - last_synced=str(datetime.now() - timedelta(hours=5)), - sync_interval_minutes=1, - brain_id=brain_id, - ) - - with pytest.raises(ValueError): - await syncutils.process_sync_file( - file=file, - previous_file=None, - current_user=sync_user, - sync_active=sync_active, - ) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_process_sync_file_noprev( - monkeypatch, - brain_user_setup: Tuple[Brain, User], - setup_syncs_data: Tuple[SyncsUser, SyncsActive], - syncutils: SyncUtils, - sync_file: SyncFile, -): - task = {} - - def _send_task(*args, **kwargs): - task["args"] = args - task["kwargs"] = {**kwargs["kwargs"]} - - monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) - - brain_1, _ = brain_user_setup - assert brain_1.brain_id - (sync_user, sync_active) = setup_syncs_data - await syncutils.process_sync_file( - file=sync_file, - previous_file=None, - current_user=sync_user, - sync_active=sync_active, - ) - - # Check notification inserted - assert ( - sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore - ) - assert ( - syncutils.notification_service.repository.received[ # type: ignore - sync_file.notification_id # type: ignore - ].status - == NotificationsStatusEnum.SUCCESS - ) - - # Check Syncfile created - dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id) - assert len(dbfiles) == 1 - assert dbfiles[0].brain_id == str(brain_1.brain_id) - assert dbfiles[0].syncs_active_id == sync_active.id - assert dbfiles[0].supported - - # Check knowledge created - all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain( - brain_1.brain_id - ) - assert len(all_km) == 1 - created_km = all_km[0] - assert created_km.file_name == sync_file.name - assert created_km.extension == ".txt" - assert created_km.file_sha1 is None - assert created_km.created_at is not None - assert created_km.metadata == {"sync_file_id": "1"} - assert len(created_km.brains)> 0 - assert created_km.brains[0]["brain_id"]== brain_1.brain_id - - # Assert celery task in correct - assert task["args"] == ("process_file_task",) - minimal_task_kwargs = { - "brain_id": brain_1.brain_id, - "knowledge_id": created_km.id, - "file_original_name": sync_file.name, - "source": syncutils.sync_cloud.name, - "notification_id": sync_file.notification_id, - } - all( - minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore - for key in minimal_task_kwargs - ) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_process_sync_file_with_prev( - monkeypatch, - supabase_client, - brain_user_setup: Tuple[Brain, User], - setup_syncs_data: Tuple[SyncsUser, SyncsActive], - syncutils: SyncUtils, - sync_file: SyncFile, - prev_file: SyncFile, -): - task = {} - - def _send_task(*args, **kwargs): - task["args"] = args - task["kwargs"] = {**kwargs["kwargs"]} - - monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) - brain_1, _ = brain_user_setup - assert brain_1.brain_id - (sync_user, sync_active) = setup_syncs_data - - # Run process_file on prev_file first - await syncutils.process_sync_file( - file=prev_file, - previous_file=None, - current_user=sync_user, - sync_active=sync_active, - ) - dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id) - assert len(dbfiles) == 1 - prev_dbfile = dbfiles[0] - - assert check_file_exists(str(brain_1.brain_id), prev_file.name) - prev_file_data = supabase_client.storage.from_("quivr").download( - f"{brain_1.brain_id}/{prev_file.name}" - ) - - ##### - # Run process_file on newer file - await syncutils.process_sync_file( - file=sync_file, - previous_file=prev_dbfile, - current_user=sync_user, - sync_active=sync_active, - ) - - # Check notification inserted - assert ( - sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore - ) - assert ( - syncutils.notification_service.repository.received[ # type: ignore - sync_file.notification_id # type: ignore - ].status - == NotificationsStatusEnum.SUCCESS - ) - - # Check Syncfile created - dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id) - assert len(dbfiles) == 1 - assert dbfiles[0].brain_id == str(brain_1.brain_id) - assert dbfiles[0].syncs_active_id == sync_active.id - assert dbfiles[0].supported - - # Check prev file was deleted and replaced with the new - all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain( - brain_1.brain_id - ) - assert len(all_km) == 1 - created_km = all_km[0] - assert created_km.file_name == sync_file.name - assert created_km.extension == ".txt" - assert created_km.file_sha1 is None - assert created_km.updated_at - assert created_km.created_at - assert created_km.updated_at == created_km.created_at # new line - assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)} - assert created_km.brains[0]["brain_id"]== brain_1.brain_id - - # Check file content changed - assert check_file_exists(str(brain_1.brain_id), sync_file.name) - new_file_data = supabase_client.storage.from_("quivr").download( - f"{brain_1.brain_id}/{sync_file.name}" - ) - assert new_file_data != prev_file_data, "Same file in prev_file and new file" - - # Assert celery task in correct - assert task["args"] == ("process_file_task",) - minimal_task_kwargs = { - "brain_id": brain_1.brain_id, - "knowledge_id": created_km.id, - "file_original_name": sync_file.name, - "source": syncutils.sync_cloud.name, - "notification_id": sync_file.notification_id, - } - all( - minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore - for key in minimal_task_kwargs - ) +# def test_should_download_file_lastsynctime_after(): +# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" +# file_not_folder = SyncFile( +# id="1", +# name="file_name", +# is_folder=False, +# last_modified=(datetime.now() - timedelta(hours=5)).strftime(datetime_format), +# mime_type="txt", +# web_view_link="link", +# ) +# last_sync_time = datetime.now().astimezone(timezone.utc) + +# assert not should_download_file( +# file=file_not_folder, +# last_updated_sync_active=last_sync_time, +# provider_name="google", +# datetime_format=datetime_format, +# ) + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_get_syncfiles_from_ids_nofolder(syncutils: SyncUtils): +# files = await syncutils.get_syncfiles_from_ids( +# credentials={}, files_ids=[str(uuid4())], folder_ids=[] +# ) +# assert len(files) == 1 + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_get_syncfiles_from_ids_folder(syncutils: SyncUtils): +# files = await syncutils.get_syncfiles_from_ids( +# credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())] +# ) +# assert len(files) == 2 + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_get_syncfiles_from_ids_notion(syncutils_notion: SyncUtils): +# files = await syncutils_notion.get_syncfiles_from_ids( +# credentials={}, files_ids=[str(uuid4())], folder_ids=[str(uuid4())] +# ) +# assert len(files) == 3 + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_download_file(syncutils: SyncUtils): +# file = SyncFile( +# id=str(uuid4()), +# name="test_file.txt", +# is_folder=False, +# last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format), +# mime_type="txt", +# web_view_link="", +# ) +# dfile = await syncutils.download_file(file, {}) +# assert dfile.extension == ".txt" +# assert dfile.file_name == file.name +# assert len(dfile.file_data.read()) > 0 + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_process_sync_file_not_supported(syncutils: SyncUtils): +# file = SyncFile( +# id=str(uuid4()), +# name="test_file.asldkjfalsdkjf", +# is_folder=False, +# last_modified=datetime.now().strftime(syncutils.sync_cloud.datetime_format), +# mime_type="txt", +# web_view_link="", +# notification_id=uuid4(), # +# ) +# brain_id = uuid4() +# sync_user = SyncsUser( +# id=1, +# user_id=uuid4(), +# name="c8xfz3g566b8xa1ajiesdh", +# provider="mock", +# credentials={}, +# state={}, +# additional_data={}, +# ) +# sync_active = SyncsActive( +# id=1, +# name="test", +# syncs_user_id=1, +# user_id=sync_user.user_id, +# settings={}, +# last_synced=str(datetime.now() - timedelta(hours=5)), +# sync_interval_minutes=1, +# brain_id=brain_id, +# ) + +# with pytest.raises(ValueError): +# await syncutils.process_sync_file( +# file=file, +# previous_file=None, +# current_user=sync_user, +# sync_active=sync_active, +# ) + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_process_sync_file_noprev( +# monkeypatch, +# brain_user_setup: Tuple[Brain, User], +# setup_syncs_data: Tuple[SyncsUser, SyncsActive], +# syncutils: SyncUtils, +# sync_file: SyncFile, +# ): +# task = {} + +# def _send_task(*args, **kwargs): +# task["args"] = args +# task["kwargs"] = {**kwargs["kwargs"]} + +# monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) + +# brain_1, _ = brain_user_setup +# assert brain_1.brain_id +# (sync_user, sync_active) = setup_syncs_data +# await syncutils.process_sync_file( +# file=sync_file, +# previous_file=None, +# current_user=sync_user, +# sync_active=sync_active, +# ) + +# # Check notification inserted +# assert ( +# sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore +# ) +# assert ( +# syncutils.notification_service.repository.received[ # type: ignore +# sync_file.notification_id # type: ignore +# ].status +# == NotificationsStatusEnum.SUCCESS +# ) + +# # Check Syncfile created +# dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id) +# assert len(dbfiles) == 1 +# assert dbfiles[0].brain_id == str(brain_1.brain_id) +# assert dbfiles[0].syncs_active_id == sync_active.id +# assert dbfiles[0].supported + +# # Check knowledge created +# all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain( +# brain_1.brain_id +# ) +# assert len(all_km) == 1 +# created_km = all_km[0] +# assert created_km.file_name == sync_file.name +# assert created_km.extension == ".txt" +# assert created_km.file_sha1 is None +# assert created_km.created_at is not None +# assert created_km.metadata == {"sync_file_id": "1"} +# assert len(created_km.brains) > 0 +# assert created_km.brains[0]["brain_id"] == brain_1.brain_id + +# # Assert celery task in correct +# assert task["args"] == ("process_file_task",) +# minimal_task_kwargs = { +# "brain_id": brain_1.brain_id, +# "knowledge_id": created_km.id, +# "file_original_name": sync_file.name, +# "source": syncutils.sync_cloud.name, +# "notification_id": sync_file.notification_id, +# } +# all( +# minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore +# for key in minimal_task_kwargs +# ) + + +# @pytest.mark.asyncio(loop_scope="session") +# async def test_process_sync_file_with_prev( +# monkeypatch, +# supabase_client, +# brain_user_setup: Tuple[Brain, User], +# setup_syncs_data: Tuple[SyncsUser, SyncsActive], +# syncutils: SyncUtils, +# sync_file: SyncFile, +# prev_file: SyncFile, +# ): +# task = {} + +# def _send_task(*args, **kwargs): +# task["args"] = args +# task["kwargs"] = {**kwargs["kwargs"]} + +# monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) +# brain_1, _ = brain_user_setup +# assert brain_1.brain_id +# (sync_user, sync_active) = setup_syncs_data + +# # Run process_file on prev_file first +# await syncutils.process_sync_file( +# file=prev_file, +# previous_file=None, +# current_user=sync_user, +# sync_active=sync_active, +# ) +# dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id) +# assert len(dbfiles) == 1 +# prev_dbfile = dbfiles[0] + +# assert check_file_exists(str(brain_1.brain_id), prev_file.name) +# prev_file_data = supabase_client.storage.from_("quivr").download( +# f"{brain_1.brain_id}/{prev_file.name}" +# ) + +# ##### +# # Run process_file on newer file +# await syncutils.process_sync_file( +# file=sync_file, +# previous_file=prev_dbfile, +# current_user=sync_user, +# sync_active=sync_active, +# ) + +# # Check notification inserted +# assert ( +# sync_file.notification_id in syncutils.notification_service.repository.received # type: ignore +# ) +# assert ( +# syncutils.notification_service.repository.received[ # type: ignore +# sync_file.notification_id # type: ignore +# ].status +# == NotificationsStatusEnum.SUCCESS +# ) + +# # Check Syncfile created +# dbfiles: list[DBSyncFile] = syncutils.sync_files_repo.get_sync_files(sync_active.id) +# assert len(dbfiles) == 1 +# assert dbfiles[0].brain_id == str(brain_1.brain_id) +# assert dbfiles[0].syncs_active_id == sync_active.id +# assert dbfiles[0].supported + +# # Check prev file was deleted and replaced with the new +# all_km = await syncutils.knowledge_service.get_all_knowledge_in_brain( +# brain_1.brain_id +# ) +# assert len(all_km) == 1 +# created_km = all_km[0] +# assert created_km.file_name == sync_file.name +# assert created_km.extension == ".txt" +# assert created_km.file_sha1 is None +# assert created_km.updated_at +# assert created_km.created_at +# assert created_km.updated_at == created_km.created_at # new line +# assert created_km.metadata == {"sync_file_id": str(dbfiles[0].id)} +# assert created_km.brains[0]["brain_id"] == brain_1.brain_id + +# # Check file content changed +# assert check_file_exists(str(brain_1.brain_id), sync_file.name) +# new_file_data = supabase_client.storage.from_("quivr").download( +# f"{brain_1.brain_id}/{sync_file.name}" +# ) +# assert new_file_data != prev_file_data, "Same file in prev_file and new file" + +# # Assert celery task in correct +# assert task["args"] == ("process_file_task",) +# minimal_task_kwargs = { +# "brain_id": brain_1.brain_id, +# "knowledge_id": created_km.id, +# "file_original_name": sync_file.name, +# "source": syncutils.sync_cloud.name, +# "notification_id": sync_file.notification_id, +# } +# all( +# minimal_task_kwargs[key] == task["kwargs"][key] # type: ignore +# for key in minimal_task_kwargs +# ) diff --git a/backend/api/quivr_api/modules/sync/utils/oauth2.py b/backend/api/quivr_api/modules/sync/utils/oauth2.py new file mode 100644 index 000000000000..cc8be15db630 --- /dev/null +++ b/backend/api/quivr_api/modules/sync/utils/oauth2.py @@ -0,0 +1,9 @@ +from uuid import UUID + +from pydantic import BaseModel + + +class Oauth2State(BaseModel): + sync_id: int | None = None + name: str + user_id: UUID diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py index 2a2e5563ef94..861ebf7306fc 100644 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ b/backend/api/quivr_api/modules/sync/utils/syncutils.py @@ -16,18 +16,9 @@ from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) -from quivr_api.modules.sync.dto.inputs import SyncsActiveUpdateInput from quivr_api.modules.sync.entity.sync_models import ( - DBSyncFile, DownloadedSyncFile, SyncFile, - SyncsActive, - SyncsUser, -) -from quivr_api.modules.sync.repository.sync_interfaces import SyncFileInterface -from quivr_api.modules.sync.service.sync_service import ( - ISyncService, - ISyncUserService, ) from quivr_api.modules.sync.utils.sync import BaseSync from quivr_api.modules.upload.service.upload_file import ( diff --git a/backend/worker/tests/test_sync.py b/backend/worker/tests/test_sync.py new file mode 100644 index 000000000000..e69de29bb2d1 From ae70171d34162fc631522b437626eaf9470f2aca Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 18 Sep 2024 15:46:54 +0200 Subject: [PATCH 05/63] oauth2 refacto --- .../sync/controller/azure_sync_routes.py | 88 +++++-------------- .../sync/controller/dropbox_sync_routes.py | 75 +++++----------- .../sync/controller/github_sync_routes.py | 61 +++---------- .../sync/controller/google_sync_routes.py | 60 +++---------- .../sync/controller/notion_sync_routes.py | 55 ++---------- .../modules/sync/controller/sync_routes.py | 2 +- .../api/quivr_api/modules/sync/dto/outputs.py | 1 + .../modules/sync/entity/sync_models.py | 4 +- .../sync/repository/sync_repository.py | 12 +-- .../modules/sync/service/sync_service.py | 44 +++++++++- .../quivr_api/modules/sync/utils/oauth2.py | 26 +++++- 11 files changed, 151 insertions(+), 277 deletions(-) diff --git a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py index 9a4585c45ce6..ad34adf38c4f 100644 --- a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py @@ -8,10 +8,9 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.oauth2 import Oauth2State -from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException +from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -63,29 +62,19 @@ async def authorize_azure( CLIENT_ID, client_credential=CLIENT_SECRET, authority=AUTHORITY ) logger.debug(f"Authorizing Azure sync for user: {current_user.id}") - state_struct = Oauth2State(name=name, user_id=current_user.id) - state = state_struct.model_dump_json() - - sync_user_input = SyncCreateInput( - user_id=current_user.id, - name=name, - provider="Azure", - credentials={}, - state={"state": state}, + state = await syncs_service.create_oauth2_state( + provider="Azure", name=name, user_id=current_user.id ) - sync = await syncs_service.create_sync_user(sync_user_input) - state_struct.sync_id = sync.id - state = state_struct.model_dump_json() - flow = client.initiate_auth_code_flow( - scopes=SCOPE, redirect_uri=REDIRECT_URI, state=state, prompt="select_account" + scopes=SCOPE, + redirect_uri=REDIRECT_URI, + state=state.model_dump_json(), + prompt="select_account", ) - - sync = await syncs_service.update_sync( - sync_id=sync.id, - sync_user_input=SyncUpdateInput( - **{**sync.model_dump(), "additional_data": {"flow": flow}} - ), + # Azure needs additional data + await syncs_service.update_sync( + sync_id=state.sync_id, + sync_user_input=SyncUpdateInput(additional_data={"flow": flow}), ) return {"authorization_url": flow["auth_uri"]} @@ -107,57 +96,27 @@ async def oauth2callback_azure( client = ConfidentialClientApplication( CLIENT_ID, client_credential=CLIENT_SECRET, authority=AUTHORITY ) - state = request.query_params.get("state") - - if not state: - raise HTTPException(status_code=400, detail="Invalid state parameter") - - state = Oauth2State.model_validate_json(state) - - if state.sync_id is None: - raise HTTPException( - status_code=400, detail="Invalid state parameter. Unknown sync" - ) - + state_str = request.query_params.get("state") + state = parse_oauth2_state(state_str) logger.debug( f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" ) - - try: - sync = await syncs_service.get_sync_by_id(state.sync_id) - except SyncNotFoundException as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" - ) - if ( - not sync - or not sync.state - or state.model_dump(exclude={"sync_id"}) != sync.state["state"] - ): - logger.error("Invalid state parameter") - raise HTTPException(status_code=400, detail="Invalid state parameter") - - if sync.user_id != state.user_id: - raise HTTPException(status_code=400, detail="Invalid user") - + sync = await syncs_service.get_from_oauth2_state(state) if sync.additional_data is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Invalid sync data" - ) + raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY) - result = client.acquire_token_by_auth_code_flow( + flow_data = client.acquire_token_by_auth_code_flow( sync.additional_data["flow"], dict(request.query_params) ) - if "access_token" not in result: - logger.error(f"Failed to acquire token: {result}") + + if "access_token" not in flow_data: + logger.error(f"Failed to acquire token: {flow_data}") raise HTTPException( status_code=400, - detail=f"Failed to acquire token: {result}", + detail=f"Failed to acquire token: {flow_data}", ) - access_token = result["access_token"] - - creds = result + access_token = flow_data["access_token"] logger.info(f"Fetched OAuth2 token for user: {state.user_id}") # Fetch user email from Microsoft Graph API @@ -172,8 +131,7 @@ async def oauth2callback_azure( user_email = user_info.get("mail") or user_info.get("userPrincipalName") logger.info(f"Retrieved email for user: {state.user_id} - {user_email}") - sync_user_input = SyncUpdateInput(credentials=result, state={}, email=user_email) - + sync_user_input = SyncUpdateInput(credentials=flow_data, state={}, email=user_email) await syncs_service.update_sync(state.sync_id, sync_user_input) logger.info(f"Azure sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py index f461dde461eb..ce53dca1b060 100644 --- a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py @@ -1,16 +1,16 @@ import os +from typing import Tuple from dropbox import Dropbox, DropboxOAuth2Flow -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import HTMLResponse from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.oauth2 import Oauth2State -from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException +from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -63,25 +63,26 @@ async def authorize_dropbox( token_access_type="offline", scope=SCOPE, ) - state_struct = Oauth2State(name=name, user_id=current_user.id) - sync_user_input = SyncCreateInput( - name=name, - user_id=current_user.id, - provider="DropBox", - credentials={}, - state={"state": state_struct.model_dump_json()}, - additional_data={}, + state = await syncs_service.create_oauth2_state( + provider="DropBox", name=name, user_id=current_user.id ) - sync = await syncs_service.create_sync_user(sync_user_input) - state_struct.sync_id = sync.id - state = state_struct.model_dump_json() - authorize_url = auth_flow.start(state) + authorize_url = auth_flow.start(state.model_dump_json()) logger.info( f"Generated authorization URL: {authorize_url} for user: {current_user.id}" ) return {"authorization_url": authorize_url} +def parse_dropbox_oauth2_session(state_str: str | None) -> Tuple[dict[str, str], str]: + if state_str is None: + raise ValueError + session = {} + session["csrf-token"] = state_str.split("|")[0] if "|" in state_str else "" + logger.debug("Keys in session : %s", session.keys()) + logger.debug("Value in session : %s", session.values()) + return session, state_str.split("|")[1] + + @dropbox_sync_router.get("/sync/dropbox/oauth2callback", tags=["Sync"]) async def oauth2callback_dropbox( request: Request, @@ -96,42 +97,10 @@ async def oauth2callback_dropbox( Returns: dict: A dictionary containing a success message. """ - state = request.query_params.get("state") - if not state: - raise HTTPException(status_code=400, detail="Invalid state parameter") - session = {} - session["csrf-token"] = state.split("|")[0] if "|" in state else "" - - logger.debug("Keys in session : %s", session.keys()) - logger.debug("Value in session : %s", session.values()) - - state = Oauth2State.model_validate_json(state) - - if state.sync_id is None: - raise HTTPException( - status_code=400, detail="Invalid state parameter. Unknown sync" - ) - - logger.debug( - f"Handling OAuth2 callback for user: {state.user_id} with state: {state} " - ) - try: - sync = await syncs_service.get_sync_by_id(state.sync_id) - except SyncNotFoundException as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" - ) - - if ( - not sync - or not sync.state - or state.model_dump(exclude={"sync_id"}) != sync.state["state"] - ): - logger.error("Invalid state parameter") - raise HTTPException(status_code=400, detail="Invalid state parameter") - - if sync.user_id != state.user_id: - raise HTTPException(status_code=400, detail="Invalid user") + state_str = request.query_params.get("state") + session, state_str = parse_dropbox_oauth2_session(state_str) + state = parse_oauth2_state(state_str) + sync = await syncs_service.get_from_oauth2_state(state) auth_flow = DropboxOAuth2Flow( DROPBOX_APP_KEY, @@ -166,7 +135,7 @@ async def oauth2callback_dropbox( state={}, email=user_email, ) - await syncs_service.update_sync(state.sync_id, sync_user_input) + await syncs_service.update_sync(sync.id, sync_user_input) logger.info(f"DropBox sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) except Exception as e: diff --git a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py index 048fa5ad67d8..933b881b432f 100644 --- a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py @@ -1,16 +1,15 @@ import os import requests -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import HTMLResponse from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.oauth2 import Oauth2State -from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException +from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -54,24 +53,12 @@ async def authorize_github( dict: A dictionary containing the authorization URL. """ logger.debug(f"Authorizing GitHub sync for user: {current_user.id}") - - state_struct = Oauth2State(name=name, user_id=current_user.id) - state = state_struct.model_dump_json() - - sync_user_input = SyncCreateInput( - user_id=current_user.id, - name=name, - provider="GitHub", - credentials={}, - state={"state": state}, + state = await syncs_service.create_oauth2_state( + provider="Github", name=name, user_id=current_user.id ) - sync = await syncs_service.create_sync_user(sync_user_input) - state_struct.sync_id = sync.id - state = state_struct.model_dump_json() - authorization_url = ( f"https://github.com/login/oauth/authorize?client_id={CLIENT_ID}" - f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state}" + f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state.model_dump_json()}" ) return {"authorization_url": authorization_url} @@ -89,39 +76,12 @@ async def oauth2callback_github( Returns: dict: A dictionary containing a success message. """ - state = request.query_params.get("state") - - if not state: - raise HTTPException(status_code=400, detail="Invalid state parameter") - - state = Oauth2State.model_validate_json(state) - - if state.sync_id is None: - raise HTTPException( - status_code=400, detail="Invalid state parameter. Unknown sync" - ) - + state_str = request.query_params.get("state") + state = parse_oauth2_state(state_str) logger.debug( f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" ) - - try: - sync = await syncs_service.get_sync_by_id(state.sync_id) - except SyncNotFoundException as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" - ) - if ( - not sync - or not sync.state - or state.model_dump(exclude={"sync_id"}) != sync.state["state"] - ): - logger.error("Invalid state parameter") - raise HTTPException(status_code=400, detail="Invalid state parameter") - - if sync.user_id != state.user_id: - raise HTTPException(status_code=400, detail="Invalid user") - + sync = await syncs_service.get_from_oauth2_state(state) token_url = "https://github.com/login/oauth/access_token" data = { "client_id": CLIENT_ID, @@ -176,6 +136,7 @@ async def oauth2callback_github( sync_user_input = SyncUpdateInput(credentials=result, state={}, email=user_email) - await syncs_service.update_sync(state.sync_id, sync_user_input) + # TODO: This an additional select query :( + await syncs_service.update_sync(sync.id, sync_user_input) logger.info(f"GitHub sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py index 75e9c2a3b1df..d2bb10b716e6 100644 --- a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py @@ -1,7 +1,7 @@ import json import os -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse from google_auth_oauthlib.flow import Flow from googleapiclient.discovery import build @@ -9,10 +9,9 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.oauth2 import Oauth2State -from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException +from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -88,23 +87,14 @@ async def authorize_google( scopes=SCOPES, redirect_uri=redirect_uri, ) - state_struct = Oauth2State(name=name, user_id=current_user.id) - state = state_struct.model_dump_json() - sync_user_input = SyncCreateInput( - name=name, - user_id=current_user.id, - provider="Google", - credentials={}, - state={"state": state}, - additional_data={}, + + state = await syncs_service.create_oauth2_state( + provider="Google", name=name, user_id=current_user.id ) - sync = await syncs_service.create_sync_user(sync_user_input) - state_struct.sync_id = sync.id - state = state_struct.model_dump_json() authorization_url, state = flow.authorization_url( access_type="offline", include_granted_scopes="true", - state=state, + state=state.model_dump_json(), prompt="consent", ) logger.info( @@ -127,43 +117,17 @@ async def oauth2callback_google( Returns: dict: A dictionary containing a success message. """ - state = request.query_params.get("state") - logger.debug(f"request state: {state}") - if not state: - raise HTTPException(status_code=400, detail="Invalid state parameter") - - state = Oauth2State.model_validate_json(state) - if state.sync_id is None: - raise HTTPException( - status_code=400, detail="Invalid state parameter. Unknown sync" - ) - + state_str = request.query_params.get("state") + state = parse_oauth2_state(state_str) logger.debug( f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" ) - - try: - sync = await syncs_service.get_sync_by_id(state.sync_id) - except SyncNotFoundException as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" - ) - if ( - not sync - or not sync.state - or state.model_dump(exclude={"sync_id"}) != sync.state["state"] - ): - logger.error("Invalid state parameter") - raise HTTPException(status_code=400, detail="Invalid state parameter") - - if sync.user_id != state.user_id: - raise HTTPException(status_code=400, detail="Invalid user") - + sync = await syncs_service.get_from_oauth2_state(state) redirect_uri = f"{BASE_REDIRECT_URI}" flow = Flow.from_client_config( CLIENT_SECRETS_FILE_CONTENT, scopes=SCOPES, - state=state, + state=state_str, redirect_uri=redirect_uri, ) flow.fetch_token(authorization_response=str(request.url)) @@ -181,6 +145,6 @@ async def oauth2callback_google( state={}, email=user_email, ) - sync = await syncs_service.update_sync(state.sync_id, sync_user_input) + sync = await syncs_service.update_sync(sync.id, sync_user_input) logger.info(f"Google Drive sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py index 9aec1a3e32b9..baf4188d9b0f 100644 --- a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py @@ -2,7 +2,7 @@ import os import requests -from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import HTMLResponse from notion_client import Client @@ -10,10 +10,9 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.oauth2 import Oauth2State -from quivr_api.modules.sync.utils.sync_exceptions import SyncNotFoundException +from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity from .successfull_connection import successfullConnectionPage @@ -57,20 +56,11 @@ async def authorize_notion( dict: A dictionary containing the authorization URL. """ logger.debug(f"Authorizing Notion sync for user: {current_user.id}, name : {name}") - state_struct = Oauth2State(name=name, user_id=current_user.id) - state = state_struct.model_dump_json() - sync_user_input = SyncCreateInput( - name=name, - user_id=current_user.id, - provider="Notion", - credentials={}, - state={"state": state}, + state = await syncs_service.create_oauth2_state( + provider="Notion", name=name, user_id=current_user.id ) - sync = await syncs_service.create_sync_user(sync_user_input) - state_struct.sync_id = sync.id - state = state_struct.model_dump_json() # Finalize the state - authorize_url = str(NOTION_AUTH_URL) + f"&state={state}" + authorize_url = str(NOTION_AUTH_URL) + f"&state={state.model_dump_json()}" logger.debug( f"Generated authorization URL: {authorize_url} for user: {current_user.id}" ) @@ -92,38 +82,9 @@ async def oauth2callback_notion( dict: A dictionary containing a success message. """ code = request.query_params.get("code") - state = request.query_params.get("state") - if not state: - raise HTTPException(status_code=400, detail="Invalid state parameter") - - state = Oauth2State.model_validate_json(state) - - if state.sync_id is None: - raise HTTPException( - status_code=400, detail="Invalid state parameter. Unknown sync" - ) - - logger.debug( - f"Handling OAuth2 callback for user: {state.user_id} with state: {state}" - ) - - try: - sync = await syncs_service.get_sync_by_id(state.sync_id) - except SyncNotFoundException as e: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" - ) - if ( - not sync - or not sync.state - or state.model_dump(exclude={"sync_id"}) != sync.state["state"] - ): - logger.error("Invalid state parameter") - raise HTTPException(status_code=400, detail="Invalid state parameter") - - if sync.user_id != state.user_id: - raise HTTPException(status_code=400, detail="Invalid user") + state_str = request.query_params.get("state") + state = parse_oauth2_state(state_str) try: token_url = "https://api.notion.com/v1/oauth/token" diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index ccea41a4dc5f..b3434e5026a7 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -113,7 +113,7 @@ async def get_user_syncs( List: A list of syncs for the user. """ logger.debug(f"Fetching user syncs for user: {current_user.id}") - return await syncs_service.get_syncs(current_user.id) + return await syncs_service.get_user_syncs(current_user.id) @sync_router.delete( diff --git a/backend/api/quivr_api/modules/sync/dto/outputs.py b/backend/api/quivr_api/modules/sync/dto/outputs.py index 5750e8f60137..3bdf004c2e36 100644 --- a/backend/api/quivr_api/modules/sync/dto/outputs.py +++ b/backend/api/quivr_api/modules/sync/dto/outputs.py @@ -28,3 +28,4 @@ class SyncsOutput(BaseModel): provider: SyncProvider state: dict | None credentials: dict | None + additional_data: dict | None diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index 8612dba21c9f..2189263fd307 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -54,6 +54,7 @@ class Syncs(SQLModel, table=True): __tablename__ = "syncs_user" # type: ignore id: int | None = Field(default=None, primary_key=True) + email: str | None = Field(default=None) user_id: UUID = Field(foreign_key="users.id", nullable=False) name: str provider: str @@ -70,9 +71,10 @@ def to_dto(self) -> SyncsOutput: return SyncsOutput( id=self.id, user_id=self.user_id, - provider=SyncProvider(self.provider), + provider=SyncProvider(self.provider.lower()), credentials=self.credentials, state=self.state, + additional_data=self.additional_data, ) diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index d81894bb2a67..21c66e1d846c 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -101,15 +101,11 @@ async def get_syncs(self, user_id: UUID, sync_id: int | None = None): user_id, sync_id, ) - query = select(Syncs).where(Syncs.id == sync_id).where(Syncs.user_id == user_id) + query = select(Syncs).where(Syncs.user_id == user_id) + if sync_id is not None: + query = query.where(Syncs.id == sync_id) result = await self.session.exec(query) - sync = result.first() - if not sync: - logger.error( - f"No sync user found for sync_id: {sync_id}", - ) - raise SyncNotFoundException() - return sync + return result.all() async def get_sync_user_by_state(self, state: dict) -> Syncs: """ diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index 4ec7259ac67a..7fee5ed31a5a 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -1,5 +1,8 @@ +from typing import Any from uuid import UUID +from fastapi import HTTPException + from quivr_api.logger import get_logger from quivr_api.modules.dependencies import BaseService from quivr_api.modules.sync.dto.inputs import ( @@ -8,6 +11,7 @@ ) from quivr_api.modules.sync.dto.outputs import SyncsOutput from quivr_api.modules.sync.repository.sync_repository import SyncsRepository +from quivr_api.modules.sync.utils.oauth2 import Oauth2BaseState, Oauth2State logger = get_logger(__name__) @@ -22,8 +26,8 @@ async def create_sync_user(self, sync_user_input: SyncCreateInput) -> SyncsOutpu sync = await self.repository.create_sync(sync_user_input) return sync.to_dto() - async def get_syncs(self, user_id: UUID, sync_id: int | None = None): - return await self.repository.get_syncs(user_id, sync_id) + async def get_user_syncs(self, user_id: UUID, sync_id: int | None = None): + return await self.repository.get_syncs(user_id=user_id, sync_id=sync_id) async def delete_sync(self, sync_id: int, user_id: UUID): await self.repository.delete_sync(sync_id, user_id) @@ -31,6 +35,42 @@ async def delete_sync(self, sync_id: int, user_id: UUID): async def get_sync_by_id(self, sync_id: int): return await self.repository.get_sync_id(sync_id) + async def get_from_oauth2_state(self, state: Oauth2State) -> SyncsOutput: + assert state.sync_id, "state should have associated sync_id" + sync = await self.get_sync_by_id(state.sync_id) + + # TODO: redo these exceptions + if ( + not sync + or not sync.state + or state.model_dump_json(exclude={"sync_id"}) != sync.state["state"] + ): + logger.error("Invalid state parameter") + raise HTTPException(status_code=400, detail="Invalid state parameter") + if sync.user_id != state.user_id: + raise HTTPException(status_code=400, detail="Invalid user") + return sync.to_dto() + + async def create_oauth2_state( + self, + provider: str, + name: str, + user_id: UUID, + additional_data: dict[str, Any] = {}, + ) -> Oauth2State: + state_struct = Oauth2BaseState(name=name, user_id=user_id) + state = state_struct.model_dump_json() + sync_user_input = SyncCreateInput( + name=name, + user_id=user_id, + provider=provider, + credentials={}, + state={"state": state}, + additional_data=additional_data, + ) + sync = await self.create_sync_user(sync_user_input) + return Oauth2State(sync_id=sync.id, **state_struct.model_dump()) + async def update_sync( self, sync_id: int, sync_user_input: SyncUpdateInput ) -> SyncsOutput: diff --git a/backend/api/quivr_api/modules/sync/utils/oauth2.py b/backend/api/quivr_api/modules/sync/utils/oauth2.py index cc8be15db630..e2344caf4225 100644 --- a/backend/api/quivr_api/modules/sync/utils/oauth2.py +++ b/backend/api/quivr_api/modules/sync/utils/oauth2.py @@ -1,9 +1,31 @@ from uuid import UUID +from fastapi import HTTPException, status from pydantic import BaseModel +from quivr_api.logger import get_logger -class Oauth2State(BaseModel): - sync_id: int | None = None +logger = get_logger(__name__) + + +class Oauth2BaseState(BaseModel): name: str user_id: UUID + + +class Oauth2State(Oauth2BaseState): + sync_id: int + + +def parse_oauth2_state(state_str: str | None) -> Oauth2State: + if not state_str: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid state parameter" + ) + + state = Oauth2State.model_validate_json(state_str) + if state.sync_id is None: + raise HTTPException( + status_code=400, detail="Invalid state parameter. Unknown sync" + ) + return state From c5d372ba8016283bb29e48a6dee75a7da12742f1 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 18 Sep 2024 16:06:21 +0200 Subject: [PATCH 06/63] ruff formatter vscode --- .vscode/settings.json | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 700a8799b95a..9cd662a50680 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,8 +1,9 @@ { "editor.codeActionsOnSave": { "source.organizeImports": "explicit", + "source.organizeImports.ruff": "explicit", "source.fixAll": "explicit", - "source.unusedImports": "explicit", + "source.unusedImports": "explicit" }, "editor.formatOnSave": true, "editor.formatOnSaveMode": "file", @@ -16,7 +17,6 @@ "**/.docusaurus/": true, "**/node_modules/": true }, - "json.sortOnSave.enable": true, "[python]": { "editor.defaultFormatter": "charliermarsh.ruff", "editor.formatOnSave": true, @@ -25,19 +25,8 @@ "source.fixAll": "explicit" } }, - "python.formatting.provider": "black", - "python.analysis.extraPaths": [ - "./backend" - ], - "python.sortImports.path": "isort", - "python.linting.mypyEnabled": true, + "python.analysis.extraPaths": ["./backend"], "python.defaultInterpreterPath": "python3", - "python.linting.enabled": true, - "python.linting.flake8Enabled": true, - "python.linting.pycodestyleEnabled": true, - "python.linting.pylintEnabled": true, - "python.linting.pycodestyleCategorySeverity.W": "Error", - "python.linting.flake8CategorySeverity.W": "Error", "python.testing.pytestArgs": [ "-v", "--color=yes", @@ -54,4 +43,4 @@ "reportUnusedImport": "warning", "reportGeneralTypeIssues": "warning" } -} \ No newline at end of file +} From b0e76e307ebd5af6b137ddd53d7d689e460abb6f Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 18 Sep 2024 18:22:27 +0200 Subject: [PATCH 07/63] syncs + cleanupdb --- backend/api/quivr_api/__init__.py | 3 +- .../modules/knowledge/dto/outputs.py | 23 +++ .../modules/knowledge/entity/knowledge.py | 13 +- .../modules/sync/entity/sync_models.py | 23 ++- .../sync/repository/sync_repository.py | 22 +- .../modules/sync/tests/test_sync_service.py | 4 +- .../modules/sync/tests/test_syncutils.py | 188 +++++++++--------- backend/core/quivr_core/models.py | 5 +- ...20240905153004_knowledge-folders copy.sql} | 0 .../20240918180003_knowledge-sync.sql | 24 +++ backend/supabase/seed.sql | 14 -- 11 files changed, 189 insertions(+), 130 deletions(-) rename backend/supabase/migrations/{20240905153004_knowledge-folders.sql => 20240905153004_knowledge-folders copy.sql} (100%) create mode 100644 backend/supabase/migrations/20240918180003_knowledge-sync.sql diff --git a/backend/api/quivr_api/__init__.py b/backend/api/quivr_api/__init__.py index f25c4b4b9308..c75182e5d2d7 100644 --- a/backend/api/quivr_api/__init__.py +++ b/backend/api/quivr_api/__init__.py @@ -2,7 +2,7 @@ from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from .modules.chat.entity.chat import Chat, ChatHistory -from .modules.sync.entity.sync_models import NotionSyncFile +from .modules.sync.entity.sync_models import NotionSyncFile, Sync from .modules.user.entity.user_identity import User __all__ = [ @@ -12,4 +12,5 @@ "NotionSyncFile", "KnowledgeDB", "Brain", + "Sync", ] diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py index 20218dfce3e6..3e804be9e6cc 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py @@ -1,9 +1,32 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional, Self from uuid import UUID from pydantic import BaseModel +from quivr_core.models import KnowledgeStatus class DeleteKnowledgeResponse(BaseModel): file_name: str | None = None status: str = "DELETED" knowledge_id: UUID + + +class KnowledgeOut(BaseModel): + id: UUID + file_size: int = 0 + status: KnowledgeStatus + file_name: Optional[str] = None + url: Optional[str] = None + extension: str = ".txt" + is_folder: bool = False + updated_at: datetime + created_at: datetime + source: Optional[str] = None + source_link: Optional[str] = None + file_sha1: Optional[str] = None + metadata: Optional[Dict[str, str]] = None + user_id: UUID + brains: List[Dict[str, Any]] + parent: Optional[Self] + children: Optional[list[Self]] diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index d890ee42d1c7..5fb2bd70449c 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Self from uuid import UUID from pydantic import BaseModel @@ -11,6 +11,7 @@ from sqlmodel import Field, Relationship, SQLModel from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain +from quivr_api.modules.sync.entity.sync_models import Sync class KnowledgeSource(str, Enum): @@ -37,8 +38,8 @@ class Knowledge(BaseModel): metadata: Optional[Dict[str, str]] = None user_id: UUID brains: List[Dict[str, Any]] - parent: Optional["Knowledge"] - children: Optional[list["Knowledge"]] + parent: Optional[Self] + children: Optional[List[Self]] class KnowledgeUpdate(BaseModel): @@ -113,6 +114,12 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): "cascade": "all, delete-orphan", }, ) + sync_id: int | None = Field( + default=None, foreign_key="syncs.id", ondelete="CASCADE" + ) + sync: Sync | None = Relationship( + back_populates="knowledges", sa_relationship_kwargs={"lazy": "joined"} + ) # TODO: nested folder search async def to_dto(self, get_children: bool = True) -> Knowledge: diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index 2189263fd307..9c6cf437bf32 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -2,7 +2,7 @@ import io from dataclasses import dataclass from datetime import datetime -from typing import Dict, Optional +from typing import Dict, List, Optional from uuid import UUID from pydantic import BaseModel @@ -50,8 +50,8 @@ class SyncFile(BaseModel): type: Optional[str] = None -class Syncs(SQLModel, table=True): - __tablename__ = "syncs_user" # type: ignore +class Sync(SQLModel, table=True): + __tablename__ = "syncs" # type: ignore id: int | None = Field(default=None, primary_key=True) email: str | None = Field(default=None) @@ -62,9 +62,26 @@ class Syncs(SQLModel, table=True): default=None, sa_column=Column("credentials", JSON) ) state: Dict[str, str] | None = Field(default=None, sa_column=Column("state", JSON)) + created_at: datetime | None = Field( + default=None, + sa_column=Column( + TIMESTAMP(timezone=False), + server_default=text("CURRENT_TIMESTAMP"), + ), + ) + updated_at: datetime | None = Field( + default=None, + sa_column=Column( + TIMESTAMP(timezone=False), + server_default=text("CURRENT_TIMESTAMP"), + onupdate=datetime.utcnow, + ), + ) + last_synced_at: datetime | None = Field(default=None) additional_data: dict | None = Field( default=None, sa_column=Column("additional_data", JSON) ) + knowledges: List["KnowledgeDB"] | None = Relationship(back_populates="sync") def to_dto(self) -> SyncsOutput: assert self.id, "can't create create output if sync isn't inserted" diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index 21c66e1d846c..bda27d9ecf6e 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -9,7 +9,7 @@ from quivr_api.modules.dependencies import BaseRepository, get_supabase_client from quivr_api.modules.sync.dto.inputs import SyncCreateInput, SyncUpdateInput from quivr_api.modules.sync.dto.outputs import SyncProvider -from quivr_api.modules.sync.entity.sync_models import SyncFile, Syncs +from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile from quivr_api.modules.sync.repository.notion_repository import NotionRepository from quivr_api.modules.sync.service.sync_notion import SyncNotionService from quivr_api.modules.sync.utils.sync import ( @@ -48,7 +48,7 @@ def __init__(self, session: AsyncSession): async def create_sync( self, sync_user_input: SyncCreateInput, - ) -> Syncs: + ) -> Sync: """ Create a new sync user in the database. @@ -59,7 +59,7 @@ async def create_sync( """ logger.info("Creating sync user with input: %s", sync_user_input) try: - sync = Syncs.model_validate(sync_user_input.model_dump()) + sync = Sync.model_validate(sync_user_input.model_dump()) self.session.add(sync) await self.session.commit() await self.session.refresh(sync) @@ -71,11 +71,11 @@ async def create_sync( await self.session.rollback() raise - async def get_sync_id(self, sync_id: int) -> Syncs: + async def get_sync_id(self, sync_id: int) -> Sync: """ Retrieve sync users from the database. """ - query = select(Syncs).where(Syncs.id == sync_id) + query = select(Sync).where(Sync.id == sync_id) result = await self.session.exec(query) sync = result.first() if not sync: @@ -101,13 +101,13 @@ async def get_syncs(self, user_id: UUID, sync_id: int | None = None): user_id, sync_id, ) - query = select(Syncs).where(Syncs.user_id == user_id) + query = select(Sync).where(Sync.user_id == user_id) if sync_id is not None: - query = query.where(Syncs.id == sync_id) + query = query.where(Sync.id == sync_id) result = await self.session.exec(query) return result.all() - async def get_sync_user_by_state(self, state: dict) -> Syncs: + async def get_sync_user_by_state(self, state: dict) -> Sync: """ Retrieve a sync user by their state. @@ -119,7 +119,7 @@ async def get_sync_user_by_state(self, state: dict) -> Syncs: """ logger.info("Getting sync user by state: %s", state) - query = select(Syncs).where(Syncs.state == state) + query = select(Sync).where(Sync.state == state) result = await self.session.exec(query) sync = result.first() if not sync: @@ -133,12 +133,12 @@ async def delete_sync(self, sync_id: int, user_id: UUID): "Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id ) await self.session.execute( - delete(Syncs).where(Syncs.id == sync_id).where(Syncs.user_id == user_id) + delete(Sync).where(Sync.id == sync_id).where(Sync.user_id == user_id) ) logger.info("Sync user deleted successfully") async def update_sync( - self, sync: Syncs, sync_input: SyncUpdateInput | dict[str, Any] + self, sync: Sync, sync_input: SyncUpdateInput | dict[str, Any] ): logger.debug( "Updating sync user with user_id: %s, state: %s, input: %s", diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_service.py b/backend/api/quivr_api/modules/sync/tests/test_sync_service.py index cf398efb3d82..19fc5f041f7d 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_sync_service.py +++ b/backend/api/quivr_api/modules/sync/tests/test_sync_service.py @@ -3,7 +3,7 @@ from sqlmodel import select from quivr_api.modules.sync.dto.outputs import SyncProvider -from quivr_api.modules.sync.entity.sync_models import Syncs +from quivr_api.modules.sync.entity.sync_models import Sync from quivr_api.modules.sync.repository.sync_repository import SyncsRepository from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import User @@ -21,7 +21,7 @@ async def user(session): async def test_sync(session, user): assert user.id - sync = Syncs( + sync = Sync( user_id=user.id, name="test_sync", provider=SyncProvider.GOOGLE, diff --git a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py index 95d1717a2c65..5685965fddf5 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_syncutils.py +++ b/backend/api/quivr_api/modules/sync/tests/test_syncutils.py @@ -57,100 +57,100 @@ # ] == filter_on_supported_files(files, existing_files) -def test_should_download_file_no_sync_time_not_folder(): - datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" - file_not_folder = SyncFile( - id="1", - name="file_name", - is_folder=False, - last_modified=datetime.now().strftime(datetime_format), - mime_type="txt", - web_view_link="link", - ) - assert should_download_file( - file=file_not_folder, - last_updated_sync_active=None, - provider_name="google", - datetime_format=datetime_format, - ) - - -def test_should_download_file_no_sync_time_folder(): - datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" - file_not_folder = SyncFile( - id="1", - name="file_name", - is_folder=True, - last_modified=datetime.now().strftime(datetime_format), - mime_type="txt", - web_view_link="link", - ) - assert not should_download_file( - file=file_not_folder, - last_updated_sync_active=None, - provider_name="google", - datetime_format=datetime_format, - ) - - -def test_should_download_file_notiondb(): - datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" - file_not_folder = SyncFile( - id="1", - name="file_name", - is_folder=False, - last_modified=datetime.now().strftime(datetime_format), - mime_type="db", - web_view_link="link", - ) - - assert not should_download_file( - file=file_not_folder, - last_updated_sync_active=(datetime.now() - timedelta(hours=5)).astimezone( - timezone.utc - ), - provider_name="notion", - datetime_format=datetime_format, - ) - - -def test_should_download_file_not_notiondb(): - datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" - file_not_folder = SyncFile( - id="1", - name="file_name", - is_folder=False, - last_modified=datetime.now().strftime(datetime_format), - mime_type="md", - web_view_link="link", - ) - - assert should_download_file( - file=file_not_folder, - last_updated_sync_active=None, - provider_name="notion", - datetime_format=datetime_format, - ) - - -def test_should_download_file_lastsynctime_before(): - datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" - file_not_folder = SyncFile( - id="1", - name="file_name", - is_folder=False, - last_modified=datetime.now().strftime(datetime_format), - mime_type="txt", - web_view_link="link", - ) - last_sync_time = (datetime.now() - timedelta(hours=5)).astimezone(timezone.utc) - - assert should_download_file( - file=file_not_folder, - last_updated_sync_active=last_sync_time, - provider_name="google", - datetime_format=datetime_format, - ) +# def test_should_download_file_no_sync_time_not_folder(): +# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" +# file_not_folder = SyncFile( +# id="1", +# name="file_name", +# is_folder=False, +# last_modified=datetime.now().strftime(datetime_format), +# mime_type="txt", +# web_view_link="link", +# ) +# assert should_download_file( +# file=file_not_folder, +# last_updated_sync_active=None, +# provider_name="google", +# datetime_format=datetime_format, +# ) + + +# def test_should_download_file_no_sync_time_folder(): +# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" +# file_not_folder = SyncFile( +# id="1", +# name="file_name", +# is_folder=True, +# last_modified=datetime.now().strftime(datetime_format), +# mime_type="txt", +# web_view_link="link", +# ) +# assert not should_download_file( +# file=file_not_folder, +# last_updated_sync_active=None, +# provider_name="google", +# datetime_format=datetime_format, +# ) + + +# def test_should_download_file_notiondb(): +# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" +# file_not_folder = SyncFile( +# id="1", +# name="file_name", +# is_folder=False, +# last_modified=datetime.now().strftime(datetime_format), +# mime_type="db", +# web_view_link="link", +# ) + +# assert not should_download_file( +# file=file_not_folder, +# last_updated_sync_active=(datetime.now() - timedelta(hours=5)).astimezone( +# timezone.utc +# ), +# provider_name="notion", +# datetime_format=datetime_format, +# ) + + +# def test_should_download_file_not_notiondb(): +# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" +# file_not_folder = SyncFile( +# id="1", +# name="file_name", +# is_folder=False, +# last_modified=datetime.now().strftime(datetime_format), +# mime_type="md", +# web_view_link="link", +# ) + +# assert should_download_file( +# file=file_not_folder, +# last_updated_sync_active=None, +# provider_name="notion", +# datetime_format=datetime_format, +# ) + + +# def test_should_download_file_lastsynctime_before(): +# datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" +# file_not_folder = SyncFile( +# id="1", +# name="file_name", +# is_folder=False, +# last_modified=datetime.now().strftime(datetime_format), +# mime_type="txt", +# web_view_link="link", +# ) +# last_sync_time = (datetime.now() - timedelta(hours=5)).astimezone(timezone.utc) + +# assert should_download_file( +# file=file_not_folder, +# last_updated_sync_active=last_sync_time, +# provider_name="google", +# datetime_format=datetime_format, +# ) # def test_should_download_file_lastsynctime_after(): diff --git a/backend/core/quivr_core/models.py b/backend/core/quivr_core/models.py index 8ebf2bbe23b8..dd0c08ce6230 100644 --- a/backend/core/quivr_core/models.py +++ b/backend/core/quivr_core/models.py @@ -39,10 +39,11 @@ class ChatMessage(BaseModelV1): class KnowledgeStatus(str, Enum): - PROCESSING = "PROCESSING" - UPLOADED = "UPLOADED" ERROR = "ERROR" RESERVED = "RESERVED" + PROCESSING = "PROCESSING" + PROCESSED = "PROCESSED" + UPLOADED = "UPLOADED" class Source(BaseModel): diff --git a/backend/supabase/migrations/20240905153004_knowledge-folders.sql b/backend/supabase/migrations/20240905153004_knowledge-folders copy.sql similarity index 100% rename from backend/supabase/migrations/20240905153004_knowledge-folders.sql rename to backend/supabase/migrations/20240905153004_knowledge-folders copy.sql diff --git a/backend/supabase/migrations/20240918180003_knowledge-sync.sql b/backend/supabase/migrations/20240918180003_knowledge-sync.sql new file mode 100644 index 000000000000..73d3eb746475 --- /dev/null +++ b/backend/supabase/migrations/20240918180003_knowledge-sync.sql @@ -0,0 +1,24 @@ +-- Renamed syncs +ALTER TABLE syncs_user RENAME TO syncs; + +-- Add column sync in knowledge +ALTER TABLE "public"."knowledge" +ADD COLUMN "sync_id" INTEGER; + +ALTER TABLE "public"."knowledge" +ADD CONSTRAINT "public_knowledge_sync_id_fkey" FOREIGN KEY (sync_id) REFERENCES syncs(id) ON DELETE CASCADE; + +-- Add columns syncs +alter table "public"."syncs" +add column "created_at" timestamp with time zone default now(); + +alter table "public"."syncs" +add column "updated_at" timestamp with time zone default now(); + +alter table "public"."syncs" +add column "last_synced_at" timestamp with time zone; + + +-- Drop files +DROP TABLE IF EXISTS "public"."syncs_active" CASCADE; +DROP TABLE IF EXISTS "public"."syncs_files" CASCADE; diff --git a/backend/supabase/seed.sql b/backend/supabase/seed.sql index b15761bf6421..a0b4ae777eba 100644 --- a/backend/supabase/seed.sql +++ b/backend/supabase/seed.sql @@ -330,20 +330,6 @@ SELECT pg_catalog.setval('"public"."integrations_user_id_seq"', 6, true); SELECT pg_catalog.setval('"public"."product_to_features_id_seq"', 1, false); --- --- Name: syncs_active_id_seq; Type: SEQUENCE SET; Schema: public; Owner: postgres --- - -SELECT pg_catalog.setval('"public"."syncs_active_id_seq"', 1, false); - - --- --- Name: syncs_files_id_seq; Type: SEQUENCE SET; Schema: public; Owner: postgres --- - -SELECT pg_catalog.setval('"public"."syncs_files_id_seq"', 1, false); - - -- -- Name: syncs_user_id_seq; Type: SEQUENCE SET; Schema: public; Owner: postgres -- From 339c6b69a0d73d71c4cf916d1376e8f630c57dc6 Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 19 Sep 2024 19:15:19 +0200 Subject: [PATCH 08/63] sync list return knowledge info --- .../modules/knowledge/entity/knowledge.py | 16 +- .../knowledge/repository/knowledges.py | 12 +- .../knowledge/service/knowledge_service.py | 11 + .../knowledge/tests/test_knowledge_entity.py | 26 ++- .../modules/sync/controller/sync_routes.py | 81 +++++-- .../modules/sync/entity/sync_models.py | 9 +- .../modules/sync/repository/sync_files.py | 126 ----------- .../sync/repository/sync_repository.py | 29 ++- .../modules/sync/service/sync_service.py | 2 +- .../quivr_api/modules/sync/tests/conftest.py | 8 +- .../sync/tests/test_sync_controller.py | 208 ++++++++++++++++++ .../api/quivr_api/modules/sync/utils/sync.py | 66 +++--- .../quivr_api/modules/sync/utils/syncutils.py | 4 +- .../20240918180003_knowledge-sync.sql | 14 +- 14 files changed, 404 insertions(+), 208 deletions(-) delete mode 100644 backend/api/quivr_api/modules/sync/repository/sync_files.py create mode 100644 backend/api/quivr_api/modules/sync/tests/test_sync_controller.py diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index 5fb2bd70449c..e387c5fe16d3 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -17,15 +17,18 @@ class KnowledgeSource(str, Enum): LOCAL = "local" WEB = "web" - GDRIVE = "google drive" + NOTETAKER = "notetaker" + GOOGLE = "google" + AZURE = "azure" DROPBOX = "dropbox" - SHAREPOINT = "sharepoint" + NOTION = "notion" + GITHUB = "github" class Knowledge(BaseModel): - id: UUID + id: Optional[UUID] file_size: int = 0 - status: KnowledgeStatus + status: Optional[KnowledgeStatus] file_name: Optional[str] = None url: Optional[str] = None extension: str = ".txt" @@ -40,6 +43,8 @@ class Knowledge(BaseModel): brains: List[Dict[str, Any]] parent: Optional[Self] children: Optional[List[Self]] + sync_id: int | None + sync_file_id: str | None class KnowledgeUpdate(BaseModel): @@ -120,6 +125,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): sync: Sync | None = Relationship( back_populates="knowledges", sa_relationship_kwargs={"lazy": "joined"} ) + sync_file_id: str | None = Field(default=None) # TODO: nested folder search async def to_dto(self, get_children: bool = True) -> Knowledge: @@ -154,4 +160,6 @@ async def to_dto(self, get_children: bool = True) -> Knowledge: parent=parent, children=[await c.to_dto(get_children=False) for c in children], user_id=self.user_id, + sync_id=self.sync_id, + sync_file_id=self.sync_file_id, ) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 427b3be063f2..3a2c4642fae8 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -1,4 +1,4 @@ -from typing import Any, Sequence +from typing import Any, List, Sequence from uuid import UUID from fastapi import HTTPException @@ -153,6 +153,16 @@ async def get_knowledge_by_sync_id(self, sync_id: int) -> KnowledgeDB: return knowledge + async def get_all_knowledge_sync_user( + self, sync_id: int, user_id: UUID | None = None + ) -> List[KnowledgeDB]: + query = select(KnowledgeDB).where(KnowledgeDB.sync_id == sync_id) + if user_id: + query = query.where(KnowledgeDB.user_id == user_id) + + result = await self.session.exec(query) + return list(result.all()) + async def get_knowledge_by_file_name_brain_id( self, file_name: str, brain_id: UUID ) -> KnowledgeDB: diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 4c7b48571026..deef902b0e59 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -68,6 +68,17 @@ async def get_knowledge_storage_path( except NoResultFound: raise FileNotFoundError(f"No knowledge for file_name: {file_name}") + async def map_syncs_knowledge_user( + self, sync_id: int, user_id: UUID + ) -> dict[str, Knowledge]: + list_kms = await self.repository.get_all_knowledge_sync_user( + sync_id=sync_id, user_id=user_id + ) + return { + k.sync_file_id: k + for k in await asyncio.gather(*[k.to_dto() for k in list_kms]) + } + async def list_knowledge( self, knowledge_id: UUID | None, user_id: UUID | None = None ) -> list[KnowledgeDB]: diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index 7376559ebc39..50e996028f63 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -9,6 +9,8 @@ from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Sync from quivr_api.modules.user.entity.user_identity import User TestData = Tuple[Brain, List[KnowledgeDB]] @@ -38,6 +40,22 @@ async def user(session): return user_1 +@pytest_asyncio.fixture(scope="function") +async def sync(session: AsyncSession, user: User) -> User: + sync = Sync( + name="test_sync", + email="test@test.com", + user_id=user.id, + credentials={"test": "test"}, + provider=SyncProvider.GOOGLE, + ) + + session.add(sync) + await session.commit() + await session.refresh(sync) + return sync + + @pytest_asyncio.fixture(scope="function") async def brain(session): brain_1 = Brain( @@ -175,7 +193,7 @@ async def test_knowledge_remove_folder_cascade( @pytest.mark.asyncio(loop_scope="session") -async def test_knowledge_dto(session, user, brain): +async def test_knowledge_dto(session, user, brain, sync): # add folder in brain folder = KnowledgeDB( file_name="folder_1", @@ -201,6 +219,8 @@ async def test_knowledge_dto(session, user, brain): user_id=user.id, brains=[brain], parent=folder, + sync_file_id="file1", + sync=sync, ) session.add(km) session.add(km) @@ -223,6 +243,10 @@ async def test_knowledge_dto(session, user, brain): assert km_dto.metadata == km.metadata_ # type: ignor assert km_dto.parent assert km_dto.parent.id == folder.id + # Syncs + + assert km_dto.sync_id == km.sync_id + assert km_dto.sync_file_id == km.sync_file_id folder_dto = await folder.to_dto() assert folder_dto.brains[0] == brain.model_dump() diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index b3434e5026a7..bf4102ac7383 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -1,11 +1,14 @@ +import asyncio import os -from typing import List +from typing import List, Tuple from fastapi import APIRouter, Depends, status from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service +from quivr_api.modules.knowledge.entity.knowledge import Knowledge +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) @@ -16,6 +19,7 @@ from quivr_api.modules.sync.controller.notion_sync_routes import notion_sync_router from quivr_api.modules.sync.dto import SyncsDescription from quivr_api.modules.sync.dto.outputs import AuthMethodEnum +from quivr_api.modules.sync.entity.sync_models import SyncFile from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -28,7 +32,8 @@ logger = get_logger(__name__) # Initialize sync service -syncs_service_dep = get_service(SyncsService) +get_sync_service = get_service(SyncsService) +get_knowledge_service = get_service(KnowledgeService) # Initialize API router @@ -80,7 +85,7 @@ dependencies=[Depends(AuthBearer())], tags=["Sync"], ) -async def get_syncs(current_user: UserIdentity = Depends(get_current_user)): +async def get_all_sync_typs(current_user: UserIdentity = Depends(get_current_user)): """ Get all available sync descriptions. @@ -101,7 +106,7 @@ async def get_syncs(current_user: UserIdentity = Depends(get_current_user)): ) async def get_user_syncs( current_user: UserIdentity = Depends(get_current_user), - syncs_service: SyncsService = Depends(syncs_service_dep), + syncs_service: SyncsService = Depends(get_sync_service), ): """ Get syncs for the current user. @@ -125,7 +130,7 @@ async def get_user_syncs( async def delete_user_sync( sync_id: int, current_user: UserIdentity = Depends(get_current_user), - syncs_service: SyncsService = Depends(syncs_service_dep), + syncs_service: SyncsService = Depends(get_sync_service), ): """ Delete a sync for the current user. @@ -146,14 +151,15 @@ async def delete_user_sync( @sync_router.get( "/sync/{sync_id}/files", - dependencies=[Depends(AuthBearer())], + response_model=List[Knowledge] | None, tags=["Sync"], ) -async def get_files_folder_user_sync( - user_sync_id: int, +async def list_sync_files( + sync_id: int, folder_id: str | None = None, current_user: UserIdentity = Depends(get_current_user), - syncs_service: SyncsService = Depends(syncs_service_dep), + syncs_service: SyncsService = Depends(get_sync_service), + knowledge_service: KnowledgeService = Depends(get_knowledge_service), ): """ Get files for an active sync. @@ -166,11 +172,52 @@ async def get_files_folder_user_sync( Returns: SyncsActive: The active sync data. """ - logger.debug( - f"Fetching files for user sync: {user_sync_id} for user: {current_user.id}" - ) - return await syncs_service.get_files_folder_user_sync( - user_sync_id, - current_user.id, - folder_id, - ) + logger.debug(f"Fetching files for user sync: {sync_id} for user: {current_user.id}") + + # TODO: check to see if this is inefficient + # Gets knowledge for each call to list the files, + # The logic is that getting from DB will be faster than provider repsonse ? + # NOTE: asyncio.gather didn't correcly typecheck + async def fetch_data() -> Tuple[dict[str, Knowledge], List[SyncFile] | None]: + map_knowledges_task = knowledge_service.map_syncs_knowledge_user( + sync_id=sync_id, user_id=current_user.id + ) + sync_files_task = syncs_service.get_files_folder_user_sync( + sync_id, + current_user.id, + folder_id, + ) + return await asyncio.gather(map_knowledges_task, sync_files_task) + + sync = await syncs_service.get_sync_by_id(sync_id=sync_id) + map_knowledges, sync_files = await fetch_data() + if not sync_files: + return None + + kms = [] + for file in sync_files: + existing_km = map_knowledges.get(file.id) + if existing_km: + kms.append(existing_km) + else: + kms.append( + Knowledge( + id=None, + file_name=file.name, + is_folder=file.is_folder, + extension=file.extension, + source=sync.provider, + source_link=file.web_view_link, + user_id=current_user.id, + brains=[], + parent=None, + children=None, + status=None, # TODO: Handle a sync not added status + # TODO: retrieve created at from sync provider + created_at=file.last_modified_at, + updated_at=file.last_modified_at, + sync_id=sync_id, + sync_file_id=file.id, + ) + ) + return kms diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index 9c6cf437bf32..e467a21db5b0 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -40,11 +40,10 @@ class SyncFile(BaseModel): id: str name: str is_folder: bool - last_modified: str - mime_type: str + last_modified_at: Optional[datetime] + extension: str web_view_link: str size: Optional[int] = None - notification_id: UUID | None = None icon: Optional[str] = None parent_id: Optional[str] = None type: Optional[str] = None @@ -54,10 +53,10 @@ class Sync(SQLModel, table=True): __tablename__ = "syncs" # type: ignore id: int | None = Field(default=None, primary_key=True) - email: str | None = Field(default=None) - user_id: UUID = Field(foreign_key="users.id", nullable=False) name: str provider: str + email: str | None = Field(default=None) + user_id: UUID = Field(foreign_key="users.id", nullable=False) credentials: Dict[str, str] | None = Field( default=None, sa_column=Column("credentials", JSON) ) diff --git a/backend/api/quivr_api/modules/sync/repository/sync_files.py b/backend/api/quivr_api/modules/sync/repository/sync_files.py deleted file mode 100644 index 9fe5ed223270..000000000000 --- a/backend/api/quivr_api/modules/sync/repository/sync_files.py +++ /dev/null @@ -1,126 +0,0 @@ -from quivr_api.logger import get_logger -from quivr_api.modules.dependencies import get_supabase_client -from quivr_api.modules.sync.dto.inputs import SyncFileInput, SyncFileUpdateInput -from quivr_api.modules.sync.entity.sync_models import DBSyncFile, SyncFile, SyncsActive -from quivr_api.modules.sync.repository.sync_interfaces import SyncFileInterface - -logger = get_logger(__name__) - - -class SyncFilesRepository(SyncFileInterface): - def __init__(self): - """ - Initialize the SyncFiles class with a Supabase client. - """ - supabase_client = get_supabase_client() - self.db = supabase_client # type: ignore - logger.debug("Supabase client initialized") - - def create_sync_file(self, sync_file_input: SyncFileInput) -> DBSyncFile | None: - """ - Create a new sync file in the database. - - Args: - sync_file_input (SyncFileInput): The input data for creating a sync file. - - Returns: - SyncsFiles: The created sync file data. - """ - logger.info("Creating sync file with input: %s", sync_file_input) - response = ( - self.db.from_("syncs_files") - .insert( - { - "path": sync_file_input.path, - "syncs_active_id": sync_file_input.syncs_active_id, - "last_modified": sync_file_input.last_modified, - "brain_id": sync_file_input.brain_id, - } - ) - .execute() - ) - if response.data: - logger.info("Sync file created successfully: %s", response.data[0]) - return DBSyncFile(**response.data[0]) - logger.warning("Failed to create sync file") - - def get_sync_files(self, sync_active_id: int) -> list[DBSyncFile]: - """ - Retrieve sync files from the database. - - Args: - sync_active_id (int): The ID of the active sync. - - Returns: - list[SyncsFiles]: A list of sync files matching the criteria. - """ - logger.info("Retrieving sync files for sync_active_id: %s", sync_active_id) - response = ( - self.db.from_("syncs_files") - .select("*") - .eq("syncs_active_id", sync_active_id) - .execute() - ) - if response.data: - # logger.info("Sync files retrieved successfully: %s", response.data) - return [DBSyncFile(**file) for file in response.data] - logger.warning("No sync files found for sync_active_id: %s", sync_active_id) - return [] - - def update_sync_file(self, sync_file_id: int, sync_file_input: SyncFileUpdateInput): - """ - Update a sync file in the database. - - Args: - sync_file_id (int): The ID of the sync file. - sync_file_input (SyncFileUpdateInput): The input data for updating the sync file. - """ - logger.info( - "Updating sync file with sync_file_id: %s, input: %s", - sync_file_id, - sync_file_input, - ) - self.db.from_("syncs_files").update( - sync_file_input.model_dump(exclude_unset=True) - ).eq("id", sync_file_id).execute() - logger.info("Sync file updated successfully") - - def update_or_create_sync_file( - self, - file: SyncFile, - sync_active: SyncsActive, - previous_file: DBSyncFile | None, - supported: bool, - ) -> DBSyncFile | None: - if previous_file: - logger.debug(f"Upserting file {previous_file} in database.") - sync_file = self.update_sync_file( - previous_file.id, - SyncFileUpdateInput( - last_modified=file.last_modified, - supported=previous_file.supported or supported, - ), - ) - else: - logger.debug("Creating new file in database.") - sync_file = self.create_sync_file( - SyncFileInput( - path=file.name, - syncs_active_id=sync_active.id, - last_modified=file.last_modified, - brain_id=str(sync_active.brain_id), - supported=supported, - ) - ) - return sync_file - - def delete_sync_file(self, sync_file_id: int): - """ - Delete a sync file from the database. - - Args: - sync_file_id (int): The ID of the sync file. - """ - logger.info("Deleting sync file with sync_file_id: %s", sync_file_id) - self.db.from_("syncs_files").delete().eq("id", sync_file_id).execute() - logger.info("Sync file deleted successfully") diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index bda27d9ecf6e..234223adeb74 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -71,13 +71,17 @@ async def create_sync( await self.session.rollback() raise - async def get_sync_id(self, sync_id: int) -> Sync: + async def get_sync_id(self, sync_id: int, user_id: UUID | None = None) -> Sync: """ Retrieve sync users from the database. """ query = select(Sync).where(Sync.id == sync_id) + + if user_id: + query = query.where(Sync.user_id == user_id) result = await self.session.exec(query) sync = result.first() + if not sync: logger.error( f"No sync user found for sync_id: {sync_id}", @@ -105,7 +109,7 @@ async def get_syncs(self, user_id: UUID, sync_id: int | None = None): if sync_id is not None: query = query.where(Sync.id == sync_id) result = await self.session.exec(query) - return result.all() + return list(result.all()) async def get_sync_user_by_state(self, state: dict) -> Sync: """ @@ -133,7 +137,7 @@ async def delete_sync(self, sync_id: int, user_id: UUID): "Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id ) await self.session.execute( - delete(Sync).where(Sync.id == sync_id).where(Sync.user_id == user_id) + delete(Sync).where(Sync.id == sync_id).where(Sync.user_id == user_id) # type: ignore ) logger.info("Sync user deleted successfully") @@ -180,35 +184,36 @@ def get_all_notion_user_syncs(self): async def get_files_folder_user_sync( self, - sync_active_id: int, + sync_id: int, user_id: UUID, folder_id: str | None = None, recursive: bool = False, ) -> List[SyncFile] | None: logger.info( "Retrieving files for user sync with sync_active_id: %s, user_id: %s, folder_id: %s", - sync_active_id, + sync_id, user_id, folder_id, ) - sync_user = await self.get_syncs(user_id=user_id, sync_id=sync_active_id) - if not sync_user: + sync = await self.get_sync_id(sync_id=sync_id, user_id=user_id) + if not sync: logger.error( "No sync user found for sync_active_id: %s, user_id: %s", - sync_active_id, + sync_id, user_id, ) return None - provider = sync_user.provider.lower() try: - sync_provider = self.sync_provider_mapping[SyncProvider(provider)] + sync_provider = self.sync_provider_mapping[ + SyncProvider(sync.provider.lower()) + ] except KeyError: raise SyncProviderError - if sync_user.credentials is None: + if sync.credentials is None: raise SyncEmptyCredentials return await sync_provider.aget_files( - sync_user.credentials, folder_id if folder_id else "", recursive + sync.credentials, folder_id if folder_id else "", recursive ) diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index 7fee5ed31a5a..cee888b57add 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -89,7 +89,7 @@ async def get_files_folder_user_sync( recursive: bool = False, ): return await self.repository.get_files_folder_user_sync( - sync_active_id=sync_active_id, + sync_id=sync_active_id, user_id=user_id, folder_id=folder_id, recursive=recursive, diff --git a/backend/api/quivr_api/modules/sync/tests/conftest.py b/backend/api/quivr_api/modules/sync/tests/conftest.py index 4ace95a7abf1..6bf84b170cf0 100644 --- a/backend/api/quivr_api/modules/sync/tests/conftest.py +++ b/backend/api/quivr_api/modules/sync/tests/conftest.py @@ -323,8 +323,8 @@ async def aget_files_by_id( id=fid, name=f"file_{fid}", is_folder=False, - last_modified=datetime.now().strftime(self.datetime_format), - mime_type="txt", + last_modified_at=datetime.now().strftime(self.datetime_format), + extension="txt", web_view_link=f"{self.name}/{fid}", ) for fid in file_ids @@ -339,8 +339,8 @@ async def aget_files( id=str(uuid4()), name=f"file_in_{folder_id}", is_folder=False, - last_modified=datetime.now().strftime(self.datetime_format), - mime_type="txt", + last_modified_at=datetime.now().strftime(self.datetime_format), + extension="txt", web_view_link=f"{self.name}/{fid}", ) for fid in range(n_files) diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py new file mode 100644 index 000000000000..cbc9efdd284c --- /dev/null +++ b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py @@ -0,0 +1,208 @@ +from datetime import datetime +from io import BytesIO +from typing import Dict, List, Union +from uuid import uuid4 + +import pytest +import pytest_asyncio +from httpx import ASGITransport, AsyncClient +from quivr_core.models import KnowledgeStatus +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_api.main import app +from quivr_api.middlewares.auth.auth_bearer import get_current_user +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.knowledge.tests.conftest import FakeStorage +from quivr_api.modules.sync.controller.sync_routes import ( + get_knowledge_service, + get_sync_service, +) +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile +from quivr_api.modules.sync.repository.sync_repository import SyncsRepository +from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.sync import BaseSync +from quivr_api.modules.user.entity.user_identity import User, UserIdentity + +N_GET_FILES = 2 + +FOLDER_SYNC_FILE_IDS = [str(uuid4())[:8] for _ in range(N_GET_FILES)] + + +class BaseFakeSync(BaseSync): + name = "FakeProvider" + lower_name = "google" + datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" + + def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFile]: + return [ + SyncFile( + id=str(fid), + name=f"file_{fid}", + extension=".txt", + web_view_link=f"test.com/{fid}", + is_folder=False, + last_modified_at=datetime.now(), + ) + for fid in file_ids + ] + + async def aget_files_by_id( + self, credentials: Dict, file_ids: List[str] + ) -> List[SyncFile]: + return self.get_files_by_id( + credentials=credentials, + file_ids=file_ids, + ) + + def get_files( + self, credentials: Dict, folder_id: str | None = None, recursive: bool = False + ) -> List[SyncFile]: + return [ + SyncFile( + id=fid, + name=f"file_{fid}", + extension=".txt", + web_view_link=f"test.com/{fid}", + parent_id=folder_id, + is_folder=idx % 2 == 0, + last_modified_at=datetime.now(), + ) + for idx, fid in enumerate(FOLDER_SYNC_FILE_IDS) + ] + + async def aget_files( + self, credentials: Dict, folder_id: str | None = None, recursive: bool = False + ) -> List[SyncFile]: + return self.get_files( + credentials=credentials, folder_id=folder_id, recursive=recursive + ) + + def check_and_refresh_access_token(self, credentials: dict) -> Dict: + raise NotImplementedError + + def download_file( + self, credentials: Dict, file: SyncFile + ) -> Dict[str, Union[str, BytesIO]]: + raise NotImplementedError + + async def adownload_file( + self, credentials: Dict, file: SyncFile + ) -> Dict[str, Union[str, BytesIO]]: + pass + + +@pytest_asyncio.fixture(scope="function") +async def user(session: AsyncSession) -> User: + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + assert user_1.id + return user_1 + + +@pytest_asyncio.fixture(scope="function") +async def sync(session: AsyncSession, user: User) -> User: + sync = Sync( + name="test_sync", + email="test@test.com", + user_id=user.id, + credentials={"test": "test"}, + provider=SyncProvider.GOOGLE, + ) + + session.add(sync) + await session.commit() + await session.refresh(sync) + return sync + + +@pytest_asyncio.fixture(scope="function") +async def brain(session): + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_1) + await session.commit() + return brain_1 + + +@pytest_asyncio.fixture(scope="function") +async def knowledge_sync(session, user: User, sync: Sync, brain: Brain): + km = KnowledgeDB( + file_name="sync_file_1.txt", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1="test_sha1", + brains=[brain], + user_id=user.id, + sync=sync, + sync_file_id=FOLDER_SYNC_FILE_IDS[0], + ) + session.add(km) + await session.commit() + await session.refresh(km) + return km + + +@pytest_asyncio.fixture(scope="function") +async def test_client(session: AsyncSession, user: User): + def default_current_user() -> UserIdentity: + assert user.id + return UserIdentity(email=user.email, id=user.id) + + async def _sync_service(): + fake_provider = {provider: BaseFakeSync() for provider in list(SyncProvider)} + repository = SyncsRepository(session) + repository.sync_provider_mapping = fake_provider + return SyncsService(repository) + + async def _km_service(): + storage = FakeStorage() + repository = KnowledgeRepository(session) + return KnowledgeService(repository, storage) + + app.dependency_overrides[get_current_user] = default_current_user + app.dependency_overrides[get_knowledge_service] = _km_service + app.dependency_overrides[get_sync_service] = _sync_service + + async with AsyncClient( + transport=ASGITransport(app=app), base_url="http://test" + ) as ac: + yield ac + app.dependency_overrides = {} + + +@pytest.mark.asyncio(loop_scope="session") +async def test_list_sync_no_knowledge(test_client: AsyncClient, sync: Sync): + params = {"folder_id": 12} + response = await test_client.get(f"/sync/{sync.id}/files", params=params) + assert response.status_code == 200 + kms = response.json() + assert len(kms) == N_GET_FILES + + +@pytest.mark.asyncio(loop_scope="session") +async def test_list_sync_with_knowledge( + test_client: AsyncClient, sync: Sync, knowledge_sync +): + params = {"folder_id": 12} + response = await test_client.get(f"/sync/{sync.id}/files", params=params) + assert response.status_code == 200 + kms = response.json() + + assert len(kms) == N_GET_FILES + km = next( + filter(lambda x: x["id"] == str(knowledge_sync.id), kms), + ) + assert km, "at least one knowledge should " + assert len(km["brains"]) == 1 diff --git a/backend/api/quivr_api/modules/sync/utils/sync.py b/backend/api/quivr_api/modules/sync/utils/sync.py index 4580f2250806..2ca886663cdd 100644 --- a/backend/api/quivr_api/modules/sync/utils/sync.py +++ b/backend/api/quivr_api/modules/sync/utils/sync.py @@ -91,7 +91,7 @@ def download_file( ) -> Dict[str, Union[str, BytesIO]]: file_id = file.id file_name = file.name - mime_type = file.mime_type + mime_type = file.extension if not self.creds: self.check_and_refresh_access_token(credentials) if not self.service: @@ -191,8 +191,8 @@ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFi is_folder=( result["mimeType"] == "application/vnd.google-apps.folder" ), - last_modified=result["modifiedTime"], - mime_type=result["mimeType"], + last_modified_at=result["modifiedTime"], + extension=result["mimeType"], web_view_link=result["webViewLink"], size=result.get("size", None), ) @@ -267,8 +267,8 @@ def get_files( is_folder=( item["mimeType"] == "application/vnd.google-apps.folder" ), - last_modified=item["modifiedTime"], - mime_type=item["mimeType"], + last_modified_at=item["modifiedTime"], + extension=item["mimeType"], web_view_link=item["webViewLink"], size=item.get("size", None), ) @@ -441,8 +441,10 @@ def fetch_files(endpoint, headers, max_retries=1): else f'{site_id}:{item.get("id")}' ), is_folder="folder" in item or not site_folder_id, - last_modified=item.get("lastModifiedDateTime"), - mime_type=item.get("file", {}).get("mimeType", "folder"), + last_modified_at=datetime.strptime( + item.get("lastModifiedDateTime"), self.datetime_format + ), + extension=item.get("file", {}).get("mimeType", "folder"), web_view_link=item.get("webUrl"), size=item.get("size", None), ) @@ -461,8 +463,8 @@ def fetch_files(endpoint, headers, max_retries=1): name="My Drive", id="root:", is_folder=True, - last_modified="", - mime_type="folder", + last_modified_at=None, + extension="folder", web_view_link="https://onedrive.live.com", ) ) @@ -520,8 +522,10 @@ def get_files_by_id(self, credentials: dict, file_ids: List[str]) -> List[SyncFi name=result.get("name"), id=f'{site_id}:{result.get("id")}', is_folder="folder" in result, - last_modified=result.get("lastModifiedDateTime"), - mime_type=result.get("file", {}).get("mimeType", "folder"), + last_modified_at=datetime.strptime( + result.get("lastModifiedDateTime"), self.datetime_format + ), + extension=result.get("file", {}).get("mimeType", "folder"), web_view_link=result.get("webUrl"), size=result.get("size", None), ) @@ -631,10 +635,14 @@ def fetch_files(metadata): name=file.name, id=file.id, is_folder=is_folder, - last_modified=( - str(file.client_modified) if not is_folder else "" + last_modified_at=( + datetime.strptime( + file.client_modified, self.datetime_format + ) + if not is_folder + else None ), - mime_type=( + extension=( file.path_lower.split(".")[-1] if not is_folder else "" ), web_view_link=shared_link, @@ -706,10 +714,14 @@ def get_files_by_id( name=metadata.name, id=metadata.id, is_folder=is_folder, - last_modified=( - str(metadata.client_modified) if not is_folder else "" + last_modified_at=( + datetime.strptime( + metadata.client_modified, self.datetime_format + ) + if not is_folder + else None ), - mime_type=( + extension=( metadata.path_lower.split(".")[-1] if not is_folder else "" ), web_view_link=shared_link, @@ -798,8 +810,8 @@ async def aget_files( name=page.name, id=str(page.notion_id), is_folder=await self.notion_service.is_folder_page(page.notion_id), - last_modified=str(page.last_modified), - mime_type=page.mime_type, + last_modified_at=str(page.last_modified), + extension=page.mime_type, web_view_link=page.web_view_link, icon=page.icon, ) @@ -837,8 +849,8 @@ async def aget_files_by_id( name=page.name, id=str(page.notion_id), is_folder=await self.notion_service.is_folder_page(page.notion_id), - last_modified=str(page.last_modified), - mime_type=page.mime_type, + last_modified_at=str(page.last_modified), + extension=page.mime_type, web_view_link=page.web_view_link, icon=page.icon, ) @@ -1049,8 +1061,8 @@ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFi name=remove_special_characters(result.get("name")), id=f"{repo_name}:{result.get('path')}", is_folder=False, - last_modified=datetime.now().strftime(self.datetime_format), - mime_type=result.get("type"), + last_modified_at=datetime.now(), + extension=result.get("type"), web_view_link=result.get("html_url"), size=result.get("size", None), ) @@ -1123,8 +1135,8 @@ def fetch_repos(endpoint, headers): name=remove_special_characters(item.get("name")), id=f"{item.get('full_name')}:", is_folder=True, - last_modified=str(item.get("updated_at")), - mime_type="repository", + last_modified_at=item.get("updated_at"), + extension="repository", web_view_link=item.get("html_url"), size=item.get("size", None), ) @@ -1170,8 +1182,8 @@ def fetch_files(endpoint, headers): name=remove_special_characters(item.get("name")), id=f"{repo_name}:{item.get('path')}", is_folder=item.get("type") == "dir", - last_modified=str(item.get("updated_at")), - mime_type=item.get("type"), + last_modified_at=str(item.get("updated_at")), + extension=item.get("type"), web_view_link=item.get("html_url"), size=item.get("size", None), ) diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py index 861ebf7306fc..ff86e396ecca 100644 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ b/backend/api/quivr_api/modules/sync/utils/syncutils.py @@ -51,7 +51,7 @@ def should_download_file( datetime_format: str, ) -> bool: file_last_modified_utc = datetime.strptime( - file.last_modified, datetime_format + file.last_modified_at, datetime_format ).replace(tzinfo=timezone.utc) should_download = ( @@ -61,7 +61,7 @@ def should_download_file( # TODO: Handle notion database if provider_name == "notion": - should_download &= file.mime_type != "db" + should_download &= file.extension != "db" else: should_download &= not file.is_folder diff --git a/backend/supabase/migrations/20240918180003_knowledge-sync.sql b/backend/supabase/migrations/20240918180003_knowledge-sync.sql index 73d3eb746475..2d7c71d0137e 100644 --- a/backend/supabase/migrations/20240918180003_knowledge-sync.sql +++ b/backend/supabase/migrations/20240918180003_knowledge-sync.sql @@ -1,24 +1,22 @@ -- Renamed syncs -ALTER TABLE syncs_user RENAME TO syncs; - +ALTER TABLE syncs_user + RENAME TO syncs; -- Add column sync in knowledge ALTER TABLE "public"."knowledge" ADD COLUMN "sync_id" INTEGER; - ALTER TABLE "public"."knowledge" ADD CONSTRAINT "public_knowledge_sync_id_fkey" FOREIGN KEY (sync_id) REFERENCES syncs(id) ON DELETE CASCADE; - +ALTER TABLE "public"."knowledge" +ADD COLUMN "sync_file_id" TEXT; +CREATE UNIQUE INDEX knowledge_sync_id_pkey ON public.knowledge USING btree (sync_id); +CREATE UNIQUE INDEX knowledge_sync_file_id_pkey ON public.knowledge USING btree (sync_file_id); -- Add columns syncs alter table "public"."syncs" add column "created_at" timestamp with time zone default now(); - alter table "public"."syncs" add column "updated_at" timestamp with time zone default now(); - alter table "public"."syncs" add column "last_synced_at" timestamp with time zone; - - -- Drop files DROP TABLE IF EXISTS "public"."syncs_active" CASCADE; DROP TABLE IF EXISTS "public"."syncs_files" CASCADE; From 625dfa29e484cbb02d821098fa71458c614e150b Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 20 Sep 2024 10:28:53 +0200 Subject: [PATCH 09/63] brain dto join --- backend/api/quivr_api/modules/knowledge/entity/knowledge.py | 2 +- .../quivr_api/modules/knowledge/tests/test_knowledge_entity.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index e387c5fe16d3..3071c8c9dc9b 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -103,7 +103,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): brains: List["Brain"] = Relationship( back_populates="knowledges", link_model=KnowledgeBrain, - sa_relationship_kwargs={"lazy": "select"}, + sa_relationship_kwargs={"lazy": "joined"}, ) parent_id: UUID | None = Field( diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index 50e996028f63..217532ebd510 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -243,8 +243,7 @@ async def test_knowledge_dto(session, user, brain, sync): assert km_dto.metadata == km.metadata_ # type: ignor assert km_dto.parent assert km_dto.parent.id == folder.id - # Syncs - + # Syncs fields assert km_dto.sync_id == km.sync_id assert km_dto.sync_file_id == km.sync_file_id From 08d7b339614ab593b3d8f52fdccb56ca8f58ace4 Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 20 Sep 2024 10:55:15 +0200 Subject: [PATCH 10/63] sync knowledge --- .../knowledge/controller/knowledge_routes.py | 43 +++++++++++-------- .../quivr_api/modules/knowledge/dto/inputs.py | 18 ++++++++ .../modules/knowledge/dto/outputs.py | 10 +++-- .../modules/knowledge/entity/knowledge.py | 42 ++---------------- .../knowledge/entity/knowledge_brain.py | 1 - .../knowledge/repository/knowledges.py | 12 +++--- .../knowledge/service/knowledge_service.py | 18 ++++---- .../modules/knowledge/tests/conftest.py | 9 ++-- .../tests/test_knowledge_controller.py | 6 ++- .../knowledge/tests/test_knowledge_entity.py | 3 +- .../knowledge/tests/test_knowledge_service.py | 8 +++- .../modules/sync/controller/sync_routes.py | 17 +++++--- .../sync/tests/test_sync_controller.py | 11 +++-- 13 files changed, 106 insertions(+), 92 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 68d01afb0c5a..f1184cc6c4e7 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -1,5 +1,5 @@ from http import HTTPStatus -from typing import Annotated, List, Optional +from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status @@ -12,8 +12,8 @@ validate_brain_authorization, ) from quivr_api.modules.dependencies import get_service -from quivr_api.modules.knowledge.dto.inputs import AddKnowledge -from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeUpdate +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.service.knowledge_exceptions import ( KnowledgeDeleteError, KnowledgeForbiddenAccess, @@ -29,15 +29,14 @@ knowledge_router = APIRouter() logger = get_logger(__name__) -get_km_service = get_service(KnowledgeService) -KnowledgeServiceDep = Annotated[KnowledgeService, Depends(get_km_service)] +get_knowledge_service = get_service(KnowledgeService) @knowledge_router.get( "/knowledge", dependencies=[Depends(AuthBearer())], tags=["Knowledge"] ) async def list_knowledge_in_brain_endpoint( - knowledge_service: KnowledgeServiceDep, + knowledge_service: KnowledgeService = Depends(get_knowledge_service), brain_id: UUID = Query(..., description="The ID of the brain"), current_user: UserIdentity = Depends(get_current_user), ): @@ -62,7 +61,7 @@ async def list_knowledge_in_brain_endpoint( ) async def delete_knowledge_brain( knowledge_id: UUID, - knowledge_service: KnowledgeServiceDep, + knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), brain_id: UUID = Query(..., description="The ID of the brain"), ): @@ -86,7 +85,7 @@ async def delete_knowledge_brain( ) async def generate_signed_url_endpoint( knowledge_id: UUID, - knowledge_service: KnowledgeServiceDep, + knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): """ @@ -120,12 +119,12 @@ async def generate_signed_url_endpoint( @knowledge_router.post( "/knowledge/", tags=["Knowledge"], - response_model=Knowledge, + response_model=KnowledgeDTO, ) async def create_knowledge( knowledge_data: str = File(...), file: Optional[UploadFile] = None, - knowledge_service: KnowledgeService = Depends(get_km_service), + knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): knowledge = AddKnowledge.model_validate_json(knowledge_data) @@ -160,12 +159,12 @@ async def create_knowledge( @knowledge_router.get( "/knowledge/children", - response_model=List[Knowledge] | None, + response_model=List[KnowledgeDTO] | None, tags=["Knowledge"], ) async def list_knowledge( parent_id: UUID | None = None, - knowledge_service: KnowledgeService = Depends(get_km_service), + knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): try: @@ -186,12 +185,12 @@ async def list_knowledge( @knowledge_router.get( "/knowledge/{knowledge_id}", - response_model=Knowledge, + response_model=KnowledgeDTO, tags=["Knowledge"], ) async def get_knowledge( knowledge_id: UUID, - knowledge_service: KnowledgeService = Depends(get_km_service), + knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): try: @@ -213,13 +212,13 @@ async def get_knowledge( @knowledge_router.patch( "/knowledge/{knowledge_id}", status_code=status.HTTP_202_ACCEPTED, - response_model=Knowledge, + response_model=KnowledgeDTO, tags=["Knowledge"], ) async def update_knowledge( knowledge_id: UUID, payload: KnowledgeUpdate, - knowledge_service: KnowledgeService = Depends(get_km_service), + knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): try: @@ -246,7 +245,7 @@ async def update_knowledge( ) async def delete_knowledge( knowledge_id: UUID, - knowledge_service: KnowledgeService = Depends(get_km_service), + knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): try: @@ -265,3 +264,13 @@ async def delete_knowledge( ) except KnowledgeDeleteError: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + + +@knowledge_router.post("/link_to_brain/") +async def link_knowledge_to_brain( + brain_id: UUID, + knowledge: KnowledgeDTO, + knowledge_service: KnowledgeService = Depends(get_knowledge_service), + current_user: UserIdentity = Depends(get_current_user), +): + pass diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py index 85a2438e9205..ad478b716fa6 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py @@ -4,6 +4,8 @@ from pydantic import BaseModel from quivr_core.models import KnowledgeStatus +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO + class CreateKnowledgeProperties(BaseModel): brain_id: UUID @@ -29,3 +31,19 @@ class AddKnowledge(BaseModel): metadata: Optional[Dict[str, str]] = None is_folder: bool = False parent_id: Optional[UUID] = None + + +class KnowledgeUpdate(BaseModel): + file_name: Optional[str] = None + status: Optional[KnowledgeStatus] = None + url: Optional[str] = None + file_sha1: Optional[str] = None + extension: Optional[str] = None + parent_id: Optional[UUID] = None + source: Optional[str] = None + source_link: Optional[str] = None + metadata: Optional[Dict[str, str]] = None + + +class LinkKnowledge(BaseModel): + knowledge: KnowledgeDTO diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py index 3e804be9e6cc..34c1860c7231 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py @@ -12,10 +12,10 @@ class DeleteKnowledgeResponse(BaseModel): knowledge_id: UUID -class KnowledgeOut(BaseModel): - id: UUID +class KnowledgeDTO(BaseModel): + id: Optional[UUID] file_size: int = 0 - status: KnowledgeStatus + status: Optional[KnowledgeStatus] file_name: Optional[str] = None url: Optional[str] = None extension: str = ".txt" @@ -29,4 +29,6 @@ class KnowledgeOut(BaseModel): user_id: UUID brains: List[Dict[str, Any]] parent: Optional[Self] - children: Optional[list[Self]] + children: Optional[List[Self]] + sync_id: int | None + sync_file_id: str | None diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index 3071c8c9dc9b..759bbbffddbd 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -1,15 +1,15 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Self +from typing import Dict, List, Optional from uuid import UUID -from pydantic import BaseModel from quivr_core.models import KnowledgeStatus from sqlalchemy import JSON, TIMESTAMP, Column, text from sqlalchemy.ext.asyncio import AsyncAttrs from sqlmodel import UUID as PGUUID from sqlmodel import Field, Relationship, SQLModel +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.sync.entity.sync_models import Sync @@ -25,40 +25,6 @@ class KnowledgeSource(str, Enum): GITHUB = "github" -class Knowledge(BaseModel): - id: Optional[UUID] - file_size: int = 0 - status: Optional[KnowledgeStatus] - file_name: Optional[str] = None - url: Optional[str] = None - extension: str = ".txt" - is_folder: bool = False - updated_at: datetime - created_at: datetime - source: Optional[str] = None - source_link: Optional[str] = None - file_sha1: Optional[str] = None - metadata: Optional[Dict[str, str]] = None - user_id: UUID - brains: List[Dict[str, Any]] - parent: Optional[Self] - children: Optional[List[Self]] - sync_id: int | None - sync_file_id: str | None - - -class KnowledgeUpdate(BaseModel): - file_name: Optional[str] = None - status: Optional[KnowledgeStatus] = None - url: Optional[str] = None - file_sha1: Optional[str] = None - extension: Optional[str] = None - parent_id: Optional[UUID] = None - source: Optional[str] = None - source_link: Optional[str] = None - metadata: Optional[Dict[str, str]] = None - - class KnowledgeDB(AsyncAttrs, SQLModel, table=True): __tablename__ = "knowledge" # type: ignore @@ -128,7 +94,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): sync_file_id: str | None = Field(default=None) # TODO: nested folder search - async def to_dto(self, get_children: bool = True) -> Knowledge: + async def to_dto(self, get_children: bool = True) -> KnowledgeDTO: assert ( self.updated_at ), "knowledge should be inserted before transforming to dto" @@ -142,7 +108,7 @@ async def to_dto(self, get_children: bool = True) -> Knowledge: parent = await self.awaitable_attrs.parent parent = await parent.to_dto(get_children=False) if parent else None - return Knowledge( + return KnowledgeDTO( id=self.id, # type: ignore file_name=self.file_name, url=self.url, diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py index 0f9b8e8ae771..017f6fb98386 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge_brain.py @@ -1,7 +1,6 @@ from datetime import datetime from uuid import UUID -from sqlalchemy import TIMESTAMP, Column, text from sqlmodel import TIMESTAMP, Column, Field, SQLModel, text from sqlmodel import UUID as PGUUID diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 3a2c4642fae8..6105a6ddccdf 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -11,11 +11,13 @@ from quivr_api.logger import get_logger from quivr_api.modules.brain.entity.brain_entity import Brain from quivr_api.modules.dependencies import BaseRepository, get_supabase_client -from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse +from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate +from quivr_api.modules.knowledge.dto.outputs import ( + DeleteKnowledgeResponse, + KnowledgeDTO, +) from quivr_api.modules.knowledge.entity.knowledge import ( - Knowledge, KnowledgeDB, - KnowledgeUpdate, ) from quivr_api.modules.knowledge.service.knowledge_exceptions import ( KnowledgeNotFoundException, @@ -47,7 +49,7 @@ async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB: async def update_knowledge( self, knowledge: KnowledgeDB, - payload: Knowledge | KnowledgeUpdate | dict[str, Any], + payload: KnowledgeDTO | KnowledgeUpdate | dict[str, Any], ) -> KnowledgeDB: try: logger.debug(f"updating {knowledge.id} with payload {payload}") @@ -161,7 +163,7 @@ async def get_all_knowledge_sync_user( query = query.where(KnowledgeDB.user_id == user_id) result = await self.session.exec(query) - return list(result.all()) + return list(result.unique().all()) async def get_knowledge_by_file_name_brain_id( self, file_name: str, brain_id: UUID diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index deef902b0e59..fb0110371167 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -12,13 +12,15 @@ from quivr_api.modules.knowledge.dto.inputs import ( AddKnowledge, CreateKnowledgeProperties, + KnowledgeUpdate, +) +from quivr_api.modules.knowledge.dto.outputs import ( + DeleteKnowledgeResponse, + KnowledgeDTO, ) -from quivr_api.modules.knowledge.dto.outputs import DeleteKnowledgeResponse from quivr_api.modules.knowledge.entity.knowledge import ( - Knowledge, KnowledgeDB, KnowledgeSource, - KnowledgeUpdate, ) from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage @@ -44,7 +46,7 @@ def __init__( self.repository = repository self.storage = storage - async def get_knowledge_sync(self, sync_id: int) -> Knowledge: + async def get_knowledge_sync(self, sync_id: int) -> KnowledgeDTO: km = await self.repository.get_knowledge_by_sync_id(sync_id) assert km.id, "Knowledge ID not generated" km = await km.to_dto() @@ -70,7 +72,7 @@ async def get_knowledge_storage_path( async def map_syncs_knowledge_user( self, sync_id: int, user_id: UUID - ) -> dict[str, Knowledge]: + ) -> dict[str, KnowledgeDTO]: list_kms = await self.repository.get_all_knowledge_sync_user( sync_id=sync_id, user_id=user_id ) @@ -100,7 +102,7 @@ async def get_knowledge( async def update_knowledge( self, knowledge: KnowledgeDB, - payload: Knowledge | KnowledgeUpdate | dict[str, Any], + payload: KnowledgeDTO | KnowledgeUpdate | dict[str, Any], ): return await self.repository.update_knowledge(knowledge, payload) @@ -191,7 +193,7 @@ async def insert_knowledge_brain( self, user_id: UUID, knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name - ) -> Knowledge: + ) -> KnowledgeDTO: knowledge = KnowledgeDB( file_name=knowledge_to_add.file_name, url=knowledge_to_add.url, @@ -213,7 +215,7 @@ async def insert_knowledge_brain( inserted_knowledge = await knowledge_db.to_dto() return inserted_knowledge - async def get_all_knowledge_in_brain(self, brain_id: UUID) -> List[Knowledge]: + async def get_all_knowledge_in_brain(self, brain_id: UUID) -> List[KnowledgeDTO]: brain = await self.repository.get_brain_by_id(brain_id) all_knowledges: List[KnowledgeDB] = await brain.awaitable_attrs.knowledges knowledges = [await knowledge.to_dto() for knowledge in all_knowledges] diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py index 2074110f6b5c..b88cf2c2c356 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/conftest.py +++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py @@ -1,6 +1,7 @@ from io import BufferedReader, FileIO -from quivr_api.modules.knowledge.entity.knowledge import Knowledge, KnowledgeDB +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface @@ -15,7 +16,7 @@ async def upload_file_storage( def get_storage_path( self, - knowledge: KnowledgeDB | Knowledge, + knowledge: KnowledgeDB | KnowledgeDTO, ) -> str: if knowledge.id is None: raise ValueError("knowledge should have a valid id") @@ -31,7 +32,7 @@ def __init__(self): def get_storage_path( self, - knowledge: KnowledgeDB | Knowledge, + knowledge: KnowledgeDB | KnowledgeDTO, ) -> str: if knowledge.id is None: raise ValueError("knowledge should have a valid id") @@ -60,7 +61,7 @@ def get_file(self, storage_path: str) -> FileIO | BufferedReader | bytes: raise FileNotFoundError(f"File not found at {storage_path}") return self.storage[storage_path] - def knowledge_exists(self, knowledge: KnowledgeDB | Knowledge) -> bool: + def knowledge_exists(self, knowledge: KnowledgeDB | KnowledgeDTO) -> bool: return self.get_storage_path(knowledge) in self.storage def clear_storage(self): diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index cf6313e97a19..bd55f5de9896 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -8,7 +8,9 @@ from quivr_api.main import app from quivr_api.middlewares.auth.auth_bearer import get_current_user -from quivr_api.modules.knowledge.controller.knowledge_routes import get_km_service +from quivr_api.modules.knowledge.controller.knowledge_routes import ( + get_knowledge_service, +) from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.tests.conftest import FakeStorage @@ -36,7 +38,7 @@ async def test_service(): return KnowledgeService(repository, storage) app.dependency_overrides[get_current_user] = default_current_user - app.dependency_overrides[get_km_service] = test_service + app.dependency_overrides[get_knowledge_service] = test_service # app.dependency_overrides[get_async_session] = lambda: session async with AsyncClient( diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index 217532ebd510..02b9a0d24261 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -41,7 +41,8 @@ async def user(session): @pytest_asyncio.fixture(scope="function") -async def sync(session: AsyncSession, user: User) -> User: +async def sync(session: AsyncSession, user: User) -> Sync: + assert user.id sync = Sync( name="test_sync", email="test@test.com", diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index 169b9bef2ca9..244c328dea98 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -11,8 +11,12 @@ from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType -from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeStatus -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeUpdate +from quivr_api.modules.knowledge.dto.inputs import ( + AddKnowledge, + KnowledgeStatus, + KnowledgeUpdate, +) +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_exceptions import ( diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index bf4102ac7383..cb615ee19f05 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -1,5 +1,6 @@ import asyncio import os +from datetime import datetime from typing import List, Tuple from fastapi import APIRouter, Depends, status @@ -7,7 +8,7 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.knowledge.entity.knowledge import Knowledge +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDTO from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.service.notification_service import ( NotificationService, @@ -151,7 +152,7 @@ async def delete_user_sync( @sync_router.get( "/sync/{sync_id}/files", - response_model=List[Knowledge] | None, + response_model=List[KnowledgeDTO] | None, tags=["Sync"], ) async def list_sync_files( @@ -178,7 +179,7 @@ async def list_sync_files( # Gets knowledge for each call to list the files, # The logic is that getting from DB will be faster than provider repsonse ? # NOTE: asyncio.gather didn't correcly typecheck - async def fetch_data() -> Tuple[dict[str, Knowledge], List[SyncFile] | None]: + async def fetch_data() -> Tuple[dict[str, KnowledgeDTO], List[SyncFile] | None]: map_knowledges_task = knowledge_service.map_syncs_knowledge_user( sync_id=sync_id, user_id=current_user.id ) @@ -201,7 +202,7 @@ async def fetch_data() -> Tuple[dict[str, Knowledge], List[SyncFile] | None]: kms.append(existing_km) else: kms.append( - Knowledge( + KnowledgeDTO( id=None, file_name=file.name, is_folder=file.is_folder, @@ -214,8 +215,12 @@ async def fetch_data() -> Tuple[dict[str, Knowledge], List[SyncFile] | None]: children=None, status=None, # TODO: Handle a sync not added status # TODO: retrieve created at from sync provider - created_at=file.last_modified_at, - updated_at=file.last_modified_at, + created_at=file.last_modified_at + if file.last_modified_at + else datetime.now(), + updated_at=file.last_modified_at + if file.last_modified_at + else datetime.now(), sync_id=sync_id, sync_file_id=file.id, ) diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py index cbc9efdd284c..6d14dc897056 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py +++ b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py @@ -93,7 +93,7 @@ def download_file( async def adownload_file( self, credentials: Dict, file: SyncFile ) -> Dict[str, Union[str, BytesIO]]: - pass + raise NotImplementedError @pytest_asyncio.fixture(scope="function") @@ -106,7 +106,7 @@ async def user(session: AsyncSession) -> User: @pytest_asyncio.fixture(scope="function") -async def sync(session: AsyncSession, user: User) -> User: +async def sync(session: AsyncSession, user: User) -> Sync: sync = Sync( name="test_sync", email="test@test.com", @@ -161,7 +161,9 @@ def default_current_user() -> UserIdentity: return UserIdentity(email=user.email, id=user.id) async def _sync_service(): - fake_provider = {provider: BaseFakeSync() for provider in list(SyncProvider)} + fake_provider: dict[SyncProvider, BaseSync] = { + provider: BaseFakeSync() for provider in list(SyncProvider) + } repository = SyncsRepository(session) repository.sync_provider_mapping = fake_provider return SyncsService(repository) @@ -176,7 +178,8 @@ async def _km_service(): app.dependency_overrides[get_sync_service] = _sync_service async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" + transport=ASGITransport(app=app), # type: ignore + base_url="http://test", ) as ac: yield ac app.dependency_overrides = {} From 9bf4519b18a27df78a4f338977ea63cea1848bad Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 20 Sep 2024 18:26:01 +0200 Subject: [PATCH 11/63] knowledge link to brain DONE --- backend/api/quivr_api/__init__.py | 2 + .../modules/brain/entity/brain_entity.py | 5 + .../modules/brain/entity/brain_user.py | 14 ++ .../modules/brain/repository/brains_users.py | 6 +- .../interfaces/brains_users_interface.py | 6 +- .../modules/brain/service/brain_service.py | 4 +- .../brain/service/brain_user_service.py | 5 +- .../knowledge/controller/knowledge_routes.py | 91 +++++++- .../quivr_api/modules/knowledge/dto/inputs.py | 10 +- .../modules/knowledge/dto/outputs.py | 2 +- .../modules/knowledge/entity/knowledge.py | 3 +- .../knowledge/repository/knowledges.py | 108 +++++---- .../knowledge/service/knowledge_service.py | 65 ++---- .../tests/test_knowledge_controller.py | 213 ++++++++++++++++++ .../knowledge/tests/test_knowledge_service.py | 75 ++++++ .../modules/sync/controller/sync_routes.py | 11 +- .../modules/user/entity/user_identity.py | 5 + backend/worker/quivr_worker/celery_worker.py | 16 +- .../quivr_worker/process/process_s3_file.py | 19 +- 19 files changed, 515 insertions(+), 145 deletions(-) create mode 100644 backend/api/quivr_api/modules/brain/entity/brain_user.py diff --git a/backend/api/quivr_api/__init__.py b/backend/api/quivr_api/__init__.py index c75182e5d2d7..92c5ed5104e5 100644 --- a/backend/api/quivr_api/__init__.py +++ b/backend/api/quivr_api/__init__.py @@ -1,4 +1,5 @@ from quivr_api.modules.brain.entity.brain_entity import Brain +from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from .modules.chat.entity.chat import Chat, ChatHistory @@ -8,6 +9,7 @@ __all__ = [ "Chat", "ChatHistory", + "BrainUserDB", "User", "NotionSyncFile", "KnowledgeDB", diff --git a/backend/api/quivr_api/modules/brain/entity/brain_entity.py b/backend/api/quivr_api/modules/brain/entity/brain_entity.py index a9a618733636..fbe90c06a972 100644 --- a/backend/api/quivr_api/modules/brain/entity/brain_entity.py +++ b/backend/api/quivr_api/modules/brain/entity/brain_entity.py @@ -9,6 +9,7 @@ from sqlmodel import TIMESTAMP, Column, Field, Relationship, SQLModel, text from sqlmodel import UUID as PGUUID +from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.brain.entity.integration_brain import ( IntegrationDescriptionEntity, IntegrationEntity, @@ -68,6 +69,10 @@ class Brain(AsyncAttrs, SQLModel, table=True): knowledges: List[KnowledgeDB] = Relationship( back_populates="brains", link_model=KnowledgeBrain ) + users: List["User"] = Relationship( + back_populates="brains", + link_model=BrainUserDB, + ) # TODO : add # "meaning" "public"."vector", # "tags" "public"."tags"[] diff --git a/backend/api/quivr_api/modules/brain/entity/brain_user.py b/backend/api/quivr_api/modules/brain/entity/brain_user.py new file mode 100644 index 000000000000..24b1b029c307 --- /dev/null +++ b/backend/api/quivr_api/modules/brain/entity/brain_user.py @@ -0,0 +1,14 @@ +from uuid import UUID + +from sqlmodel import Field, SQLModel + + +class BrainUserDB(SQLModel, table=True): + __tablename__ = "brains_users" # type: ignore + + brain_id: UUID = Field( + nullable=False, foreign_key="brains.brain_id", primary_key=True + ) + user_id: UUID = Field(nullable=False, foreign_key="users.id", primary_key=True) + default_brain: bool + rights: str diff --git a/backend/api/quivr_api/modules/brain/repository/brains_users.py b/backend/api/quivr_api/modules/brain/repository/brains_users.py index 9176eeb35ce6..cdbc69f903b8 100644 --- a/backend/api/quivr_api/modules/brain/repository/brains_users.py +++ b/backend/api/quivr_api/modules/brain/repository/brains_users.py @@ -2,7 +2,7 @@ from quivr_api.logger import get_logger from quivr_api.modules.brain.entity.brain_entity import ( - BrainUser, + BrainUserDB, MinimalUserBrainEntity, ) from quivr_api.modules.brain.repository.interfaces.brains_users_interface import ( @@ -161,7 +161,7 @@ def get_user_default_brain_id(self, user_id: UUID) -> UUID | None: return None return UUID(response[0].get("brain_id")) - def get_brain_users(self, brain_id: UUID) -> list[BrainUser]: + def get_brain_users(self, brain_id: UUID) -> list[BrainUserDB]: response = ( self.db.table("brains_users") .select("id:brain_id, *") @@ -169,7 +169,7 @@ def get_brain_users(self, brain_id: UUID) -> list[BrainUser]: .execute() ) - return [BrainUser(**item) for item in response.data] + return [BrainUserDB(**item) for item in response.data] def delete_brain_subscribers(self, brain_id: UUID): results = ( diff --git a/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py b/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py index dabe8ef924b7..d87365239a41 100644 --- a/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py +++ b/backend/api/quivr_api/modules/brain/repository/interfaces/brains_users_interface.py @@ -3,7 +3,7 @@ from uuid import UUID from quivr_api.modules.brain.entity.brain_entity import ( - BrainUser, + BrainUserDB, MinimalUserBrainEntity, ) @@ -56,7 +56,7 @@ def get_user_default_brain_id(self, user_id: UUID) -> UUID | None: pass @abstractmethod - def get_brain_users(self, brain_id: UUID) -> List[BrainUser]: + def get_brain_users(self, brain_id: UUID) -> List[BrainUserDB]: """ Get all users for a brain """ @@ -88,7 +88,7 @@ def update_brain_user_default_status( @abstractmethod def update_brain_user_rights( self, brain_id: UUID, user_id: UUID, rights: str - ) -> BrainUser: + ) -> BrainUserDB: """ Update the rights for a user in a brain """ diff --git a/backend/api/quivr_api/modules/brain/service/brain_service.py b/backend/api/quivr_api/modules/brain/service/brain_service.py index 891dc8ea4119..764b1928c1ee 100644 --- a/backend/api/quivr_api/modules/brain/service/brain_service.py +++ b/backend/api/quivr_api/modules/brain/service/brain_service.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Tuple from uuid import UUID from fastapi import HTTPException @@ -54,7 +54,7 @@ def find_brain_from_question( chat_id: UUID, history, vector_store: CustomSupabaseVectorStore, - ) -> (Optional[BrainEntity], dict[str, str]): + ) -> Tuple[Optional[BrainEntity], dict[str, str]]: """Find the brain to use for a question. Args: diff --git a/backend/api/quivr_api/modules/brain/service/brain_user_service.py b/backend/api/quivr_api/modules/brain/service/brain_user_service.py index b1bf15038723..61699fcededf 100644 --- a/backend/api/quivr_api/modules/brain/service/brain_user_service.py +++ b/backend/api/quivr_api/modules/brain/service/brain_user_service.py @@ -2,10 +2,11 @@ from uuid import UUID from fastapi import HTTPException + from quivr_api.logger import get_logger from quivr_api.modules.brain.entity.brain_entity import ( BrainEntity, - BrainUser, + BrainUserDB, MinimalUserBrainEntity, RoleEnum, ) @@ -73,7 +74,7 @@ def get_user_brains(self, user_id: UUID) -> list[MinimalUserBrainEntity]: return results # type: ignore - def get_brain_users(self, brain_id: UUID) -> List[BrainUser]: + def get_brain_users(self, brain_id: UUID) -> List[BrainUserDB]: return self.brain_user_repository.get_brain_users(brain_id) def update_brain_user_rights( diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index f1184cc6c4e7..9fbcd838a86d 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -1,9 +1,12 @@ +import asyncio from http import HTTPStatus from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status +from quivr_core.models import KnowledgeStatus +from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.brain.entity.brain_entity import RoleEnum @@ -12,8 +15,13 @@ validate_brain_authorization, ) from quivr_api.modules.dependencies import get_service -from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate +from quivr_api.modules.knowledge.dto.inputs import ( + AddKnowledge, + KnowledgeUpdate, + LinkKnowledgeBrain, +) from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.service.knowledge_exceptions import ( KnowledgeDeleteError, KnowledgeForbiddenAccess, @@ -21,15 +29,23 @@ UploadError, ) from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.notification.dto.inputs import CreateNotification +from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum +from quivr_api.modules.notification.service.notification_service import ( + NotificationService, +) +from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.upload.service.generate_file_signed_url import ( generate_file_signed_url, ) from quivr_api.modules.user.entity.user_identity import UserIdentity -knowledge_router = APIRouter() logger = get_logger(__name__) +knowledge_router = APIRouter() +notification_service = NotificationService() get_knowledge_service = get_service(KnowledgeService) +get_sync_service = get_service(SyncsService) @knowledge_router.get( @@ -43,7 +59,6 @@ async def list_knowledge_in_brain_endpoint( """ Retrieve and list all the knowledge in a brain. """ - validate_brain_authorization(brain_id=brain_id, user_id=current_user.id) knowledges = await knowledge_service.get_all_knowledge_in_brain(brain_id) @@ -266,11 +281,73 @@ async def delete_knowledge( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) -@knowledge_router.post("/link_to_brain/") +@knowledge_router.post( + "/knowledge/link_to_brains/", + status_code=status.HTTP_201_CREATED, + response_model=List[KnowledgeDTO], +) async def link_knowledge_to_brain( - brain_id: UUID, - knowledge: KnowledgeDTO, + link_request: LinkKnowledgeBrain, knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): - pass + brains_ids, knowledge_dto, bulk_id = ( + link_request.brain_ids, + link_request.knowledge, + link_request.bulk_id, + ) + if len(brains_ids) == 0: + return "empty brain list" + + async def _send_knowledge_process(knowledge: KnowledgeDB): + assert knowledge.id + knowledge = await knowledge_service.update_knowledge( + knowledge=knowledge, + payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + ) + upload_notification = notification_service.add_notification( + CreateNotification( + user_id=current_user.id, + bulk_id=bulk_id, + status=NotificationsStatusEnum.INFO, + title=f"{knowledge.file_name}", + category="process", + ) + ) + celery.send_task( + "process_file_task", + kwargs={ + "knowledge_id": knowledge.id, + "file_name": knowledge.file_name, + "notification_id": upload_notification.id, + "source": knowledge.source, + "source_link": knowledge.source_link, + }, + ) + + if knowledge_dto.id is None: + if knowledge_dto.sync_file_id is None: + raise HTTPException( + status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unknown knowledge entity" + ) + # Create a knowledge from this sync + knowledge = await knowledge_service.create_knowledge( + user_id=current_user.id, + knowledge_to_add=AddKnowledge(**knowledge_dto.model_dump()), + upload_file=None, + ) + linked_kms = await knowledge_service.link_knowledge_tree_brains( + knowledge, brains_ids=brains_ids, user_id=current_user.id + ) + + else: + linked_kms = await knowledge_service.link_knowledge_tree_brains( + knowledge_dto.id, brains_ids=brains_ids, user_id=current_user.id + ) + + for knowledge in filter( + lambda k: k.status != KnowledgeStatus.PROCESSED, linked_kms + ): + await _send_knowledge_process(knowledge=knowledge) + + return await asyncio.gather(*[k.to_dto() for k in linked_kms]) diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py index ad478b716fa6..cc0ab8958389 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Dict, List, Optional from uuid import UUID from pydantic import BaseModel @@ -30,7 +30,9 @@ class AddKnowledge(BaseModel): source_link: Optional[str] = None metadata: Optional[Dict[str, str]] = None is_folder: bool = False - parent_id: Optional[UUID] = None + parent_id: UUID | None = None + sync_id: int | None = None + sync_file_id: str | None = None class KnowledgeUpdate(BaseModel): @@ -45,5 +47,7 @@ class KnowledgeUpdate(BaseModel): metadata: Optional[Dict[str, str]] = None -class LinkKnowledge(BaseModel): +class LinkKnowledgeBrain(BaseModel): + bulk_id: UUID knowledge: KnowledgeDTO + brain_ids: List[UUID] diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py index 34c1860c7231..6d361dec12cb 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py @@ -29,6 +29,6 @@ class KnowledgeDTO(BaseModel): user_id: UUID brains: List[Dict[str, Any]] parent: Optional[Self] - children: Optional[List[Self]] + children: List[Self] sync_id: int | None sync_file_id: str | None diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index 759bbbffddbd..19c2a3f2defe 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -105,6 +105,7 @@ async def to_dto(self, get_children: bool = True) -> KnowledgeDTO: children: list[KnowledgeDB] = ( await self.awaitable_attrs.children if get_children else [] ) + children_dto = [await c.to_dto(get_children=False) for c in children] parent = await self.awaitable_attrs.parent parent = await parent.to_dto(get_children=False) if parent else None @@ -124,7 +125,7 @@ async def to_dto(self, get_children: bool = True) -> KnowledgeDTO: metadata=self.metadata_, # type: ignore brains=[b.model_dump() for b in brains], parent=parent, - children=[await c.to_dto(get_children=False) for c in children], + children=children_dto, user_id=self.user_id, sync_id=self.sync_id, sync_file_id=self.sync_file_id, diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 6105a6ddccdf..18b4f4595c6e 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -5,11 +5,11 @@ from quivr_core.models import KnowledgeStatus from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.orm import joinedload -from sqlmodel import select, text +from sqlmodel import and_, col, select, text from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.logger import get_logger -from quivr_api.modules.brain.entity.brain_entity import Brain +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainUserDB, RoleEnum from quivr_api.modules.dependencies import BaseRepository, get_supabase_client from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate from quivr_api.modules.knowledge.dto.outputs import ( @@ -20,6 +20,7 @@ KnowledgeDB, ) from quivr_api.modules.knowledge.service.knowledge_exceptions import ( + KnowledgeCreationError, KnowledgeNotFoundException, KnowledgeUpdateError, ) @@ -69,6 +70,47 @@ async def update_knowledge( logger.error(f"Error updating knowledge {e}") raise KnowledgeUpdateError + async def link_knowledge_tree_brains( + self, knowledge: KnowledgeDB, brains_ids: List[UUID], user_id: UUID + ) -> list[KnowledgeDB]: + assert knowledge.id, "can't link knowledge not in db" + # TODO(@aminediro @StanGirard): verification should be done elsewhere + # should rewrite BrainService and Brain Authorization to be as middleware + try: + stmt = ( + select(Brain) + .join(BrainUserDB, col(Brain.brain_id) == col(BrainUserDB.brain_id)) + .where( + and_( + col(Brain.brain_id).in_(brains_ids), + BrainUserDB.user_id == user_id, + BrainUserDB.rights == RoleEnum.Owner, + ) + ) + ) + brains = list((await self.session.exec(stmt)).unique().all()) + if len(brains) == 0: + logger.error( + f"No brains for user_id={user_id}, brains_list={brains_ids}" + ) + raise KnowledgeCreationError("can't associate knowledge to brains") + children = await self.get_knowledge_tree(knowledge.id) + all_kms = [knowledge, *children] + for k in all_kms: + for b in brains: + k.brains.append(b) + for k in all_kms: + await self.session.merge(k) + await self.session.commit() + [await self.session.refresh(k) for k in all_kms] + return all_kms + except IntegrityError: + await self.session.rollback() + raise + except Exception: + await self.session.rollback() + raise + async def insert_knowledge_brain( self, knowledge: KnowledgeDB, brain_id: UUID ) -> KnowledgeDB: @@ -191,44 +233,34 @@ async def get_knowledge_by_sha1(self, sha1: str) -> KnowledgeDB: return knowledge - async def get_all_children(self, parent_id: UUID) -> list[KnowledgeDB]: - query = text(""" - WITH RECURSIVE knowledge_tree AS ( - SELECT * - FROM knowledge - WHERE parent_id = :parent_id - UNION ALL - SELECT k.* - FROM knowledge k - JOIN knowledge_tree kt ON k.parent_id = kt.id - ) - SELECT * FROM knowledge_tree - """) - - result = await self.session.execute(query, params={"parent_id": parent_id}) - rows = result.fetchall() - knowledge_list = [] - for row in rows: - knowledge = KnowledgeDB( - id=row.id, - parent_id=row.parent_id, - file_name=row.file_name, - url=row.url, - extension=row.extension, - status=row.status, - source=row.source, - source_link=row.source_link, - file_size=row.file_size, - file_sha1=row.file_sha1, - created_at=row.created_at, - updated_at=row.updated_at, - metadata_=row.metadata, - is_folder=row.is_folder, - user_id=row.user_id, + async def get_knowledge_tree(self, parent_id: UUID) -> list[KnowledgeDB]: + from sqlalchemy.orm import aliased + + Knowledge = aliased(KnowledgeDB) + KnowledgeRecursive = aliased(KnowledgeDB) + + recursive_cte = ( + select(KnowledgeRecursive) + .where(KnowledgeRecursive.parent_id == parent_id) + .cte(name="knowledge_tree", recursive=True) + ) + + recursive_cte = recursive_cte.union_all( + select(Knowledge).join( + recursive_cte, col(Knowledge.parent_id) == col(recursive_cte.c.id) ) - knowledge_list.append(knowledge) + ) + # TODO(@AmineDiro): Optimize get_knowledge_tree + query = ( + select(KnowledgeDB) + .join(recursive_cte, col(KnowledgeDB.id) == recursive_cte.c.id) + .options(joinedload(KnowledgeDB.brains)) + ) + + result = await self.session.exec(query) + knowledge_list = result.unique().all() - return knowledge_list + return list(knowledge_list) async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]: query = ( diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index fb0110371167..1c79d9f99115 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -79,6 +79,7 @@ async def map_syncs_knowledge_user( return { k.sync_file_id: k for k in await asyncio.gather(*[k.to_dto() for k in list_kms]) + if k.sync_file_id } async def list_knowledge( @@ -167,8 +168,12 @@ async def create_knowledge( metadata_=knowledge_to_add.metadata, # type: ignore status=KnowledgeStatus.RESERVED, parent_id=knowledge_to_add.parent_id, + sync_id=knowledge_to_add.sync_id, + sync_file_id=knowledge_to_add.sync_file_id, ) + knowledge_db = await self.repository.create_knowledge(knowledgedb) + try: if knowledgedb.source == KnowledgeSource.LOCAL and upload_file: # NOTE(@aminediro): Unnecessary mem buffer because supabase doesnt accept FileIO.. @@ -177,10 +182,10 @@ async def create_knowledge( knowledgedb, buff_reader ) knowledgedb.source_link = storage_path - knowledge_db = await self.repository.update_knowledge( - knowledge_db, - KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore - ) + knowledge_db = await self.repository.update_knowledge( + knowledge_db, + KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore + ) return knowledge_db except Exception as e: logger.exception( @@ -251,7 +256,7 @@ async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeRespo try: # TODO: # - Notion folders are special, they are themselves files and should be removed from storage - children = await self.repository.get_all_children(knowledge.id) + children = await self.repository.get_knowledge_tree(knowledge.id) km_paths = [ self.storage.get_storage_path(k) for k in children if not k.is_folder ] @@ -300,45 +305,11 @@ async def remove_all_knowledges_from_brain(self, brain_id: UUID) -> None: f"All knowledge in brain {brain_id} removed successfully from table" ) - # TODO: REDO THIS MESS !!!! - # REMOVE ALL SYNC TABLES and start from scratch - # async def update_or_create_knowledge_sync( - # self, - # brain_id: UUID, - # user_id: UUID, - # file: SyncFile, - # new_sync_file: DBSyncFile | None, - # prev_sync_file: DBSyncFile | None, - # downloaded_file: DownloadedSyncFile, - # source: str, - # source_link: str, - # ) -> Knowledge: - # sync_id = None - # # TODO: THIS IS A HACK!! Remove all of this - # if prev_sync_file: - # prev_knowledge = await self.get_knowledge_sync(sync_id=prev_sync_file.id) - # if len(prev_knowledge.brains) > 1: - # await self.repository.remove_knowledge_from_brain( - # prev_knowledge.id, brain_id - # ) - # else: - # await self.repository.remove_knowledge_by_id(prev_knowledge.id) - # sync_id = prev_sync_file.id - - # sync_id = new_sync_file.id if new_sync_file else sync_id - # knowledge_to_add = CreateKnowledgeProperties( - # brain_id=brain_id, - # file_name=file.name, - # extension=downloaded_file.extension, - # source=source, - # status=KnowledgeStatus.PROCESSING, - # source_link=source_link, - # file_size=file.size if file.size else 0, - # # FIXME (@aminediro): This is a temporary fix, redo in KMS - # file_sha1=None, - # metadata={"sync_file_id": str(sync_id)}, - # ) - # added_knowledge = await self.insert_knowledge_brain( - # knowledge_to_add=knowledge_to_add, user_id=user_id - # ) - # return added_knowledge + async def link_knowledge_tree_brains( + self, knowledge: KnowledgeDB | UUID, brains_ids: List[UUID], user_id: UUID + ) -> List[KnowledgeDB]: + if isinstance(knowledge, UUID): + knowledge = await self.repository.get_knowledge_by_id(knowledge) + return await self.repository.link_knowledge_tree_brains( + knowledge, brains_ids=brains_ids, user_id=user_id + ) diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index bd55f5de9896..ea7b1a479d01 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -1,19 +1,28 @@ import json +from datetime import datetime +from uuid import uuid4 import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient +from quivr_core.models import KnowledgeStatus from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.main import app from quivr_api.middlewares.auth.auth_bearer import get_current_user +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.knowledge.controller.knowledge_routes import ( get_knowledge_service, ) +from quivr_api.modules.knowledge.dto.inputs import LinkKnowledgeBrain +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.tests.conftest import FakeStorage +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Sync from quivr_api.modules.user.entity.user_identity import User, UserIdentity @@ -26,6 +35,43 @@ async def user(session: AsyncSession) -> User: return user_1 +@pytest_asyncio.fixture(scope="function") +async def brain(session, user): + assert user.id + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_1) + await session.commit() + await session.refresh(brain_1) + assert brain_1.brain_id + brain_user = BrainUserDB( + brain_id=brain_1.brain_id, user_id=user.id, default_brain=True, rights="Owner" + ) + session.add(brain_user) + await session.commit() + return brain_1 + + +@pytest_asyncio.fixture(scope="function") +async def sync(session: AsyncSession, user: User) -> Sync: + assert user.id + sync = Sync( + name="test_sync", + email="test@test.com", + user_id=user.id, + credentials={"test": "test"}, + provider=SyncProvider.GOOGLE, + ) + + session.add(sync) + await session.commit() + await session.refresh(sync) + return sync + + @pytest_asyncio.fixture(scope="function") async def test_client(session: AsyncSession, user: User): def default_current_user() -> UserIdentity: @@ -48,6 +94,33 @@ async def test_service(): app.dependency_overrides = {} +@pytest.mark.asyncio(loop_scope="session") +async def test_post_knowledge_folder(test_client: AsyncClient): + km_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": True, + "parent_id": None, + } + + multipart_data = { + "knowledge_data": (None, json.dumps(km_data), "application/json"), + } + + response = await test_client.post( + "/knowledge/", + files=multipart_data, + ) + + assert response.status_code == 200 + km = KnowledgeDTO.model_validate(response.json()) + + assert km.id + assert km.is_folder + assert km.parent is None + assert km.children == [] + + @pytest.mark.asyncio(loop_scope="session") async def test_post_knowledge(test_client: AsyncClient): km_data = { @@ -74,3 +147,143 @@ async def test_post_knowledge(test_client: AsyncClient): async def test_add_knowledge_invalid_input(test_client): response = await test_client.post("/knowledge/", files={}) assert response.status_code == 422 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_link_knowledge_sync_file( + monkeypatch, + session: AsyncSession, + test_client: AsyncClient, + brain: Brain, + user: User, + sync: Sync, +): + tasks = {} + + def _send_task(*args, **kwargs): + tasks["args"] = args + tasks["kwargs"] = {**kwargs["kwargs"]} + + monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) + + assert user.id + assert brain.brain_id + km = KnowledgeDTO( + id=None, + file_name="test.txt", + extension=".txt", + status=None, + user_id=user.id, + created_at=datetime.now(), + updated_at=datetime.now(), + brains=[], + source=SyncProvider.GOOGLE, + source_link="drive://test.txt", + sync_id=sync.id, + sync_file_id="sync_file_id_1", + parent=None, + children=[], + ) + json_data = LinkKnowledgeBrain( + bulk_id=uuid4(), brain_ids=[brain.brain_id], knowledge=km + ).model_dump_json() + response = await test_client.post( + "/knowledge/link_to_brains/", + content=json_data, + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 201 + km = KnowledgeDTO.model_validate(response.json()[0]) + assert km.id + assert km.status == KnowledgeStatus.PROCESSING + assert len(km.brains) == 1 + + # Assert task added to celery + assert len(tasks) > 0 + assert tasks["args"] == ("process_file_task",) + + minimal_task_kwargs = { + "knowledge_id": km.id, + } + all( + minimal_task_kwargs[key] == tasks["kwargs"][key] # type: ignore + for key in minimal_task_kwargs + ) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_link_knowledge_folder( + monkeypatch, + session: AsyncSession, + test_client: AsyncClient, + brain: Brain, + user: User, + sync: Sync, +): + assert brain.brain_id + tasks = {} + + def _send_task(*args, **kwargs): + tasks["args"] = args + tasks["kwargs"] = {**kwargs["kwargs"]} + + monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) + + folder_data = { + "file_name": "folder", + "source": "local", + "is_folder": True, + "parent_id": None, + } + response = await test_client.post( + "/knowledge/", + files={ + "knowledge_data": (None, json.dumps(folder_data), "application/json"), + }, + ) + # 1. Insert folder + folder_km = KnowledgeDTO.model_validate(response.json()) + file_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": True, + "parent_id": str(folder_km.id), + } + + multipart_data = { + "knowledge_data": (None, json.dumps(file_data), "application/json"), + } + # 2. Insert file in folder + response = await test_client.post( + "/knowledge/", + files=multipart_data, + ) + file_km = KnowledgeDTO.model_validate(response.json()) + + json_data = LinkKnowledgeBrain( + bulk_id=uuid4(), brain_ids=[brain.brain_id], knowledge=folder_km + ).model_dump_json() + + response = await test_client.post( + "/knowledge/link_to_brains/", + content=json_data, + headers={"Content-Type": "application/json"}, + ) + + assert response.status_code == 201 + + updated_kms = [KnowledgeDTO.model_validate(d) for d in response.json()] + + assert len(updated_kms) == 2 + assert next( + filter(lambda k: k.id == folder_km.id, updated_kms) + ), "file not in updated list" + assert next( + filter(lambda k: k.id == file_km.id, updated_kms) + ), "file not in updated list" + for km in updated_kms: + assert len(km.brains) == 1 + + # Assert both files are being processed + assert len(tasks) == 2 diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index 244c328dea98..75364ab0016b 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -11,6 +11,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.knowledge.dto.inputs import ( AddKnowledge, KnowledgeStatus, @@ -58,6 +59,26 @@ async def user(session: AsyncSession) -> User: return user_1 +@pytest_asyncio.fixture(scope="function") +async def brain_user(session, user: User) -> Brain: + assert user.id + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_1) + await session.commit() + await session.refresh(brain_1) + assert brain_1.brain_id + brain_user = BrainUserDB( + brain_id=brain_1.brain_id, user_id=user.id, default_brain=True, rights="Owner" + ) + session.add(brain_user) + await session.commit() + return brain_1 + + @pytest_asyncio.fixture(scope="function") async def test_data(session: AsyncSession) -> TestData: user_1 = ( @@ -1021,3 +1042,57 @@ async def test_list_knowledge(session: AsyncSession, user: User): assert len(kms) == 1 assert kms[0].id == nested_file.id + + +@pytest.mark.asyncio(loop_scope="session") +async def test_link_knowledge_brain( + session: AsyncSession, user: User, brain_user: Brain +): + assert user.id + assert brain_user.brain_id + + root_folder = KnowledgeDB( + file_name="folder", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[], + children=[], + user_id=user.id, + is_folder=True, + ) + nested_file = KnowledgeDB( + file_name="file_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=10, + file_sha1=None, + user_id=user.id, + parent=root_folder, + ) + session.add(nested_file) + session.add(root_folder) + await session.commit() + await session.refresh(root_folder) + await session.refresh(nested_file) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + await service.link_knowledge_tree_brains( + root_folder, brains_ids=[brain_user.brain_id], user_id=user.id + ) + kms = await service.get_all_knowledge_in_brain(brain_id=brain_user.brain_id) + assert len(kms) == 2 + assert {k.id for k in kms} == {root_folder.id, nested_file.id} + + +@pytest.mark.asyncio(loop_scope="session") +async def test_link_knowledge_brain_existing_brains(): + """test knowledge already in brain and we add it to the same brain because we added his parent""" diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index cb615ee19f05..136ae96d6754 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -201,6 +201,9 @@ async def fetch_data() -> Tuple[dict[str, KnowledgeDTO], List[SyncFile] | None]: if existing_km: kms.append(existing_km) else: + last_modified_at = ( + file.last_modified_at if file.last_modified_at else datetime.now() + ) kms.append( KnowledgeDTO( id=None, @@ -215,12 +218,8 @@ async def fetch_data() -> Tuple[dict[str, KnowledgeDTO], List[SyncFile] | None]: children=None, status=None, # TODO: Handle a sync not added status # TODO: retrieve created at from sync provider - created_at=file.last_modified_at - if file.last_modified_at - else datetime.now(), - updated_at=file.last_modified_at - if file.last_modified_at - else datetime.now(), + created_at=last_modified_at, + updated_at=last_modified_at, sync_id=sync_id, sync_file_id=file.id, ) diff --git a/backend/api/quivr_api/modules/user/entity/user_identity.py b/backend/api/quivr_api/modules/user/entity/user_identity.py index 3f734f1a66cc..22e4940af882 100644 --- a/backend/api/quivr_api/modules/user/entity/user_identity.py +++ b/backend/api/quivr_api/modules/user/entity/user_identity.py @@ -2,6 +2,7 @@ from uuid import UUID, uuid4 from pydantic import BaseModel +from quivr_api.modules.brain.entity.brain_user import BrainUserDB from sqlmodel import Field, Relationship, SQLModel @@ -17,6 +18,10 @@ class User(SQLModel, table=True): onboarded: bool | None = None chats: List["Chat"] | None = Relationship(back_populates="user") # type: ignore notion_syncs: List["NotionSyncFile"] | None = Relationship(back_populates="user") # type: ignore + brains: List["Brain"] = Relationship( + back_populates="users", + link_model=BrainUserDB, + ) class UserIdentity(BaseModel): diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index d1bd6fb6e67a..9ad58dd5640c 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -18,9 +18,7 @@ from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) -from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository from quivr_api.modules.sync.service.sync_notion import SyncNotionService -from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.vector.repository.vectors_repository import VectorRepository from quivr_api.modules.vector.service.vector_service import VectorService from quivr_api.utils.telemetry import maybe_send_telemetry @@ -53,9 +51,6 @@ supabase_client = get_supabase_client() # document_vector_store = get_documents_vector_store() notification_service = NotificationService() -sync_active_service = SyncsService() -sync_user_service = SyncsService() -sync_files_repo_service = SyncFilesRepository() brain_service = BrainService() brain_vectors = BrainsVectors() storage = SupabaseS3Storage() @@ -99,11 +94,9 @@ def init_worker(**kwargs): dont_autoretry_for=(FileExistsError,), ) def process_file_task( + knowledge_id: UUID, file_name: str, - file_original_name: str, - brain_id: UUID, notification_id: UUID, - knowledge_id: UUID, source: str | None = None, source_link: str | None = None, delete_file: bool = False, @@ -119,8 +112,6 @@ def process_file_task( loop.run_until_complete( aprocess_file_task( file_name=file_name, - file_original_name=file_original_name, - brain_id=brain_id, notification_id=notification_id, knowledge_id=knowledge_id, source=source, @@ -132,8 +123,6 @@ def process_file_task( async def aprocess_file_task( file_name: str, - file_original_name: str, - brain_id: UUID, notification_id: UUID, knowledge_id: UUID, source: str | None = None, @@ -163,12 +152,9 @@ async def aprocess_file_task( vector_service=vector_service, knowledge_service=knowledge_service, file_name=file_name, - brain_id=brain_id, - file_original_name=file_original_name, knowledge_id=knowledge_id, integration=source, integration_link=source_link, - delete_file=delete_file, ) session.commit() await async_session.commit() diff --git a/backend/worker/quivr_worker/process/process_s3_file.py b/backend/worker/quivr_worker/process/process_s3_file.py index 99bc4e7360d1..526a62ad50cb 100644 --- a/backend/worker/quivr_worker/process/process_s3_file.py +++ b/backend/worker/quivr_worker/process/process_s3_file.py @@ -2,7 +2,7 @@ from quivr_api.logger import get_logger from quivr_api.modules.brain.service.brain_service import BrainService -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeUpdate +from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.vector.service.vector_service import VectorService @@ -19,27 +19,12 @@ async def process_uploaded_file( vector_service: VectorService, knowledge_service: KnowledgeService, file_name: str, - brain_id: UUID, - file_original_name: str, knowledge_id: UUID, integration: str | None = None, integration_link: str | None = None, - delete_file: bool = False, bucket_name: str = "quivr", ): - brain = brain_service.get_brain_by_id(brain_id) - if brain is None: - logger.exception( - "It seems like you're uploading knowledge to an unknown brain." - ) - raise ValueError("unknown brain") - assert brain file_data = supabase_client.storage.from_(bucket_name).download(file_name) - # TODO: Have the whole logic on do we process file or not - # Don't process a file that already exists (file_sha1 in the table with STATUS=UPLOADED) - # - # - Check on file_sha1 and status - # If we have some knowledge with error with build_file(file_data, knowledge_id, file_name) as file_instance: knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id) await knowledge_service.update_knowledge( @@ -48,7 +33,7 @@ async def process_uploaded_file( ) await process_file( file_instance=file_instance, - brain=brain, + brain=knowledge.brains[0], # FIXME: this is temporary brain_service=brain_service, vector_service=vector_service, integration=integration, From 6d1dd8d2a21482dba69b5e337860628f875bcc86 Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 20 Sep 2024 18:27:44 +0200 Subject: [PATCH 12/63] fixed sync children none --- backend/api/quivr_api/modules/sync/controller/sync_routes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index 136ae96d6754..a62ee7426eb9 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -215,7 +215,7 @@ async def fetch_data() -> Tuple[dict[str, KnowledgeDTO], List[SyncFile] | None]: user_id=current_user.id, brains=[], parent=None, - children=None, + children=[], status=None, # TODO: Handle a sync not added status # TODO: retrieve created at from sync provider created_at=last_modified_at, From 6e9158329f4dd5aa4763ec66470cc0430c9c8a99 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 23 Sep 2024 10:55:35 +0200 Subject: [PATCH 13/63] link knowledge on creation --- .../knowledge/service/knowledge_service.py | 69 ++---- .../tests/test_knowledge_controller.py | 5 +- .../knowledge/tests/test_knowledge_service.py | 234 ++++++------------ 3 files changed, 103 insertions(+), 205 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 1c79d9f99115..63670e163b95 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -7,6 +7,7 @@ from quivr_core.models import KnowledgeStatus from sqlalchemy.exc import NoResultFound +from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.modules.dependencies import BaseService from quivr_api.modules.knowledge.dto.inputs import ( @@ -107,55 +108,16 @@ async def update_knowledge( ): return await self.repository.update_knowledge(knowledge, payload) - # TODO: Remove all of this - # TODO (@aminediro): Replace with ON CONFLICT smarter query... - # there is a chance of race condition but for now we let it crash in worker - # the tasks will be dealt with on retry - async def update_sha1_conflict( - self, knowledge: KnowledgeDB, brain_id: UUID, file_sha1: str - ) -> bool: - assert knowledge.id - knowledge.file_sha1 = file_sha1 - - try: - existing_knowledge = await self.repository.get_knowledge_by_sha1( - knowledge.file_sha1 - ) - logger.debug("The content of the knowledge already exists in the brain. ") - # Get existing knowledge sha1 and brains - if ( - existing_knowledge.status == KnowledgeStatus.UPLOADED - or existing_knowledge.status == KnowledgeStatus.PROCESSING - ): - existing_brains = await existing_knowledge.awaitable_attrs.brains - if brain_id in [b.brain_id for b in existing_brains]: - logger.debug("Added file to brain that already has the knowledge") - raise FileExistsError( - f"Existing file in brain {brain_id} with name {existing_knowledge.file_name}" - ) - else: - await self.repository.link_to_brain(existing_knowledge, brain_id) - await self.remove_knowledge_brain(brain_id, knowledge.id) - return False - else: - logger.debug(f"Removing previous errored file {existing_knowledge.id}") - assert existing_knowledge.id - await self.remove_knowledge_brain(brain_id, existing_knowledge.id) - await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1) - return True - except NoResultFound: - logger.debug( - f"First knowledge with sha1. Updating file_sha1 of {knowledge.id}" - ) - await self.update_file_sha1_knowledge(knowledge.id, knowledge.file_sha1) - return True - async def create_knowledge( self, user_id: UUID, knowledge_to_add: AddKnowledge, upload_file: UploadFile | None = None, ) -> KnowledgeDB: + brains = [] + if knowledge_to_add.parent_id: + parent_knowledge = await self.get_knowledge(knowledge_to_add.parent_id) + brains = await parent_knowledge.awaitable_attrs.brains knowledgedb = KnowledgeDB( user_id=user_id, file_name=knowledge_to_add.file_name, @@ -170,6 +132,7 @@ async def create_knowledge( parent_id=knowledge_to_add.parent_id, sync_id=knowledge_to_add.sync_id, sync_file_id=knowledge_to_add.sync_file_id, + brains=brains, ) knowledge_db = await self.repository.create_knowledge(knowledgedb) @@ -184,8 +147,26 @@ async def create_knowledge( knowledgedb.source_link = storage_path knowledge_db = await self.repository.update_knowledge( knowledge_db, - KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), # type: ignore + KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), + ) + if knowledge_db.brains and len(knowledge_db.brains) > 0: + # Schedule this new knowledge to be processed + knowledge_db = await self.repository.update_knowledge( + knowledge_db, + KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + ) + celery.send_task( + "process_file_task", + kwargs={ + "knowledge_id": knowledge_db.id, + "file_name": knowledge_db.file_name, + "source": knowledge_db.source, + "source_link": knowledge_db.source_link, + # TODO: Notification on notification + # "notification_id": upload_notification.id, + }, ) + return knowledge_db except Exception as e: logger.exception( diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index ea7b1a479d01..63ecd4dd8d24 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -270,11 +270,10 @@ def _send_task(*args, **kwargs): content=json_data, headers={"Content-Type": "application/json"}, ) - assert response.status_code == 201 - updated_kms = [KnowledgeDTO.model_validate(d) for d in response.json()] + # 3. Validate that created knowledges are correct assert len(updated_kms) == 2 assert next( filter(lambda k: k.id == folder_km.id, updated_kms) @@ -285,5 +284,5 @@ def _send_task(*args, **kwargs): for km in updated_kms: assert len(km.brains) == 1 - # Assert both files are being processed + # 4. Assert both files are being scheduled for processing assert len(tasks) == 2 diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index 75364ab0016b..39c1f8759d20 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -213,6 +213,31 @@ async def folder_km(session: AsyncSession, user: User): return folder +@pytest_asyncio.fixture(scope="function") +async def folder_km_brain(session: AsyncSession, brain_user: Brain): + "local folder linked to a brain" + user: User = (await brain_user.awaitable_attrs.users)[0] + assert user.id + folder = KnowledgeDB( + file_name="folder_1", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=0, + file_sha1=None, + brains=[brain_user], + children=[], + user_id=user.id, + is_folder=True, + parent_id=None, + ) + session.add(folder) + await session.commit() + await session.refresh(folder) + return folder + + @pytest.mark.asyncio(loop_scope="session") async def test_updates_knowledge_status(session: AsyncSession, test_data: TestData): brain, knowledges = test_data @@ -394,164 +419,6 @@ async def test_get_knowledge_in_brain(session: AsyncSession, test_data: TestData assert brain.brain_id in brains_of_knowledge -@pytest.mark.asyncio(loop_scope="session") -async def test_should_process_knowledge_exists( - session: AsyncSession, test_data: TestData -): - brain, [existing_knowledge, _] = test_data - assert brain.brain_id - new = KnowledgeDB( - file_name="new", - extension="txt", - status="PROCESSING", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=None, - brains=[brain], - user_id=existing_knowledge.user_id, - ) - session.add(new) - await session.commit() - await session.refresh(new) - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - assert existing_knowledge.file_sha1 - with pytest.raises(FileExistsError): - await service.update_sha1_conflict( - new, brain.brain_id, file_sha1=existing_knowledge.file_sha1 - ) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_should_process_knowledge_link_brain( - session: AsyncSession, test_data: TestData -): - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - brain, [existing_knowledge, _] = test_data - user_id = existing_knowledge.user_id - assert brain.brain_id - prev = KnowledgeDB( - file_name="prev", - extension=".txt", - status=KnowledgeStatus.UPLOADED, - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test1", - brains=[brain], - user_id=user_id, - ) - brain_2 = Brain( - name="test_brain", - description="this is a test brain", - brain_type=BrainType.integration, - ) - session.add(brain_2) - session.add(prev) - await session.commit() - await session.refresh(prev) - await session.refresh(brain_2) - - assert prev.id - assert brain_2.brain_id - - new = KnowledgeDB( - file_name="new", - extension="txt", - status="PROCESSING", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=None, - brains=[brain_2], - user_id=user_id, - ) - session.add(new) - await session.commit() - await session.refresh(new) - - incoming_knowledge = await new.to_dto() - assert prev.file_sha1 - - should_process = await service.update_sha1_conflict( - incoming_knowledge, brain_2.brain_id, file_sha1=prev.file_sha1 - ) - assert not should_process - - # Check prev knowledge was linked - assert incoming_knowledge.file_sha1 - prev_knowledge = await service.repository.get_knowledge_by_id(prev.id) - prev_brains = await prev_knowledge.awaitable_attrs.brains - assert {b.brain_id for b in prev_brains} == { - brain.brain_id, - brain_2.brain_id, - } - # Check new knowledge was removed - assert new.id - with pytest.raises(KnowledgeNotFoundException): - await service.repository.get_knowledge_by_id(new.id) - - -@pytest.mark.asyncio(loop_scope="session") -async def test_should_process_knowledge_prev_error( - session: AsyncSession, test_data: TestData -): - repo = KnowledgeRepository(session) - service = KnowledgeService(repo) - brain, [existing_knowledge, _] = test_data - user_id = existing_knowledge.user_id - assert brain.brain_id - prev = KnowledgeDB( - file_name="prev", - extension="txt", - status=KnowledgeStatus.ERROR, - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test1", - brains=[brain], - user_id=user_id, - ) - session.add(prev) - await session.commit() - await session.refresh(prev) - - assert prev.id - - new = KnowledgeDB( - file_name="new", - extension="txt", - status="PROCESSING", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1=None, - brains=[brain], - user_id=user_id, - ) - session.add(new) - await session.commit() - await session.refresh(new) - - incoming_knowledge = await new.to_dto() - assert prev.file_sha1 - should_process = await service.update_sha1_conflict( - incoming_knowledge, brain.brain_id, file_sha1=prev.file_sha1 - ) - - # Checks we should process this file - assert should_process - # Previous errored file is cleaned up - with pytest.raises(KnowledgeNotFoundException): - await service.repository.get_knowledge_by_id(prev.id) - - assert new.id - new = await service.repository.get_knowledge_by_id(new.id) - assert new.file_sha1 - - @pytest.mark.asyncio(loop_scope="session") async def test_get_knowledge_storage_path(session: AsyncSession, test_data: TestData): _, [knowledge, _] = test_data @@ -644,6 +511,57 @@ async def test_create_knowledge_folder(session: AsyncSession, user: User): assert storage.knowledge_exists(km) +@pytest.mark.asyncio(loop_scope="session") +async def test_create_knowledge_file_in_folder( + monkeypatch, session: AsyncSession, user: User, folder_km_brain: KnowledgeDB +): + tasks = {} + + def _send_task(*args, **kwargs): + tasks["args"] = args + tasks["kwargs"] = {**kwargs["kwargs"]} + + monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) + assert user.id + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=True, + parent_id=folder_km_brain.id, + ) + km_data = BytesIO(os.urandom(128)) + km = await service.create_knowledge( + user_id=user.id, + knowledge_to_add=km_to_add, + upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name), + ) + + assert km.file_name == km_to_add.file_name + assert km.id + # Knowledge properties + assert km.file_name == km_to_add.file_name + assert km.is_folder == km_to_add.is_folder + assert km.url == km_to_add.url + assert km.extension == km_to_add.extension + assert km.source == km_to_add.source + assert km.file_size == 128 + assert km.metadata_ == km_to_add.metadata + assert km.is_folder == km_to_add.is_folder + assert km.status == KnowledgeStatus.PROCESSING + # Knowledge was saved + assert storage.knowledge_exists(km) + assert km.brains + assert len(km.brains) > 0 + assert km.brains[0].brain_id == folder_km_brain.brains[0].brain_id + + # Scheduled + assert len(tasks) > 0 + + @pytest.mark.asyncio(loop_scope="session") async def test_create_knowledge_upload_error(session: AsyncSession, user: User): assert user.id From 50a0800ef99bef7e0def51f8a4b5dad4d3843708 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 23 Sep 2024 11:14:14 +0200 Subject: [PATCH 14/63] merged n+1 query optimization --- .../knowledge/controller/knowledge_routes.py | 2 +- .../quivr_api/modules/knowledge/dto/outputs.py | 1 + .../modules/knowledge/entity/knowledge.py | 15 ++++++++++----- .../modules/knowledge/repository/knowledges.py | 8 +++++++- .../knowledge/service/knowledge_service.py | 9 +++++---- .../knowledge/tests/test_knowledge_entity.py | 2 +- backend/worker/quivr_worker/celery_worker.py | 2 +- 7 files changed, 26 insertions(+), 13 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 9fbcd838a86d..0a8a188ec5dc 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -173,7 +173,7 @@ async def create_knowledge( @knowledge_router.get( - "/knowledge/children", + "/knowledge/files", response_model=List[KnowledgeDTO] | None, tags=["Knowledge"], ) diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py index 6d361dec12cb..5cf0a0ebf911 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py @@ -27,6 +27,7 @@ class KnowledgeDTO(BaseModel): file_sha1: Optional[str] = None metadata: Optional[Dict[str, str]] = None user_id: UUID + # TODO: brain dto here not the brain nor the model_dump brains: List[Dict[str, Any]] parent: Optional[Self] children: List[Self] diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index 19c2a3f2defe..e6c4b8cc714e 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime from enum import Enum from typing import Dict, List, Optional @@ -94,7 +95,9 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): sync_file_id: str | None = Field(default=None) # TODO: nested folder search - async def to_dto(self, get_children: bool = True) -> KnowledgeDTO: + async def to_dto( + self, get_children: bool = True, get_parent: bool = True + ) -> KnowledgeDTO: assert ( self.updated_at ), "knowledge should be inserted before transforming to dto" @@ -105,9 +108,11 @@ async def to_dto(self, get_children: bool = True) -> KnowledgeDTO: children: list[KnowledgeDB] = ( await self.awaitable_attrs.children if get_children else [] ) - children_dto = [await c.to_dto(get_children=False) for c in children] - parent = await self.awaitable_attrs.parent - parent = await parent.to_dto(get_children=False) if parent else None + children_dto = await asyncio.gather( + *[c.to_dto(get_children=False) for c in children] + ) + parent = await self.awaitable_attrs.parent if get_parent else None + parent_dto = await parent.to_dto(get_children=False) if parent else None return KnowledgeDTO( id=self.id, # type: ignore @@ -124,7 +129,7 @@ async def to_dto(self, get_children: bool = True) -> KnowledgeDTO: created_at=self.created_at, metadata=self.metadata_, # type: ignore brains=[b.model_dump() for b in brains], - parent=parent, + parent=parent_dto, children=children_dto, user_id=self.user_id, sync_id=self.sync_id, diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 18b4f4595c6e..6893efe7bdaf 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -289,8 +289,14 @@ async def get_knowledge_by_id( raise KnowledgeNotFoundException("Knowledge not found") return knowledge - async def get_brain_by_id(self, brain_id: UUID) -> Brain: + async def get_brain_by_id( + self, brain_id: UUID, get_knowledge: bool = True + ) -> Brain: query = select(Brain).where(Brain.brain_id == brain_id) + if get_knowledge: + query = query.options( + joinedload(Brain.knowledges).joinedload(KnowledgeDB.brains) + ) result = await self.session.exec(query) brain = result.first() if not brain: diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 63670e163b95..0c371be505cf 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -162,8 +162,6 @@ async def create_knowledge( "file_name": knowledge_db.file_name, "source": knowledge_db.source, "source_link": knowledge_db.source_link, - # TODO: Notification on notification - # "notification_id": upload_notification.id, }, ) @@ -202,9 +200,12 @@ async def insert_knowledge_brain( return inserted_knowledge async def get_all_knowledge_in_brain(self, brain_id: UUID) -> List[KnowledgeDTO]: - brain = await self.repository.get_brain_by_id(brain_id) + brain = await self.repository.get_brain_by_id(brain_id, get_knowledge=True) all_knowledges: List[KnowledgeDB] = await brain.awaitable_attrs.knowledges - knowledges = [await knowledge.to_dto() for knowledge in all_knowledges] + knowledges = [ + await knowledge.to_dto(get_children=False, get_parent=False) + for knowledge in all_knowledges + ] return knowledges diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index 02b9a0d24261..18a440be8118 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -189,7 +189,7 @@ async def test_knowledge_remove_folder_cascade( await session.commit() statement = select(KnowledgeDB) - results = (await session.exec(statement)).all() + results = (await session.exec(statement)).unique().all() assert results == [] diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index 9ad58dd5640c..ee652e6ab699 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -96,7 +96,7 @@ def init_worker(**kwargs): def process_file_task( knowledge_id: UUID, file_name: str, - notification_id: UUID, + notification_id: UUID | None = None, source: str | None = None, source_link: str | None = None, delete_file: bool = False, From 211fcaa0bc9b354bb2126d47b20aa6c240c3751d Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 23 Sep 2024 18:47:52 +0200 Subject: [PATCH 15/63] processor class worker --- .../knowledge/controller/knowledge_routes.py | 8 +- .../modules/knowledge/repository/storage.py | 16 +- .../knowledge/repository/storage_interface.py | 4 + .../knowledge/service/knowledge_service.py | 13 +- .../modules/knowledge/tests/conftest.py | 4 + .../modules/sync/controller/sync_routes.py | 30 +- .../quivr_api/modules/sync/utils/syncutils.py | 11 +- backend/core/quivr_core/files/file.py | 2 +- backend/worker/quivr_worker/celery_worker.py | 40 +-- backend/worker/quivr_worker/files.py | 26 +- .../worker/quivr_worker/process/processor.py | 300 ++++++++++++++++++ .../syncs/process_active_syncs.py | 17 + 12 files changed, 387 insertions(+), 84 deletions(-) create mode 100644 backend/worker/quivr_worker/process/processor.py diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 0a8a188ec5dc..5d690d1e8c01 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -301,10 +301,6 @@ async def link_knowledge_to_brain( async def _send_knowledge_process(knowledge: KnowledgeDB): assert knowledge.id - knowledge = await knowledge_service.update_knowledge( - knowledge=knowledge, - payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), - ) upload_notification = notification_service.add_notification( CreateNotification( user_id=current_user.id, @@ -324,6 +320,10 @@ async def _send_knowledge_process(knowledge: KnowledgeDB): "source_link": knowledge.source_link, }, ) + knowledge = await knowledge_service.update_knowledge( + knowledge=knowledge, + payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + ) if knowledge_dto.id is None: if knowledge_dto.sync_file_id is None: diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage.py b/backend/api/quivr_api/modules/knowledge/repository/storage.py index 0e58e25d9e48..a47ee8689dc8 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/storage.py +++ b/backend/api/quivr_api/modules/knowledge/repository/storage.py @@ -5,18 +5,30 @@ from quivr_api.modules.dependencies import get_supabase_async_client from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface +from supabase.client import AsyncClient logger = get_logger(__name__) class SupabaseS3Storage(StorageInterface): - def __init__(self): - self.client = None + def __init__(self, client: AsyncClient | None = None): + self.client = client async def _set_client(self): if self.client is None: self.client = await get_supabase_async_client() + async def download_file( + self, + knowledge: KnowledgeDB, + bucket_name: str = "quivr", + ) -> bytes: + await self._set_client() + assert self.client + path = self.get_storage_path(knowledge) + file_data = await self.client.storage.from_(bucket_name).download(path) + return file_data + def get_storage_path( self, knowledge: KnowledgeDB, diff --git a/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py b/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py index bd5a3debc03a..3a3e8cb8cc06 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py +++ b/backend/api/quivr_api/modules/knowledge/repository/storage_interface.py @@ -12,6 +12,10 @@ def get_storage_path( ) -> str: pass + @abstractmethod + async def download_file(self, knowledge: KnowledgeDB, **kwargs) -> bytes: + pass + @abstractmethod async def upload_file_storage( self, diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 0c371be505cf..6ed5d125ee1f 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -42,7 +42,7 @@ class KnowledgeService(BaseService[KnowledgeRepository]): def __init__( self, repository: KnowledgeRepository, - storage: StorageInterface = SupabaseS3Storage(), + storage: StorageInterface = SupabaseS3Storage(client=None), ): self.repository = repository self.storage = storage @@ -73,15 +73,11 @@ async def get_knowledge_storage_path( async def map_syncs_knowledge_user( self, sync_id: int, user_id: UUID - ) -> dict[str, KnowledgeDTO]: + ) -> dict[str, KnowledgeDB]: list_kms = await self.repository.get_all_knowledge_sync_user( sync_id=sync_id, user_id=user_id ) - return { - k.sync_file_id: k - for k in await asyncio.gather(*[k.to_dto() for k in list_kms]) - if k.sync_file_id - } + return {k.sync_file_id: k for k in list_kms if k.sync_file_id} async def list_knowledge( self, knowledge_id: UUID | None, user_id: UUID | None = None @@ -113,6 +109,7 @@ async def create_knowledge( user_id: UUID, knowledge_to_add: AddKnowledge, upload_file: UploadFile | None = None, + status: KnowledgeStatus = KnowledgeStatus.RESERVED, ) -> KnowledgeDB: brains = [] if knowledge_to_add.parent_id: @@ -128,7 +125,7 @@ async def create_knowledge( source_link=knowledge_to_add.source_link, file_size=upload_file.size if upload_file else 0, metadata_=knowledge_to_add.metadata, # type: ignore - status=KnowledgeStatus.RESERVED, + status=status, parent_id=knowledge_to_add.parent_id, sync_id=knowledge_to_add.sync_id, sync_file_id=knowledge_to_add.sync_file_id, diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py index b88cf2c2c356..404c8798d58c 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/conftest.py +++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py @@ -66,3 +66,7 @@ def knowledge_exists(self, knowledge: KnowledgeDB | KnowledgeDTO) -> bool: def clear_storage(self): self.storage.clear() + + async def download_file(self, knowledge: KnowledgeDB, **kwargs) -> bytes: + storage_path = self.get_storage_path(knowledge) + return self.storage[storage_path] diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index a62ee7426eb9..910d46e6ae37 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -1,7 +1,6 @@ -import asyncio import os from datetime import datetime -from typing import List, Tuple +from typing import List from fastapi import APIRouter, Depends, status @@ -20,8 +19,8 @@ from quivr_api.modules.sync.controller.notion_sync_routes import notion_sync_router from quivr_api.modules.sync.dto import SyncsDescription from quivr_api.modules.sync.dto.outputs import AuthMethodEnum -from quivr_api.modules.sync.entity.sync_models import SyncFile from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.syncutils import fetch_sync_knowledge from quivr_api.modules.user.entity.user_identity import UserIdentity notification_service = NotificationService() @@ -179,27 +178,23 @@ async def list_sync_files( # Gets knowledge for each call to list the files, # The logic is that getting from DB will be faster than provider repsonse ? # NOTE: asyncio.gather didn't correcly typecheck - async def fetch_data() -> Tuple[dict[str, KnowledgeDTO], List[SyncFile] | None]: - map_knowledges_task = knowledge_service.map_syncs_knowledge_user( - sync_id=sync_id, user_id=current_user.id - ) - sync_files_task = syncs_service.get_files_folder_user_sync( - sync_id, - current_user.id, - folder_id, - ) - return await asyncio.gather(map_knowledges_task, sync_files_task) sync = await syncs_service.get_sync_by_id(sync_id=sync_id) - map_knowledges, sync_files = await fetch_data() + syncfile_to_knowledge, sync_files = await fetch_sync_knowledge( + sync_id=sync_id, + user_id=current_user.id, + folder_id=folder_id, + knowledge_service=knowledge_service, + syncs_service=syncs_service, + ) if not sync_files: return None kms = [] for file in sync_files: - existing_km = map_knowledges.get(file.id) + existing_km = syncfile_to_knowledge.get(file.id) if existing_km: - kms.append(existing_km) + kms.append(await existing_km.to_dto(get_children=False, get_parent=False)) else: last_modified_at = ( file.last_modified_at if file.last_modified_at else datetime.now() @@ -216,7 +211,8 @@ async def fetch_data() -> Tuple[dict[str, KnowledgeDTO], List[SyncFile] | None]: brains=[], parent=None, children=[], - status=None, # TODO: Handle a sync not added status + # TODO: Handle a sync not added status + status=None, # TODO: retrieve created at from sync provider created_at=last_modified_at, updated_at=last_modified_at, diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py index ff86e396ecca..fa2d426bae71 100644 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ b/backend/api/quivr_api/modules/sync/utils/syncutils.py @@ -40,7 +40,6 @@ def filter_on_supported_files( prev_file = existing_files.get(new_file.name, None) if (prev_file and prev_file.supported) or prev_file is None: res.append((new_file, prev_file)) - return res @@ -71,18 +70,18 @@ def should_download_file( class SyncUtils: def __init__( self, - sync_user_service: ISyncUserService, - sync_active_service: ISyncService, - knowledge_service: KnowledgeService, - sync_files_repo: SyncFileInterface, + # sync_user_service: ISyncUserService, + # sync_active_service: ISyncService, + # sync_files_repo: SyncFileInterface, sync_cloud: BaseSync, + knowledge_service: KnowledgeService, notification_service: NotificationService, brain_vectors: BrainsVectors, ) -> None: self.sync_user_service = sync_user_service self.sync_active_service = sync_active_service - self.knowledge_service = knowledge_service self.sync_files_repo = sync_files_repo + self.knowledge_service = knowledge_service self.sync_cloud = sync_cloud self.notification_service = notification_service self.brain_vectors = brain_vectors diff --git a/backend/core/quivr_core/files/file.py b/backend/core/quivr_core/files/file.py index 9f4089b103fb..0a778f176567 100644 --- a/backend/core/quivr_core/files/file.py +++ b/backend/core/quivr_core/files/file.py @@ -112,9 +112,9 @@ def __init__( id: UUID, original_filename: str, path: Path, - brain_id: UUID, file_sha1: str, file_extension: FileExtension | str, + brain_id: UUID | None = None, file_size: int | None = None, metadata: dict[str, Any] | None = None, ) -> None: diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index ee652e6ab699..6a06b856b937 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -112,7 +112,6 @@ def process_file_task( loop.run_until_complete( aprocess_file_task( file_name=file_name, - notification_id=notification_id, knowledge_id=knowledge_id, source=source, source_link=source_link, @@ -123,7 +122,6 @@ def process_file_task( async def aprocess_file_task( file_name: str, - notification_id: UUID, knowledge_id: UUID, source: str | None = None, source_link: str | None = None, @@ -136,34 +134,26 @@ async def aprocess_file_task( await async_session.execute( text("SET SESSION idle_in_transaction_session_timeout = '5min';") ) - with Session(engine, expire_on_commit=False, autoflush=False) as session: - session.execute( - text("SET SESSION idle_in_transaction_session_timeout = '5min';") - ) - vector_repository = VectorRepository(session) - vector_service = VectorService( - vector_repository - ) # FIXME @amine: fix to need AsyncSession in vector Service - knowledge_repository = KnowledgeRepository(async_session) - knowledge_service = KnowledgeService(knowledge_repository) - await process_uploaded_file( - supabase_client=supabase_client, - brain_service=brain_service, - vector_service=vector_service, - knowledge_service=knowledge_service, - file_name=file_name, - knowledge_id=knowledge_id, - integration=source, - integration_link=source_link, - ) - session.commit() + # FIXME @amine: fix to need AsyncSession in vector Service + vector_repository = VectorRepository(async_session.sync_session) + vector_service = VectorService(vector_repository) + knowledge_repository = KnowledgeRepository(async_session) + knowledge_service = KnowledgeService(knowledge_repository) + await process_uploaded_file( + supabase_client=supabase_client, + brain_service=brain_service, + vector_service=vector_service, + knowledge_service=knowledge_service, + file_name=file_name, + knowledge_id=knowledge_id, + integration=source, + integration_link=source_link, + ) await async_session.commit() except Exception as e: - session.rollback() await async_session.rollback() raise e finally: - session.close() await async_session.close() diff --git a/backend/worker/quivr_worker/files.py b/backend/worker/quivr_worker/files.py index 8aefe51d9c4e..d27385763b4d 100644 --- a/backend/worker/quivr_worker/files.py +++ b/backend/worker/quivr_worker/files.py @@ -23,33 +23,18 @@ def compute_sha1(content: bytes) -> str: @contextmanager def build_file( file_data: bytes, - knowledge_id: UUID, - file_name: str, - original_file_name: str | None = None, + file_name_ext: str, ): try: # TODO(@aminediro) : Maybe use fsspec file to be agnostic to where files are stored :? # We are reading the whole file to memory, which doesn't scale - tmp_name, base_file_name, file_extension = get_tmp_name(file_name) + tmp_name, _, _ = get_tmp_name(file_name_ext) tmp_file = NamedTemporaryFile( suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none ) tmp_file.write(file_data) tmp_file.flush() - file_sha1 = compute_sha1(file_data) - - file_instance = File( - knowledge_id=knowledge_id, - file_name=base_file_name, - original_file_name=( - original_file_name if original_file_name else base_file_name - ), - tmp_file_path=Path(tmp_file.name), - file_size=len(file_data), - file_extension=file_extension, - file_sha1=file_sha1, - ) - yield file_instance + yield Path(tmp_file.name) finally: # Code to release resource, e.g.: tmp_file.close() @@ -85,14 +70,13 @@ def __init__( self.original_file_name = original_file_name def is_empty(self): - return self.file_size < 1 # pyright: ignore reportPrivateUsage=none + return self.file_size < 1 - def to_qfile(self, brain_id: UUID, metadata: dict[str, Any] = {}) -> QuivrFile: + def to_qfile(self, metadata: dict[str, Any] = {}) -> QuivrFile: return QuivrFile( id=self.id, original_filename=self.file_name, path=self.tmp_file_path, - brain_id=brain_id, file_sha1=self.file_sha1, file_extension=self.file_extension, file_size=self.file_size, diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py new file mode 100644 index 000000000000..096ccf3bc4ab --- /dev/null +++ b/backend/worker/quivr_worker/process/processor.py @@ -0,0 +1,300 @@ +import time +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass +from io import BytesIO +from typing import Any, AsyncGenerator, Generator, List, Tuple +from uuid import UUID + +from quivr_api.logger import get_logger +from quivr_api.modules.dependencies import get_supabase_async_client +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource +from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import SyncFile +from quivr_api.modules.sync.repository.sync_repository import SyncsRepository +from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.sync import ( + AzureDriveSync, + BaseSync, + DropboxSync, + GitHubSync, + GoogleDriveSync, +) +from quivr_api.modules.vector.repository.vectors_repository import VectorRepository +from quivr_api.modules.vector.service.vector_service import VectorService +from quivr_core.files.file import FileExtension, QuivrFile +from quivr_core.models import KnowledgeStatus +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlmodel import text +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_worker.files import build_file, compute_sha1 + +logger = get_logger("celery_worker") + + +def skip_process(knowledge: KnowledgeDTO) -> bool: + return knowledge.is_folder and knowledge.source != KnowledgeSource.NOTION + + +def build_syncprovider_mapping() -> dict[str, BaseSync]: + mapping_sync_utils = { + "google": GoogleDriveSync(), + "azure": AzureDriveSync(), + "dropbox": DropboxSync(), + "github": GitHubSync(), + # "notion", NotionSync(notion_service=notion_service), + } + return mapping_sync_utils + + +@dataclass +class ProcessorServices: + sync_service: SyncsService + vector_service: VectorService + knowledge_service: KnowledgeService + syncprovider_mapping: dict[str, BaseSync] + + +@asynccontextmanager +async def _start_session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: + async with AsyncSession(engine) as session: + try: + await session.execute( + text("SET SESSION idle_in_transaction_session_timeout = '5min';") + ) + yield session + await session.commit() + except Exception as e: + await session.rollback() + raise e + finally: + await session.close() + + +@asynccontextmanager +async def build_processor_services(engine: AsyncEngine): + async_client = await get_supabase_async_client() + storage = SupabaseS3Storage(async_client) + try: + async with _start_session(engine) as async_session: + vector_repository = VectorRepository(async_session.sync_session) + vector_service = VectorService(vector_repository) + knowledge_repository = KnowledgeRepository(async_session) + knowledge_service = KnowledgeService(knowledge_repository, storage=storage) + sync_repository = SyncsRepository(async_session) + sync_service = SyncsService(sync_repository) + yield ProcessorServices( + knowledge_service=knowledge_service, + vector_service=vector_service, + sync_service=sync_service, + syncprovider_mapping=build_syncprovider_mapping(), + ) + finally: + logger.info("Closing processor services") + + +async def download_sync_file( + sync_provider: BaseSync, file: SyncFile, credentials: dict[str, Any] +) -> bytes: + logger.info(f"Downloading {file} using {sync_provider}") + file_response = await sync_provider.adownload_file(credentials, file) + logger.debug(f"Fetch sync file response: {file_response}") + raw_data = file_response["content"] + if isinstance(raw_data, BytesIO): + file_data = raw_data.read() + else: + file_data = raw_data.encode("utf-8") + logger.debug(f"Successfully downloaded sync file : {file}") + return file_data + + +@contextmanager +def build_qfile( + knowledge: KnowledgeDB, file_data: bytes +) -> Generator[QuivrFile, None, None]: + assert knowledge.id + assert knowledge.file_name + assert knowledge.file_sha1 + with build_file( + file_data=file_data, file_name_ext=knowledge.file_name + ) as tmp_file_path: + qfile = QuivrFile( + id=knowledge.id, + original_filename=knowledge.file_name, + path=tmp_file_path, + file_sha1=knowledge.file_sha1, + file_extension=FileExtension(knowledge.extension), + file_size=knowledge.file_size, + metadata={ + "date": time.strftime("%Y%m%d"), + "file_name": knowledge.file_name, + "knowledge_id": knowledge.id, + }, + ) + if knowledge.metadata_: + qfile.additional_metadata = { + **qfile.metadata, + **knowledge.metadata_, + } + yield qfile + + +class KnowledgeProcessor: + def __init__(self, services: ProcessorServices): + self.services = services + + async def fetch_sync_knowledge( + self, + sync_id: int, + user_id: UUID, + folder_id: str | None, + ) -> Tuple[dict[str, KnowledgeDB], List[SyncFile] | None]: + map_knowledges_task = self.services.knowledge_service.map_syncs_knowledge_user( + sync_id=sync_id, user_id=user_id + ) + sync_files_task = self.services.sync_service.get_files_folder_user_sync( + sync_id, + user_id, + folder_id, + ) + return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 + + @asynccontextmanager + async def build_processable( + self, knowledge: KnowledgeDTO + ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: + if knowledge.source == KnowledgeSource.LOCAL: + async with self._build_local(knowledge) as to_process: + yield to_process + elif knowledge.source in ( + KnowledgeSource.AZURE, + KnowledgeSource.GITHUB, + KnowledgeSource.GOOGLE, + KnowledgeSource.NOTION, + ): + async with self._build_sync(knowledge) as to_process: + yield to_process + elif knowledge.source == KnowledgeSource.WEB: + raise NotImplementedError + else: + logger.error( + f"received knowledge : {knowledge.id} with unknown source: {knowledge.source}" + ) + raise ValueError("Unknown knowledge source : {knoledge.source}") + + @asynccontextmanager + async def _build_local( + self, knowledge: KnowledgeDTO + ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: + if knowledge.id is None or knowledge.file_name is None: + logger.error(f"received unprocessable local knowledge : {knowledge.id} ") + raise ValueError( + f"received unprocessable local knowledge : {knowledge.id} " + ) + knowledge_db = await self.services.knowledge_service.get_knowledge(knowledge.id) + file_data = await self.services.knowledge_service.storage.download_file( + knowledge_db + ) + knowledge_db.file_sha1 = compute_sha1(file_data) + with build_qfile(knowledge_db, file_data) as qfile: + yield (knowledge_db, qfile) + + @asynccontextmanager + async def _build_sync( + self, knowledge_dto: KnowledgeDTO + ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: + if knowledge_dto.id is None: + logger.error(f"received unprocessable knowledge: {knowledge_dto.id} ") + raise ValueError + + parent_knowledge = await self.services.knowledge_service.get_knowledge( + knowledge_dto.id + ) + if parent_knowledge.file_name is None: + logger.error(f"received unprocessable knowledge : {parent_knowledge.id} ") + raise ValueError( + f"received unprocessable knowledge : {parent_knowledge.id} " + ) + if ( + parent_knowledge.sync_file_id is None + or parent_knowledge.sync_id is None + or parent_knowledge.source_link is None + ): + logger.error( + f"unprocessable sync knowledge : {parent_knowledge.id}. no sync_file_id" + ) + raise ValueError( + f"received unprocessable knowledge : {parent_knowledge.id} " + ) + # Get associated sync + sync = await self.services.sync_service.get_sync_by_id(parent_knowledge.sync_id) + if sync.credentials is None: + logger.error(f"can't process sync file. sync {sync.id} has no credentials") + return + provider_name = SyncProvider(sync.provider.lower()) + sync_provider = self.services.syncprovider_mapping[provider_name] + + syncfile_to_knowledge, sync_files = await self.fetch_sync_knowledge( + sync_id=parent_knowledge.sync_id, + user_id=parent_knowledge.user_id, + folder_id=parent_knowledge.sync_file_id, + ) + if not sync_files: + return + + # Yield parent knowledge to process + file_data = await download_sync_file( + sync_provider=sync_provider, + file=SyncFile( + id=parent_knowledge.sync_file_id, + name=parent_knowledge.file_name, + extension=parent_knowledge.extension, + web_view_link=parent_knowledge.source_link, + is_folder=parent_knowledge.is_folder, + last_modified_at=parent_knowledge.updated_at, + ), + credentials=sync.credentials, + ) + parent_knowledge.file_sha1 = compute_sha1(file_data) + with build_qfile(parent_knowledge, file_data) as qfile: + yield (parent_knowledge, qfile) + + for sync_file in sync_files: + existing_km = syncfile_to_knowledge.get(sync_file.id) + if existing_km: + file_knowledge = existing_km + else: + # create sync file knowledge + # automagically gets the brains associated with the parent + file_knowledge = await self.services.knowledge_service.create_knowledge( + user_id=parent_knowledge.user_id, + knowledge_to_add=AddKnowledge( + file_name=sync_file.name, + is_folder=sync_file.is_folder, + extension=sync_file.extension, + source=parent_knowledge.source, # same as parent + source_link=sync_file.web_view_link, + parent_id=parent_knowledge.id, + sync_id=parent_knowledge.sync_id, + sync_file_id=sync_file.id, + ), + status=KnowledgeStatus.PROCESSING, + upload_file=None, + ) + file_data = await download_sync_file( + sync_provider=sync_provider, + file=sync_file, + credentials=sync.credentials, + ) + file_knowledge.file_sha1 = compute_sha1(file_data) + with build_qfile(file_knowledge, file_data) as qfile: + yield (file_knowledge, qfile) + + async def process_knowledge(self, knowledge_dto: KnowledgeDTO): + async for (knowledge, qfile) in self.build_processable(knowledge_dto): + pass diff --git a/backend/worker/quivr_worker/syncs/process_active_syncs.py b/backend/worker/quivr_worker/syncs/process_active_syncs.py index 92733e466d63..0314367b8e5e 100644 --- a/backend/worker/quivr_worker/syncs/process_active_syncs.py +++ b/backend/worker/quivr_worker/syncs/process_active_syncs.py @@ -14,6 +14,12 @@ update_notion_pages, ) from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.sync import ( + AzureDriveSync, + DropboxSync, + GitHubSync, + GoogleDriveSync, +) from quivr_api.modules.sync.utils.syncutils import SyncUtils from sqlalchemy.ext.asyncio import AsyncEngine from sqlmodel import text @@ -26,6 +32,17 @@ logger = get_logger("celery_worker") +async def build_syncprovider_mapping(): + mapping_sync_utils = { + "google": GoogleDriveSync(), + "azure": AzureDriveSync(), + "dropbox": DropboxSync(), + "github": GitHubSync(), + # "notion", NotionSync(notion_service=notion_service), + } + return mapping_sync_utils + + async def process_sync( sync: SyncsActive, files_ids: list[str], From 51421c0f103f4669c3ad532a1fb7f09d9d6b9192 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 24 Sep 2024 15:55:50 +0200 Subject: [PATCH 16/63] working process knowledge --- .../modules/knowledge/tests/conftest.py | 8 +- .../modules/sync/controller/sync_routes.py | 25 +- .../sync/tests/test_sync_controller.py | 5 + .../quivr_api/modules/sync/utils/syncutils.py | 83 ++++--- .../quivr_api/modules/vector/entity/vector.py | 8 +- .../vector/repository/vectors_repository.py | 33 ++- .../modules/vector/service/vector_service.py | 24 +- .../modules/vector/tests/test_vectors.py | 89 ++++---- backend/api/quivr_api/vectorstore/supabase.py | 5 +- backend/core/quivr_core/files/file.py | 1 - .../quivr_core/processor/processor_base.py | 3 - backend/worker/quivr_worker/celery_worker.py | 215 +++--------------- backend/worker/quivr_worker/files.py | 90 -------- backend/worker/quivr_worker/parsers/audio.py | 8 +- .../quivr_worker/process/process_file.py | 86 ++----- .../quivr_worker/process/process_s3_file.py | 41 ---- .../worker/quivr_worker/process/processor.py | 109 ++++----- backend/worker/quivr_worker/process/utils.py | 105 +++++++++ backend/worker/quivr_worker/syncs/utils.py | 91 -------- backend/worker/quivr_worker/utils.py | 10 - backend/worker/tests/conftest.py | 173 ++++++++++++-- backend/worker/tests/test_process_file.py | 36 +-- .../worker/tests/test_process_file_task.py | 49 ++++ backend/worker/tests/test_sync.py | 0 backend/worker/tests/test_utils.py | 1 - 25 files changed, 589 insertions(+), 709 deletions(-) delete mode 100644 backend/worker/quivr_worker/files.py delete mode 100644 backend/worker/quivr_worker/process/process_s3_file.py create mode 100644 backend/worker/quivr_worker/process/utils.py delete mode 100644 backend/worker/quivr_worker/syncs/utils.py delete mode 100644 backend/worker/tests/test_sync.py diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py index 404c8798d58c..63c9768ddc70 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/conftest.py +++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py @@ -47,7 +47,13 @@ async def upload_file_storage( storage_path = f"{knowledge.id}" if not upsert and storage_path in self.storage: raise ValueError(f"File already exists at {storage_path}") - self.storage[storage_path] = knowledge_data + if isinstance(knowledge_data, FileIO) or isinstance( + knowledge_data, BufferedReader + ): + self.storage[storage_path] = knowledge_data.read() + else: + self.storage[storage_path] = knowledge_data + return storage_path async def remove_file(self, storage_path: str): diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index 910d46e6ae37..17ee0b61c247 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -1,13 +1,15 @@ +import asyncio import os from datetime import datetime -from typing import List +from typing import List, Tuple +from uuid import UUID from fastapi import APIRouter, Depends, status from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDTO +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeDTO from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.service.notification_service import ( NotificationService, @@ -19,8 +21,8 @@ from quivr_api.modules.sync.controller.notion_sync_routes import notion_sync_router from quivr_api.modules.sync.dto import SyncsDescription from quivr_api.modules.sync.dto.outputs import AuthMethodEnum +from quivr_api.modules.sync.entity.sync_models import SyncFile from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.syncutils import fetch_sync_knowledge from quivr_api.modules.user.entity.user_identity import UserIdentity notification_service = NotificationService() @@ -179,13 +181,26 @@ async def list_sync_files( # The logic is that getting from DB will be faster than provider repsonse ? # NOTE: asyncio.gather didn't correcly typecheck + async def fetch_sync_knowledge( + sync_id: int, + user_id: UUID, + folder_id: str | None, + ) -> Tuple[dict[str, KnowledgeDB], List[SyncFile] | None]: + map_knowledges_task = knowledge_service.map_syncs_knowledge_user( + sync_id=sync_id, user_id=user_id + ) + sync_files_task = syncs_service.get_files_folder_user_sync( + sync_id, + user_id, + folder_id, + ) + return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 + sync = await syncs_service.get_sync_by_id(sync_id=sync_id) syncfile_to_knowledge, sync_files = await fetch_sync_knowledge( sync_id=sync_id, user_id=current_user.id, folder_id=folder_id, - knowledge_service=knowledge_service, - syncs_service=syncs_service, ) if not sync_files: return None diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py index 6d14dc897056..057c0ae57d14 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py +++ b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py @@ -38,6 +38,11 @@ class BaseFakeSync(BaseSync): lower_name = "google" datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" + def __init__(self, provider_name: str | None = None): + super().__init__() + if provider_name: + self.lower_name = provider_name + def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFile]: return [ SyncFile( diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py index fa2d426bae71..08419049db72 100644 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ b/backend/api/quivr_api/modules/sync/utils/syncutils.py @@ -31,40 +31,57 @@ celery_inspector = celery.control.inspect() -# NOTE: we are filtering based on file path names in sync ! -def filter_on_supported_files( - files: list[SyncFile], existing_files: dict[str, DBSyncFile] -) -> list[Tuple[SyncFile, DBSyncFile | None]]: - res = [] - for new_file in files: - prev_file = existing_files.get(new_file.name, None) - if (prev_file and prev_file.supported) or prev_file is None: - res.append((new_file, prev_file)) - return res - - -def should_download_file( - file: SyncFile, - last_updated_sync_active: datetime | None, - provider_name: str, - datetime_format: str, -) -> bool: - file_last_modified_utc = datetime.strptime( - file.last_modified_at, datetime_format - ).replace(tzinfo=timezone.utc) - - should_download = ( - last_updated_sync_active is None - or file_last_modified_utc > last_updated_sync_active +async def fetch_sync_knowledge( + self, + sync_id: int, + user_id: UUID, + folder_id: str | None, +) -> Tuple[dict[str, KnowledgeDB], List[SyncFile] | None]: + map_knowledges_task = self.services.knowledge_service.map_syncs_knowledge_user( + sync_id=sync_id, user_id=user_id ) - - # TODO: Handle notion database - if provider_name == "notion": - should_download &= file.extension != "db" - else: - should_download &= not file.is_folder - - return should_download + sync_files_task = self.services.sync_service.get_files_folder_user_sync( + sync_id, + user_id, + folder_id, + ) + return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 + + +# # NOTE: we are filtering based on file path names in sync ! +# def filter_on_supported_files( +# files: list[SyncFile], existing_files: dict[str, DBSyncFile] +# ) -> list[Tuple[SyncFile, DBSyncFile | None]]: +# res = [] +# for new_file in files: +# prev_file = existing_files.get(new_file.name, None) +# if (prev_file and prev_file.supported) or prev_file is None: +# res.append((new_file, prev_file)) +# return res + + +# def should_download_file( +# file: SyncFile, +# last_updated_sync_active: datetime | None, +# provider_name: str, +# datetime_format: str, +# ) -> bool: +# file_last_modified_utc = datetime.strptime( +# file.last_modified_at, datetime_format +# ).replace(tzinfo=timezone.utc) + +# should_download = ( +# last_updated_sync_active is None +# or file_last_modified_utc > last_updated_sync_active +# ) + +# # TODO: Handle notion database +# if provider_name == "notion": +# should_download &= file.extension != "db" +# else: +# should_download &= not file.is_folder + +# return should_download class SyncUtils: diff --git a/backend/api/quivr_api/modules/vector/entity/vector.py b/backend/api/quivr_api/modules/vector/entity/vector.py index a0d46baa41a8..c8c2bec775ff 100644 --- a/backend/api/quivr_api/modules/vector/entity/vector.py +++ b/backend/api/quivr_api/modules/vector/entity/vector.py @@ -3,12 +3,10 @@ from pgvector.sqlalchemy import Vector as PGVector from pydantic import BaseModel -from sqlalchemy import Column +from quivr_api.models.settings import settings from sqlmodel import JSON, Column, Field, SQLModel, text from sqlmodel import UUID as PGUUID -from quivr_api.models.settings import settings - class Vector(SQLModel, table=True): __tablename__ = "vectors" # type: ignore @@ -21,10 +19,10 @@ class Vector(SQLModel, table=True): ), ) content: str = Field(default=None) - metadata_: dict = Field(default={}, sa_column=Column("metadata", JSON, default={})) embedding: Optional[PGVector] = Field( sa_column=Column(PGVector(settings.embedding_dim)), - ) # Verify with text_ada -> put it in Env variabme + ) + metadata_: dict = Field(default={}, sa_column=Column("metadata", JSON, default={})) knowledge_id: UUID = Field(default=None, foreign_key="knowledge.id") class Config: diff --git a/backend/api/quivr_api/modules/vector/repository/vectors_repository.py b/backend/api/quivr_api/modules/vector/repository/vectors_repository.py index a7a5a4b63b85..21906360b1e4 100644 --- a/backend/api/quivr_api/modules/vector/repository/vectors_repository.py +++ b/backend/api/quivr_api/modules/vector/repository/vectors_repository.py @@ -5,42 +5,39 @@ from quivr_api.modules.dependencies import BaseRepository from quivr_api.modules.vector.entity.vector import SimilaritySearchOutput, Vector from sqlalchemy import exc, text -from sqlmodel import Session, select +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession logger = get_logger(__name__) class VectorRepository(BaseRepository): - def __init__(self, session: Session): + def __init__(self, session: AsyncSession): super().__init__(session) self.session = session - def create_vectors(self, new_vectors: List[Vector]) -> List[Vector]: + async def create_vectors(self, new_vectors: List[Vector]) -> List[Vector]: try: - # Use SQLAlchemy session to add and commit the new vector self.session.add_all(new_vectors) - self.session.commit() + await self.session.commit() + for vector in new_vectors: + await self.session.refresh(vector) + return new_vectors except exc.IntegrityError: # Rollback the session if there’s an IntegrityError - self.session.rollback() + await self.session.rollback() raise Exception("Integrity error occurred while creating vector.") except Exception as e: - self.session.rollback() + await self.session.rollback() print(f"Error: {e}") raise Exception(f"An error occurred while creating vector: {e}") - # Refresh the session to get any updated fields (like auto-generated IDs) - for vector in new_vectors: - self.session.refresh(vector) - - return new_vectors - - def get_vectors_by_knowledge_id(self, knowledge_id: UUID) -> Sequence[Vector]: + async def get_vectors_by_knowledge_id(self, knowledge_id: UUID) -> Sequence[Vector]: query = select(Vector).where(Vector.knowledge_id == knowledge_id) - results = self.session.execute(query) + results = await self.session.execute(query) return results.scalars().all() - def similarity_search( + async def similarity_search( self, query_embedding: List[float], brain_id: UUID, @@ -94,13 +91,13 @@ def similarity_search( """) params = { - "query_embedding": query_embedding, + "query_embedding": str(query_embedding), "p_brain_id": brain_id, "k": k, "max_chunk_sum": max_chunk_sum, } - result = self.session.execute(sql_query, params=params) + result = await self.session.execute(sql_query, params=params) full_results = result.all() formated_result = [ SimilaritySearchOutput( diff --git a/backend/api/quivr_api/modules/vector/service/vector_service.py b/backend/api/quivr_api/modules/vector/service/vector_service.py index 6a775dd6e1fb..d627837d7fad 100644 --- a/backend/api/quivr_api/modules/vector/service/vector_service.py +++ b/backend/api/quivr_api/modules/vector/service/vector_service.py @@ -13,18 +13,26 @@ class VectorService(BaseService[VectorRepository]): repository_cls = VectorRepository - _embedding: Embeddings = get_embedding_client() - def __init__(self, repository: VectorRepository): + def __init__( + self, repository: VectorRepository, embedder: Embeddings | None = None + ): + if embedder is None: + self.embedder = get_embedding_client() + else: + self.embedder = embedder + self.repository = repository - def create_vectors(self, chunks: List[Document], knowledge_id: UUID) -> List[UUID]: + async def create_vectors( + self, chunks: List[Document], knowledge_id: UUID + ) -> List[UUID]: # Vector is created upon the user's first question asked logger.info( f"New vector entry in vectors table for knowledge_id {knowledge_id}" ) # FIXME ADD a check in case of failure - embeddings = self._embedding.embed_documents( + embeddings = self.embedder.embed_documents( [chunk.page_content for chunk in chunks] ) new_vectors = [ @@ -36,14 +44,14 @@ def create_vectors(self, chunks: List[Document], knowledge_id: UUID) -> List[UUI ) for i, chunk in enumerate(chunks) ] - created_vector = self.repository.create_vectors(new_vectors) + created_vector = await self.repository.create_vectors(new_vectors) return [vector.id for vector in created_vector if vector.id] - def similarity_search(self, query: str, brain_id: UUID, k: int = 40): - vectors = self._embedding.embed_documents([query]) + async def similarity_search(self, query: str, brain_id: UUID, k: int = 40): + vectors = self.embedder.embed_documents([query]) query_embedding = vectors[0] - vectors = self.repository.similarity_search( + vectors = await self.repository.similarity_search( query_embedding=query_embedding, brain_id=brain_id, k=k ) diff --git a/backend/api/quivr_api/modules/vector/tests/test_vectors.py b/backend/api/quivr_api/modules/vector/tests/test_vectors.py index ce4b6f04bf2a..d64d772b04bb 100644 --- a/backend/api/quivr_api/modules/vector/tests/test_vectors.py +++ b/backend/api/quivr_api/modules/vector/tests/test_vectors.py @@ -1,10 +1,13 @@ from typing import List, Tuple import pytest +import pytest_asyncio from langchain.docstore.document import Document from langchain_core.embeddings import DeterministicFakeEmbedding -from sqlmodel import Session, select +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession +from quivr_api.models.settings import settings from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.user.entity.user_identity import User @@ -19,13 +22,13 @@ @pytest.fixture(scope="module") def embedder(): - return DeterministicFakeEmbedding(size=1536) + return DeterministicFakeEmbedding(size=settings.embedding_dim) -@pytest.fixture(scope="function") -def test_data(sync_session: Session, embedder) -> TestData: +@pytest_asyncio.fixture(scope="function") +async def test_data(session: AsyncSession, embedder) -> TestData: user_1 = ( - sync_session.exec(select(User).where(User.email == "admin@quivr.app")) + await session.exec(select(User).where(User.email == "admin@quivr.app")) ).one() assert user_1.id vectors = embedder.embed_documents( @@ -51,9 +54,9 @@ def test_data(sync_session: Session, embedder) -> TestData: brains=[brain_1], user_id=user_1.id, ) - sync_session.add(knowledge_1) - sync_session.commit() - sync_session.refresh(knowledge_1) + session.add(knowledge_1) + await session.commit() + await session.refresh(knowledge_1) assert knowledge_1.id, "Knowledge ID not generated" @@ -71,39 +74,44 @@ def test_data(sync_session: Session, embedder) -> TestData: knowledge_id=knowledge_1.id, ) - sync_session.add(vector_1) - sync_session.add(vector_2) + session.add(vector_1) + session.add(vector_2) - sync_session.commit() + await session.commit() return ([vector_1, vector_2], knowledge_1, brain_1) -def test_create_vectors_service(sync_session: Session, test_data: TestData, embedder): +@pytest.mark.asyncio(loop_scope="session") +async def test_create_vectors_service( + session: AsyncSession, test_data: TestData, embedder +): _, knowledge, _ = test_data assert knowledge.id - repo = VectorRepository(sync_session) + repo = VectorRepository(session) service = VectorService(repo) - service._embedding = embedder + service.embedder = embedder chunk_1 = Document(page_content="I love eating pasta with tomato sauce") chunk_2 = Document(page_content="I love eating pizza with extra cheese") # Create vectors from documents - new_vectors_id: List[int] = service.create_vectors([chunk_1, chunk_2], knowledge.id) # type: ignore + new_vectors_id: List[int] = await service.create_vectors( + [chunk_1, chunk_2], knowledge.id + ) # type: ignore # Verify the correct number of vectors were created assert len(new_vectors_id) == 2, f"Expected 2 vectors, got {len(new_vectors_id)}" # Verify the content of the first vector matches the corresponding document vector_1_content = ( - sync_session.execute(select(Vector).where(Vector.id == new_vectors_id[0])) + (await session.execute(select(Vector).where(Vector.id == new_vectors_id[0]))) .scalars() .first() .content ) vector_2_content = ( - sync_session.execute(select(Vector).where(Vector.id == new_vectors_id[1])) + (await session.execute(select(Vector).where(Vector.id == new_vectors_id[1]))) .scalars() .first() .content @@ -117,12 +125,13 @@ def test_create_vectors_service(sync_session: Session, test_data: TestData, embe ), "The content of the second vector does not match" -def test_get_vectors_by_knowledge_id(sync_session: Session, test_data: TestData): +@pytest.mark.asyncio(loop_scope="session") +async def test_get_vectors_by_knowledge_id(session: AsyncSession, test_data: TestData): vectors, knowledge, _ = test_data assert knowledge.id - repo = VectorRepository(sync_session) - results = repo.get_vectors_by_knowledge_id(knowledge.id) # type: ignore + repo = VectorRepository(session) + results = await repo.get_vectors_by_knowledge_id(knowledge.id) # type: ignore assert len(results) == 2, f"Expected 2 vectors, got {len(results)}" assert ( @@ -133,71 +142,75 @@ def test_get_vectors_by_knowledge_id(sync_session: Session, test_data: TestData) ), f"Expected {vectors[1].content}, got {results[1].content}" -def test_service_similarity_search( - sync_session: Session, test_data: TestData, embedder +@pytest.mark.asyncio(loop_scope="session") +async def test_service_similarity_search( + session: AsyncSession, test_data: TestData, embedder ): vectors, knowledge, brain = test_data assert knowledge.id assert brain.brain_id - repo = VectorRepository(sync_session) - service = VectorService(repo) - service._embedding = embedder + repo = VectorRepository(session) + service = VectorService(repo, embedder=embedder) k = 2 - results = service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore + results = await service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore assert len(results) == k assert results[0].page_content == vectors[0].content - results = service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore + results = await service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore assert results[0].page_content == vectors[1].content k = 1 - results = service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore + results = await service.similarity_search(vectors[0].content, brain.brain_id, k=k) # type: ignore assert len(results) == k assert results[0].page_content == vectors[0].content - results = service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore + results = await service.similarity_search(vectors[1].content, brain.brain_id, k=k) # type: ignore assert results[0].page_content == vectors[1].content -def test_similarity_search(sync_session: Session, test_data: TestData): +@pytest.mark.asyncio(loop_scope="session") +async def test_similarity_search(session: AsyncSession, test_data: TestData): vectors, knowledge, brain = test_data assert knowledge.id assert brain.brain_id - repo = VectorRepository(sync_session) + repo = VectorRepository(session) k = 2 - results = repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore + results = await repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore assert len(results) == k assert results[0].content == vectors[0].content - results = repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore + results = await repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore assert results[0].content == vectors[1].content k = 1 - results = repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore + results = await repo.similarity_search(vectors[0].embedding, brain.brain_id, k=k) # type: ignore assert len(results) == k assert results[0].content == vectors[0].content - results = repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore + results = await repo.similarity_search(vectors[1].embedding, brain.brain_id, k=k) # type: ignore assert results[0].content == vectors[1].content -def test_similarity_with_oversized_chunk(sync_session: Session, test_data: TestData): +@pytest.mark.asyncio(loop_scope="session") +async def test_similarity_with_oversized_chunk( + session: AsyncSession, test_data: TestData +): vectors, knowledge, brain = test_data assert knowledge.id assert brain.brain_id - repo = VectorRepository(sync_session) + repo = VectorRepository(session) k = 2 - results = repo.similarity_search( + results = await repo.similarity_search( vectors[0].embedding, # type: ignore brain.brain_id, k=k, diff --git a/backend/api/quivr_api/vectorstore/supabase.py b/backend/api/quivr_api/vectorstore/supabase.py index dfca682f0099..a4dcd8f79183 100644 --- a/backend/api/quivr_api/vectorstore/supabase.py +++ b/backend/api/quivr_api/vectorstore/supabase.py @@ -69,18 +69,17 @@ def find_brain_closest_query( ] return brain_details - def similarity_search( + async def asimilarity_search( self, query: str, k: int = 40, - table: str = "match_vectors", threshold: float = 0.5, **kwargs: Any, ) -> List[Document]: logger.debug(f"Similarity search for query: {query}") assert self.brain_id, "Brain ID is required for similarity search" - match_result = self.vector_service.similarity_search( + match_result = await self.vector_service.similarity_search( query, brain_id=self.brain_id, k=k ) diff --git a/backend/core/quivr_core/files/file.py b/backend/core/quivr_core/files/file.py index 0a778f176567..e7923d34df5b 100644 --- a/backend/core/quivr_core/files/file.py +++ b/backend/core/quivr_core/files/file.py @@ -153,7 +153,6 @@ def metadata(self) -> dict[str, Any]: def serialize(self) -> QuivrFileSerialized: return QuivrFileSerialized( id=self.id, - brain_id=self.brain_id, path=self.path.absolute(), original_filename=self.original_filename, file_size=self.file_size, diff --git a/backend/core/quivr_core/processor/processor_base.py b/backend/core/quivr_core/processor/processor_base.py index 1b8cbbe39423..a108dca85eed 100644 --- a/backend/core/quivr_core/processor/processor_base.py +++ b/backend/core/quivr_core/processor/processor_base.py @@ -2,7 +2,6 @@ from abc import ABC, abstractmethod from importlib.metadata import PackageNotFoundError, version from typing import Any -from uuid import uuid4 from langchain_core.documents import Document @@ -13,7 +12,6 @@ # TODO: processors should be cached somewhere ? # The processor should be cached by processor type -# The cache should use a single class ProcessorBase(ABC): supported_extensions: list[FileExtension | str] @@ -43,7 +41,6 @@ async def process_file(self, file: QuivrFile) -> list[Document]: "utf-8" ) doc.metadata = { - "id": uuid4(), "chunk_index": idx, "quivr_core_version": qvr_version, **file.metadata, diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index 6a06b856b937..b7dee74d3b52 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -2,7 +2,6 @@ import os from uuid import UUID -from celery.schedules import crontab from celery.signals import worker_process_init from dotenv import load_dotenv from quivr_api.celery_config import celery @@ -12,29 +11,19 @@ from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors from quivr_api.modules.brain.service.brain_service import BrainService from quivr_api.modules.dependencies import get_supabase_client -from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage -from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) from quivr_api.modules.sync.service.sync_notion import SyncNotionService -from quivr_api.modules.vector.repository.vectors_repository import VectorRepository -from quivr_api.modules.vector.service.vector_service import VectorService from quivr_api.utils.telemetry import maybe_send_telemetry -from sqlalchemy import Engine, create_engine from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from sqlmodel import Session, text -from sqlmodel.ext.asyncio.session import AsyncSession from quivr_worker.check_premium import check_is_premium -from quivr_worker.process.process_s3_file import process_uploaded_file -from quivr_worker.process.process_url import process_url_func +from quivr_worker.process.processor import KnowledgeProcessor, build_processor_services from quivr_worker.syncs.process_active_syncs import ( - SyncServices, - process_all_active_syncs, process_notion_sync, - process_sync, ) from quivr_worker.syncs.store_notion import fetch_and_store_notion_files_async from quivr_worker.utils import _patch_json @@ -56,13 +45,11 @@ storage = SupabaseS3Storage() notion_service: SyncNotionService | None = None async_engine: AsyncEngine | None = None -engine: Engine | None = None @worker_process_init.connect def init_worker(**kwargs): global async_engine - global engine if not async_engine: async_engine = create_async_engine( settings.pg_database_async_url, @@ -74,17 +61,6 @@ def init_worker(**kwargs): pool_recycle=1800, ) - if not engine: - engine = create_engine( - settings.pg_database_url, - echo=True if os.getenv("ORM_DEBUG") else False, - future=True, - # NOTE: pessimistic bound on - pool_pre_ping=True, - pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6 - pool_recycle=1800, - ) - @celery.task( retries=3, @@ -94,109 +70,26 @@ def init_worker(**kwargs): dont_autoretry_for=(FileExistsError,), ) def process_file_task( - knowledge_id: UUID, - file_name: str, + knowledge_dto: KnowledgeDTO, notification_id: UUID | None = None, - source: str | None = None, - source_link: str | None = None, - delete_file: bool = False, ): if async_engine is None: init_worker() logger.info( - f"Task process_file started for file_name={file_name}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}" + f"Task process_file started for knowledge_id={knowledge_dto.id}, notification_id={notification_id}" ) loop = asyncio.get_event_loop() - loop.run_until_complete( - aprocess_file_task( - file_name=file_name, - knowledge_id=knowledge_id, - source=source, - source_link=source_link, - delete_file=delete_file, - ) - ) + loop.run_until_complete(aprocess_file_task(knowledge_dto)) -async def aprocess_file_task( - file_name: str, - knowledge_id: UUID, - source: str | None = None, - source_link: str | None = None, - delete_file: bool = False, -): - global engine - assert engine - async with AsyncSession(async_engine) as async_session: - try: - await async_session.execute( - text("SET SESSION idle_in_transaction_session_timeout = '5min';") - ) - # FIXME @amine: fix to need AsyncSession in vector Service - vector_repository = VectorRepository(async_session.sync_session) - vector_service = VectorService(vector_repository) - knowledge_repository = KnowledgeRepository(async_session) - knowledge_service = KnowledgeService(knowledge_repository) - await process_uploaded_file( - supabase_client=supabase_client, - brain_service=brain_service, - vector_service=vector_service, - knowledge_service=knowledge_service, - file_name=file_name, - knowledge_id=knowledge_id, - integration=source, - integration_link=source_link, - ) - await async_session.commit() - except Exception as e: - await async_session.rollback() - raise e - finally: - await async_session.close() - - -@celery.task( - retries=3, - default_retry_delay=1, - name="process_crawl_task", - autoretry_for=(Exception,), -) -def process_crawl_task( - crawl_website_url: str, - brain_id: UUID, - knowledge_id: UUID, - notification_id: UUID | None = None, -): - logger.info( - f"Task process_crawl_task started for url={crawl_website_url}, knowledge_id={knowledge_id}, brain_id={brain_id}, notification_id={notification_id}" - ) - global engine - assert engine - try: - with Session(engine, expire_on_commit=False, autoflush=False) as session: - session.execute( - text("SET SESSION idle_in_transaction_session_timeout = '5min';") - ) - vector_repository = VectorRepository(session) - vector_service = VectorService(vector_repository) - loop = asyncio.get_event_loop() - loop.run_until_complete( - process_url_func( - url=crawl_website_url, - brain_id=brain_id, - knowledge_id=knowledge_id, - brain_service=brain_service, - vector_service=vector_service, - ) - ) - session.commit() - except Exception as e: - session.rollback() - raise e - finally: - session.close() +async def aprocess_file_task(knowledge_dto: KnowledgeDTO): + global async_engine + assert async_engine + async with build_processor_services(async_engine) as processor_services: + km_processor = KnowledgeProcessor(processor_services) + await km_processor.process_knowledge(knowledge_dto) @celery.task(name="NotionConnectorLoad") @@ -222,54 +115,6 @@ def check_is_premium_task(): check_is_premium(supabase_client) -@celery.task(name="process_sync_task") -def process_sync_task( - sync_id: int, user_id: str, files_ids: list[str], folder_ids: list[str] -): - global async_engine - assert async_engine - sync = next( - filter(lambda s: s.id == sync_id, sync_active_service.get_syncs_active(user_id)) - ) - loop = asyncio.get_event_loop() - loop.run_until_complete( - process_sync( - sync=sync, - files_ids=files_ids, - folder_ids=folder_ids, - services=SyncServices( - async_engine=async_engine, - sync_active_service=sync_active_service, - sync_user_service=sync_user_service, - sync_files_repo_service=sync_files_repo_service, - storage=storage, - brain_vectors=brain_vectors, - notification_service=notification_service, - ), - ) - ) - - -@celery.task(name="process_active_syncs_task") -def process_active_syncs_task(): - global async_engine - assert async_engine - loop = asyncio.get_event_loop() - loop.run_until_complete( - process_all_active_syncs( - SyncServices( - async_engine=async_engine, - sync_active_service=sync_active_service, - sync_user_service=sync_user_service, - sync_files_repo_service=sync_files_repo_service, - storage=storage, - brain_vectors=brain_vectors, - notification_service=notification_service, - ), - ) - ) - - @celery.task(name="process_notion_sync_task") def process_notion_sync_task(): global async_engine @@ -290,21 +135,23 @@ def fetch_and_store_notion_files_task(access_token: str, user_id: UUID): ) -celery.conf.beat_schedule = { - "ping_telemetry": { - "task": f"{__name__}.ping_telemetry", - "schedule": crontab(minute="*/30", hour="*"), - }, - "process_active_syncs": { - "task": "process_active_syncs_task", - "schedule": crontab(minute="*/1", hour="*"), - }, - "process_premium_users": { - "task": "check_is_premium_task", - "schedule": crontab(minute="*/1", hour="*"), - }, - "process_notion_sync": { - "task": "process_notion_sync_task", - "schedule": crontab(minute="0", hour="*/6"), - }, -} +# from celery.schedules import crontab + +# celery.conf.beat_schedule = { +# "ping_telemetry": { +# "task": f"{__name__}.ping_telemetry", +# "schedule": crontab(minute="*/30", hour="*"), +# }, +# "process_active_syncs": { +# "task": "process_active_syncs_task", +# "schedule": crontab(minute="*/1", hour="*"), +# }, +# "process_premium_users": { +# "task": "check_is_premium_task", +# "schedule": crontab(minute="*/1", hour="*"), +# }, +# "process_notion_sync": { +# "task": "process_notion_sync_task", +# "schedule": crontab(minute="0", hour="*/6"), +# }, +# } diff --git a/backend/worker/quivr_worker/files.py b/backend/worker/quivr_worker/files.py deleted file mode 100644 index d27385763b4d..000000000000 --- a/backend/worker/quivr_worker/files.py +++ /dev/null @@ -1,90 +0,0 @@ -import hashlib -import time -from contextlib import contextmanager -from pathlib import Path -from tempfile import NamedTemporaryFile -from typing import Any -from uuid import UUID - -from quivr_api.logger import get_logger -from quivr_core.files.file import FileExtension, QuivrFile - -from quivr_worker.utils import get_tmp_name - -logger = get_logger("celery_worker") - - -def compute_sha1(content: bytes) -> str: - m = hashlib.sha1() - m.update(content) - return m.hexdigest() - - -@contextmanager -def build_file( - file_data: bytes, - file_name_ext: str, -): - try: - # TODO(@aminediro) : Maybe use fsspec file to be agnostic to where files are stored :? - # We are reading the whole file to memory, which doesn't scale - tmp_name, _, _ = get_tmp_name(file_name_ext) - tmp_file = NamedTemporaryFile( - suffix="_" + tmp_name, # pyright: ignore reportPrivateUsage=none - ) - tmp_file.write(file_data) - tmp_file.flush() - yield Path(tmp_file.name) - finally: - # Code to release resource, e.g.: - tmp_file.close() - - -class File: - __slots__ = [ - "id", - "file_name", - "tmp_file_path", - "file_size", - "file_extension", - "file_sha1", - "original_file_name", - ] - - def __init__( - self, - knowledge_id: UUID, - file_name: str, - tmp_file_path: Path, - file_size: int, - file_extension: str, - file_sha1: str, - original_file_name: str, - ): - self.id = knowledge_id - self.file_name = file_name - self.tmp_file_path = tmp_file_path - self.file_size = file_size - self.file_sha1 = file_sha1 - self.file_extension = FileExtension(file_extension) - self.original_file_name = original_file_name - - def is_empty(self): - return self.file_size < 1 - - def to_qfile(self, metadata: dict[str, Any] = {}) -> QuivrFile: - return QuivrFile( - id=self.id, - original_filename=self.file_name, - path=self.tmp_file_path, - file_sha1=self.file_sha1, - file_extension=self.file_extension, - file_size=self.file_size, - metadata={ - "date": time.strftime("%Y%m%d"), - "file_name": self.file_name, - "original_file_name": self.original_file_name, - "knowledge_id": self.id, - **metadata, - }, - ) diff --git a/backend/worker/quivr_worker/parsers/audio.py b/backend/worker/quivr_worker/parsers/audio.py index 533357e28080..f4c6890500e2 100644 --- a/backend/worker/quivr_worker/parsers/audio.py +++ b/backend/worker/quivr_worker/parsers/audio.py @@ -3,11 +3,12 @@ from langchain.schema import Document from langchain.text_splitter import RecursiveCharacterTextSplitter from openai import OpenAI +from quivr_core.files.file import QuivrFile -from quivr_worker.files import File, compute_sha1 +from quivr_worker.process.utils import compute_sha1 -def process_audio(file: File, model: str = "whisper=1"): +def process_audio(file: QuivrFile, model: str = "whisper=1"): # TODO(@aminediro): These should apear in the class processor # Should be instanciated once per Processor chunk_size = 500 @@ -19,7 +20,8 @@ def process_audio(file: File, model: str = "whisper=1"): dateshort = time.strftime("%Y%m%d-%H%M%S") file_meta_name = f"audiotranscript_{dateshort}.txt" - with open(file.tmp_file_path, "rb") as audio_file: + # TODO: This reopens the file adding an additional FD + with open(file.path, "rb") as audio_file: transcript = client.audio.transcriptions.create(model=model, file=audio_file) transcript_txt = transcript.text.encode("utf-8") diff --git a/backend/worker/quivr_worker/process/process_file.py b/backend/worker/quivr_worker/process/process_file.py index a13eb833f8ac..bff5e72401ba 100644 --- a/backend/worker/quivr_worker/process/process_file.py +++ b/backend/worker/quivr_worker/process/process_file.py @@ -1,14 +1,12 @@ from typing import Any -from uuid import UUID from langchain_core.documents import Document from quivr_api.logger import get_logger -from quivr_api.modules.brain.entity.brain_entity import BrainEntity -from quivr_api.modules.brain.service.brain_service import BrainService +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.vector.service.vector_service import VectorService +from quivr_core.files.file import QuivrFile from quivr_core.processor.registry import get_processor_class -from quivr_worker.files import File from quivr_worker.parsers.audio import process_audio logger = get_logger("celery_worker") @@ -25,92 +23,54 @@ } -async def process_file( - file_instance: File, - brain: BrainEntity, - brain_service: BrainService, - vector_service: VectorService, - integration: str | None, - integration_link: str | None, -): - chunks = await parse_file( - file=file_instance, - brain=brain, - integration=integration, - integration_link=integration_link, - ) - store_chunks( - file=file_instance, - brain_id=brain.brain_id, - chunks=chunks, - brain_service=brain_service, - vector_service=vector_service, - ) - - -def store_chunks( +async def store_chunks( *, - file: File, - brain_id: UUID, + knowledge: KnowledgeDB, chunks: list[Document], - brain_service: BrainService, vector_service: VectorService, ): - # vector_ids = document_vector_store.add_documents(chunks) - vector_ids = vector_service.create_vectors(chunks, file.id) - logger.debug(f"Inserted {len(chunks)} chunks in vectors table for {file}") - + assert knowledge.id + vector_ids = await vector_service.create_vectors(chunks, knowledge.id) + logger.debug( + f"Inserted {len(chunks)} chunks in vectors table for knowledge: {knowledge.id}" + ) if vector_ids is None or len(vector_ids) == 0: - raise Exception(f"Error inserting chunks for file {file.file_name}") - - brain_service.update_brain_last_update_time(brain_id) + raise Exception(f"Error inserting chunks for knowledge {knowledge.id}") -async def parse_file( - file: File, - brain: BrainEntity, - integration: str | None = None, - integration_link: str | None = None, +async def parse_qfile( + *, + qfile: QuivrFile, **processor_kwargs: dict[str, Any], ) -> list[Document]: try: # TODO(@aminediro): add audio procesors to quivr-core - if file.file_extension in audio_extensions: - logger.debug(f"processing audio file {file}") - audio_docs = process_audio_file(file, brain) + if qfile.file_extension in audio_extensions: + logger.debug(f"processing audio file {qfile}") + audio_docs = process_audio_file(qfile) return audio_docs else: - qfile = file.to_qfile( - brain.brain_id, - { - "integration": integration or "", - "integration_link": integration_link or "", - }, - ) - processor_cls = get_processor_class(file.file_extension) + processor_cls = get_processor_class(qfile.file_extension) processor = processor_cls(**processor_kwargs) docs = await processor.process_file(qfile) logger.debug(f"Parsed {qfile} to : {docs}") return docs except KeyError as e: - raise ValueError(f"Can't parse {file}. No available processor") from e + raise ValueError(f"Can't parse {qfile}. No available processor") from e +# TODO: Move this to quivr-core def process_audio_file( - file: File, - brain: BrainEntity, + qfile: QuivrFile, ): try: - result = process_audio(file=file) + result = process_audio(file=qfile) if result is None or result == 0: logger.info( - f"{file.file_name} has been uploaded to brain. There might have been an error while reading it, please make sure the file is not illformed or just an image", # pyright: ignore reportPrivateUsage=none + f"{qfile.file_name} has been uploaded to brain. There might have been an error while reading it, please make sure the file is not illformed or just an image", # pyright: ignore reportPrivateUsage=none ) return [] - logger.info( - f"{file.file_name} has been uploaded to brain {brain.name} in {result} chunks", # pyright: ignore reportPrivateUsage=none - ) return result except Exception as e: - logger.exception(f"Error processing audio file {file}: {e}") + logger.exception(f"Error processing audio file {qfile}: {e}") raise e diff --git a/backend/worker/quivr_worker/process/process_s3_file.py b/backend/worker/quivr_worker/process/process_s3_file.py deleted file mode 100644 index 526a62ad50cb..000000000000 --- a/backend/worker/quivr_worker/process/process_s3_file.py +++ /dev/null @@ -1,41 +0,0 @@ -from uuid import UUID - -from quivr_api.logger import get_logger -from quivr_api.modules.brain.service.brain_service import BrainService -from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate -from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService -from quivr_api.modules.vector.service.vector_service import VectorService - -from quivr_worker.files import build_file -from quivr_worker.process.process_file import process_file -from supabase import Client - -logger = get_logger("celery_worker") - - -async def process_uploaded_file( - supabase_client: Client, - brain_service: BrainService, - vector_service: VectorService, - knowledge_service: KnowledgeService, - file_name: str, - knowledge_id: UUID, - integration: str | None = None, - integration_link: str | None = None, - bucket_name: str = "quivr", -): - file_data = supabase_client.storage.from_(bucket_name).download(file_name) - with build_file(file_data, knowledge_id, file_name) as file_instance: - knowledge = await knowledge_service.get_knowledge(knowledge_id=knowledge_id) - await knowledge_service.update_knowledge( - knowledge, - KnowledgeUpdate(file_sha1=file_instance.file_sha1), # type: ignore - ) - await process_file( - file_instance=file_instance, - brain=knowledge.brains[0], # FIXME: this is temporary - brain_service=brain_service, - vector_service=vector_service, - integration=integration, - integration_link=integration_link, - ) diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 096ccf3bc4ab..6a2e9a2c5088 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -1,13 +1,12 @@ -import time -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from dataclasses import dataclass from io import BytesIO -from typing import Any, AsyncGenerator, Generator, List, Tuple +from typing import Any, AsyncGenerator, List, Optional, Tuple from uuid import UUID from quivr_api.logger import get_logger from quivr_api.modules.dependencies import get_supabase_async_client -from quivr_api.modules.knowledge.dto.inputs import AddKnowledge +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository @@ -18,46 +17,33 @@ from quivr_api.modules.sync.repository.sync_repository import SyncsRepository from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.sync import ( - AzureDriveSync, BaseSync, - DropboxSync, - GitHubSync, - GoogleDriveSync, ) from quivr_api.modules.vector.repository.vectors_repository import VectorRepository from quivr_api.modules.vector.service.vector_service import VectorService -from quivr_core.files.file import FileExtension, QuivrFile +from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus from sqlalchemy.ext.asyncio import AsyncEngine from sqlmodel import text from sqlmodel.ext.asyncio.session import AsyncSession -from quivr_worker.files import build_file, compute_sha1 +from quivr_worker.process.process_file import parse_qfile, store_chunks +from quivr_worker.process.utils import ( + build_qfile, + build_syncprovider_mapping, + compute_sha1, + skip_process, +) logger = get_logger("celery_worker") -def skip_process(knowledge: KnowledgeDTO) -> bool: - return knowledge.is_folder and knowledge.source != KnowledgeSource.NOTION - - -def build_syncprovider_mapping() -> dict[str, BaseSync]: - mapping_sync_utils = { - "google": GoogleDriveSync(), - "azure": AzureDriveSync(), - "dropbox": DropboxSync(), - "github": GitHubSync(), - # "notion", NotionSync(notion_service=notion_service), - } - return mapping_sync_utils - - @dataclass class ProcessorServices: sync_service: SyncsService vector_service: VectorService knowledge_service: KnowledgeService - syncprovider_mapping: dict[str, BaseSync] + syncprovider_mapping: dict[SyncProvider, BaseSync] @asynccontextmanager @@ -81,12 +67,12 @@ async def build_processor_services(engine: AsyncEngine): async_client = await get_supabase_async_client() storage = SupabaseS3Storage(async_client) try: - async with _start_session(engine) as async_session: - vector_repository = VectorRepository(async_session.sync_session) + async with _start_session(engine) as session: + vector_repository = VectorRepository(session) vector_service = VectorService(vector_repository) - knowledge_repository = KnowledgeRepository(async_session) + knowledge_repository = KnowledgeRepository(session) knowledge_service = KnowledgeService(knowledge_repository, storage=storage) - sync_repository = SyncsRepository(async_session) + sync_repository = SyncsRepository(session) sync_service = SyncsService(sync_repository) yield ProcessorServices( knowledge_service=knowledge_service, @@ -113,37 +99,6 @@ async def download_sync_file( return file_data -@contextmanager -def build_qfile( - knowledge: KnowledgeDB, file_data: bytes -) -> Generator[QuivrFile, None, None]: - assert knowledge.id - assert knowledge.file_name - assert knowledge.file_sha1 - with build_file( - file_data=file_data, file_name_ext=knowledge.file_name - ) as tmp_file_path: - qfile = QuivrFile( - id=knowledge.id, - original_filename=knowledge.file_name, - path=tmp_file_path, - file_sha1=knowledge.file_sha1, - file_extension=FileExtension(knowledge.extension), - file_size=knowledge.file_size, - metadata={ - "date": time.strftime("%Y%m%d"), - "file_name": knowledge.file_name, - "knowledge_id": knowledge.id, - }, - ) - if knowledge.metadata_: - qfile.additional_metadata = { - **qfile.metadata, - **knowledge.metadata_, - } - yield qfile - - class KnowledgeProcessor: def __init__(self, services: ProcessorServices): self.services = services @@ -164,12 +119,11 @@ async def fetch_sync_knowledge( ) return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 - @asynccontextmanager - async def build_processable( + async def yield_processable_kms( self, knowledge: KnowledgeDTO ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: if knowledge.source == KnowledgeSource.LOCAL: - async with self._build_local(knowledge) as to_process: + async for to_process in self._build_local(knowledge): yield to_process elif knowledge.source in ( KnowledgeSource.AZURE, @@ -177,7 +131,7 @@ async def build_processable( KnowledgeSource.GOOGLE, KnowledgeSource.NOTION, ): - async with self._build_sync(knowledge) as to_process: + async for to_process in self._build_sync(knowledge): yield to_process elif knowledge.source == KnowledgeSource.WEB: raise NotImplementedError @@ -187,7 +141,6 @@ async def build_processable( ) raise ValueError("Unknown knowledge source : {knoledge.source}") - @asynccontextmanager async def _build_local( self, knowledge: KnowledgeDTO ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: @@ -204,10 +157,9 @@ async def _build_local( with build_qfile(knowledge_db, file_data) as qfile: yield (knowledge_db, qfile) - @asynccontextmanager async def _build_sync( self, knowledge_dto: KnowledgeDTO - ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: + ) -> AsyncGenerator[Optional[Tuple[KnowledgeDB, QuivrFile]], None]: if knowledge_dto.id is None: logger.error(f"received unprocessable knowledge: {knowledge_dto.id} ") raise ValueError @@ -247,7 +199,7 @@ async def _build_sync( if not sync_files: return - # Yield parent knowledge to process + # Yield parent_knowledge as the first knowledge to process file_data = await download_sync_file( sync_provider=sync_provider, file=SyncFile( @@ -296,5 +248,20 @@ async def _build_sync( yield (file_knowledge, qfile) async def process_knowledge(self, knowledge_dto: KnowledgeDTO): - async for (knowledge, qfile) in self.build_processable(knowledge_dto): - pass + async for knowledge_tuple in self.yield_processable_kms(knowledge_dto): + if knowledge_tuple is None: + continue + knowledge, qfile = knowledge_tuple + if not skip_process(knowledge): + chunks = await parse_qfile(qfile=qfile) + await store_chunks( + knowledge=knowledge, + chunks=chunks, + vector_service=self.services.vector_service, + ) + await self.services.knowledge_service.update_knowledge( + knowledge, + KnowledgeUpdate( + status=KnowledgeStatus.PROCESSED, file_sha1=knowledge.file_sha1 + ), + ) diff --git a/backend/worker/quivr_worker/process/utils.py b/backend/worker/quivr_worker/process/utils.py new file mode 100644 index 000000000000..5a11375df358 --- /dev/null +++ b/backend/worker/quivr_worker/process/utils.py @@ -0,0 +1,105 @@ +import hashlib +import os +import time +from contextlib import contextmanager +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Generator, Tuple + +from quivr_api.celery_config import celery +from quivr_api.logger import get_logger +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.utils.sync import ( + AzureDriveSync, + BaseSync, + DropboxSync, + GitHubSync, + GoogleDriveSync, +) +from quivr_core.files.file import FileExtension, QuivrFile + +celery_inspector = celery.control.inspect() + +logger = get_logger("celery_worker") + + +def skip_process(knowledge: KnowledgeDTO | KnowledgeDB) -> bool: + return knowledge.is_folder and knowledge.source != KnowledgeSource.NOTION + + +def build_syncprovider_mapping() -> dict[SyncProvider, BaseSync]: + mapping_sync_utils = { + SyncProvider.GOOGLE: GoogleDriveSync(), + SyncProvider.AZURE: AzureDriveSync(), + SyncProvider.DROPBOX: DropboxSync(), + SyncProvider.GITHUB: GitHubSync(), + # SyncProvider.NOTION: NotionSync(notion_service=notion_service), + } + return mapping_sync_utils + + +def compute_sha1(content: bytes) -> str: + m = hashlib.sha1() + m.update(content) + return m.hexdigest() + + +def get_tmp_name(file_name: str) -> Tuple[str, str, str]: + # Filepath is S3 based + tmp_name = file_name.replace("/", "_") + base_file_name = os.path.basename(file_name) + _, file_extension = os.path.splitext(base_file_name) + return tmp_name, base_file_name, file_extension + + +@contextmanager +def create_temp_file( + file_data: bytes, + file_name_ext: str, +): + # TODO(@aminediro) : + # Maybe use fsspec file to be agnostic to where files are stored + # We are reading the whole file to memory, which doesn't scale + try: + tmp_name, _, _ = get_tmp_name(file_name_ext) + tmp_file = NamedTemporaryFile( + suffix="_" + tmp_name, + ) + tmp_file.write(file_data) + tmp_file.flush() + yield Path(tmp_file.name) + finally: + tmp_file.close() + + +@contextmanager +def build_qfile( + knowledge: KnowledgeDB, file_data: bytes +) -> Generator[QuivrFile, None, None]: + assert knowledge.id + assert knowledge.file_name + assert knowledge.file_sha1 + with create_temp_file( + file_data=file_data, file_name_ext=knowledge.file_name + ) as tmp_file_path: + qfile = QuivrFile( + id=knowledge.id, + original_filename=knowledge.file_name, + path=tmp_file_path, + file_sha1=knowledge.file_sha1, + file_extension=FileExtension(knowledge.extension), + file_size=knowledge.file_size, + metadata={ + "date": time.strftime("%Y%m%d"), + "file_name": knowledge.file_name, + "knowledge_id": knowledge.id, + }, + ) + if knowledge.metadata_: + qfile.additional_metadata = { + **qfile.metadata, + **knowledge.metadata_, + } + yield qfile diff --git a/backend/worker/quivr_worker/syncs/utils.py b/backend/worker/quivr_worker/syncs/utils.py deleted file mode 100644 index da8a6af83b3b..000000000000 --- a/backend/worker/quivr_worker/syncs/utils.py +++ /dev/null @@ -1,91 +0,0 @@ -from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import AsyncGenerator - -from quivr_api.celery_config import celery -from quivr_api.logger import get_logger -from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors -from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository -from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage -from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService -from quivr_api.modules.notification.service.notification_service import ( - NotificationService, -) -from quivr_api.modules.sync.repository.notion_repository import NotionRepository -from quivr_api.modules.sync.repository.sync_files import SyncFilesRepository -from quivr_api.modules.sync.service.sync_notion import ( - SyncNotionService, -) -from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.sync import ( - AzureDriveSync, - DropboxSync, - GitHubSync, - GoogleDriveSync, - NotionSync, -) -from quivr_api.modules.sync.utils.syncutils import SyncUtils -from sqlalchemy.ext.asyncio import AsyncEngine -from sqlmodel import text -from sqlmodel.ext.asyncio.session import AsyncSession - -celery_inspector = celery.control.inspect() - -logger = get_logger("celery_worker") - - -@dataclass -class SyncServices: - async_engine: AsyncEngine - sync_active_service: SyncsService - sync_user_service: SyncsService - sync_files_repo_service: SyncFilesRepository - notification_service: NotificationService - brain_vectors: BrainsVectors - storage: SupabaseS3Storage - - -@asynccontextmanager -async def build_syncs_utils( - deps: SyncServices, -) -> AsyncGenerator[dict[str, SyncUtils], None]: - try: - async with AsyncSession( - deps.async_engine, expire_on_commit=False, autoflush=False - ) as session: - await session.execute( - text("SET SESSION idle_in_transaction_session_timeout = '5min';") - ) - notion_repository = NotionRepository(session) - notion_service = SyncNotionService(notion_repository) - knowledge_service = KnowledgeService(KnowledgeRepository(session)) - - mapping_sync_utils = {} - for provider_name, sync_cloud in [ - ("google", GoogleDriveSync()), - ("azure", AzureDriveSync()), - ("dropbox", DropboxSync()), - ("github", GitHubSync()), - ( - "notion", - NotionSync(notion_service=notion_service), - ), # Fixed duplicate "github" key - ]: - provider_sync_util = SyncUtils( - sync_user_service=deps.sync_user_service, - sync_active_service=deps.sync_active_service, - sync_files_repo=deps.sync_files_repo_service, - sync_cloud=sync_cloud, - notification_service=deps.notification_service, - brain_vectors=deps.brain_vectors, - knowledge_service=knowledge_service, - ) - mapping_sync_utils[provider_name] = provider_sync_util - - yield mapping_sync_utils - await session.commit() - except Exception as e: - await session.rollback() - raise e - finally: - await session.close() diff --git a/backend/worker/quivr_worker/utils.py b/backend/worker/quivr_worker/utils.py index 75978b27cd9c..80dfb034a917 100644 --- a/backend/worker/quivr_worker/utils.py +++ b/backend/worker/quivr_worker/utils.py @@ -1,16 +1,6 @@ -import os import uuid from json import JSONEncoder from pathlib import PosixPath -from typing import Tuple - - -def get_tmp_name(file_name: str) -> Tuple[str, str, str]: - # Filepath is S3 based - tmp_name = file_name.replace("/", "_") - base_file_name = os.path.basename(file_name) - _, file_extension = os.path.splitext(base_file_name) - return tmp_name, base_file_name, file_extension # TODO: This is a hack for making uuid work with supabase clients diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index 596b81c062de..f40d2961fded 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -1,40 +1,187 @@ import os +from io import BytesIO from uuid import uuid4 import pytest +import pytest_asyncio +import sqlalchemy +from fastapi import UploadFile +from langchain_core.embeddings import DeterministicFakeEmbedding +from quivr_api.models.settings import settings +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.brain.entity.brain_user import BrainUserDB +from quivr_api.modules.dependencies import get_supabase_client +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.knowledge.tests.conftest import FakeStorage +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.repository.sync_repository import SyncsRepository +from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.tests.test_sync_controller import BaseFakeSync +from quivr_api.modules.sync.utils.sync import BaseSync +from quivr_api.modules.user.entity.user_identity import User +from quivr_api.modules.vector.repository.vectors_repository import VectorRepository +from quivr_api.modules.vector.service.vector_service import VectorService +from quivr_core.files.file import QuivrFile +from quivr_worker.process.processor import KnowledgeProcessor, ProcessorServices +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession -from quivr_worker.files import File +pg_database_base_url = "postgres:postgres@localhost:54322/postgres" + + +async_engine = create_async_engine( + "postgresql+asyncpg://" + pg_database_base_url, + echo=True if os.getenv("ORM_DEBUG") else False, + future=True, +) + + +@pytest_asyncio.fixture(scope="function") +async def session(): + async with async_engine.connect() as conn: + trans = await conn.begin() + nested = await conn.begin_nested() + async_session = AsyncSession( + conn, + expire_on_commit=False, + autoflush=False, + autocommit=False, + ) + + @sqlalchemy.event.listens_for( + async_session.sync_session, "after_transaction_end" + ) + def end_savepoint(session, transaction): + nonlocal nested + if not nested.is_active: + nested = conn.sync_connection.begin_nested() # type: ignore + + yield async_session + await trans.rollback() + await async_session.close() + + +@pytest.fixture(scope="session") +def supabase_client(): + return get_supabase_client() + + +@pytest_asyncio.fixture(scope="function") +async def user(session: AsyncSession) -> User: + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + assert user_1.id + return user_1 + + +@pytest_asyncio.fixture(scope="function") +async def brain_user(session, user: User) -> Brain: + assert user.id + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_1) + await session.commit() + await session.refresh(brain_1) + assert brain_1.brain_id + brain_user = BrainUserDB( + brain_id=brain_1.brain_id, user_id=user.id, default_brain=True, rights="Owner" + ) + session.add(brain_user) + await session.commit() + return brain_1 + + +@pytest_asyncio.fixture(scope="function") +async def proc_services(session: AsyncSession) -> ProcessorServices: + storage = FakeStorage() + embedder = DeterministicFakeEmbedding(size=settings.embedding_dim) + sync_provider_mapping: dict[SyncProvider, BaseSync] = { + provider: BaseFakeSync(provider_name=str(provider)) + for provider in list(SyncProvider) + } + vector_repository = VectorRepository(session) + vector_service = VectorService(vector_repository, embedder=embedder) + knowledge_repository = KnowledgeRepository(session) + knowledge_service = KnowledgeService(knowledge_repository, storage=storage) + sync_repository = SyncsRepository(session) + sync_service = SyncsService(sync_repository) + + return ProcessorServices( + knowledge_service=knowledge_service, + vector_service=vector_service, + sync_service=sync_service, + syncprovider_mapping=sync_provider_mapping, + ) + + +@pytest_asyncio.fixture(scope="function") +async def km_processor(proc_services: ProcessorServices): + return KnowledgeProcessor(proc_services) + + +@pytest_asyncio.fixture(scope="function") +async def local_knowledge_file( + proc_services: ProcessorServices, user: User, brain_user: Brain +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + service = proc_services.knowledge_service + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=False, + parent_id=None, + ) + km_data = BytesIO(os.urandom(24)) + + km = await service.create_knowledge( + user_id=user.id, + knowledge_to_add=km_to_add, + upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name), + ) + + # Link it to the brain + await service.link_knowledge_tree_brains( + km, brains_ids=[brain_user.brain_id], user_id=user.id + ) + return km @pytest.fixture -def file_instance(tmp_path) -> File: +def qfile_instance(tmp_path) -> QuivrFile: data = "This is some test data." temp_file = tmp_path / "data.txt" temp_file.write_text(data) knowledge_id = uuid4() - return File( - knowledge_id=knowledge_id, + return QuivrFile( + id=knowledge_id, file_sha1="124", file_extension=".txt", - file_name=temp_file.name, - original_file_name=temp_file.name, + original_filename=temp_file.name, + path=temp_file.absolute(), file_size=len(data), - tmp_file_path=temp_file.absolute(), ) @pytest.fixture -def audio_file(tmp_path) -> File: +def audio_file(tmp_path) -> QuivrFile: data = os.urandom(128) temp_file = tmp_path / "data.mp4" temp_file.write_bytes(data) knowledge_id = uuid4() - return File( - knowledge_id=knowledge_id, + return QuivrFile( + id=knowledge_id, file_sha1="124", file_extension=".mp4", - file_name=temp_file.name, - original_file_name="data.mp4", + original_filename="data.mp4", + path=temp_file.absolute(), file_size=len(data), - tmp_file_path=temp_file.absolute(), ) diff --git a/backend/worker/tests/test_process_file.py b/backend/worker/tests/test_process_file.py index 9ca8a49b7153..ebe30702a29a 100644 --- a/backend/worker/tests/test_process_file.py +++ b/backend/worker/tests/test_process_file.py @@ -6,41 +6,23 @@ import pytest from quivr_api.modules.brain.entity.brain_entity import BrainEntity, BrainType from quivr_core.files.file import FileExtension -from quivr_worker.files import File, build_file -from quivr_worker.parsers.crawler import URL, slugify -from quivr_worker.process.process_file import parse_file +from quivr_worker.process.process_file import parse_qfile +from quivr_worker.process.utils import build_qfile -def test_build_file(): +def test_build_qfile(): random_bytes = os.urandom(128) brain_id = uuid4() file_name = f"{brain_id}/test_file.txt" knowledge_id = uuid4() - with build_file(random_bytes, knowledge_id, file_name) as file: + with build_qfile(random_bytes, file_name) as file: assert file.file_size == 128 assert file.file_name == "test_file.txt" assert file.id == knowledge_id assert file.file_extension == FileExtension.txt -def test_build_url(): - random_bytes = os.urandom(128) - crawl_website = URL(url="http://url.url") - file_name = slugify(crawl_website.url) + ".txt" - knowledge_id = uuid4() - - with build_file( - random_bytes, - knowledge_id, - file_name=file_name, - original_file_name=crawl_website.url, - ) as file: - qfile = file.to_qfile(brain_id=uuid4()) - assert qfile.metadata["original_file_name"] == crawl_website.url - assert qfile.metadata["file_name"] == file_name - - @pytest.mark.asyncio async def test_parse_audio(monkeypatch, audio_file): from openai.resources.audio.transcriptions import Transcriptions @@ -56,7 +38,7 @@ def transcribe(*args, **kwargs): brain_type=BrainType.doc, last_update=datetime.datetime.now(), ) - chunks = await parse_file( + chunks = await parse_qfile( file=audio_file, brain=brain, ) @@ -65,15 +47,15 @@ def transcribe(*args, **kwargs): @pytest.mark.asyncio -async def test_parse_file(file_instance): +async def test_parse_file(qfile_instance): brain = BrainEntity( brain_id=uuid4(), name="test", brain_type=BrainType.doc, last_update=datetime.datetime.now(), ) - chunks = await parse_file( - file=file_instance, + chunks = await parse_qfile( + file=qfile_instance, brain=brain, ) assert len(chunks) > 0 @@ -96,7 +78,7 @@ async def test_parse_file_pdf(): brain_type=BrainType.doc, last_update=datetime.datetime.now(), ) - chunks = await parse_file( + chunks = await parse_qfile( file=file_instance, brain=brain, ) diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index e69de29bb2d1..83b029488737 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -0,0 +1,49 @@ +from typing import Any + +import pytest +from langchain_core.documents import Document +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.vector.entity.vector import Vector +from quivr_core.files.file import QuivrFile +from quivr_core.models import KnowledgeStatus +from quivr_worker.process.processor import KnowledgeProcessor +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +@pytest.mark.asyncio(loop_scope="session") +async def test_process_local_file( + monkeypatch, + session: AsyncSession, + km_processor: KnowledgeProcessor, + local_knowledge_file: KnowledgeDB, +): + async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + km_dto = await local_knowledge_file.to_dto(get_children=False, get_parent=False) + await km_processor.process_knowledge(km_dto) + + # Check knowledge set to processed + assert km_dto.id + assert km_dto.brains + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(km_dto.id) + assert km.status == KnowledgeStatus.PROCESSED + assert km.brains[0].brain_id == km_dto.brains[0]["brain_id"] + + # Check vectors where added + vecs = list( + ( + await session.exec( + select(Vector).where(col(Vector.knowledge_id) == km_dto.id) + ) + ).all() + ) + assert len(vecs) > 0 + assert vecs[0].metadata_ is not None diff --git a/backend/worker/tests/test_sync.py b/backend/worker/tests/test_sync.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/backend/worker/tests/test_utils.py b/backend/worker/tests/test_utils.py index 3b0cfdab51f5..d0a764037b8d 100644 --- a/backend/worker/tests/test_utils.py +++ b/backend/worker/tests/test_utils.py @@ -4,7 +4,6 @@ import pytest from langchain_core.documents import Document - from quivr_worker.utils import _patch_json From 5b39026fba002647e3c14bbabd543cab3bc89245 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 10:47:52 +0200 Subject: [PATCH 17/63] all tests working --- .../modules/knowledge/tests/conftest.py | 3 + .../sync/repository/sync_repository.py | 27 +++-- .../sync/tests/test_sync_controller.py | 20 ++-- .../worker/quivr_worker/process/processor.py | 38 ++++--- backend/worker/tests/conftest.py | 102 ++++++++++++++++-- .../worker/tests/test_process_file_task.py | 43 +++++++- 6 files changed, 196 insertions(+), 37 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py index 63c9768ddc70..dd16af4d1b89 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/conftest.py +++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py @@ -25,6 +25,9 @@ def get_storage_path( async def remove_file(self, storage_path: str): raise SystemError + async def download_file(self, knowledge: KnowledgeDB, **kwargs) -> bytes: + raise NotImplementedError + class FakeStorage(StorageInterface): def __init__(self): diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index 234223adeb74..8c7dffde3dbc 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -31,19 +31,26 @@ class SyncsRepository(BaseRepository): - def __init__(self, session: AsyncSession): + def __init__( + self, + session: AsyncSession, + sync_provider_mapping: dict[SyncProvider, BaseSync] | None = None, + ): self.session = session self.db = get_supabase_client() - self.sync_provider_mapping: dict[SyncProvider, BaseSync] = { - SyncProvider.GOOGLE: GoogleDriveSync(), - SyncProvider.DROPBOX: DropboxSync(), - SyncProvider.AZURE: AzureDriveSync(), - SyncProvider.NOTION: NotionSync( - notion_service=SyncNotionService(NotionRepository(self.session)) - ), - SyncProvider.GITHUB: GitHubSync(), - } + if sync_provider_mapping is None: + self.sync_provider_mapping: dict[SyncProvider, BaseSync] = { + SyncProvider.GOOGLE: GoogleDriveSync(), + SyncProvider.DROPBOX: DropboxSync(), + SyncProvider.AZURE: AzureDriveSync(), + SyncProvider.NOTION: NotionSync( + notion_service=SyncNotionService(NotionRepository(self.session)) + ), + SyncProvider.GITHUB: GitHubSync(), + } + else: + self.sync_provider_mapping = sync_provider_mapping async def create_sync( self, diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py index 057c0ae57d14..38ad64be46e3 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py +++ b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py @@ -1,3 +1,4 @@ +import os from datetime import datetime from io import BytesIO from typing import Dict, List, Union @@ -28,9 +29,9 @@ from quivr_api.modules.sync.utils.sync import BaseSync from quivr_api.modules.user.entity.user_identity import User, UserIdentity +MAX_SYNC_FILES = 1000 N_GET_FILES = 2 - -FOLDER_SYNC_FILE_IDS = [str(uuid4())[:8] for _ in range(N_GET_FILES)] +FOLDER_SYNC_FILE_IDS = [str(uuid4())[:8] for _ in range(MAX_SYNC_FILES)] class BaseFakeSync(BaseSync): @@ -38,8 +39,12 @@ class BaseFakeSync(BaseSync): lower_name = "google" datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" - def __init__(self, provider_name: str | None = None): + def __init__(self, provider_name: str | None = None, n_get_files: int = 2): super().__init__() + self.n_get_files = n_get_files + if n_get_files > MAX_SYNC_FILES: + raise ValueError("can't create fake sync") + self.folder_sync_file_ids = FOLDER_SYNC_FILE_IDS[:n_get_files] if provider_name: self.lower_name = provider_name @@ -77,7 +82,7 @@ def get_files( is_folder=idx % 2 == 0, last_modified_at=datetime.now(), ) - for idx, fid in enumerate(FOLDER_SYNC_FILE_IDS) + for idx, fid in enumerate(self.folder_sync_file_ids) ] async def aget_files( @@ -98,7 +103,7 @@ def download_file( async def adownload_file( self, credentials: Dict, file: SyncFile ) -> Dict[str, Union[str, BytesIO]]: - raise NotImplementedError + return {"content": str(os.urandom(24))} @pytest_asyncio.fixture(scope="function") @@ -112,6 +117,7 @@ async def user(session: AsyncSession) -> User: @pytest_asyncio.fixture(scope="function") async def sync(session: AsyncSession, user: User) -> Sync: + assert user.id sync = Sync( name="test_sync", email="test@test.com", @@ -140,6 +146,7 @@ async def brain(session): @pytest_asyncio.fixture(scope="function") async def knowledge_sync(session, user: User, sync: Sync, brain: Brain): + assert user.id km = KnowledgeDB( file_name="sync_file_1.txt", extension=".txt", @@ -167,7 +174,8 @@ def default_current_user() -> UserIdentity: async def _sync_service(): fake_provider: dict[SyncProvider, BaseSync] = { - provider: BaseFakeSync() for provider in list(SyncProvider) + provider: BaseFakeSync(n_get_files=N_GET_FILES) + for provider in list(SyncProvider) } repository = SyncsRepository(session) repository.sync_provider_mapping = fake_provider diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 6a2e9a2c5088..f4e5f67f9565 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager from dataclasses import dataclass from io import BytesIO @@ -123,7 +124,7 @@ async def yield_processable_kms( self, knowledge: KnowledgeDTO ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: if knowledge.source == KnowledgeSource.LOCAL: - async for to_process in self._build_local(knowledge): + async for to_process in self._yield_local(knowledge): yield to_process elif knowledge.source in ( KnowledgeSource.AZURE, @@ -131,7 +132,7 @@ async def yield_processable_kms( KnowledgeSource.GOOGLE, KnowledgeSource.NOTION, ): - async for to_process in self._build_sync(knowledge): + async for to_process in self._yield_syncs(knowledge): yield to_process elif knowledge.source == KnowledgeSource.WEB: raise NotImplementedError @@ -141,7 +142,7 @@ async def yield_processable_kms( ) raise ValueError("Unknown knowledge source : {knoledge.source}") - async def _build_local( + async def _yield_local( self, knowledge: KnowledgeDTO ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: if knowledge.id is None or knowledge.file_name is None: @@ -157,7 +158,7 @@ async def _build_local( with build_qfile(knowledge_db, file_data) as qfile: yield (knowledge_db, qfile) - async def _build_sync( + async def _yield_syncs( self, knowledge_dto: KnowledgeDTO ) -> AsyncGenerator[Optional[Tuple[KnowledgeDB, QuivrFile]], None]: if knowledge_dto.id is None: @@ -191,14 +192,6 @@ async def _build_sync( provider_name = SyncProvider(sync.provider.lower()) sync_provider = self.services.syncprovider_mapping[provider_name] - syncfile_to_knowledge, sync_files = await self.fetch_sync_knowledge( - sync_id=parent_knowledge.sync_id, - user_id=parent_knowledge.user_id, - folder_id=parent_knowledge.sync_file_id, - ) - if not sync_files: - return - # Yield parent_knowledge as the first knowledge to process file_data = await download_sync_file( sync_provider=sync_provider, @@ -216,10 +209,28 @@ async def _build_sync( with build_qfile(parent_knowledge, file_data) as qfile: yield (parent_knowledge, qfile) + # Fetch children + syncfile_to_knowledge, sync_files = await self.fetch_sync_knowledge( + sync_id=parent_knowledge.sync_id, + user_id=parent_knowledge.user_id, + folder_id=parent_knowledge.sync_file_id, + ) + if not sync_files: + return + for sync_file in sync_files: existing_km = syncfile_to_knowledge.get(sync_file.id) if existing_km: - file_knowledge = existing_km + # SyncKnowledge already exists => + # It's already processed in some other brain so just link it and move on if it is Processed + # ELSE reprocess the file + for brain in parent_knowledge.brains: + await self.services.knowledge_service.repository.link_to_brain( + existing_km, brain_id=brain.brain_id + ) + # Don't reprocess already added syncs + if existing_km.status == KnowledgeStatus.PROCESSED: + continue else: # create sync file knowledge # automagically gets the brains associated with the parent @@ -238,6 +249,7 @@ async def _build_sync( status=KnowledgeStatus.PROCESSING, upload_file=None, ) + # TODO: cache maybe processed files or skip them ? file_data = await download_sync_file( sync_provider=sync_provider, file=sync_file, diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index f40d2961fded..0c8a35841b6a 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -17,6 +17,7 @@ from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.tests.conftest import FakeStorage from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Sync from quivr_api.modules.sync.repository.sync_repository import SyncsRepository from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.tests.test_sync_controller import BaseFakeSync @@ -25,6 +26,7 @@ from quivr_api.modules.vector.repository.vectors_repository import VectorRepository from quivr_api.modules.vector.service.vector_service import VectorService from quivr_core.files.file import QuivrFile +from quivr_core.models import KnowledgeStatus from quivr_worker.process.processor import KnowledgeProcessor, ProcessorServices from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import select @@ -99,19 +101,22 @@ async def brain_user(session, user: User) -> Brain: return brain_1 +# NOTE: param sets the number of sync file the provider returns @pytest_asyncio.fixture(scope="function") -async def proc_services(session: AsyncSession) -> ProcessorServices: +async def proc_services(session: AsyncSession, request) -> ProcessorServices: storage = FakeStorage() embedder = DeterministicFakeEmbedding(size=settings.embedding_dim) - sync_provider_mapping: dict[SyncProvider, BaseSync] = { - provider: BaseFakeSync(provider_name=str(provider)) - for provider in list(SyncProvider) - } vector_repository = VectorRepository(session) vector_service = VectorService(vector_repository, embedder=embedder) knowledge_repository = KnowledgeRepository(session) knowledge_service = KnowledgeService(knowledge_repository, storage=storage) - sync_repository = SyncsRepository(session) + sync_provider_mapping: dict[SyncProvider, BaseSync] = { + provider: BaseFakeSync(provider_name=str(provider), n_get_files=request.param) + for provider in list(SyncProvider) + } + sync_repository = SyncsRepository( + session, sync_provider_mapping=sync_provider_mapping + ) sync_service = SyncsService(sync_repository) return ProcessorServices( @@ -127,6 +132,23 @@ async def km_processor(proc_services: ProcessorServices): return KnowledgeProcessor(proc_services) +@pytest_asyncio.fixture(scope="function") +async def sync(session: AsyncSession, user: User) -> Sync: + assert user.id + sync = Sync( + name="test_sync", + email="test@test.com", + user_id=user.id, + credentials={"test": "test"}, + provider=SyncProvider.GOOGLE, + ) + + session.add(sync) + await session.commit() + await session.refresh(sync) + return sync + + @pytest_asyncio.fixture(scope="function") async def local_knowledge_file( proc_services: ProcessorServices, user: User, brain_user: Brain @@ -141,7 +163,6 @@ async def local_knowledge_file( parent_id=None, ) km_data = BytesIO(os.urandom(24)) - km = await service.create_knowledge( user_id=user.id, knowledge_to_add=km_to_add, @@ -155,6 +176,73 @@ async def local_knowledge_file( return km +@pytest_asyncio.fixture(scope="function") +async def sync_knowledge_file( + session: AsyncSession, + proc_services: ProcessorServices, + user: User, + brain_user: Brain, + sync: Sync, +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + + km = KnowledgeDB( + file_name="test_file_1.txt", + extension=".txt", + status=KnowledgeStatus.PROCESSING, + source=SyncProvider.GOOGLE, + source_link="drive://test/test", + file_size=0, + file_sha1=None, + user_id=user.id, + brains=[brain_user], + parent=None, + sync_file_id="id1", + sync=sync, + ) + + session.add(km) + await session.commit() + await session.refresh(km) + + return km + + +@pytest_asyncio.fixture(scope="function") +async def sync_knowledge_folder( + session: AsyncSession, + proc_services: ProcessorServices, + user: User, + brain_user: Brain, + sync: Sync, +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + + km = KnowledgeDB( + file_name="folder1", + extension=".txt", + status=KnowledgeStatus.PROCESSING, + source=SyncProvider.GOOGLE, + source_link="drive://test/folder1", + file_size=0, + file_sha1=None, + user_id=user.id, + brains=[brain_user], + parent=None, + is_folder=True, + sync_file_id="id1", + sync=sync, + ) + + session.add(km) + await session.commit() + await session.refresh(km) + + return km + + @pytest.fixture def qfile_instance(tmp_path) -> QuivrFile: data = "This is some test data." diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index 83b029488737..6ba8559d43e3 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -6,12 +6,13 @@ from quivr_api.modules.vector.entity.vector import Vector from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus -from quivr_worker.process.processor import KnowledgeProcessor +from quivr_worker.process.processor import KnowledgeProcessor, ProcessorServices from sqlmodel import col, select from sqlmodel.ext.asyncio.session import AsyncSession @pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) async def test_process_local_file( monkeypatch, session: AsyncSession, @@ -47,3 +48,43 @@ async def _parse_file_mock( ) assert len(vecs) > 0 assert vecs[0].metadata_ is not None + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) +async def test_process_sync_file( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + sync_knowledge_file: KnowledgeDB, +): + async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + km_processor = KnowledgeProcessor(proc_services) + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + km_dto = await sync_knowledge_file.to_dto(get_children=False, get_parent=False) + await km_processor.process_knowledge(km_dto) + + # Check knowledge set to processed + assert km_dto.id + assert km_dto.brains + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(km_dto.id) + assert km.status == KnowledgeStatus.PROCESSED + assert km.brains[0].brain_id == km_dto.brains[0]["brain_id"] + + # Check vectors where added + vecs = list( + ( + await session.exec( + select(Vector).where(col(Vector.knowledge_id) == km_dto.id) + ) + ).all() + ) + assert len(vecs) > 0 + assert vecs[0].metadata_ is not None From 0a8faaf73adbc05564abf6a07198aaafe674cfeb Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 11:14:52 +0200 Subject: [PATCH 18/63] fixed tests --- .../modules/knowledge/tests/test_knowledge_controller.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index 63ecd4dd8d24..a75b75b5aec3 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -88,7 +88,8 @@ async def test_service(): # app.dependency_overrides[get_async_session] = lambda: session async with AsyncClient( - transport=ASGITransport(app=app), base_url="http://test" + transport=ASGITransport(app=app), # type: ignore + base_url="http://test", ) as ac: yield ac app.dependency_overrides = {} From 0654860a4ef6134c8ee5fa5dcbf17f7433288e36 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 11:22:15 +0200 Subject: [PATCH 19/63] link brain appends unique brains --- .../quivr_api/modules/knowledge/repository/knowledges.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 6893efe7bdaf..361bc1f475bf 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -97,7 +97,11 @@ async def link_knowledge_tree_brains( children = await self.get_knowledge_tree(knowledge.id) all_kms = [knowledge, *children] for k in all_kms: - for b in brains: + km_brains = {km_brain.brain_id for km_brain in k.brains} + for b in filter( + lambda b: b.brain_id not in km_brains, + brains, + ): k.brains.append(b) for k in all_kms: await self.session.merge(k) From 8fc821918492d67e2c265a72ef5de98532346e2f Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 12:13:46 +0200 Subject: [PATCH 20/63] added sync status --- .../sync/controller/azure_sync_routes.py | 10 ++- .../sync/controller/dropbox_sync_routes.py | 3 +- .../sync/controller/github_sync_routes.py | 9 ++- .../sync/controller/google_sync_routes.py | 3 +- .../sync/controller/notion_sync_routes.py | 3 +- .../api/quivr_api/modules/sync/dto/inputs.py | 10 ++- .../modules/sync/service/sync_service.py | 2 + ...l => 20240920153014_knowledge-folders.sql} | 0 ....sql => 20240920180003_knowledge-sync.sql} | 7 ++- .../worker/quivr_worker/process/processor.py | 62 +++++++++++-------- .../worker/tests/test_process_file_task.py | 41 ++++++++++++ 11 files changed, 112 insertions(+), 38 deletions(-) rename backend/supabase/migrations/{20240905153004_knowledge-folders copy.sql => 20240920153014_knowledge-folders.sql} (100%) rename backend/supabase/migrations/{20240918180003_knowledge-sync.sql => 20240920180003_knowledge-sync.sql} (75%) diff --git a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py index ad34adf38c4f..65d3224386d1 100644 --- a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py @@ -8,7 +8,7 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -74,7 +74,9 @@ async def authorize_azure( # Azure needs additional data await syncs_service.update_sync( sync_id=state.sync_id, - sync_user_input=SyncUpdateInput(additional_data={"flow": flow}), + sync_user_input=SyncUpdateInput( + additional_data={"flow": flow}, status=SyncStatus.SYNCED + ), ) return {"authorization_url": flow["auth_uri"]} @@ -131,7 +133,9 @@ async def oauth2callback_azure( user_email = user_info.get("mail") or user_info.get("userPrincipalName") logger.info(f"Retrieved email for user: {state.user_id} - {user_email}") - sync_user_input = SyncUpdateInput(credentials=flow_data, state={}, email=user_email) + sync_user_input = SyncUpdateInput( + credentials=flow_data, state={}, email=user_email, status=SyncStatus.SYNCED + ) await syncs_service.update_sync(state.sync_id, sync_user_input) logger.info(f"Azure sync created successfully for user: {state.user_id}") return HTMLResponse(successfullConnectionPage) diff --git a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py index ce53dca1b060..568837cfe9af 100644 --- a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py @@ -8,7 +8,7 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -134,6 +134,7 @@ async def oauth2callback_dropbox( credentials=credentials, state={}, email=user_email, + status=SyncStatus.SYNCED, ) await syncs_service.update_sync(sync.id, sync_user_input) logger.info(f"DropBox sync created successfully for user: {state.user_id}") diff --git a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py index 44c05e60df59..7dbb866c17d5 100644 --- a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py @@ -7,7 +7,7 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -133,7 +133,12 @@ async def oauth2callback_github( logger.info(f"Retrieved email for user: {state.user_id} - {user_email}") - sync_user_input = SyncUpdateInput(credentials=result, state={}, email=user_email) + sync_user_input = SyncUpdateInput( + credentials=result, + state={}, + email=user_email, + status=SyncStatus.SYNCED, + ) # TODO: This an additional select query :( await syncs_service.update_sync(sync.id, sync_user_input) diff --git a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py index d2bb10b716e6..4d02de674d4c 100644 --- a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py @@ -9,7 +9,7 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -144,6 +144,7 @@ async def oauth2callback_google( credentials=json.loads(creds.to_json()), state={}, email=user_email, + status=SyncStatus.SYNCED, ) sync = await syncs_service.update_sync(sync.id, sync_user_input) logger.info(f"Google Drive sync created successfully for user: {state.user_id}") diff --git a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py index baf4188d9b0f..81b0be95abaf 100644 --- a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py @@ -10,7 +10,7 @@ from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service -from quivr_api.modules.sync.dto.inputs import SyncUpdateInput +from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -127,6 +127,7 @@ async def oauth2callback_notion( credentials=result, state={}, email=user_email, + status=SyncStatus.SYNCED, ) await syncs_service.update_sync(state.sync_id, sync_user_input) diff --git a/backend/api/quivr_api/modules/sync/dto/inputs.py b/backend/api/quivr_api/modules/sync/dto/inputs.py index c487ed25903e..4a3824f7c67a 100644 --- a/backend/api/quivr_api/modules/sync/dto/inputs.py +++ b/backend/api/quivr_api/modules/sync/dto/inputs.py @@ -1,8 +1,16 @@ +import enum from uuid import UUID from pydantic import BaseModel +class SyncStatus(str, enum.Enum): + SYNCED = "SYNCED" + SYNCING = "SYNCING" + ERROR = "ERROR" + REMOVED = "REMOVED" + + class SyncCreateInput(BaseModel): """ Input model for creating a new sync user. @@ -38,4 +46,4 @@ class SyncUpdateInput(BaseModel): credentials: dict | None = None state: dict | None = None email: str | None = None - status: str | None = None + status: SyncStatus diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index cee888b57add..18428fd29c99 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -7,6 +7,7 @@ from quivr_api.modules.dependencies import BaseService from quivr_api.modules.sync.dto.inputs import ( SyncCreateInput, + SyncStatus, SyncUpdateInput, ) from quivr_api.modules.sync.dto.outputs import SyncsOutput @@ -67,6 +68,7 @@ async def create_oauth2_state( credentials={}, state={"state": state}, additional_data=additional_data, + status=SyncStatus.SYNCING, ) sync = await self.create_sync_user(sync_user_input) return Oauth2State(sync_id=sync.id, **state_struct.model_dump()) diff --git a/backend/supabase/migrations/20240905153004_knowledge-folders copy.sql b/backend/supabase/migrations/20240920153014_knowledge-folders.sql similarity index 100% rename from backend/supabase/migrations/20240905153004_knowledge-folders copy.sql rename to backend/supabase/migrations/20240920153014_knowledge-folders.sql diff --git a/backend/supabase/migrations/20240918180003_knowledge-sync.sql b/backend/supabase/migrations/20240920180003_knowledge-sync.sql similarity index 75% rename from backend/supabase/migrations/20240918180003_knowledge-sync.sql rename to backend/supabase/migrations/20240920180003_knowledge-sync.sql index 2d7c71d0137e..f8897e23a7ec 100644 --- a/backend/supabase/migrations/20240918180003_knowledge-sync.sql +++ b/backend/supabase/migrations/20240920180003_knowledge-sync.sql @@ -1,15 +1,16 @@ -- Renamed syncs ALTER TABLE syncs_user RENAME TO syncs; --- Add column sync in knowledge +-- Add column foreign key sync in knowledge ALTER TABLE "public"."knowledge" ADD COLUMN "sync_id" INTEGER; ALTER TABLE "public"."knowledge" ADD CONSTRAINT "public_knowledge_sync_id_fkey" FOREIGN KEY (sync_id) REFERENCES syncs(id) ON DELETE CASCADE; +-- Add column for sync_file_ids ALTER TABLE "public"."knowledge" ADD COLUMN "sync_file_id" TEXT; -CREATE UNIQUE INDEX knowledge_sync_id_pkey ON public.knowledge USING btree (sync_id); -CREATE UNIQUE INDEX knowledge_sync_file_id_pkey ON public.knowledge USING btree (sync_file_id); +CREATE INDEX knowledge_sync_id_pkey ON public.knowledge USING btree (sync_id); +CREATE INDEX knowledge_sync_file_id_pkey ON public.knowledge USING btree (sync_file_id); -- Add columns syncs alter table "public"."syncs" add column "created_at" timestamp with time zone default now(); diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index f4e5f67f9565..099a968ebba9 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -121,26 +121,33 @@ async def fetch_sync_knowledge( return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 async def yield_processable_kms( - self, knowledge: KnowledgeDTO + self, knowledge_dto: KnowledgeDTO ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: - if knowledge.source == KnowledgeSource.LOCAL: - async for to_process in self._yield_local(knowledge): + """Should only yield ready to process knowledges: + Knowledge ready to process: + - Is either Local or Sync + - Is in a status: PROCESSING | ERROR + - Has an associated QuivrFile that is parsable + """ + if knowledge_dto.source == KnowledgeSource.LOCAL: + async for to_process in self._yield_local(knowledge_dto): yield to_process - elif knowledge.source in ( + elif knowledge_dto.source in ( KnowledgeSource.AZURE, - KnowledgeSource.GITHUB, KnowledgeSource.GOOGLE, + KnowledgeSource.DROPBOX, + KnowledgeSource.GITHUB, KnowledgeSource.NOTION, ): - async for to_process in self._yield_syncs(knowledge): + async for to_process in self._yield_syncs(knowledge_dto): yield to_process - elif knowledge.source == KnowledgeSource.WEB: + elif knowledge_dto.source == KnowledgeSource.WEB: raise NotImplementedError else: logger.error( - f"received knowledge : {knowledge.id} with unknown source: {knowledge.source}" + f"received knowledge : {knowledge_dto.id} with unknown source: {knowledge_dto.source}" ) - raise ValueError("Unknown knowledge source : {knoledge.source}") + raise ValueError(f"Unknown knowledge source : {knowledge_dto.source}") async def _yield_local( self, knowledge: KnowledgeDTO @@ -218,6 +225,7 @@ async def _yield_syncs( if not sync_files: return + breakpoint() for sync_file in sync_files: existing_km = syncfile_to_knowledge.get(sync_file.id) if existing_km: @@ -228,7 +236,7 @@ async def _yield_syncs( await self.services.knowledge_service.repository.link_to_brain( existing_km, brain_id=brain.brain_id ) - # Don't reprocess already added syncs + # Don't reprocess already added syncs knowledges if existing_km.status == KnowledgeStatus.PROCESSED: continue else: @@ -249,7 +257,6 @@ async def _yield_syncs( status=KnowledgeStatus.PROCESSING, upload_file=None, ) - # TODO: cache maybe processed files or skip them ? file_data = await download_sync_file( sync_provider=sync_provider, file=sync_file, @@ -261,19 +268,22 @@ async def _yield_syncs( async def process_knowledge(self, knowledge_dto: KnowledgeDTO): async for knowledge_tuple in self.yield_processable_kms(knowledge_dto): - if knowledge_tuple is None: - continue - knowledge, qfile = knowledge_tuple - if not skip_process(knowledge): - chunks = await parse_qfile(qfile=qfile) - await store_chunks( - knowledge=knowledge, - chunks=chunks, - vector_service=self.services.vector_service, + try: + if knowledge_tuple is None: + continue + knowledge, qfile = knowledge_tuple + if not skip_process(knowledge): + chunks = await parse_qfile(qfile=qfile) + await store_chunks( + knowledge=knowledge, + chunks=chunks, + vector_service=self.services.vector_service, + ) + await self.services.knowledge_service.update_knowledge( + knowledge, + KnowledgeUpdate( + status=KnowledgeStatus.PROCESSED, file_sha1=knowledge.file_sha1 + ), ) - await self.services.knowledge_service.update_knowledge( - knowledge, - KnowledgeUpdate( - status=KnowledgeStatus.PROCESSED, file_sha1=knowledge.file_sha1 - ), - ) + except Exception as e: + logger.error(f"Error processing knowledge {knowledge_dto.id} : {e}") diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index 6ba8559d43e3..139917315b61 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -88,3 +88,44 @@ async def _parse_file_mock( ) assert len(vecs) > 0 assert vecs[0].metadata_ is not None + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [2], indirect=True) +async def test_process_sync_folder( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + sync_knowledge_folder: KnowledgeDB, +): + async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + km_processor = KnowledgeProcessor(proc_services) + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + km_dto = await sync_knowledge_folder.to_dto(get_children=False, get_parent=False) + await km_processor.process_knowledge(km_dto) + + # Check knowledge set to processed + assert km_dto.id + assert km_dto.brains + assert km_dto.brains[0] + knowledge_service = km_processor.services.knowledge_service + # FIXME (@AmineDiro): brain dto!! + kms = await knowledge_service.get_all_knowledge_in_brain( + km_dto.brains[0]["brain_id"] + ) + + for km in kms: + assert km.status == KnowledgeStatus.PROCESSED + assert km.brains[0]["brain_id"] + assert km.brains[0]["brain_id"] == km_dto.brains[0]["brain_id"] + + # Check vectors where added + vecs = list((await session.exec(select(Vector))).all()) + assert len(vecs) > 0 + assert vecs[0].metadata_ is not None From c9d8a504cce77b68d560e9dd6cd84b18fa5a95f3 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 13:53:01 +0200 Subject: [PATCH 21/63] process file sync --- .../sync/tests/test_sync_controller.py | 10 ++-- .../worker/quivr_worker/process/processor.py | 9 ++- backend/worker/tests/conftest.py | 55 ++++++++++++++++++- .../worker/tests/test_process_file_task.py | 52 +++++++++++++++++- 4 files changed, 114 insertions(+), 12 deletions(-) diff --git a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py index 38ad64be46e3..df4780ba0af0 100644 --- a/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py +++ b/backend/api/quivr_api/modules/sync/tests/test_sync_controller.py @@ -2,7 +2,6 @@ from datetime import datetime from io import BytesIO from typing import Dict, List, Union -from uuid import uuid4 import pytest import pytest_asyncio @@ -29,12 +28,13 @@ from quivr_api.modules.sync.utils.sync import BaseSync from quivr_api.modules.user.entity.user_identity import User, UserIdentity +# TODO: move to top layer MAX_SYNC_FILES = 1000 N_GET_FILES = 2 -FOLDER_SYNC_FILE_IDS = [str(uuid4())[:8] for _ in range(MAX_SYNC_FILES)] +FOLDER_SYNC_FILE_IDS = [f"file-{str(idx)}" for idx in range(MAX_SYNC_FILES)] -class BaseFakeSync(BaseSync): +class FakeSync(BaseSync): name = "FakeProvider" lower_name = "google" datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" @@ -79,7 +79,7 @@ def get_files( extension=".txt", web_view_link=f"test.com/{fid}", parent_id=folder_id, - is_folder=idx % 2 == 0, + is_folder=idx % 2 == 1, last_modified_at=datetime.now(), ) for idx, fid in enumerate(self.folder_sync_file_ids) @@ -174,7 +174,7 @@ def default_current_user() -> UserIdentity: async def _sync_service(): fake_provider: dict[SyncProvider, BaseSync] = { - provider: BaseFakeSync(n_get_files=N_GET_FILES) + provider: FakeSync(n_get_files=N_GET_FILES) for provider in list(SyncProvider) } repository = SyncsRepository(session) diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 099a968ebba9..62773695bce9 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -225,14 +225,17 @@ async def _yield_syncs( if not sync_files: return - breakpoint() for sync_file in sync_files: existing_km = syncfile_to_knowledge.get(sync_file.id) - if existing_km: + if existing_km is not None: # SyncKnowledge already exists => # It's already processed in some other brain so just link it and move on if it is Processed # ELSE reprocess the file - for brain in parent_knowledge.brains: + km_brains = {km_brain.brain_id for km_brain in existing_km.brains} + for brain in filter( + lambda b: b.brain_id not in km_brains, + parent_knowledge.brains, + ): await self.services.knowledge_service.repository.link_to_brain( existing_km, brain_id=brain.brain_id ) diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index 0c8a35841b6a..cb02a65334ca 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -20,7 +20,7 @@ from quivr_api.modules.sync.entity.sync_models import Sync from quivr_api.modules.sync.repository.sync_repository import SyncsRepository from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.tests.test_sync_controller import BaseFakeSync +from quivr_api.modules.sync.tests.test_sync_controller import FakeSync from quivr_api.modules.sync.utils.sync import BaseSync from quivr_api.modules.user.entity.user_identity import User from quivr_api.modules.vector.repository.vectors_repository import VectorRepository @@ -111,7 +111,7 @@ async def proc_services(session: AsyncSession, request) -> ProcessorServices: knowledge_repository = KnowledgeRepository(session) knowledge_service = KnowledgeService(knowledge_repository, storage=storage) sync_provider_mapping: dict[SyncProvider, BaseSync] = { - provider: BaseFakeSync(provider_name=str(provider), n_get_files=request.param) + provider: FakeSync(provider_name=str(provider), n_get_files=request.param) for provider in list(SyncProvider) } sync_repository = SyncsRepository( @@ -243,6 +243,57 @@ async def sync_knowledge_folder( return km +@pytest_asyncio.fixture(scope="function") +async def sync_knowledge_folder_with_file_in_brain( + session: AsyncSession, + proc_services: ProcessorServices, + user: User, + brain_user: Brain, + sync: Sync, +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + file = KnowledgeDB( + file_name="file", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://test/file1", + file_size=10, + file_sha1="test", + user_id=user.id, + brains=[brain_user], + parent=None, + is_folder=False, + # NOTE: See FakeSync Implementation + sync_file_id="file-0", + sync=sync, + ) + + km = KnowledgeDB( + file_name="folder1", + extension=".txt", + status=KnowledgeStatus.PROCESSING, + source=SyncProvider.GOOGLE, + source_link="drive://test/folder1", + file_size=0, + file_sha1=None, + user_id=user.id, + brains=[brain_user], + parent=None, + is_folder=True, + sync_file_id="id1", + sync=sync, + ) + + session.add(file) + session.add(km) + await session.commit() + await session.refresh(km) + + return km + + @pytest.fixture def qfile_instance(tmp_path) -> QuivrFile: data = "This is some test data." diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index 139917315b61..c3ad23f2ebb9 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -91,7 +91,7 @@ async def _parse_file_mock( @pytest.mark.asyncio(loop_scope="session") -@pytest.mark.parametrize("proc_services", [2], indirect=True) +@pytest.mark.parametrize("proc_services", [4], indirect=True) async def test_process_sync_folder( monkeypatch, session: AsyncSession, @@ -120,6 +120,8 @@ async def _parse_file_mock( km_dto.brains[0]["brain_id"] ) + # NOTE : this knowledge + 2 remote sync files + assert len(kms) == 5 for km in kms: assert km.status == KnowledgeStatus.PROCESSED assert km.brains[0]["brain_id"] @@ -127,5 +129,51 @@ async def _parse_file_mock( # Check vectors where added vecs = list((await session.exec(select(Vector))).all()) - assert len(vecs) > 0 + # Fake sync return a folder half the time, we skip folders + assert len(vecs) >= 2 assert vecs[0].metadata_ is not None + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [1], indirect=True) +async def test_process_sync_folder_with_file_in_brain( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + sync_knowledge_folder_with_file_in_brain: KnowledgeDB, +): + async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + km_processor = KnowledgeProcessor(proc_services) + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + km_dto = await sync_knowledge_folder_with_file_in_brain.to_dto( + get_children=False, get_parent=False + ) + await km_processor.process_knowledge(km_dto) + + # Check knowledge set to processed + assert km_dto.id + assert km_dto.brains + assert km_dto.brains[0] + knowledge_service = km_processor.services.knowledge_service + # FIXME (@AmineDiro): brain dto!! + kms = await knowledge_service.get_all_knowledge_in_brain( + km_dto.brains[0]["brain_id"] + ) + + # NOTE : this knowledge + 2 remote sync files + assert len(kms) == 2 + for km in kms: + assert km.status == KnowledgeStatus.PROCESSED + assert len(km.brains) == 1, "File added to the same brain multiple times" + assert km.brains[0]["brain_id"] + assert km.brains[0]["brain_id"] == km_dto.brains[0]["brain_id"] + + # Check vectors + vecs = list((await session.exec(select(Vector))).all()) + assert len(vecs) == 0, "File reprocessed, or folder processed " From ae3a5ef41fb003fcbd7135fc9350fe9d1107bc3f Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 14:25:56 +0200 Subject: [PATCH 22/63] worker tests working --- backend/worker/tests/conftest.py | 22 ++- backend/worker/tests/test_process_file.py | 88 ++++++------ backend/worker/tests/test_process_url_task.py | 125 ------------------ 3 files changed, 57 insertions(+), 178 deletions(-) delete mode 100644 backend/worker/tests/test_process_url_task.py diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index cb02a65334ca..75cd24343999 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -1,5 +1,6 @@ import os from io import BytesIO +from pathlib import Path from uuid import uuid4 import pytest @@ -104,6 +105,8 @@ async def brain_user(session, user: User) -> Brain: # NOTE: param sets the number of sync file the provider returns @pytest_asyncio.fixture(scope="function") async def proc_services(session: AsyncSession, request) -> ProcessorServices: + n_get_files = getattr(request, "param", 0) + storage = FakeStorage() embedder = DeterministicFakeEmbedding(size=settings.embedding_dim) vector_repository = VectorRepository(session) @@ -111,7 +114,7 @@ async def proc_services(session: AsyncSession, request) -> ProcessorServices: knowledge_repository = KnowledgeRepository(session) knowledge_service = KnowledgeService(knowledge_repository, storage=storage) sync_provider_mapping: dict[SyncProvider, BaseSync] = { - provider: FakeSync(provider_name=str(provider), n_get_files=request.param) + provider: FakeSync(provider_name=str(provider), n_get_files=n_get_files) for provider in list(SyncProvider) } sync_repository = SyncsRepository( @@ -311,7 +314,7 @@ def qfile_instance(tmp_path) -> QuivrFile: @pytest.fixture -def audio_file(tmp_path) -> QuivrFile: +def audio_qfile(tmp_path) -> QuivrFile: data = os.urandom(128) temp_file = tmp_path / "data.mp4" temp_file.write_bytes(data) @@ -324,3 +327,18 @@ def audio_file(tmp_path) -> QuivrFile: path=temp_file.absolute(), file_size=len(data), ) + + +@pytest.fixture +def pdf_qfile(tmp_path) -> QuivrFile: + data = "This is some test data." + temp_file = tmp_path / "data.txt" + temp_file.write_text(data) + return QuivrFile( + id=uuid4(), + file_extension=".pdf", + original_filename="sample.pdf", + file_sha1="124", + file_size=1000, + path=Path("./tests/sample.pdf"), + ) diff --git a/backend/worker/tests/test_process_file.py b/backend/worker/tests/test_process_file.py index ebe30702a29a..cd27813b04d7 100644 --- a/backend/worker/tests/test_process_file.py +++ b/backend/worker/tests/test_process_file.py @@ -1,30 +1,47 @@ -import datetime import os -from pathlib import Path from uuid import uuid4 import pytest -from quivr_api.modules.brain.entity.brain_entity import BrainEntity, BrainType -from quivr_core.files.file import FileExtension +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_worker.process.process_file import parse_qfile from quivr_worker.process.utils import build_qfile -def test_build_qfile(): +def test_build_qfile_fail(local_knowledge_file: KnowledgeDB): random_bytes = os.urandom(128) - brain_id = uuid4() - file_name = f"{brain_id}/test_file.txt" - knowledge_id = uuid4() + local_knowledge_file.file_sha1 = None + with pytest.raises(AssertionError): + with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as _: + pass + local_knowledge_file.file_sha1 = "sha1" + local_knowledge_file.id = None + with pytest.raises(AssertionError): + with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as _: + pass - with build_qfile(random_bytes, file_name) as file: + local_knowledge_file.id = uuid4() + local_knowledge_file.file_name = None + with pytest.raises(AssertionError): + with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as _: + pass + + +def test_build_qfile(local_knowledge_file: KnowledgeDB): + random_bytes = os.urandom(128) + local_knowledge_file.file_sha1 = "sha1" + + with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as file: + assert file.id == local_knowledge_file.id assert file.file_size == 128 - assert file.file_name == "test_file.txt" - assert file.id == knowledge_id - assert file.file_extension == FileExtension.txt + assert file.original_filename == local_knowledge_file.file_name + assert file.file_extension == local_knowledge_file.extension + if local_knowledge_file.metadata_: + assert local_knowledge_file.metadata_.items() <= file.metadata.items() + assert file.brain_id is None -@pytest.mark.asyncio -async def test_parse_audio(monkeypatch, audio_file): +@pytest.mark.asyncio(loop_scope="session") +async def test_parse_audio(monkeypatch, audio_qfile): from openai.resources.audio.transcriptions import Transcriptions from openai.types.audio.transcription import Transcription @@ -32,56 +49,25 @@ def transcribe(*args, **kwargs): return Transcription(text="audio data") monkeypatch.setattr(Transcriptions, "create", transcribe) - brain = BrainEntity( - brain_id=uuid4(), - name="test", - brain_type=BrainType.doc, - last_update=datetime.datetime.now(), - ) chunks = await parse_qfile( - file=audio_file, - brain=brain, + qfile=audio_qfile, ) assert len(chunks) > 0 assert chunks[0].page_content == "audio data" -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_parse_file(qfile_instance): - brain = BrainEntity( - brain_id=uuid4(), - name="test", - brain_type=BrainType.doc, - last_update=datetime.datetime.now(), - ) chunks = await parse_qfile( - file=qfile_instance, - brain=brain, + qfile=qfile_instance, ) assert len(chunks) > 0 -@pytest.mark.asyncio -async def test_parse_file_pdf(): - file_instance = File( - knowledge_id=uuid4(), - file_sha1="124", - file_extension=".pdf", - file_name="test", - original_file_name="test", - file_size=1000, - tmp_file_path=Path("./tests/sample.pdf"), - ) - brain = BrainEntity( - brain_id=uuid4(), - name="test", - brain_type=BrainType.doc, - last_update=datetime.datetime.now(), - ) +@pytest.mark.asyncio(loop_scope="session") +async def test_parse_file_pdf(pdf_qfile): chunks = await parse_qfile( - file=file_instance, - brain=brain, + qfile=pdf_qfile, ) - assert len(chunks[0].page_content) > 0 assert len(chunks) > 0 diff --git a/backend/worker/tests/test_process_url_task.py b/backend/worker/tests/test_process_url_task.py deleted file mode 100644 index a34501b52001..000000000000 --- a/backend/worker/tests/test_process_url_task.py +++ /dev/null @@ -1,125 +0,0 @@ -import asyncio -import os -from typing import List, Tuple -from uuid import uuid4 - -import pytest -import pytest_asyncio -import sqlalchemy -from quivr_api.celery_config import celery -from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB -from quivr_api.modules.user.entity.user_identity import User -from quivr_worker.parsers.crawler import URL, extract_from_url -from sqlalchemy.ext.asyncio import create_async_engine -from sqlmodel import select -from sqlmodel.ext.asyncio.session import AsyncSession - -pg_database_base_url = "postgres:postgres@localhost:54322/postgres" - -async_engine = create_async_engine( - "postgresql+asyncpg://" + pg_database_base_url, - echo=True if os.getenv("ORM_DEBUG") else False, - future=True, - pool_pre_ping=True, - pool_size=10, - pool_recycle=0.1, -) - - -TestData = Tuple[Brain, List[KnowledgeDB]] - - -@pytest_asyncio.fixture(scope="function") -async def session(): - print("\nSESSION_EVEN_LOOP", id(asyncio.get_event_loop())) - async with async_engine.connect() as conn: - trans = await conn.begin() - nested = await conn.begin_nested() - async_session = AsyncSession( - conn, - expire_on_commit=False, - autoflush=False, - autocommit=False, - ) - - @sqlalchemy.event.listens_for( - async_session.sync_session, "after_transaction_end" - ) - def end_savepoint(session, transaction): - nonlocal nested - if not nested.is_active: - nested = conn.sync_connection.begin_nested() - - yield async_session - await trans.rollback() - await async_session.close() - - -@pytest_asyncio.fixture() -async def test_data(session: AsyncSession) -> TestData: - user_1 = ( - await session.exec(select(User).where(User.email == "admin@quivr.app")) - ).one() - assert user_1.id - # Brain data - brain_1 = Brain( - name="test_brain", - description="this is a test brain", - brain_type=BrainType.integration, - ) - - knowledge_brain_1 = KnowledgeDB( - file_name="test_file_1", - extension="txt", - status="UPLOADED", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test_sha1", - brains=[brain_1], - user_id=user_1.id, - ) - - knowledge_brain_2 = KnowledgeDB( - file_name="test_file_2", - extension="txt", - status="UPLOADED", - source="test_source", - source_link="test_source_link", - file_size=100, - file_sha1="test_sha2", - brains=[], - user_id=user_1.id, - ) - - session.add(brain_1) - session.add(knowledge_brain_1) - session.add(knowledge_brain_2) - await session.commit() - return brain_1, [knowledge_brain_1, knowledge_brain_2] - - -@pytest.mark.skip -def test_crawl(): - url = "https://en.wikipedia.org/wiki/Python_(programming_language)" - crawl_website = URL(url=url) - extracted_content = extract_from_url(crawl_website) - - assert len(extracted_content) > 1 - - -@pytest.mark.skip -def test_process_crawl_task(test_data: TestData): - brain, [knowledge, _] = test_data - url = "https://en.wikipedia.org/wiki/Python_(programming_language)" - task = celery.send_task( - "process_crawl_task", - kwargs={ - "crawl_website_url": url, - "brain_id": brain.brain_id, - "knowledge_id": knowledge.id, - "notification_id": uuid4(), - }, - ) - result = task.wait() # noqa: F841 From da160ae4da7b549e4596d2ca9a5dc05ffe3718f9 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 16:08:32 +0200 Subject: [PATCH 23/63] small refacto --- .../knowledge/controller/knowledge_routes.py | 3 - .../knowledge/service/knowledge_service.py | 1 + .../knowledge/tests/integration_test.py | 50 ++ .../quivr_api/modules/sync/utils/syncutils.py | 694 +++++++++--------- .../quivr_worker/assistants/assistants.py | 36 +- backend/worker/quivr_worker/celery_worker.py | 197 ++--- .../worker/quivr_worker/process/__init__.py | 11 + .../worker/quivr_worker/process/processor.py | 46 +- backend/worker/tests/conftest.py | 7 +- .../worker/tests/test_process_file_task.py | 109 ++- 10 files changed, 605 insertions(+), 549 deletions(-) create mode 100644 backend/api/quivr_api/modules/knowledge/tests/integration_test.py diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 5d690d1e8c01..aa089eebae00 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -314,10 +314,7 @@ async def _send_knowledge_process(knowledge: KnowledgeDB): "process_file_task", kwargs={ "knowledge_id": knowledge.id, - "file_name": knowledge.file_name, "notification_id": upload_notification.id, - "source": knowledge.source, - "source_link": knowledge.source_link, }, ) knowledge = await knowledge_service.update_knowledge( diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 6ed5d125ee1f..d54e375022bd 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -175,6 +175,7 @@ async def insert_knowledge_brain( user_id: UUID, knowledge_to_add: CreateKnowledgeProperties, # FIXME: (later) @Amine brain id should not be in CreateKnowledgeProperties but since storage is brain_id/file_name ) -> KnowledgeDTO: + # TODO: check input knowledge = KnowledgeDB( file_name=knowledge_to_add.file_name, url=knowledge_to_add.url, diff --git a/backend/api/quivr_api/modules/knowledge/tests/integration_test.py b/backend/api/quivr_api/modules/knowledge/tests/integration_test.py new file mode 100644 index 000000000000..023378621d21 --- /dev/null +++ b/backend/api/quivr_api/modules/knowledge/tests/integration_test.py @@ -0,0 +1,50 @@ +import asyncio +import json +from uuid import UUID, uuid4 + +from httpx import AsyncClient + +from quivr_api.modules.knowledge.dto.inputs import LinkKnowledgeBrain +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO + + +async def main(): + url = "http://localhost:5050" + km_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": False, + "parent_id": None, + } + + multipart_data = { + "knowledge_data": (None, json.dumps(km_data), "application/json"), + "file": ("test_file.txt", b"Test file content", "application/octet-stream"), + } + + async with AsyncClient( + base_url=url, headers={"Authorization": "Bearer 123"} + ) as test_client: + response = await test_client.post( + "/knowledge/", + files=multipart_data, + ) + response.raise_for_status() + km = KnowledgeDTO.model_validate(response.json()) + + json_data = LinkKnowledgeBrain( + bulk_id=uuid4(), + brain_ids=[UUID("40ba47d7-51b2-4b2a-9247-89e29619efb0")], + knowledge=km, + ).model_dump_json() + response = await test_client.post( + "/knowledge/link_to_brains/", + content=json_data, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + print(response.json()) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py index 8ffdeaa3c0cd..1edaaa26dfc4 100644 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ b/backend/api/quivr_api/modules/sync/utils/syncutils.py @@ -1,30 +1,12 @@ -import io -import os -from datetime import datetime, timezone -from typing import Any, List, Tuple -from uuid import UUID, uuid4 +from typing import List, Tuple +from uuid import UUID from quivr_api.celery_config import celery from quivr_api.logger import get_logger -from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors -from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService -from quivr_api.modules.notification.dto.inputs import ( - CreateNotification, - NotificationUpdatableProperties, -) -from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum -from quivr_api.modules.notification.service.notification_service import ( - NotificationService, -) +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.sync.entity.sync_models import ( - DownloadedSyncFile, SyncFile, ) -from quivr_api.modules.sync.utils.sync import BaseSync -from quivr_api.modules.upload.service.upload_file import ( - check_file_exists, - upload_file_storage, -) logger = get_logger(__name__) @@ -84,338 +66,338 @@ async def fetch_sync_knowledge( # return should_download -class SyncUtils: - def __init__( - self, - # sync_user_service: ISyncUserService, - # sync_active_service: ISyncService, - # sync_files_repo: SyncFileInterface, - sync_cloud: BaseSync, - knowledge_service: KnowledgeService, - notification_service: NotificationService, - brain_vectors: BrainsVectors, - ) -> None: - self.sync_user_service = sync_user_service - self.sync_active_service = sync_active_service - self.sync_files_repo = sync_files_repo - self.knowledge_service = knowledge_service - self.sync_cloud = sync_cloud - self.notification_service = notification_service - self.brain_vectors = brain_vectors - - # TODO: This modifies the file, we should treat it as such - def create_sync_bulk_notification( - self, files: list[SyncFile], current_user: UUID, brain_id: UUID, bulk_id: UUID - ) -> list[SyncFile]: - res = [] - # TODO: bulk insert in batch - for file in files: - upload_notification = self.notification_service.add_notification( - CreateNotification( - user_id=current_user, - bulk_id=bulk_id, - status=NotificationsStatusEnum.INFO, - title=file.name, - category="sync", - brain_id=str(brain_id), - ) - ) - file.notification_id = upload_notification.id - res.append(file) - return res - - async def download_file( - self, file: SyncFile, credentials: dict[str, Any] - ) -> DownloadedSyncFile: - logger.info(f"Downloading {file} using {self.sync_cloud}") - file_response = await self.sync_cloud.adownload_file(credentials, file) - logger.debug(f"Fetch sync file response: {file_response}") - file_name = str(file_response["file_name"]) - raw_data = file_response["content"] - file_data = ( - io.BufferedReader(raw_data) # type: ignore - if isinstance(raw_data, io.BytesIO) - else io.BufferedReader(raw_data.encode("utf-8")) # type: ignore - ) - extension = os.path.splitext(file_name)[-1].lower() - dfile = DownloadedSyncFile( - file_name=file_name, - file_data=file_data, - extension=extension, - ) - logger.debug(f"Successfully downloaded sync file : {dfile}") - return dfile - - # TODO: REDO THIS MESS !!!! - # REMOVE ALL SYNC TABLES and start from scratch - - async def process_sync_file( - self, - file: SyncFile, - previous_file: DBSyncFile | None, - current_user: SyncsUser, - sync_active: SyncsActive, - ): - logger.info("Processing file: %s", file.name) - brain_id = sync_active.brain_id - source, source_link = self.sync_cloud.name, file.web_view_link - downloaded_file = await self.download_file(file, current_user.credentials) - storage_path = f"{brain_id}/{downloaded_file.file_name}" - exists_in_storage = check_file_exists(str(brain_id), file.name) - - if downloaded_file.extension not in [ - ".pdf", - ".txt", - ".md", - ".csv", - ".docx", - ".xlsx", - ".pptx", - ".doc", - ]: - raise ValueError(f"Incompatible file extension for {downloaded_file}") - - response = await upload_file_storage( - downloaded_file.file_data, - storage_path, - upsert=exists_in_storage, - ) - assert response, f"Error uploading {downloaded_file} to {storage_path}" - self.notification_service.update_notification_by_id( - file.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.SUCCESS, - description="File downloaded successfully", - ), - ) - # TODO : why knowledge + syncfile, drop syncfile ... - # FIXME : Simplify this logic in KMS plzzz - sync_file_db = self.sync_files_repo.update_or_create_sync_file( - file=file, - previous_file=previous_file, - sync_active=sync_active, - supported=True, - ) - knowledge = await self.knowledge_service.update_or_create_knowledge_sync( - brain_id=brain_id, - file=file, - new_sync_file=sync_file_db, - prev_sync_file=previous_file, - downloaded_file=downloaded_file, - source=source, - source_link=source_link, - user_id=current_user.user_id, - ) - - # Send file for processing - celery.send_task( - "process_file_task", - kwargs={ - "brain_id": brain_id, - "knowledge_id": knowledge.id, - "file_name": storage_path, - "file_original_name": file.name, - "source": source, - "source_link": source_link, - "notification_id": file.notification_id, - }, - ) - return file - - async def process_sync_files( - self, - files: List[SyncFile], - current_user: SyncsUser, - sync_active: SyncsActive, - ): - logger.info(f"Processing {len(files)} for sync_active: {sync_active.id}") - current_user.credentials = self.sync_cloud.check_and_refresh_access_token( - current_user.credentials - ) - - bulk_id = uuid4() - downloaded_files = [] - list_existing_files = self.sync_files_repo.get_sync_files(sync_active.id) - existing_files = {f.path: f for f in list_existing_files} - - supported_files = filter_on_supported_files(files, existing_files) - - files = self.create_sync_bulk_notification( - files, current_user.user_id, sync_active.brain_id, bulk_id - ) - - for file, prev_file in supported_files: - try: - result = await self.process_sync_file( - file=file, - previous_file=prev_file, - current_user=current_user, - sync_active=sync_active, - ) - if result is not None: - downloaded_files.append(result) - - self.notification_service.update_notification_by_id( - file.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.SUCCESS, - description="File downloaded successfully", - ), - ) - - except Exception as e: - logger.error( - "An error occurred while syncing %s files: %s", - self.sync_cloud.name, - e, - ) - # TODO: this process_sync_file could fail for a LOT of reason redo this logic - # File isn't supported so we set it as so ? - self.sync_files_repo.update_or_create_sync_file( - file=file, - sync_active=sync_active, - previous_file=prev_file, - supported=False, - ) - self.notification_service.update_notification_by_id( - file.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.ERROR, - description="Error downloading file", - ), - ) - - return {"downloaded_files": downloaded_files} - - async def get_files_to_download( - self, sync_active: SyncsActive, user_sync: SyncsUser - ) -> list[SyncFile]: - # Get the folder id from the settings from sync_active - folders = sync_active.settings.get("folders", []) - files_ids = sync_active.settings.get("files", []) - - files = await self.get_syncfiles_from_ids( - user_sync.credentials, - files_ids=files_ids, - folder_ids=folders, - sync_user_id=user_sync.id, - ) - - logger.debug(f"original files to download for {sync_active.id} : {files}") - - last_synced_time = ( - datetime.fromisoformat(sync_active.last_synced).astimezone(timezone.utc) - if sync_active.last_synced - else None - ) - - files_ids = [ - file - for file in files - if should_download_file( - file=file, - last_updated_sync_active=last_synced_time, - provider_name=self.sync_cloud.lower_name, - datetime_format=self.sync_cloud.datetime_format, - ) - ] - - logger.debug(f"filter files to download for {sync_active} : {files_ids}") - return files_ids - - async def get_syncfiles_from_ids( - self, - credentials: dict[str, Any], - files_ids: list[str], - folder_ids: list[str], - sync_user_id: int, - ) -> list[SyncFile]: - files = [] - if self.sync_cloud.lower_name == "notion": - files_ids += folder_ids - - for folder_id in folder_ids: - logger.debug( - f"Recursively getting file_ids from {self.sync_cloud.name}. folder_id={folder_id}" - ) - files.extend( - await self.sync_cloud.aget_files( - credentials=credentials, - sync_user_id=sync_user_id, - folder_id=folder_id, - recursive=True, - ) - ) - if len(files_ids) > 0: - files.extend( - await self.sync_cloud.aget_files_by_id( - credentials=credentials, - file_ids=files_ids, - ) - ) - return files - - async def direct_sync( - self, - sync_active: SyncsActive, - sync_user: SyncsUser, - files_ids: list[str], - folder_ids: list[str], - ): - files = await self.get_syncfiles_from_ids( - sync_user.credentials, files_ids, folder_ids - ) - processed_files = await self.process_sync_files( - files=files, - current_user=sync_user, - sync_active=sync_active, - ) - - # Update the last_synced timestamp - self.sync_active_service.update_sync_active( - sync_active.id, - SyncsActiveUpdateInput( - last_synced=datetime.now().astimezone().isoformat(), force_sync=False - ), - ) - logger.info( - f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", - ) - return processed_files - - async def sync( - self, - sync_active: SyncsActive, - user_sync: SyncsUser, - ): - """ - Check if the Specific sync has not been synced and download the folders and files based on the settings. - - Args: - sync_active_id (int): The ID of the active sync. - user_id (str): The user ID associated with the active sync. - """ - logger.info( - "Starting %s sync for sync_active: %s", - self.sync_cloud.lower_name, - sync_active, - ) - - files_to_download = await self.get_files_to_download(sync_active, user_sync) - processed_files = await self.process_sync_files( - files=files_to_download, - current_user=user_sync, - sync_active=sync_active, - ) - - # Update the last_synced timestamp - self.sync_active_service.update_sync_active( - sync_active.id, - SyncsActiveUpdateInput( - last_synced=datetime.now().astimezone().isoformat(), force_sync=False - ), - ) - logger.info( - f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", - ) - return processed_files +# class SyncUtils: +# def __init__( +# self, +# # sync_user_service: ISyncUserService, +# # sync_active_service: ISyncService, +# # sync_files_repo: SyncFileInterface, +# sync_cloud: BaseSync, +# knowledge_service: KnowledgeService, +# notification_service: NotificationService, +# brain_vectors: BrainsVectors, +# ) -> None: +# self.sync_user_service = sync_user_service +# self.sync_active_service = sync_active_service +# self.sync_files_repo = sync_files_repo +# self.knowledge_service = knowledge_service +# self.sync_cloud = sync_cloud +# self.notification_service = notification_service +# self.brain_vectors = brain_vectors + +# # TODO: This modifies the file, we should treat it as such +# def create_sync_bulk_notification( +# self, files: list[SyncFile], current_user: UUID, brain_id: UUID, bulk_id: UUID +# ) -> list[SyncFile]: +# res = [] +# # TODO: bulk insert in batch +# for file in files: +# upload_notification = self.notification_service.add_notification( +# CreateNotification( +# user_id=current_user, +# bulk_id=bulk_id, +# status=NotificationsStatusEnum.INFO, +# title=file.name, +# category="sync", +# brain_id=str(brain_id), +# ) +# ) +# file.notification_id = upload_notification.id +# res.append(file) +# return res + +# async def download_file( +# self, file: SyncFile, credentials: dict[str, Any] +# ) -> DownloadedSyncFile: +# logger.info(f"Downloading {file} using {self.sync_cloud}") +# file_response = await self.sync_cloud.adownload_file(credentials, file) +# logger.debug(f"Fetch sync file response: {file_response}") +# file_name = str(file_response["file_name"]) +# raw_data = file_response["content"] +# file_data = ( +# io.BufferedReader(raw_data) # type: ignore +# if isinstance(raw_data, io.BytesIO) +# else io.BufferedReader(raw_data.encode("utf-8")) # type: ignore +# ) +# extension = os.path.splitext(file_name)[-1].lower() +# dfile = DownloadedSyncFile( +# file_name=file_name, +# file_data=file_data, +# extension=extension, +# ) +# logger.debug(f"Successfully downloaded sync file : {dfile}") +# return dfile + +# # TODO: REDO THIS MESS !!!! +# # REMOVE ALL SYNC TABLES and start from scratch + +# async def process_sync_file( +# self, +# file: SyncFile, +# previous_file: DBSyncFile | None, +# current_user: SyncsUser, +# sync_active: SyncsActive, +# ): +# logger.info("Processing file: %s", file.name) +# brain_id = sync_active.brain_id +# source, source_link = self.sync_cloud.name, file.web_view_link +# downloaded_file = await self.download_file(file, current_user.credentials) +# storage_path = f"{brain_id}/{downloaded_file.file_name}" +# exists_in_storage = check_file_exists(str(brain_id), file.name) + +# if downloaded_file.extension not in [ +# ".pdf", +# ".txt", +# ".md", +# ".csv", +# ".docx", +# ".xlsx", +# ".pptx", +# ".doc", +# ]: +# raise ValueError(f"Incompatible file extension for {downloaded_file}") + +# response = await upload_file_storage( +# downloaded_file.file_data, +# storage_path, +# upsert=exists_in_storage, +# ) +# assert response, f"Error uploading {downloaded_file} to {storage_path}" +# self.notification_service.update_notification_by_id( +# file.notification_id, +# NotificationUpdatableProperties( +# status=NotificationsStatusEnum.SUCCESS, +# description="File downloaded successfully", +# ), +# ) +# # TODO : why knowledge + syncfile, drop syncfile ... +# # FIXME : Simplify this logic in KMS plzzz +# sync_file_db = self.sync_files_repo.update_or_create_sync_file( +# file=file, +# previous_file=previous_file, +# sync_active=sync_active, +# supported=True, +# ) +# knowledge = await self.knowledge_service.update_or_create_knowledge_sync( +# brain_id=brain_id, +# file=file, +# new_sync_file=sync_file_db, +# prev_sync_file=previous_file, +# downloaded_file=downloaded_file, +# source=source, +# source_link=source_link, +# user_id=current_user.user_id, +# ) + +# # Send file for processing +# celery.send_task( +# "process_file_task", +# kwargs={ +# "brain_id": brain_id, +# "knowledge_id": knowledge.id, +# "file_name": storage_path, +# "file_original_name": file.name, +# "source": source, +# "source_link": source_link, +# "notification_id": file.notification_id, +# }, +# ) +# return file + +# async def process_sync_files( +# self, +# files: List[SyncFile], +# current_user: SyncsUser, +# sync_active: SyncsActive, +# ): +# logger.info(f"Processing {len(files)} for sync_active: {sync_active.id}") +# current_user.credentials = self.sync_cloud.check_and_refresh_access_token( +# current_user.credentials +# ) + +# bulk_id = uuid4() +# downloaded_files = [] +# list_existing_files = self.sync_files_repo.get_sync_files(sync_active.id) +# existing_files = {f.path: f for f in list_existing_files} + +# supported_files = filter_on_supported_files(files, existing_files) + +# files = self.create_sync_bulk_notification( +# files, current_user.user_id, sync_active.brain_id, bulk_id +# ) + +# for file, prev_file in supported_files: +# try: +# result = await self.process_sync_file( +# file=file, +# previous_file=prev_file, +# current_user=current_user, +# sync_active=sync_active, +# ) +# if result is not None: +# downloaded_files.append(result) + +# self.notification_service.update_notification_by_id( +# file.notification_id, +# NotificationUpdatableProperties( +# status=NotificationsStatusEnum.SUCCESS, +# description="File downloaded successfully", +# ), +# ) + +# except Exception as e: +# logger.error( +# "An error occurred while syncing %s files: %s", +# self.sync_cloud.name, +# e, +# ) +# # TODO: this process_sync_file could fail for a LOT of reason redo this logic +# # File isn't supported so we set it as so ? +# self.sync_files_repo.update_or_create_sync_file( +# file=file, +# sync_active=sync_active, +# previous_file=prev_file, +# supported=False, +# ) +# self.notification_service.update_notification_by_id( +# file.notification_id, +# NotificationUpdatableProperties( +# status=NotificationsStatusEnum.ERROR, +# description="Error downloading file", +# ), +# ) + +# return {"downloaded_files": downloaded_files} + +# async def get_files_to_download( +# self, sync_active: SyncsActive, user_sync: SyncsUser +# ) -> list[SyncFile]: +# # Get the folder id from the settings from sync_active +# folders = sync_active.settings.get("folders", []) +# files_ids = sync_active.settings.get("files", []) + +# files = await self.get_syncfiles_from_ids( +# user_sync.credentials, +# files_ids=files_ids, +# folder_ids=folders, +# sync_user_id=user_sync.id, +# ) + +# logger.debug(f"original files to download for {sync_active.id} : {files}") + +# last_synced_time = ( +# datetime.fromisoformat(sync_active.last_synced).astimezone(timezone.utc) +# if sync_active.last_synced +# else None +# ) + +# files_ids = [ +# file +# for file in files +# if should_download_file( +# file=file, +# last_updated_sync_active=last_synced_time, +# provider_name=self.sync_cloud.lower_name, +# datetime_format=self.sync_cloud.datetime_format, +# ) +# ] + +# logger.debug(f"filter files to download for {sync_active} : {files_ids}") +# return files_ids + +# async def get_syncfiles_from_ids( +# self, +# credentials: dict[str, Any], +# files_ids: list[str], +# folder_ids: list[str], +# sync_user_id: int, +# ) -> list[SyncFile]: +# files = [] +# if self.sync_cloud.lower_name == "notion": +# files_ids += folder_ids + +# for folder_id in folder_ids: +# logger.debug( +# f"Recursively getting file_ids from {self.sync_cloud.name}. folder_id={folder_id}" +# ) +# files.extend( +# await self.sync_cloud.aget_files( +# credentials=credentials, +# sync_user_id=sync_user_id, +# folder_id=folder_id, +# recursive=True, +# ) +# ) +# if len(files_ids) > 0: +# files.extend( +# await self.sync_cloud.aget_files_by_id( +# credentials=credentials, +# file_ids=files_ids, +# ) +# ) +# return files + +# async def direct_sync( +# self, +# sync_active: SyncsActive, +# sync_user: SyncsUser, +# files_ids: list[str], +# folder_ids: list[str], +# ): +# files = await self.get_syncfiles_from_ids( +# sync_user.credentials, files_ids, folder_ids +# ) +# processed_files = await self.process_sync_files( +# files=files, +# current_user=sync_user, +# sync_active=sync_active, +# ) + +# # Update the last_synced timestamp +# self.sync_active_service.update_sync_active( +# sync_active.id, +# SyncsActiveUpdateInput( +# last_synced=datetime.now().astimezone().isoformat(), force_sync=False +# ), +# ) +# logger.info( +# f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", +# ) +# return processed_files + +# async def sync( +# self, +# sync_active: SyncsActive, +# user_sync: SyncsUser, +# ): +# """ +# Check if the Specific sync has not been synced and download the folders and files based on the settings. + +# Args: +# sync_active_id (int): The ID of the active sync. +# user_id (str): The user ID associated with the active sync. +# """ +# logger.info( +# "Starting %s sync for sync_active: %s", +# self.sync_cloud.lower_name, +# sync_active, +# ) + +# files_to_download = await self.get_files_to_download(sync_active, user_sync) +# processed_files = await self.process_sync_files( +# files=files_to_download, +# current_user=user_sync, +# sync_active=sync_active, +# ) + +# # Update the last_synced timestamp +# self.sync_active_service.update_sync_active( +# sync_active.id, +# SyncsActiveUpdateInput( +# last_synced=datetime.now().astimezone().isoformat(), force_sync=False +# ), +# ) +# logger.info( +# f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", +# ) +# return processed_files diff --git a/backend/worker/quivr_worker/assistants/assistants.py b/backend/worker/quivr_worker/assistants/assistants.py index b44f7273ebbb..7310384b0a85 100644 --- a/backend/worker/quivr_worker/assistants/assistants.py +++ b/backend/worker/quivr_worker/assistants/assistants.py @@ -1,13 +1,43 @@ import os +from quivr_api.modules.assistant.repository.tasks import TasksRepository from quivr_api.modules.assistant.services.tasks_service import TasksService from quivr_api.modules.upload.service.upload_file import ( upload_file_storage, ) +from sqlalchemy.ext.asyncio import AsyncEngine +from quivr_worker.process.processor import _start_session from quivr_worker.utils.pdf_generator.pdf_generator import PDFGenerator, PDFModel +async def aprocess_assistant_task( + engine: AsyncEngine, + assistant_id: str, + notification_uuid: str, + task_id: int, + user_id: str, +): + async with _start_session(engine) as async_session: + try: + tasks_repository = TasksRepository(async_session) + tasks_service = TasksService(tasks_repository) + + await process_assistant( + assistant_id, + notification_uuid, + task_id, + tasks_service, + user_id, + ) + + except Exception as e: + await async_session.rollback() + raise e + finally: + await async_session.close() + + async def process_assistant( assistant_id: str, notification_uuid: str, @@ -16,10 +46,8 @@ async def process_assistant( user_id: str, ): task = await tasks_service.get_task_by_id(task_id, user_id) # type: ignore - - await tasks_service.update_task(task_id, {"status": "in_progress"}) - - print(task) + assert task.id + await tasks_service.update_task(task.id, {"status": "in_progress"}) task_result = {"status": "completed", "answer": "#### Assistant answer"} diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index ae37a8dcf153..aa4c0c8d3366 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -7,28 +7,14 @@ from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.models.settings import settings -from quivr_api.modules.assistant.repository.tasks import TasksRepository -from quivr_api.modules.assistant.services.tasks_service import TasksService from quivr_api.modules.brain.integrations.Notion.Notion_connector import NotionConnector -from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors -from quivr_api.modules.brain.service.brain_service import BrainService from quivr_api.modules.dependencies import get_supabase_client -from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO -from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage -from quivr_api.modules.notification.service.notification_service import ( - NotificationService, -) -from quivr_api.modules.sync.service.sync_notion import SyncNotionService from quivr_api.utils.telemetry import maybe_send_telemetry from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine -from quivr_worker.assistants.assistants import process_assistant +from quivr_worker.assistants.assistants import aprocess_assistant_task from quivr_worker.check_premium import check_is_premium -from quivr_worker.process.processor import KnowledgeProcessor, build_processor_services -from quivr_worker.syncs.process_active_syncs import ( - process_notion_sync, -) -from quivr_worker.syncs.store_notion import fetch_and_store_notion_files_async +from quivr_worker.process import aprocess_file_task from quivr_worker.utils.utils import _patch_json load_dotenv() @@ -37,16 +23,7 @@ logger = get_logger("celery_worker") _patch_json() - -# FIXME: load at init time -# Services supabase_client = get_supabase_client() -# document_vector_store = get_documents_vector_store() -notification_service = NotificationService() -brain_service = BrainService() -brain_vectors = BrainsVectors() -storage = SupabaseS3Storage() -notion_service: SyncNotionService | None = None async_engine: AsyncEngine | None = None @@ -56,11 +33,15 @@ def init_worker(**kwargs): if not async_engine: async_engine = create_async_engine( settings.pg_database_async_url, + connect_args={ + "server_settings": {"application_name": f"quivr-worker-{os.getpid()}"} + }, echo=True if os.getenv("ORM_DEBUG") else False, future=True, - # NOTE: pessimistic bound on + # NOTE: pessimistic bound on reconnect pool_pre_ping=True, - pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6 + # NOTE: no bouncer for now + pool_size=1, pool_recycle=1800, ) @@ -68,88 +49,54 @@ def init_worker(**kwargs): @celery.task( retries=3, default_retry_delay=1, - name="process_assistant_task", + name="process_file_task", autoretry_for=(Exception,), + dont_autoretry_for=(FileExistsError,), ) -def process_assistant_task( - assistant_id: str, - notification_uuid: str, - task_id: int, - user_id: str, +def process_file_task( + knowledge_id: UUID, + notification_id: UUID | None = None, ): + if async_engine is None: + init_worker() + assert async_engine logger.info( - f"process_assistant_task started for assistant_id={assistant_id}, notification_uuid={notification_uuid}, task_id={task_id}" + f"Task process_file started for knowledge_id={knowledge_id}, notification_id={notification_id}" ) - print("process_assistant_task") - loop = asyncio.get_event_loop() loop.run_until_complete( - aprocess_assistant_task( - assistant_id, - notification_uuid, - task_id, - user_id, - ) + aprocess_file_task(async_engine=async_engine, knowledge_id=knowledge_id) ) -async def aprocess_assistant_task( - assistant_id: str, - notification_uuid: str, - task_id: int, - user_id: str, -): - async with AsyncSession(async_engine) as async_session: - try: - await async_session.execute( - text("SET SESSION idle_in_transaction_session_timeout = '5min';") - ) - tasks_repository = TasksRepository(async_session) - tasks_service = TasksService(tasks_repository) - - await process_assistant( - assistant_id, - notification_uuid, - task_id, - tasks_service, - user_id, - ) - - except Exception as e: - await async_session.rollback() - raise e - finally: - await async_session.close() - - @celery.task( retries=3, default_retry_delay=1, - name="process_file_task", + name="process_assistant_task", autoretry_for=(Exception,), - dont_autoretry_for=(FileExistsError,), ) -def process_file_task( - knowledge_dto: KnowledgeDTO, - notification_id: UUID | None = None, +def process_assistant_task( + assistant_id: str, + notification_uuid: str, + task_id: int, + user_id: str, ): if async_engine is None: init_worker() - + assert async_engine logger.info( - f"Task process_file started for knowledge_id={knowledge_dto.id}, notification_id={notification_id}" + f"process_assistant_task started for assistant_id={assistant_id}, notification_uuid={notification_uuid}, task_id={task_id}" ) - loop = asyncio.get_event_loop() - loop.run_until_complete(aprocess_file_task(knowledge_dto)) - - -async def aprocess_file_task(knowledge_dto: KnowledgeDTO): - global async_engine - assert async_engine - async with build_processor_services(async_engine) as processor_services: - km_processor = KnowledgeProcessor(processor_services) - await km_processor.process_knowledge(knowledge_dto) + loop.run_until_complete( + aprocess_assistant_task( + async_engine, + assistant_id, + notification_uuid, + task_id, + user_id, + ) + ) @celery.task(name="NotionConnectorLoad") @@ -175,43 +122,43 @@ def check_is_premium_task(): check_is_premium(supabase_client) -@celery.task(name="process_notion_sync_task") -def process_notion_sync_task(): - global async_engine - assert async_engine - loop = asyncio.get_event_loop() - loop.run_until_complete(process_notion_sync(async_engine)) - - -@celery.task(name="fetch_and_store_notion_files_task") -def fetch_and_store_notion_files_task( - access_token: str, user_id: UUID, sync_user_id: int -): - if async_engine is None: - init_worker() - assert async_engine - try: - logger.debug("Fetching and storing Notion files") - loop = asyncio.get_event_loop() - loop.run_until_complete( - fetch_and_store_notion_files_async( - async_engine, access_token, user_id, sync_user_id - ) - ) - sync_user_service.update_sync_user_status( - sync_user_id=sync_user_id, status=str(SyncsUserStatus.SYNCED) - ) - except Exception: - logger.error("Error fetching and storing Notion files") - sync_user_service.update_sync_user_status( - sync_user_id=sync_user_id, status=str(SyncsUserStatus.ERROR) - ) - - -@celery.task(name="clean_notion_user_syncs") -def clean_notion_user_syncs(): - logger.debug("Cleaning Notion user syncs") - sync_user_service.clean_notion_user_syncs() +# @celery.task(name="process_notion_sync_task") +# def process_notion_sync_task(): +# global async_engine +# assert async_engine +# loop = asyncio.get_event_loop() +# loop.run_until_complete(process_notion_sync(async_engine)) + + +# @celery.task(name="fetch_and_store_notion_files_task") +# def fetch_and_store_notion_files_task( +# access_token: str, user_id: UUID, sync_user_id: int +# ): +# if async_engine is None: +# init_worker() +# assert async_engine +# try: +# logger.debug("Fetching and storing Notion files") +# loop = asyncio.get_event_loop() +# loop.run_until_complete( +# fetch_and_store_notion_files_async( +# async_engine, access_token, user_id, sync_user_id +# ) +# ) +# sync_user_service.update_sync_user_status( +# sync_user_id=sync_user_id, status=str(SyncStatus.SYNCED) +# ) +# except Exception: +# logger.error("Error fetching and storing Notion files") +# sync_user_service.update_sync_user_status( +# sync_user_id=sync_user_id, status=str(SyncStatus.ERROR) +# ) + + +# @celery.task(name="clean_notion_user_syncs") +# def clean_notion_user_syncs(): +# logger.debug("Cleaning Notion user syncs") +# sync_user_service.clean_notion_user_syncs() # from celery.schedules import crontab diff --git a/backend/worker/quivr_worker/process/__init__.py b/backend/worker/quivr_worker/process/__init__.py index e69de29bb2d1..a8d86c84f78c 100644 --- a/backend/worker/quivr_worker/process/__init__.py +++ b/backend/worker/quivr_worker/process/__init__.py @@ -0,0 +1,11 @@ +from uuid import UUID + +from sqlalchemy.ext.asyncio import AsyncEngine + +from quivr_worker.process.processor import KnowledgeProcessor, build_processor_services + + +async def aprocess_file_task(async_engine: AsyncEngine, knowledge_id: UUID): + async with build_processor_services(async_engine) as processor_services: + km_processor = KnowledgeProcessor(services=processor_services) + await km_processor.process_knowledge(knowledge_id) diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 62773695bce9..f77647b8e6d5 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -8,7 +8,6 @@ from quivr_api.logger import get_logger from quivr_api.modules.dependencies import get_supabase_async_client from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate -from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage @@ -64,7 +63,9 @@ async def _start_session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, No @asynccontextmanager -async def build_processor_services(engine: AsyncEngine): +async def build_processor_services( + engine: AsyncEngine, +) -> AsyncGenerator[ProcessorServices, None]: async_client = await get_supabase_async_client() storage = SupabaseS3Storage(async_client) try: @@ -121,7 +122,7 @@ async def fetch_sync_knowledge( return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 async def yield_processable_kms( - self, knowledge_dto: KnowledgeDTO + self, knowledge_id: UUID ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: """Should only yield ready to process knowledges: Knowledge ready to process: @@ -129,35 +130,35 @@ async def yield_processable_kms( - Is in a status: PROCESSING | ERROR - Has an associated QuivrFile that is parsable """ - if knowledge_dto.source == KnowledgeSource.LOCAL: - async for to_process in self._yield_local(knowledge_dto): + knowledge = await self.services.knowledge_service.get_knowledge(knowledge_id) + if knowledge.source == KnowledgeSource.LOCAL: + async for to_process in self._yield_local(knowledge): yield to_process - elif knowledge_dto.source in ( + elif knowledge.source in ( KnowledgeSource.AZURE, KnowledgeSource.GOOGLE, KnowledgeSource.DROPBOX, KnowledgeSource.GITHUB, KnowledgeSource.NOTION, ): - async for to_process in self._yield_syncs(knowledge_dto): + async for to_process in self._yield_syncs(knowledge): yield to_process - elif knowledge_dto.source == KnowledgeSource.WEB: + elif knowledge.source == KnowledgeSource.WEB: raise NotImplementedError else: logger.error( - f"received knowledge : {knowledge_dto.id} with unknown source: {knowledge_dto.source}" + f"received knowledge : {knowledge.id} with unknown source: {knowledge.source}" ) - raise ValueError(f"Unknown knowledge source : {knowledge_dto.source}") + raise ValueError(f"Unknown knowledge source : {knowledge.source}") async def _yield_local( - self, knowledge: KnowledgeDTO + self, knowledge_db: KnowledgeDB ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: - if knowledge.id is None or knowledge.file_name is None: - logger.error(f"received unprocessable local knowledge : {knowledge.id} ") + if knowledge_db.id is None or knowledge_db.file_name is None: + logger.error(f"received unprocessable local knowledge : {knowledge_db.id} ") raise ValueError( - f"received unprocessable local knowledge : {knowledge.id} " + f"received unprocessable local knowledge : {knowledge_db.id} " ) - knowledge_db = await self.services.knowledge_service.get_knowledge(knowledge.id) file_data = await self.services.knowledge_service.storage.download_file( knowledge_db ) @@ -166,15 +167,12 @@ async def _yield_local( yield (knowledge_db, qfile) async def _yield_syncs( - self, knowledge_dto: KnowledgeDTO + self, parent_knowledge: KnowledgeDB ) -> AsyncGenerator[Optional[Tuple[KnowledgeDB, QuivrFile]], None]: - if knowledge_dto.id is None: - logger.error(f"received unprocessable knowledge: {knowledge_dto.id} ") + if parent_knowledge.id is None: + logger.error(f"received unprocessable knowledge: {parent_knowledge.id} ") raise ValueError - parent_knowledge = await self.services.knowledge_service.get_knowledge( - knowledge_dto.id - ) if parent_knowledge.file_name is None: logger.error(f"received unprocessable knowledge : {parent_knowledge.id} ") raise ValueError( @@ -269,8 +267,8 @@ async def _yield_syncs( with build_qfile(file_knowledge, file_data) as qfile: yield (file_knowledge, qfile) - async def process_knowledge(self, knowledge_dto: KnowledgeDTO): - async for knowledge_tuple in self.yield_processable_kms(knowledge_dto): + async def process_knowledge(self, knowledge_id: UUID): + async for knowledge_tuple in self.yield_processable_kms(knowledge_id): try: if knowledge_tuple is None: continue @@ -289,4 +287,4 @@ async def process_knowledge(self, knowledge_dto: KnowledgeDTO): ), ) except Exception as e: - logger.error(f"Error processing knowledge {knowledge_dto.id} : {e}") + logger.error(f"Error processing knowledge {knowledge_id} : {e}") diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index 75cd24343999..8b339eb417af 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -28,7 +28,7 @@ from quivr_api.modules.vector.service.vector_service import VectorService from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus -from quivr_worker.process.processor import KnowledgeProcessor, ProcessorServices +from quivr_worker.process.processor import ProcessorServices from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -130,11 +130,6 @@ async def proc_services(session: AsyncSession, request) -> ProcessorServices: ) -@pytest_asyncio.fixture(scope="function") -async def km_processor(proc_services: ProcessorServices): - return KnowledgeProcessor(proc_services) - - @pytest_asyncio.fixture(scope="function") async def sync(session: AsyncSession, user: User) -> Sync: assert user.id diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index c3ad23f2ebb9..4291bdb5db74 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -16,9 +16,11 @@ async def test_process_local_file( monkeypatch, session: AsyncSession, - km_processor: KnowledgeProcessor, + proc_services: ProcessorServices, local_knowledge_file: KnowledgeDB, ): + input_km = local_knowledge_file + async def _parse_file_mock( qfile: QuivrFile, **processor_kwargs: dict[str, Any], @@ -27,22 +29,22 @@ async def _parse_file_mock( return [Document(page_content=str(f.read()), metadata={})] monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) - km_dto = await local_knowledge_file.to_dto(get_children=False, get_parent=False) - await km_processor.process_knowledge(km_dto) + assert input_km.id + assert input_km.brains + km_processor = KnowledgeProcessor(proc_services) + await km_processor.process_knowledge(input_km.id) # Check knowledge set to processed - assert km_dto.id - assert km_dto.brains knowledge_service = km_processor.services.knowledge_service - km = await knowledge_service.get_knowledge(km_dto.id) + km = await knowledge_service.get_knowledge(input_km.id) assert km.status == KnowledgeStatus.PROCESSED - assert km.brains[0].brain_id == km_dto.brains[0]["brain_id"] + assert km.brains[0].brain_id == input_km.brains[0].brain_id # Check vectors where added vecs = list( ( await session.exec( - select(Vector).where(col(Vector.knowledge_id) == km_dto.id) + select(Vector).where(col(Vector.knowledge_id) == input_km.id) ) ).all() ) @@ -58,6 +60,10 @@ async def test_process_sync_file( proc_services: ProcessorServices, sync_knowledge_file: KnowledgeDB, ): + input_km = sync_knowledge_file + assert input_km.id + assert input_km.brains + async def _parse_file_mock( qfile: QuivrFile, **processor_kwargs: dict[str, Any], @@ -67,22 +73,19 @@ async def _parse_file_mock( km_processor = KnowledgeProcessor(proc_services) monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) - km_dto = await sync_knowledge_file.to_dto(get_children=False, get_parent=False) - await km_processor.process_knowledge(km_dto) + await km_processor.process_knowledge(input_km.id) # Check knowledge set to processed - assert km_dto.id - assert km_dto.brains knowledge_service = km_processor.services.knowledge_service - km = await knowledge_service.get_knowledge(km_dto.id) + km = await knowledge_service.get_knowledge(input_km.id) assert km.status == KnowledgeStatus.PROCESSED - assert km.brains[0].brain_id == km_dto.brains[0]["brain_id"] + assert km.brains[0].brain_id == input_km.brains[0].brain_id # Check vectors where added vecs = list( ( await session.exec( - select(Vector).where(col(Vector.knowledge_id) == km_dto.id) + select(Vector).where(col(Vector.knowledge_id) == input_km.id) ) ).all() ) @@ -98,6 +101,10 @@ async def test_process_sync_folder( proc_services: ProcessorServices, sync_knowledge_folder: KnowledgeDB, ): + input_km = sync_knowledge_folder + assert input_km.id + assert input_km.brains + async def _parse_file_mock( qfile: QuivrFile, **processor_kwargs: dict[str, Any], @@ -107,17 +114,16 @@ async def _parse_file_mock( km_processor = KnowledgeProcessor(proc_services) monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) - km_dto = await sync_knowledge_folder.to_dto(get_children=False, get_parent=False) - await km_processor.process_knowledge(km_dto) + await km_processor.process_knowledge(input_km.id) # Check knowledge set to processed - assert km_dto.id - assert km_dto.brains - assert km_dto.brains[0] + assert input_km.id + assert input_km.brains + assert input_km.brains[0] knowledge_service = km_processor.services.knowledge_service # FIXME (@AmineDiro): brain dto!! kms = await knowledge_service.get_all_knowledge_in_brain( - km_dto.brains[0]["brain_id"] + input_km.brains[0].brain_id ) # NOTE : this knowledge + 2 remote sync files @@ -125,7 +131,7 @@ async def _parse_file_mock( for km in kms: assert km.status == KnowledgeStatus.PROCESSED assert km.brains[0]["brain_id"] - assert km.brains[0]["brain_id"] == km_dto.brains[0]["brain_id"] + assert km.brains[0]["brain_id"] == input_km.brains[0].brain_id # Check vectors where added vecs = list((await session.exec(select(Vector))).all()) @@ -142,6 +148,10 @@ async def test_process_sync_folder_with_file_in_brain( proc_services: ProcessorServices, sync_knowledge_folder_with_file_in_brain: KnowledgeDB, ): + input_km = sync_knowledge_folder_with_file_in_brain + assert input_km.id + assert input_km.brains + async def _parse_file_mock( qfile: QuivrFile, **processor_kwargs: dict[str, Any], @@ -151,19 +161,16 @@ async def _parse_file_mock( km_processor = KnowledgeProcessor(proc_services) monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) - km_dto = await sync_knowledge_folder_with_file_in_brain.to_dto( - get_children=False, get_parent=False - ) - await km_processor.process_knowledge(km_dto) + await km_processor.process_knowledge(input_km.id) # Check knowledge set to processed - assert km_dto.id - assert km_dto.brains - assert km_dto.brains[0] + assert input_km.id + assert input_km.brains + assert input_km.brains[0] knowledge_service = km_processor.services.knowledge_service # FIXME (@AmineDiro): brain dto!! kms = await knowledge_service.get_all_knowledge_in_brain( - km_dto.brains[0]["brain_id"] + input_km.brains[0].brain_id ) # NOTE : this knowledge + 2 remote sync files @@ -172,8 +179,48 @@ async def _parse_file_mock( assert km.status == KnowledgeStatus.PROCESSED assert len(km.brains) == 1, "File added to the same brain multiple times" assert km.brains[0]["brain_id"] - assert km.brains[0]["brain_id"] == km_dto.brains[0]["brain_id"] + assert km.brains[0]["brain_id"] == input_km.brains[0].brain_id # Check vectors vecs = list((await session.exec(select(Vector))).all()) assert len(vecs) == 0, "File reprocessed, or folder processed " + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) +async def test_process_km_rollback( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + local_knowledge_file: KnowledgeDB, +): + input_km = local_knowledge_file + assert input_km.id + assert input_km.brains + + async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + async def _update_km_error(*args, **kwargs): + raise Exception("Error") + + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + + km_processor = KnowledgeProcessor(proc_services) + + # Set error at the end + km_processor.services.knowledge_service.update_knowledge = _update_km_error + + await km_processor.process_knowledge(input_km.id) + + # Check knowledge set to processed + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(input_km.id) + assert km.status == KnowledgeStatus.PROCESSING + vecs = list((await session.exec(select(Vector))).all()) + # Check we remove the vectors + assert len(vecs) == 0 From e318ad6f2af53047de3cdda95553c7fc0b56d915 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 17:05:11 +0200 Subject: [PATCH 24/63] rollback on error working --- .../modules/vector/repository/vectors_repository.py | 13 +++++++++---- .../modules/vector/service/vector_service.py | 4 ++-- backend/worker/quivr_worker/process/process_file.py | 4 +++- backend/worker/quivr_worker/process/processor.py | 5 +++++ backend/worker/tests/conftest.py | 7 +++++-- backend/worker/tests/test_process_file_task.py | 7 +++++-- 6 files changed, 29 insertions(+), 11 deletions(-) diff --git a/backend/api/quivr_api/modules/vector/repository/vectors_repository.py b/backend/api/quivr_api/modules/vector/repository/vectors_repository.py index 21906360b1e4..73cf281f11eb 100644 --- a/backend/api/quivr_api/modules/vector/repository/vectors_repository.py +++ b/backend/api/quivr_api/modules/vector/repository/vectors_repository.py @@ -16,12 +16,17 @@ def __init__(self, session: AsyncSession): super().__init__(session) self.session = session - async def create_vectors(self, new_vectors: List[Vector]) -> List[Vector]: + async def create_vectors( + self, new_vectors: List[Vector], autocommit: bool + ) -> List[Vector]: try: self.session.add_all(new_vectors) - await self.session.commit() - for vector in new_vectors: - await self.session.refresh(vector) + # FIXME: @AmineDiro : check if this is possible with nested transactions + if autocommit: + await self.session.commit() + for vector in new_vectors: + await self.session.refresh(vector) + await self.session.flush() return new_vectors except exc.IntegrityError: # Rollback the session if there’s an IntegrityError diff --git a/backend/api/quivr_api/modules/vector/service/vector_service.py b/backend/api/quivr_api/modules/vector/service/vector_service.py index d627837d7fad..2a0d577c44f1 100644 --- a/backend/api/quivr_api/modules/vector/service/vector_service.py +++ b/backend/api/quivr_api/modules/vector/service/vector_service.py @@ -25,7 +25,7 @@ def __init__( self.repository = repository async def create_vectors( - self, chunks: List[Document], knowledge_id: UUID + self, chunks: List[Document], knowledge_id: UUID, autocommit: bool = True ) -> List[UUID]: # Vector is created upon the user's first question asked logger.info( @@ -44,7 +44,7 @@ async def create_vectors( ) for i, chunk in enumerate(chunks) ] - created_vector = await self.repository.create_vectors(new_vectors) + created_vector = await self.repository.create_vectors(new_vectors, autocommit) return [vector.id for vector in created_vector if vector.id] diff --git a/backend/worker/quivr_worker/process/process_file.py b/backend/worker/quivr_worker/process/process_file.py index bff5e72401ba..01fb9513f632 100644 --- a/backend/worker/quivr_worker/process/process_file.py +++ b/backend/worker/quivr_worker/process/process_file.py @@ -30,7 +30,9 @@ async def store_chunks( vector_service: VectorService, ): assert knowledge.id - vector_ids = await vector_service.create_vectors(chunks, knowledge.id) + vector_ids = await vector_service.create_vectors( + chunks, knowledge.id, autocommit=False + ) logger.debug( f"Inserted {len(chunks)} chunks in vectors table for knowledge: {knowledge.id}" ) diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index f77647b8e6d5..7c40a4b7e6b4 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -269,6 +269,10 @@ async def _yield_syncs( async def process_knowledge(self, knowledge_id: UUID): async for knowledge_tuple in self.yield_processable_kms(knowledge_id): + savepoint = ( + await self.services.knowledge_service.repository.session.begin_nested() + ) + try: if knowledge_tuple is None: continue @@ -287,4 +291,5 @@ async def process_knowledge(self, knowledge_id: UUID): ), ) except Exception as e: + await savepoint.rollback() logger.error(f"Error processing knowledge {knowledge_id} : {e}") diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index 8b339eb417af..eb5616910569 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -12,7 +12,7 @@ from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.dependencies import get_supabase_client -from quivr_api.modules.knowledge.dto.inputs import AddKnowledge +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService @@ -166,11 +166,14 @@ async def local_knowledge_file( knowledge_to_add=km_to_add, upload_file=UploadFile(file=km_data, size=128, filename=km_to_add.file_name), ) - # Link it to the brain await service.link_knowledge_tree_brains( km, brains_ids=[brain_user.brain_id], user_id=user.id ) + km = await service.update_knowledge( + knowledge=km, + payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + ) return km diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index 4291bdb5db74..e13d2daa4b7b 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -34,11 +34,12 @@ async def _parse_file_mock( km_processor = KnowledgeProcessor(proc_services) await km_processor.process_knowledge(input_km.id) - # Check knowledge set to processed + # Check knowledge processed knowledge_service = km_processor.services.knowledge_service km = await knowledge_service.get_knowledge(input_km.id) assert km.status == KnowledgeStatus.PROCESSED assert km.brains[0].brain_id == input_km.brains[0].brain_id + assert km.file_sha1 is not None # Check vectors where added vecs = list( @@ -80,6 +81,7 @@ async def _parse_file_mock( km = await knowledge_service.get_knowledge(input_km.id) assert km.status == KnowledgeStatus.PROCESSED assert km.brains[0].brain_id == input_km.brains[0].brain_id + assert km.file_sha1 is not None # Check vectors where added vecs = list( @@ -132,6 +134,7 @@ async def _parse_file_mock( assert km.status == KnowledgeStatus.PROCESSED assert km.brains[0]["brain_id"] assert km.brains[0]["brain_id"] == input_km.brains[0].brain_id + assert km.file_sha1 is not None # Check vectors where added vecs = list((await session.exec(select(Vector))).all()) @@ -220,7 +223,7 @@ async def _update_km_error(*args, **kwargs): # Check knowledge set to processed knowledge_service = km_processor.services.knowledge_service km = await knowledge_service.get_knowledge(input_km.id) - assert km.status == KnowledgeStatus.PROCESSING + assert km.status == KnowledgeStatus.PROCESSING # tests are just uploaded vecs = list((await session.exec(select(Vector))).all()) # Check we remove the vectors assert len(vecs) == 0 From 6c2ba279e359779cd71390a67eedf3a7c0194c6e Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 17:05:23 +0200 Subject: [PATCH 25/63] rollback on error working --- backend/worker/quivr_worker/process/processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 7c40a4b7e6b4..bd5105e1658b 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -269,6 +269,7 @@ async def _yield_syncs( async def process_knowledge(self, knowledge_id: UUID): async for knowledge_tuple in self.yield_processable_kms(knowledge_id): + # FIXME savepoint = ( await self.services.knowledge_service.repository.session.begin_nested() ) From 5a88bf01fe6f94f51531bff5aa29399fd7b63e19 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 25 Sep 2024 17:50:37 +0200 Subject: [PATCH 26/63] knowledge sync --- .../modules/knowledge/entity/knowledge.py | 2 +- .../worker/quivr_worker/parsers/crawler.py | 1 - .../quivr_worker/process/process_url.py | 41 ----------------- .../worker/quivr_worker/process/processor.py | 35 +++++++++++--- backend/worker/quivr_worker/process/utils.py | 19 ++++++-- backend/worker/tests/conftest.py | 32 ++++++++++++- backend/worker/tests/test_process_file.py | 19 ++++++-- .../worker/tests/test_process_file_task.py | 46 +++++++++++++++++++ 8 files changed, 136 insertions(+), 59 deletions(-) delete mode 100644 backend/worker/quivr_worker/process/process_url.py diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index 32752ca28637..2174272dfa40 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -37,7 +37,7 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): primary_key=True, ), ) - file_name: str = Field(default="", max_length=255) + file_name: Optional[str] = Field(default=None, max_length=255) url: Optional[str] = Field(default=None, max_length=2048) extension: str = Field(default=".txt", max_length=100) status: str = Field(max_length=50) diff --git a/backend/worker/quivr_worker/parsers/crawler.py b/backend/worker/quivr_worker/parsers/crawler.py index b6bec671c231..d60f3b36de5b 100644 --- a/backend/worker/quivr_worker/parsers/crawler.py +++ b/backend/worker/quivr_worker/parsers/crawler.py @@ -20,7 +20,6 @@ class URL(BaseModel): async def extract_from_url(url: URL) -> str: # Extract and combine content recursively loader = PlaywrightURLLoader(urls=[url.url], remove_selectors=["header", "footer"]) - data = await loader.aload() # Now turn the data into a string logger.info(f"Extracted content from {len(data)} pages") diff --git a/backend/worker/quivr_worker/process/process_url.py b/backend/worker/quivr_worker/process/process_url.py deleted file mode 100644 index a5dabecd361d..000000000000 --- a/backend/worker/quivr_worker/process/process_url.py +++ /dev/null @@ -1,41 +0,0 @@ -from uuid import UUID - -from quivr_api.logger import get_logger -from quivr_api.modules.brain.service.brain_service import BrainService -from quivr_api.modules.vector.service.vector_service import VectorService - -from quivr_worker.files import build_file -from quivr_worker.parsers.crawler import URL, extract_from_url, slugify -from quivr_worker.process.process_file import process_file - -logger = get_logger("celery_worker") - - -async def process_url_func( - url: str, - brain_id: UUID, - knowledge_id: UUID, - brain_service: BrainService, - vector_service: VectorService, -): - crawl_website = URL(url=url) - extracted_content = await extract_from_url(crawl_website) - extracted_content_bytes = extracted_content.encode("utf-8") - file_name = slugify(crawl_website.url) + ".txt" - - brain = brain_service.get_brain_by_id(brain_id) - if brain is None: - logger.error("It seems like you're uploading knowledge to an unknown brain.") - return 1 - - with build_file(extracted_content_bytes, knowledge_id, file_name) as file_instance: - # TODO(@StanGirard): fix bug - # NOTE (@aminediro): I think this might be related to knowledge delete timeouts ? - await process_file( - file_instance=file_instance, - brain=brain, - brain_service=brain_service, - integration=None, - integration_link=None, - vector_service=vector_service, - ) diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index bd5105e1658b..35caa6c5b999 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -27,6 +27,7 @@ from sqlmodel import text from sqlmodel.ext.asyncio.session import AsyncSession +from quivr_worker.parsers.crawler import URL, extract_from_url from quivr_worker.process.process_file import parse_qfile, store_chunks from quivr_worker.process.utils import ( build_qfile, @@ -144,7 +145,8 @@ async def yield_processable_kms( async for to_process in self._yield_syncs(knowledge): yield to_process elif knowledge.source == KnowledgeSource.WEB: - raise NotImplementedError + async for to_process in self._yield_web(knowledge): + yield to_process else: logger.error( f"received knowledge : {knowledge.id} with unknown source: {knowledge.source}" @@ -166,6 +168,22 @@ async def _yield_local( with build_qfile(knowledge_db, file_data) as qfile: yield (knowledge_db, qfile) + async def _yield_web( + self, knowledge_db: KnowledgeDB + ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: + if knowledge_db.id is None or knowledge_db.url is None: + logger.error(f"received unprocessable web knowledge : {knowledge_db.id} ") + raise ValueError( + f"received unprocessable web knowledge : {knowledge_db.id} " + ) + crawl_website = URL(url=knowledge_db.url) + extracted_content = await extract_from_url(crawl_website) + extracted_content_bytes = extracted_content.encode("utf-8") + knowledge_db.file_sha1 = compute_sha1(extracted_content_bytes) + knowledge_db.file_size = len(extracted_content_bytes) + with build_qfile(knowledge_db, extracted_content_bytes) as qfile: + yield (knowledge_db, qfile) + async def _yield_syncs( self, parent_knowledge: KnowledgeDB ) -> AsyncGenerator[Optional[Tuple[KnowledgeDB, QuivrFile]], None]: @@ -192,8 +210,11 @@ async def _yield_syncs( # Get associated sync sync = await self.services.sync_service.get_sync_by_id(parent_knowledge.sync_id) if sync.credentials is None: - logger.error(f"can't process sync file. sync {sync.id} has no credentials") - return + logger.error( + f"can't process knowledge: {parent_knowledge.id}. sync {sync.id} has no credentials" + ) + raise ValueError("no associated credentials") + provider_name = SyncProvider(sync.provider.lower()) sync_provider = self.services.syncprovider_mapping[provider_name] @@ -228,7 +249,8 @@ async def _yield_syncs( if existing_km is not None: # SyncKnowledge already exists => # It's already processed in some other brain so just link it and move on if it is Processed - # ELSE reprocess the file + # ELSE + # reprocess the file km_brains = {km_brain.brain_id for km_brain in existing_km.brains} for brain in filter( lambda b: b.brain_id not in km_brains, @@ -264,6 +286,7 @@ async def _yield_syncs( credentials=sync.credentials, ) file_knowledge.file_sha1 = compute_sha1(file_data) + file_knowledge.file_size = len(file_data) with build_qfile(file_knowledge, file_data) as qfile: yield (file_knowledge, qfile) @@ -273,7 +296,6 @@ async def process_knowledge(self, knowledge_id: UUID): savepoint = ( await self.services.knowledge_service.repository.session.begin_nested() ) - try: if knowledge_tuple is None: continue @@ -288,7 +310,8 @@ async def process_knowledge(self, knowledge_id: UUID): await self.services.knowledge_service.update_knowledge( knowledge, KnowledgeUpdate( - status=KnowledgeStatus.PROCESSED, file_sha1=knowledge.file_sha1 + status=KnowledgeStatus.PROCESSED, + file_sha1=knowledge.file_sha1, ), ) except Exception as e: diff --git a/backend/worker/quivr_worker/process/utils.py b/backend/worker/quivr_worker/process/utils.py index 5a11375df358..e9c73362311a 100644 --- a/backend/worker/quivr_worker/process/utils.py +++ b/backend/worker/quivr_worker/process/utils.py @@ -20,6 +20,8 @@ ) from quivr_core.files.file import FileExtension, QuivrFile +from quivr_worker.parsers.crawler import slugify + celery_inspector = celery.control.inspect() logger = get_logger("celery_worker") @@ -79,18 +81,25 @@ def build_qfile( knowledge: KnowledgeDB, file_data: bytes ) -> Generator[QuivrFile, None, None]: assert knowledge.id - assert knowledge.file_name assert knowledge.file_sha1 + if knowledge.source == KnowledgeSource.WEB: + file_name = slugify(knowledge.url) + ".txt" + extension = FileExtension.txt + else: + assert knowledge.file_name + file_name = knowledge.file_name + extension = FileExtension(knowledge.extension) + with create_temp_file( - file_data=file_data, file_name_ext=knowledge.file_name + file_data=file_data, file_name_ext=file_name ) as tmp_file_path: qfile = QuivrFile( id=knowledge.id, - original_filename=knowledge.file_name, + original_filename=file_name, path=tmp_file_path, file_sha1=knowledge.file_sha1, - file_extension=FileExtension(knowledge.extension), - file_size=knowledge.file_size, + file_extension=extension, + file_size=len(file_data), metadata={ "date": time.strftime("%Y%m%d"), "file_name": knowledge.file_name, diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index eb5616910569..4bc7c9f8d90a 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -13,7 +13,7 @@ from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.dependencies import get_supabase_client from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.tests.conftest import FakeStorage @@ -295,6 +295,36 @@ async def sync_knowledge_folder_with_file_in_brain( return km +@pytest_asyncio.fixture(scope="function") +async def web_knowledge( + session: AsyncSession, + user: User, + brain_user: Brain, +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + + km = KnowledgeDB( + file_name=None, + url="www.quivr.app", + extension=".html", + status=KnowledgeStatus.PROCESSING, + source=KnowledgeSource.WEB, + source_link="www.quivr.app", + file_size=0, + file_sha1=None, + user_id=user.id, + brains=[brain_user], + is_folder=False, + ) + + session.add(km) + await session.commit() + await session.refresh(km) + + return km + + @pytest.fixture def qfile_instance(tmp_path) -> QuivrFile: data = "This is some test data." diff --git a/backend/worker/tests/test_process_file.py b/backend/worker/tests/test_process_file.py index cd27813b04d7..7d57a5304c11 100644 --- a/backend/worker/tests/test_process_file.py +++ b/backend/worker/tests/test_process_file.py @@ -3,6 +3,7 @@ import pytest from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_worker.parsers.crawler import slugify from quivr_worker.process.process_file import parse_qfile from quivr_worker.process.utils import build_qfile @@ -20,10 +21,20 @@ def test_build_qfile_fail(local_knowledge_file: KnowledgeDB): pass local_knowledge_file.id = uuid4() - local_knowledge_file.file_name = None - with pytest.raises(AssertionError): - with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as _: - pass + + +def test_build_qfile_web(web_knowledge: KnowledgeDB): + random_bytes = os.urandom(128) + web_knowledge.file_sha1 = "sha1" + + with build_qfile(knowledge=web_knowledge, file_data=random_bytes) as file: + assert file.id == web_knowledge.id + assert file.file_size == 128 + assert file.original_filename == slugify(web_knowledge.url) + ".txt" + assert file.file_extension == ".txt" + if web_knowledge.metadata_: + assert web_knowledge.metadata_.items() <= file.metadata.items() + assert file.brain_id is None def test_build_qfile(local_knowledge_file: KnowledgeDB): diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index e13d2daa4b7b..b1d73d156ad1 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -53,6 +53,52 @@ async def _parse_file_mock( assert vecs[0].metadata_ is not None +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) +async def test_process_web_file( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + web_knowledge: KnowledgeDB, +): + input_km = web_knowledge + + async def _extract_url(url: str) -> str: + return "quivr has the best rag" + + async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + monkeypatch.setattr("quivr_worker.process.processor.extract_from_url", _extract_url) + assert input_km.id + assert input_km.brains + km_processor = KnowledgeProcessor(proc_services) + await km_processor.process_knowledge(input_km.id) + + # Check knowledge processed + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(input_km.id) + assert km.status == KnowledgeStatus.PROCESSED + assert km.brains[0].brain_id == input_km.brains[0].brain_id + assert km.file_sha1 is not None + + # Check vectors where added + vecs = list( + ( + await session.exec( + select(Vector).where(col(Vector.knowledge_id) == input_km.id) + ) + ).all() + ) + assert len(vecs) > 0 + assert vecs[0].metadata_ is not None + + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("proc_services", [0], indirect=True) async def test_process_sync_file( From 1d5761d185e5370757bc3c397afd8dc0b0dbd910 Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 10:27:09 +0200 Subject: [PATCH 27/63] added test sync file in different brain --- .../knowledge/controller/knowledge_routes.py | 1 - .../knowledge/service/knowledge_service.py | 2 +- backend/worker/tests/conftest.py | 71 +++++++++++++++++++ .../worker/tests/test_process_file_task.py | 47 ++++++++++++ 4 files changed, 119 insertions(+), 2 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index aa089eebae00..20adfe94c047 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -265,7 +265,6 @@ async def delete_knowledge( ): try: km = await knowledge_service.get_knowledge(knowledge_id) - if km.user_id != current_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index d54e375022bd..6a680aad1d4b 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -245,8 +245,8 @@ async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeRespo # recursively deletes files deleted_km = await self.repository.remove_knowledge(knowledge) + # TODO: remove storage asynchronously in background task or in some task await asyncio.gather(*[self.storage.remove_file(p) for p in km_paths]) - return deleted_km except Exception as e: logger.error(f"Error while remove knowledge : {e}") diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index 4bc7c9f8d90a..53b3f6688cd4 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -102,6 +102,26 @@ async def brain_user(session, user: User) -> Brain: return brain_1 +@pytest_asyncio.fixture(scope="function") +async def brain_user2(session, user: User) -> Brain: + assert user.id + brain = Brain( + name="test_brain2", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain) + await session.commit() + await session.refresh(brain) + assert brain.brain_id + brain_user = BrainUserDB( + brain_id=brain.brain_id, user_id=user.id, default_brain=True, rights="Owner" + ) + session.add(brain_user) + await session.commit() + return brain + + # NOTE: param sets the number of sync file the provider returns @pytest_asyncio.fixture(scope="function") async def proc_services(session: AsyncSession, request) -> ProcessorServices: @@ -244,6 +264,57 @@ async def sync_knowledge_folder( return km +@pytest_asyncio.fixture(scope="function") +async def sync_knowledge_folder_with_file_in_other_brain( + session: AsyncSession, + user: User, + brain_user: Brain, + brain_user2: Brain, + sync: Sync, +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + file = KnowledgeDB( + file_name="file", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://test/file1", + file_size=10, + file_sha1="test", + user_id=user.id, + brains=[brain_user2], + parent=None, + is_folder=False, + # NOTE: See FakeSync Implementation + sync_file_id="file-0", + sync=sync, + ) + + km = KnowledgeDB( + file_name="folder1", + extension=".txt", + status=KnowledgeStatus.PROCESSING, + source=SyncProvider.GOOGLE, + source_link="drive://test/folder1", + file_size=0, + file_sha1=None, + user_id=user.id, + brains=[brain_user], + parent=None, + is_folder=True, + sync_file_id="id1", + sync=sync, + ) + + session.add(file) + session.add(km) + await session.commit() + await session.refresh(km) + + return km + + @pytest_asyncio.fixture(scope="function") async def sync_knowledge_folder_with_file_in_brain( session: AsyncSession, diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index b1d73d156ad1..25408a24fc6a 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -235,6 +235,53 @@ async def _parse_file_mock( assert len(vecs) == 0, "File reprocessed, or folder processed " +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [1], indirect=True) +async def test_process_sync_folder_with_file_in_other_brain( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + sync_knowledge_folder_with_file_in_other_brain: KnowledgeDB, +): + input_km = sync_knowledge_folder_with_file_in_other_brain + assert input_km.id + assert input_km.brains + + async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + km_processor = KnowledgeProcessor(proc_services) + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + await km_processor.process_knowledge(input_km.id) + + # Check knowledge set to processed + assert input_km.id + assert input_km.brains + assert input_km.brains[0] + knowledge_service = km_processor.services.knowledge_service + # FIXME (@AmineDiro): brain dto!! + kms = await knowledge_service.get_all_knowledge_in_brain( + input_km.brains[0].brain_id + ) + + assert len(kms) == 2 + for km in kms: + assert km.status == KnowledgeStatus.PROCESSED + assert len(km.brains) >= 1, "File added to the same brain multiple times" + assert km.brains[0]["brain_id"] + assert input_km.brains[0].brain_id in {b["brain_id"] for b in km.brains} + if len(km.brains) > 1: + assert len({b["brain_id"] for b in km.brains}) == 2 + + # Check vectors + vecs = list((await session.exec(select(Vector))).all()) + assert len(vecs) == 0, "File reprocessed, or folder processed " + + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("proc_services", [0], indirect=True) async def test_process_km_rollback( From 1b3915c07fcae2e8d467e415a2620d5efe0c861d Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 11:26:05 +0200 Subject: [PATCH 28/63] unlink from brain --- .../knowledge/controller/knowledge_routes.py | 130 +++++++++++------- .../quivr_api/modules/knowledge/dto/inputs.py | 5 + .../knowledge/repository/knowledges.py | 42 ++++++ .../knowledge/service/knowledge_service.py | 9 ++ .../knowledge/tests/test_knowledge_service.py | 109 ++++++++++++++- .../quivr_api/modules/rag_service/utils.py | 2 +- backend/pyproject.toml | 15 +- .../worker/quivr_worker/process/processor.py | 9 +- 8 files changed, 260 insertions(+), 61 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 20adfe94c047..53525c5b5921 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -19,9 +19,9 @@ AddKnowledge, KnowledgeUpdate, LinkKnowledgeBrain, + UnlinkKnowledgeBrain, ) from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.service.knowledge_exceptions import ( KnowledgeDeleteError, KnowledgeForbiddenAccess, @@ -66,33 +66,6 @@ async def list_knowledge_in_brain_endpoint( return {"knowledges": knowledges} -@knowledge_router.delete( - "/knowledge/{knowledge_id}", - dependencies=[ - Depends(AuthBearer()), - Depends(has_brain_authorization(RoleEnum.Owner)), - ], - tags=["Knowledge"], -) -async def delete_knowledge_brain( - knowledge_id: UUID, - knowledge_service: KnowledgeService = Depends(get_knowledge_service), - current_user: UserIdentity = Depends(get_current_user), - brain_id: UUID = Query(..., description="The ID of the brain"), -): - """ - Delete a specific knowledge from a brain. - """ - - knowledge = await knowledge_service.get_knowledge(knowledge_id) - file_name = knowledge.file_name if knowledge.file_name else knowledge.url - await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id) - - return { - "message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}." - } - - @knowledge_router.get( "/knowledge/{knowledge_id}/signed_download_url", dependencies=[Depends(AuthBearer())], @@ -253,6 +226,33 @@ async def update_knowledge( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) +@knowledge_router.delete( + "/knowledge/{knowledge_id}", + dependencies=[ + Depends(AuthBearer()), + Depends(has_brain_authorization(RoleEnum.Owner)), + ], + tags=["Knowledge"], +) +async def delete_knowledge_brain( + knowledge_id: UUID, + knowledge_service: KnowledgeService = Depends(get_knowledge_service), + current_user: UserIdentity = Depends(get_current_user), + brain_id: UUID = Query(..., description="The ID of the brain"), +): + """ + Delete a specific knowledge from a brain. + """ + + knowledge = await knowledge_service.get_knowledge(knowledge_id) + file_name = knowledge.file_name if knowledge.file_name else knowledge.url + await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id) + + return { + "message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}." + } + + @knowledge_router.delete( "/knowledge/{knowledge_id}", status_code=status.HTTP_202_ACCEPTED, @@ -284,6 +284,7 @@ async def delete_knowledge( "/knowledge/link_to_brains/", status_code=status.HTTP_201_CREATED, response_model=List[KnowledgeDTO], + tags=["Knowledge"], ) async def link_knowledge_to_brain( link_request: LinkKnowledgeBrain, @@ -298,7 +299,29 @@ async def link_knowledge_to_brain( if len(brains_ids) == 0: return "empty brain list" - async def _send_knowledge_process(knowledge: KnowledgeDB): + if knowledge_dto.id is None: + if knowledge_dto.sync_file_id is None: + raise HTTPException( + status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unknown knowledge entity" + ) + # Create a knowledge from this sync + knowledge = await knowledge_service.create_knowledge( + user_id=current_user.id, + knowledge_to_add=AddKnowledge(**knowledge_dto.model_dump()), + upload_file=None, + ) + linked_kms = await knowledge_service.link_knowledge_tree_brains( + knowledge, brains_ids=brains_ids, user_id=current_user.id + ) + + else: + linked_kms = await knowledge_service.link_knowledge_tree_brains( + knowledge_dto.id, brains_ids=brains_ids, user_id=current_user.id + ) + + for knowledge in filter( + lambda k: k.status != KnowledgeStatus.PROCESSED, linked_kms + ): assert knowledge.id upload_notification = notification_service.add_notification( CreateNotification( @@ -321,29 +344,36 @@ async def _send_knowledge_process(knowledge: KnowledgeDB): payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), ) - if knowledge_dto.id is None: - if knowledge_dto.sync_file_id is None: - raise HTTPException( - status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Unknown knowledge entity" - ) - # Create a knowledge from this sync - knowledge = await knowledge_service.create_knowledge( - user_id=current_user.id, - knowledge_to_add=AddKnowledge(**knowledge_dto.model_dump()), - upload_file=None, - ) - linked_kms = await knowledge_service.link_knowledge_tree_brains( - knowledge, brains_ids=brains_ids, user_id=current_user.id + return await asyncio.gather(*[k.to_dto() for k in linked_kms]) + + +@knowledge_router.delete( + "/knowledge/unlink_from_brains/", + response_model=List[KnowledgeDTO] | None, + tags=["Knowledge"], +) +async def unlink_knowledge_from_brain( + unlink_request: UnlinkKnowledgeBrain, + knowledge_service: KnowledgeService = Depends(get_knowledge_service), + current_user: UserIdentity = Depends(get_current_user), +): + brains_ids, knowledge_id = unlink_request.brain_ids, unlink_request.knowledge_id + + if len(brains_ids) == 0: + raise HTTPException( + status_code=status.HTTP_204_NO_CONTENT, ) - else: - linked_kms = await knowledge_service.link_knowledge_tree_brains( - knowledge_dto.id, brains_ids=brains_ids, user_id=current_user.id + km = await knowledge_service.get_knowledge(knowledge_id) + if km.user_id != current_user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to remove this knowledge.", ) - for knowledge in filter( - lambda k: k.status != KnowledgeStatus.PROCESSED, linked_kms - ): - await _send_knowledge_process(knowledge=knowledge) + unlinked_kms = await knowledge_service.unlink_knowledge_tree_brains( + knowledge=knowledge_id, brains_ids=brains_ids, user_id=current_user.id + ) - return await asyncio.gather(*[k.to_dto() for k in linked_kms]) + if unlinked_kms: + return await asyncio.gather(*[k.to_dto() for k in unlinked_kms]) diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py index cc0ab8958389..0290e057ae0c 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py @@ -51,3 +51,8 @@ class LinkKnowledgeBrain(BaseModel): bulk_id: UUID knowledge: KnowledgeDTO brain_ids: List[UUID] + + +class UnlinkKnowledgeBrain(BaseModel): + knowledge_id: UUID + brain_ids: List[UUID] diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 361bc1f475bf..418656f8cdaf 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -70,6 +70,48 @@ async def update_knowledge( logger.error(f"Error updating knowledge {e}") raise KnowledgeUpdateError + async def unlink_knowledge_tree_brains( + self, knowledge: KnowledgeDB, brains_ids: List[UUID], user_id: UUID + ) -> list[KnowledgeDB] | None: + assert knowledge.id, "can't link knowledge not in db" + try: + # TODO: Move check somewhere else + stmt = ( + select(Brain) + .join(BrainUserDB, col(Brain.brain_id) == col(BrainUserDB.brain_id)) + .where( + and_( + col(Brain.brain_id).in_(brains_ids), + BrainUserDB.user_id == user_id, + BrainUserDB.rights == RoleEnum.Owner, + ) + ) + ) + unlink_brains = list((await self.session.exec(stmt)).unique().all()) + unlink_brain_ids = {b.brain_id for b in unlink_brains} + + if len(unlink_brains) == 0: + logger.info( + f"No brains for user_id={user_id}, brains_list={brains_ids}" + ) + return + children = await self.get_knowledge_tree(knowledge.id) + all_kms = [knowledge, *children] + for k in all_kms: + k.brains = list( + filter(lambda b: b.brain_id not in unlink_brain_ids, k.brains) + ) + [self.session.add(k) for k in all_kms] + await self.session.commit() + [await self.session.refresh(k) for k in all_kms] + return all_kms + except IntegrityError: + await self.session.rollback() + raise + except Exception: + await self.session.rollback() + raise + async def link_knowledge_tree_brains( self, knowledge: KnowledgeDB, brains_ids: List[UUID], user_id: UUID ) -> list[KnowledgeDB]: diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 6a680aad1d4b..3d260c6edc4e 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -293,3 +293,12 @@ async def link_knowledge_tree_brains( return await self.repository.link_knowledge_tree_brains( knowledge, brains_ids=brains_ids, user_id=user_id ) + + async def unlink_knowledge_tree_brains( + self, knowledge: KnowledgeDB | UUID, brains_ids: List[UUID], user_id: UUID + ) -> List[KnowledgeDB] | None: + if isinstance(knowledge, UUID): + knowledge = await self.repository.get_knowledge_by_id(knowledge) + return await self.repository.unlink_knowledge_tree_brains( + knowledge, brains_ids=brains_ids, user_id=user_id + ) diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index 39c1f8759d20..c764c92833f9 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -79,6 +79,46 @@ async def brain_user(session, user: User) -> Brain: return brain_1 +@pytest_asyncio.fixture(scope="function") +async def brain_user2(session, user: User) -> Brain: + assert user.id + brain = Brain( + name="test_brain2", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain) + await session.commit() + await session.refresh(brain) + assert brain.brain_id + brain_user = BrainUserDB( + brain_id=brain.brain_id, user_id=user.id, default_brain=True, rights="Owner" + ) + session.add(brain_user) + await session.commit() + return brain + + +@pytest_asyncio.fixture(scope="function") +async def brain_user3(session, user: User) -> Brain: + assert user.id + brain = Brain( + name="test_brain2", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain) + await session.commit() + await session.refresh(brain) + assert brain.brain_id + brain_user = BrainUserDB( + brain_id=brain.brain_id, user_id=user.id, default_brain=True, rights="Owner" + ) + session.add(brain_user) + await session.commit() + return brain + + @pytest_asyncio.fixture(scope="function") async def test_data(session: AsyncSession) -> TestData: user_1 = ( @@ -1012,5 +1052,72 @@ async def test_link_knowledge_brain( @pytest.mark.asyncio(loop_scope="session") -async def test_link_knowledge_brain_existing_brains(): +async def test_link_knowledge_brain_existing_brains( + session: AsyncSession, user: User, brain_user: Brain +): """test knowledge already in brain and we add it to the same brain because we added his parent""" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_unlink_knowledge_brain( + session: AsyncSession, + user: User, + brain_user: Brain, + brain_user2: Brain, + brain_user3: Brain, +): + assert user.id + assert brain_user.brain_id + assert brain_user2.brain_id + assert brain_user3.brain_id + + root_folder = KnowledgeDB( + file_name="folder", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=4, + file_sha1=None, + brains=[brain_user, brain_user2], + children=[], + user_id=user.id, + is_folder=True, + ) + file = KnowledgeDB( + file_name="file_2", + extension="", + status="UPLOADED", + source="local", + source_link="local", + file_size=10, + file_sha1=None, + user_id=user.id, + parent=root_folder, + # 1 additional brain + brains=[brain_user, brain_user2, brain_user3], + ) + session.add(file) + session.add(root_folder) + await session.commit() + await session.refresh(root_folder) + await session.refresh(file) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + await service.unlink_knowledge_tree_brains( + root_folder, + brains_ids=[brain_user.brain_id, brain_user2.brain_id], + user_id=user.id, + ) + kms = await service.get_all_knowledge_in_brain(brain_id=brain_user.brain_id) + assert len(kms) == 0 + + kms = await service.get_all_knowledge_in_brain(brain_id=brain_user2.brain_id) + assert len(kms) == 0 + + kms = await service.get_all_knowledge_in_brain(brain_id=brain_user3.brain_id) + assert len(kms) == 1 + assert kms[0].id == file.id diff --git a/backend/api/quivr_api/modules/rag_service/utils.py b/backend/api/quivr_api/modules/rag_service/utils.py index 068a2db28c5e..afc12082eac8 100644 --- a/backend/api/quivr_api/modules/rag_service/utils.py +++ b/backend/api/quivr_api/modules/rag_service/utils.py @@ -68,7 +68,7 @@ async def generate_source( try: file_name = doc.metadata["file_name"] file_path = await knowledge_service.get_knowledge_storage_path( - file_name=file_name, brain_id=brain_id + file_name=file_name, brain_id=brain_id ) if file_path in generated_urls: source_url = generated_urls[file_path] diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 491e98cfe6f4..a196d32a2570 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -8,10 +8,7 @@ authors = [ { name = "Chloé Daems", email = "chloe@quivr.app" }, { name = "Jacopo Chevallard", email = "jacopo@quivr.app" }, ] -dependencies = [ - "packaging>=22.0", - "langchain-anthropic>=0.1.23", -] +dependencies = ["packaging>=22.0", "langchain-anthropic>=0.1.23"] readme = "README.md" requires-python = ">= 3.11" @@ -41,7 +38,15 @@ dev-dependencies = [ ] [tool.rye.workspace] -members = [".", "core", "worker", "api", "docs", "core/examples/chatbot", "core/MegaParse"] +members = [ + ".", + "core", + "worker", + "api", + "docs", + "core/examples/chatbot", + "core/MegaParse", +] [tool.hatch.metadata] allow-direct-references = true diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 35caa6c5b999..6e5a6107c414 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -247,10 +247,11 @@ async def _yield_syncs( for sync_file in sync_files: existing_km = syncfile_to_knowledge.get(sync_file.id) if existing_km is not None: - # SyncKnowledge already exists => - # It's already processed in some other brain so just link it and move on if it is Processed - # ELSE - # reprocess the file + # NOTE: + # The parent_knowledge was just added (we are processing it) + # This implies that we could have sync children that were processed before + # IF SyncKnowledge already exists => It's already processed in some other brain + # => Link it to the parent brains and move on if it is PROCESSED ELSE Reprocess the file km_brains = {km_brain.brain_id for km_brain in existing_km.brains} for brain in filter( lambda b: b.brain_id not in km_brains, From 74a8989720b7d40b8e164a6b6e6874b17fbc9612 Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 11:43:35 +0200 Subject: [PATCH 29/63] profiler for tests --- backend/pyproject.toml | 1 + backend/requirements-dev.lock | 5 +++++ backend/worker/tests/test_process_file.py | 8 ++++---- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index a196d32a2570..d27c31ce4b67 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -35,6 +35,7 @@ dev-dependencies = [ "pytest-cov>=5.0.0", "tox>=4.0.0", "chainlit>=1.1.306", + "pytest-profiling>=1.7.0", ] [tool.rye.workspace] diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index 0623ba330324..dbd87016b7d4 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -281,6 +281,8 @@ googleapis-common-protos==1.63.2 # via opentelemetry-exporter-otlp-proto-http gotrue==2.7.0 # via supabase +gprof2dot==2024.6.6 + # via pytest-profiling greenlet==3.0.3 # via playwright # via sqlalchemy @@ -892,12 +894,14 @@ pytest==8.3.2 # via pytest-cov # via pytest-dotenv # via pytest-mock + # via pytest-profiling # via pytest-xdist pytest-asyncio==0.24.0 pytest-benchmark==4.0.0 pytest-cov==5.0.0 pytest-dotenv==0.5.2 pytest-mock==3.14.0 +pytest-profiling==1.7.0 pytest-xdist==3.6.1 python-dateutil==2.9.0.post0 # via botocore @@ -1040,6 +1044,7 @@ six==1.16.0 # via langdetect # via markdownify # via posthog + # via pytest-profiling # via python-dateutil # via stone # via unstructured-client diff --git a/backend/worker/tests/test_process_file.py b/backend/worker/tests/test_process_file.py index 7d57a5304c11..2e41acd99fa8 100644 --- a/backend/worker/tests/test_process_file.py +++ b/backend/worker/tests/test_process_file.py @@ -1,4 +1,4 @@ -import os +from random import randbytes from uuid import uuid4 import pytest @@ -9,7 +9,7 @@ def test_build_qfile_fail(local_knowledge_file: KnowledgeDB): - random_bytes = os.urandom(128) + random_bytes = randbytes(128) local_knowledge_file.file_sha1 = None with pytest.raises(AssertionError): with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as _: @@ -24,7 +24,7 @@ def test_build_qfile_fail(local_knowledge_file: KnowledgeDB): def test_build_qfile_web(web_knowledge: KnowledgeDB): - random_bytes = os.urandom(128) + random_bytes = randbytes(128) web_knowledge.file_sha1 = "sha1" with build_qfile(knowledge=web_knowledge, file_data=random_bytes) as file: @@ -38,7 +38,7 @@ def test_build_qfile_web(web_knowledge: KnowledgeDB): def test_build_qfile(local_knowledge_file: KnowledgeDB): - random_bytes = os.urandom(128) + random_bytes = randbytes(128) local_knowledge_file.file_sha1 = "sha1" with build_qfile(knowledge=local_knowledge_file, file_data=random_bytes) as file: From 868058f54ef1b1d41c88a6586e912ae1481549ee Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 17:22:18 +0200 Subject: [PATCH 30/63] notifier update --- .../knowledge/controller/knowledge_routes.py | 4 +- .../knowledge/service/knowledge_service.py | 43 +++++ backend/worker/quivr_worker/celery_monitor.py | 179 ++++++++++-------- backend/worker/quivr_worker/celery_worker.py | 39 ++-- backend/worker/quivr_worker/process/README.md | 67 +++++++ .../worker/quivr_worker/process/processor.py | 134 ++++++------- backend/worker/quivr_worker/process/utils.py | 40 +++- backend/worker/tests/conftest.py | 66 +++++++ .../worker/tests/test_process_file_task.py | 125 +++++++----- 9 files changed, 464 insertions(+), 233 deletions(-) create mode 100644 backend/worker/quivr_worker/process/README.md diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 53525c5b5921..d7991827162d 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -320,7 +320,9 @@ async def link_knowledge_to_brain( ) for knowledge in filter( - lambda k: k.status != KnowledgeStatus.PROCESSED, linked_kms + lambda k: k.status + not in [KnowledgeStatus.PROCESSED, KnowledgeStatus.PROCESSING], + linked_kms, ): assert knowledge.id upload_notification = notification_service.add_notification( diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 3d260c6edc4e..c5b03aec2500 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -31,6 +31,7 @@ KnowledgeForbiddenAccess, UploadError, ) +from quivr_api.modules.sync.entity.sync_models import SyncFile from quivr_api.modules.upload.service.upload_file import check_file_exists logger = get_logger(__name__) @@ -53,6 +54,48 @@ async def get_knowledge_sync(self, sync_id: int) -> KnowledgeDTO: km = await km.to_dto() return km + async def create_or_link_sync_knowledge( + self, + syncfile_to_knowledge: dict[str, KnowledgeDB], + parent_knowledge: KnowledgeDB, + sync_file: SyncFile, + ): + existing_km = syncfile_to_knowledge.get(sync_file.id) + if existing_km is not None: + # NOTE: function called in worker processor + # The parent_knowledge was just added (we are processing it) + # This implies that we could have sync children that were processed before + # IF SyncKnowledge already exists => It's already processed in some other brain + # => Link it to the parent brains and move on if it is PROCESSED ELSE Reprocess the file + km_brains = {km_brain.brain_id for km_brain in existing_km.brains} + for brain in filter( + lambda b: b.brain_id not in km_brains, + parent_knowledge.brains, + ): + await self.repository.link_to_brain( + existing_km, brain_id=brain.brain_id + ) + return existing_km + else: + # create sync file knowledge + # automagically gets the brains associated with the parent + file_knowledge = await self.create_knowledge( + user_id=parent_knowledge.user_id, + knowledge_to_add=AddKnowledge( + file_name=sync_file.name, + is_folder=sync_file.is_folder, + extension=sync_file.extension, + source=parent_knowledge.source, # same as parent + source_link=sync_file.web_view_link, + parent_id=parent_knowledge.id, + sync_id=parent_knowledge.sync_id, + sync_file_id=sync_file.id, + ), + status=KnowledgeStatus.PROCESSING, + upload_file=None, + ) + return file_knowledge + # TODO: this is temporary fix for getting knowledge path. # KM storage path should be unrelated to brain async def get_knowledge_storage_path( diff --git a/backend/worker/quivr_worker/celery_monitor.py b/backend/worker/quivr_worker/celery_monitor.py index 245e8dcc1914..0d8a6c94b32d 100644 --- a/backend/worker/quivr_worker/celery_monitor.py +++ b/backend/worker/quivr_worker/celery_monitor.py @@ -1,4 +1,5 @@ import asyncio +import os import threading from enum import Enum from queue import Queue @@ -8,7 +9,7 @@ from celery.result import AsyncResult from quivr_api.celery_config import celery from quivr_api.logger import get_logger -from quivr_api.modules.dependencies import async_engine +from quivr_api.models.settings import settings from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.notification.dto.inputs import NotificationUpdatableProperties @@ -17,10 +18,23 @@ NotificationService, ) from quivr_core.models import KnowledgeStatus +from sqlalchemy.ext.asyncio import create_async_engine +from sqlmodel import text from sqlmodel.ext.asyncio.session import AsyncSession logger = get_logger("notifier_service", "notifier_service.log") -notification_service = NotificationService() + +async_engine = create_async_engine( + settings.pg_database_async_url, + connect_args={"server_settings": {"application_name": "quivr-monitor"}}, + echo=True if os.getenv("ORM_DEBUG") else False, + future=True, + pool_pre_ping=True, + pool_size=5, # NOTE: no bouncer for now, if 6 process workers => 6 + pool_recycle=1800, + isolation_level="AUTOCOMMIT", +) + queue = Queue() @@ -31,80 +45,102 @@ class TaskStatus(str, Enum): class TaskIdentifier(str, Enum): PROCESS_FILE_TASK = "process_file_task" - PROCESS_CRAWL_TASK = "process_crawl_task" @dataclass class TaskEvent: task_id: str - brain_id: UUID task_name: TaskIdentifier notification_id: str knowledge_id: UUID status: TaskStatus -async def handler_loop(): - session = AsyncSession(async_engine, expire_on_commit=False, autoflush=False) - knowledge_service = KnowledgeService(KnowledgeRepository(session)) +async def handle_error_task( + task: TaskEvent, + knowledge_service: KnowledgeService, + notification_service: NotificationService, +): + logger.error( + f"task {task.task_id} process_file_task. Sending notifition {task.notification_id}" + ) + notification_service.update_notification_by_id( + task.notification_id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.ERROR, + description=("An error occurred while processing the file"), + ), + ) + logger.error( + f"task {task.task_id} process_file_task failed. Updating knowledge {task.knowledge_id} to Error" + ) + if task.knowledge_id: + await knowledge_service.update_status_knowledge( + task.knowledge_id, KnowledgeStatus.ERROR + ) + logger.error( + f"task {task.task_id} process_file_task . Updating knowledge {task.knowledge_id} status to Error" + ) + + +async def handle_success_task( + task: TaskEvent, + knowledge_service: KnowledgeService, + notification_service: NotificationService, +): + logger.info( + f"task {task.task_id} process_file_task succeeded. Sending notification {task.notification_id}" + ) + notification_service.update_notification_by_id( + task.notification_id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.SUCCESS, + description=( + "Your file has been properly uploaded!" + if task.task_name == TaskIdentifier.PROCESS_FILE_TASK + else "Your URL has been properly crawled!" + ), + ), + ) + if task.knowledge_id: + await knowledge_service.update_status_knowledge( + knowledge_id=task.knowledge_id, + status=KnowledgeStatus.UPLOADED, + ) + logger.info( + f"task {task.task_id} process_file_task failed. Updating knowledge {task.knowledge_id} to UPLOADED" + ) - logger.info("Initialized knowledge_service. Listening to task event...") - while True: - try: - event: TaskEvent = queue.get() - if event.status == TaskStatus.FAILED: - logger.error( - f"task {event.task_id} process_file_task. Sending notifition {event.notification_id}" - ) - notification_service.update_notification_by_id( - event.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.ERROR, - description=( - "An error occurred while processing the file" - if event.task_name == TaskIdentifier.PROCESS_FILE_TASK - else "An error occurred while processing the URL" - ), - ), - ) - logger.error( - f"task {event.task_id} process_file_task failed. Updating knowledge {event.knowledge_id} to Error" - ) - if event.knowledge_id: - await knowledge_service.update_status_knowledge( - event.knowledge_id, KnowledgeStatus.ERROR + +async def handler_loop(): + async with AsyncSession( + async_engine, expire_on_commit=False, autoflush=False + ) as session: + await session.execute( + text("SET SESSION idle_in_transaction_session_timeout = '1min';") + ) + knowledge_service = KnowledgeService(KnowledgeRepository(session)) + notification_service = NotificationService() + logger.info("Initialized knowledge_service. Listening to task event...") + while True: + try: + event: TaskEvent = queue.get() + if event.status == TaskStatus.FAILED: + await handle_success_task( + task=event, + knowledge_service=knowledge_service, + notification_service=notification_service, ) - logger.error( - f"task {event.task_id} process_file_task . Updating knowledge {event.knowledge_id} status to Error" - ) - if event.status == TaskStatus.SUCCESS: - logger.info( - f"task {event.task_id} process_file_task succeeded. Sending notification {event.notification_id}" - ) - notification_service.update_notification_by_id( - event.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.SUCCESS, - description=( - "Your file has been properly uploaded!" - if event.task_name == TaskIdentifier.PROCESS_FILE_TASK - else "Your URL has been properly crawled!" - ), - ), - ) - if event.knowledge_id: - await knowledge_service.update_status_knowledge( - knowledge_id=event.knowledge_id, - status=KnowledgeStatus.UPLOADED, - brain_id=event.brain_id, + if event.status == TaskStatus.SUCCESS: + await handle_error_task( + task=event, + knowledge_service=knowledge_service, + notification_service=notification_service, ) - logger.info( - f"task {event.task_id} process_file_task failed. Updating knowledge {event.knowledge_id} to UPLOADED" - ) - except Exception as e: - logger.error(f"Excpetion occured handling event {event}: {e}") + except Exception as e: + logger.error(f"Excpetion occured handling event {event}: {e}") def notifier(app): @@ -117,16 +153,14 @@ def handle_task_event(event): task_result = AsyncResult(task.id, app=app) task_name, task_kwargs = task_result.name, task_result.kwargs - if task_name == "process_file_task" or task_name == "process_crawl_task": + if task_name == TaskIdentifier.PROCESS_FILE_TASK: logger.debug(f"Received Event : {task} - {task_name} {task_kwargs} ") - notification_id = task_kwargs["notification_id"] knowledge_id = task_kwargs.get("knowledge_id", None) - brain_id = task_kwargs.get("brain_id", None) + notification_id = task_kwargs.get("notification_id", None) event = TaskEvent( task_id=task, task_name=TaskIdentifier(task_name), knowledge_id=knowledge_id, - brain_id=brain_id, notification_id=notification_id, status=TaskStatus(event["type"]), ) @@ -146,23 +180,6 @@ def handle_task_event(event): recv.capture(limit=None, timeout=None, wakeup=True) -def is_being_executed(task_name: str) -> bool: - """Returns whether the task with given task_name is already being executed. - - Args: - task_name: Name of the task to check if it is running currently. - Returns: A boolean indicating whether the task with the given task name is - running currently. - """ - active_tasks = celery.control.inspect().active() - for worker, running_tasks in active_tasks.items(): - for task in running_tasks: - if task["name"] == task_name: # type: ignore - return True - - return False - - if __name__ == "__main__": logger.info("Started quivr-notifier service...") diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index aa4c0c8d3366..25a51c10d933 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -2,6 +2,7 @@ import os from uuid import UUID +from celery.schedules import crontab from celery.signals import worker_process_init from dotenv import load_dotenv from quivr_api.celery_config import celery @@ -161,23 +162,21 @@ def check_is_premium_task(): # sync_user_service.clean_notion_user_syncs() -# from celery.schedules import crontab - -# celery.conf.beat_schedule = { -# "ping_telemetry": { -# "task": f"{__name__}.ping_telemetry", -# "schedule": crontab(minute="*/30", hour="*"), -# }, -# "process_active_syncs": { -# "task": "process_active_syncs_task", -# "schedule": crontab(minute="*/1", hour="*"), -# }, -# "process_premium_users": { -# "task": "check_is_premium_task", -# "schedule": crontab(minute="*/1", hour="*"), -# }, -# "process_notion_sync": { -# "task": "process_notion_sync_task", -# "schedule": crontab(minute="0", hour="*/6"), -# }, -# } +celery.conf.beat_schedule = { + "ping_telemetry": { + "task": f"{__name__}.ping_telemetry", + "schedule": crontab(minute="*/30", hour="*"), + }, + # "process_active_syncs": { + # "task": "process_active_syncs_task", + # "schedule": crontab(minute="*/1", hour="*"), + # }, + "process_premium_users": { + "task": "check_is_premium_task", + "schedule": crontab(minute="*/1", hour="*"), + }, + "process_notion_sync": { + "task": "process_notion_sync_task", + "schedule": crontab(minute="0", hour="*/6"), + }, +} diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md new file mode 100644 index 000000000000..0bd35b56fdd4 --- /dev/null +++ b/backend/worker/quivr_worker/process/README.md @@ -0,0 +1,67 @@ +# Processing Knowledges + +## Processing steps + +- Task received a `knowledge_id : UUID`. +- `KnowledgeProcessor.porcess_knowlege` processes the knowledge: + - Builds a processable tuple of [Knowledge,QuivrFile] stream: + - Gets the `KnowledgeDB` object from db: + - Matches based on the knowledge source: + - **local**: + - Downloads the knowledge data from S3 storage and writes it to tempfile + - Yields the + - **web**: works a lot like **local**... + - **syncs**: + - Get the associated sync and checks the credentials + - Concurrently fetch all knowledges for user that are in db associated with this sync and the tree of sync files which the knowledge is the parent (using the sync provider) + - Downloads knowledge and yields the [knowledge,QuivrFile] + - For all children of this knowledges (those fetched from the sync): + - If child in db (ie we have knowledge where `knowledge.sync_id == sync_file.id`): + - Implies that we could have sync children that were processed before in some other brain + - Link it to the parent brains and move on if it is PROCESSED ELSE Reprocess the file + - We are done here + - Else: + - Create the knowledge associated with this sync file and set it to Processing + - Downloads syncfile data and yield the [knowledge,quivr_file] + +In the processing loop for each processable [KnowledgeDB,Quivrfile], if an exception raised we need to deal with this: + +### Catchable error: + +1. Rollback (only affects the vectors) if they were set. + +- Stateful operations are in order: + + - Creating knowledges (with processing status) + - Updating knowledges: linking to brains + - Creating vectors + - Updating knowledges + +- Creation operations and linking to brains can be retried safely. Knowledge is only recreated if they do not exist in DB. Which means we get we can safely retry this operation + +- Linking km to brain only link brain if it's not already associated with km. Safe for retry + +- Creating vectors : + + - This operation should be rollback if we have an error after. Because we would have a knowledge in Processing/ ERROR status with associated vectors. + + - Reprocessing the knowledge would mean reinserting vectors in the db. which would insert duplicate vectors ! + +2. Set knowledge to error + +3. Continue processing + +| This would mean that some knowledges would be errored. For now we don't automatically reprocess the knowledge right after. + +### Uncatchable error ie worker process fails: + +- The task will be automatically retried 3 times. +- Notifier will receive event task as failed +- Notifier sets knowledge status to ERROR for the task + +**NOTE**: for the v0.1 version: +For `process_knowledge` tasks that need to process a sync folder, the folder will be set to ERROR. If we have created child knowledges associated with sync, we can't really set their status to ERROR. This would mean that they are showed as PROCESSING. + +## Notification steps + +TO discuss @StanGirard @Zewed diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 6e5a6107c414..1305a4d40ca9 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -1,13 +1,13 @@ import asyncio from contextlib import asynccontextmanager from dataclasses import dataclass -from io import BytesIO -from typing import Any, AsyncGenerator, List, Optional, Tuple +from pathlib import Path +from typing import AsyncGenerator, List, Optional, Tuple from uuid import UUID from quivr_api.logger import get_logger from quivr_api.modules.dependencies import get_supabase_async_client -from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate +from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage @@ -31,6 +31,7 @@ from quivr_worker.process.process_file import parse_qfile, store_chunks from quivr_worker.process.utils import ( build_qfile, + build_sync_file, build_syncprovider_mapping, compute_sha1, skip_process, @@ -87,21 +88,6 @@ async def build_processor_services( logger.info("Closing processor services") -async def download_sync_file( - sync_provider: BaseSync, file: SyncFile, credentials: dict[str, Any] -) -> bytes: - logger.info(f"Downloading {file} using {sync_provider}") - file_response = await sync_provider.adownload_file(credentials, file) - logger.debug(f"Fetch sync file response: {file_response}") - raw_data = file_response["content"] - if isinstance(raw_data, BytesIO): - file_data = raw_data.read() - else: - file_data = raw_data.encode("utf-8") - logger.debug(f"Successfully downloaded sync file : {file}") - return file_data - - class KnowledgeProcessor: def __init__(self, services: ProcessorServices): self.services = services @@ -122,7 +108,7 @@ async def fetch_sync_knowledge( ) return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 - async def yield_processable_kms( + async def yield_processable_knowledge( self, knowledge_id: UUID ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: """Should only yield ready to process knowledges: @@ -154,19 +140,31 @@ async def yield_processable_kms( raise ValueError(f"Unknown knowledge source : {knowledge.source}") async def _yield_local( - self, knowledge_db: KnowledgeDB + self, knowledge: KnowledgeDB ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile] | None, None]: - if knowledge_db.id is None or knowledge_db.file_name is None: - logger.error(f"received unprocessable local knowledge : {knowledge_db.id} ") + if knowledge.id is None or knowledge.file_name is None: + logger.error(f"received unprocessable local knowledge : {knowledge.id} ") raise ValueError( - f"received unprocessable local knowledge : {knowledge_db.id} " + f"received unprocessable local knowledge : {knowledge.id} " ) - file_data = await self.services.knowledge_service.storage.download_file( - knowledge_db - ) - knowledge_db.file_sha1 = compute_sha1(file_data) - with build_qfile(knowledge_db, file_data) as qfile: - yield (knowledge_db, qfile) + if knowledge.is_folder: + yield ( + knowledge, + QuivrFile( + id=knowledge.id, + original_filename=knowledge.file_name, + file_extension=knowledge.extension, + file_sha1="", + path=Path(), + ), + ) + else: + file_data = await self.services.knowledge_service.storage.download_file( + knowledge + ) + knowledge.file_sha1 = compute_sha1(file_data) + with build_qfile(knowledge, file_data) as qfile: + yield (knowledge, qfile) async def _yield_web( self, knowledge_db: KnowledgeDB @@ -219,9 +217,11 @@ async def _yield_syncs( sync_provider = self.services.syncprovider_mapping[provider_name] # Yield parent_knowledge as the first knowledge to process - file_data = await download_sync_file( + async with build_sync_file( + file_knowledge=parent_knowledge, + sync=sync, sync_provider=sync_provider, - file=SyncFile( + sync_file=SyncFile( id=parent_knowledge.sync_file_id, name=parent_knowledge.file_name, extension=parent_knowledge.extension, @@ -229,11 +229,8 @@ async def _yield_syncs( is_folder=parent_knowledge.is_folder, last_modified_at=parent_knowledge.updated_at, ), - credentials=sync.credentials, - ) - parent_knowledge.file_sha1 = compute_sha1(file_data) - with build_qfile(parent_knowledge, file_data) as qfile: - yield (parent_knowledge, qfile) + ) as f: + yield f # Fetch children syncfile_to_knowledge, sync_files = await self.fetch_sync_knowledge( @@ -245,54 +242,25 @@ async def _yield_syncs( return for sync_file in sync_files: - existing_km = syncfile_to_knowledge.get(sync_file.id) - if existing_km is not None: - # NOTE: - # The parent_knowledge was just added (we are processing it) - # This implies that we could have sync children that were processed before - # IF SyncKnowledge already exists => It's already processed in some other brain - # => Link it to the parent brains and move on if it is PROCESSED ELSE Reprocess the file - km_brains = {km_brain.brain_id for km_brain in existing_km.brains} - for brain in filter( - lambda b: b.brain_id not in km_brains, - parent_knowledge.brains, - ): - await self.services.knowledge_service.repository.link_to_brain( - existing_km, brain_id=brain.brain_id - ) - # Don't reprocess already added syncs knowledges - if existing_km.status == KnowledgeStatus.PROCESSED: - continue - else: - # create sync file knowledge - # automagically gets the brains associated with the parent - file_knowledge = await self.services.knowledge_service.create_knowledge( - user_id=parent_knowledge.user_id, - knowledge_to_add=AddKnowledge( - file_name=sync_file.name, - is_folder=sync_file.is_folder, - extension=sync_file.extension, - source=parent_knowledge.source, # same as parent - source_link=sync_file.web_view_link, - parent_id=parent_knowledge.id, - sync_id=parent_knowledge.sync_id, - sync_file_id=sync_file.id, - ), - status=KnowledgeStatus.PROCESSING, - upload_file=None, + file_knowledge = ( + await self.services.knowledge_service.create_or_link_sync_knowledge( + syncfile_to_knowledge=syncfile_to_knowledge, + parent_knowledge=parent_knowledge, + sync_file=sync_file, ) - file_data = await download_sync_file( - sync_provider=sync_provider, - file=sync_file, - credentials=sync.credentials, ) - file_knowledge.file_sha1 = compute_sha1(file_data) - file_knowledge.file_size = len(file_data) - with build_qfile(file_knowledge, file_data) as qfile: - yield (file_knowledge, qfile) + if file_knowledge.status == KnowledgeStatus.PROCESSED: + continue + async with build_sync_file( + file_knowledge=file_knowledge, + sync=sync, + sync_provider=sync_provider, + sync_file=sync_file, + ) as f: + yield f async def process_knowledge(self, knowledge_id: UUID): - async for knowledge_tuple in self.yield_processable_kms(knowledge_id): + async for knowledge_tuple in self.yield_processable_knowledge(knowledge_id): # FIXME savepoint = ( await self.services.knowledge_service.repository.session.begin_nested() @@ -318,3 +286,9 @@ async def process_knowledge(self, knowledge_id: UUID): except Exception as e: await savepoint.rollback() logger.error(f"Error processing knowledge {knowledge_id} : {e}") + await self.services.knowledge_service.update_knowledge( + knowledge, + KnowledgeUpdate( + status=KnowledgeStatus.ERROR, + ), + ) diff --git a/backend/worker/quivr_worker/process/utils.py b/backend/worker/quivr_worker/process/utils.py index e9c73362311a..34adb5147c3b 100644 --- a/backend/worker/quivr_worker/process/utils.py +++ b/backend/worker/quivr_worker/process/utils.py @@ -1,16 +1,18 @@ import hashlib import os import time -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager +from io import BytesIO from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Generator, Tuple +from typing import Any, AsyncGenerator, Generator, Tuple from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile from quivr_api.modules.sync.utils.sync import ( AzureDriveSync, BaseSync, @@ -76,6 +78,40 @@ def create_temp_file( tmp_file.close() +async def download_sync_file( + sync_provider: BaseSync, file: SyncFile, credentials: dict[str, Any] +) -> bytes: + logger.info(f"Downloading {file} using {sync_provider}") + file_response = await sync_provider.adownload_file(credentials, file) + logger.debug(f"Fetch sync file response: {file_response}") + raw_data = file_response["content"] + if isinstance(raw_data, BytesIO): + file_data = raw_data.read() + else: + file_data = raw_data.encode("utf-8") + logger.debug(f"Successfully downloaded sync file : {file}") + return file_data + + +@asynccontextmanager +async def build_sync_file( + file_knowledge: KnowledgeDB, + sync_file: SyncFile, + sync_provider: BaseSync, + sync: Sync, +) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile], None]: + assert sync.credentials + file_data = await download_sync_file( + sync_provider=sync_provider, + file=sync_file, + credentials=sync.credentials, + ) + file_knowledge.file_sha1 = compute_sha1(file_data) + file_knowledge.file_size = len(file_data) + with build_qfile(file_knowledge, file_data) as qfile: + yield (file_knowledge, qfile) + + @contextmanager def build_qfile( knowledge: KnowledgeDB, file_data: bytes diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index 53b3f6688cd4..ebe44a02241a 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -167,6 +167,72 @@ async def sync(session: AsyncSession, user: User) -> Sync: return sync +@pytest_asyncio.fixture(scope="function") +async def local_knowledge_folder( + proc_services: ProcessorServices, user: User, brain_user: Brain +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + service = proc_services.knowledge_service + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=True, + parent_id=None, + ) + km = await service.create_knowledge( + user_id=user.id, knowledge_to_add=km_to_add, upload_file=None + ) + # Link it to the brain + await service.link_knowledge_tree_brains( + km, brains_ids=[brain_user.brain_id], user_id=user.id + ) + km = await service.update_knowledge( + knowledge=km, + payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + ) + return km + + +@pytest_asyncio.fixture(scope="function") +async def local_knowledge_folder_with_file( + proc_services: ProcessorServices, user: User, brain_user: Brain +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + service = proc_services.knowledge_service + km_to_add = AddKnowledge( + file_name="test", + source="local", + is_folder=True, + parent_id=None, + ) + folder_km = await service.create_knowledge( + user_id=user.id, knowledge_to_add=km_to_add, upload_file=None + ) + km_to_add = AddKnowledge( + file_name="test_file", + source=KnowledgeSource.LOCAL, + is_folder=False, + parent_id=folder_km.id, + ) + km_data = BytesIO(os.urandom(24)) + km = await service.create_knowledge( + user_id=user.id, + knowledge_to_add=km_to_add, + upload_file=UploadFile(file=km_data, size=24, filename=km_to_add.file_name), + ) + # Link it to the brain + await service.link_knowledge_tree_brains( + folder_km, brains_ids=[brain_user.brain_id], user_id=user.id + ) + await service.update_knowledge( + knowledge=folder_km, + payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + ) + return folder_km + + @pytest_asyncio.fixture(scope="function") async def local_knowledge_file( proc_services: ProcessorServices, user: User, brain_user: Brain diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index 25408a24fc6a..4f497b03751a 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -11,6 +11,14 @@ from sqlmodel.ext.asyncio.session import AsyncSession +async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], +) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("proc_services", [0], indirect=True) async def test_process_local_file( @@ -21,13 +29,6 @@ async def test_process_local_file( ): input_km = local_knowledge_file - async def _parse_file_mock( - qfile: QuivrFile, - **processor_kwargs: dict[str, Any], - ) -> list[Document]: - with open(qfile.path, "rb") as f: - return [Document(page_content=str(f.read()), metadata={})] - monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) assert input_km.id assert input_km.brains @@ -53,6 +54,74 @@ async def _parse_file_mock( assert vecs[0].metadata_ is not None +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) +async def test_process_local_folder( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + local_knowledge_folder: KnowledgeDB, +): + input_km = local_knowledge_folder + + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + assert input_km.id + assert input_km.brains + km_processor = KnowledgeProcessor(proc_services) + await km_processor.process_knowledge(input_km.id) + + # Check knowledge processed + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(input_km.id) + assert km.status == KnowledgeStatus.PROCESSED + assert km.brains[0].brain_id == input_km.brains[0].brain_id + assert km.file_sha1 is None + + # Check vectors where added + vecs = list( + ( + await session.exec( + select(Vector).where(col(Vector.knowledge_id) == input_km.id) + ) + ).all() + ) + assert len(vecs) == 0 + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) +async def test_process_local_folder_with_file( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + local_knowledge_folder_with_file: KnowledgeDB, +): + input_km = local_knowledge_folder_with_file + + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + assert input_km.id + assert input_km.brains + km_processor = KnowledgeProcessor(proc_services) + await km_processor.process_knowledge(input_km.id) + + # Check knowledge processed + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(input_km.id) + assert km.status == KnowledgeStatus.PROCESSED + assert km.brains[0].brain_id == input_km.brains[0].brain_id + assert km.file_sha1 is None + + # Check vectors where added + vecs = list( + ( + await session.exec( + select(Vector).where(col(Vector.knowledge_id) == input_km.id) + ) + ).all() + ) + assert len(vecs) == 0 + + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("proc_services", [0], indirect=True) async def test_process_web_file( @@ -66,13 +135,6 @@ async def test_process_web_file( async def _extract_url(url: str) -> str: return "quivr has the best rag" - async def _parse_file_mock( - qfile: QuivrFile, - **processor_kwargs: dict[str, Any], - ) -> list[Document]: - with open(qfile.path, "rb") as f: - return [Document(page_content=str(f.read()), metadata={})] - monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) monkeypatch.setattr("quivr_worker.process.processor.extract_from_url", _extract_url) assert input_km.id @@ -111,13 +173,6 @@ async def test_process_sync_file( assert input_km.id assert input_km.brains - async def _parse_file_mock( - qfile: QuivrFile, - **processor_kwargs: dict[str, Any], - ) -> list[Document]: - with open(qfile.path, "rb") as f: - return [Document(page_content=str(f.read()), metadata={})] - km_processor = KnowledgeProcessor(proc_services) monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) await km_processor.process_knowledge(input_km.id) @@ -153,13 +208,6 @@ async def test_process_sync_folder( assert input_km.id assert input_km.brains - async def _parse_file_mock( - qfile: QuivrFile, - **processor_kwargs: dict[str, Any], - ) -> list[Document]: - with open(qfile.path, "rb") as f: - return [Document(page_content=str(f.read()), metadata={})] - km_processor = KnowledgeProcessor(proc_services) monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) await km_processor.process_knowledge(input_km.id) @@ -201,13 +249,6 @@ async def test_process_sync_folder_with_file_in_brain( assert input_km.id assert input_km.brains - async def _parse_file_mock( - qfile: QuivrFile, - **processor_kwargs: dict[str, Any], - ) -> list[Document]: - with open(qfile.path, "rb") as f: - return [Document(page_content=str(f.read()), metadata={})] - km_processor = KnowledgeProcessor(proc_services) monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) await km_processor.process_knowledge(input_km.id) @@ -247,13 +288,6 @@ async def test_process_sync_folder_with_file_in_other_brain( assert input_km.id assert input_km.brains - async def _parse_file_mock( - qfile: QuivrFile, - **processor_kwargs: dict[str, Any], - ) -> list[Document]: - with open(qfile.path, "rb") as f: - return [Document(page_content=str(f.read()), metadata={})] - km_processor = KnowledgeProcessor(proc_services) monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) await km_processor.process_knowledge(input_km.id) @@ -294,13 +328,6 @@ async def test_process_km_rollback( assert input_km.id assert input_km.brains - async def _parse_file_mock( - qfile: QuivrFile, - **processor_kwargs: dict[str, Any], - ) -> list[Document]: - with open(qfile.path, "rb") as f: - return [Document(page_content=str(f.read()), metadata={})] - async def _update_km_error(*args, **kwargs): raise Exception("Error") From 970bbac69724702fb871b96c09de46917fdaf8b2 Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 19:00:40 +0200 Subject: [PATCH 31/63] readme update --- backend/worker/quivr_worker/process/README.md | 14 ++++++++++---- backend/worker/quivr_worker/process/processor.py | 7 ++++--- backend/worker/tests/test_process_file_task.py | 14 ++++++-------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index 0bd35b56fdd4..f9c3bd99a537 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -51,16 +51,22 @@ In the processing loop for each processable [KnowledgeDB,Quivrfile], if an excep 3. Continue processing -| This would mean that some knowledges would be errored. For now we don't automatically reprocess the knowledge right after. +| This would mean that some knowledges would be errored. For now we don't automatically reschedule them for processing right after. ### Uncatchable error ie worker process fails: -- The task will be automatically retried 3 times. +- The task will be automatically retried 3 times handled by celery - Notifier will receive event task as failed - Notifier sets knowledge status to ERROR for the task -**NOTE**: for the v0.1 version: -For `process_knowledge` tasks that need to process a sync folder, the folder will be set to ERROR. If we have created child knowledges associated with sync, we can't really set their status to ERROR. This would mean that they are showed as PROCESSING. +🔴 **NOTE: Sync error handling for the v0.1 version:** + +`process_knowledge` tasks that need to process a sync folder, the folder will be set to ERROR. +If we have created child knowledges associated with sync, we can't really set their status to ERROR. This would mean that they will be stuck at status PROCESSING with their parent with an ERROR status. + +Why can't we set all children to ERROR? This could introduce a subtle race condition: sync knowledge can be added to brain independently from their parent so we can't know for sure if the status PROCESSING is associated with the task that just failed. We could keep a `task_id` associated with knowledge_id but this is bug prone and impacts the db schema which has a large impact. + +The knowledge (syncs) that are added to some brain will be reprocessed after some period of time in the update sync task so their status will be eventually set to the correct state. ## Notification steps diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 1305a4d40ca9..5e5ace44cc09 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -265,10 +265,10 @@ async def process_knowledge(self, knowledge_id: UUID): savepoint = ( await self.services.knowledge_service.repository.session.begin_nested() ) + if knowledge_tuple is None: + continue + knowledge, qfile = knowledge_tuple try: - if knowledge_tuple is None: - continue - knowledge, qfile = knowledge_tuple if not skip_process(knowledge): chunks = await parse_qfile(qfile=qfile) await store_chunks( @@ -283,6 +283,7 @@ async def process_knowledge(self, knowledge_id: UUID): file_sha1=knowledge.file_sha1, ), ) + except Exception as e: await savepoint.rollback() logger.error(f"Error processing knowledge {knowledge_id} : {e}") diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index 4f497b03751a..83ed7521fe47 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -328,22 +328,20 @@ async def test_process_km_rollback( assert input_km.id assert input_km.brains - async def _update_km_error(*args, **kwargs): - raise Exception("Error") + async def _store_chunks_error(*args, **kwargs): + raise Exception("mock error") - monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + monkeypatch.setattr( + "quivr_worker.process.processor.store_chunks", _store_chunks_error + ) km_processor = KnowledgeProcessor(proc_services) - - # Set error at the end - km_processor.services.knowledge_service.update_knowledge = _update_km_error - await km_processor.process_knowledge(input_km.id) # Check knowledge set to processed knowledge_service = km_processor.services.knowledge_service km = await knowledge_service.get_knowledge(input_km.id) - assert km.status == KnowledgeStatus.PROCESSING # tests are just uploaded + assert km.status == KnowledgeStatus.ERROR vecs = list((await session.exec(select(Vector))).all()) # Check we remove the vectors assert len(vecs) == 0 From 84b19015eaff738f4608c186d064374b885afffb Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 19:14:00 +0200 Subject: [PATCH 32/63] updated readme --- backend/worker/quivr_worker/process/README.md | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index f9c3bd99a537..a9f6355e4fc1 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -9,33 +9,37 @@ - Matches based on the knowledge source: - **local**: - Downloads the knowledge data from S3 storage and writes it to tempfile - - Yields the + - Yields the [Knowledge,QuivrFile] - **web**: works a lot like **local**... - - **syncs**: + - **[syncs]**: - Get the associated sync and checks the credentials - - Concurrently fetch all knowledges for user that are in db associated with this sync and the tree of sync files which the knowledge is the parent (using the sync provider) - - Downloads knowledge and yields the [knowledge,QuivrFile] - - For all children of this knowledges (those fetched from the sync): + - Concurrently fetches all knowledges for user that are in db associated with this sync and the tree of sync files this knowledge is the parent of (using the sync provider) + - Downloads knowledge and yields the first [knowledge,QuivrFile]. This is the one this task received + - For all children of this knowledges (ie: those fetched from the sync): - If child in db (ie we have knowledge where `knowledge.sync_id == sync_file.id`): - Implies that we could have sync children that were processed before in some other brain - - Link it to the parent brains and move on if it is PROCESSED ELSE Reprocess the file - - We are done here + - if it is PROCESSED Link it to the parent brains and move on + - ELSE reprocess the file - Else: - Create the knowledge associated with this sync file and set it to Processing - Downloads syncfile data and yield the [knowledge,quivr_file] + - Skip processing of the tuple if the knowledge is folder + - Parse the QuivrFile using `quivr-core` + - Store the chunks in the DB + - Update knowledge status to PROCESSED -In the processing loop for each processable [KnowledgeDB,Quivrfile], if an exception raised we need to deal with this: +If an exception occurs during the parsing loop we do the following: ### Catchable error: -1. Rollback (only affects the vectors) if they were set. +1. We first the current transaction Rollback (only affects the vectors) if they were set. The processing loop has the following stateful operations in this order: -- Stateful operations are in order: +- Creating knowledges (with processing status) +- Updating knowledges: linking to brains +- Creating vectors +- Updating knowledges - - Creating knowledges (with processing status) - - Updating knowledges: linking to brains - - Creating vectors - - Updating knowledges +Here is the transaction SAFETY for each operation. These could change and we need to keep the transactional garantees in mind: - Creation operations and linking to brains can be retried safely. Knowledge is only recreated if they do not exist in DB. Which means we get we can safely retry this operation @@ -45,7 +49,7 @@ In the processing loop for each processable [KnowledgeDB,Quivrfile], if an excep - This operation should be rollback if we have an error after. Because we would have a knowledge in Processing/ ERROR status with associated vectors. - - Reprocessing the knowledge would mean reinserting vectors in the db. which would insert duplicate vectors ! + - Reprocessing the knowledge would mean reinserting vectors in the db. This means ending up with duplicate vectors for the same knowledge ! 2. Set knowledge to error @@ -55,7 +59,7 @@ In the processing loop for each processable [KnowledgeDB,Quivrfile], if an excep ### Uncatchable error ie worker process fails: -- The task will be automatically retried 3 times handled by celery +- The task will be automatically retried 3 times -> handled by celery - Notifier will receive event task as failed - Notifier sets knowledge status to ERROR for the task From 42a2dbb6327f1ecea43d201789d2845db8f00b9c Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 19:14:18 +0200 Subject: [PATCH 33/63] readme --- backend/worker/quivr_worker/process/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index a9f6355e4fc1..ea345671a392 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -3,7 +3,7 @@ ## Processing steps - Task received a `knowledge_id : UUID`. -- `KnowledgeProcessor.porcess_knowlege` processes the knowledge: +- `KnowledgeProcessor.process_knowlege` processes the knowledge: - Builds a processable tuple of [Knowledge,QuivrFile] stream: - Gets the `KnowledgeDB` object from db: - Matches based on the knowledge source: From e9e0dbda7624eb563a32eb991b3c1a7147808dbd Mon Sep 17 00:00:00 2001 From: aminediro Date: Thu, 26 Sep 2024 19:19:09 +0200 Subject: [PATCH 34/63] readme --- backend/worker/quivr_worker/process/README.md | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index ea345671a392..5e92de91d296 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -2,7 +2,7 @@ ## Processing steps -- Task received a `knowledge_id : UUID`. +- Task receives a `knowledge_id : UUID`. - `KnowledgeProcessor.process_knowlege` processes the knowledge: - Builds a processable tuple of [Knowledge,QuivrFile] stream: - Gets the `KnowledgeDB` object from db: @@ -10,19 +10,19 @@ - **local**: - Downloads the knowledge data from S3 storage and writes it to tempfile - Yields the [Knowledge,QuivrFile] - - **web**: works a lot like **local**... - - **[syncs]**: - - Get the associated sync and checks the credentials - - Concurrently fetches all knowledges for user that are in db associated with this sync and the tree of sync files this knowledge is the parent of (using the sync provider) - - Downloads knowledge and yields the first [knowledge,QuivrFile]. This is the one this task received - - For all children of this knowledges (ie: those fetched from the sync): - - If child in db (ie we have knowledge where `knowledge.sync_id == sync_file.id`): - - Implies that we could have sync children that were processed before in some other brain - - if it is PROCESSED Link it to the parent brains and move on - - ELSE reprocess the file - - Else: - - Create the knowledge associated with this sync file and set it to Processing - - Downloads syncfile data and yield the [knowledge,quivr_file] + - **web**: works a lot like **local**... + - **[syncs]**: + - Get the associated sync and checks the credentials + - Concurrently fetches all knowledges for user that are in db associated with this sync and the tree of sync files this knowledge is the parent of (using the sync provider) + - Downloads knowledge and yields the first [knowledge,QuivrFile]. This is the one this task received + - For all children of this knowledges (ie: those fetched from the sync): + - If child in db (ie we have knowledge where `knowledge.sync_id == sync_file.id`): + - Implies that we could have sync children that were processed before in some other brain + - if it is PROCESSED Link it to the parent brains and move on + - ELSE reprocess the file + - Else: + - Create the knowledge associated with this sync file and set it to Processing + - Downloads syncfile data and yield the [knowledge,quivr_file] - Skip processing of the tuple if the knowledge is folder - Parse the QuivrFile using `quivr-core` - Store the chunks in the DB From 513929492ec33c8a0d4b705eca4a289857dbc667 Mon Sep 17 00:00:00 2001 From: AmineDiro Date: Thu, 26 Sep 2024 19:36:00 +0200 Subject: [PATCH 35/63] Update README.md --- backend/worker/quivr_worker/process/README.md | 116 +++++++++--------- 1 file changed, 59 insertions(+), 57 deletions(-) diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index 5e92de91d296..2499bd7bb872 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -1,77 +1,79 @@ -# Processing Knowledges +Here's the grammar correction and a more explicit version of your markdown, keeping the original logic intact: -## Processing steps +--- -- Task receives a `knowledge_id : UUID`. -- `KnowledgeProcessor.process_knowlege` processes the knowledge: - - Builds a processable tuple of [Knowledge,QuivrFile] stream: - - Gets the `KnowledgeDB` object from db: - - Matches based on the knowledge source: - - **local**: - - Downloads the knowledge data from S3 storage and writes it to tempfile - - Yields the [Knowledge,QuivrFile] - - **web**: works a lot like **local**... - - **[syncs]**: - - Get the associated sync and checks the credentials - - Concurrently fetches all knowledges for user that are in db associated with this sync and the tree of sync files this knowledge is the parent of (using the sync provider) - - Downloads knowledge and yields the first [knowledge,QuivrFile]. This is the one this task received - - For all children of this knowledges (ie: those fetched from the sync): - - If child in db (ie we have knowledge where `knowledge.sync_id == sync_file.id`): - - Implies that we could have sync children that were processed before in some other brain - - if it is PROCESSED Link it to the parent brains and move on - - ELSE reprocess the file - - Else: - - Create the knowledge associated with this sync file and set it to Processing - - Downloads syncfile data and yield the [knowledge,quivr_file] - - Skip processing of the tuple if the knowledge is folder - - Parse the QuivrFile using `quivr-core` - - Store the chunks in the DB - - Update knowledge status to PROCESSED +# Knowledge Processing -If an exception occurs during the parsing loop we do the following: +## Steps for Processing -### Catchable error: +1. The task receives a `knowledge_id: UUID`. +2. The `KnowledgeProcessor.process_knowledge` method processes the knowledge: + - It constructs a processable tuple of `[Knowledge, QuivrFile]` stream: + - Retrieves the `KnowledgeDB` object from the database. + - Determines the processing steps based on the knowledge source: + - **Local**: + - Downloads the knowledge data from S3 storage and writes it to a temporary file. + - Yields the `[Knowledge, QuivrFile]`. + - **Web**: Processes similarly to the **Local** method. + - **[Syncs]**: + - Fetches the associated sync and verifies the credentials. + - Concurrently retrieves all knowledges for the user from the database associated with this sync, as well as the tree of sync files where this knowledge is the parent (using the sync provider). + - Downloads the knowledge and yields the initial `[Knowledge, QuivrFile]` that the task received. + - For all children of this knowledge (i.e., those fetched from the sync): + - If the child exists in the database (i.e., knowledge where `knowledge.sync_id == sync_file.id`): + - This implies that the sync's child knowledge might have been processed earlier in another brain. + - If the knowledge has been PROCESSED, link it to the parent brains and continue. + - If not, reprocess the file. + - If the child does not exist: + - Create the knowledge associated with the sync file and set it to `Processing`. + - Download the sync file's data and yield the `[Knowledge, QuivrFile]`. + - Skip processing of the tuple if the knowledge is a folder. + - Parse the `QuivrFile` using `quivr-core`. + - Store the resulting chunks in the database. + - Update the knowledge status to `PROCESSED`. -1. We first the current transaction Rollback (only affects the vectors) if they were set. The processing loop has the following stateful operations in this order: +### Handling Exceptions During Parsing Loop -- Creating knowledges (with processing status) -- Updating knowledges: linking to brains -- Creating vectors -- Updating knowledges +#### Catchable Errors: -Here is the transaction SAFETY for each operation. These could change and we need to keep the transactional garantees in mind: +If an exception occurs during the parsing loop, the following steps are taken: -- Creation operations and linking to brains can be retried safely. Knowledge is only recreated if they do not exist in DB. Which means we get we can safely retry this operation +1. Roll back the current transaction (this only affects the vectors) if they were set. The processing loop performs the following stateful operations in this order: + - Creating knowledges (with `Processing` status). + - Updating knowledges: linking them to brains. + - Creating vectors. + - Updating knowledges. + + **Transaction Safety for Each Operation:** + - **Creating knowledge and linking to brains**: These operations can be retried safely. Knowledge is only recreated if it does not already exist in the database, allowing for safe retry. + - **Linking knowledge to brains**: Only links the brain if it is not already associated with the knowledge. Safe for retry. + - **Creating vectors**: + - This operation should be rolled back if an error occurs afterward. Otherwise, the knowledge could remain in `Processing` or `ERROR` status with associated vectors. + - Reprocessing the knowledge would result in reinserting the vectors into the database, leading to duplicate vectors for the same knowledge. -- Linking km to brain only link brain if it's not already associated with km. Safe for retry +2. Set the knowledge status to `ERROR`. +3. Continue processing. -- Creating vectors : +| Note: This means that some knowledges will remain in an errored state. Currently, they are not automatically rescheduled for processing. - - This operation should be rollback if we have an error after. Because we would have a knowledge in Processing/ ERROR status with associated vectors. +#### Uncatchable Errors (e.g., worker process fails): - - Reprocessing the knowledge would mean reinserting vectors in the db. This means ending up with duplicate vectors for the same knowledge ! +- The task will be automatically retried three times, handled by Celery. +- The notifier will receive an event indicating the task has failed. +- The notifier will set the knowledge status to `ERROR` for the task. -2. Set knowledge to error +--- -3. Continue processing +🔴 **NOTE: Sync Error Handling for Version v0.1:** -| This would mean that some knowledges would be errored. For now we don't automatically reschedule them for processing right after. +For `process_knowledge` tasks involving the processing of a sync folder, the folder's status will be set to `ERROR`. If child knowledges associated with the sync have already been created, their status cannot be set to `ERROR`. This would leave them stuck in `PROCESSING` status while their parent has an `ERROR` status. -### Uncatchable error ie worker process fails: +Why can’t we set all children to `ERROR`? This introduces a potential race condition: Sync knowledge can be added to a brain independently from its parent, so it’s unclear if the `PROCESSING` status is tied to the failed task. Although keeping a `task_id` associated with `knowledge_id` could help, it’s error-prone and impacts the database schema, which would have significant consequences. -- The task will be automatically retried 3 times -> handled by celery -- Notifier will receive event task as failed -- Notifier sets knowledge status to ERROR for the task +However, sync knowledge added to a brain will be reprocessed after some time through the sync update task, ensuring that their status will eventually be set to the correct state. -🔴 **NOTE: Sync error handling for the v0.1 version:** +--- -`process_knowledge` tasks that need to process a sync folder, the folder will be set to ERROR. -If we have created child knowledges associated with sync, we can't really set their status to ERROR. This would mean that they will be stuck at status PROCESSING with their parent with an ERROR status. +## Notification Steps -Why can't we set all children to ERROR? This could introduce a subtle race condition: sync knowledge can be added to brain independently from their parent so we can't know for sure if the status PROCESSING is associated with the task that just failed. We could keep a `task_id` associated with knowledge_id but this is bug prone and impacts the db schema which has a large impact. - -The knowledge (syncs) that are added to some brain will be reprocessed after some period of time in the update sync task so their status will be eventually set to the correct state. - -## Notification steps - -TO discuss @StanGirard @Zewed +To discuss: @StanGirard @Zewed From 4f365198da23cac16cbb102ffe0573f48f791fc8 Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 27 Sep 2024 10:10:49 +0200 Subject: [PATCH 36/63] root list join on children only --- .../api/quivr_api/modules/knowledge/repository/knowledges.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 418656f8cdaf..ceeebaec5484 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -313,7 +313,7 @@ async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]: select(KnowledgeDB) .where(KnowledgeDB.parent_id.is_(None)) # type: ignore .where(KnowledgeDB.user_id == user_id) - .options(joinedload(KnowledgeDB.parent), joinedload(KnowledgeDB.children)) # type: ignore + .options(joinedload(KnowledgeDB.children)) # type: ignore ) result = await self.session.exec(query) kms = result.unique().all() From c7491b842ae4509400111844343c85e5c0e3b79d Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 27 Sep 2024 11:25:10 +0200 Subject: [PATCH 37/63] serialization benchmark --- backend/benchmarks/serialization_dto.py | 158 ++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 backend/benchmarks/serialization_dto.py diff --git a/backend/benchmarks/serialization_dto.py b/backend/benchmarks/serialization_dto.py new file mode 100644 index 000000000000..77623e25c77e --- /dev/null +++ b/backend/benchmarks/serialization_dto.py @@ -0,0 +1,158 @@ +""" +Small experiment debugging json serializer for KMS. +Compare three serialization libs: pydantic, msgspec, orjson +""" + +import statistics +import timeit +from datetime import datetime +from typing import Any, Dict, List, Optional +from uuid import UUID + +import msgspec +import orjson +from pydantic import BaseModel +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_core.models import KnowledgeStatus +from rich.console import Console +from rich.table import Table + +n_dto = 1000 +num_runs = 100 + + +class ListKM(BaseModel): + kms: List[KnowledgeDTO] + + +def serialize_orjson(kms: list[KnowledgeDTO]): + return orjson.dumps([k.model_dump() for k in kms]) + + +def serialize_orjson_single(kms: ListKM): + return orjson.dumps(kms.model_dump()) + + +def serialize_pydantic(kms: list[KnowledgeDTO]): + return [km.model_dump_json() for km in kms] + + +def serialize_pydantic_obj(kms: ListKM): + return kms.model_dump_json() + + +def evaluate(name, func): + times = timeit.repeat( + lambda: func(), globals=globals(), repeat=num_runs, number=1 + ) # Change repeat=5 for desired runs + average_time = sum(times) / len(times) + std_dev = statistics.stdev(times) + return name, average_time * 1000, std_dev * 1000 + + +class KnowledgeMsg(msgspec.Struct): + updated_at: datetime + created_at: datetime + user_id: UUID + brains: List[Dict[str, Any]] + id: Optional[UUID] = None + status: Optional[KnowledgeStatus] = None + file_size: int = 0 + file_name: Optional[str] = None + url: Optional[str] = None + extension: str = ".txt" + is_folder: bool = False + source: Optional[str] = None + source_link: Optional[str] = None + file_sha1: Optional[str] = None + metadata: Optional[Dict[str, str]] = None + parent: Optional["KnowledgeDTO"] = None + children: List["KnowledgeDTO"] = [] + sync_id: Optional[int] = None + sync_file_id: Optional[str] = None + + +def print_table(results): + console = Console() + table = Table(title=f"Serialization Performance, n_obj={n_dto}", show_lines=True) + + # Define table columns + table.add_column("Function Name", justify="left", style="cyan") + table.add_column("Average Time (ms)", justify="right", style="magenta") + table.add_column("Standard Deviation (ms)", justify="right", style="green") + + # Add rows with evaluation results + for name, avg_time, std_dev in results: + table.add_row(name, f"{avg_time:.6f}", f"{std_dev:.6f}") + + # Print the table to the console + console.print(table) + + +def main(): + data = { + "id": "24185498-9025-44ea-ae70-b5a1a342f97c", + "file_size": 57210, + "status": "UPLOADED", + "file_name": "0000993.pdf", + "url": None, + "extension": ".pdf", + "is_folder": False, + "updated_at": "2024-09-26T19:01:23.881842Z", + "created_at": "2024-09-26T19:00:57.110967Z", + "source": "local", + "source_link": None, + "file_sha1": "1488859a8d85a309b2bff4c669177e688997bfe9", + "metadata": None, + "user_id": "155b9ab3-e649-4f8a-b5cf-8150728a9202", + "brains": [ + { + "name": "all_kms", + "description": "kms", + "temperature": 0, + "brain_type": "doc", + "brain_id": "a035b4e5-a385-468a-8f41-2d8344cc6a8f", + "status": "private", + "model": None, + "max_tokens": 2000, + "last_update": "2024-09-26T19:31:16.352708", + "prompt_id": None, + } + ], + "sync_id": None, + "sync_file_id": None, + "parent": None, + "children": [], + } + + km = KnowledgeDTO.model_validate(data) + # print(isinstance([km]*N,BaseModel)) + list_dto = [km] * n_dto + single_obj = ListKM(kms=list_dto) + km_msgspec = msgspec.json.decode(msgspec.json.encode(data), type=KnowledgeMsg) + list_msgspec = [km_msgspec] * n_dto + + # Evaluation + results = [] + results.append(evaluate("serialize_pydantic", lambda: serialize_pydantic(list_dto))) + results.append( + evaluate( + "serialize_pydantic_single_obj", lambda: serialize_pydantic_obj(single_obj) + ) + ) + results.append(evaluate("serialize_orjson", lambda: serialize_orjson(list_dto))) + results.append( + evaluate("serialize_orjson_single", lambda: serialize_orjson_single(single_obj)) + ) + results.append( + evaluate( + "serialize_msgspec", + lambda: [msgspec.json.encode(msg) for msg in list_msgspec], + ) + ) + + print_table(results) + + +if __name__ == "__main__": + main() From f08f6534643f85f11523fb36c4d258fdbdb96ab2 Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 27 Sep 2024 13:39:58 +0200 Subject: [PATCH 38/63] url tests --- .../knowledge/service/knowledge_service.py | 15 ++-- .../knowledge/tests/test_knowledge_service.py | 28 ++++++- backend/worker/quivr_worker/process/README.md | 81 +++++++++---------- 3 files changed, 73 insertions(+), 51 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index c5b03aec2500..6c2f332cb630 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -185,10 +185,6 @@ async def create_knowledge( knowledgedb, buff_reader ) knowledgedb.source_link = storage_path - knowledge_db = await self.repository.update_knowledge( - knowledge_db, - KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), - ) if knowledge_db.brains and len(knowledge_db.brains) > 0: # Schedule this new knowledge to be processed knowledge_db = await self.repository.update_knowledge( @@ -199,13 +195,16 @@ async def create_knowledge( "process_file_task", kwargs={ "knowledge_id": knowledge_db.id, - "file_name": knowledge_db.file_name, - "source": knowledge_db.source, - "source_link": knowledge_db.source_link, }, ) + return knowledge_db + else: + knowledge_db = await self.repository.update_knowledge( + knowledge_db, + KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), + ) + return knowledge_db - return knowledge_db except Exception as e: logger.exception( f"Error uploading knowledge {knowledgedb.id} to storage : {e}" diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index c764c92833f9..1d2a71489c18 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -17,7 +17,7 @@ KnowledgeStatus, KnowledgeUpdate, ) -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_exceptions import ( @@ -514,6 +514,30 @@ async def test_create_knowledge_file(session: AsyncSession, user: User): storage.knowledge_exists(km) +@pytest.mark.asyncio(loop_scope="session") +async def test_create_knowledge_web(session: AsyncSession, user: User): + assert user.id + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + km_to_add = AddKnowledge( + url="http://quivr.app", + source=KnowledgeSource.WEB, + is_folder=False, + parent_id=None, + ) + + km = await service.create_knowledge( + user_id=user.id, knowledge_to_add=km_to_add, upload_file=None + ) + + assert km.id + assert km.url == km_to_add.url + assert km.status == KnowledgeStatus.UPLOADED + assert not km.is_folder + + @pytest.mark.asyncio(loop_scope="session") async def test_create_knowledge_folder(session: AsyncSession, user: User): assert user.id @@ -552,7 +576,7 @@ async def test_create_knowledge_folder(session: AsyncSession, user: User): @pytest.mark.asyncio(loop_scope="session") -async def test_create_knowledge_file_in_folder( +async def test_create_knowledge_file_in_folder_in_brain( monkeypatch, session: AsyncSession, user: User, folder_km_brain: KnowledgeDB ): tasks = {} diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index 2499bd7bb872..fe7ddc0a2716 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -1,36 +1,32 @@ -Here's the grammar correction and a more explicit version of your markdown, keeping the original logic intact: - ---- - # Knowledge Processing ## Steps for Processing 1. The task receives a `knowledge_id: UUID`. 2. The `KnowledgeProcessor.process_knowledge` method processes the knowledge: - - It constructs a processable tuple of `[Knowledge, QuivrFile]` stream: - - Retrieves the `KnowledgeDB` object from the database. - - Determines the processing steps based on the knowledge source: - - **Local**: - - Downloads the knowledge data from S3 storage and writes it to a temporary file. - - Yields the `[Knowledge, QuivrFile]`. - - **Web**: Processes similarly to the **Local** method. - - **[Syncs]**: - - Fetches the associated sync and verifies the credentials. - - Concurrently retrieves all knowledges for the user from the database associated with this sync, as well as the tree of sync files where this knowledge is the parent (using the sync provider). - - Downloads the knowledge and yields the initial `[Knowledge, QuivrFile]` that the task received. - - For all children of this knowledge (i.e., those fetched from the sync): - - If the child exists in the database (i.e., knowledge where `knowledge.sync_id == sync_file.id`): - - This implies that the sync's child knowledge might have been processed earlier in another brain. - - If the knowledge has been PROCESSED, link it to the parent brains and continue. - - If not, reprocess the file. - - If the child does not exist: - - Create the knowledge associated with the sync file and set it to `Processing`. - - Download the sync file's data and yield the `[Knowledge, QuivrFile]`. - - Skip processing of the tuple if the knowledge is a folder. - - Parse the `QuivrFile` using `quivr-core`. - - Store the resulting chunks in the database. - - Update the knowledge status to `PROCESSED`. + - It constructs a processable tuple of `[Knowledge, QuivrFile]` stream: + - Retrieves the `KnowledgeDB` object from the database. + - Determines the processing steps based on the knowledge source: + - **Local**: + - Downloads the knowledge data from S3 storage and writes it to a temporary file. + - Yields the `[Knowledge, QuivrFile]`. + - **Web**: Processes similarly to the **Local** method. + - **[Syncs]**: + - Fetches the associated sync and verifies the credentials. + - Concurrently retrieves all knowledges for the user from the database associated with this sync, as well as the tree of sync files where this knowledge is the parent (using the sync provider). + - Downloads the knowledge and yields the initial `[Knowledge, QuivrFile]` that the task received. + - For all children of this knowledge (i.e., those fetched from the sync): + - If the child exists in the database (i.e., knowledge where `knowledge.sync_id == sync_file.id`): + - This implies that the sync's child knowledge might have been processed earlier in another brain. + - If the knowledge has been PROCESSED, link it to the parent brains and continue. + - If not, reprocess the file. + - If the child does not exist: + - Create the knowledge associated with the sync file and set it to `Processing`. + - Download the sync file's data and yield the `[Knowledge, QuivrFile]`. + - Skip processing of the tuple if the knowledge is a folder. + - Parse the `QuivrFile` using `quivr-core`. + - Store the resulting chunks in the database. + - Update the knowledge status to `PROCESSED`. ### Handling Exceptions During Parsing Loop @@ -39,20 +35,23 @@ Here's the grammar correction and a more explicit version of your markdown, keep If an exception occurs during the parsing loop, the following steps are taken: 1. Roll back the current transaction (this only affects the vectors) if they were set. The processing loop performs the following stateful operations in this order: - - Creating knowledges (with `Processing` status). - - Updating knowledges: linking them to brains. - - Creating vectors. - - Updating knowledges. - - **Transaction Safety for Each Operation:** - - **Creating knowledge and linking to brains**: These operations can be retried safely. Knowledge is only recreated if it does not already exist in the database, allowing for safe retry. - - **Linking knowledge to brains**: Only links the brain if it is not already associated with the knowledge. Safe for retry. - - **Creating vectors**: - - This operation should be rolled back if an error occurs afterward. Otherwise, the knowledge could remain in `Processing` or `ERROR` status with associated vectors. - - Reprocessing the knowledge would result in reinserting the vectors into the database, leading to duplicate vectors for the same knowledge. - -2. Set the knowledge status to `ERROR`. -3. Continue processing. + - Creating knowledges (with `Processing` status). + - Downloading sync files from sync provider + - Updating knowledges: linking them to brains. + - Creating vectors. + - Updating knowledges. + +**Transaction Safety for Each Operation:** + +- **Creating knowledge and linking to brains**: These operations can be retried safely. Knowledge is only recreated if it does not already exist in the database, allowing for safe retry. +- **Downloading sync files**: This operation is idempotent but is safe to retry. If a change has occured, we would download the last version of the file. +- **Linking knowledge to brains**: Only links the brain if it is not already associated with the knowledge. Safe for retry. +- **Creating vectors**: + - This operation should be rolled back if an error occurs afterward. Otherwise, the knowledge could remain in `Processing` or `ERROR` status with associated vectors. + - Reprocessing the knowledge would result in reinserting the vectors into the database, leading to duplicate vectors for the same knowledge. + +1. Set the knowledge status to `ERROR`. +2. Continue processing. | Note: This means that some knowledges will remain in an errored state. Currently, they are not automatically rescheduled for processing. From 055a033b9a087373a3f4bfb83887bb539dbe831f Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 27 Sep 2024 17:39:23 +0200 Subject: [PATCH 39/63] load tested kms --- .gitignore | 1 + .../knowledge/controller/knowledge_routes.py | 69 ++++--- .../modules/knowledge/entity/knowledge.py | 8 +- .../tests/test_knowledge_controller.py | 22 +-- backend/benchmarks/benchmark_kms.sh | 8 + backend/benchmarks/load_data.py | 173 ++++++++++++++++++ backend/benchmarks/locustfile.py | 101 ++++++++++ backend/benchmarks/locustfile_kms.py | 138 ++++++++++++++ backend/pyproject.toml | 1 + backend/requirements-dev.lock | 47 ++++- 10 files changed, 524 insertions(+), 44 deletions(-) create mode 100755 backend/benchmarks/benchmark_kms.sh create mode 100644 backend/benchmarks/load_data.py create mode 100644 backend/benchmarks/locustfile.py create mode 100644 backend/benchmarks/locustfile_kms.py diff --git a/.gitignore b/.gitignore index c034793f475b..d3b9324517d7 100644 --- a/.gitignore +++ b/.gitignore @@ -103,3 +103,4 @@ backend/core/examples/chatbot/.chainlit/translations/en-US.json .tox Pipfile *.pkl +backend/benchmarks/data.json diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index d7991827162d..e8fa2ef4ea2b 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -3,7 +3,16 @@ from typing import List, Optional from uuid import UUID -from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile, status +from fastapi import ( + APIRouter, + Depends, + File, + HTTPException, + Query, + Response, + UploadFile, + status, +) from quivr_core.models import KnowledgeStatus from quivr_api.celery_config import celery @@ -297,7 +306,7 @@ async def link_knowledge_to_brain( link_request.bulk_id, ) if len(brains_ids) == 0: - return "empty brain list" + return Response(status_code=status.HTTP_204_NO_CONTENT) if knowledge_dto.id is None: if knowledge_dto.sync_file_id is None: @@ -319,34 +328,38 @@ async def link_knowledge_to_brain( knowledge_dto.id, brains_ids=brains_ids, user_id=current_user.id ) - for knowledge in filter( - lambda k: k.status - not in [KnowledgeStatus.PROCESSED, KnowledgeStatus.PROCESSING], - linked_kms, - ): - assert knowledge.id - upload_notification = notification_service.add_notification( - CreateNotification( - user_id=current_user.id, - bulk_id=bulk_id, - status=NotificationsStatusEnum.INFO, - title=f"{knowledge.file_name}", - category="process", + if len(linked_kms) > 0: + for knowledge in filter( + lambda k: k.status + not in [KnowledgeStatus.PROCESSED, KnowledgeStatus.PROCESSING], + linked_kms, + ): + assert knowledge.id + upload_notification = notification_service.add_notification( + CreateNotification( + user_id=current_user.id, + bulk_id=bulk_id, + status=NotificationsStatusEnum.INFO, + title=f"{knowledge.file_name}", + category="process", + ) + ) + celery.send_task( + "process_file_task", + kwargs={ + "knowledge_id": knowledge.id, + "notification_id": upload_notification.id, + }, + ) + knowledge = await knowledge_service.update_knowledge( + knowledge=knowledge, + payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), ) - ) - celery.send_task( - "process_file_task", - kwargs={ - "knowledge_id": knowledge.id, - "notification_id": upload_notification.id, - }, - ) - knowledge = await knowledge_service.update_knowledge( - knowledge=knowledge, - payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), - ) - return await asyncio.gather(*[k.to_dto() for k in linked_kms]) + return await asyncio.gather(*[k.to_dto() for k in linked_kms]) + + else: + return Response(status_code=status.HTTP_204_NO_CONTENT) @knowledge_router.delete( diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index 2174272dfa40..fe7d54b478f8 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -99,10 +99,10 @@ async def to_dto( self, get_children: bool = True, get_parent: bool = True ) -> KnowledgeDTO: assert ( - self.updated_at + await self.awaitable_attrs.updated_at ), "knowledge should be inserted before transforming to dto" assert ( - self.created_at + await self.awaitable_attrs.created_at ), "knowledge should be inserted before transforming to dto" brains = await self.awaitable_attrs.brains children: list[KnowledgeDB] = ( @@ -125,8 +125,8 @@ async def to_dto( is_folder=self.is_folder, file_size=self.file_size or 0, file_sha1=self.file_sha1, - updated_at=self.updated_at, - created_at=self.created_at, + updated_at=await self.awaitable_attrs.updated_at, + created_at=await self.awaitable_attrs.created_at, metadata=self.metadata_, # type: ignore brains=[b.model_dump() for b in brains], parent=parent_dto, diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index a75b75b5aec3..8a3e4637501c 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -96,16 +96,17 @@ async def test_service(): @pytest.mark.asyncio(loop_scope="session") -async def test_post_knowledge_folder(test_client: AsyncClient): +async def test_post_knowledge(test_client: AsyncClient): km_data = { "file_name": "test_file.txt", "source": "local", - "is_folder": True, + "is_folder": False, "parent_id": None, } multipart_data = { "knowledge_data": (None, json.dumps(km_data), "application/json"), + "file": ("test_file.txt", b"Test file content", "application/octet-stream"), } response = await test_client.post( @@ -114,26 +115,19 @@ async def test_post_knowledge_folder(test_client: AsyncClient): ) assert response.status_code == 200 - km = KnowledgeDTO.model_validate(response.json()) - - assert km.id - assert km.is_folder - assert km.parent is None - assert km.children == [] @pytest.mark.asyncio(loop_scope="session") -async def test_post_knowledge(test_client: AsyncClient): +async def test_post_knowledge_folder(test_client: AsyncClient): km_data = { "file_name": "test_file.txt", "source": "local", - "is_folder": False, + "is_folder": True, "parent_id": None, } multipart_data = { "knowledge_data": (None, json.dumps(km_data), "application/json"), - "file": ("test_file.txt", b"Test file content", "application/octet-stream"), } response = await test_client.post( @@ -142,6 +136,12 @@ async def test_post_knowledge(test_client: AsyncClient): ) assert response.status_code == 200 + km = KnowledgeDTO.model_validate(response.json()) + + assert km.id + assert km.is_folder + assert km.parent is None + assert km.children == [] @pytest.mark.asyncio(loop_scope="session") diff --git a/backend/benchmarks/benchmark_kms.sh b/backend/benchmarks/benchmark_kms.sh new file mode 100755 index 000000000000..6c86ca226c5e --- /dev/null +++ b/backend/benchmarks/benchmark_kms.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +# LOG_LEVEL=info rye run uvicorn quivr_api.main:app --log-level info --host 0.0.0.0 --port 5050 --workers 5 --loop uvloop +# supabase db reset && supabase stop && supabase start +# rm benchmarks/data.json +# rye run python benchmarks/load_data.py +LOG_LEVEL=info rye run uvicorn quivr_api.main:app --log-level info --host 0.0.0.0 --port 5050 --workers 5 --loop uvloop& +rye run locust -f benchmarks/locustfile_kms.py -H http://localhost:5050 diff --git a/backend/benchmarks/load_data.py b/backend/benchmarks/load_data.py new file mode 100644 index 000000000000..4e47b2331754 --- /dev/null +++ b/backend/benchmarks/load_data.py @@ -0,0 +1,173 @@ +import os +from typing import List +from uuid import UUID + +import numpy as np +from pydantic import BaseModel +from quivr_api.logger import get_logger +from quivr_api.models.settings import settings +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.brain.entity.brain_user import BrainUserDB +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.user.entity.user_identity import User +from quivr_api.modules.vector.entity.vector import Vector +from sqlmodel import Session, create_engine, select + +N_BRAINS = 100 +N_USERS = 1 +KNOWLEDGE_PER_BRAIN_MAX = 50 +KNOWLEDGE_PER_BRAIN_MIN = 20 +MEAN_VECTORS_PER_KNOWLEDGE = 500 +STD_VECTORS_PER_KNOWLEDGE = 200 +SAVE_PATH = "benchmarks/data.json" + + +logger = get_logger("load_testing") +pg_database_base_url = "postgresql://postgres:postgres@localhost:54322/postgres" + + +class Data(BaseModel): + brains_ids: List[UUID] + knowledges_ids: List[UUID] + vectors_ids: List[UUID] + + +def setup_brains(session: Session, user_id: UUID): + brains = [] + brains_users = [] + + for idx in range(N_BRAINS): + brain = Brain( + name=f"brain_{idx}", + description="this is a test brain", + brain_type=BrainType.integration, + status="private", + ) + brains.append(brain) + + session.add_all(brains) + session.commit() + [session.refresh(b) for b in brains] + + for brain in brains: + brain_user = BrainUserDB( + brain_id=brain.brain_id, + user_id=user_id, + default_brain=True, + rights="Owner", + ) + brains_users.append(brain_user) + session.add_all(brains_users) + session.commit() + + return brains + + +def setup_knowledge_brain(session: Session, brain: Brain, n_km: int, user_id: UUID): + kms = [] + + for idx in range(n_km): + knowledge = KnowledgeDB( + file_name=f"test_file_{idx}_brain_{idx}", + extension="txt", + status="UPLOADED", + source="test_source", + source_link="test_source_link", + file_size=100, + file_sha1=f"{os.urandom(128)}", + brains=[brain], + user_id=user_id, + ) + kms.append(knowledge) + + return kms + + +def setup_vectors_knowledge(session: Session, knowledge: KnowledgeDB, n_vecs: int): + vecs = [] + assert knowledge.id + for idx in range(n_vecs): + vector = Vector( + content=f"vector_{idx}", + metadata_={"file_name": f"{knowledge.file_name}", "chunk_size": 96}, + embedding=np.random.randn(settings.embedding_dim), # type: ignore + knowledge_id=knowledge.id, + ) + + vecs.append(vector) + + return vecs + + +def setup_all(session: Session): + user = (session.exec(select(User).where(User.email == "admin@quivr.app"))).one() + assert user.id + brains = setup_brains(session, user.id) + logger.info(f"Inserted all {len(brains)} brains") + # all_km = [] + # all_vecs = [] + # for brain in brains: + # assert brain + # n_knowledges = random.randint(KNOWLEDGE_PER_BRAIN_MIN, KNOWLEDGE_PER_BRAIN_MAX) + # knowledges = setup_knowledge_brain( + # session, brain=brain, n_km=n_knowledges, user_id=user.id + # ) + # logger.info(f"Inserted all {len(knowledges)} kms for {brain.name}") + # all_km.extend(knowledges) + + # session.add_all(all_km) + # session.commit() + # [session.refresh(b) for b in all_km] + + # n_vecs = np.random.normal( + # MEAN_VECTORS_PER_KNOWLEDGE, STD_VECTORS_PER_KNOWLEDGE, len(all_km) + # ).tolist() + # for n_vecs_km, knowledge in zip(n_vecs, all_km, strict=False): + # vecs = setup_vectors_knowledge(session, knowledge, int(n_vecs_km)) + # all_vecs.extend(vecs) + + # logger.info(f"Inserting all {len(all_vecs)} vecs for knowledge {knowledge.id}") + # session.add_all(all_vecs) + # session.commit() + # [session.refresh(b) for b in all_km] + # [session.refresh(b) for b in all_vecs] + + return Data( + brains_ids=[b.brain_id for b in brains], + knowledges_ids=[], # [k.id for k in all_km], + vectors_ids=[], # [v.id for v in all_vecs], + ) + + +def setup_data(): + logger.info(f"""Starting load data script + N_BRAINS = {N_BRAINS}, + N_USERS = {N_USERS}, + KNOWLEDGE_PER_BRAIN_MIN = {KNOWLEDGE_PER_BRAIN_MIN}, + KNOWLEDGE_PER_BRAIN_MAX = {KNOWLEDGE_PER_BRAIN_MAX }, + MEAN_VECTORS_PER_KNOWLEDGE = {MEAN_VECTORS_PER_KNOWLEDGE} + STD_VECTORS_PER_KNOWLEDGE ={STD_VECTORS_PER_KNOWLEDGE} + """) + sync_engine = create_engine( + pg_database_base_url, + echo=True if os.getenv("ORM_DEBUG") else False, + future=True, + # NOTE: pessimistic bound on + pool_pre_ping=True, + pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6 + pool_recycle=1800, + ) + + with Session(sync_engine, expire_on_commit=False, autoflush=False) as session: + data = setup_all(session) + + logger.info( + f"Insert {len(data.brains_ids)} brains, {len(data.knowledges_ids)} knowledges, {len(data.vectors_ids)} vectors" + ) + + with open(SAVE_PATH, "w") as f: + f.write(data.model_dump_json()) + + +if __name__ == "__main__": + setup_data() diff --git a/backend/benchmarks/locustfile.py b/backend/benchmarks/locustfile.py new file mode 100644 index 000000000000..065d100501f5 --- /dev/null +++ b/backend/benchmarks/locustfile.py @@ -0,0 +1,101 @@ +import os +import random +import tempfile +from typing import List +from uuid import UUID + +from locust import between, task +from locust.contrib.fasthttp import FastHttpUser +from pydantic import BaseModel + +FILE_SIZE = 1024 * 1024 + + +class Data(BaseModel): + brains_ids: List[UUID] + knowledges_ids: List[UUID] + vectors_ids: List[UUID] + + +with open("data.json", "r") as f: + data = Data.model_validate_json(f.read()) + + +class QuivrUser(FastHttpUser): + wait_time = between(0.2, 1) # Wait 1-5 seconds between tasks + host = "http://localhost:5050" + auth_headers = { + "Authorization": "Bearer 123", + } + query_params = "?brain_id=40ba47d7-51b2-4b2a-9247-89e29619efb0" + + def on_start(self): + # Prepare the file to be uploaded + self.file_path = "test_file.txt" + with open(self.file_path, "wb") as f: + f.write(os.urandom(1024)) # 1 KB + + @task(10) + def upload_file(self): + with tempfile.NamedTemporaryFile(suffix="_file.txt") as fp: + fp.write(os.urandom(1024)) # 1 KB + fp.flush() + files = { + "uploadFile": fp, + } + response = self.client.post( + f"/upload{self.query_params}", + files=files, + headers={"Content-Type": "multipart/form-data", **self.auth_headers}, + ) + + # Check if the upload was successful + if response.status_code == 200: + print(f"File uploaded successfully. Response: {response.text}") + else: + print(f"File upload failed. Status code: {response.status_code}") + + upload_file.__name__ = f"{upload_file.__name__}_1MB" + + @task(10) + def get_brains(self): + self.client.get("/brains", headers=self.auth_headers) + + @task(10) + def get_brain_by_id(self): + random_brain_id = random.choice(data.brains_ids) + self.client.get(f"/brains/{random_brain_id}", headers=self.auth_headers) + + @task(10) + def get_knowledge_by_id(self): + random_brain_id = random.choice(data.brains_ids) + self.client.get( + f"/knowledge?brain_id={random_brain_id}", headers=self.auth_headers + ) + + @task(2) + def get_knowledge_signed_url(self): + random_knowledge = random.choice(data.knowledges_ids) + self.client.get( + f"/knowledge/{random_knowledge}/signed_download_url", + headers=self.auth_headers, + ) + + @task(1) + def delete_knowledge(self): + random_knowledge = random.choice(data.knowledges_ids) + data.knowledges_ids.remove(random_knowledge) + self.client.delete( + f"/knowledge/{random_knowledge}", + headers=self.auth_headers, + ) + + def on_stop(self): + # Clean up the test file + if os.path.exists(self.file_path): + os.remove(self.file_path) + + +# GET Knowledge brain +# DELETE knowledge cascades on vectors +# GET /knowledge/{knowledge_id}/signed_download_url diff --git a/backend/benchmarks/locustfile_kms.py b/backend/benchmarks/locustfile_kms.py new file mode 100644 index 000000000000..7a4b3b3df315 --- /dev/null +++ b/backend/benchmarks/locustfile_kms.py @@ -0,0 +1,138 @@ +import io +import json +import os +import random +from typing import List +from uuid import UUID, uuid4 + +from locust import between, task +from locust.contrib.fasthttp import FastHttpUser +from pydantic import BaseModel +from quivr_api.modules.knowledge.dto.inputs import LinkKnowledgeBrain +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO + + +class Data(BaseModel): + brains_ids: List[UUID] + knowledges_ids: List[UUID] + vectors_ids: List[UUID] + + +load_params = { + "data_path": "benchmarks/data.json", + "file_size": 1024 * 1024, # 1MB + "parent_prob": 0.3, + "folder_prob": 0.2, + "km_root_prob": 0.2, + "create_km_rate": 10, + "list_km_rate": 10, + "link_brain_rate": 5, + "max_link_brains": 3, + "delete_km_rate": 2, +} + +with open(load_params["data_path"], "r") as f: + data = Data.model_validate_json(f.read()) + +all_kms: List[KnowledgeDTO] = [] +brains_ids = data.brains_ids + + +def is_folder() -> bool: + return random.random() < load_params["folder_prob"] + + +def get_parent_id() -> str | None: + if random.random() < load_params["parent_prob"] and len(all_kms) > 0: + folders = list(filter(lambda k: k.is_folder, all_kms)) + if len(folders) == 0: + return None + folder = random.choice(folders) + return str(folder.id) + return None + + +class QuivrUser(FastHttpUser): + wait_time = between(0.2, 1) # Wait 1-5 seconds between tasks + host = "http://localhost:5050" + auth_headers = { + "Authorization": "Bearer 123", + } + + data = io.BytesIO(os.urandom(load_params["file_size"])) + + @task(load_params["create_km_rate"]) + def create_knowledge(self): + km_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": is_folder(), + "parent_id": get_parent_id(), + } + + multipart_data = { + "knowledge_data": (None, json.dumps(km_data), "application/json"), + "file": ("test_file.txt", self.data, "application/octet-stream"), + } + response = self.client.post( + "/knowledge/", + headers=self.auth_headers, + files=multipart_data, + ) + returned_km = KnowledgeDTO.model_validate_json(response.text) + all_kms.append(returned_km) + + create_knowledge.__name__ = "create_knowledge_1MB" + + @task(load_params["link_brain_rate"]) + def link_to_brains(self): + if len(all_kms) == 0: + return + nb_brains = random.randint(1, load_params["max_link_brains"]) + random_brains = [random.choice(brains_ids) for _ in range(nb_brains)] + random_km = random.choice(all_kms) + json_data = LinkKnowledgeBrain( + bulk_id=uuid4(), brain_ids=random_brains, knowledge=random_km + ).model_dump_json() + self.client.post( + "/knowledge/link_to_brains/", + data=json_data, + headers={ + "Content-Type": "application/json", + **self.auth_headers, + }, + ) + + link_to_brains.__name__ = "link_to_brain" + + @task(load_params["list_km_rate"]) + def list_knowledge_files(self): + if random.random() < load_params["km_root_prob"] or len(all_kms) == 0: + self.client.get( + "/knowledge/files", + headers=self.auth_headers, + name="/knowledge/files", + ) + else: + random_km = random.choice(all_kms) + self.client.get( + f"/knowledge/files?parent_id={str(random_km.id)}", + headers=self.auth_headers, + name="/knowledge/files", + ) + + list_knowledge_files.__name__ = "list_knowledge_files" + + # @task(load_params["delete_km_rate"]) + # def delete_knowledge_files(self): + # only_files = [idx for idx, km in enumerate(all_kms) if not km.is_folder] + # if len(only_files) == 0: + # return + # random_index = random.choice(only_files) + # random_km = all_kms.pop(random_index) + # self.client.delete( + # f"/knowledge/{str(random_km.id)}", + # headers=self.auth_headers, + # ) + + # delete_knowledge_files.__name__ = "delete_knowledge_file" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index d27c31ce4b67..66b57ee1f4ef 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -36,6 +36,7 @@ dev-dependencies = [ "tox>=4.0.0", "chainlit>=1.1.306", "pytest-profiling>=1.7.0", + "locust>=2.31.7", ] [tool.rye.workspace] diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index dbd87016b7d4..89bdd3dbd59e 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -82,11 +82,15 @@ black==24.8.0 # via flake8-black bleach==6.1.0 # via nbconvert +blinker==1.8.2 + # via flask boto3==1.35.2 # via cohere botocore==1.35.2 # via boto3 # via s3transfer +brotli==1.1.0 + # via geventhttpclient cachetools==5.5.0 # via google-auth # via tox @@ -95,6 +99,7 @@ celery==5.4.0 # via quivr-api # via quivr-worker certifi==2022.12.7 + # via geventhttpclient # via httpcore # via httpx # via requests @@ -102,6 +107,7 @@ certifi==2022.12.7 # via unstructured-client cffi==1.17.0 ; platform_python_implementation != 'PyPy' or implementation_name == 'pypy' # via cryptography + # via gevent # via pyzmq cfgv==3.4.0 # via pre-commit @@ -122,6 +128,7 @@ click==8.1.7 # via click-didyoumean # via click-plugins # via click-repl + # via flask # via litellm # via mkdocs # via mkdocstrings @@ -153,6 +160,8 @@ colorlog==6.8.2 # via quivr-api comm==0.2.2 # via ipykernel +configargparse==1.7 + # via locust contourpy==1.2.1 # via matplotlib coverage==7.6.1 @@ -236,6 +245,14 @@ fire==0.6.0 flake8==7.1.1 # via flake8-black flake8-black==0.3.6 +flask==3.0.3 + # via flask-cors + # via flask-login + # via locust +flask-cors==5.0.0 + # via locust +flask-login==0.6.3 + # via locust flatbuffers==24.3.25 # via onnxruntime flower==2.0.1 @@ -254,6 +271,11 @@ fsspec==2024.2.0 # via llama-index-core # via llama-index-legacy # via torch +gevent==24.2.1 + # via geventhttpclient + # via locust +geventhttpclient==2.3.1 + # via locust ghp-import==2.1.0 # via mkdocs google-api-core==2.19.1 @@ -284,6 +306,7 @@ gotrue==2.7.0 gprof2dot==2024.6.6 # via pytest-profiling greenlet==3.0.3 + # via gevent # via playwright # via sqlalchemy griffe==1.2.0 @@ -356,9 +379,12 @@ ipykernel==6.29.5 # via mkdocs-jupyter ipython==8.26.0 # via ipykernel +itsdangerous==2.2.0 + # via flask jedi==0.19.1 # via ipython jinja2==3.1.3 + # via flask # via litellm # via mkdocs # via mkdocs-material @@ -504,6 +530,7 @@ llama-parse==0.4.9 # via llama-index-readers-llama-parse # via megaparse # via quivr-api +locust==2.31.7 lxml==5.3.0 # via pikepdf # via python-docx @@ -530,6 +557,7 @@ markupsafe==2.1.5 # via mkdocs-autorefs # via mkdocstrings # via nbconvert + # via werkzeug marshmallow==3.22.0 # via dataclasses-json # via marshmallow-enum @@ -582,6 +610,8 @@ mpmath==1.3.0 # via sympy msal==1.30.0 # via quivr-api +msgpack==1.1.0 + # via locust multidict==6.0.5 # via aiohttp # via yarl @@ -806,6 +836,7 @@ protobuf==4.25.4 # via transformers psutil==6.0.0 # via ipykernel + # via locust # via unstructured psycopg2-binary==2.9.9 # via quivr-api @@ -946,8 +977,9 @@ python-socketio==5.11.3 pytz==2024.1 # via flower # via pandas -pywin32==306 ; (platform_python_implementation != 'PyPy' and sys_platform == 'win32') or platform_system == 'Windows' +pywin32==306 ; sys_platform == 'win32' or platform_system == 'Windows' # via jupyter-core + # via locust # via portalocker pyyaml==6.0.2 # via huggingface-hub @@ -970,6 +1002,7 @@ pyyaml-env-tag==0.1 pyzmq==26.1.1 # via ipykernel # via jupyter-client + # via locust rapidfuzz==3.9.6 # via unstructured # via unstructured-inference @@ -997,6 +1030,7 @@ requests==2.32.3 # via litellm # via llama-index-core # via llama-index-legacy + # via locust # via mkdocs-material # via msal # via opentelemetry-exporter-otlp-proto-http @@ -1033,6 +1067,8 @@ sentry-sdk==2.13.0 # via quivr-api setuptools==70.0.0 # via opentelemetry-instrumentation + # via zope-event + # via zope-interface simple-websocket==1.0.0 # via python-engineio six==1.16.0 @@ -1219,6 +1255,7 @@ uritemplate==4.1.1 # via google-api-python-client urllib3==1.26.13 # via botocore + # via geventhttpclient # via requests # via sentry-sdk # via unstructured-client @@ -1245,6 +1282,10 @@ webencodings==0.5.1 # via tinycss2 websockets==12.0 # via realtime +werkzeug==3.0.4 + # via flask + # via flask-login + # via locust wrapt==1.16.0 # via deprecated # via llama-index-core @@ -1260,3 +1301,7 @@ yarl==1.9.4 # via aiohttp zipp==3.20.0 # via importlib-metadata +zope-event==5.0 + # via gevent +zope-interface==7.0.3 + # via gevent From fa46f3019bcb91ace84d95c43a154d698d438cee Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 27 Sep 2024 18:44:20 +0200 Subject: [PATCH 40/63] fix awaitable attr knowledge --- .../knowledge/controller/knowledge_routes.py | 85 ++++++------------- .../knowledge/service/knowledge_service.py | 4 +- backend/benchmarks/benchmark_kms.sh | 39 +++++++-- backend/benchmarks/locustfile_kms.py | 27 +++--- 4 files changed, 77 insertions(+), 78 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index e8fa2ef4ea2b..137532ce79a0 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -18,9 +18,7 @@ from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user -from quivr_api.modules.brain.entity.brain_entity import RoleEnum from quivr_api.modules.brain.service.brain_authorization_service import ( - has_brain_authorization, validate_brain_authorization, ) from quivr_api.modules.dependencies import get_service @@ -235,33 +233,6 @@ async def update_knowledge( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) -@knowledge_router.delete( - "/knowledge/{knowledge_id}", - dependencies=[ - Depends(AuthBearer()), - Depends(has_brain_authorization(RoleEnum.Owner)), - ], - tags=["Knowledge"], -) -async def delete_knowledge_brain( - knowledge_id: UUID, - knowledge_service: KnowledgeService = Depends(get_knowledge_service), - current_user: UserIdentity = Depends(get_current_user), - brain_id: UUID = Query(..., description="The ID of the brain"), -): - """ - Delete a specific knowledge from a brain. - """ - - knowledge = await knowledge_service.get_knowledge(knowledge_id) - file_name = knowledge.file_name if knowledge.file_name else knowledge.url - await knowledge_service.remove_knowledge_brain(brain_id, knowledge_id) - - return { - "message": f"{file_name} of brain {brain_id} has been deleted by user {current_user.email}." - } - - @knowledge_router.delete( "/knowledge/{knowledge_id}", status_code=status.HTTP_202_ACCEPTED, @@ -328,38 +299,34 @@ async def link_knowledge_to_brain( knowledge_dto.id, brains_ids=brains_ids, user_id=current_user.id ) - if len(linked_kms) > 0: - for knowledge in filter( - lambda k: k.status - not in [KnowledgeStatus.PROCESSED, KnowledgeStatus.PROCESSING], - linked_kms, - ): - assert knowledge.id - upload_notification = notification_service.add_notification( - CreateNotification( - user_id=current_user.id, - bulk_id=bulk_id, - status=NotificationsStatusEnum.INFO, - title=f"{knowledge.file_name}", - category="process", - ) - ) - celery.send_task( - "process_file_task", - kwargs={ - "knowledge_id": knowledge.id, - "notification_id": upload_notification.id, - }, - ) - knowledge = await knowledge_service.update_knowledge( - knowledge=knowledge, - payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + for knowledge in [ + k + for k in linked_kms + if await k.awaitable_attrs.status + not in [KnowledgeStatus.PROCESSED, KnowledgeStatus.PROCESSING] + ]: + upload_notification = notification_service.add_notification( + CreateNotification( + user_id=current_user.id, + bulk_id=bulk_id, + status=NotificationsStatusEnum.INFO, + title=f"{await knowledge.awaitable_attrs.file_name}", + category="process", ) + ) + celery.send_task( + "process_file_task", + kwargs={ + "knowledge_id": await knowledge.awaitable_attrs.id, + "notification_id": upload_notification.id, + }, + ) + knowledge = await knowledge_service.update_knowledge( + knowledge=knowledge, + payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + ) - return await asyncio.gather(*[k.to_dto() for k in linked_kms]) - - else: - return Response(status_code=status.HTTP_204_NO_CONTENT) + return await asyncio.gather(*[k.to_dto() for k in linked_kms]) @knowledge_router.delete( diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index c5b03aec2500..f49ab439d1e3 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -142,9 +142,11 @@ async def get_knowledge( async def update_knowledge( self, - knowledge: KnowledgeDB, + knowledge: KnowledgeDB | UUID, payload: KnowledgeDTO | KnowledgeUpdate | dict[str, Any], ): + if isinstance(knowledge, UUID): + knowledge = await self.repository.get_knowledge_by_id(knowledge) return await self.repository.update_knowledge(knowledge, payload) async def create_knowledge( diff --git a/backend/benchmarks/benchmark_kms.sh b/backend/benchmarks/benchmark_kms.sh index 6c86ca226c5e..21976ec1ec78 100755 --- a/backend/benchmarks/benchmark_kms.sh +++ b/backend/benchmarks/benchmark_kms.sh @@ -1,8 +1,37 @@ #!/bin/bash -# LOG_LEVEL=info rye run uvicorn quivr_api.main:app --log-level info --host 0.0.0.0 --port 5050 --workers 5 --loop uvloop -# supabase db reset && supabase stop && supabase start -# rm benchmarks/data.json -# rye run python benchmarks/load_data.py -LOG_LEVEL=info rye run uvicorn quivr_api.main:app --log-level info --host 0.0.0.0 --port 5050 --workers 5 --loop uvloop& +# Function to handle cleanup on exit +cleanup() { + echo "Cleaning up..." + # Stop Uvicorn server if running + if [[ ! -z "$UVICORN_PID" ]]; then + kill "$UVICORN_PID" + wait "$UVICORN_PID" + fi + exit 0 +} + +# Trap signals (like Ctrl+C, SIGTERM) to run the cleanup function +#trap cleanup SIGINT SIGTERM + +# Reset and start Supabase +supabase db reset && supabase stop && supabase start + +# Remove old benchmark data +rm -f benchmarks/data.json + +# Load new data +rye run python benchmarks/load_data.py + +# Start Uvicorn server in the background +LOG_LEVEL=info rye run uvicorn quivr_api.main:app --log-level info --host 0.0.0.0 --port 5050 --workers 5 --loop uvloop & +UVICORN_PID=$! + +# Wait a bit to ensure the server is running +sleep 1 + +# Run Locust for benchmarking rye run locust -f benchmarks/locustfile_kms.py -H http://localhost:5050 + +# Wait for all background processes (including Uvicorn) to finish +wait "$UVICORN_PID" diff --git a/backend/benchmarks/locustfile_kms.py b/backend/benchmarks/locustfile_kms.py index 7a4b3b3df315..8e70ff0e7b5e 100644 --- a/backend/benchmarks/locustfile_kms.py +++ b/backend/benchmarks/locustfile_kms.py @@ -123,16 +123,17 @@ def list_knowledge_files(self): list_knowledge_files.__name__ = "list_knowledge_files" - # @task(load_params["delete_km_rate"]) - # def delete_knowledge_files(self): - # only_files = [idx for idx, km in enumerate(all_kms) if not km.is_folder] - # if len(only_files) == 0: - # return - # random_index = random.choice(only_files) - # random_km = all_kms.pop(random_index) - # self.client.delete( - # f"/knowledge/{str(random_km.id)}", - # headers=self.auth_headers, - # ) - - # delete_knowledge_files.__name__ = "delete_knowledge_file" + @task(load_params["delete_km_rate"]) + def delete_knowledge_files(self): + only_files = [idx for idx, km in enumerate(all_kms) if not km.is_folder] + if len(only_files) == 0: + return + random_index = random.choice(only_files) + random_km = all_kms.pop(random_index) + self.client.delete( + f"/knowledge/{str(random_km.id)}", + headers=self.auth_headers, + name="/knowledge/delete", + ) + + delete_knowledge_files.__name__ = "delete_knowledge_file" From 8d0ca965b6962ea8a53149bd342874161d4fc0e5 Mon Sep 17 00:00:00 2001 From: aminediro Date: Sat, 28 Sep 2024 15:18:20 +0200 Subject: [PATCH 41/63] pool size as param + locust on_start/on_stop setup --- .env.example | 2 + backend/api/quivr_api/models/settings.py | 2 + backend/api/quivr_api/modules/dependencies.py | 10 ++- backend/benchmarks/locustfile_kms.py | 84 ++++++++++++++++--- 4 files changed, 82 insertions(+), 16 deletions(-) diff --git a/.env.example b/.env.example index d69382ebe103..22933f62f326 100644 --- a/.env.example +++ b/.env.example @@ -54,6 +54,8 @@ EXTERNAL_SUPABASE_URL=http://localhost:54321 SUPABASE_SERVICE_KEY=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZS1kZW1vIiwicm9sZSI6InNlcnZpY2Vfcm9sZSIsImV4cCI6MTk4MzgxMjk5Nn0.EGIM96RAZx35lJzdJsyH-qQwv8Hdp7fsn3W0YpN81IU PG_DATABASE_URL=postgresql://postgres:postgres@host.docker.internal:54322/postgres PG_DATABASE_ASYNC_URL=postgresql+asyncpg://postgres:postgres@host.docker.internal:54322/postgres +SQLALCHEMY_POOL_SIZE=10 +SQLALCHEMY_MAX_POOL_OVERFLOW=0 JWT_SECRET_KEY=super-secret-jwt-token-with-at-least-32-characters-long AUTHENTICATE=true TELEMETRY_ENABLED=true diff --git a/backend/api/quivr_api/models/settings.py b/backend/api/quivr_api/models/settings.py index 5987addc7e72..7a2a94871e02 100644 --- a/backend/api/quivr_api/models/settings.py +++ b/backend/api/quivr_api/models/settings.py @@ -123,6 +123,8 @@ class BrainSettings(BaseSettings): pg_database_url: str pg_database_async_url: str embedding_dim: int + sqlalchemy_pool_size: int + sqlalchemy_max_pool_overflow: int class ResendSettings(BaseSettings): diff --git a/backend/api/quivr_api/modules/dependencies.py b/backend/api/quivr_api/modules/dependencies.py index fd71696cd151..de1940abf295 100644 --- a/backend/api/quivr_api/modules/dependencies.py +++ b/backend/api/quivr_api/modules/dependencies.py @@ -61,16 +61,20 @@ def get_repository_cls(cls) -> Type[R]: future=True, # NOTE: pessimistic bound on pool_pre_ping=True, - pool_size=10, # NOTE: no bouncer for now, if 6 process workers => 6 + pool_size=1, # NOTE: no bouncer for now, if 6 process workers => 6 + max_overflow=0, pool_recycle=1800, ) async_engine = create_async_engine( settings.pg_database_async_url, - connect_args={"server_settings": {"application_name": "quivr-api-async"}}, + connect_args={ + "server_settings": {"application_name": "quivr-api-async"}, + }, echo=True if os.getenv("ORM_DEBUG") else False, future=True, pool_pre_ping=True, - pool_size=5, # NOTE: no bouncer for now, if 6 process workers => 6 + pool_size=settings.sqlalchemy_pool_size, # NOTE: no bouncer for now, if 6 process workers => 6 + max_overflow=settings.sqlalchemy_max_pool_overflow, pool_recycle=1800, isolation_level="AUTOCOMMIT", ) diff --git a/backend/benchmarks/locustfile_kms.py b/backend/benchmarks/locustfile_kms.py index 8e70ff0e7b5e..3fc3b41fe508 100644 --- a/backend/benchmarks/locustfile_kms.py +++ b/backend/benchmarks/locustfile_kms.py @@ -1,4 +1,3 @@ -import io import json import os import random @@ -7,19 +6,19 @@ from locust import between, task from locust.contrib.fasthttp import FastHttpUser -from pydantic import BaseModel +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.brain.entity.brain_user import BrainUserDB +from quivr_api.modules.dependencies import get_supabase_client from quivr_api.modules.knowledge.dto.inputs import LinkKnowledgeBrain from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_api.modules.user.entity.user_identity import User +from sqlmodel import Session, create_engine, select, text - -class Data(BaseModel): - brains_ids: List[UUID] - knowledges_ids: List[UUID] - vectors_ids: List[UUID] +pg_database_base_url = "postgresql://postgres:postgres@localhost:54322/postgres" load_params = { - "data_path": "benchmarks/data.json", + "n_brains": 100, "file_size": 1024 * 1024, # 1MB "parent_prob": 0.3, "folder_prob": 0.2, @@ -31,11 +30,9 @@ class Data(BaseModel): "delete_km_rate": 2, } -with open(load_params["data_path"], "r") as f: - data = Data.model_validate_json(f.read()) all_kms: List[KnowledgeDTO] = [] -brains_ids = data.brains_ids +brains_ids: List[UUID] = [] def is_folder() -> bool: @@ -52,14 +49,61 @@ def get_parent_id() -> str | None: return None +def setup_brains(session: Session, user_id: UUID) -> List[Brain]: + brains = [] + brains_users = [] + + for idx in range(load_params["n_brains"]): + brain = Brain( + name=f"brain_{idx}", + description="this is a test brain", + brain_type=BrainType.integration, + status="private", + ) + brains.append(brain) + + session.add_all(brains) + session.commit() + [session.refresh(b) for b in brains] + + for brain in brains: + brain_user = BrainUserDB( + brain_id=brain.brain_id, + user_id=user_id, + default_brain=True, + rights="Owner", + ) + brains_users.append(brain_user) + session.add_all(brains_users) + session.commit() + + return brains + + class QuivrUser(FastHttpUser): - wait_time = between(0.2, 1) # Wait 1-5 seconds between tasks + # Wait 1-5 seconds between tasks + wait_time = between(1, 5) host = "http://localhost:5050" auth_headers = { "Authorization": "Bearer 123", } - data = io.BytesIO(os.urandom(load_params["file_size"])) + data = os.urandom(load_params["file_size"]) + sync_engine = create_engine( + pg_database_base_url, + echo=True, + ) + + def on_start(self) -> None: + global brains_ids + + with Session(self.sync_engine) as session: + user = ( + session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + assert user.id + brains = setup_brains(session, user.id) + brains_ids = [b.brain_id for b in brains] # type: ignore @task(load_params["create_km_rate"]) def create_knowledge(self): @@ -137,3 +181,17 @@ def delete_knowledge_files(self): ) delete_knowledge_files.__name__ = "delete_knowledge_file" + + def on_stop(self): + global brains_ids + global all_kms + all_kms = [] + brains_ids = [] + # Cleanup db + with Session(self.sync_engine) as session: + session.execute(text("DELETE FROM brains;")) + session.execute(text("DELETE FROM knowledge;")) + session.commit() + # Cleanup storage + client = get_supabase_client() + client.storage.empty_bucket("quivr") From ae5f6a880e00ddbc888de465c497eb68515c0e8b Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 30 Sep 2024 13:59:42 +0200 Subject: [PATCH 42/63] sync init --- backend/api/quivr_api/models/settings.py | 11 --- .../quivr_api/models/sqlalchemy_repository.py | 73 ------------------ backend/api/quivr_api/modules/dependencies.py | 5 -- .../modules/knowledge/dto/outputs.py | 1 + .../modules/knowledge/entity/knowledge.py | 12 ++- .../knowledge/repository/knowledges.py | 11 +++ .../knowledge/service/knowledge_service.py | 5 +- .../modules/sync/entity/sync_models.py | 1 - .../20240920180003_knowledge-sync.sql | 6 +- backend/worker/quivr_worker/celery_worker.py | 17 +++++ .../worker/quivr_worker/process/__init__.py | 3 +- .../worker/quivr_worker/process/processor.py | 68 +---------------- .../worker/quivr_worker/syncs/update_syncs.py | 11 +++ backend/worker/quivr_worker/utils/services.py | 74 +++++++++++++++++++ backend/worker/tests/conftest.py | 2 +- 15 files changed, 136 insertions(+), 164 deletions(-) delete mode 100644 backend/api/quivr_api/models/sqlalchemy_repository.py create mode 100644 backend/worker/quivr_worker/syncs/update_syncs.py create mode 100644 backend/worker/quivr_worker/utils/services.py diff --git a/backend/api/quivr_api/models/settings.py b/backend/api/quivr_api/models/settings.py index 7a2a94871e02..2f4c780d39fd 100644 --- a/backend/api/quivr_api/models/settings.py +++ b/backend/api/quivr_api/models/settings.py @@ -1,13 +1,9 @@ -from typing import Optional from uuid import UUID from posthog import Posthog from pydantic_settings import BaseSettings, SettingsConfigDict -from sqlalchemy import Engine from quivr_api.logger import get_logger -from quivr_api.models.databases.supabase.supabase import SupabaseDB -from supabase.client import AsyncClient, Client logger = get_logger(__name__) @@ -136,11 +132,4 @@ class ResendSettings(BaseSettings): quivr_smtp_password: str = "" -# Global variables to store the Supabase client and database instances -_supabase_client: Optional[Client] = None -_supabase_async_client: Optional[AsyncClient] = None -_supabase_db: Optional[SupabaseDB] = None -_db_engine: Optional[Engine] = None -_embedding_service = None - settings = BrainSettings() # type: ignore diff --git a/backend/api/quivr_api/models/sqlalchemy_repository.py b/backend/api/quivr_api/models/sqlalchemy_repository.py deleted file mode 100644 index 7b295187973a..000000000000 --- a/backend/api/quivr_api/models/sqlalchemy_repository.py +++ /dev/null @@ -1,73 +0,0 @@ -from datetime import datetime -from uuid import uuid4 - -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import relationship - -Base = declarative_base() - - -class User(Base): - __tablename__ = "users" - - user_id = Column(String, primary_key=True) - email = Column(String) - date = Column(DateTime) - daily_requests_count = Column(Integer) - - -class Brain(Base): - __tablename__ = "brains" - - brain_id = Column(Integer, primary_key=True) - name = Column(String) - users = relationship("BrainUser", back_populates="brain") - vectors = relationship("BrainVector", back_populates="brain") - - -class BrainUser(Base): - __tablename__ = "brains_users" - - id = Column(Integer, primary_key=True) - user_id = Column(Integer, ForeignKey("users.user_id")) - brain_id = Column(Integer, ForeignKey("brains.brain_id")) - rights = Column(String) - - user = relationship("User") - brain = relationship("Brain", back_populates="users") - - -class BrainVector(Base): - __tablename__ = "brains_vectors" - - vector_id = Column(String, primary_key=True, default=lambda: str(uuid4())) - brain_id = Column(Integer, ForeignKey("brains.brain_id")) - file_sha1 = Column(String) - - brain = relationship("Brain", back_populates="vectors") - - -class BrainSubscriptionInvitation(Base): - __tablename__ = "brain_subscription_invitations" - - id = Column(Integer, primary_key=True) # Assuming an integer primary key named 'id' - brain_id = Column(String, ForeignKey("brains.brain_id")) - email = Column(String, ForeignKey("users.email")) - rights = Column(String) - - brain = relationship("Brain") - user = relationship("User", foreign_keys=[email]) - - -class ApiKey(Base): - __tablename__ = "api_keys" - - key_id = Column(String, primary_key=True, default=lambda: str(uuid4())) - user_id = Column(Integer, ForeignKey("users.user_id")) - api_key = Column(String, unique=True) - creation_time = Column(DateTime, default=datetime.utcnow) - is_active = Column(Boolean, default=True) - deleted_time = Column(DateTime, nullable=True) - - user = relationship("User") diff --git a/backend/api/quivr_api/modules/dependencies.py b/backend/api/quivr_api/modules/dependencies.py index de1940abf295..073881e06f73 100644 --- a/backend/api/quivr_api/modules/dependencies.py +++ b/backend/api/quivr_api/modules/dependencies.py @@ -5,12 +5,7 @@ from fastapi import Depends from langchain.embeddings.base import Embeddings from langchain_community.embeddings.ollama import OllamaEmbeddings - -# from langchain_community.vectorstores.supabase import SupabaseVectorStore from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings - -# from quivr_api.modules.vector.service.vector_service import VectorService -# from quivr_api.modules.vectorstore.supabase import CustomSupabaseVectorStore from sqlalchemy import Engine, create_engine from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import Session, text diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py index 5cf0a0ebf911..b60fc0a7b51b 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py @@ -33,3 +33,4 @@ class KnowledgeDTO(BaseModel): children: List[Self] sync_id: int | None sync_file_id: str | None + last_synced_at: datetime | None = None diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index fe7d54b478f8..b4ff593e327e 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -50,18 +50,25 @@ class KnowledgeDB(AsyncAttrs, SQLModel, table=True): created_at: datetime | None = Field( default=None, sa_column=Column( - TIMESTAMP(timezone=False), + TIMESTAMP(timezone=True), server_default=text("CURRENT_TIMESTAMP"), ), ) updated_at: datetime | None = Field( default=None, sa_column=Column( - TIMESTAMP(timezone=False), + TIMESTAMP(timezone=True), server_default=text("CURRENT_TIMESTAMP"), onupdate=datetime.utcnow, ), ) + + last_synced_at: datetime | None = Field( + default=None, + sa_column=Column( + TIMESTAMP(timezone=True), + ), + ) metadata_: Optional[Dict[str, str]] = Field( default=None, sa_column=Column("metadata", JSON) ) @@ -134,4 +141,5 @@ async def to_dto( user_id=self.user_id, sync_id=self.sync_id, sync_file_id=self.sync_file_id, + last_synced_at=self.last_synced_at, ) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index ceeebaec5484..f9c74b15b6a2 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -435,3 +435,14 @@ async def get_all_knowledge(self) -> Sequence[KnowledgeDB]: query = select(KnowledgeDB) result = await self.session.exec(query) return result.all() + + async def get_sync_knowledges_to_update(self,batch_size: int = 1) -> Sequence[KnowledgeDB]: + query = select(KnowledgeDB) + .where( + and_( + col(KnowledgeDB.sync_id).in_(brains_ids), + ) + ) + + result = await self.session.exec(query) + return result.all() diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index f49ab439d1e3..57dcc5b335d5 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -66,12 +66,15 @@ async def create_or_link_sync_knowledge( # The parent_knowledge was just added (we are processing it) # This implies that we could have sync children that were processed before # IF SyncKnowledge already exists => It's already processed in some other brain - # => Link it to the parent brains and move on if it is PROCESSED ELSE Reprocess the file + # => Link it to the parent and add its brains and move on if it is PROCESSED ELSE Reprocess the file km_brains = {km_brain.brain_id for km_brain in existing_km.brains} for brain in filter( lambda b: b.brain_id not in km_brains, parent_knowledge.brains, ): + await self.repository.update_knowledge( + existing_km, KnowledgeUpdate(parent_id=parent_knowledge.id) + ) await self.repository.link_to_brain( existing_km, brain_id=brain.brain_id ) diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index 7ba040acd1cd..b85a0476e673 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -77,7 +77,6 @@ class Sync(SQLModel, table=True): onupdate=datetime.utcnow, ), ) - last_synced_at: datetime | None = Field(default=None) additional_data: dict | None = Field( default=None, sa_column=Column("additional_data", JSON) ) diff --git a/backend/supabase/migrations/20240920180003_knowledge-sync.sql b/backend/supabase/migrations/20240920180003_knowledge-sync.sql index f8897e23a7ec..ff0995141a96 100644 --- a/backend/supabase/migrations/20240920180003_knowledge-sync.sql +++ b/backend/supabase/migrations/20240920180003_knowledge-sync.sql @@ -1,13 +1,15 @@ -- Renamed syncs ALTER TABLE syncs_user RENAME TO syncs; --- Add column foreign key sync in knowledge +-- Add column foreign key sync in knowledge ALTER TABLE "public"."knowledge" ADD COLUMN "sync_id" INTEGER; ALTER TABLE "public"."knowledge" ADD CONSTRAINT "public_knowledge_sync_id_fkey" FOREIGN KEY (sync_id) REFERENCES syncs(id) ON DELETE CASCADE; -- Add column for sync_file_ids ALTER TABLE "public"."knowledge" +ADD COLUMN "last_synced_at" timestamp with time zone; +ALTER TABLE "public"."knowledge" ADD COLUMN "sync_file_id" TEXT; CREATE INDEX knowledge_sync_id_pkey ON public.knowledge USING btree (sync_id); CREATE INDEX knowledge_sync_file_id_pkey ON public.knowledge USING btree (sync_file_id); @@ -16,8 +18,6 @@ alter table "public"."syncs" add column "created_at" timestamp with time zone default now(); alter table "public"."syncs" add column "updated_at" timestamp with time zone default now(); -alter table "public"."syncs" -add column "last_synced_at" timestamp with time zone; -- Drop files DROP TABLE IF EXISTS "public"."syncs_active" CASCADE; DROP TABLE IF EXISTS "public"."syncs_files" CASCADE; diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index 25a51c10d933..c58b04146bf0 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -16,6 +16,7 @@ from quivr_worker.assistants.assistants import aprocess_assistant_task from quivr_worker.check_premium import check_is_premium from quivr_worker.process import aprocess_file_task +from quivr_worker.syncs.update_syncs import update_sync_files from quivr_worker.utils.utils import _patch_json load_dotenv() @@ -70,6 +71,22 @@ def process_file_task( ) +@celery.task( + retries=3, + default_retry_delay=1, + name="process_file_task", + autoretry_for=(Exception,), + dont_autoretry_for=(FileExistsError,), +) +def update_sync_task(): + if async_engine is None: + init_worker() + assert async_engine + logger.info("Update sync task started") + loop = asyncio.get_event_loop() + loop.run_until_complete(update_sync_files(async_engine=async_engine)) + + @celery.task( retries=3, default_retry_delay=1, diff --git a/backend/worker/quivr_worker/process/__init__.py b/backend/worker/quivr_worker/process/__init__.py index a8d86c84f78c..d44d4bd6bff5 100644 --- a/backend/worker/quivr_worker/process/__init__.py +++ b/backend/worker/quivr_worker/process/__init__.py @@ -2,7 +2,8 @@ from sqlalchemy.ext.asyncio import AsyncEngine -from quivr_worker.process.processor import KnowledgeProcessor, build_processor_services +from quivr_worker.process.processor import KnowledgeProcessor +from quivr_worker.utils.services import build_processor_services async def aprocess_file_task(async_engine: AsyncEngine, knowledge_id: UUID): diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 5e5ace44cc09..342afe84c947 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -1,93 +1,29 @@ import asyncio -from contextlib import asynccontextmanager -from dataclasses import dataclass from pathlib import Path from typing import AsyncGenerator, List, Optional, Tuple from uuid import UUID from quivr_api.logger import get_logger -from quivr_api.modules.dependencies import get_supabase_async_client from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource -from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository -from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage -from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.entity.sync_models import SyncFile -from quivr_api.modules.sync.repository.sync_repository import SyncsRepository -from quivr_api.modules.sync.service.sync_service import SyncsService -from quivr_api.modules.sync.utils.sync import ( - BaseSync, -) -from quivr_api.modules.vector.repository.vectors_repository import VectorRepository -from quivr_api.modules.vector.service.vector_service import VectorService from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus -from sqlalchemy.ext.asyncio import AsyncEngine -from sqlmodel import text -from sqlmodel.ext.asyncio.session import AsyncSession from quivr_worker.parsers.crawler import URL, extract_from_url from quivr_worker.process.process_file import parse_qfile, store_chunks from quivr_worker.process.utils import ( build_qfile, build_sync_file, - build_syncprovider_mapping, compute_sha1, skip_process, ) +from quivr_worker.utils.services import ProcessorServices logger = get_logger("celery_worker") -@dataclass -class ProcessorServices: - sync_service: SyncsService - vector_service: VectorService - knowledge_service: KnowledgeService - syncprovider_mapping: dict[SyncProvider, BaseSync] - - -@asynccontextmanager -async def _start_session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: - async with AsyncSession(engine) as session: - try: - await session.execute( - text("SET SESSION idle_in_transaction_session_timeout = '5min';") - ) - yield session - await session.commit() - except Exception as e: - await session.rollback() - raise e - finally: - await session.close() - - -@asynccontextmanager -async def build_processor_services( - engine: AsyncEngine, -) -> AsyncGenerator[ProcessorServices, None]: - async_client = await get_supabase_async_client() - storage = SupabaseS3Storage(async_client) - try: - async with _start_session(engine) as session: - vector_repository = VectorRepository(session) - vector_service = VectorService(vector_repository) - knowledge_repository = KnowledgeRepository(session) - knowledge_service = KnowledgeService(knowledge_repository, storage=storage) - sync_repository = SyncsRepository(session) - sync_service = SyncsService(sync_repository) - yield ProcessorServices( - knowledge_service=knowledge_service, - vector_service=vector_service, - sync_service=sync_service, - syncprovider_mapping=build_syncprovider_mapping(), - ) - finally: - logger.info("Closing processor services") - - class KnowledgeProcessor: def __init__(self, services: ProcessorServices): self.services = services @@ -261,7 +197,7 @@ async def _yield_syncs( async def process_knowledge(self, knowledge_id: UUID): async for knowledge_tuple in self.yield_processable_knowledge(knowledge_id): - # FIXME + # FIXME(@AmineDiro) : nested transaction for making savepoint = ( await self.services.knowledge_service.repository.session.begin_nested() ) diff --git a/backend/worker/quivr_worker/syncs/update_syncs.py b/backend/worker/quivr_worker/syncs/update_syncs.py new file mode 100644 index 000000000000..a2d6660a9cd3 --- /dev/null +++ b/backend/worker/quivr_worker/syncs/update_syncs.py @@ -0,0 +1,11 @@ +from sqlalchemy.ext.asyncio import AsyncEngine + +from quivr_worker.utils.services import build_processor_services + + +async def update_sync_files(async_engine: AsyncEngine): + async with build_processor_services(async_engine) as processor_services: + pass + + +# If knowledge is folder just call the link_knowledge_to_brain diff --git a/backend/worker/quivr_worker/utils/services.py b/backend/worker/quivr_worker/utils/services.py new file mode 100644 index 000000000000..3fe2f6dbdeb5 --- /dev/null +++ b/backend/worker/quivr_worker/utils/services.py @@ -0,0 +1,74 @@ +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import AsyncGenerator + +from quivr_api.logger import get_logger +from quivr_api.modules.dependencies import get_supabase_async_client +from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository +from quivr_api.modules.knowledge.repository.storage import SupabaseS3Storage +from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.repository.sync_repository import SyncsRepository +from quivr_api.modules.sync.service.sync_service import SyncsService +from quivr_api.modules.sync.utils.sync import ( + BaseSync, +) +from quivr_api.modules.vector.repository.vectors_repository import VectorRepository +from quivr_api.modules.vector.service.vector_service import VectorService +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlmodel import text +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_worker.process.utils import ( + build_syncprovider_mapping, +) + +logger = get_logger("celery_worker") + + +@dataclass +class ProcessorServices: + sync_service: SyncsService + vector_service: VectorService + knowledge_service: KnowledgeService + syncprovider_mapping: dict[SyncProvider, BaseSync] + + +@asynccontextmanager +async def _start_session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: + async with AsyncSession(engine) as session: + try: + await session.execute( + text("SET SESSION idle_in_transaction_session_timeout = '5min';") + ) + yield session + await session.commit() + except Exception as e: + await session.rollback() + raise e + finally: + await session.close() + + +@asynccontextmanager +async def build_processor_services( + engine: AsyncEngine, +) -> AsyncGenerator[ProcessorServices, None]: + async_client = await get_supabase_async_client() + storage = SupabaseS3Storage(async_client) + try: + async with _start_session(engine) as session: + vector_repository = VectorRepository(session) + vector_service = VectorService(vector_repository) + knowledge_repository = KnowledgeRepository(session) + knowledge_service = KnowledgeService(knowledge_repository, storage=storage) + sync_repository = SyncsRepository(session) + sync_service = SyncsService(sync_repository) + yield ProcessorServices( + knowledge_service=knowledge_service, + vector_service=vector_service, + sync_service=sync_service, + syncprovider_mapping=build_syncprovider_mapping(), + ) + finally: + logger.info("Closing processor services") diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index ebe44a02241a..e993f50138a5 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -28,7 +28,7 @@ from quivr_api.modules.vector.service.vector_service import VectorService from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus -from quivr_worker.process.processor import ProcessorServices +from quivr_worker.utils.services import ProcessorServices from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession From 95df9194207096feb635dc954f671c407c25f214 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 30 Sep 2024 14:33:23 +0200 Subject: [PATCH 43/63] sync status added --- backend/api/quivr_api/modules/sync/entity/sync_models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index 7ba040acd1cd..f799b93f29a6 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -17,6 +17,7 @@ ) from sqlmodel import UUID as PGUUID +from quivr_api.modules.sync.dto.inputs import SyncStatus from quivr_api.modules.sync.dto.outputs import SyncProvider, SyncsOutput from quivr_api.modules.user.entity.user_identity import User @@ -62,6 +63,7 @@ class Sync(SQLModel, table=True): default=None, sa_column=Column("credentials", JSON) ) state: Dict[str, str] | None = Field(default=None, sa_column=Column("state", JSON)) + status: str = Field(default=SyncStatus.SYNCING) created_at: datetime | None = Field( default=None, sa_column=Column( From 8d729dee64a5596edf1d344eaef592080c1bb8fd Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 30 Sep 2024 15:25:58 +0200 Subject: [PATCH 44/63] get files last synced at --- .../knowledge/repository/knowledges.py | 30 ++-- .../knowledge/service/knowledge_service.py | 7 + .../modules/knowledge/tests/conftest.py | 62 ++++++++ .../knowledge/tests/test_knowledge_entity.py | 60 +------- .../knowledge/tests/test_knowledge_service.py | 145 +++++++++++++++++- 5 files changed, 236 insertions(+), 68 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index f9c74b15b6a2..ccc550c43060 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta, timezone from typing import Any, List, Sequence from uuid import UUID @@ -5,7 +6,8 @@ from quivr_core.models import KnowledgeStatus from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.orm import joinedload -from sqlmodel import and_, col, select, text +from sqlalchemy.sql.functions import random +from sqlmodel import and_, col, not_, select, text from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.logger import get_logger @@ -436,13 +438,23 @@ async def get_all_knowledge(self) -> Sequence[KnowledgeDB]: result = await self.session.exec(query) return result.all() - async def get_sync_knowledges_to_update(self,batch_size: int = 1) -> Sequence[KnowledgeDB]: - query = select(KnowledgeDB) - .where( - and_( - col(KnowledgeDB.sync_id).in_(brains_ids), - ) - ) + async def get_sync_knowledges_files_to_update( + self, timedelta_hour: int, batch_size: int + ) -> List[KnowledgeDB]: + time_delta = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) + query = ( + select(KnowledgeDB) + .where( + not_(KnowledgeDB.is_folder), + col(KnowledgeDB.sync_id).isnot(None), + col(KnowledgeDB.last_synced_at) < time_delta, + col(KnowledgeDB.brains).any(), + ) + # Oldest first + .order_by(col(KnowledgeDB.last_synced_at).asc(), random()) + .limit(batch_size) + ) + # Execute the query (assuming you have a session) result = await self.session.exec(query) - return result.all() + return list(result.unique().all()) diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 57dcc5b335d5..1b4ee7ac01be 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -350,3 +350,10 @@ async def unlink_knowledge_tree_brains( return await self.repository.unlink_knowledge_tree_brains( knowledge, brains_ids=brains_ids, user_id=user_id ) + + async def get_sync_knowledges_files_to_update( + self, timedelta_hour: int, batch_size: int = 1 + ) -> List[KnowledgeDB]: + return await self.repository.get_sync_knowledges_files_to_update( + timedelta_hour, batch_size + ) diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py index dd16af4d1b89..e0a8423f0384 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/conftest.py +++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py @@ -1,8 +1,17 @@ from io import BufferedReader, FileIO +from uuid import uuid4 +import pytest_asyncio +from sqlmodel import select, text +from sqlmodel.ext.asyncio.session import AsyncSession + +from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB from quivr_api.modules.knowledge.repository.storage_interface import StorageInterface +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Sync +from quivr_api.modules.user.entity.user_identity import User class ErrorStorage(StorageInterface): @@ -79,3 +88,56 @@ def clear_storage(self): async def download_file(self, knowledge: KnowledgeDB, **kwargs) -> bytes: storage_path = self.get_storage_path(knowledge) return self.storage[storage_path] + + +@pytest_asyncio.fixture(scope="function") +async def other_user(session: AsyncSession): + sql = text( + """ + INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES + ('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL); + """ + ) + await session.execute(sql, params={"id": uuid4()}) + + other_user = ( + await session.exec(select(User).where(User.email == "other@quivr.app")) + ).one() + return other_user + + +@pytest_asyncio.fixture(scope="function") +async def user(session): + user_1 = ( + await session.exec(select(User).where(User.email == "admin@quivr.app")) + ).one() + return user_1 + + +@pytest_asyncio.fixture(scope="function") +async def sync(session: AsyncSession, user: User) -> Sync: + assert user.id + sync = Sync( + name="test_sync", + email="test@test.com", + user_id=user.id, + credentials={"test": "test"}, + provider=SyncProvider.GOOGLE, + ) + + session.add(sync) + await session.commit() + await session.refresh(sync) + return sync + + +@pytest_asyncio.fixture(scope="function") +async def brain(session): + brain_1 = Brain( + name="test_brain", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_1) + await session.commit() + return brain_1 diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index 18a440be8118..857c111e9fae 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -1,74 +1,18 @@ from typing import List, Tuple -from uuid import uuid4 import pytest import pytest_asyncio from quivr_core.models import KnowledgeStatus -from sqlmodel import select, text +from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession -from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType +from quivr_api.modules.brain.entity.brain_entity import Brain from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB -from quivr_api.modules.sync.dto.outputs import SyncProvider -from quivr_api.modules.sync.entity.sync_models import Sync from quivr_api.modules.user.entity.user_identity import User TestData = Tuple[Brain, List[KnowledgeDB]] -@pytest_asyncio.fixture(scope="function") -async def other_user(session: AsyncSession): - sql = text( - """ - INSERT INTO "auth"."users" ("instance_id", "id", "aud", "role", "email", "encrypted_password", "email_confirmed_at", "invited_at", "confirmation_token", "confirmation_sent_at", "recovery_token", "recovery_sent_at", "email_change_token_new", "email_change", "email_change_sent_at", "last_sign_in_at", "raw_app_meta_data", "raw_user_meta_data", "is_super_admin", "created_at", "updated_at", "phone", "phone_confirmed_at", "phone_change", "phone_change_token", "phone_change_sent_at", "email_change_token_current", "email_change_confirm_status", "banned_until", "reauthentication_token", "reauthentication_sent_at", "is_sso_user", "deleted_at") VALUES - ('00000000-0000-0000-0000-000000000000', :id , 'authenticated', 'authenticated', 'other@quivr.app', '$2a$10$vwKX0eMLlrOZvxQEA3Vl4e5V4/hOuxPjGYn9QK1yqeaZxa.42Uhze', '2024-01-22 22:27:00.166861+00', NULL, '', NULL, 'e91d41043ca2c83c3be5a6ee7a4abc8a4f4fb1afc0a8453c502af931', '2024-03-05 16:22:13.780421+00', '', '', NULL, '2024-03-30 23:21:12.077887+00', '{"provider": "email", "providers": ["email"]}', '{}', NULL, '2024-01-22 22:27:00.158026+00', '2024-04-01 17:40:15.332205+00', NULL, NULL, '', '', NULL, '', 0, NULL, '', NULL, false, NULL); - """ - ) - await session.execute(sql, params={"id": uuid4()}) - - other_user = ( - await session.exec(select(User).where(User.email == "other@quivr.app")) - ).one() - return other_user - - -@pytest_asyncio.fixture(scope="function") -async def user(session): - user_1 = ( - await session.exec(select(User).where(User.email == "admin@quivr.app")) - ).one() - return user_1 - - -@pytest_asyncio.fixture(scope="function") -async def sync(session: AsyncSession, user: User) -> Sync: - assert user.id - sync = Sync( - name="test_sync", - email="test@test.com", - user_id=user.id, - credentials={"test": "test"}, - provider=SyncProvider.GOOGLE, - ) - - session.add(sync) - await session.commit() - await session.refresh(sync) - return sync - - -@pytest_asyncio.fixture(scope="function") -async def brain(session): - brain_1 = Brain( - name="test_brain", - description="this is a test brain", - brain_type=BrainType.integration, - ) - session.add(brain_1) - await session.commit() - return brain_1 - - @pytest_asyncio.fixture(scope="function") async def folder(session, user): folder = KnowledgeDB( diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index c764c92833f9..2a9700161c06 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -1,4 +1,5 @@ import os +from datetime import datetime, timedelta from io import BytesIO from typing import List, Tuple from uuid import uuid4 @@ -17,7 +18,7 @@ KnowledgeStatus, KnowledgeUpdate, ) -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_exceptions import ( @@ -27,6 +28,8 @@ ) from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.tests.conftest import ErrorStorage, FakeStorage +from quivr_api.modules.sync.dto.outputs import SyncProvider +from quivr_api.modules.sync.entity.sync_models import Sync from quivr_api.modules.upload.service.upload_file import upload_file_storage from quivr_api.modules.user.entity.user_identity import User from quivr_api.modules.vector.entity.vector import Vector @@ -1121,3 +1124,143 @@ async def test_unlink_knowledge_brain( kms = await service.get_all_knowledge_in_brain(brain_id=brain_user3.brain_id) assert len(kms) == 1 assert kms[0].id == file.id + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_sync_km_files_to_update_date( + session: AsyncSession, user: User, brain_user: Brain, sync: Sync +): + assert user.id + assert brain_user.brain_id + + file1 = KnowledgeDB( + file_name="folder", + extension="", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://file2", + file_size=4, + file_sha1="", + brains=[brain_user], + user_id=user.id, + sync_id=sync.id, + sync_file_id="file2", + last_synced_at=datetime.now() - timedelta(days=2), + ) + file2 = KnowledgeDB( + file_name="file_2", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://file2", + file_size=10, + file_sha1=None, + brains=[brain_user], + user_id=user.id, + sync_id=sync.id, + sync_file_id="file2", + last_synced_at=datetime.now(), + ) + session.add(file2) + session.add(file1) + await session.commit() + await session.refresh(file1) + await session.refresh(file2) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + kms = await service.get_sync_knowledges_files_to_update( + timedelta_hour=4, + batch_size=10, + ) + assert len(kms) == 1 + assert kms[0].id == file1.id + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_sync_km_files_to_update_brains( + session: AsyncSession, user: User, brain_user: Brain, sync: Sync +): + assert user.id + assert brain_user.brain_id + + file2 = KnowledgeDB( + file_name="file", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source=KnowledgeSource.LOCAL, + source_link="path", + file_size=4, + file_sha1="", + brains=[], + user_id=user.id, + ) + + file1 = KnowledgeDB( + file_name="file", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://file2", + file_size=4, + file_sha1="", + brains=[], + user_id=user.id, + sync_id=sync.id, + sync_file_id="file2", + last_synced_at=datetime.now() - timedelta(days=2), + ) + session.add(file1) + session.add(file2) + await session.commit() + await session.refresh(file1) + await session.refresh(file2) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + kms = await service.get_sync_knowledges_files_to_update( + timedelta_hour=4, + batch_size=10, + ) + assert len(kms) == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_sync_km_files_to_update_file_only( + session: AsyncSession, user: User, brain_user: Brain, sync: Sync +): + assert user.id + assert brain_user.brain_id + + file1 = KnowledgeDB( + file_name="folder", + extension="", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://file2", + file_size=4, + file_sha1="", + brains=[brain_user], + user_id=user.id, + sync_id=sync.id, + sync_file_id="file2", + is_folder=True, + last_synced_at=datetime.now() - timedelta(days=2), + ) + session.add(file1) + await session.commit() + await session.refresh(file1) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + kms = await service.get_sync_knowledges_files_to_update( + timedelta_hour=4, + batch_size=10, + ) + assert len(kms) == 0 From 7fbfab81ba34271cfe728d1fd0acd597885e84b1 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 30 Sep 2024 15:57:50 +0200 Subject: [PATCH 45/63] sync knowledge service --- .../quivr_api/modules/knowledge/dto/inputs.py | 2 ++ .../knowledge/repository/knowledges.py | 29 +++++++++++++++++-- .../knowledge/service/knowledge_service.py | 19 ++++++------ .../worker/quivr_worker/process/processor.py | 13 +++++++-- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py index 0290e057ae0c..8801ab48a1ef 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Dict, List, Optional from uuid import UUID @@ -45,6 +46,7 @@ class KnowledgeUpdate(BaseModel): source: Optional[str] = None source_link: Optional[str] = None metadata: Optional[Dict[str, str]] = None + last_synced_at: Optional[datetime] = None class LinkKnowledgeBrain(BaseModel): diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index ccc550c43060..aed61c041f91 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -438,8 +438,33 @@ async def get_all_knowledge(self) -> Sequence[KnowledgeDB]: result = await self.session.exec(query) return result.all() - async def get_sync_knowledges_files_to_update( - self, timedelta_hour: int, batch_size: int + async def get_outdated_sync_files( + self, + timedelta_hour: int, + batch_size: int, + ) -> List[KnowledgeDB]: + time_delta = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) + query = ( + select(KnowledgeDB) + .where( + not_(KnowledgeDB.is_folder), + col(KnowledgeDB.sync_id).isnot(None), + col(KnowledgeDB.last_synced_at) < time_delta, + col(KnowledgeDB.brains).any(), + ) + # Oldest first + .order_by(col(KnowledgeDB.last_synced_at).asc(), random()) + .limit(batch_size) + ) + + # Execute the query (assuming you have a session) + result = await self.session.exec(query) + return list(result.unique().all()) + + async def get_outdated_sync_folders( + self, + timedelta_hour: int, + batch_size: int, ) -> List[KnowledgeDB]: time_delta = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) query = ( diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 1b4ee7ac01be..87fe63073aa4 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -56,17 +56,20 @@ async def get_knowledge_sync(self, sync_id: int) -> KnowledgeDTO: async def create_or_link_sync_knowledge( self, - syncfile_to_knowledge: dict[str, KnowledgeDB], + syncfile_id_to_knowledge: dict[str, KnowledgeDB], parent_knowledge: KnowledgeDB, sync_file: SyncFile, ): - existing_km = syncfile_to_knowledge.get(sync_file.id) + existing_km = syncfile_id_to_knowledge.get(sync_file.id) if existing_km is not None: # NOTE: function called in worker processor - # The parent_knowledge was just added (we are processing it) - # This implies that we could have sync children that were processed before - # IF SyncKnowledge already exists => It's already processed in some other brain - # => Link it to the parent and add its brains and move on if it is PROCESSED ELSE Reprocess the file + # The parent_knowledge was just added to db (we are processing it) + # This implies that we could have sync children files and folders that were processed before + # If SyncKnowledge already exists + # IF STATUS == PROCESSED: + # => It's already processed in some other brain ! + # => Link it to the parent and update its brains to the correct ones + # ELSE Reprocess the file km_brains = {km_brain.brain_id for km_brain in existing_km.brains} for brain in filter( lambda b: b.brain_id not in km_brains, @@ -354,6 +357,4 @@ async def unlink_knowledge_tree_brains( async def get_sync_knowledges_files_to_update( self, timedelta_hour: int, batch_size: int = 1 ) -> List[KnowledgeDB]: - return await self.repository.get_sync_knowledges_files_to_update( - timedelta_hour, batch_size - ) + return await self.repository.get_outdated_sync_files(timedelta_hour, batch_size) diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 342afe84c947..044e6b314cab 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -1,4 +1,5 @@ import asyncio +from datetime import datetime, timezone from pathlib import Path from typing import AsyncGenerator, List, Optional, Tuple from uuid import UUID @@ -28,7 +29,7 @@ class KnowledgeProcessor: def __init__(self, services: ProcessorServices): self.services = services - async def fetch_sync_knowledge( + async def fetch_db_knowledges_and_syncprovider( self, sync_id: int, user_id: UUID, @@ -169,7 +170,10 @@ async def _yield_syncs( yield f # Fetch children - syncfile_to_knowledge, sync_files = await self.fetch_sync_knowledge( + ( + syncfile_to_knowledge, + sync_files, + ) = await self.fetch_db_knowledges_and_syncprovider( sync_id=parent_knowledge.sync_id, user_id=parent_knowledge.user_id, folder_id=parent_knowledge.sync_file_id, @@ -180,7 +184,7 @@ async def _yield_syncs( for sync_file in sync_files: file_knowledge = ( await self.services.knowledge_service.create_or_link_sync_knowledge( - syncfile_to_knowledge=syncfile_to_knowledge, + syncfile_id_to_knowledge=syncfile_to_knowledge, parent_knowledge=parent_knowledge, sync_file=sync_file, ) @@ -212,11 +216,14 @@ async def process_knowledge(self, knowledge_id: UUID): chunks=chunks, vector_service=self.services.vector_service, ) + + last_synced_at = datetime.now(timezone.utc) await self.services.knowledge_service.update_knowledge( knowledge, KnowledgeUpdate( status=KnowledgeStatus.PROCESSED, file_sha1=knowledge.file_sha1, + last_synced_at=last_synced_at if knowledge.sync_id else None, ), ) From f8db1126291abb56f279d89e8b11d47443bb5f51 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 30 Sep 2024 18:52:34 +0200 Subject: [PATCH 46/63] sync utils --- .../knowledge/controller/knowledge_routes.py | 1 + .../knowledge/repository/knowledges.py | 40 +- .../knowledge/service/knowledge_service.py | 14 +- .../knowledge/tests/test_knowledge_service.py | 83 +- .../modules/sync/entity/sync_models.py | 6 + .../api/quivr_api/modules/sync/utils/sync.py | 22 +- .../quivr_api/modules/sync/utils/syncutils.py | 742 +++++++++--------- .../worker/quivr_worker/process/processor.py | 99 ++- .../worker/quivr_worker/syncs/update_syncs.py | 6 + backend/worker/quivr_worker/utils/services.py | 19 +- .../worker/tests/test_process_file_task.py | 2 + 11 files changed, 582 insertions(+), 452 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 137532ce79a0..43675a6c41da 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -290,6 +290,7 @@ async def link_knowledge_to_brain( knowledge_to_add=AddKnowledge(**knowledge_dto.model_dump()), upload_file=None, ) + # TODO (@AmineDiro): Check if tree is necessary or updating this knowledge suffice linked_kms = await knowledge_service.link_knowledge_tree_brains( knowledge, brains_ids=brains_ids, user_id=current_user.id ) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index aed61c041f91..4f05af84a07e 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime from typing import Any, List, Sequence from uuid import UUID @@ -7,7 +7,7 @@ from sqlalchemy.exc import IntegrityError, NoResultFound from sqlalchemy.orm import joinedload from sqlalchemy.sql.functions import random -from sqlmodel import and_, col, not_, select, text +from sqlmodel import and_, col, select, text from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.logger import get_logger @@ -26,6 +26,7 @@ KnowledgeNotFoundException, KnowledgeUpdateError, ) +from quivr_api.modules.sync.entity.sync_models import SyncType logger = get_logger(__name__) @@ -438,41 +439,20 @@ async def get_all_knowledge(self) -> Sequence[KnowledgeDB]: result = await self.session.exec(query) return result.all() - async def get_outdated_sync_files( + async def get_outdated_syncs( self, - timedelta_hour: int, + limit_time: datetime, batch_size: int, + km_sync_type: SyncType, ) -> List[KnowledgeDB]: - time_delta = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) + is_folder_check = km_sync_type == SyncType.FOLDER query = ( select(KnowledgeDB) .where( - not_(KnowledgeDB.is_folder), + KnowledgeDB.is_folder == is_folder_check, col(KnowledgeDB.sync_id).isnot(None), - col(KnowledgeDB.last_synced_at) < time_delta, - col(KnowledgeDB.brains).any(), - ) - # Oldest first - .order_by(col(KnowledgeDB.last_synced_at).asc(), random()) - .limit(batch_size) - ) - - # Execute the query (assuming you have a session) - result = await self.session.exec(query) - return list(result.unique().all()) - - async def get_outdated_sync_folders( - self, - timedelta_hour: int, - batch_size: int, - ) -> List[KnowledgeDB]: - time_delta = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) - query = ( - select(KnowledgeDB) - .where( - not_(KnowledgeDB.is_folder), - col(KnowledgeDB.sync_id).isnot(None), - col(KnowledgeDB.last_synced_at) < time_delta, + col(KnowledgeDB.sync_file_id).isnot(None), + col(KnowledgeDB.last_synced_at) < limit_time, col(KnowledgeDB.brains).any(), ) # Oldest first diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 87fe63073aa4..e3f2085602dc 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -1,5 +1,6 @@ import asyncio import io +from datetime import datetime from typing import Any, List from uuid import UUID @@ -31,7 +32,7 @@ KnowledgeForbiddenAccess, UploadError, ) -from quivr_api.modules.sync.entity.sync_models import SyncFile +from quivr_api.modules.sync.entity.sync_models import SyncFile, SyncType from quivr_api.modules.upload.service.upload_file import check_file_exists logger = get_logger(__name__) @@ -354,7 +355,12 @@ async def unlink_knowledge_tree_brains( knowledge, brains_ids=brains_ids, user_id=user_id ) - async def get_sync_knowledges_files_to_update( - self, timedelta_hour: int, batch_size: int = 1 + async def get_outdated_syncs( + self, + limit_time: datetime, + km_sync_type: SyncType, + batch_size: int = 1, ) -> List[KnowledgeDB]: - return await self.repository.get_outdated_sync_files(timedelta_hour, batch_size) + return await self.repository.get_outdated_syncs( + limit_time=limit_time, batch_size=batch_size, km_sync_type=km_sync_type + ) diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index 2a9700161c06..3e9335a8630b 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -1,5 +1,5 @@ import os -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from io import BytesIO from typing import List, Tuple from uuid import uuid4 @@ -29,7 +29,7 @@ from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService from quivr_api.modules.knowledge.tests.conftest import ErrorStorage, FakeStorage from quivr_api.modules.sync.dto.outputs import SyncProvider -from quivr_api.modules.sync.entity.sync_models import Sync +from quivr_api.modules.sync.entity.sync_models import Sync, SyncType from quivr_api.modules.upload.service.upload_file import upload_file_storage from quivr_api.modules.user.entity.user_identity import User from quivr_api.modules.vector.entity.vector import Vector @@ -1127,7 +1127,7 @@ async def test_unlink_knowledge_brain( @pytest.mark.asyncio(loop_scope="session") -async def test_get_sync_km_files_to_update_date( +async def test_get_outdated_sync_update_date( session: AsyncSession, user: User, brain_user: Brain, sync: Sync ): assert user.id @@ -1171,16 +1171,16 @@ async def test_get_sync_km_files_to_update_date( repository = KnowledgeRepository(session) service = KnowledgeService(repository, storage) - kms = await service.get_sync_knowledges_files_to_update( - timedelta_hour=4, - batch_size=10, + last_time = datetime.now(timezone.utc) - timedelta(hours=4) + kms = await service.get_outdated_syncs( + limit_time=last_time, batch_size=10, km_sync_type=SyncType.FILE ) assert len(kms) == 1 assert kms[0].id == file1.id @pytest.mark.asyncio(loop_scope="session") -async def test_get_sync_km_files_to_update_brains( +async def test_get_outdated_sync_file_only_brains( session: AsyncSession, user: User, brain_user: Brain, sync: Sync ): assert user.id @@ -1222,15 +1222,15 @@ async def test_get_sync_km_files_to_update_brains( repository = KnowledgeRepository(session) service = KnowledgeService(repository, storage) - kms = await service.get_sync_knowledges_files_to_update( - timedelta_hour=4, - batch_size=10, + last_time = datetime.now(timezone.utc) - timedelta(hours=4) + kms = await service.get_outdated_syncs( + limit_time=last_time, batch_size=10, km_sync_type=SyncType.FILE ) assert len(kms) == 0 @pytest.mark.asyncio(loop_scope="session") -async def test_get_sync_km_files_to_update_file_only( +async def test_get_outdated_sync_file_only( session: AsyncSession, user: User, brain_user: Brain, sync: Sync ): assert user.id @@ -1259,8 +1259,63 @@ async def test_get_sync_km_files_to_update_file_only( repository = KnowledgeRepository(session) service = KnowledgeService(repository, storage) - kms = await service.get_sync_knowledges_files_to_update( - timedelta_hour=4, - batch_size=10, + last_time = datetime.now(timezone.utc) - timedelta(hours=4) + kms = await service.get_outdated_syncs( + limit_time=last_time, batch_size=10, km_sync_type=SyncType.FILE ) assert len(kms) == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_get_outdated_folders_sync( + session: AsyncSession, user: User, brain_user: Brain, sync: Sync +): + assert user.id + assert brain_user.brain_id + + folder = KnowledgeDB( + file_name="folder", + extension="", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://file2", + file_size=0, + file_sha1="", + brains=[brain_user], + user_id=user.id, + sync_id=sync.id, + sync_file_id="folder1", + is_folder=True, + last_synced_at=datetime.now() - timedelta(days=2), + ) + file = KnowledgeDB( + file_name="file", + extension="", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://file2", + file_size=4, + file_sha1="", + brains=[brain_user], + user_id=user.id, + sync_id=sync.id, + sync_file_id="file", + is_folder=False, + last_synced_at=datetime.now() - timedelta(days=2), + parent=folder, + ) + session.add(folder) + session.add(file) + await session.commit() + await session.refresh(folder) + + storage = FakeStorage() + repository = KnowledgeRepository(session) + service = KnowledgeService(repository, storage) + + last_time = datetime.now(timezone.utc) - timedelta(hours=4) + kms = await service.get_outdated_syncs( + limit_time=last_time, batch_size=10, km_sync_type=SyncType.FOLDER + ) + assert len(kms) == 1 + assert kms[0].id == folder.id diff --git a/backend/api/quivr_api/modules/sync/entity/sync_models.py b/backend/api/quivr_api/modules/sync/entity/sync_models.py index c0e3540c6544..036b84669ffb 100644 --- a/backend/api/quivr_api/modules/sync/entity/sync_models.py +++ b/backend/api/quivr_api/modules/sync/entity/sync_models.py @@ -2,6 +2,7 @@ import io from dataclasses import dataclass from datetime import datetime +from enum import Enum, auto from typing import Dict, List, Optional from uuid import UUID @@ -50,6 +51,11 @@ class SyncFile(BaseModel): type: Optional[str] = None +class SyncType(Enum): + FOLDER = auto() + FILE = auto() + + class Sync(SQLModel, table=True): __tablename__ = "syncs" # type: ignore diff --git a/backend/api/quivr_api/modules/sync/utils/sync.py b/backend/api/quivr_api/modules/sync/utils/sync.py index df5db5d5905f..e71f6d7d3d5b 100644 --- a/backend/api/quivr_api/modules/sync/utils/sync.py +++ b/backend/api/quivr_api/modules/sync/utils/sync.py @@ -3,7 +3,7 @@ import os import time from abc import ABC, abstractmethod -from datetime import datetime +from datetime import datetime, timezone from io import BytesIO from typing import Any, Dict, List, Optional, Union @@ -195,7 +195,9 @@ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFi is_folder=( result["mimeType"] == "application/vnd.google-apps.folder" ), - last_modified_at=result["modifiedTime"], + last_modified_at=datetime.strptime( + result["modifiedTime"], self.datetime_format + ).replace(tzinfo=timezone.utc), extension=result["mimeType"], web_view_link=result["webViewLink"], size=result.get("size", None), @@ -208,6 +210,8 @@ def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFi return files except HTTPError as error: + if error.response.status_code == 404: + raise FileNotFoundError logger.error( "An error occurred while retrieving Google Drive files: %s", error ) @@ -271,7 +275,9 @@ def get_files( is_folder=( item["mimeType"] == "application/vnd.google-apps.folder" ), - last_modified_at=item["modifiedTime"], + last_modified_at=datetime.strptime( + item["modifiedTime"], self.datetime_format + ).replace(tzinfo=timezone.utc), extension=item["mimeType"], web_view_link=item["webViewLink"], size=item.get("size", None), @@ -446,8 +452,8 @@ def fetch_files(endpoint, headers, max_retries=1): ), is_folder="folder" in item or not site_folder_id, last_modified_at=datetime.strptime( - item.get("lastModifiedDateTime"), self.datetime_format - ), + item["lastModiedDateTime"], self.datetime_format + ).replace(tzinfo=timezone.utc), extension=item.get("file", {}).get("mimeType", "folder"), web_view_link=item.get("webUrl"), size=item.get("size", None), @@ -528,7 +534,7 @@ def get_files_by_id(self, credentials: dict, file_ids: List[str]) -> List[SyncFi is_folder="folder" in result, last_modified_at=datetime.strptime( result.get("lastModifiedDateTime"), self.datetime_format - ), + ).replace(tzinfo=timezone.utc), extension=result.get("file", {}).get("mimeType", "folder"), web_view_link=result.get("webUrl"), size=result.get("size", None), @@ -820,7 +826,7 @@ async def aget_files( name=page.name, id=str(page.notion_id), is_folder=await self.notion_service.is_folder_page(page.notion_id), - last_modified_at=str(page.last_modified), + last_modified_at=page.last_modified, extension=page.mime_type, web_view_link=page.web_view_link, icon=page.icon, @@ -864,7 +870,7 @@ async def aget_files_by_id( name=page.name, id=str(page.notion_id), is_folder=await self.notion_service.is_folder_page(page.notion_id), - last_modified_at=str(page.last_modified), + last_modified_at=page.last_modified, extension=page.mime_type, web_view_link=page.web_view_link, icon=page.icon, diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py index 1edaaa26dfc4..33adb5603962 100644 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ b/backend/api/quivr_api/modules/sync/utils/syncutils.py @@ -30,374 +30,374 @@ async def fetch_sync_knowledge( return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 -# # NOTE: we are filtering based on file path names in sync ! -# def filter_on_supported_files( -# files: list[SyncFile], existing_files: dict[str, DBSyncFile] -# ) -> list[Tuple[SyncFile, DBSyncFile | None]]: -# res = [] -# for new_file in files: -# prev_file = existing_files.get(new_file.name, None) -# if (prev_file and prev_file.supported) or prev_file is None: -# res.append((new_file, prev_file)) -# return res - - -# def should_download_file( -# file: SyncFile, -# last_updated_sync_active: datetime | None, -# provider_name: str, -# datetime_format: str, -# ) -> bool: -# file_last_modified_utc = datetime.strptime( -# file.last_modified_at, datetime_format -# ).replace(tzinfo=timezone.utc) - -# should_download = ( -# last_updated_sync_active is None -# or file_last_modified_utc > last_updated_sync_active -# ) - -# # TODO: Handle notion database -# if provider_name == "notion": -# should_download &= file.extension != "db" -# else: -# should_download &= not file.is_folder - -# return should_download - - -# class SyncUtils: -# def __init__( -# self, -# # sync_user_service: ISyncUserService, -# # sync_active_service: ISyncService, -# # sync_files_repo: SyncFileInterface, -# sync_cloud: BaseSync, -# knowledge_service: KnowledgeService, -# notification_service: NotificationService, -# brain_vectors: BrainsVectors, -# ) -> None: -# self.sync_user_service = sync_user_service -# self.sync_active_service = sync_active_service -# self.sync_files_repo = sync_files_repo -# self.knowledge_service = knowledge_service -# self.sync_cloud = sync_cloud -# self.notification_service = notification_service -# self.brain_vectors = brain_vectors - -# # TODO: This modifies the file, we should treat it as such -# def create_sync_bulk_notification( -# self, files: list[SyncFile], current_user: UUID, brain_id: UUID, bulk_id: UUID -# ) -> list[SyncFile]: -# res = [] -# # TODO: bulk insert in batch -# for file in files: -# upload_notification = self.notification_service.add_notification( -# CreateNotification( -# user_id=current_user, -# bulk_id=bulk_id, -# status=NotificationsStatusEnum.INFO, -# title=file.name, -# category="sync", -# brain_id=str(brain_id), -# ) -# ) -# file.notification_id = upload_notification.id -# res.append(file) -# return res - -# async def download_file( -# self, file: SyncFile, credentials: dict[str, Any] -# ) -> DownloadedSyncFile: -# logger.info(f"Downloading {file} using {self.sync_cloud}") -# file_response = await self.sync_cloud.adownload_file(credentials, file) -# logger.debug(f"Fetch sync file response: {file_response}") -# file_name = str(file_response["file_name"]) -# raw_data = file_response["content"] -# file_data = ( -# io.BufferedReader(raw_data) # type: ignore -# if isinstance(raw_data, io.BytesIO) -# else io.BufferedReader(raw_data.encode("utf-8")) # type: ignore -# ) -# extension = os.path.splitext(file_name)[-1].lower() -# dfile = DownloadedSyncFile( -# file_name=file_name, -# file_data=file_data, -# extension=extension, -# ) -# logger.debug(f"Successfully downloaded sync file : {dfile}") -# return dfile - -# # TODO: REDO THIS MESS !!!! -# # REMOVE ALL SYNC TABLES and start from scratch - -# async def process_sync_file( -# self, -# file: SyncFile, -# previous_file: DBSyncFile | None, -# current_user: SyncsUser, -# sync_active: SyncsActive, -# ): -# logger.info("Processing file: %s", file.name) -# brain_id = sync_active.brain_id -# source, source_link = self.sync_cloud.name, file.web_view_link -# downloaded_file = await self.download_file(file, current_user.credentials) -# storage_path = f"{brain_id}/{downloaded_file.file_name}" -# exists_in_storage = check_file_exists(str(brain_id), file.name) - -# if downloaded_file.extension not in [ -# ".pdf", -# ".txt", -# ".md", -# ".csv", -# ".docx", -# ".xlsx", -# ".pptx", -# ".doc", -# ]: -# raise ValueError(f"Incompatible file extension for {downloaded_file}") - -# response = await upload_file_storage( -# downloaded_file.file_data, -# storage_path, -# upsert=exists_in_storage, -# ) -# assert response, f"Error uploading {downloaded_file} to {storage_path}" -# self.notification_service.update_notification_by_id( -# file.notification_id, -# NotificationUpdatableProperties( -# status=NotificationsStatusEnum.SUCCESS, -# description="File downloaded successfully", -# ), -# ) -# # TODO : why knowledge + syncfile, drop syncfile ... -# # FIXME : Simplify this logic in KMS plzzz -# sync_file_db = self.sync_files_repo.update_or_create_sync_file( -# file=file, -# previous_file=previous_file, -# sync_active=sync_active, -# supported=True, -# ) -# knowledge = await self.knowledge_service.update_or_create_knowledge_sync( -# brain_id=brain_id, -# file=file, -# new_sync_file=sync_file_db, -# prev_sync_file=previous_file, -# downloaded_file=downloaded_file, -# source=source, -# source_link=source_link, -# user_id=current_user.user_id, -# ) - -# # Send file for processing -# celery.send_task( -# "process_file_task", -# kwargs={ -# "brain_id": brain_id, -# "knowledge_id": knowledge.id, -# "file_name": storage_path, -# "file_original_name": file.name, -# "source": source, -# "source_link": source_link, -# "notification_id": file.notification_id, -# }, -# ) -# return file - -# async def process_sync_files( -# self, -# files: List[SyncFile], -# current_user: SyncsUser, -# sync_active: SyncsActive, -# ): -# logger.info(f"Processing {len(files)} for sync_active: {sync_active.id}") -# current_user.credentials = self.sync_cloud.check_and_refresh_access_token( -# current_user.credentials -# ) - -# bulk_id = uuid4() -# downloaded_files = [] -# list_existing_files = self.sync_files_repo.get_sync_files(sync_active.id) -# existing_files = {f.path: f for f in list_existing_files} - -# supported_files = filter_on_supported_files(files, existing_files) - -# files = self.create_sync_bulk_notification( -# files, current_user.user_id, sync_active.brain_id, bulk_id -# ) - -# for file, prev_file in supported_files: -# try: -# result = await self.process_sync_file( -# file=file, -# previous_file=prev_file, -# current_user=current_user, -# sync_active=sync_active, -# ) -# if result is not None: -# downloaded_files.append(result) - -# self.notification_service.update_notification_by_id( -# file.notification_id, -# NotificationUpdatableProperties( -# status=NotificationsStatusEnum.SUCCESS, -# description="File downloaded successfully", -# ), -# ) - -# except Exception as e: -# logger.error( -# "An error occurred while syncing %s files: %s", -# self.sync_cloud.name, -# e, -# ) -# # TODO: this process_sync_file could fail for a LOT of reason redo this logic -# # File isn't supported so we set it as so ? -# self.sync_files_repo.update_or_create_sync_file( -# file=file, -# sync_active=sync_active, -# previous_file=prev_file, -# supported=False, -# ) -# self.notification_service.update_notification_by_id( -# file.notification_id, -# NotificationUpdatableProperties( -# status=NotificationsStatusEnum.ERROR, -# description="Error downloading file", -# ), -# ) - -# return {"downloaded_files": downloaded_files} - -# async def get_files_to_download( -# self, sync_active: SyncsActive, user_sync: SyncsUser -# ) -> list[SyncFile]: -# # Get the folder id from the settings from sync_active -# folders = sync_active.settings.get("folders", []) -# files_ids = sync_active.settings.get("files", []) - -# files = await self.get_syncfiles_from_ids( -# user_sync.credentials, -# files_ids=files_ids, -# folder_ids=folders, -# sync_user_id=user_sync.id, -# ) - -# logger.debug(f"original files to download for {sync_active.id} : {files}") - -# last_synced_time = ( -# datetime.fromisoformat(sync_active.last_synced).astimezone(timezone.utc) -# if sync_active.last_synced -# else None -# ) - -# files_ids = [ -# file -# for file in files -# if should_download_file( -# file=file, -# last_updated_sync_active=last_synced_time, -# provider_name=self.sync_cloud.lower_name, -# datetime_format=self.sync_cloud.datetime_format, -# ) -# ] - -# logger.debug(f"filter files to download for {sync_active} : {files_ids}") -# return files_ids - -# async def get_syncfiles_from_ids( -# self, -# credentials: dict[str, Any], -# files_ids: list[str], -# folder_ids: list[str], -# sync_user_id: int, -# ) -> list[SyncFile]: -# files = [] -# if self.sync_cloud.lower_name == "notion": -# files_ids += folder_ids - -# for folder_id in folder_ids: -# logger.debug( -# f"Recursively getting file_ids from {self.sync_cloud.name}. folder_id={folder_id}" -# ) -# files.extend( -# await self.sync_cloud.aget_files( -# credentials=credentials, -# sync_user_id=sync_user_id, -# folder_id=folder_id, -# recursive=True, -# ) -# ) -# if len(files_ids) > 0: -# files.extend( -# await self.sync_cloud.aget_files_by_id( -# credentials=credentials, -# file_ids=files_ids, -# ) -# ) -# return files - -# async def direct_sync( -# self, -# sync_active: SyncsActive, -# sync_user: SyncsUser, -# files_ids: list[str], -# folder_ids: list[str], -# ): -# files = await self.get_syncfiles_from_ids( -# sync_user.credentials, files_ids, folder_ids -# ) -# processed_files = await self.process_sync_files( -# files=files, -# current_user=sync_user, -# sync_active=sync_active, -# ) - -# # Update the last_synced timestamp -# self.sync_active_service.update_sync_active( -# sync_active.id, -# SyncsActiveUpdateInput( -# last_synced=datetime.now().astimezone().isoformat(), force_sync=False -# ), -# ) -# logger.info( -# f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", -# ) -# return processed_files - -# async def sync( -# self, -# sync_active: SyncsActive, -# user_sync: SyncsUser, -# ): -# """ -# Check if the Specific sync has not been synced and download the folders and files based on the settings. - -# Args: -# sync_active_id (int): The ID of the active sync. -# user_id (str): The user ID associated with the active sync. -# """ -# logger.info( -# "Starting %s sync for sync_active: %s", -# self.sync_cloud.lower_name, -# sync_active, -# ) - -# files_to_download = await self.get_files_to_download(sync_active, user_sync) -# processed_files = await self.process_sync_files( -# files=files_to_download, -# current_user=user_sync, -# sync_active=sync_active, -# ) - -# # Update the last_synced timestamp -# self.sync_active_service.update_sync_active( -# sync_active.id, -# SyncsActiveUpdateInput( -# last_synced=datetime.now().astimezone().isoformat(), force_sync=False -# ), -# ) -# logger.info( -# f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", -# ) -# return processed_files +# NOTE: we are filtering based on file path names in sync ! +def filter_on_supported_files( + files: list[SyncFile], existing_files: dict[str, DBSyncFile] +) -> list[Tuple[SyncFile, DBSyncFile | None]]: + res = [] + for new_file in files: + prev_file = existing_files.get(new_file.name, None) + if (prev_file and prev_file.supported) or prev_file is None: + res.append((new_file, prev_file)) + return res + + +def should_download_file( + file: SyncFile, + last_updated_sync_active: datetime | None, + provider_name: str, + datetime_format: str, +) -> bool: + file_last_modified_utc = datetime.strptime( + file.last_modified_at, datetime_format + ).replace(tzinfo=timezone.utc) + + should_download = ( + last_updated_sync_active is None + or file_last_modified_utc > last_updated_sync_active + ) + + # TODO: Handle notion database + if provider_name == "notion": + should_download &= file.extension != "db" + else: + should_download &= not file.is_folder + + return should_download + + +class SyncUtils: + def __init__( + self, + # sync_user_service: ISyncUserService, + # sync_active_service: ISyncService, + # sync_files_repo: SyncFileInterface, + sync_cloud: BaseSync, + knowledge_service: KnowledgeService, + notification_service: NotificationService, + brain_vectors: BrainsVectors, + ) -> None: + self.sync_user_service = sync_user_service + self.sync_active_service = sync_active_service + self.sync_files_repo = sync_files_repo + self.knowledge_service = knowledge_service + self.sync_cloud = sync_cloud + self.notification_service = notification_service + self.brain_vectors = brain_vectors + + # TODO: This modifies the file, we should treat it as such + def create_sync_bulk_notification( + self, files: list[SyncFile], current_user: UUID, brain_id: UUID, bulk_id: UUID + ) -> list[SyncFile]: + res = [] + # TODO: bulk insert in batch + for file in files: + upload_notification = self.notification_service.add_notification( + CreateNotification( + user_id=current_user, + bulk_id=bulk_id, + status=NotificationsStatusEnum.INFO, + title=file.name, + category="sync", + brain_id=str(brain_id), + ) + ) + file.notification_id = upload_notification.id + res.append(file) + return res + + async def download_file( + self, file: SyncFile, credentials: dict[str, Any] + ) -> DownloadedSyncFile: + logger.info(f"Downloading {file} using {self.sync_cloud}") + file_response = await self.sync_cloud.adownload_file(credentials, file) + logger.debug(f"Fetch sync file response: {file_response}") + file_name = str(file_response["file_name"]) + raw_data = file_response["content"] + file_data = ( + io.BufferedReader(raw_data) # type: ignore + if isinstance(raw_data, io.BytesIO) + else io.BufferedReader(raw_data.encode("utf-8")) # type: ignore + ) + extension = os.path.splitext(file_name)[-1].lower() + dfile = DownloadedSyncFile( + file_name=file_name, + file_data=file_data, + extension=extension, + ) + logger.debug(f"Successfully downloaded sync file : {dfile}") + return dfile + + # TODO: REDO THIS MESS !!!! + # REMOVE ALL SYNC TABLES and start from scratch + + async def process_sync_file( + self, + file: SyncFile, + previous_file: DBSyncFile | None, + current_user: SyncsUser, + sync_active: SyncsActive, + ): + logger.info("Processing file: %s", file.name) + brain_id = sync_active.brain_id + source, source_link = self.sync_cloud.name, file.web_view_link + downloaded_file = await self.download_file(file, current_user.credentials) + storage_path = f"{brain_id}/{downloaded_file.file_name}" + exists_in_storage = check_file_exists(str(brain_id), file.name) + + if downloaded_file.extension not in [ + ".pdf", + ".txt", + ".md", + ".csv", + ".docx", + ".xlsx", + ".pptx", + ".doc", + ]: + raise ValueError(f"Incompatible file extension for {downloaded_file}") + + response = await upload_file_storage( + downloaded_file.file_data, + storage_path, + upsert=exists_in_storage, + ) + assert response, f"Error uploading {downloaded_file} to {storage_path}" + self.notification_service.update_notification_by_id( + file.notification_id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.SUCCESS, + description="File downloaded successfully", + ), + ) + # TODO : why knowledge + syncfile, drop syncfile ... + # FIXME : Simplify this logic in KMS plzzz + sync_file_db = self.sync_files_repo.update_or_create_sync_file( + file=file, + previous_file=previous_file, + sync_active=sync_active, + supported=True, + ) + knowledge = await self.knowledge_service.update_or_create_knowledge_sync( + brain_id=brain_id, + file=file, + new_sync_file=sync_file_db, + prev_sync_file=previous_file, + downloaded_file=downloaded_file, + source=source, + source_link=source_link, + user_id=current_user.user_id, + ) + + # Send file for processing + celery.send_task( + "process_file_task", + kwargs={ + "brain_id": brain_id, + "knowledge_id": knowledge.id, + "file_name": storage_path, + "file_original_name": file.name, + "source": source, + "source_link": source_link, + "notification_id": file.notification_id, + }, + ) + return file + + async def process_sync_files( + self, + files: List[SyncFile], + current_user: SyncsUser, + sync_active: SyncsActive, + ): + logger.info(f"Processing {len(files)} for sync_active: {sync_active.id}") + current_user.credentials = self.sync_cloud.check_and_refresh_access_token( + current_user.credentials + ) + + bulk_id = uuid4() + downloaded_files = [] + list_existing_files = self.sync_files_repo.get_sync_files(sync_active.id) + existing_files = {f.path: f for f in list_existing_files} + + supported_files = filter_on_supported_files(files, existing_files) + + files = self.create_sync_bulk_notification( + files, current_user.user_id, sync_active.brain_id, bulk_id + ) + + for file, prev_file in supported_files: + try: + result = await self.process_sync_file( + file=file, + previous_file=prev_file, + current_user=current_user, + sync_active=sync_active, + ) + if result is not None: + downloaded_files.append(result) + + self.notification_service.update_notification_by_id( + file.notification_id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.SUCCESS, + description="File downloaded successfully", + ), + ) + + except Exception as e: + logger.error( + "An error occurred while syncing %s files: %s", + self.sync_cloud.name, + e, + ) + # TODO: this process_sync_file could fail for a LOT of reason redo this logic + # File isn't supported so we set it as so ? + self.sync_files_repo.update_or_create_sync_file( + file=file, + sync_active=sync_active, + previous_file=prev_file, + supported=False, + ) + self.notification_service.update_notification_by_id( + file.notification_id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.ERROR, + description="Error downloading file", + ), + ) + + return {"downloaded_files": downloaded_files} + + async def get_files_to_download( + self, sync_active: SyncsActive, user_sync: SyncsUser + ) -> list[SyncFile]: + # Get the folder id from the settings from sync_active + folders = sync_active.settings.get("folders", []) + files_ids = sync_active.settings.get("files", []) + + files = await self.get_syncfiles_from_ids( + user_sync.credentials, + files_ids=files_ids, + folder_ids=folders, + sync_user_id=user_sync.id, + ) + + logger.debug(f"original files to download for {sync_active.id} : {files}") + + last_synced_time = ( + datetime.fromisoformat(sync_active.last_synced).astimezone(timezone.utc) + if sync_active.last_synced + else None + ) + + files_ids = [ + file + for file in files + if should_download_file( + file=file, + last_updated_sync_active=last_synced_time, + provider_name=self.sync_cloud.lower_name, + datetime_format=self.sync_cloud.datetime_format, + ) + ] + + logger.debug(f"filter files to download for {sync_active} : {files_ids}") + return files_ids + + async def get_syncfiles_from_ids( + self, + credentials: dict[str, Any], + files_ids: list[str], + folder_ids: list[str], + sync_user_id: int, + ) -> list[SyncFile]: + files = [] + if self.sync_cloud.lower_name == "notion": + files_ids += folder_ids + + for folder_id in folder_ids: + logger.debug( + f"Recursively getting file_ids from {self.sync_cloud.name}. folder_id={folder_id}" + ) + files.extend( + await self.sync_cloud.aget_files( + credentials=credentials, + sync_user_id=sync_user_id, + folder_id=folder_id, + recursive=True, + ) + ) + if len(files_ids) > 0: + files.extend( + await self.sync_cloud.aget_files_by_id( + credentials=credentials, + file_ids=files_ids, + ) + ) + return files + + async def direct_sync( + self, + sync_active: SyncsActive, + sync_user: SyncsUser, + files_ids: list[str], + folder_ids: list[str], + ): + files = await self.get_syncfiles_from_ids( + sync_user.credentials, files_ids, folder_ids + ) + processed_files = await self.process_sync_files( + files=files, + current_user=sync_user, + sync_active=sync_active, + ) + + # Update the last_synced timestamp + self.sync_active_service.update_sync_active( + sync_active.id, + SyncsActiveUpdateInput( + last_synced=datetime.now().astimezone().isoformat(), force_sync=False + ), + ) + logger.info( + f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", + ) + return processed_files + + async def sync( + self, + sync_active: SyncsActive, + user_sync: SyncsUser, + ): + """ + Check if the Specific sync has not been synced and download the folders and files based on the settings. + + Args: + sync_active_id (int): The ID of the active sync. + user_id (str): The user ID associated with the active sync. + """ + logger.info( + "Starting %s sync for sync_active: %s", + self.sync_cloud.lower_name, + sync_active, + ) + + files_to_download = await self.get_files_to_download(sync_active, user_sync) + processed_files = await self.process_sync_files( + files=files_to_download, + current_user=user_sync, + sync_active=sync_active, + ) + + # Update the last_synced timestamp + self.sync_active_service.update_sync_active( + sync_active.id, + SyncsActiveUpdateInput( + last_synced=datetime.now().astimezone().isoformat(), force_sync=False + ), + ) + logger.info( + f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", + ) + return processed_files diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 044e6b314cab..2fd349f5f370 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -1,5 +1,5 @@ import asyncio -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from pathlib import Path from typing import AsyncGenerator, List, Optional, Tuple from uuid import UUID @@ -8,7 +8,7 @@ from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.sync.dto.outputs import SyncProvider -from quivr_api.modules.sync.entity.sync_models import SyncFile +from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile, SyncType from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus @@ -63,7 +63,7 @@ async def yield_processable_knowledge( KnowledgeSource.GOOGLE, KnowledgeSource.DROPBOX, KnowledgeSource.GITHUB, - KnowledgeSource.NOTION, + # KnowledgeSource.NOTION, ): async for to_process in self._yield_syncs(knowledge): yield to_process @@ -209,30 +209,87 @@ async def process_knowledge(self, knowledge_id: UUID): continue knowledge, qfile = knowledge_tuple try: - if not skip_process(knowledge): - chunks = await parse_qfile(qfile=qfile) - await store_chunks( - knowledge=knowledge, - chunks=chunks, - vector_service=self.services.vector_service, - ) - - last_synced_at = datetime.now(timezone.utc) - await self.services.knowledge_service.update_knowledge( - knowledge, - KnowledgeUpdate( - status=KnowledgeStatus.PROCESSED, - file_sha1=knowledge.file_sha1, - last_synced_at=last_synced_at if knowledge.sync_id else None, - ), - ) - + await self._process_inner(knowledge=knowledge, qfile=qfile) except Exception as e: await savepoint.rollback() logger.error(f"Error processing knowledge {knowledge_id} : {e}") + # FIXME: This one can also fail if knowledge was deleted await self.services.knowledge_service.update_knowledge( knowledge, KnowledgeUpdate( status=KnowledgeStatus.ERROR, ), ) + + async def _process_inner(self, knowledge: KnowledgeDB, qfile: QuivrFile): + if not skip_process(knowledge): + chunks = await parse_qfile(qfile=qfile) + await store_chunks( + knowledge=knowledge, + chunks=chunks, + vector_service=self.services.vector_service, + ) + last_synced_at = datetime.now(timezone.utc) + await self.services.knowledge_service.update_knowledge( + knowledge, + KnowledgeUpdate( + status=KnowledgeStatus.PROCESSED, + file_sha1=knowledge.file_sha1, + # Update sync + last_synced_at=last_synced_at if knowledge.sync_id else None, + ), + ) + + async def update_outdated_syncs_files( + self, timedelta_hour: int = 8, batch_size: int = 1000 + ): + last_time = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) + km_sync_files = await self.services.knowledge_service.get_outdated_syncs( + limit_time=last_time, + batch_size=batch_size, + km_sync_type=SyncType.FILE, + ) + sync_cache: dict[int, Sync] = {} + for km in km_sync_files: + assert km.sync_id, "can only update sync files with sync_id" + assert km.sync_file_id, "can only update sync files with sync_file_id " + assert ( + km.last_synced_at + ), "can only update sync files without a last_synced_at" + if km.sync_id in sync_cache: + sync = sync_cache[km.sync_id] + else: + sync = await self.services.sync_service.get_sync_by_id(km.sync_id) + assert sync.id + sync_cache[sync.id] = sync + + if sync.credentials is None: + logger.error( + f"can't process knowledge: {km.id}. sync {sync.id} has no credentials" + ) + raise ValueError("no associated credentials") + provider_name = SyncProvider(sync.provider.lower()) + sync_provider = self.services.syncprovider_mapping[provider_name] + try: + new_sync_file = ( + await sync_provider.aget_files_by_id( + credentials=sync.credentials, file_ids=[km.sync_file_id] + ) + )[0] + if ( + new_sync_file.last_modified_at + and new_sync_file.last_modified_at < km.last_synced_at + ) or new_sync_file.last_modified_at is None: + # Create transaction + # - Create new knowledge with brains and parent like the old one + # _process_knowledge + # - Parse it + store the chunks + # - Update the knowledge + # - Remove the older knowledge + pass + else: + continue + + except FileNotFoundError: + # Remove has been deleted + pass diff --git a/backend/worker/quivr_worker/syncs/update_syncs.py b/backend/worker/quivr_worker/syncs/update_syncs.py index a2d6660a9cd3..ed8d89596ab0 100644 --- a/backend/worker/quivr_worker/syncs/update_syncs.py +++ b/backend/worker/quivr_worker/syncs/update_syncs.py @@ -5,6 +5,12 @@ async def update_sync_files(async_engine: AsyncEngine): async with build_processor_services(async_engine) as processor_services: + # Get folders + ## Fetch the folder children + ## Fetch all knowledge for this sync + ## If this knowledge isn't + ## Get files + # Update sync files pass diff --git a/backend/worker/quivr_worker/utils/services.py b/backend/worker/quivr_worker/utils/services.py index 3fe2f6dbdeb5..5bfbda44c156 100644 --- a/backend/worker/quivr_worker/utils/services.py +++ b/backend/worker/quivr_worker/utils/services.py @@ -11,7 +11,11 @@ from quivr_api.modules.sync.repository.sync_repository import SyncsRepository from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.sync import ( + AzureDriveSync, BaseSync, + DropboxSync, + GitHubSync, + GoogleDriveSync, ) from quivr_api.modules.vector.repository.vectors_repository import VectorRepository from quivr_api.modules.vector.service.vector_service import VectorService @@ -19,13 +23,20 @@ from sqlmodel import text from sqlmodel.ext.asyncio.session import AsyncSession -from quivr_worker.process.utils import ( - build_syncprovider_mapping, -) - logger = get_logger("celery_worker") +def build_syncprovider_mapping() -> dict[SyncProvider, BaseSync]: + mapping_sync_utils = { + SyncProvider.GOOGLE: GoogleDriveSync(), + SyncProvider.AZURE: AzureDriveSync(), + SyncProvider.DROPBOX: DropboxSync(), + SyncProvider.GITHUB: GitHubSync(), + # SyncProvider.NOTION: NotionSync(notion_service=notion_service), + } + return mapping_sync_utils + + @dataclass class ProcessorServices: sync_service: SyncsService diff --git a/backend/worker/tests/test_process_file_task.py b/backend/worker/tests/test_process_file_task.py index 83ed7521fe47..d98050a0a499 100644 --- a/backend/worker/tests/test_process_file_task.py +++ b/backend/worker/tests/test_process_file_task.py @@ -183,6 +183,7 @@ async def test_process_sync_file( assert km.status == KnowledgeStatus.PROCESSED assert km.brains[0].brain_id == input_km.brains[0].brain_id assert km.file_sha1 is not None + assert km.last_synced_at is not None # Check vectors where added vecs = list( @@ -229,6 +230,7 @@ async def test_process_sync_folder( assert km.brains[0]["brain_id"] assert km.brains[0]["brain_id"] == input_km.brains[0].brain_id assert km.file_sha1 is not None + assert km.last_synced_at is not None # Check vectors where added vecs = list((await session.exec(select(Vector))).all()) From 0b7eecd272aa67a7053c704a84982945ba041b04 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 1 Oct 2024 11:45:34 +0200 Subject: [PATCH 47/63] working update + tests --- .../knowledge/repository/knowledges.py | 26 ++- .../knowledge/service/knowledge_service.py | 61 +++++-- .../worker/quivr_worker/process/processor.py | 153 ++++++++++++------ backend/worker/quivr_worker/process/utils.py | 7 +- backend/worker/tests/conftest.py | 4 +- backend/worker/tests/test_update_syncs.py | 88 ++++++++++ 6 files changed, 266 insertions(+), 73 deletions(-) create mode 100644 backend/worker/tests/test_update_syncs.py diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index 4f05af84a07e..e4b36bfe7033 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -37,11 +37,15 @@ def __init__(self, session: AsyncSession): supabase_client = get_supabase_client() self.db = supabase_client - async def create_knowledge(self, knowledge: KnowledgeDB) -> KnowledgeDB: + async def create_knowledge( + self, knowledge: KnowledgeDB, autocommit: bool + ) -> KnowledgeDB: try: self.session.add(knowledge) - await self.session.commit() - await self.session.refresh(knowledge) + if autocommit: + await self.session.commit() + await self.session.refresh(knowledge) + await self.session.flush() except IntegrityError: await self.session.rollback() raise @@ -54,6 +58,7 @@ async def update_knowledge( self, knowledge: KnowledgeDB, payload: KnowledgeDTO | KnowledgeUpdate | dict[str, Any], + autocommit: bool, ) -> KnowledgeDB: try: logger.debug(f"updating {knowledge.id} with payload {payload}") @@ -65,8 +70,11 @@ async def update_knowledge( setattr(knowledge, field, update_data[field]) self.session.add(knowledge) - await self.session.commit() - await self.session.refresh(knowledge) + if autocommit: + await self.session.commit() + await self.session.refresh(knowledge) + else: + await self.session.flush() return knowledge except IntegrityError as e: await self.session.rollback() @@ -207,10 +215,14 @@ async def remove_knowledge_from_brain( await self.session.refresh(knowledge) return knowledge - async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse: + async def remove_knowledge( + self, knowledge: KnowledgeDB, autocommit: bool + ) -> DeleteKnowledgeResponse: assert knowledge.id await self.session.delete(knowledge) - await self.session.commit() + if autocommit: + await self.session.commit() + return DeleteKnowledgeResponse( status="deleted", knowledge_id=knowledge.id, file_name=knowledge.file_name ) diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index e3f2085602dc..bdff2236a796 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -10,6 +10,7 @@ from quivr_api.celery_config import celery from quivr_api.logger import get_logger +from quivr_api.modules.brain.entity.brain_entity import Brain from quivr_api.modules.dependencies import BaseService from quivr_api.modules.knowledge.dto.inputs import ( AddKnowledge, @@ -60,6 +61,7 @@ async def create_or_link_sync_knowledge( syncfile_id_to_knowledge: dict[str, KnowledgeDB], parent_knowledge: KnowledgeDB, sync_file: SyncFile, + autocommit: bool = True, ): existing_km = syncfile_id_to_knowledge.get(sync_file.id) if existing_km is not None: @@ -77,7 +79,9 @@ async def create_or_link_sync_knowledge( parent_knowledge.brains, ): await self.repository.update_knowledge( - existing_km, KnowledgeUpdate(parent_id=parent_knowledge.id) + existing_km, + KnowledgeUpdate(parent_id=parent_knowledge.id), + autocommit=autocommit, ) await self.repository.link_to_brain( existing_km, brain_id=brain.brain_id @@ -151,10 +155,11 @@ async def update_knowledge( self, knowledge: KnowledgeDB | UUID, payload: KnowledgeDTO | KnowledgeUpdate | dict[str, Any], + autocommit: bool = True, ): if isinstance(knowledge, UUID): knowledge = await self.repository.get_knowledge_by_id(knowledge) - return await self.repository.update_knowledge(knowledge, payload) + return await self.repository.update_knowledge(knowledge, payload, autocommit) async def create_knowledge( self, @@ -162,11 +167,22 @@ async def create_knowledge( knowledge_to_add: AddKnowledge, upload_file: UploadFile | None = None, status: KnowledgeStatus = KnowledgeStatus.RESERVED, + add_brains: list[Brain] = [], + autocommit: bool = True, ) -> KnowledgeDB: brains = [] if knowledge_to_add.parent_id: parent_knowledge = await self.get_knowledge(knowledge_to_add.parent_id) brains = await parent_knowledge.awaitable_attrs.brains + if len(add_brains) > 0: + brains.extend( + [ + b + for b in add_brains + if b.brain_id not in {b.brain_id for b in brains} + ] + ) + knowledgedb = KnowledgeDB( user_id=user_id, file_name=knowledge_to_add.file_name, @@ -184,7 +200,9 @@ async def create_knowledge( brains=brains, ) - knowledge_db = await self.repository.create_knowledge(knowledgedb) + knowledge_db = await self.repository.create_knowledge( + knowledgedb, autocommit=autocommit + ) try: if knowledgedb.source == KnowledgeSource.LOCAL and upload_file: @@ -197,20 +215,19 @@ async def create_knowledge( knowledge_db = await self.repository.update_knowledge( knowledge_db, KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), + autocommit=autocommit, ) if knowledge_db.brains and len(knowledge_db.brains) > 0: # Schedule this new knowledge to be processed knowledge_db = await self.repository.update_knowledge( knowledge_db, KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), + autocommit=autocommit, ) celery.send_task( "process_file_task", kwargs={ "knowledge_id": knowledge_db.id, - "file_name": knowledge_db.file_name, - "source": knowledge_db.source, - "source_link": knowledge_db.source_link, }, ) @@ -219,7 +236,9 @@ async def create_knowledge( logger.exception( f"Error uploading knowledge {knowledgedb.id} to storage : {e}" ) - await self.repository.remove_knowledge(knowledge=knowledge_db) + await self.repository.remove_knowledge( + knowledge=knowledge_db, autocommit=autocommit + ) raise UploadError() async def insert_knowledge_brain( @@ -282,21 +301,31 @@ async def update_status_knowledge( async def update_file_sha1_knowledge(self, knowledge_id: UUID, file_sha1: str): return await self.repository.update_file_sha1_knowledge(knowledge_id, file_sha1) - async def remove_knowledge(self, knowledge: KnowledgeDB) -> DeleteKnowledgeResponse: + async def remove_knowledge( + self, knowledge: KnowledgeDB, autocommit: bool = True + ) -> DeleteKnowledgeResponse: assert knowledge.id try: # TODO: # - Notion folders are special, they are themselves files and should be removed from storage - children = await self.repository.get_knowledge_tree(knowledge.id) - km_paths = [ - self.storage.get_storage_path(k) for k in children if not k.is_folder - ] - if not knowledge.is_folder: - km_paths.append(self.storage.get_storage_path(knowledge)) - + km_paths = [] + if knowledge.source == KnowledgeSource.LOCAL: + if knowledge.is_folder: + children = await self.repository.get_knowledge_tree(knowledge.id) + km_paths.extend( + [ + self.storage.get_storage_path(k) + for k in children + if not k.is_folder + ] + ) + if not knowledge.is_folder: + km_paths.append(self.storage.get_storage_path(knowledge)) # recursively deletes files - deleted_km = await self.repository.remove_knowledge(knowledge) + deleted_km = await self.repository.remove_knowledge( + knowledge, autocommit=autocommit + ) # TODO: remove storage asynchronously in background task or in some task await asyncio.gather(*[self.storage.remove_file(p) for p in km_paths]) return deleted_km diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 2fd349f5f370..dcc9149dd251 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -1,16 +1,19 @@ import asyncio from datetime import datetime, timedelta, timezone +from functools import lru_cache from pathlib import Path -from typing import AsyncGenerator, List, Optional, Tuple +from typing import Any, AsyncGenerator, List, Optional, Tuple from uuid import UUID from quivr_api.logger import get_logger -from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate +from quivr_api.modules.knowledge.dto.inputs import AddKnowledge, KnowledgeUpdate from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile, SyncType +from quivr_api.modules.sync.utils.sync import BaseSync from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus +from sqlalchemy.ext.asyncio import AsyncSessionTransaction from quivr_worker.parsers.crawler import URL, extract_from_url from quivr_worker.process.process_file import parse_qfile, store_chunks @@ -156,7 +159,7 @@ async def _yield_syncs( # Yield parent_knowledge as the first knowledge to process async with build_sync_file( file_knowledge=parent_knowledge, - sync=sync, + credentials=sync.credentials, sync_provider=sync_provider, sync_file=SyncFile( id=parent_knowledge.sync_file_id, @@ -193,23 +196,28 @@ async def _yield_syncs( continue async with build_sync_file( file_knowledge=file_knowledge, - sync=sync, + credentials=sync.credentials, sync_provider=sync_provider, sync_file=sync_file, ) as f: yield f + async def create_savepoint(self) -> AsyncSessionTransaction: + savepoint = ( + await self.services.knowledge_service.repository.session.begin_nested() + ) + return savepoint + async def process_knowledge(self, knowledge_id: UUID): async for knowledge_tuple in self.yield_processable_knowledge(knowledge_id): # FIXME(@AmineDiro) : nested transaction for making - savepoint = ( - await self.services.knowledge_service.repository.session.begin_nested() - ) + savepoint = await self.create_savepoint() if knowledge_tuple is None: continue knowledge, qfile = knowledge_tuple try: await self._process_inner(knowledge=knowledge, qfile=qfile) + await savepoint.commit() except Exception as e: await savepoint.rollback() logger.error(f"Error processing knowledge {knowledge_id} : {e}") @@ -238,8 +246,14 @@ async def _process_inner(self, knowledge: KnowledgeDB, qfile: QuivrFile): # Update sync last_synced_at=last_synced_at if knowledge.sync_id else None, ), + autocommit=False, ) + @lru_cache(maxsize=50) # noqa: B019 + async def get_sync_provider(self, sync_id: int) -> Sync: + sync = await self.services.sync_service.get_sync_by_id(sync_id) + return sync + async def update_outdated_syncs_files( self, timedelta_hour: int = 8, batch_size: int = 1000 ): @@ -249,47 +263,96 @@ async def update_outdated_syncs_files( batch_size=batch_size, km_sync_type=SyncType.FILE, ) - sync_cache: dict[int, Sync] = {} - for km in km_sync_files: - assert km.sync_id, "can only update sync files with sync_id" - assert km.sync_file_id, "can only update sync files with sync_file_id " - assert ( - km.last_synced_at - ), "can only update sync files without a last_synced_at" - if km.sync_id in sync_cache: - sync = sync_cache[km.sync_id] - else: - sync = await self.services.sync_service.get_sync_by_id(km.sync_id) - assert sync.id - sync_cache[sync.id] = sync - - if sync.credentials is None: - logger.error( - f"can't process knowledge: {km.id}. sync {sync.id} has no credentials" - ) - raise ValueError("no associated credentials") - provider_name = SyncProvider(sync.provider.lower()) - sync_provider = self.services.syncprovider_mapping[provider_name] + for old_km in km_sync_files: try: + assert old_km.sync_id, "can only update sync files with sync_id" + assert ( + old_km.sync_file_id + ), "can only update sync files with sync_file_id " + sync = await self.get_sync_provider(old_km.sync_id) + if sync.credentials is None: + logger.error( + f"can't process knowledge: {old_km.id}. sync {sync.id} has no credentials" + ) + raise ValueError( + f"no associated credentials with knowledge {old_km}" + ) + provider_name = SyncProvider(sync.provider.lower()) + sync_provider = self.services.syncprovider_mapping[provider_name] new_sync_file = ( await sync_provider.aget_files_by_id( - credentials=sync.credentials, file_ids=[km.sync_file_id] + credentials=sync.credentials, file_ids=[old_km.sync_file_id] ) )[0] - if ( - new_sync_file.last_modified_at - and new_sync_file.last_modified_at < km.last_synced_at - ) or new_sync_file.last_modified_at is None: - # Create transaction - # - Create new knowledge with brains and parent like the old one - # _process_knowledge - # - Parse it + store the chunks - # - Update the knowledge - # - Remove the older knowledge - pass - else: - continue - + await self.update_outdated_km( + old_km=old_km, + new_sync_file=new_sync_file, + sync_provider=sync_provider, + sync_credentials=sync.credentials, + ) except FileNotFoundError: - # Remove has been deleted - pass + logger.info( + f"Knowledge {old_km.id} not found in remote sync. Removing the knowledge" + ) + await self.services.knowledge_service.remove_knowledge( + old_km, autocommit=True + ) + except Exception: + logger.exception(f"Exception occured processing km: {old_km.id}") + + async def update_outdated_km( + self, + old_km: KnowledgeDB, + new_sync_file: SyncFile, + sync_provider: BaseSync, + sync_credentials: dict[str, Any], + ) -> KnowledgeDB | None: + assert ( + old_km.last_synced_at + ), "can only update sync files without a last_synced_at" + if ( + new_sync_file.last_modified_at + and new_sync_file.last_modified_at > old_km.last_synced_at + ) or new_sync_file.last_modified_at is None: + savepoint = await self.create_savepoint() + try: + new_km = await self.services.knowledge_service.create_knowledge( + user_id=old_km.user_id, + knowledge_to_add=AddKnowledge( + file_name=new_sync_file.name, + is_folder=new_sync_file.is_folder, + extension=new_sync_file.extension, + source=old_km.source, + source_link=new_sync_file.web_view_link, + parent_id=old_km.parent_id, + sync_id=old_km.sync_id, + sync_file_id=new_sync_file.id, + ), + status=KnowledgeStatus.PROCESSING, + add_brains=await old_km.awaitable_attrs.brains, + upload_file=None, + autocommit=False, + ) + async with build_sync_file( + new_km, + new_sync_file, + sync_provider=sync_provider, + credentials=sync_credentials, + ) as ( + new_km, + qfile, + ): + await self._process_inner(new_km, qfile) + await self.services.knowledge_service.remove_knowledge( + old_km, autocommit=False + ) + await savepoint.commit() + await savepoint.session.refresh(new_km) + return new_km + + except Exception as e: + logger.exception( + f"Rolling back. Error occured updating sync {old_km.id}: {e}" + ) + await savepoint.rollback() + raise diff --git a/backend/worker/quivr_worker/process/utils.py b/backend/worker/quivr_worker/process/utils.py index 34adb5147c3b..f4b052a25a8a 100644 --- a/backend/worker/quivr_worker/process/utils.py +++ b/backend/worker/quivr_worker/process/utils.py @@ -12,7 +12,7 @@ from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.sync.dto.outputs import SyncProvider -from quivr_api.modules.sync.entity.sync_models import Sync, SyncFile +from quivr_api.modules.sync.entity.sync_models import SyncFile from quivr_api.modules.sync.utils.sync import ( AzureDriveSync, BaseSync, @@ -98,13 +98,12 @@ async def build_sync_file( file_knowledge: KnowledgeDB, sync_file: SyncFile, sync_provider: BaseSync, - sync: Sync, + credentials: dict[str, Any], ) -> AsyncGenerator[Tuple[KnowledgeDB, QuivrFile], None]: - assert sync.credentials file_data = await download_sync_file( sync_provider=sync_provider, file=sync_file, - credentials=sync.credentials, + credentials=credentials, ) file_knowledge.file_sha1 = compute_sha1(file_data) file_knowledge.file_size = len(file_data) diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index e993f50138a5..4807f0dd90ac 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -1,4 +1,5 @@ import os +from datetime import datetime, timedelta, timezone from io import BytesIO from pathlib import Path from uuid import uuid4 @@ -217,7 +218,7 @@ async def local_knowledge_folder_with_file( parent_id=folder_km.id, ) km_data = BytesIO(os.urandom(24)) - km = await service.create_knowledge( + _ = await service.create_knowledge( user_id=user.id, knowledge_to_add=km_to_add, upload_file=UploadFile(file=km_data, size=24, filename=km_to_add.file_name), @@ -287,6 +288,7 @@ async def sync_knowledge_file( parent=None, sync_file_id="id1", sync=sync, + last_synced_at=datetime.now(timezone.utc) - timedelta(days=2), ) session.add(km) diff --git a/backend/worker/tests/test_update_syncs.py b/backend/worker/tests/test_update_syncs.py new file mode 100644 index 000000000000..4db0d6bf573a --- /dev/null +++ b/backend/worker/tests/test_update_syncs.py @@ -0,0 +1,88 @@ +from datetime import datetime, timedelta, timezone +from typing import Any + +import pytest +from langchain_core.documents import Document +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.sync.entity.sync_models import SyncFile +from quivr_api.modules.sync.tests.test_sync_controller import FakeSync +from quivr_api.modules.vector.entity.vector import Vector +from quivr_core.files.file import QuivrFile +from quivr_core.models import KnowledgeStatus +from quivr_worker.process.processor import KnowledgeProcessor, ProcessorServices +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + + +async def _parse_file_mock( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], +) -> list[Document]: + with open(qfile.path, "rb") as f: + return [Document(page_content=str(f.read()), metadata={})] + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) +async def test_update_sync_file( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + sync_knowledge_file: KnowledgeDB, +): + input_km = sync_knowledge_file + assert input_km.id + assert input_km.brains + assert input_km.sync_file_id + assert input_km.file_name + assert input_km.source_link + assert input_km.last_synced_at + + km_processor = KnowledgeProcessor(proc_services) + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + new_sync_file = SyncFile( + id=input_km.sync_file_id, + name=input_km.file_name, + extension=input_km.extension, + is_folder=False, + web_view_link=input_km.source_link, + last_modified_at=datetime.now(timezone.utc) - timedelta(hours=1), + ) + sync_provider = FakeSync(provider_name=input_km.source, n_get_files=0) + new_km = await km_processor.update_outdated_km( + old_km=sync_knowledge_file, + new_sync_file=new_sync_file, + sync_provider=sync_provider, + sync_credentials={}, + ) + + # Check knowledge was updated + assert new_km + assert new_km.id + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(new_km.id) + assert km.status == KnowledgeStatus.PROCESSED + assert {b.brain_id for b in km.brains} == {b.brain_id for b in input_km.brains} + assert km.parent_id == input_km.parent_id + assert km.file_sha1 is not None + assert km.last_synced_at + assert km.last_synced_at > input_km.last_synced_at + + # Check vectors where removed + vecs = list( + ( + await session.exec( + select(Vector).where(col(Vector.knowledge_id) == input_km.id) + ) + ).all() + ) + assert len(vecs) == 0 + + # Check vectors where added for the new km + vecs = list( + ( + await session.exec(select(Vector).where(col(Vector.knowledge_id) == km.id)) + ).all() + ) + assert len(vecs) > 0 + assert vecs[0].metadata_ is not None From 3ed0de58e5cdfd2034b0ce9e8d9c4f2ed9fde9a4 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 1 Oct 2024 14:19:29 +0200 Subject: [PATCH 48/63] test sync update single file --- backend/worker/quivr_worker/process/README.md | 64 +++++++++++-------- .../worker/quivr_worker/process/processor.py | 1 - backend/worker/tests/conftest.py | 52 +++++++++++++++ backend/worker/tests/test_update_syncs.py | 63 ++++++++++++++++++ 4 files changed, 151 insertions(+), 29 deletions(-) diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index 2499bd7bb872..170531b048b0 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -8,29 +8,29 @@ Here's the grammar correction and a more explicit version of your markdown, keep 1. The task receives a `knowledge_id: UUID`. 2. The `KnowledgeProcessor.process_knowledge` method processes the knowledge: - - It constructs a processable tuple of `[Knowledge, QuivrFile]` stream: - - Retrieves the `KnowledgeDB` object from the database. - - Determines the processing steps based on the knowledge source: - - **Local**: - - Downloads the knowledge data from S3 storage and writes it to a temporary file. - - Yields the `[Knowledge, QuivrFile]`. - - **Web**: Processes similarly to the **Local** method. - - **[Syncs]**: - - Fetches the associated sync and verifies the credentials. - - Concurrently retrieves all knowledges for the user from the database associated with this sync, as well as the tree of sync files where this knowledge is the parent (using the sync provider). - - Downloads the knowledge and yields the initial `[Knowledge, QuivrFile]` that the task received. - - For all children of this knowledge (i.e., those fetched from the sync): - - If the child exists in the database (i.e., knowledge where `knowledge.sync_id == sync_file.id`): - - This implies that the sync's child knowledge might have been processed earlier in another brain. - - If the knowledge has been PROCESSED, link it to the parent brains and continue. - - If not, reprocess the file. - - If the child does not exist: - - Create the knowledge associated with the sync file and set it to `Processing`. - - Download the sync file's data and yield the `[Knowledge, QuivrFile]`. - - Skip processing of the tuple if the knowledge is a folder. - - Parse the `QuivrFile` using `quivr-core`. - - Store the resulting chunks in the database. - - Update the knowledge status to `PROCESSED`. + - It constructs a processable tuple of `[Knowledge, QuivrFile]` stream: + - Retrieves the `KnowledgeDB` object from the database. + - Determines the processing steps based on the knowledge source: + - **Local**: + - Downloads the knowledge data from S3 storage and writes it to a temporary file. + - Yields the `[Knowledge, QuivrFile]`. + - **Web**: Processes similarly to the **Local** method. + - **[Syncs]**: + - Fetches the associated sync and verifies the credentials. + - Concurrently retrieves all knowledges for the user from the database associated with this sync, as well as the tree of sync files where this knowledge is the parent (using the sync provider). + - Downloads the knowledge and yields the initial `[Knowledge, QuivrFile]` that the task received. + - For all children of this knowledge (i.e., those fetched from the sync): + - If the child exists in the database (i.e., knowledge where `knowledge.sync_id == sync_file.id`): + - This implies that the sync's child knowledge might have been processed earlier in another brain. + - If the knowledge has been PROCESSED, link it to the parent brains and continue. + - If not, reprocess the file. + - If the child does not exist: + - Create the knowledge associated with the sync file and set it to `Processing`. + - Download the sync file's data and yield the `[Knowledge, QuivrFile]`. + - Skip processing of the tuple if the knowledge is a folder. + - Parse the `QuivrFile` using `quivr-core`. + - Store the resulting chunks in the database. + - Update the knowledge status to `PROCESSED`. ### Handling Exceptions During Parsing Loop @@ -39,12 +39,14 @@ Here's the grammar correction and a more explicit version of your markdown, keep If an exception occurs during the parsing loop, the following steps are taken: 1. Roll back the current transaction (this only affects the vectors) if they were set. The processing loop performs the following stateful operations in this order: - - Creating knowledges (with `Processing` status). - - Updating knowledges: linking them to brains. - - Creating vectors. - - Updating knowledges. - + + - Creating knowledges (with `Processing` status). + - Updating knowledges: linking them to brains. + - Creating vectors. + - Updating knowledges. + **Transaction Safety for Each Operation:** + - **Creating knowledge and linking to brains**: These operations can be retried safely. Knowledge is only recreated if it does not already exist in the database, allowing for safe retry. - **Linking knowledge to brains**: Only links the brain if it is not already associated with the knowledge. Safe for retry. - **Creating vectors**: @@ -77,3 +79,9 @@ However, sync knowledge added to a brain will be reprocessed after some time thr ## Notification Steps To discuss: @StanGirard @Zewed + +# Sync update + +- Task to update all files +- Task to get all folders and proceess +- Rollback doesnt modify the original knowledge diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index dcc9149dd251..d9a956117071 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -355,4 +355,3 @@ async def update_outdated_km( f"Rolling back. Error occured updating sync {old_km.id}: {e}" ) await savepoint.rollback() - raise diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index 4807f0dd90ac..f87e7597d4e0 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -25,6 +25,7 @@ from quivr_api.modules.sync.tests.test_sync_controller import FakeSync from quivr_api.modules.sync.utils.sync import BaseSync from quivr_api.modules.user.entity.user_identity import User +from quivr_api.modules.vector.entity.vector import Vector from quivr_api.modules.vector.repository.vectors_repository import VectorRepository from quivr_api.modules.vector.service.vector_service import VectorService from quivr_core.files.file import QuivrFile @@ -298,6 +299,57 @@ async def sync_knowledge_file( return km +@pytest.fixture(scope="module") +def embedder(): + return DeterministicFakeEmbedding(size=settings.embedding_dim) + + +@pytest_asyncio.fixture(scope="function") +async def sync_knowledge_file_processed( + session: AsyncSession, + proc_services: ProcessorServices, + user: User, + brain_user: Brain, + sync: Sync, + embedder: DeterministicFakeEmbedding, +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + + km = KnowledgeDB( + file_name="test_file_1.txt", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://test/test", + file_size=1233, + file_sha1="1234kj", + user_id=user.id, + brains=[brain_user], + parent=None, + sync_file_id="id1", + sync=sync, + last_synced_at=datetime.now(timezone.utc) - timedelta(days=2), + ) + + session.add(km) + await session.commit() + await session.refresh(km) + + assert km.id + + vec = Vector( + content="test", + metadata_={}, + embedding=embedder.embed_query("test"), # type: ignore + knowledge_id=km.id, + ) + session.add(vec) + await session.commit() + + return km + + @pytest_asyncio.fixture(scope="function") async def sync_knowledge_folder( session: AsyncSession, diff --git a/backend/worker/tests/test_update_syncs.py b/backend/worker/tests/test_update_syncs.py index 4db0d6bf573a..6e9b75aad836 100644 --- a/backend/worker/tests/test_update_syncs.py +++ b/backend/worker/tests/test_update_syncs.py @@ -86,3 +86,66 @@ async def test_update_sync_file( ) assert len(vecs) > 0 assert vecs[0].metadata_ is not None + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [0], indirect=True) +async def test_update_sync_file_rollback( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + sync_knowledge_file_processed: KnowledgeDB, +): + input_km = sync_knowledge_file_processed + assert input_km.id + assert input_km.brains + assert input_km.sync_file_id + assert input_km.file_name + assert input_km.source_link + assert input_km.last_synced_at + + async def _parse_file_mock_error( + qfile: QuivrFile, + **processor_kwargs: dict[str, Any], + ) -> list[Document]: + raise Exception("error") + + km_processor = KnowledgeProcessor(proc_services) + monkeypatch.setattr( + "quivr_worker.process.processor.parse_qfile", _parse_file_mock_error + ) + new_sync_file = SyncFile( + id=input_km.sync_file_id, + name=input_km.file_name, + extension=input_km.extension, + is_folder=False, + web_view_link=input_km.source_link, + last_modified_at=datetime.now(timezone.utc) - timedelta(hours=1), + ) + sync_provider = FakeSync(provider_name=input_km.source, n_get_files=0) + new_km = await km_processor.update_outdated_km( + old_km=input_km, + new_sync_file=new_sync_file, + sync_provider=sync_provider, + sync_credentials={}, + ) + + # Check knowledge was not removed + assert new_km is None + + # Check vectors where not removed + vecs = list( + ( + await session.exec( + select(Vector).where(col(Vector.knowledge_id) == input_km.id) + ) + ).all() + ) + # Check nothing was added + assert len(vecs) == 1 + + # Check kms statyed correct + all_kms = list((await session.exec(select(KnowledgeDB))).unique().all()) + assert len(all_kms) == 1 + assert all_kms[0].id == input_km.id + assert all_kms[0].last_synced_at == input_km.last_synced_at From b99db0f5a4722985f42278e534e4ecbf67a42225 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 1 Oct 2024 16:37:46 +0200 Subject: [PATCH 49/63] celery worker update syncs --- .../knowledge/service/knowledge_service.py | 9 +- .../quivr_api/modules/sync/utils/syncutils.py | 403 ------------------ .../quivr_worker/assistants/assistants.py | 2 +- backend/worker/quivr_worker/celery_worker.py | 78 ++-- backend/worker/quivr_worker/process/README.md | 47 +- .../worker/quivr_worker/process/processor.py | 98 ++++- backend/worker/quivr_worker/syncs/__init__.py | 3 - .../worker/quivr_worker/syncs/update_syncs.py | 17 +- backend/worker/tests/conftest.py | 52 +++ backend/worker/tests/test_update_syncs.py | 79 +++- 10 files changed, 292 insertions(+), 496 deletions(-) delete mode 100644 backend/api/quivr_api/modules/sync/utils/syncutils.py diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index bdff2236a796..a11d03ab6340 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -167,18 +167,19 @@ async def create_knowledge( knowledge_to_add: AddKnowledge, upload_file: UploadFile | None = None, status: KnowledgeStatus = KnowledgeStatus.RESERVED, - add_brains: list[Brain] = [], + link_brains: list[Brain] = [], autocommit: bool = True, + process_async: bool = True, ) -> KnowledgeDB: brains = [] if knowledge_to_add.parent_id: parent_knowledge = await self.get_knowledge(knowledge_to_add.parent_id) brains = await parent_knowledge.awaitable_attrs.brains - if len(add_brains) > 0: + if len(link_brains) > 0: brains.extend( [ b - for b in add_brains + for b in link_brains if b.brain_id not in {b.brain_id for b in brains} ] ) @@ -217,7 +218,7 @@ async def create_knowledge( KnowledgeUpdate(status=KnowledgeStatus.UPLOADED), autocommit=autocommit, ) - if knowledge_db.brains and len(knowledge_db.brains) > 0: + if knowledge_db.brains and len(knowledge_db.brains) > 0 and process_async: # Schedule this new knowledge to be processed knowledge_db = await self.repository.update_knowledge( knowledge_db, diff --git a/backend/api/quivr_api/modules/sync/utils/syncutils.py b/backend/api/quivr_api/modules/sync/utils/syncutils.py deleted file mode 100644 index 33adb5603962..000000000000 --- a/backend/api/quivr_api/modules/sync/utils/syncutils.py +++ /dev/null @@ -1,403 +0,0 @@ -from typing import List, Tuple -from uuid import UUID - -from quivr_api.celery_config import celery -from quivr_api.logger import get_logger -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB -from quivr_api.modules.sync.entity.sync_models import ( - SyncFile, -) - -logger = get_logger(__name__) - -celery_inspector = celery.control.inspect() - - -async def fetch_sync_knowledge( - self, - sync_id: int, - user_id: UUID, - folder_id: str | None, -) -> Tuple[dict[str, KnowledgeDB], List[SyncFile] | None]: - map_knowledges_task = self.services.knowledge_service.map_syncs_knowledge_user( - sync_id=sync_id, user_id=user_id - ) - sync_files_task = self.services.sync_service.get_files_folder_user_sync( - sync_id, - user_id, - folder_id, - ) - return await asyncio.gather(*[map_knowledges_task, sync_files_task]) # type: ignore # noqa: F821 - - -# NOTE: we are filtering based on file path names in sync ! -def filter_on_supported_files( - files: list[SyncFile], existing_files: dict[str, DBSyncFile] -) -> list[Tuple[SyncFile, DBSyncFile | None]]: - res = [] - for new_file in files: - prev_file = existing_files.get(new_file.name, None) - if (prev_file and prev_file.supported) or prev_file is None: - res.append((new_file, prev_file)) - return res - - -def should_download_file( - file: SyncFile, - last_updated_sync_active: datetime | None, - provider_name: str, - datetime_format: str, -) -> bool: - file_last_modified_utc = datetime.strptime( - file.last_modified_at, datetime_format - ).replace(tzinfo=timezone.utc) - - should_download = ( - last_updated_sync_active is None - or file_last_modified_utc > last_updated_sync_active - ) - - # TODO: Handle notion database - if provider_name == "notion": - should_download &= file.extension != "db" - else: - should_download &= not file.is_folder - - return should_download - - -class SyncUtils: - def __init__( - self, - # sync_user_service: ISyncUserService, - # sync_active_service: ISyncService, - # sync_files_repo: SyncFileInterface, - sync_cloud: BaseSync, - knowledge_service: KnowledgeService, - notification_service: NotificationService, - brain_vectors: BrainsVectors, - ) -> None: - self.sync_user_service = sync_user_service - self.sync_active_service = sync_active_service - self.sync_files_repo = sync_files_repo - self.knowledge_service = knowledge_service - self.sync_cloud = sync_cloud - self.notification_service = notification_service - self.brain_vectors = brain_vectors - - # TODO: This modifies the file, we should treat it as such - def create_sync_bulk_notification( - self, files: list[SyncFile], current_user: UUID, brain_id: UUID, bulk_id: UUID - ) -> list[SyncFile]: - res = [] - # TODO: bulk insert in batch - for file in files: - upload_notification = self.notification_service.add_notification( - CreateNotification( - user_id=current_user, - bulk_id=bulk_id, - status=NotificationsStatusEnum.INFO, - title=file.name, - category="sync", - brain_id=str(brain_id), - ) - ) - file.notification_id = upload_notification.id - res.append(file) - return res - - async def download_file( - self, file: SyncFile, credentials: dict[str, Any] - ) -> DownloadedSyncFile: - logger.info(f"Downloading {file} using {self.sync_cloud}") - file_response = await self.sync_cloud.adownload_file(credentials, file) - logger.debug(f"Fetch sync file response: {file_response}") - file_name = str(file_response["file_name"]) - raw_data = file_response["content"] - file_data = ( - io.BufferedReader(raw_data) # type: ignore - if isinstance(raw_data, io.BytesIO) - else io.BufferedReader(raw_data.encode("utf-8")) # type: ignore - ) - extension = os.path.splitext(file_name)[-1].lower() - dfile = DownloadedSyncFile( - file_name=file_name, - file_data=file_data, - extension=extension, - ) - logger.debug(f"Successfully downloaded sync file : {dfile}") - return dfile - - # TODO: REDO THIS MESS !!!! - # REMOVE ALL SYNC TABLES and start from scratch - - async def process_sync_file( - self, - file: SyncFile, - previous_file: DBSyncFile | None, - current_user: SyncsUser, - sync_active: SyncsActive, - ): - logger.info("Processing file: %s", file.name) - brain_id = sync_active.brain_id - source, source_link = self.sync_cloud.name, file.web_view_link - downloaded_file = await self.download_file(file, current_user.credentials) - storage_path = f"{brain_id}/{downloaded_file.file_name}" - exists_in_storage = check_file_exists(str(brain_id), file.name) - - if downloaded_file.extension not in [ - ".pdf", - ".txt", - ".md", - ".csv", - ".docx", - ".xlsx", - ".pptx", - ".doc", - ]: - raise ValueError(f"Incompatible file extension for {downloaded_file}") - - response = await upload_file_storage( - downloaded_file.file_data, - storage_path, - upsert=exists_in_storage, - ) - assert response, f"Error uploading {downloaded_file} to {storage_path}" - self.notification_service.update_notification_by_id( - file.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.SUCCESS, - description="File downloaded successfully", - ), - ) - # TODO : why knowledge + syncfile, drop syncfile ... - # FIXME : Simplify this logic in KMS plzzz - sync_file_db = self.sync_files_repo.update_or_create_sync_file( - file=file, - previous_file=previous_file, - sync_active=sync_active, - supported=True, - ) - knowledge = await self.knowledge_service.update_or_create_knowledge_sync( - brain_id=brain_id, - file=file, - new_sync_file=sync_file_db, - prev_sync_file=previous_file, - downloaded_file=downloaded_file, - source=source, - source_link=source_link, - user_id=current_user.user_id, - ) - - # Send file for processing - celery.send_task( - "process_file_task", - kwargs={ - "brain_id": brain_id, - "knowledge_id": knowledge.id, - "file_name": storage_path, - "file_original_name": file.name, - "source": source, - "source_link": source_link, - "notification_id": file.notification_id, - }, - ) - return file - - async def process_sync_files( - self, - files: List[SyncFile], - current_user: SyncsUser, - sync_active: SyncsActive, - ): - logger.info(f"Processing {len(files)} for sync_active: {sync_active.id}") - current_user.credentials = self.sync_cloud.check_and_refresh_access_token( - current_user.credentials - ) - - bulk_id = uuid4() - downloaded_files = [] - list_existing_files = self.sync_files_repo.get_sync_files(sync_active.id) - existing_files = {f.path: f for f in list_existing_files} - - supported_files = filter_on_supported_files(files, existing_files) - - files = self.create_sync_bulk_notification( - files, current_user.user_id, sync_active.brain_id, bulk_id - ) - - for file, prev_file in supported_files: - try: - result = await self.process_sync_file( - file=file, - previous_file=prev_file, - current_user=current_user, - sync_active=sync_active, - ) - if result is not None: - downloaded_files.append(result) - - self.notification_service.update_notification_by_id( - file.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.SUCCESS, - description="File downloaded successfully", - ), - ) - - except Exception as e: - logger.error( - "An error occurred while syncing %s files: %s", - self.sync_cloud.name, - e, - ) - # TODO: this process_sync_file could fail for a LOT of reason redo this logic - # File isn't supported so we set it as so ? - self.sync_files_repo.update_or_create_sync_file( - file=file, - sync_active=sync_active, - previous_file=prev_file, - supported=False, - ) - self.notification_service.update_notification_by_id( - file.notification_id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.ERROR, - description="Error downloading file", - ), - ) - - return {"downloaded_files": downloaded_files} - - async def get_files_to_download( - self, sync_active: SyncsActive, user_sync: SyncsUser - ) -> list[SyncFile]: - # Get the folder id from the settings from sync_active - folders = sync_active.settings.get("folders", []) - files_ids = sync_active.settings.get("files", []) - - files = await self.get_syncfiles_from_ids( - user_sync.credentials, - files_ids=files_ids, - folder_ids=folders, - sync_user_id=user_sync.id, - ) - - logger.debug(f"original files to download for {sync_active.id} : {files}") - - last_synced_time = ( - datetime.fromisoformat(sync_active.last_synced).astimezone(timezone.utc) - if sync_active.last_synced - else None - ) - - files_ids = [ - file - for file in files - if should_download_file( - file=file, - last_updated_sync_active=last_synced_time, - provider_name=self.sync_cloud.lower_name, - datetime_format=self.sync_cloud.datetime_format, - ) - ] - - logger.debug(f"filter files to download for {sync_active} : {files_ids}") - return files_ids - - async def get_syncfiles_from_ids( - self, - credentials: dict[str, Any], - files_ids: list[str], - folder_ids: list[str], - sync_user_id: int, - ) -> list[SyncFile]: - files = [] - if self.sync_cloud.lower_name == "notion": - files_ids += folder_ids - - for folder_id in folder_ids: - logger.debug( - f"Recursively getting file_ids from {self.sync_cloud.name}. folder_id={folder_id}" - ) - files.extend( - await self.sync_cloud.aget_files( - credentials=credentials, - sync_user_id=sync_user_id, - folder_id=folder_id, - recursive=True, - ) - ) - if len(files_ids) > 0: - files.extend( - await self.sync_cloud.aget_files_by_id( - credentials=credentials, - file_ids=files_ids, - ) - ) - return files - - async def direct_sync( - self, - sync_active: SyncsActive, - sync_user: SyncsUser, - files_ids: list[str], - folder_ids: list[str], - ): - files = await self.get_syncfiles_from_ids( - sync_user.credentials, files_ids, folder_ids - ) - processed_files = await self.process_sync_files( - files=files, - current_user=sync_user, - sync_active=sync_active, - ) - - # Update the last_synced timestamp - self.sync_active_service.update_sync_active( - sync_active.id, - SyncsActiveUpdateInput( - last_synced=datetime.now().astimezone().isoformat(), force_sync=False - ), - ) - logger.info( - f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", - ) - return processed_files - - async def sync( - self, - sync_active: SyncsActive, - user_sync: SyncsUser, - ): - """ - Check if the Specific sync has not been synced and download the folders and files based on the settings. - - Args: - sync_active_id (int): The ID of the active sync. - user_id (str): The user ID associated with the active sync. - """ - logger.info( - "Starting %s sync for sync_active: %s", - self.sync_cloud.lower_name, - sync_active, - ) - - files_to_download = await self.get_files_to_download(sync_active, user_sync) - processed_files = await self.process_sync_files( - files=files_to_download, - current_user=user_sync, - sync_active=sync_active, - ) - - # Update the last_synced timestamp - self.sync_active_service.update_sync_active( - sync_active.id, - SyncsActiveUpdateInput( - last_synced=datetime.now().astimezone().isoformat(), force_sync=False - ), - ) - logger.info( - f"{self.sync_cloud.lower_name} sync completed for sync_active: {sync_active.id}. Synced all {len(processed_files)} files.", - ) - return processed_files diff --git a/backend/worker/quivr_worker/assistants/assistants.py b/backend/worker/quivr_worker/assistants/assistants.py index 7310384b0a85..b050d2adb0c3 100644 --- a/backend/worker/quivr_worker/assistants/assistants.py +++ b/backend/worker/quivr_worker/assistants/assistants.py @@ -7,8 +7,8 @@ ) from sqlalchemy.ext.asyncio import AsyncEngine -from quivr_worker.process.processor import _start_session from quivr_worker.utils.pdf_generator.pdf_generator import PDFGenerator, PDFModel +from quivr_worker.utils.services import _start_session async def aprocess_assistant_task( diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index c58b04146bf0..a8561b7ac717 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -16,7 +16,7 @@ from quivr_worker.assistants.assistants import aprocess_assistant_task from quivr_worker.check_premium import check_is_premium from quivr_worker.process import aprocess_file_task -from quivr_worker.syncs.update_syncs import update_sync_files +from quivr_worker.syncs.update_syncs import refresh_sync_files, refresh_sync_folders from quivr_worker.utils.utils import _patch_json load_dotenv() @@ -53,7 +53,6 @@ def init_worker(**kwargs): default_retry_delay=1, name="process_file_task", autoretry_for=(Exception,), - dont_autoretry_for=(FileExistsError,), ) def process_file_task( knowledge_id: UUID, @@ -74,17 +73,31 @@ def process_file_task( @celery.task( retries=3, default_retry_delay=1, - name="process_file_task", + name="refresh_sync_files_task", + autoretry_for=(Exception,), +) +def refresh_sync_files_task(): + if async_engine is None: + init_worker() + assert async_engine + logger.info("Update sync task started") + loop = asyncio.get_event_loop() + loop.run_until_complete(refresh_sync_files(async_engine=async_engine)) + + +@celery.task( + retries=3, + default_retry_delay=1, + name="refresh_sync_folders_task", autoretry_for=(Exception,), - dont_autoretry_for=(FileExistsError,), ) -def update_sync_task(): +def refresh_sync_folders_task(): if async_engine is None: init_worker() assert async_engine logger.info("Update sync task started") loop = asyncio.get_event_loop() - loop.run_until_complete(update_sync_files(async_engine=async_engine)) + loop.run_until_complete(refresh_sync_folders(async_engine=async_engine)) @celery.task( @@ -140,60 +153,21 @@ def check_is_premium_task(): check_is_premium(supabase_client) -# @celery.task(name="process_notion_sync_task") -# def process_notion_sync_task(): -# global async_engine -# assert async_engine -# loop = asyncio.get_event_loop() -# loop.run_until_complete(process_notion_sync(async_engine)) - - -# @celery.task(name="fetch_and_store_notion_files_task") -# def fetch_and_store_notion_files_task( -# access_token: str, user_id: UUID, sync_user_id: int -# ): -# if async_engine is None: -# init_worker() -# assert async_engine -# try: -# logger.debug("Fetching and storing Notion files") -# loop = asyncio.get_event_loop() -# loop.run_until_complete( -# fetch_and_store_notion_files_async( -# async_engine, access_token, user_id, sync_user_id -# ) -# ) -# sync_user_service.update_sync_user_status( -# sync_user_id=sync_user_id, status=str(SyncStatus.SYNCED) -# ) -# except Exception: -# logger.error("Error fetching and storing Notion files") -# sync_user_service.update_sync_user_status( -# sync_user_id=sync_user_id, status=str(SyncStatus.ERROR) -# ) - - -# @celery.task(name="clean_notion_user_syncs") -# def clean_notion_user_syncs(): -# logger.debug("Cleaning Notion user syncs") -# sync_user_service.clean_notion_user_syncs() - - celery.conf.beat_schedule = { "ping_telemetry": { "task": f"{__name__}.ping_telemetry", "schedule": crontab(minute="*/30", hour="*"), }, - # "process_active_syncs": { - # "task": "process_active_syncs_task", - # "schedule": crontab(minute="*/1", hour="*"), - # }, "process_premium_users": { "task": "check_is_premium_task", "schedule": crontab(minute="*/1", hour="*"), }, - "process_notion_sync": { - "task": "process_notion_sync_task", - "schedule": crontab(minute="0", hour="*/6"), + "refresh_sync_files": { + "task": "refresh_sync_files_task", + "schedule": crontab(hour="*/8"), + }, + "refresh_sync_folders": { + "task": "refresh_sync_folders_task", + "schedule": crontab(hour="*/8"), }, } diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index 170531b048b0..75cbb36ae12c 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -1,8 +1,4 @@ -Here's the grammar correction and a more explicit version of your markdown, keeping the original logic intact: - ---- - -# Knowledge Processing +# Knowledge Processing Task ## Steps for Processing @@ -74,14 +70,43 @@ Why can’t we set all children to `ERROR`? This introduces a potential race con However, sync knowledge added to a brain will be reprocessed after some time through the sync update task, ensuring that their status will eventually be set to the correct state. +# Syncing Knowledge task + +- Task to update all files +- Task to get all folders and proceess +- Rollback doesnt modify the original knowledge +- If we don't find a folder we will delete it -> deleting all the knowledges associated with it ! + +1. **Syncing Knowledge Syncs of Type Files:** + - Outdated file syncs are fetched in batches. + - For each file, if the remote file's `last_modified_at` is newer than the local `last_synced_at`, the file is updated. + - If the file is missing remotely, the db knowledge is deleted. +2. **Syncing Knowledge Folders:** + - Outdated folder syncs are retrieved in batches. + - For each folder, its children (files and subfolders) are fetched from both the database and the remote provider. + - Remote children missing from the local database are added and processed. + - **If a Folder is Not Found:** + - If a folder no longer exists remotely, it is deleted locally, along with all associated knowledge entries. + +🔴 **Key Considerations** + +- **Batch Processing:** + + - Both file and folder syncs are handled in batches, ensuring the system can process large data efficiently. + +- **Error Handling:** + + - The system logs errors such as missing credentials or files, allowing the sync process to continue or fail gracefully. + +- **Savepoints and Rollback:** + + - During file and folder processing, savepoints are created. If an error occurs, the transaction can be rolled back, ensuring the original knowledge remains unmodified. + +- **Deleting Folders:** + - If a folder is missing remotely, it triggers the deletion of the folder and all associated knowledge entries from the local system. + --- ## Notification Steps To discuss: @StanGirard @Zewed - -# Sync update - -- Task to update all files -- Task to get all folders and proceess -- Rollback doesnt modify the original knowledge diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index d9a956117071..6828aef13f56 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -146,7 +146,7 @@ async def _yield_syncs( f"received unprocessable knowledge : {parent_knowledge.id} " ) # Get associated sync - sync = await self.services.sync_service.get_sync_by_id(parent_knowledge.sync_id) + sync = await self._get_sync(parent_knowledge.sync_id) if sync.credentials is None: logger.error( f"can't process knowledge: {parent_knowledge.id}. sync {sync.id} has no credentials" @@ -230,6 +230,7 @@ async def process_knowledge(self, knowledge_id: UUID): ) async def _process_inner(self, knowledge: KnowledgeDB, qfile: QuivrFile): + last_synced_at = datetime.now(timezone.utc) if not skip_process(knowledge): chunks = await parse_qfile(qfile=qfile) await store_chunks( @@ -237,7 +238,6 @@ async def _process_inner(self, knowledge: KnowledgeDB, qfile: QuivrFile): chunks=chunks, vector_service=self.services.vector_service, ) - last_synced_at = datetime.now(timezone.utc) await self.services.knowledge_service.update_knowledge( knowledge, KnowledgeUpdate( @@ -250,11 +250,92 @@ async def _process_inner(self, knowledge: KnowledgeDB, qfile: QuivrFile): ) @lru_cache(maxsize=50) # noqa: B019 - async def get_sync_provider(self, sync_id: int) -> Sync: + async def _get_sync(self, sync_id: int) -> Sync: sync = await self.services.sync_service.get_sync_by_id(sync_id) return sync - async def update_outdated_syncs_files( + async def refresh_sync_folders( + self, timedelta_hour: int = 8, batch_size: int = 100 + ): + last_time = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) + km_sync_folders = await self.services.knowledge_service.get_outdated_syncs( + limit_time=last_time, + batch_size=batch_size, + km_sync_type=SyncType.FOLDER, + ) + for sync_folder_km in km_sync_folders: + await self.refresh_sync_folder(sync_folder_km) + + async def refresh_sync_folder(self, folder_km: KnowledgeDB) -> KnowledgeDB: + assert folder_km.sync_id, "can only update sync files with sync_id" + assert folder_km.sync_file_id, "can only update sync files with sync_file_id " + sync = await self._get_sync(folder_km.sync_id) + if sync.credentials is None: + logger.error( + f"can't process knowledge: {folder_km.id}. sync {sync.id} has no credentials" + ) + raise ValueError(f"no associated credentials with knowledge {folder_km}") + provider_name = SyncProvider(sync.provider.lower()) + sync_provider = self.services.syncprovider_mapping[provider_name] + km_children: List[KnowledgeDB] = await folder_km.awaitable_attrs.children + sync_children = {c.sync_file_id for c in km_children} + try: + sync_files = await sync_provider.aget_files( + credentials=sync.credentials, + folder_id=folder_km.sync_file_id, + recursive=False, + ) + breakpoint() + for sync_entry in filter(lambda s: s.id not in sync_children, sync_files): + await self.add_new_sync_entry(folder=folder_km, sync_entry=sync_entry) + + except FileNotFoundError: + logger.info( + f"Knowledge {folder_km.id} not found in remote sync. Removing the folder" + ) + await self.services.knowledge_service.remove_knowledge( + folder_km, autocommit=True + ) + except Exception: + logger.exception(f"Exception occured processing folder: {folder_km.id}") + finally: + await self.services.knowledge_service.update_knowledge( + knowledge=folder_km, + payload=KnowledgeUpdate(last_synced_at=datetime.now(timezone.utc)), + ) + return folder_km + + async def add_new_sync_entry(self, folder: KnowledgeDB, sync_entry: SyncFile): + sync_km = await self.services.knowledge_service.create_knowledge( + user_id=folder.user_id, + knowledge_to_add=AddKnowledge( + file_name=sync_entry.name, + is_folder=sync_entry.is_folder, + extension=sync_entry.extension, + source=folder.source, + source_link=sync_entry.web_view_link, + parent_id=folder.id, + sync_id=folder.sync_id, + sync_file_id=sync_entry.id, + ), + status=KnowledgeStatus.PROCESSING, + upload_file=None, + autocommit=True, + process_async=False, + ) + async for processable_tuple in self._yield_syncs(sync_km): + if processable_tuple is None: + continue + knowledge, qfile = processable_tuple + savepoint = await self.create_savepoint() + try: + await self._process_inner(knowledge=knowledge, qfile=qfile) + await savepoint.commit() + except Exception: + await savepoint.rollback() + logger.exception(f"Error occured processing :{knowledge.id}") + + async def refresh_knowledge_sync_files( self, timedelta_hour: int = 8, batch_size: int = 1000 ): last_time = datetime.now(timezone.utc) - timedelta(hours=timedelta_hour) @@ -269,7 +350,7 @@ async def update_outdated_syncs_files( assert ( old_km.sync_file_id ), "can only update sync files with sync_file_id " - sync = await self.get_sync_provider(old_km.sync_id) + sync = await self._get_sync(old_km.sync_id) if sync.credentials is None: logger.error( f"can't process knowledge: {old_km.id}. sync {sync.id} has no credentials" @@ -284,7 +365,7 @@ async def update_outdated_syncs_files( credentials=sync.credentials, file_ids=[old_km.sync_file_id] ) )[0] - await self.update_outdated_km( + await self.refresh_knowledge_entry( old_km=old_km, new_sync_file=new_sync_file, sync_provider=sync_provider, @@ -300,7 +381,7 @@ async def update_outdated_syncs_files( except Exception: logger.exception(f"Exception occured processing km: {old_km.id}") - async def update_outdated_km( + async def refresh_knowledge_entry( self, old_km: KnowledgeDB, new_sync_file: SyncFile, @@ -329,9 +410,10 @@ async def update_outdated_km( sync_file_id=new_sync_file.id, ), status=KnowledgeStatus.PROCESSING, - add_brains=await old_km.awaitable_attrs.brains, + link_brains=await old_km.awaitable_attrs.brains, upload_file=None, autocommit=False, + process_async=False, ) async with build_sync_file( new_km, diff --git a/backend/worker/quivr_worker/syncs/__init__.py b/backend/worker/quivr_worker/syncs/__init__.py index 0ca8a21db9fe..e69de29bb2d1 100644 --- a/backend/worker/quivr_worker/syncs/__init__.py +++ b/backend/worker/quivr_worker/syncs/__init__.py @@ -1,3 +0,0 @@ -from .process_active_syncs import process_all_active_syncs - -__all__ = ["process_all_active_syncs"] diff --git a/backend/worker/quivr_worker/syncs/update_syncs.py b/backend/worker/quivr_worker/syncs/update_syncs.py index ed8d89596ab0..960f38bd1997 100644 --- a/backend/worker/quivr_worker/syncs/update_syncs.py +++ b/backend/worker/quivr_worker/syncs/update_syncs.py @@ -1,17 +1,16 @@ from sqlalchemy.ext.asyncio import AsyncEngine +from quivr_worker.process.processor import KnowledgeProcessor from quivr_worker.utils.services import build_processor_services -async def update_sync_files(async_engine: AsyncEngine): +async def refresh_sync_files(async_engine: AsyncEngine): async with build_processor_services(async_engine) as processor_services: - # Get folders - ## Fetch the folder children - ## Fetch all knowledge for this sync - ## If this knowledge isn't - ## Get files - # Update sync files - pass + km_processor = KnowledgeProcessor(services=processor_services) + await km_processor.refresh_knowledge_sync_files() -# If knowledge is folder just call the link_knowledge_to_brain +async def refresh_sync_folders(async_engine: AsyncEngine): + async with build_processor_services(async_engine) as processor_services: + km_processor = KnowledgeProcessor(services=processor_services) + await km_processor.refresh_knowledge_sync_files() diff --git a/backend/worker/tests/conftest.py b/backend/worker/tests/conftest.py index f87e7597d4e0..d5f386f8d444 100644 --- a/backend/worker/tests/conftest.py +++ b/backend/worker/tests/conftest.py @@ -561,3 +561,55 @@ def pdf_qfile(tmp_path) -> QuivrFile: file_size=1000, path=Path("./tests/sample.pdf"), ) + + +@pytest_asyncio.fixture(scope="function") +async def sync_knowledge_folder_processed( + session: AsyncSession, + user: User, + brain_user: Brain, + sync: Sync, +) -> KnowledgeDB: + assert user.id + assert brain_user.brain_id + folder = KnowledgeDB( + file_name="folder", + extension="", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://test/file1", + file_size=10, + file_sha1="test", + user_id=user.id, + brains=[brain_user], + parent=None, + is_folder=True, + # NOTE: See FakeSync Implementation + sync_file_id="folder-1", + sync=sync, + last_synced_at=datetime.now(timezone.utc) - timedelta(days=2), + ) + + km = KnowledgeDB( + file_name="file", + extension=".txt", + status=KnowledgeStatus.PROCESSED, + source=SyncProvider.GOOGLE, + source_link="drive://test/folder1", + file_size=0, + file_sha1=None, + user_id=user.id, + brains=[brain_user], + parent=folder, + is_folder=False, + sync_file_id="file-1", + sync=sync, + last_synced_at=datetime.now(timezone.utc) - timedelta(days=2), + ) + + session.add(folder) + session.add(km) + await session.commit() + await session.refresh(folder) + + return folder diff --git a/backend/worker/tests/test_update_syncs.py b/backend/worker/tests/test_update_syncs.py index 6e9b75aad836..574ab30fb8e7 100644 --- a/backend/worker/tests/test_update_syncs.py +++ b/backend/worker/tests/test_update_syncs.py @@ -1,11 +1,15 @@ +import os from datetime import datetime, timedelta, timezone -from typing import Any +from io import BytesIO +from typing import Any, Dict, List, Union import pytest from langchain_core.documents import Document from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.entity.sync_models import SyncFile from quivr_api.modules.sync.tests.test_sync_controller import FakeSync +from quivr_api.modules.sync.utils.sync import BaseSync from quivr_api.modules.vector.entity.vector import Vector from quivr_core.files.file import QuivrFile from quivr_core.models import KnowledgeStatus @@ -24,7 +28,7 @@ async def _parse_file_mock( @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("proc_services", [0], indirect=True) -async def test_update_sync_file( +async def test_refresh_sync_file( monkeypatch, session: AsyncSession, proc_services: ProcessorServices, @@ -49,7 +53,7 @@ async def test_update_sync_file( last_modified_at=datetime.now(timezone.utc) - timedelta(hours=1), ) sync_provider = FakeSync(provider_name=input_km.source, n_get_files=0) - new_km = await km_processor.update_outdated_km( + new_km = await km_processor.refresh_knowledge_entry( old_km=sync_knowledge_file, new_sync_file=new_sync_file, sync_provider=sync_provider, @@ -90,7 +94,7 @@ async def test_update_sync_file( @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("proc_services", [0], indirect=True) -async def test_update_sync_file_rollback( +async def test_refresh_sync_file_rollback( monkeypatch, session: AsyncSession, proc_services: ProcessorServices, @@ -123,7 +127,7 @@ async def _parse_file_mock_error( last_modified_at=datetime.now(timezone.utc) - timedelta(hours=1), ) sync_provider = FakeSync(provider_name=input_km.source, n_get_files=0) - new_km = await km_processor.update_outdated_km( + new_km = await km_processor.refresh_knowledge_entry( old_km=input_km, new_sync_file=new_sync_file, sync_provider=sync_provider, @@ -149,3 +153,68 @@ async def _parse_file_mock_error( assert len(all_kms) == 1 assert all_kms[0].id == input_km.id assert all_kms[0].last_synced_at == input_km.last_synced_at + + +@pytest.mark.asyncio(loop_scope="session") +@pytest.mark.parametrize("proc_services", [2], indirect=True) +async def test_refresh_sync_folder( + monkeypatch, + session: AsyncSession, + proc_services: ProcessorServices, + sync_knowledge_folder_processed: KnowledgeDB, +): + input_km_folder = sync_knowledge_folder_processed + assert input_km_folder.id + assert input_km_folder.brains + assert input_km_folder.sync_file_id + assert input_km_folder.file_name + assert input_km_folder.source_link + assert input_km_folder.last_synced_at + + class _MockSync: + name = "FakeProvider" + lower_name = "google" + datetime_format: str = "%Y-%m-%dT%H:%M:%S.%fZ" + + async def aget_files( + self, credentials: Dict, file_ids: List[str] + ) -> List[SyncFile]: + return self.get_files(credentials, file_ids) + + def get_files(self, credentials: Dict, file_ids: List[str]) -> List[SyncFile]: + return [ + SyncFile( + id="file_id_1", + name="new_file", + extension=".txt", + web_view_link="fake://test.com", + is_folder=False, + last_modified_at=datetime.now(), + ) + ] + + async def adownload_file( + self, credentials: Dict, file: SyncFile + ) -> Dict[str, Union[str, BytesIO]]: + return {"content": str(os.urandom(24))} + + sync_provider_mapping: dict[SyncProvider, BaseSync] = { + provider: _MockSync() # type: ignore + for provider in list(SyncProvider) + } + + input_km_children = await input_km_folder.awaitable_attrs.children + proc_services.syncprovider_mapping = sync_provider_mapping + km_processor = KnowledgeProcessor(proc_services) + monkeypatch.setattr("quivr_worker.process.processor.parse_qfile", _parse_file_mock) + await km_processor.refresh_sync_folder(folder_km=input_km_folder) + + # Check knowledge was updated + assert input_km_folder + assert input_km_folder.id + knowledge_service = km_processor.services.knowledge_service + km = await knowledge_service.get_knowledge(input_km_folder.id) + assert km.status == KnowledgeStatus.PROCESSED + assert {k.id for k in await km.awaitable_attrs.children}.issuperset( + {k.id for k in input_km_children} + ) From 4180026b6fa4f4f1385910364603ae60c1bb71e6 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 1 Oct 2024 16:48:53 +0200 Subject: [PATCH 50/63] celery worker time limits --- backend/worker/quivr_worker/celery_worker.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/worker/quivr_worker/celery_worker.py b/backend/worker/quivr_worker/celery_worker.py index a8561b7ac717..31eedc399330 100644 --- a/backend/worker/quivr_worker/celery_worker.py +++ b/backend/worker/quivr_worker/celery_worker.py @@ -2,6 +2,7 @@ import os from uuid import UUID +from celery.exceptions import SoftTimeLimitExceeded, TimeLimitExceeded from celery.schedules import crontab from celery.signals import worker_process_init from dotenv import load_dotenv @@ -52,7 +53,10 @@ def init_worker(**kwargs): retries=3, default_retry_delay=1, name="process_file_task", - autoretry_for=(Exception,), + time_limit=600, # 10 min + soft_time_limit=300, + autoretry_for=(Exception,), # SoftTimeLimitExceeded should not included? + dont_autoretry_for=(SoftTimeLimitExceeded, TimeLimitExceeded), ) def process_file_task( knowledge_id: UUID, @@ -74,6 +78,7 @@ def process_file_task( retries=3, default_retry_delay=1, name="refresh_sync_files_task", + soft_time_limit=3600, autoretry_for=(Exception,), ) def refresh_sync_files_task(): From 53ba2bbd08a87af3861c254246c2f7b73def695e Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 1 Oct 2024 18:58:04 +0200 Subject: [PATCH 51/63] modify Readme.ms --- backend/worker/quivr_worker/process/README.md | 5 ----- 1 file changed, 5 deletions(-) diff --git a/backend/worker/quivr_worker/process/README.md b/backend/worker/quivr_worker/process/README.md index 75cbb36ae12c..bfb94ff26de9 100644 --- a/backend/worker/quivr_worker/process/README.md +++ b/backend/worker/quivr_worker/process/README.md @@ -72,11 +72,6 @@ However, sync knowledge added to a brain will be reprocessed after some time thr # Syncing Knowledge task -- Task to update all files -- Task to get all folders and proceess -- Rollback doesnt modify the original knowledge -- If we don't find a folder we will delete it -> deleting all the knowledges associated with it ! - 1. **Syncing Knowledge Syncs of Type Files:** - Outdated file syncs are fetched in batches. - For each file, if the remote file's `last_modified_at` is newer than the local `last_synced_at`, the file is updated. From 713fa860a5d5bff5f2a5bde141398e21493fa58b Mon Sep 17 00:00:00 2001 From: aminediro Date: Fri, 4 Oct 2024 16:07:45 +0200 Subject: [PATCH 52/63] added field brains --- backend/api/quivr_api/modules/brain/entity/brain_entity.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/api/quivr_api/modules/brain/entity/brain_entity.py b/backend/api/quivr_api/modules/brain/entity/brain_entity.py index aae2bc2e73fe..28bc55d3759f 100644 --- a/backend/api/quivr_api/modules/brain/entity/brain_entity.py +++ b/backend/api/quivr_api/modules/brain/entity/brain_entity.py @@ -70,10 +70,12 @@ class Brain(AsyncAttrs, SQLModel, table=True): knowledges: List[KnowledgeDB] = Relationship( back_populates="brains", link_model=KnowledgeBrain ) - users: List["User"] = Relationship( + users: List["User"] = Relationship( # type: ignore # noqa: F821 back_populates="brains", link_model=BrainUserDB, ) + snippet_color: str | None = Field(default="#d0c6f2") + snippet_emoji: str | None = Field(default="🧠") # TODO : add # "meaning" "public"."vector", # "tags" "public"."tags"[] From e7cdf6a4917a58aeddf8a3cb7646992d96a35524 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 7 Oct 2024 11:35:01 +0200 Subject: [PATCH 53/63] file size limit --- backend/api/quivr_api/models/settings.py | 5 ++-- .../knowledge/controller/knowledge_routes.py | 6 ++++ .../tests/test_knowledge_controller.py | 28 +++++++++++++++++++ backend/requirements-dev.lock | 3 ++ 4 files changed, 40 insertions(+), 2 deletions(-) diff --git a/backend/api/quivr_api/models/settings.py b/backend/api/quivr_api/models/settings.py index c29205518faa..745677e9415e 100644 --- a/backend/api/quivr_api/models/settings.py +++ b/backend/api/quivr_api/models/settings.py @@ -103,6 +103,8 @@ def set_once_user_properties(self, user_id: UUID, event_name, properties: dict): class BrainSettings(BaseSettings): model_config = SettingsConfigDict(validate_default=False) + pg_database_url: str + pg_database_async_url: str openai_api_key: str = "" azure_openai_embeddings_url: str = "" supabase_url: str = "" @@ -112,11 +114,10 @@ class BrainSettings(BaseSettings): ollama_api_base_url: str | None = None langfuse_public_key: str | None = None langfuse_secret_key: str | None = None - pg_database_url: str - pg_database_async_url: str sqlalchemy_pool_size: int = 5 sqlalchemy_max_pool_overflow: int = 5 embedding_dim: int = 1536 + max_file_size: int = int(5e7) class ResendSettings(BaseSettings): diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index 43675a6c41da..da7edb760f28 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -18,6 +18,7 @@ from quivr_api.celery_config import celery from quivr_api.logger import get_logger from quivr_api.middlewares.auth import AuthBearer, get_current_user +from quivr_api.models.settings import settings from quivr_api.modules.brain.service.brain_authorization_service import ( validate_brain_authorization, ) @@ -123,6 +124,11 @@ async def create_knowledge( current_user: UserIdentity = Depends(get_current_user), ): knowledge = AddKnowledge.model_validate_json(knowledge_data) + if file and file.size and file.size > settings.max_file_size: + raise HTTPException( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + detail="Uploaded file is too large", + ) if not knowledge.file_name and not knowledge.url: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index 8a3e4637501c..3174bfa91cd1 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -5,6 +5,7 @@ import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient +from quivr_api.models.settings import BrainSettings from quivr_core.models import KnowledgeStatus from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession @@ -144,6 +145,33 @@ async def test_post_knowledge_folder(test_client: AsyncClient): assert km.children == [] +@pytest.mark.asyncio(loop_scope="session") +async def test_add_knowledge_large_file(monkeypatch, test_client): + _settings = BrainSettings() + _settings.max_file_size = 2 + monkeypatch.setattr( + "quivr_api.modules.knowledge.controller.knowledge_routes.settings", _settings + ) + km_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": False, + "parent_id": None, + } + + multipart_data = { + "knowledge_data": (None, json.dumps(km_data), "application/json"), + "file": ("test_file.txt", b"Test file content", "application/octet-stream"), + } + + response = await test_client.post( + "/knowledge/", + files=multipart_data, + ) + + assert response.status_code == 413 + + @pytest.mark.asyncio(loop_scope="session") async def test_add_knowledge_invalid_input(test_client): response = await test_client.post("/knowledge/", files={}) diff --git a/backend/requirements-dev.lock b/backend/requirements-dev.lock index 97b7b2e18c53..98e7bcad88a4 100644 --- a/backend/requirements-dev.lock +++ b/backend/requirements-dev.lock @@ -553,6 +553,7 @@ llama-parse==0.5.6 # via quivr-api llvmlite==0.43.0 # via numba +locust==2.31.8 lxml==5.3.0 # via pikepdf # via python-docx @@ -1130,6 +1131,8 @@ sentry-sdk==2.13.0 # via quivr-api setuptools==70.0.0 # via opentelemetry-instrumentation + # via zope-event + # via zope-interface shapely==2.0.6 # via python-doctr simple-websocket==1.0.0 From 233b496ce6b4c5b9fdeaeb2a3016838253641f65 Mon Sep 17 00:00:00 2001 From: AmineDiro Date: Mon, 7 Oct 2024 12:24:26 +0200 Subject: [PATCH 54/63] fix update knowledge (#3334) # Description --- .../knowledge/controller/knowledge_routes.py | 2 +- .../tests/test_knowledge_controller.py | 131 +++++++++++++++++- .../worker/quivr_worker/process/processor.py | 1 - 3 files changed, 130 insertions(+), 4 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index da7edb760f28..df9bacb7f36a 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -230,7 +230,7 @@ async def update_knowledge( detail="You do not have permission to access this knowledge.", ) km = await knowledge_service.update_knowledge(km, payload) - return km + return await km.to_dto() except KnowledgeNotFoundException as e: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"{e.message}" diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index 3174bfa91cd1..0977582b07af 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -5,19 +5,19 @@ import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient -from quivr_api.models.settings import BrainSettings from quivr_core.models import KnowledgeStatus from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.main import app from quivr_api.middlewares.auth.auth_bearer import get_current_user +from quivr_api.models.settings import BrainSettings from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.knowledge.controller.knowledge_routes import ( get_knowledge_service, ) -from quivr_api.modules.knowledge.dto.inputs import LinkKnowledgeBrain +from quivr_api.modules.knowledge.dto.inputs import KnowledgeUpdate, LinkKnowledgeBrain from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.knowledge.repository.knowledges import KnowledgeRepository from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService @@ -315,3 +315,130 @@ def _send_task(*args, **kwargs): # 4. Assert both files are being scheduled for processing assert len(tasks) == 2 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_move_knowledge_to_folder( + monkeypatch, + session: AsyncSession, + test_client: AsyncClient, + brain: Brain, + user: User, + sync: Sync, +): + assert brain.brain_id + tasks = {} + + def _send_task(*args, **kwargs): + tasks["args"] = args + tasks["kwargs"] = {**kwargs["kwargs"]} + + monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) + + folder_data = { + "file_name": "folder", + "source": "local", + "is_folder": True, + "parent_id": None, + } + response = await test_client.post( + "/knowledge/", + files={ + "knowledge_data": (None, json.dumps(folder_data), "application/json"), + }, + ) + # 1. Insert folder + folder_km = KnowledgeDTO.model_validate(response.json()) + file_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": True, + "parent_id": None, + } + + multipart_data = { + "knowledge_data": (None, json.dumps(file_data), "application/json"), + } + # 2. Insert file in Root + response = await test_client.post( + "/knowledge/", + files=multipart_data, + ) + file_km = KnowledgeDTO.model_validate(response.json()) + + # Move file to folder + update = KnowledgeUpdate(parent_id=folder_km.id) + response = await test_client.patch( + f"/knowledge/{file_km.id}", + content=update.model_dump_json(exclude_unset=True), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 202 + updated_km = KnowledgeDTO.model_validate(response.json()) + + # 3. Validate that created knowledges are correct + assert updated_km.parent and updated_km.parent.id + assert updated_km.parent.id == folder_km.id + + +@pytest.mark.asyncio(loop_scope="session") +async def test_move_knowledge_to_root( + monkeypatch, + session: AsyncSession, + test_client: AsyncClient, + brain: Brain, + user: User, + sync: Sync, +): + assert brain.brain_id + tasks = {} + + def _send_task(*args, **kwargs): + tasks["args"] = args + tasks["kwargs"] = {**kwargs["kwargs"]} + + monkeypatch.setattr("quivr_api.celery_config.celery.send_task", _send_task) + + folder_data = { + "file_name": "folder", + "source": "local", + "is_folder": True, + "parent_id": None, + } + response = await test_client.post( + "/knowledge/", + files={ + "knowledge_data": (None, json.dumps(folder_data), "application/json"), + }, + ) + # 1. Insert folder + folder_km = KnowledgeDTO.model_validate(response.json()) + file_data = { + "file_name": "test_file.txt", + "source": "local", + "is_folder": True, + "parent_id": str(folder_km.id), + } + + multipart_data = { + "knowledge_data": (None, json.dumps(file_data), "application/json"), + } + # 2. Insert file in Root + response = await test_client.post( + "/knowledge/", + files=multipart_data, + ) + file_km = KnowledgeDTO.model_validate(response.json()) + + # Move file to Root + update = KnowledgeUpdate(parent_id=None) + response = await test_client.patch( + f"/knowledge/{file_km.id}", + content=update.model_dump_json(exclude_unset=True), + headers={"Content-Type": "application/json"}, + ) + assert response.status_code == 202 + updated_km = KnowledgeDTO.model_validate(response.json()) + + # 3. Validate that updated + assert updated_km.parent is None diff --git a/backend/worker/quivr_worker/process/processor.py b/backend/worker/quivr_worker/process/processor.py index 6828aef13f56..83bee5f61d8b 100644 --- a/backend/worker/quivr_worker/process/processor.py +++ b/backend/worker/quivr_worker/process/processor.py @@ -285,7 +285,6 @@ async def refresh_sync_folder(self, folder_km: KnowledgeDB) -> KnowledgeDB: folder_id=folder_km.sync_file_id, recursive=False, ) - breakpoint() for sync_entry in filter(lambda s: s.id not in sync_children, sync_files): await self.add_new_sync_entry(folder=folder_km, sync_entry=sync_entry) From 24f672f27390dc8f2384d1af75b1b4451e10f5e7 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 7 Oct 2024 14:42:22 +0200 Subject: [PATCH 55/63] fix arg length --- .../sync/repository/sync_repository.py | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/backend/api/quivr_api/modules/sync/repository/sync_repository.py b/backend/api/quivr_api/modules/sync/repository/sync_repository.py index 8c7dffde3dbc..584074b1fe29 100644 --- a/backend/api/quivr_api/modules/sync/repository/sync_repository.py +++ b/backend/api/quivr_api/modules/sync/repository/sync_repository.py @@ -64,7 +64,7 @@ async def create_sync( Returns: """ - logger.info("Creating sync user with input: %s", sync_user_input) + logger.info(f"Creating sync user with input: {sync_user_input}") try: sync = Sync.model_validate(sync_user_input.model_dump()) self.session.add(sync) @@ -108,9 +108,7 @@ async def get_syncs(self, user_id: UUID, sync_id: int | None = None): list: A list of sync users matching the criteria. """ logger.info( - "Retrieving sync users for user_id: %s, sync_user_id: %s", - user_id, - sync_id, + f"Retrieving sync users for user_id: {user_id}, sync_user_id: {sync_id}", ) query = select(Sync).where(Sync.user_id == user_id) if sync_id is not None: @@ -128,7 +126,7 @@ async def get_sync_user_by_state(self, state: dict) -> Sync: Returns: dict or None: The sync user data matching the state or None if not found. """ - logger.info("Getting sync user by state: %s", state) + logger.info(f"Getting sync user by state: {state}") query = select(Sync).where(Sync.state == state) result = await self.session.exec(query) @@ -140,9 +138,7 @@ async def get_sync_user_by_state(self, state: dict) -> Sync: return None async def delete_sync(self, sync_id: int, user_id: UUID): - logger.info( - "Deleting sync user with sync_id: %s, user_id: %s", sync_id, user_id - ) + logger.info(f"Deleting sync user with sync_id: {sync_id}, user_id: {user_id}") await self.session.execute( delete(Sync).where(Sync.id == sync_id).where(Sync.user_id == user_id) # type: ignore ) @@ -151,11 +147,7 @@ async def delete_sync(self, sync_id: int, user_id: UUID): async def update_sync( self, sync: Sync, sync_input: SyncUpdateInput | dict[str, Any] ): - logger.debug( - "Updating sync user with user_id: %s, state: %s, input: %s", - sync.id, - sync_input, - ) + logger.debug(f"Updating sync user with user_id: {sync.id}") try: if isinstance(sync_input, dict): update_data = sync_input @@ -197,17 +189,12 @@ async def get_files_folder_user_sync( recursive: bool = False, ) -> List[SyncFile] | None: logger.info( - "Retrieving files for user sync with sync_active_id: %s, user_id: %s, folder_id: %s", - sync_id, - user_id, - folder_id, + f"Retrieving files for user sync with sync_active_id: {sync_id}, user_id: {user_id}, folder_id: {folder_id}", ) sync = await self.get_sync_id(sync_id=sync_id, user_id=user_id) if not sync: logger.error( - "No sync user found for sync_active_id: %s, user_id: %s", - sync_id, - user_id, + f"No sync user found for sync_active_id: {sync_id}, user_id: {user_id}", ) return None From 6075b836ab545cfc76c51dceb198c614a5051ff3 Mon Sep 17 00:00:00 2001 From: aminediro Date: Mon, 7 Oct 2024 16:23:37 +0200 Subject: [PATCH 56/63] update km date --- .../api/quivr_api/modules/knowledge/repository/knowledges.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index e4b36bfe7033..dfa2e2c954ea 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone from typing import Any, List, Sequence from uuid import UUID @@ -68,7 +68,7 @@ async def update_knowledge( update_data = payload.model_dump(exclude_unset=True) for field in update_data: setattr(knowledge, field, update_data[field]) - + knowledge.updated_at = datetime.now(timezone.utc) self.session.add(knowledge) if autocommit: await self.session.commit() From aae36f484a17e745b20a298c5721c086077a8449 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 8 Oct 2024 16:34:24 +0200 Subject: [PATCH 57/63] sort knowledge, remove bulk_id --- .../knowledge/controller/knowledge_routes.py | 26 ++++------ .../quivr_api/modules/knowledge/dto/inputs.py | 1 - .../modules/knowledge/dto/outputs.py | 7 +++ .../modules/knowledge/entity/knowledge.py | 3 +- .../knowledge/tests/integration_test.py | 3 +- .../tests/test_knowledge_controller.py | 5 +- .../knowledge/tests/test_knowledge_entity.py | 47 ++++++++++++++++++- .../knowledge/tests/test_knowledge_service.py | 4 ++ backend/benchmarks/locustfile_kms.py | 4 +- 9 files changed, 71 insertions(+), 29 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py index df9bacb7f36a..5fc5f62842cb 100644 --- a/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py +++ b/backend/api/quivr_api/modules/knowledge/controller/knowledge_routes.py @@ -29,7 +29,7 @@ LinkKnowledgeBrain, UnlinkKnowledgeBrain, ) -from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO, sort_knowledge_dtos from quivr_api.modules.knowledge.service.knowledge_exceptions import ( KnowledgeDeleteError, KnowledgeForbiddenAccess, @@ -37,8 +37,6 @@ UploadError, ) from quivr_api.modules.knowledge.service.knowledge_service import KnowledgeService -from quivr_api.modules.notification.dto.inputs import CreateNotification -from quivr_api.modules.notification.entity.notification import NotificationsStatusEnum from quivr_api.modules.notification.service.notification_service import ( NotificationService, ) @@ -171,7 +169,8 @@ async def list_knowledge( try: # TODO: Returns one level of children children = await knowledge_service.list_knowledge(parent_id, current_user.id) - return [await c.to_dto(get_children=False) for c in children] + children_dto = [await c.to_dto(get_children=False) for c in children] + return sort_knowledge_dtos(children_dto) except KnowledgeNotFoundException as e: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"{e.message}" @@ -277,10 +276,9 @@ async def link_knowledge_to_brain( knowledge_service: KnowledgeService = Depends(get_knowledge_service), current_user: UserIdentity = Depends(get_current_user), ): - brains_ids, knowledge_dto, bulk_id = ( + brains_ids, knowledge_dto = ( link_request.brain_ids, link_request.knowledge, - link_request.bulk_id, ) if len(brains_ids) == 0: return Response(status_code=status.HTTP_204_NO_CONTENT) @@ -312,20 +310,10 @@ async def link_knowledge_to_brain( if await k.awaitable_attrs.status not in [KnowledgeStatus.PROCESSED, KnowledgeStatus.PROCESSING] ]: - upload_notification = notification_service.add_notification( - CreateNotification( - user_id=current_user.id, - bulk_id=bulk_id, - status=NotificationsStatusEnum.INFO, - title=f"{await knowledge.awaitable_attrs.file_name}", - category="process", - ) - ) celery.send_task( "process_file_task", kwargs={ "knowledge_id": await knowledge.awaitable_attrs.id, - "notification_id": upload_notification.id, }, ) knowledge = await knowledge_service.update_knowledge( @@ -333,7 +321,8 @@ async def link_knowledge_to_brain( payload=KnowledgeUpdate(status=KnowledgeStatus.PROCESSING), ) - return await asyncio.gather(*[k.to_dto() for k in linked_kms]) + linked_kms = await asyncio.gather(*[k.to_dto() for k in linked_kms]) + return sort_knowledge_dtos(linked_kms) @knowledge_router.delete( @@ -365,4 +354,5 @@ async def unlink_knowledge_from_brain( ) if unlinked_kms: - return await asyncio.gather(*[k.to_dto() for k in unlinked_kms]) + unlinked_knowledges = await asyncio.gather(*[k.to_dto() for k in unlinked_kms]) + return sort_knowledge_dtos(unlinked_knowledges) diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py index 8801ab48a1ef..8ce186311987 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py @@ -50,7 +50,6 @@ class KnowledgeUpdate(BaseModel): class LinkKnowledgeBrain(BaseModel): - bulk_id: UUID knowledge: KnowledgeDTO brain_ids: List[UUID] diff --git a/backend/api/quivr_api/modules/knowledge/dto/outputs.py b/backend/api/quivr_api/modules/knowledge/dto/outputs.py index b60fc0a7b51b..7de9587689f9 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/outputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/outputs.py @@ -34,3 +34,10 @@ class KnowledgeDTO(BaseModel): sync_id: int | None sync_file_id: str | None last_synced_at: datetime | None = None + + +def sort_knowledge_dtos(dtos: List[KnowledgeDTO]) -> List[KnowledgeDTO]: + return sorted( + dtos, + key=lambda dto: (not dto.is_folder, dto.file_name is None, dto.file_name or ""), + ) diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index b4ff593e327e..fb3f482f6cc1 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -10,7 +10,7 @@ from sqlmodel import UUID as PGUUID from sqlmodel import Field, Relationship, SQLModel -from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO, sort_knowledge_dtos from quivr_api.modules.knowledge.entity.knowledge_brain import KnowledgeBrain from quivr_api.modules.sync.entity.sync_models import Sync @@ -118,6 +118,7 @@ async def to_dto( children_dto = await asyncio.gather( *[c.to_dto(get_children=False) for c in children] ) + children_dto = sort_knowledge_dtos(children_dto) parent = await self.awaitable_attrs.parent if get_parent else None parent_dto = await parent.to_dto(get_children=False) if parent else None diff --git a/backend/api/quivr_api/modules/knowledge/tests/integration_test.py b/backend/api/quivr_api/modules/knowledge/tests/integration_test.py index 023378621d21..33b552f4039d 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/integration_test.py +++ b/backend/api/quivr_api/modules/knowledge/tests/integration_test.py @@ -1,6 +1,6 @@ import asyncio import json -from uuid import UUID, uuid4 +from uuid import UUID from httpx import AsyncClient @@ -33,7 +33,6 @@ async def main(): km = KnowledgeDTO.model_validate(response.json()) json_data = LinkKnowledgeBrain( - bulk_id=uuid4(), brain_ids=[UUID("40ba47d7-51b2-4b2a-9247-89e29619efb0")], knowledge=km, ).model_dump_json() diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py index 0977582b07af..21f59168f9f0 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_controller.py @@ -1,6 +1,5 @@ import json from datetime import datetime -from uuid import uuid4 import pytest import pytest_asyncio @@ -214,7 +213,7 @@ def _send_task(*args, **kwargs): children=[], ) json_data = LinkKnowledgeBrain( - bulk_id=uuid4(), brain_ids=[brain.brain_id], knowledge=km + brain_ids=[brain.brain_id], knowledge=km ).model_dump_json() response = await test_client.post( "/knowledge/link_to_brains/", @@ -291,7 +290,7 @@ def _send_task(*args, **kwargs): file_km = KnowledgeDTO.model_validate(response.json()) json_data = LinkKnowledgeBrain( - bulk_id=uuid4(), brain_ids=[brain.brain_id], knowledge=folder_km + brain_ids=[brain.brain_id], knowledge=folder_km ).model_dump_json() response = await test_client.post( diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index 857c111e9fae..c84ef5fd827b 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -1,4 +1,6 @@ +from datetime import datetime from typing import List, Tuple +from uuid import uuid4 import pytest import pytest_asyncio @@ -7,7 +9,8 @@ from sqlmodel.ext.asyncio.session import AsyncSession from quivr_api.modules.brain.entity.brain_entity import Brain -from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB +from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO, sort_knowledge_dtos +from quivr_api.modules.knowledge.entity.knowledge import KnowledgeDB, KnowledgeSource from quivr_api.modules.user.entity.user_identity import User TestData = Tuple[Brain, List[KnowledgeDB]] @@ -185,7 +188,7 @@ async def test_knowledge_dto(session, user, brain, sync): assert km_dto.file_sha1 == km.file_sha1 assert km_dto.updated_at == km.updated_at assert km_dto.created_at == km.created_at - assert km_dto.metadata == km.metadata_ # type: ignor + assert km_dto.metadata == km.metadata_ # type: ignore assert km_dto.parent assert km_dto.parent.id == folder.id # Syncs fields @@ -195,3 +198,43 @@ async def test_knowledge_dto(session, user, brain, sync): folder_dto = await folder.to_dto() assert folder_dto.brains[0] == brain.model_dump() assert folder_dto.children == [await km.to_dto()] + + +def test_sort_knowledge_dtos(): + user_id = uuid4() + + data_dict = { + "extension": ".txt", + "status": None, + "user_id": user_id, + "created_at": datetime.now(), + "updated_at": datetime.now(), + "brains": [], + "source": KnowledgeSource.LOCAL, + "source_link": "://test.txt", + "sync_id": None, + "sync_file_id": None, + "parent": None, + "children": [], + } + dtos = [ + KnowledgeDTO(id=uuid4(), is_folder=False, file_name=None, **data_dict), + KnowledgeDTO(id=uuid4(), is_folder=True, file_name="A", **data_dict), + KnowledgeDTO(id=uuid4(), is_folder=True, file_name=None, **data_dict), + KnowledgeDTO(id=uuid4(), is_folder=False, file_name="B", **data_dict), + ] + + sorted_dtos = sort_knowledge_dtos(dtos) + + # First element should be a folder with file_name="A" + assert sorted_dtos[0].is_folder is True + assert sorted_dtos[0].file_name == "A" + # Second element should be a folder with file_name=None + assert sorted_dtos[1].is_folder is True + assert sorted_dtos[1].file_name is None + # Third element should be a file with file_name="B" + assert sorted_dtos[2].is_folder is False + assert sorted_dtos[2].file_name == "B" + # Fourth element should be a file with file_name=None + assert sorted_dtos[3].is_folder is False + assert sorted_dtos[3].file_name is None diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index 3e9335a8630b..72b39cc7c77a 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -948,6 +948,10 @@ async def test_list_knowledge_root(session: AsyncSession, user: User): assert len(root_kms) == 2 assert {k.id for k in root_kms} == {root_folder.id, root_file.id} + # check order + assert root_kms[0].file_name == "folder" + assert root_kms[1].file_name == "file_1" + @pytest.mark.asyncio(loop_scope="session") async def test_list_knowledge(session: AsyncSession, user: User): diff --git a/backend/benchmarks/locustfile_kms.py b/backend/benchmarks/locustfile_kms.py index 3fc3b41fe508..84092b72f023 100644 --- a/backend/benchmarks/locustfile_kms.py +++ b/backend/benchmarks/locustfile_kms.py @@ -2,7 +2,7 @@ import os import random from typing import List -from uuid import UUID, uuid4 +from uuid import UUID from locust import between, task from locust.contrib.fasthttp import FastHttpUser @@ -136,7 +136,7 @@ def link_to_brains(self): random_brains = [random.choice(brains_ids) for _ in range(nb_brains)] random_km = random.choice(all_kms) json_data = LinkKnowledgeBrain( - bulk_id=uuid4(), brain_ids=random_brains, knowledge=random_km + brain_ids=random_brains, knowledge=random_km ).model_dump_json() self.client.post( "/knowledge/link_to_brains/", From 4c4cd6dec73bca932605151512c5907e59a1a81e Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 8 Oct 2024 16:37:36 +0200 Subject: [PATCH 58/63] modified test --- .../quivr_api/modules/knowledge/tests/test_knowledge_entity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index c84ef5fd827b..dbd563eac138 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -219,9 +219,9 @@ def test_sort_knowledge_dtos(): } dtos = [ KnowledgeDTO(id=uuid4(), is_folder=False, file_name=None, **data_dict), + KnowledgeDTO(id=uuid4(), is_folder=False, file_name="B", **data_dict), KnowledgeDTO(id=uuid4(), is_folder=True, file_name="A", **data_dict), KnowledgeDTO(id=uuid4(), is_folder=True, file_name=None, **data_dict), - KnowledgeDTO(id=uuid4(), is_folder=False, file_name="B", **data_dict), ] sorted_dtos = sort_knowledge_dtos(dtos) From a81f553612f6cc2c4b817d37d5db5300f6c2fda2 Mon Sep 17 00:00:00 2001 From: aminediro Date: Tue, 8 Oct 2024 22:35:22 +0200 Subject: [PATCH 59/63] updated benchmark with routes --- backend/benchmarks/locustfile_kms.py | 82 +++++++++++++++++++++++++--- 1 file changed, 73 insertions(+), 9 deletions(-) diff --git a/backend/benchmarks/locustfile_kms.py b/backend/benchmarks/locustfile_kms.py index 3fc3b41fe508..99ec24be32dd 100644 --- a/backend/benchmarks/locustfile_kms.py +++ b/backend/benchmarks/locustfile_kms.py @@ -9,7 +9,11 @@ from quivr_api.modules.brain.entity.brain_entity import Brain, BrainType from quivr_api.modules.brain.entity.brain_user import BrainUserDB from quivr_api.modules.dependencies import get_supabase_client -from quivr_api.modules.knowledge.dto.inputs import LinkKnowledgeBrain +from quivr_api.modules.knowledge.dto.inputs import ( + KnowledgeUpdate, + LinkKnowledgeBrain, + UnlinkKnowledgeBrain, +) from quivr_api.modules.knowledge.dto.outputs import KnowledgeDTO from quivr_api.modules.user.entity.user_identity import User from sqlmodel import Session, create_engine, select, text @@ -23,11 +27,14 @@ "parent_prob": 0.3, "folder_prob": 0.2, "km_root_prob": 0.2, + # Task rates "create_km_rate": 10, "list_km_rate": 10, "link_brain_rate": 5, "max_link_brains": 3, "delete_km_rate": 2, + "unlink_brain_rate": 5, + "update_km_rate": 2, } @@ -91,7 +98,7 @@ class QuivrUser(FastHttpUser): data = os.urandom(load_params["file_size"]) sync_engine = create_engine( pg_database_base_url, - echo=True, + echo=False, ) def on_start(self) -> None: @@ -126,19 +133,19 @@ def create_knowledge(self): returned_km = KnowledgeDTO.model_validate_json(response.text) all_kms.append(returned_km) - create_knowledge.__name__ = "create_knowledge_1MB" - @task(load_params["link_brain_rate"]) def link_to_brains(self): + global all_kms if len(all_kms) == 0: return nb_brains = random.randint(1, load_params["max_link_brains"]) random_brains = [random.choice(brains_ids) for _ in range(nb_brains)] - random_km = random.choice(all_kms) + random_idx = random.choice(range(len(all_kms))) + random_km = all_kms.pop(random_idx) json_data = LinkKnowledgeBrain( bulk_id=uuid4(), brain_ids=random_brains, knowledge=random_km ).model_dump_json() - self.client.post( + response = self.client.post( "/knowledge/link_to_brains/", data=json_data, headers={ @@ -146,8 +153,9 @@ def link_to_brains(self): **self.auth_headers, }, ) - - link_to_brains.__name__ = "link_to_brain" + response.raise_for_status() + kms = [KnowledgeDTO.model_validate(r) for r in response.json()] + all_kms.extend(kms) @task(load_params["list_km_rate"]) def list_knowledge_files(self): @@ -165,22 +173,78 @@ def list_knowledge_files(self): name="/knowledge/files", ) - list_knowledge_files.__name__ = "list_knowledge_files" + @task(load_params["unlink_brain_rate"]) + def unlink_knowledge_brain(self): + global all_kms + if len(all_kms) == 0: + return + random_idx = random.choice(range(len(all_kms))) + random_km = all_kms.pop(random_idx) + if len(random_km.brains) == 0: + return + random_brain = random.choice(random_km.brains) + assert random_km.id + json_data = UnlinkKnowledgeBrain( + brain_ids=[UUID(random_brain["brain_id"])], + knowledge_id=random_km.id, + ).model_dump_json() + self.client.delete( + "/knowledge/unlink_from_brains/", + data=json_data, + headers={ + "Content-Type": "application/json", + **self.auth_headers, + }, + ) @task(load_params["delete_km_rate"]) def delete_knowledge_files(self): + global all_kms only_files = [idx for idx, km in enumerate(all_kms) if not km.is_folder] if len(only_files) == 0: return random_index = random.choice(only_files) random_km = all_kms.pop(random_index) + children_ids = [c.id for c in random_km.children] + all_kms[:] = [k for k in all_kms if k.id not in children_ids] self.client.delete( f"/knowledge/{str(random_km.id)}", headers=self.auth_headers, name="/knowledge/delete", ) + @task(load_params["update_km_rate"]) + def update_knowledge(self): + global all_kms + if len(all_kms) == 0: + return + random_idx = random.choice(range(len(all_kms))) + random_km = all_kms.pop(random_idx) + assert random_km.id + json_data = KnowledgeUpdate(file_name=f"file-{uuid4()}").model_dump_json( + exclude_unset=True + ) + response = self.client.patch( + f"/knowledge/{random_km.id}/", + data=json_data, + name="/knowledge/update", + headers={ + "Content-Type": "application/json", + **self.auth_headers, + }, + ) + assert response and response.text + km = KnowledgeDTO.model_validate_json(response.text) + all_kms.append(km) + + # CRUD operations + create_knowledge.__name__ = "create_knowledge_1MB" + update_knowledge.__name__ = "update_knowledge" delete_knowledge_files.__name__ = "delete_knowledge_file" + list_knowledge_files.__name__ = "list_knowledge_files" + # Special linking/unlinking brains + link_to_brains.__name__ = "link_to_brain" + unlink_knowledge_brain.__name__ = "unlink_knowledge_brain" def on_stop(self): global brains_ids From 4e88b7e44031f49450d4a2fa9efa49f1c4685824 Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 16 Oct 2024 10:15:39 +0200 Subject: [PATCH 60/63] sort brain name and knowledge list local --- .../modules/knowledge/entity/knowledge.py | 1 + .../modules/knowledge/repository/knowledges.py | 2 ++ .../quivr_api/modules/knowledge/tests/conftest.py | 14 +++++++++++++- .../knowledge/tests/test_knowledge_entity.py | 8 ++++++-- 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py index fb3f482f6cc1..39f63ac627f2 100644 --- a/backend/api/quivr_api/modules/knowledge/entity/knowledge.py +++ b/backend/api/quivr_api/modules/knowledge/entity/knowledge.py @@ -112,6 +112,7 @@ async def to_dto( await self.awaitable_attrs.created_at ), "knowledge should be inserted before transforming to dto" brains = await self.awaitable_attrs.brains + brains = sorted(brains, key=lambda b: (b is None, b.name)) children: list[KnowledgeDB] = ( await self.awaitable_attrs.children if get_children else [] ) diff --git a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py index dfa2e2c954ea..48b07fd8d6f7 100644 --- a/backend/api/quivr_api/modules/knowledge/repository/knowledges.py +++ b/backend/api/quivr_api/modules/knowledge/repository/knowledges.py @@ -20,6 +20,7 @@ ) from quivr_api.modules.knowledge.entity.knowledge import ( KnowledgeDB, + KnowledgeSource, ) from quivr_api.modules.knowledge.service.knowledge_exceptions import ( KnowledgeCreationError, @@ -328,6 +329,7 @@ async def get_root_knowledge_user(self, user_id: UUID) -> list[KnowledgeDB]: select(KnowledgeDB) .where(KnowledgeDB.parent_id.is_(None)) # type: ignore .where(KnowledgeDB.user_id == user_id) + .where(KnowledgeDB.source == KnowledgeSource.LOCAL) .options(joinedload(KnowledgeDB.children)) # type: ignore ) result = await self.session.exec(query) diff --git a/backend/api/quivr_api/modules/knowledge/tests/conftest.py b/backend/api/quivr_api/modules/knowledge/tests/conftest.py index e0a8423f0384..c87d20af5968 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/conftest.py +++ b/backend/api/quivr_api/modules/knowledge/tests/conftest.py @@ -134,7 +134,19 @@ async def sync(session: AsyncSession, user: User) -> Sync: @pytest_asyncio.fixture(scope="function") async def brain(session): brain_1 = Brain( - name="test_brain", + name="brain_1", + description="this is a test brain", + brain_type=BrainType.integration, + ) + session.add(brain_1) + await session.commit() + return brain_1 + + +@pytest_asyncio.fixture(scope="function") +async def brain2(session): + brain_1 = Brain( + name="brain_2", description="this is a test brain", brain_type=BrainType.integration, ) diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py index dbd563eac138..8964e2964384 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_entity.py @@ -141,7 +141,7 @@ async def test_knowledge_remove_folder_cascade( @pytest.mark.asyncio(loop_scope="session") -async def test_knowledge_dto(session, user, brain, sync): +async def test_knowledge_dto(session, user, brain, brain2, sync): # add folder in brain folder = KnowledgeDB( file_name="folder_1", @@ -165,7 +165,7 @@ async def test_knowledge_dto(session, user, brain, sync): file_size=100, file_sha1="test_sha1", user_id=user.id, - brains=[brain], + brains=[brain2, brain], parent=folder, sync_file_id="file1", sync=sync, @@ -194,7 +194,11 @@ async def test_knowledge_dto(session, user, brain, sync): # Syncs fields assert km_dto.sync_id == km.sync_id assert km_dto.sync_file_id == km.sync_file_id + # Check brain_name order + assert len(km_dto.brains) == 2 + assert km_dto.brains[1]["name"] > km_dto.brains[0]["name"] + # Check folder to dto folder_dto = await folder.to_dto() assert folder_dto.brains[0] == brain.model_dump() assert folder_dto.children == [await km.to_dto()] From 505a76548e6e329dc595c0cfbaa50e7946d6f8ad Mon Sep 17 00:00:00 2001 From: AmineDiro Date: Wed, 16 Oct 2024 11:43:29 +0200 Subject: [PATCH 61/63] fix: sync provider naming (#3380) # Description --- .../modules/sync/controller/azure_sync_routes.py | 3 ++- .../modules/sync/controller/dropbox_sync_routes.py | 3 ++- .../modules/sync/controller/github_sync_routes.py | 3 ++- .../modules/sync/controller/google_sync_routes.py | 3 ++- .../modules/sync/controller/notion_sync_routes.py | 3 ++- .../quivr_api/modules/sync/controller/sync_routes.py | 12 ++++++------ .../quivr_api/modules/sync/service/sync_service.py | 4 ++-- 7 files changed, 18 insertions(+), 13 deletions(-) diff --git a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py index 65d3224386d1..dcfdc13074c0 100644 --- a/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/azure_sync_routes.py @@ -9,6 +9,7 @@ from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput +from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -63,7 +64,7 @@ async def authorize_azure( ) logger.debug(f"Authorizing Azure sync for user: {current_user.id}") state = await syncs_service.create_oauth2_state( - provider="Azure", name=name, user_id=current_user.id + provider=SyncProvider.AZURE, name=name, user_id=current_user.id ) flow = client.initiate_auth_code_flow( scopes=SCOPE, diff --git a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py index 568837cfe9af..b715e899ec7b 100644 --- a/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/dropbox_sync_routes.py @@ -9,6 +9,7 @@ from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput +from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -64,7 +65,7 @@ async def authorize_dropbox( scope=SCOPE, ) state = await syncs_service.create_oauth2_state( - provider="DropBox", name=name, user_id=current_user.id + provider=SyncProvider.DROPBOX, name=name, user_id=current_user.id ) authorize_url = auth_flow.start(state.model_dump_json()) logger.info( diff --git a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py index 7dbb866c17d5..e42f6d4d7e22 100644 --- a/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/github_sync_routes.py @@ -8,6 +8,7 @@ from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput +from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -54,7 +55,7 @@ async def authorize_github( """ logger.debug(f"Authorizing GitHub sync for user: {current_user.id}") state = await syncs_service.create_oauth2_state( - provider="Github", name=name, user_id=current_user.id + provider=SyncProvider.GITHUB, name=name, user_id=current_user.id ) authorization_url = ( f"https://github.com/login/oauth/authorize?client_id={CLIENT_ID}" diff --git a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py index 4d02de674d4c..e9dd95a52d2e 100644 --- a/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/google_sync_routes.py @@ -10,6 +10,7 @@ from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput +from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -89,7 +90,7 @@ async def authorize_google( ) state = await syncs_service.create_oauth2_state( - provider="Google", name=name, user_id=current_user.id + provider=SyncProvider.GOOGLE, name=name, user_id=current_user.id ) authorization_url, state = flow.authorization_url( access_type="offline", diff --git a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py index 81b0be95abaf..b151a043cb2d 100644 --- a/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/notion_sync_routes.py @@ -11,6 +11,7 @@ from quivr_api.middlewares.auth import AuthBearer, get_current_user from quivr_api.modules.dependencies import get_service from quivr_api.modules.sync.dto.inputs import SyncStatus, SyncUpdateInput +from quivr_api.modules.sync.dto.outputs import SyncProvider from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.sync.utils.oauth2 import parse_oauth2_state from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -57,7 +58,7 @@ async def authorize_notion( """ logger.debug(f"Authorizing Notion sync for user: {current_user.id}, name : {name}") state = await syncs_service.create_oauth2_state( - provider="Notion", name=name, user_id=current_user.id + provider=SyncProvider.NOTION, name=name, user_id=current_user.id ) # Finalize the state authorize_url = str(NOTION_AUTH_URL) + f"&state={state.model_dump_json()}" diff --git a/backend/api/quivr_api/modules/sync/controller/sync_routes.py b/backend/api/quivr_api/modules/sync/controller/sync_routes.py index 17ee0b61c247..e01d9df8d752 100644 --- a/backend/api/quivr_api/modules/sync/controller/sync_routes.py +++ b/backend/api/quivr_api/modules/sync/controller/sync_routes.py @@ -20,7 +20,7 @@ from quivr_api.modules.sync.controller.google_sync_routes import google_sync_router from quivr_api.modules.sync.controller.notion_sync_routes import notion_sync_router from quivr_api.modules.sync.dto import SyncsDescription -from quivr_api.modules.sync.dto.outputs import AuthMethodEnum +from quivr_api.modules.sync.dto.outputs import AuthMethodEnum, SyncProvider from quivr_api.modules.sync.entity.sync_models import SyncFile from quivr_api.modules.sync.service.sync_service import SyncsService from quivr_api.modules.user.entity.user_identity import UserIdentity @@ -51,31 +51,31 @@ # Google sync description google_sync = SyncsDescription( - name="Google", + name=SyncProvider.GOOGLE, description="Sync your Google Drive with Quivr", auth_method=AuthMethodEnum.URI_WITH_CALLBACK, ) azure_sync = SyncsDescription( - name="Azure", + name=SyncProvider.AZURE, description="Sync your Azure Drive with Quivr", auth_method=AuthMethodEnum.URI_WITH_CALLBACK, ) dropbox_sync = SyncsDescription( - name="DropBox", + name=SyncProvider.DROPBOX, description="Sync your DropBox Drive with Quivr", auth_method=AuthMethodEnum.URI_WITH_CALLBACK, ) notion_sync = SyncsDescription( - name="Notion", + name=SyncProvider.NOTION, description="Sync your Notion with Quivr", auth_method=AuthMethodEnum.URI_WITH_CALLBACK, ) github_sync = SyncsDescription( - name="GitHub", + name=SyncProvider.GITHUB, description="Sync your GitHub Drive with Quivr", auth_method=AuthMethodEnum.URI_WITH_CALLBACK, ) diff --git a/backend/api/quivr_api/modules/sync/service/sync_service.py b/backend/api/quivr_api/modules/sync/service/sync_service.py index 18428fd29c99..ea440ac9884e 100644 --- a/backend/api/quivr_api/modules/sync/service/sync_service.py +++ b/backend/api/quivr_api/modules/sync/service/sync_service.py @@ -10,7 +10,7 @@ SyncStatus, SyncUpdateInput, ) -from quivr_api.modules.sync.dto.outputs import SyncsOutput +from quivr_api.modules.sync.dto.outputs import SyncProvider, SyncsOutput from quivr_api.modules.sync.repository.sync_repository import SyncsRepository from quivr_api.modules.sync.utils.oauth2 import Oauth2BaseState, Oauth2State @@ -54,7 +54,7 @@ async def get_from_oauth2_state(self, state: Oauth2State) -> SyncsOutput: async def create_oauth2_state( self, - provider: str, + provider: SyncProvider, name: str, user_id: UUID, additional_data: dict[str, Any] = {}, From 939d28df18755568307689cf7e1cd61ada37802a Mon Sep 17 00:00:00 2001 From: aminediro Date: Wed, 16 Oct 2024 15:13:43 +0200 Subject: [PATCH 62/63] fixed wrong parsing file extension --- backend/api/quivr_api/modules/knowledge/dto/inputs.py | 5 ++--- .../modules/knowledge/service/knowledge_service.py | 10 ++++++++-- .../modules/knowledge/tests/test_knowledge_service.py | 5 ++--- backend/api/quivr_api/utils/knowledge_utils.py | 10 ++++++++++ 4 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 backend/api/quivr_api/utils/knowledge_utils.py diff --git a/backend/api/quivr_api/modules/knowledge/dto/inputs.py b/backend/api/quivr_api/modules/knowledge/dto/inputs.py index 8ce186311987..9bf6da2e6a44 100644 --- a/backend/api/quivr_api/modules/knowledge/dto/inputs.py +++ b/backend/api/quivr_api/modules/knowledge/dto/inputs.py @@ -26,14 +26,13 @@ class CreateKnowledgeProperties(BaseModel): class AddKnowledge(BaseModel): file_name: Optional[str] = None url: Optional[str] = None - extension: str = ".txt" + is_folder: bool = False source: str = "local" source_link: Optional[str] = None - metadata: Optional[Dict[str, str]] = None - is_folder: bool = False parent_id: UUID | None = None sync_id: int | None = None sync_file_id: str | None = None + metadata: Optional[Dict[str, str]] = None class KnowledgeUpdate(BaseModel): diff --git a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py index 395d9fc19535..71a96124fa2f 100644 --- a/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/service/knowledge_service.py @@ -35,6 +35,7 @@ ) from quivr_api.modules.sync.entity.sync_models import SyncFile, SyncType from quivr_api.modules.upload.service.upload_file import check_file_exists +from quivr_api.utils.knowledge_utils import parse_file_extension logger = get_logger(__name__) @@ -95,7 +96,6 @@ async def create_or_link_sync_knowledge( knowledge_to_add=AddKnowledge( file_name=sync_file.name, is_folder=sync_file.is_folder, - extension=sync_file.extension, source=parent_knowledge.source, # same as parent source_link=sync_file.web_view_link, parent_id=parent_knowledge.id, @@ -183,13 +183,19 @@ async def create_knowledge( if b.brain_id not in {b.brain_id for b in brains} ] ) + # TODO: slugify url names here !! + extension = ( + parse_file_extension(knowledge_to_add.file_name) + if knowledge_to_add.file_name + else "" + ) knowledgedb = KnowledgeDB( user_id=user_id, file_name=knowledge_to_add.file_name, is_folder=knowledge_to_add.is_folder, url=knowledge_to_add.url, - extension=knowledge_to_add.extension, + extension=extension, source=knowledge_to_add.source, source_link=knowledge_to_add.source_link, file_size=upload_file.size if upload_file else 0, diff --git a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py index 88f78a5193c0..c2cd9450bcbd 100644 --- a/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py +++ b/backend/api/quivr_api/modules/knowledge/tests/test_knowledge_service.py @@ -549,7 +549,7 @@ async def test_create_knowledge_folder(session: AsyncSession, user: User): service = KnowledgeService(repository, storage) km_to_add = AddKnowledge( - file_name="test", + file_name="test.txt", source="local", is_folder=True, parent_id=None, @@ -566,9 +566,9 @@ async def test_create_knowledge_folder(session: AsyncSession, user: User): assert km.id # Knowledge properties assert km.file_name == km_to_add.file_name + assert km.extension == ".txt" assert km.is_folder == km_to_add.is_folder assert km.url == km_to_add.url - assert km.extension == km_to_add.extension assert km.source == km_to_add.source assert km.file_size == 128 assert km.metadata_ == km_to_add.metadata @@ -613,7 +613,6 @@ def _send_task(*args, **kwargs): assert km.file_name == km_to_add.file_name assert km.is_folder == km_to_add.is_folder assert km.url == km_to_add.url - assert km.extension == km_to_add.extension assert km.source == km_to_add.source assert km.file_size == 128 assert km.metadata_ == km_to_add.metadata diff --git a/backend/api/quivr_api/utils/knowledge_utils.py b/backend/api/quivr_api/utils/knowledge_utils.py new file mode 100644 index 000000000000..5a5a24dc7fb6 --- /dev/null +++ b/backend/api/quivr_api/utils/knowledge_utils.py @@ -0,0 +1,10 @@ +from quivr_core.files.file import FileExtension + + +def parse_file_extension(file_name: str) -> FileExtension | str: + if file_name.startswith(".") and file_name.count(".") == 1: + return "" + if "." not in file_name or file_name.endswith("."): + return "" + + return FileExtension(f".{file_name.split('.')[-1]}") From fe1680542b557acc47fcd70c1f12f3cc3af16b80 Mon Sep 17 00:00:00 2001 From: Antoine Dewez <44063631+Zewed@users.noreply.github.com> Date: Tue, 22 Oct 2024 09:53:11 +0200 Subject: [PATCH 63/63] feat(frontend): display all knowledge in a kms (#3411) # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --- backend/supabase/seed.sql | 2 +- frontend/app/App.tsx | 37 +- .../KnowledgeToFeed.module.scss | 29 - .../KnowledgeToFeed/KnowledgeToFeed.tsx | 132 ---- .../FromConnections/FileLine/FileLine.tsx | 25 - .../FolderLine/FolderLine.module.scss | 0 .../FromConnections/FolderLine/FolderLine.tsx | 28 - .../FromConnections.module.scss | 42 -- .../FromConnections/FromConnections.tsx | 150 ----- .../FromConnection-provider.tsx | 69 --- .../hooks/useFromConnectionContext.tsx | 17 - .../SyncElementLine.module.scss | 40 -- .../SyncElementLine/SyncElementLine.tsx | 128 ---- .../FromDocuments/FromDocuments.module.scss | 47 -- .../FromDocuments/FromDocuments.tsx | 42 -- .../FromWebsites/FromWebsites.module.scss | 0 .../components/FromWebsites/FromWebsites.tsx | 20 - .../hooks/useFeedBrainInChat.ts | 121 ---- .../components/KnowledgeToFeed/index.ts | 1 - .../components/ActionsBar/components/index.ts | 1 - .../MessageRow/components/Source/Source.tsx | 8 +- .../ChatDialogue/hooks/useChatDialogue.ts | 4 +- .../hooks/useChatNotificationsSync.ts | 78 --- frontend/app/chat/[chatId]/page.tsx | 26 +- frontend/app/globals.css | 6 +- frontend/app/knowledge/page.module.scss | 13 + frontend/app/knowledge/page.tsx | 21 + .../AssistantTab/AssistantTab.tsx | 6 +- frontend/app/quality-assistant/page.tsx | 2 +- frontend/app/search/page.tsx | 7 +- .../components/Analytics/Analytics.tsx | 2 +- .../BrainManagementTabs.module.scss | 38 +- .../BrainManagementTabs.tsx | 38 +- .../KnowledgeTab/KnowledgeTab.module.scss | 2 + .../components/KnowledgeTab/KnowledgeTab.tsx | 31 +- .../KnowledgeItem/KnowledgeItem.tsx | 28 +- .../KnowledgeItem/hooks/useKnowledgeItem.ts | 4 +- .../KnowledgeTable/KnowledgeTable.module.scss | 105 ++-- .../KnowledgeTable/KnowledgeTable.tsx | 185 +++--- .../KnowledgeTab/hooks/useAddedKnowledge.ts | 4 +- .../KnowledgeTab/hooks/useKnowledge.ts | 4 +- .../app/studio/[brainId]/page.module.scss | 3 +- frontend/app/studio/[brainId]/page.tsx | 13 - frontend/app/studio/page.tsx | 12 - frontend/lib/api/knowledge/knowledge.ts | 155 ++++- frontend/lib/api/knowledge/useKnowledgeApi.ts | 36 +- frontend/lib/api/sync/sync.ts | 30 +- frontend/lib/api/sync/types.ts | 43 +- frontend/lib/api/sync/useSync.ts | 30 +- .../AddBrainModal/AddBrainModal.module.scss | 4 - .../AddBrainModal/AddBrainModal.tsx | 64 +- .../AddBrainModal/brainCreation-provider.tsx | 17 - .../BrainCreationForm.module.scss} | 10 +- .../BrainCreationForm.tsx} | 93 ++- .../BrainRecapCard/BrainRecapCard.module.scss | 24 - .../BrainRecapCard/BrainRecapCard.tsx | 21 - .../BrainRecapStep/BrainRecapStep.module.scss | 76 --- .../BrainRecapStep/BrainRecapStep.tsx | 137 ----- .../FeedBrainStep/FeedBrainStep.module.scss | 42 -- .../FeedBrainStep/FeedBrainStep.tsx | 112 ---- .../hooks/useBrainCreationApi.ts | 144 ----- .../components/Stepper/Stepper.module.scss | 118 ---- .../components/Stepper/Stepper.tsx | 63 -- .../components/hooks/useBrainCreationApi.ts | 97 +++ .../hooks/useBrainCreationSteps.ts | 66 -- .../components/AddBrainModal/types/types.ts | 10 - .../ConnectionCards/ConnectionCards.tsx | 12 +- .../ConnectionLine/ConnectionLine.tsx | 6 +- .../ConnectionSection/ConnectionSection.tsx | 133 +--- .../CurrentFolderExplorer.module.scss | 7 + .../CurrentFolderExplorer.tsx | 34 + .../ProviderAccount.module.scss | 25 + .../ProviderAccount/ProviderAccount.tsx | 38 ++ .../ProviderCurrentFolder.module.scss | 29 + .../ProviderCurrentFolder.tsx | 136 ++++ .../AddFolderModal/AddFolderModal.module.scss | 13 + .../AddFolderModal/AddFolderModal.tsx | 92 +++ .../AddKnowledgeModal.module.scss | 60 ++ .../AddKnowledgeModal/AddKnowledgeModal.tsx | 233 +++++++ .../QuivrCurrentFolder.module.scss | 41 ++ .../QuivrCurrentFolder/QuivrCurrentFolder.tsx | 235 +++++++ .../AddToBrainsModal.module.scss | 43 ++ .../AddToBrainsModal/AddToBrainsModal.tsx | 166 +++++ .../ConnectedBrains.module.scss | 73 +++ .../ConnectedBrains/ConnectedBrains.tsx | 171 ++++++ .../CurrentFolderExplorerLine.module.scss | 74 +++ .../CurrentFolderExplorerLine.tsx | 170 +++++ .../FolderExplorerHeader.module.scss | 80 +++ .../FolderExplorerHeader.tsx | 185 ++++++ .../KnowledgeManagementSystem.module.scss | 66 ++ .../KnowledgeManagementSystem.tsx | 93 +++ .../hooks/useKnowledgeContext.tsx | 15 + .../KnowledgeProvider/knowledge-provider.tsx | 76 +++ .../ConnectionAccount.module.scss | 58 ++ .../ConnectionAccount/ConnectionAccount.tsx | 115 ++++ .../ConnectionKnowledges.module.scss | 44 ++ .../ConnectionKnowledges.tsx | 75 +++ .../SyncFolder/SyncFolder.module.scss | 46 ++ .../SyncFolder/SyncFolder.tsx | 116 ++++ .../ConnectionsKnowledge.module.scss | 12 + .../ConnectionsKnowledges.tsx | 58 ++ .../QuivrFolder/QuivrFolder.module.scss | 51 ++ .../QuivrFolder/QuivrFolder.tsx | 131 ++++ .../QuivrKnowledges.module.scss | 52 ++ .../Menu/QuivrKnowledge/QuivrKnowledges.tsx | 126 ++++ .../KnowledgeToFeedInput.tsx | 43 -- .../components/Crawler/helpers/isValidUrl.ts | 9 - .../components/Crawler/hooks/useCrawler.ts | 62 -- .../components/Crawler/index.tsx | 43 -- .../components/FeedItems/FeedItems.tsx | 34 - .../FeedItems/components/CrawlFeedItem.tsx | 34 - .../FeedTitleDisplayer/FeedTitleDisplayer.tsx | 24 - .../components/FeedTitleDisplayer/index.ts | 1 - .../utils/enhanceUrlDisplay.ts | 24 - .../utils/removeFileExtension.ts | 8 - .../FeedItems/components/FileFeedItem.tsx | 36 -- .../components/FeedItems/components/index.ts | 1 - .../components/FeedItems/index.ts | 1 - .../FeedItems/styles/StyledFeedItemDiv.tsx | 17 - .../components/FileUploader/index.tsx | 38 -- .../KnowledgeToFeedInput/components/index.ts | 3 - .../hooks/useKnowledgeToFeedInput.ts.ts | 94 --- .../components/KnowledgeToFeedInput/index.ts | 1 - frontend/lib/components/Menu/Menu.tsx | 14 +- .../KnowledgeButton.module.scss} | 0 .../KnowledgeButton/KnowledgeButton.tsx | 21 + .../GenericNotification.tsx | 2 +- .../UploadDocumentModal.module.scss | 22 - .../UploadDocumentModal.tsx | 127 ---- .../hooks/useAddKnowledge.ts | 31 - .../UploadDocumentModal/hooks/useFeedBrain.ts | 82 --- .../hooks/useFeedBrainHandler.ts | 77 --- .../ui/FileInput/FileInput.module.scss | 3 +- .../lib/components/ui/FileInput/FileInput.tsx | 42 +- .../lib/components/ui/Icon/Icon.module.scss | 6 +- frontend/lib/components/ui/Icon/Icon.tsx | 9 +- .../components/ui/QuivrButton/QuivrButton.tsx | 2 +- .../ui/TextInput/TextInput.module.scss | 5 + .../lib/components/ui/TextInput/TextInput.tsx | 63 +- .../lib/components/ui/Tooltip/Tooltip.tsx | 14 +- frontend/lib/context/BrainProvider/types.ts | 1 + .../hooks/useKnowledgeToFeedContext.tsx | 41 -- .../context/KnowledgeToFeedProvider/index.ts | 1 - .../knowledgeToFeed-provider.tsx | 38 -- frontend/lib/context/index.tsx | 1 - .../helpers/formatBrains.ts} | 0 .../lib/helpers/handleConnectionButtons.ts | 182 ------ frontend/lib/helpers/iconList.ts | 46 +- frontend/lib/helpers/kms.ts | 51 ++ frontend/lib/helpers/providers.ts | 12 + frontend/lib/hooks/useDropzone.ts | 88 --- frontend/lib/hooks/useKnowledgeToFeed.ts | 26 - frontend/lib/types/Knowledge.ts | 32 +- frontend/lib/types/Modal.ts | 6 - frontend/styles/_Variables.module.scss | 1 + frontend/yarn.lock | 579 ++++++++---------- 156 files changed, 4250 insertions(+), 4001 deletions(-) delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.module.scss delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FileLine/FileLine.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FolderLine/FolderLine.module.scss delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FolderLine/FolderLine.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.module.scss delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/FromConnection-provider.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.module.scss delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.module.scss delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromWebsites/FromWebsites.module.scss delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromWebsites/FromWebsites.tsx delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/hooks/useFeedBrainInChat.ts delete mode 100644 frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/index.ts delete mode 100644 frontend/app/chat/[chatId]/hooks/useChatNotificationsSync.ts create mode 100644 frontend/app/knowledge/page.module.scss create mode 100644 frontend/app/knowledge/page.tsx rename frontend/lib/components/AddBrainModal/components/{BrainMainInfosStep/BrainMainInfosStep.module.scss => BrainCreationForm/BrainCreationForm.module.scss} (94%) rename frontend/lib/components/AddBrainModal/components/{BrainMainInfosStep/BrainMainInfosStep.tsx => BrainCreationForm/BrainCreationForm.tsx} (52%) delete mode 100644 frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.module.scss delete mode 100644 frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.tsx delete mode 100644 frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.module.scss delete mode 100644 frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.tsx delete mode 100644 frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.module.scss delete mode 100644 frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.tsx delete mode 100644 frontend/lib/components/AddBrainModal/components/FeedBrainStep/hooks/useBrainCreationApi.ts delete mode 100644 frontend/lib/components/AddBrainModal/components/Stepper/Stepper.module.scss delete mode 100644 frontend/lib/components/AddBrainModal/components/Stepper/Stepper.tsx create mode 100644 frontend/lib/components/AddBrainModal/components/hooks/useBrainCreationApi.ts delete mode 100644 frontend/lib/components/AddBrainModal/hooks/useBrainCreationSteps.ts create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/knowledge-provider.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledge.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledges.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.tsx create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.module.scss create mode 100644 frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/KnowledgeToFeedInput.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/Crawler/helpers/isValidUrl.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/Crawler/hooks/useCrawler.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/Crawler/index.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/FeedItems.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/CrawlFeedItem.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/FeedTitleDisplayer.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/index.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/enhanceUrlDisplay.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/removeFileExtension.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FileFeedItem.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/index.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/index.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/styles/StyledFeedItemDiv.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/FileUploader/index.tsx delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/components/index.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/hooks/useKnowledgeToFeedInput.ts.ts delete mode 100644 frontend/lib/components/KnowledgeToFeedInput/index.ts rename frontend/{app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FileLine/FileLine.module.scss => lib/components/Menu/components/KnowledgeButton/KnowledgeButton.module.scss} (100%) create mode 100644 frontend/lib/components/Menu/components/KnowledgeButton/KnowledgeButton.tsx delete mode 100644 frontend/lib/components/UploadDocumentModal/UploadDocumentModal.module.scss delete mode 100644 frontend/lib/components/UploadDocumentModal/UploadDocumentModal.tsx delete mode 100644 frontend/lib/components/UploadDocumentModal/hooks/useAddKnowledge.ts delete mode 100644 frontend/lib/components/UploadDocumentModal/hooks/useFeedBrain.ts delete mode 100644 frontend/lib/components/UploadDocumentModal/hooks/useFeedBrainHandler.ts delete mode 100644 frontend/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext.tsx delete mode 100644 frontend/lib/context/KnowledgeToFeedProvider/index.ts delete mode 100644 frontend/lib/context/KnowledgeToFeedProvider/knowledgeToFeed-provider.tsx rename frontend/{app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/utils/formatMinimalBrainsToSelectComponentInput.ts => lib/helpers/formatBrains.ts} (100%) delete mode 100644 frontend/lib/helpers/handleConnectionButtons.ts create mode 100644 frontend/lib/helpers/kms.ts create mode 100644 frontend/lib/helpers/providers.ts delete mode 100644 frontend/lib/hooks/useDropzone.ts delete mode 100644 frontend/lib/hooks/useKnowledgeToFeed.ts delete mode 100644 frontend/lib/types/Modal.ts diff --git a/backend/supabase/seed.sql b/backend/supabase/seed.sql index a0b4ae777eba..0ccc144eb1e3 100644 --- a/backend/supabase/seed.sql +++ b/backend/supabase/seed.sql @@ -298,7 +298,7 @@ INSERT INTO "public"."user_daily_usage" ("user_id", "email", "date", "daily_requ -- INSERT INTO "public"."user_identity" ("user_id", "openai_api_key", "company", "onboarded", "username", "company_size", "usage_purpose") VALUES - ('39418e3b-0258-4452-af60-7acfcc1263ff', NULL, 'Stan', true, 'Stan', NULL, ''); + ('39418e3b-0258-4452-af60-7acfcc1263ff', NULL, 'Quivr Local', true, 'Quivr Local', NULL, ''); -- diff --git a/frontend/app/App.tsx b/frontend/app/App.tsx index 401b3258aa1b..f5679effe066 100644 --- a/frontend/app/App.tsx +++ b/frontend/app/App.tsx @@ -10,11 +10,7 @@ import { HelpWindow } from "@/lib/components/HelpWindow/HelpWindow"; import { Menu } from "@/lib/components/Menu/Menu"; import { useOutsideClickListener } from "@/lib/components/Menu/hooks/useOutsideClickListener"; import { SearchModal } from "@/lib/components/SearchModal/SearchModal"; -import { - BrainProvider, - ChatProvider, - KnowledgeToFeedProvider, -} from "@/lib/context"; +import { BrainProvider, ChatProvider } from "@/lib/context"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; import { ChatsProvider } from "@/lib/context/ChatsProvider"; import { HelpProvider } from "@/lib/context/HelpProvider/help-provider"; @@ -32,7 +28,6 @@ import { usePageTracking } from "@/services/analytics/june/usePageTracking"; import "../lib/config/LocaleConfig/i18n"; import styles from "./App.module.scss"; -import { FromConnectionsProvider } from "./chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/FromConnection-provider"; if ( process.env.NEXT_PUBLIC_POSTHOG_KEY != null && @@ -106,23 +101,19 @@ const AppWithQueryClient = ({ children }: PropsWithChildren): JSX.Element => { - - - - - - - - - {children} - - - - - - - - + + + + + + + {children} + + + + + + diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.module.scss b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.module.scss deleted file mode 100644 index d30ea4ed6aab..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.module.scss +++ /dev/null @@ -1,29 +0,0 @@ -@use "styles/ScreenSizes.module.scss"; -@use "styles/Spacings.module.scss"; -@use "styles/Typography.module.scss"; - -.knowledge_to_feed_wrapper { - display: flex; - flex-direction: column; - padding-block: Spacings.$spacing05; - width: 100%; - gap: Spacings.$spacing05; - overflow: hidden; - height: 100%; - - .single_selector_wrapper { - width: 30%; - min-width: 250px; - - @media (max-width: ScreenSizes.$small) { - width: 100%; - } - } - - .tabs_content_wrapper { - width: 100%; - height: 80%; - overflow: auto; - padding: Spacings.$spacing01; - } -} diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.tsx deleted file mode 100644 index d85098f2974d..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/KnowledgeToFeed.tsx +++ /dev/null @@ -1,132 +0,0 @@ -import { useEffect, useMemo, useState } from "react"; - -import { useSync } from "@/lib/api/sync/useSync"; -import { SingleSelector } from "@/lib/components/ui/SingleSelector/SingleSelector"; -import { Tabs } from "@/lib/components/ui/Tabs/Tabs"; -import { requiredRolesForUpload } from "@/lib/config/upload"; -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { Tab } from "@/lib/types/Tab"; - -import styles from "./KnowledgeToFeed.module.scss"; -import { FromConnections } from "./components/FromConnections/FromConnections"; -import { useFromConnectionsContext } from "./components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { FromDocuments } from "./components/FromDocuments/FromDocuments"; -import { FromWebsites } from "./components/FromWebsites/FromWebsites"; -import { formatMinimalBrainsToSelectComponentInput } from "./utils/formatMinimalBrainsToSelectComponentInput"; - -export const KnowledgeToFeed = ({ - hideBrainSelector, -}: { - hideBrainSelector?: boolean; -}): JSX.Element => { - const { allBrains, setCurrentBrainId, currentBrainId, currentBrain } = - useBrainContext(); - const [selectedTab, setSelectedTab] = useState("Documents"); - const { knowledgeToFeed } = useKnowledgeToFeedContext(); - const { openedConnections, setOpenedConnections, setCurrentSyncId } = - useFromConnectionsContext(); - const { getActiveSyncsForBrain } = useSync(); - - const brainsWithUploadRights = formatMinimalBrainsToSelectComponentInput( - useMemo( - () => - allBrains.filter( - (brain) => - requiredRolesForUpload.includes(brain.role) && !!brain.max_files - ), - [allBrains] - ) - ); - - const knowledgesTabs: Tab[] = [ - { - label: "Documents", - isSelected: selectedTab === "Documents", - onClick: () => setSelectedTab("Documents"), - iconName: "file", - badge: knowledgeToFeed.filter( - (knowledge) => knowledge.source === "upload" - ).length, - }, - { - label: "Connections", - isSelected: selectedTab === "Connections", - onClick: () => setSelectedTab("Connections"), - iconName: "sync", - badge: openedConnections.filter((connection) => connection.submitted) - .length, - }, - { - label: "Websites' page", - isSelected: selectedTab === "Websites", - onClick: () => setSelectedTab("Websites"), - iconName: "website", - badge: knowledgeToFeed.filter((knowledge) => knowledge.source === "crawl") - .length, - }, - ]; - - useEffect(() => { - if (currentBrain) { - void (async () => { - try { - const res = await getActiveSyncsForBrain(currentBrain.id); - setCurrentSyncId(undefined); - setOpenedConnections( - res.map((sync) => ({ - user_sync_id: sync.syncs_user_id, - id: sync.id, - provider: sync.syncs_user.provider, - submitted: true, - selectedFiles: { - files: [ - ...(sync.settings.folders?.map((folder) => ({ - id: folder, - name: undefined, - is_folder: true, - })) ?? []), - ...(sync.settings.files?.map((file) => ({ - id: file, - name: undefined, - is_folder: false, - })) ?? []), - ], - }, - name: sync.name, - last_synced: sync.last_synced, - })) - ); - } catch (error) { - console.error(error); - } - })(); - } - }, [currentBrainId]); - - return ( -
- {!hideBrainSelector && ( -
- -
- )} - -
- {selectedTab === "Connections" && } - {selectedTab === "Documents" && } - {selectedTab === "Websites" && } -
-
- ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FileLine/FileLine.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FileLine/FileLine.tsx deleted file mode 100644 index 6073a7cf9de2..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FileLine/FileLine.tsx +++ /dev/null @@ -1,25 +0,0 @@ -import { SyncElementLine } from "../SyncElementLine/SyncElementLine"; - -interface FileLineProps { - name: string; - selectable: boolean; - id: string; - icon?: string; -} - -export const FileLine = ({ - name, - selectable, - id, - icon -}: FileLineProps): JSX.Element => { - return ( - - ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FolderLine/FolderLine.module.scss b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FolderLine/FolderLine.module.scss deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FolderLine/FolderLine.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FolderLine/FolderLine.tsx deleted file mode 100644 index f0f0c877b397..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FolderLine/FolderLine.tsx +++ /dev/null @@ -1,28 +0,0 @@ -import { SyncElementLine } from "../SyncElementLine/SyncElementLine"; - -interface FolderLineProps { - name: string; - selectable: boolean; - id: string; - icon?: string; - isAlsoFile?: boolean; -} - -export const FolderLine = ({ - name, - selectable, - id, - icon, - isAlsoFile, -}: FolderLineProps): JSX.Element => { - return ( - - ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.module.scss b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.module.scss deleted file mode 100644 index fa882ff49d17..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.module.scss +++ /dev/null @@ -1,42 +0,0 @@ -@use "styles/Spacings.module.scss"; -@use "styles/Typography.module.scss"; - -.from_connection_container { - overflow: auto; - height: 100%; - padding: Spacings.$spacing01; - - .from_connection_wrapper { - display: flex; - flex-direction: column; - gap: Spacings.$spacing06; - overflow: hidden; - max-height: 100%; - - .header_buttons { - display: flex; - justify-content: space-between; - } - - .connection_content { - overflow: auto; - flex-grow: 1; - - &.disable { - opacity: 0.5; - pointer-events: none; - } - - .loader_icon { - display: flex; - align-items: center; - justify-content: center; - } - - .empty_folder { - font-style: italic; - font-size: Typography.$small; - } - } - } -} diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.tsx deleted file mode 100644 index cdbf8ffc5e50..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnections.tsx +++ /dev/null @@ -1,150 +0,0 @@ -import { useEffect, useState } from "react"; - -import { SyncElement } from "@/lib/api/sync/types"; -import { useSync } from "@/lib/api/sync/useSync"; -import { ConnectionCards } from "@/lib/components/ConnectionCards/ConnectionCards"; -import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; -import { TextButton } from "@/lib/components/ui/TextButton/TextButton"; -import { useUserData } from "@/lib/hooks/useUserData"; - -import { FileLine } from "./FileLine/FileLine"; -import { FolderLine } from "./FolderLine/FolderLine"; -import styles from "./FromConnections.module.scss"; -import { useFromConnectionsContext } from "./FromConnectionsProvider/hooks/useFromConnectionContext"; - -export const FromConnections = (): JSX.Element => { - const [folderStack, setFolderStack] = useState<(string | null)[]>([]); - const { - currentSyncElements, - setCurrentSyncElements, - currentSyncId, - loadingFirstList, - setCurrentSyncId, - currentProvider, - } = useFromConnectionsContext(); - const [currentFiles, setCurrentFiles] = useState([]); - const [currentFolders, setCurrentFolders] = useState([]); - const { getSyncFiles } = useSync(); - const { userData } = useUserData(); - const [loading, setLoading] = useState(false); - - const isPremium = userData?.is_premium; - - useEffect(() => { - setCurrentFiles( - currentSyncElements?.files.filter((file) => !file.is_folder) ?? [] - ); - setCurrentFolders( - currentSyncElements?.files.filter((file) => file.is_folder) ?? [] - ); - setLoading(false); - }, [currentSyncElements]); - - const handleGetSyncFiles = async ( - userSyncId: number, - folderId: string | null - ) => { - try { - setLoading(true); - let res; - if (folderId !== null) { - res = await getSyncFiles(userSyncId, folderId); - } else { - res = await getSyncFiles(userSyncId); - } - setCurrentSyncElements(res); - } catch (error) { - console.error("Failed to get sync files:", error); - } - }; - - const handleBackClick = async () => { - if (folderStack.length > 0 && currentSyncId) { - const newFolderStack = [...folderStack]; - newFolderStack.pop(); - setFolderStack(newFolderStack); - const parentFolderId = newFolderStack[newFolderStack.length - 1]; - await handleGetSyncFiles(currentSyncId, parentFolderId); - } else { - setCurrentSyncElements({ files: [] }); - } - }; - - const handleFolderClick = async (userSyncId: number, folderId: string) => { - setFolderStack([...folderStack, folderId]); - await handleGetSyncFiles(userSyncId, folderId); - }; - - return ( -
- {!currentSyncId && !loadingFirstList ? ( - - ) : ( -
-
- { - if (folderStack.length) { - void handleBackClick(); - } else { - setCurrentSyncId(undefined); - } - }} - small={true} - disabled={loading || loadingFirstList} - /> -
-
- {loading || loadingFirstList ? ( -
- -
- ) : ( - <> - {currentFolders.map((folder) => ( -
{ - if (currentSyncId) { - void handleFolderClick(currentSyncId, folder.id); - } - }} - > - -
- ))} - {currentFiles.map((file) => ( -
- -
- ))} - - )} - {!currentFiles.length && - !currentFolders.length && - !loading && - !loadingFirstList && ( - Empty folder - )} -
-
- )} -
- ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/FromConnection-provider.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/FromConnection-provider.tsx deleted file mode 100644 index ac83aa192f9f..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/FromConnection-provider.tsx +++ /dev/null @@ -1,69 +0,0 @@ -import { createContext, useState } from "react"; - -import { OpenedConnection, Provider, SyncElements } from "@/lib/api/sync/types"; - -export type FromConnectionsContextType = { - currentSyncElements: SyncElements | undefined; - setCurrentSyncElements: React.Dispatch< - React.SetStateAction - >; - currentSyncId: number | undefined; - setCurrentSyncId: React.Dispatch>; - openedConnections: OpenedConnection[]; - setOpenedConnections: React.Dispatch< - React.SetStateAction - >; - hasToReload: boolean; - setHasToReload: React.Dispatch>; - loadingFirstList: boolean; - setLoadingFirstList: React.Dispatch>; - currentProvider: Provider | null; - setCurrentProvider: React.Dispatch>; -}; - -export const FromConnectionsContext = createContext< - FromConnectionsContextType | undefined ->(undefined); - -export const FromConnectionsProvider = ({ - children, -}: { - children: React.ReactNode; -}): JSX.Element => { - const [currentSyncElements, setCurrentSyncElements] = useState< - SyncElements | undefined - >(undefined); - const [currentSyncId, setCurrentSyncId] = useState( - undefined - ); - const [openedConnections, setOpenedConnections] = useState< - OpenedConnection[] - >([]); - const [currentProvider, setCurrentProvider] = useState(null); - const [hasToReload, setHasToReload] = useState(false); - const [loadingFirstList, setLoadingFirstList] = useState(false); - - return ( - - {children} - - ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext.tsx deleted file mode 100644 index 034d2b01c4ac..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext.tsx +++ /dev/null @@ -1,17 +0,0 @@ -import { useContext } from "react"; - -import { - FromConnectionsContext, - FromConnectionsContextType, -} from "../FromConnection-provider"; - -export const useFromConnectionsContext = (): FromConnectionsContextType => { - const context = useContext(FromConnectionsContext); - if (context === undefined) { - throw new Error( - "useFromConnectionsContext must be used within a FromConnectionsProvider" - ); - } - - return context; -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.module.scss b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.module.scss deleted file mode 100644 index ae3d6636075b..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.module.scss +++ /dev/null @@ -1,40 +0,0 @@ -@use "styles/Spacings.module.scss"; -@use "styles/Typography.module.scss"; - -.sync_element_line_wrapper { - display: flex; - justify-content: space-between; - padding: Spacings.$spacing03; - border-top: 1px solid var(--border-1); - align-items: center; - cursor: pointer; - font-weight: 500; - - .left { - display: flex; - gap: Spacings.$spacing03; - align-items: center; - overflow: hidden; - - .element_name { - font-size: Typography.$small; - @include Typography.EllipsisOverflow; - } - - &.folder { - margin-left: Spacings.$spacing06; - } - } - - &:hover { - background-color: var(--background-3); - } - - &.no_hover { - cursor: default; - - &:hover { - background-color: var(--background-0); - } - } -} diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.tsx deleted file mode 100644 index 40b698bc8ecd..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/SyncElementLine/SyncElementLine.tsx +++ /dev/null @@ -1,128 +0,0 @@ -import { useState } from "react"; - -import { Checkbox } from "@/lib/components/ui/Checkbox/Checkbox"; -import { Icon } from "@/lib/components/ui/Icon/Icon"; -import Tooltip from "@/lib/components/ui/Tooltip/Tooltip"; - -import styles from "./SyncElementLine.module.scss"; - -import { useFromConnectionsContext } from "../FromConnectionsProvider/hooks/useFromConnectionContext"; - -interface SyncElementLineProps { - name: string; - selectable: boolean; - id: string; - isFolder: boolean; - icon?: string; - isAlsoFile?: boolean; -} - -export const SyncElementLine = ({ - name, - selectable, - id, - isFolder, - icon, - isAlsoFile, -}: SyncElementLineProps): JSX.Element => { - const [isCheckboxHovered, setIsCheckboxHovered] = useState(false); - const { currentSyncId, openedConnections, setOpenedConnections } = - useFromConnectionsContext(); - - const initialChecked = (): boolean => { - const currentConnection = openedConnections.find( - (connection) => connection.user_sync_id === currentSyncId - ); - - return currentConnection - ? currentConnection.selectedFiles.files.some((file) => file.id === id) - : false; - }; - - const [checked, setChecked] = useState(initialChecked); - - const showCheckbox: boolean = isAlsoFile ?? selectable; - - const handleSetChecked = () => { - setOpenedConnections((prevState) => { - return prevState.map((connection) => { - if (connection.user_sync_id === currentSyncId) { - const isFileSelected = connection.selectedFiles.files.some( - (file) => file.id === id - ); - const updatedFiles = isFileSelected - ? connection.selectedFiles.files.filter((file) => file.id !== id) - : [ - ...connection.selectedFiles.files, - { id, name, is_folder: isFolder }, - ]; - - return { - ...connection, - selectedFiles: { - files: updatedFiles, - }, - }; - } - - return connection; - }); - }); - setChecked((prevChecked) => !prevChecked); - }; - - const content = ( -
{ - if (isFolder && checked) { - event.stopPropagation(); - } - }} - > -
- {showCheckbox && ( -
setIsCheckboxHovered(true)} - onMouseLeave={() => setIsCheckboxHovered(false)} - style={{ pointerEvents: "auto" }} - > - -
- )} - {icon ? ( -
{icon}
- ) : ( - - )} - {name} -
- {isFolder && ( - - )} -
- ); - - return selectable ? ( - content - ) : ( - - {content} - - ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.module.scss b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.module.scss deleted file mode 100644 index 099fd78d6871..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.module.scss +++ /dev/null @@ -1,47 +0,0 @@ -@use "styles/Radius.module.scss"; -@use "styles/ScreenSizes.module.scss"; -@use "styles/Spacings.module.scss"; - -.from_document_wrapper { - width: 100%; - border: 1px dashed var(--border-0); - border-radius: Radius.$big; - box-sizing: border-box; - cursor: pointer; - height: 100%; - overflow: auto; - transition: border-color 0.3s ease 0.2s, border-width 0.1s ease 0.1s; - - &.dragging { - border: 3px dashed var(--accent); - background-color: var(--background-3); - } - - &:hover { - border: 3px dashed var(--accent); - } - - .box_content { - padding: Spacings.$spacing05; - display: flex; - flex-direction: column; - column-gap: Spacings.$spacing05; - justify-content: center; - align-items: center; - height: 100%; - - .input { - display: flex; - gap: Spacings.$spacing02; - padding: Spacings.$spacing05; - - @media (max-width: ScreenSizes.$small) { - flex-direction: column; - } - - .clickable { - font-weight: bold; - } - } - } -} diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.tsx deleted file mode 100644 index 4bbe25061265..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromDocuments/FromDocuments.tsx +++ /dev/null @@ -1,42 +0,0 @@ -import { useEffect, useState } from "react"; - -import { Icon } from "@/lib/components/ui/Icon/Icon"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { useCustomDropzone } from "@/lib/hooks/useDropzone"; - -import styles from "./FromDocuments.module.scss"; - -export const FromDocuments = (): JSX.Element => { - const [dragging, setDragging] = useState(false); - const { getRootProps, getInputProps, open } = useCustomDropzone(); - const { knowledgeToFeed } = useKnowledgeToFeedContext(); - - useEffect(() => { - setDragging(false); - }, [knowledgeToFeed]); - - return ( -
setDragging(true)} - onDragLeave={() => setDragging(false)} - onMouseLeave={() => setDragging(false)} - onClick={open} - > -
- -
-
- Choose files - -
- or drag it here -
-
-
- ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromWebsites/FromWebsites.module.scss b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromWebsites/FromWebsites.module.scss deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromWebsites/FromWebsites.tsx b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromWebsites/FromWebsites.tsx deleted file mode 100644 index d13b6ff693ff..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromWebsites/FromWebsites.tsx +++ /dev/null @@ -1,20 +0,0 @@ -import { useCrawler } from "@/lib/components/KnowledgeToFeedInput/components/Crawler/hooks/useCrawler"; -import { TextInput } from "@/lib/components/ui/TextInput/TextInput"; - -import styles from "./FromWebsites.module.scss"; - -export const FromWebsites = (): JSX.Element => { - const { handleSubmit, urlToCrawl, setUrlToCrawl } = useCrawler(); - - return ( -
- handleSubmit()} - /> -
- ); -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/hooks/useFeedBrainInChat.ts b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/hooks/useFeedBrainInChat.ts deleted file mode 100644 index 9c2359ebb289..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/hooks/useFeedBrainInChat.ts +++ /dev/null @@ -1,121 +0,0 @@ -/*eslint max-lines: ["error", 200 ]*/ - -import { useQueryClient } from "@tanstack/react-query"; -import { UUID } from "crypto"; -import { useParams, useRouter } from "next/navigation"; -import { useState } from "react"; -import { useTranslation } from "react-i18next"; - -import { CHATS_DATA_KEY } from "@/lib/api/chat/config"; -import { useChatApi } from "@/lib/api/chat/useChatApi"; -import { useNotificationApi } from "@/lib/api/notification/useNotificationApi"; -import { useKnowledgeToFeedInput } from "@/lib/components/KnowledgeToFeedInput/hooks/useKnowledgeToFeedInput.ts"; -import { useChatContext } from "@/lib/context"; -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { useToast } from "@/lib/hooks"; -import { useOnboarding } from "@/lib/hooks/useOnboarding"; - -import { FeedItemCrawlType, FeedItemUploadType } from "../../../types"; -import { useFromConnectionsContext } from "../components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useFeedBrainInChat = ({ - dispatchHasPendingRequests, -}: { - dispatchHasPendingRequests: () => void; -}) => { - const { publish } = useToast(); - const queryClient = useQueryClient(); - const { t } = useTranslation(["upload"]); - const router = useRouter(); - const { updateOnboarding, onboarding } = useOnboarding(); - const { setShouldDisplayFeedCard } = useKnowledgeToFeedContext(); - const { currentBrainId } = useBrainContext(); - const { setKnowledgeToFeed, knowledgeToFeed } = useKnowledgeToFeedContext(); - const [hasPendingRequests, setHasPendingRequests] = useState(false); - const { createChat } = useChatApi(); - const params = useParams(); - const chatId = params?.chatId as UUID | undefined; - const { setNotifications } = useChatContext(); - const { getChatNotifications } = useNotificationApi(); - const fetchNotifications = async (currentChatId: UUID): Promise => { - const fetchedNotifications = await getChatNotifications(currentChatId); - setNotifications(fetchedNotifications); - }; - const { openedConnections } = useFromConnectionsContext(); - const { crawlWebsiteHandler, uploadFileHandler } = useKnowledgeToFeedInput(); - const files: File[] = ( - knowledgeToFeed.filter((c) => c.source === "upload") as FeedItemUploadType[] - ).map((c) => c.file); - const urls: string[] = ( - knowledgeToFeed.filter((c) => c.source === "crawl") as FeedItemCrawlType[] - ).map((c) => c.url); - const feedBrain = async (): Promise => { - if (currentBrainId === null) { - publish({ - variant: "danger", - text: t("selectBrainFirst"), - }); - - return; - } - if (knowledgeToFeed.length === 0 && !openedConnections.length) { - publish({ - variant: "danger", - text: t("addFiles"), - }); - - return; - } - try { - dispatchHasPendingRequests(); - setShouldDisplayFeedCard(false); - setHasPendingRequests(true); - const currentChatId = chatId ?? (await createChat("New Chat")).chat_id; - const uploadPromises = files.map((file) => - uploadFileHandler(file, currentBrainId, currentChatId) - ); - const crawlPromises = urls.map((url) => - crawlWebsiteHandler(url, currentBrainId, currentChatId) - ); - - const updateOnboardingPromise = async () => { - if (onboarding.onboarding_a) { - await updateOnboarding({ - onboarding_a: false, - }); - } - }; - - await Promise.all([ - ...uploadPromises, - ...crawlPromises, - updateOnboardingPromise(), - ]); - - setKnowledgeToFeed([]); - - if (chatId === undefined) { - void queryClient.invalidateQueries({ - queryKey: [CHATS_DATA_KEY], - }); - void router.push(`/chat/${currentChatId}`); - } else { - await fetchNotifications(currentChatId); - } - } catch (e) { - publish({ - variant: "danger", - text: JSON.stringify(e), - }); - } finally { - setHasPendingRequests(false); - } - }; - - return { - feedBrain, - hasPendingRequests, - }; -}; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/index.ts b/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/index.ts deleted file mode 100644 index 87c3aadd3774..000000000000 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./KnowledgeToFeed"; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/index.ts b/frontend/app/chat/[chatId]/components/ActionsBar/components/index.ts index 77e5a42989a6..924a01f255d8 100644 --- a/frontend/app/chat/[chatId]/components/ActionsBar/components/index.ts +++ b/frontend/app/chat/[chatId]/components/ActionsBar/components/index.ts @@ -1,2 +1 @@ export * from "./ChatInput"; -export * from "./KnowledgeToFeed"; diff --git a/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/ChatItem/QADisplay/components/MessageRow/components/Source/Source.tsx b/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/ChatItem/QADisplay/components/MessageRow/components/Source/Source.tsx index 658156bea757..092f9f9900df 100644 --- a/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/ChatItem/QADisplay/components/MessageRow/components/Source/Source.tsx +++ b/frontend/app/chat/[chatId]/components/ChatDialogueArea/components/ChatDialogue/components/ChatItem/QADisplay/components/MessageRow/components/Source/Source.tsx @@ -19,7 +19,7 @@ export const SourceCitations = ({ sourceFile }: SourceProps): JSX.Element => { const [isCitationModalOpened, setIsCitationModalOpened] = useState(false); const [citationIndex, setCitationIndex] = useState(0); - const { integrationIconUrls } = useSync(); + const { providerIconUrls } = useSync(); return (
@@ -44,7 +44,11 @@ export const SourceCitations = ({ sourceFile }: SourceProps): JSX.Element => { {sourceFile.filename} {sourceFile.integration ? ( integration_icon { const chatListRef = useRef(null); const { messages } = useChat(); - const { shouldDisplayFeedCard } = useKnowledgeToFeedContext(); const scrollToBottom = useCallback( _debounce(() => { @@ -45,7 +43,7 @@ export const useChatDialogue = () => { useEffect(() => { scrollToBottom(); - }, [messages, scrollToBottom, shouldDisplayFeedCard]); + }, [messages, scrollToBottom]); return { chatListRef, diff --git a/frontend/app/chat/[chatId]/hooks/useChatNotificationsSync.ts b/frontend/app/chat/[chatId]/hooks/useChatNotificationsSync.ts deleted file mode 100644 index 294b46a3c17c..000000000000 --- a/frontend/app/chat/[chatId]/hooks/useChatNotificationsSync.ts +++ /dev/null @@ -1,78 +0,0 @@ -import { useQuery } from "@tanstack/react-query"; -import { useParams } from "next/navigation"; -import { useEffect } from "react"; - -import { useChatApi } from "@/lib/api/chat/useChatApi"; -import { useNotificationApi } from "@/lib/api/notification/useNotificationApi"; -import { useChatContext } from "@/lib/context"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; - -import { getChatNotificationsQueryKey } from "../utils/getChatNotificationsQueryKey"; -import { getMessagesFromChatItems } from "../utils/getMessagesFromChatItems"; -import { getNotificationsFromChatItems } from "../utils/getNotificationsFromChatItems"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useChatNotificationsSync = () => { - const { setMessages, setNotifications, notifications } = useChatContext(); - const { getChatItems } = useChatApi(); - const { getChatNotifications } = useNotificationApi(); - const { setShouldDisplayFeedCard } = useKnowledgeToFeedContext(); - const params = useParams(); - const chatId = params?.chatId as string | undefined; - - const chatNotificationsQueryKey = getChatNotificationsQueryKey(chatId ?? ""); - const { data: fetchedNotifications = [] } = useQuery({ - queryKey: [chatNotificationsQueryKey], - enabled: notifications.length > 0, - queryFn: () => { - if (chatId === undefined) { - return []; - } - - return getChatNotifications(chatId); - }, - refetchInterval: () => { - if (notifications.length === 0) { - return false; - } - const hasAPendingNotification = notifications.find( - (item) => item.status === "Pending" - ); - - if (hasAPendingNotification) { - return 2_000; // in ms - } - - return false; - }, - }); - - useEffect(() => { - if (fetchedNotifications.length === 0) { - return; - } - setNotifications(fetchedNotifications); - }, [fetchedNotifications]); - - useEffect(() => { - setShouldDisplayFeedCard(false); - const fetchHistory = async () => { - if (chatId === undefined) { - setMessages([]); - setNotifications([]); - - return; - } - const chatItems = await getChatItems(chatId); - const messagesFromChatItems = getMessagesFromChatItems(chatItems); - if ( - messagesFromChatItems.length > 1 || - (messagesFromChatItems[0] && messagesFromChatItems[0].assistant !== "") - ) { - setMessages(messagesFromChatItems); - setNotifications(getNotificationsFromChatItems(chatItems)); - } - }; - void fetchHistory(); - }, [chatId]); -}; diff --git a/frontend/app/chat/[chatId]/page.tsx b/frontend/app/chat/[chatId]/page.tsx index af1e8ee4d055..b9ed73b82c01 100644 --- a/frontend/app/chat/[chatId]/page.tsx +++ b/frontend/app/chat/[chatId]/page.tsx @@ -5,31 +5,20 @@ import { useEffect } from "react"; import { AddBrainModal } from "@/lib/components/AddBrainModal"; import { useBrainCreationContext } from "@/lib/components/AddBrainModal/brainCreation-provider"; import { PageHeader } from "@/lib/components/PageHeader/PageHeader"; -import { UploadDocumentModal } from "@/lib/components/UploadDocumentModal/UploadDocumentModal"; import { useChatContext } from "@/lib/context"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { useCustomDropzone } from "@/lib/hooks/useDropzone"; import { ButtonType } from "@/lib/types/QuivrButton"; import { cn } from "@/lib/utils"; import { ActionsBar } from "./components/ActionsBar"; import { ChatDialogueArea } from "./components/ChatDialogueArea/ChatDialogue"; -import { useChatNotificationsSync } from "./hooks/useChatNotificationsSync"; import styles from "./page.module.scss"; const SelectedChatPage = (): JSX.Element => { - const { getRootProps } = useCustomDropzone(); - - const { setShouldDisplayFeedCard, shouldDisplayFeedCard } = - useKnowledgeToFeedContext(); const { setIsBrainCreationModalOpened } = useBrainCreationContext(); - const { currentBrain, setCurrentBrainId } = useBrainContext(); const { messages } = useChatContext(); - useChatNotificationsSync(); - const buttons: ButtonType[] = [ { label: "Create brain", @@ -39,15 +28,6 @@ const SelectedChatPage = (): JSX.Element => { }, iconName: "brain", }, - { - label: "Add knowledge", - color: "primary", - onClick: () => { - setShouldDisplayFeedCard(true); - }, - iconName: "uploadFile", - hidden: !currentBrain?.max_files, - }, { label: "Manage current brain", color: "primary", @@ -74,10 +54,7 @@ const SelectedChatPage = (): JSX.Element => {
-
+
{
-
diff --git a/frontend/app/globals.css b/frontend/app/globals.css index 5fa74f1c5a44..17394213c544 100644 --- a/frontend/app/globals.css +++ b/frontend/app/globals.css @@ -3,8 +3,6 @@ @import "tailwindcss/utilities"; @import './colors.css'; - - main { @apply max-w-screen-xl mx-auto flex flex-col; } @@ -57,6 +55,7 @@ div:focus { --background-3: var(--grey-4); --background-4: var(--grey-0); --background-5: var(--black-0); + --background-6: var(--grey-4); --background-primary-0: var(--primary-2); --background-primary-1: var(--primary-1); --background-blur: rgba(0, 0, 0, 0.9); @@ -74,6 +73,7 @@ div:focus { --icon-1: var(--grey-2); --icon-2: var(--grey-0); --icon-3: var(--black-0); + --icon-4: var(--grey-1); /* Text */ --text-0: var(--white-0); @@ -94,6 +94,7 @@ body.dark_mode { --background-3: var(--black-3); --background-4: var(--black-4); --background-5: var(--black-5); + --background-6: var(--black-5); --background-primary-0: var(--black-5); --background-primary-1: var(--black-5); --background-opposite: var(--white-0); @@ -112,6 +113,7 @@ body.dark_mode { --icon-1: var(--grey-0); --icon-2: var(--grey-2); --icon-3: var(--white-0); + --icon-4: var(--grey-2); /* Text */ --text-0: var(--black-0); diff --git a/frontend/app/knowledge/page.module.scss b/frontend/app/knowledge/page.module.scss new file mode 100644 index 000000000000..414dff05c156 --- /dev/null +++ b/frontend/app/knowledge/page.module.scss @@ -0,0 +1,13 @@ +.main_container { + display: flex; + flex-direction: column; + width: 100%; + height: 100vh; + overflow: hidden; + + .kms_wrapper { + width: 100%; + height: 100%; + overflow: hidden; + } +} diff --git a/frontend/app/knowledge/page.tsx b/frontend/app/knowledge/page.tsx new file mode 100644 index 000000000000..ae37a85eab9c --- /dev/null +++ b/frontend/app/knowledge/page.tsx @@ -0,0 +1,21 @@ +"use client"; + +import KnowledgeManagementSystem from "@/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem"; +import { PageHeader } from "@/lib/components/PageHeader/PageHeader"; + +import styles from "./page.module.scss"; + +const Knowledge = (): JSX.Element => { + return ( +
+
+ +
+
+ +
+
+ ); +}; + +export default Knowledge; diff --git a/frontend/app/quality-assistant/AssistantTab/AssistantTab.tsx b/frontend/app/quality-assistant/AssistantTab/AssistantTab.tsx index 1fadcbc6d4ee..1ee3d2cc9feb 100644 --- a/frontend/app/quality-assistant/AssistantTab/AssistantTab.tsx +++ b/frontend/app/quality-assistant/AssistantTab/AssistantTab.tsx @@ -5,7 +5,7 @@ import { useEffect, useState } from "react"; import { useAssistants } from "@/lib/api/assistants/useAssistants"; import { FileInput } from "@/lib/components/ui/FileInput/FileInput"; import { Icon } from "@/lib/components/ui/Icon/Icon"; -import QuivrButton from "@/lib/components/ui/QuivrButton/QuivrButton"; +import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; import AssistantCard from "./AssistantCard/AssistantCard"; import styles from "./AssistantTab.module.scss"; @@ -221,7 +221,9 @@ const AssistantTab = ({ setSelectedTab }: AssistantTabProps): JSX.Element => { handleFileChange(input.key, file)} + onFileChange={(files) => + handleFileChange(input.key, files[0]) + } acceptedFileTypes={FILE_TYPES} /> diff --git a/frontend/app/quality-assistant/page.tsx b/frontend/app/quality-assistant/page.tsx index a4da3eb48c40..18a1601dbfe5 100644 --- a/frontend/app/quality-assistant/page.tsx +++ b/frontend/app/quality-assistant/page.tsx @@ -2,7 +2,7 @@ import { useState } from "react"; -import PageHeader from "@/lib/components/PageHeader/PageHeader"; +import { PageHeader } from "@/lib/components/PageHeader/PageHeader"; import { Tabs } from "@/lib/components/ui/Tabs/Tabs"; import { Tab } from "@/lib/types/Tab"; diff --git a/frontend/app/search/page.tsx b/frontend/app/search/page.tsx index 256ec4c95963..e62c6fdff182 100644 --- a/frontend/app/search/page.tsx +++ b/frontend/app/search/page.tsx @@ -7,7 +7,6 @@ import { AddBrainModal } from "@/lib/components/AddBrainModal"; import { useBrainCreationContext } from "@/lib/components/AddBrainModal/brainCreation-provider"; import { OnboardingModal } from "@/lib/components/OnboardingModal/OnboardingModal"; import { PageHeader } from "@/lib/components/PageHeader/PageHeader"; -import { UploadDocumentModal } from "@/lib/components/UploadDocumentModal/UploadDocumentModal"; import { SearchBar } from "@/lib/components/ui/SearchBar/SearchBar"; import { SmallTabs } from "@/lib/components/ui/SmallTabs/SmallTabs"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; @@ -23,7 +22,6 @@ import styles from "./page.module.scss"; const projectName = process.env.NEXT_PUBLIC_PROJECT_NAME; - const Search = (): JSX.Element => { const [selectedTab, setSelectedTab] = useState("Models"); const [isNewBrain, setIsNewBrain] = useState(false); @@ -113,7 +111,9 @@ const Search = (): JSX.Element => {
Talk to - {projectName ? projectName : "Quivr"} + + {projectName ? projectName : "Quivr"} +
@@ -132,7 +132,6 @@ const Search = (): JSX.Element => {
- diff --git a/frontend/app/studio/BrainsTabs/components/Analytics/Analytics.tsx b/frontend/app/studio/BrainsTabs/components/Analytics/Analytics.tsx index f25546b5c96a..54f8bf6bc7b3 100644 --- a/frontend/app/studio/BrainsTabs/components/Analytics/Analytics.tsx +++ b/frontend/app/studio/BrainsTabs/components/Analytics/Analytics.tsx @@ -14,13 +14,13 @@ import { import { useLayoutEffect, useState } from "react"; import { Line } from "react-chartjs-2"; -import { formatMinimalBrainsToSelectComponentInput } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/utils/formatMinimalBrainsToSelectComponentInput"; import { Range } from "@/lib/api/analytics/types"; import { useAnalytics } from "@/lib/api/analytics/useAnalyticsApi"; import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; import { MessageInfoBox } from "@/lib/components/ui/MessageInfoBox/MessageInfoBox"; import { SingleSelector } from "@/lib/components/ui/SingleSelector/SingleSelector"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; +import { formatMinimalBrainsToSelectComponentInput } from "@/lib/helpers/formatBrains"; import { useDevice } from "@/lib/hooks/useDevice"; import styles from "./Analytics.module.scss"; diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.module.scss b/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.module.scss index 5d884f035b08..d34426540ca3 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.module.scss +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.module.scss @@ -1,22 +1,30 @@ @use "styles/Spacings.module.scss"; -.loader { - display: flex; - justify-content: center; - align-items: center; - width: 100%; +.main_container { height: 100%; -} - -.header_wrapper { - display: flex; - width: calc(100% + (Spacings.$spacing05 + Spacings.$spacing03)); - margin-left: -(Spacings.$spacing05 + Spacings.$spacing03); - gap: Spacings.$spacing03; - align-items: center; - padding-top: Spacings.$spacing05; - .tabs { + .loader { + display: flex; + justify-content: center; + align-items: center; width: 100%; + height: 100%; + } + + .header_wrapper { + display: flex; + width: calc(100% + (Spacings.$spacing05 + Spacings.$spacing03)); + margin-left: -(Spacings.$spacing05 + Spacings.$spacing03); + gap: Spacings.$spacing03; + align-items: center; + padding-top: Spacings.$spacing05; + + .tabs { + width: 100%; + } + } + + .knowledge_tab { + height: 100%; } } diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.tsx b/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.tsx index 69db8747032b..c44024c668dc 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.tsx +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/BrainManagementTabs.tsx @@ -17,7 +17,7 @@ import { useBrainFetcher } from "./hooks/useBrainFetcher"; import { useBrainManagementTabs } from "./hooks/useBrainManagementTabs"; export const BrainManagementTabs = (): JSX.Element => { - const [selectedTab, setSelectedTab] = useState("Knowledge"); + const [selectedTab, setSelectedTab] = useState("Settings"); const { brainId, hasEditRights } = useBrainManagementTabs(); const { allKnowledge } = useAddedKnowledge({ brainId: brainId ?? undefined }); const router = useRouter(); @@ -27,21 +27,25 @@ export const BrainManagementTabs = (): JSX.Element => { }); const brainManagementTabs: Tab[] = [ + { + label: "Settings", + isSelected: selectedTab === "Settings", + onClick: () => setSelectedTab("Settings"), + iconName: "settings", + }, { label: hasEditRights - ? `Knowledge${allKnowledge.length > 1 ? "s" : ""} (${ - allKnowledge.length + ? `Knowledge${ + allKnowledge.filter((knowledge) => !knowledge.is_folder).length > 1 + ? "s" + : "" + } (${ + allKnowledge.filter((knowledge) => !knowledge.is_folder).length })` : "Knowledge", isSelected: selectedTab === "Knowledge", onClick: () => setSelectedTab("Knowledge"), - iconName: "file", - }, - { - label: "Settings", - isSelected: selectedTab === "Settings", - onClick: () => setSelectedTab("Settings"), - iconName: "settings", + iconName: "knowledge", }, { label: "People", @@ -65,7 +69,7 @@ export const BrainManagementTabs = (): JSX.Element => { } return ( -
+
{ )} {selectedTab === "People" && } {selectedTab === "Knowledge" && ( - +
+ +
)}
); diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.module.scss b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.module.scss index 3d1f588a6666..39700bcdbee9 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.module.scss +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.module.scss @@ -4,6 +4,7 @@ .knowledge_tab_container { padding-block: Spacings.$spacing05; + height: 100%; .knowledge_tab_wrapper { display: flex; @@ -11,6 +12,7 @@ width: 100%; gap: Spacings.$spacing05; padding-block: Spacings.$spacing05; + height: 100%; .message { display: inline-flex; diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.tsx b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.tsx index 1cf5653ad831..8f8d11507d7e 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.tsx +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTab.tsx @@ -1,12 +1,9 @@ "use client"; import { UUID } from "crypto"; -import { AnimatePresence, motion } from "framer-motion"; +import { KMSElement } from "@/lib/api/sync/types"; import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; import { MessageInfoBox } from "@/lib/components/ui/MessageInfoBox/MessageInfoBox"; -import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { Knowledge } from "@/lib/types/Knowledge"; import styles from "./KnowledgeTab.module.scss"; import KnowledgeTable from "./KnowledgeTable/KnowledgeTable"; @@ -15,7 +12,7 @@ import { useAddedKnowledge } from "./hooks/useAddedKnowledge"; type KnowledgeTabProps = { brainId: UUID; hasEditRights: boolean; - allKnowledge: Knowledge[]; + allKnowledge: KMSElement[]; }; export const KnowledgeTab = ({ brainId, @@ -25,7 +22,6 @@ export const KnowledgeTab = ({ const { isPending } = useAddedKnowledge({ brainId, }); - const { setShouldDisplayFeedCard } = useKnowledgeToFeedContext(); if (!hasEditRights) { return ( @@ -44,21 +40,12 @@ export const KnowledgeTab = ({ return ; } - if (allKnowledge.length === 0) { + if (allKnowledge.filter((knowledge) => !knowledge.is_folder).length === 0) { return (
-
- This brain is empty! You can add knowledge by clicking on - setShouldDisplayFeedCard(true)} - /> - . -
+
This brain is empty!
@@ -68,11 +55,11 @@ export const KnowledgeTab = ({ return (
- - - - - + !knowledge.is_folder + )} + />
); diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeItem/KnowledgeItem.tsx b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeItem/KnowledgeItem.tsx index 7a23e7fcb269..b53178e3eef8 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeItem/KnowledgeItem.tsx +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeItem/KnowledgeItem.tsx @@ -4,6 +4,7 @@ import Image from "next/image"; import React, { useEffect, useRef, useState } from "react"; import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; +import { KMSElement } from "@/lib/api/sync/types"; import { useSync } from "@/lib/api/sync/useSync"; import { Checkbox } from "@/lib/components/ui/Checkbox/Checkbox"; import { Icon } from "@/lib/components/ui/Icon/Icon"; @@ -12,7 +13,6 @@ import { Tag } from "@/lib/components/ui/Tag/Tag"; import { iconList } from "@/lib/helpers/iconList"; import { useUrlBrain } from "@/lib/hooks/useBrainIdFromUrl"; import { useDevice } from "@/lib/hooks/useDevice"; -import { isUploadedKnowledge, Knowledge } from "@/lib/types/Knowledge"; import { Option } from "@/lib/types/Options"; import { useKnowledgeItem } from "./hooks/useKnowledgeItem"; @@ -25,7 +25,7 @@ const KnowledgeItem = ({ setSelected, lastChild, }: { - knowledge: Knowledge; + knowledge: KMSElement; selected: boolean; setSelected: (selected: boolean, event: React.MouseEvent) => void; lastChild?: boolean; @@ -37,7 +37,7 @@ const KnowledgeItem = ({ const { brain } = useUrlBrain(); const { generateSignedUrlKnowledge } = useKnowledgeApi(); const { isMobile } = useDevice(); - const { integrationIconUrls } = useSync(); + const { providerIconUrls } = useSync(); const getOptions = (): Option[] => [ { @@ -52,12 +52,12 @@ const KnowledgeItem = ({ onClick: () => void downloadFile(), iconName: "download", iconColor: "primary", - disabled: brain?.role !== "Owner" || !isUploadedKnowledge(knowledge), + disabled: brain?.role !== "Owner" || !!knowledge.url, }, ]; const downloadFile = async () => { - if (isUploadedKnowledge(knowledge)) { + if (!knowledge.url && knowledge.file_name) { const downloadUrl = await generateSignedUrlKnowledge({ knowledgeId: knowledge.id, }); @@ -71,7 +71,7 @@ const KnowledgeItem = ({ const a = document.createElement("a"); a.href = blobUrl; - a.download = knowledge.fileName; + a.download = knowledge.file_name; document.body.appendChild(a); a.click(); @@ -103,14 +103,10 @@ const KnowledgeItem = ({ }, []); const renderIcon = () => { - if (isUploadedKnowledge(knowledge)) { - return knowledge.integration ? ( + if (!knowledge.url) { + return knowledge.source !== "local" ? ( integration_icon { - if (isUploadedKnowledge(knowledge)) { - return {knowledge.fileName}; + if (!knowledge.url) { + return {knowledge.file_name}; } return ( @@ -161,7 +157,7 @@ const KnowledgeItem = ({ {!isMobile && (
{ const { t } = useTranslation(["translation", "explore"]); - const onDeleteKnowledge = async (knowledge: Knowledge) => { + const onDeleteKnowledge = async (knowledge: KMSElement) => { setIsDeleting(true); void track("DELETE_DOCUMENT"); const knowledge_name = diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.module.scss b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.module.scss index 8e5f3ea062d8..f80d46f3cf79 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.module.scss +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.module.scss @@ -8,9 +8,16 @@ gap: Spacings.$spacing05; padding-bottom: Spacings.$spacing10; border-radius: Radius.$normal; + height: 100%; - .title { - @include Typography.H2; + .content_header { + display: flex; + align-items: center; + justify-content: space-between; + + .title { + @include Typography.H2; + } } .table_header { @@ -18,73 +25,83 @@ justify-content: space-between; align-items: center; gap: Spacings.$spacing03; + padding-bottom: Spacings.$spacing05; .search { width: 250px; } } - .first_line { - display: flex; - justify-content: space-between; - padding-left: calc(Spacings.$spacing06); - padding-right: Spacings.$spacing04; - padding-block: Spacings.$spacing02; - font-weight: 500; - background-color: var(--background-1); - font-size: Typography.$small; - border: 1px solid var(--border-0); - border-radius: Radius.$normal Radius.$normal 0 0; - border-bottom: none; - - &.empty { - border: 1px solid var(--border-0); - border-radius: Radius.$normal; - } + .content { + height: 100%; - .left { + .first_line { display: flex; - align-items: center; - gap: calc(Spacings.$spacing06 + 6px); + justify-content: space-between; + padding-left: calc(Spacings.$spacing06); + padding-right: Spacings.$spacing04; + padding-block: Spacings.$spacing02; + font-weight: 500; + background-color: var(--background-1); + font-size: Typography.$small; + border: 1px solid var(--border-0); + border-radius: Radius.$normal Radius.$normal 0 0; + border-bottom: none; + + &.empty { + border: 1px solid var(--border-0); + border-radius: Radius.$normal; + } - .name { + .left { display: flex; align-items: center; - gap: Spacings.$spacing02; - cursor: pointer; + gap: calc(Spacings.$spacing06 + 6px); - .icon { - visibility: hidden; - } + .name { + display: flex; + align-items: center; + gap: Spacings.$spacing02; + cursor: pointer; - &:hover { .icon { - visibility: visible; + visibility: hidden; + } + + &:hover { + .icon { + visibility: visible; + } } } } - } - - .right { - display: flex; - gap: calc(Spacings.$spacing06 + 12px); - .status { + .right { display: flex; - align-items: center; - gap: Spacings.$spacing02; - cursor: pointer; + gap: calc(Spacings.$spacing06 + 12px); - .icon { - visibility: hidden; - } + .status { + display: flex; + align-items: center; + gap: Spacings.$spacing02; + cursor: pointer; - &:hover { .icon { - visibility: visible; + visibility: hidden; + } + + &:hover { + .icon { + visibility: visible; + } } } } } + + .kms { + height: 100vh; + margin-inline: -(Spacings.$spacing09); + } } } diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.tsx b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.tsx index 7bfd982907a1..b93231b8fab1 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.tsx +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/KnowledgeTable/KnowledgeTable.tsx @@ -1,31 +1,32 @@ import React, { useEffect, useState } from "react"; +import { KMSElement } from "@/lib/api/sync/types"; +import KnowledgeManagementSystem from "@/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem"; import { Checkbox } from "@/lib/components/ui/Checkbox/Checkbox"; import { Icon } from "@/lib/components/ui/Icon/Icon"; import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; +import { SwitchButton } from "@/lib/components/ui/SwitchButton/SwitchButton"; import { TextInput } from "@/lib/components/ui/TextInput/TextInput"; import { updateSelectedItems } from "@/lib/helpers/table"; import { useDevice } from "@/lib/hooks/useDevice"; -import { isUploadedKnowledge, Knowledge } from "@/lib/types/Knowledge"; -import { useKnowledgeItem } from "./KnowledgeItem/hooks/useKnowledgeItem"; -// eslint-disable-next-line import/order import KnowledgeItem from "./KnowledgeItem/KnowledgeItem"; +import { useKnowledgeItem } from "./KnowledgeItem/hooks/useKnowledgeItem"; import styles from "./KnowledgeTable.module.scss"; interface KnowledgeTableProps { - knowledgeList: Knowledge[]; + knowledgeList: KMSElement[]; } const filterAndSortKnowledge = ( - knowledgeList: Knowledge[], + knowledgeList: KMSElement[], searchQuery: string, sortConfig: { key: string; direction: string } -): Knowledge[] => { +): KMSElement[] => { let filteredList = knowledgeList.filter((knowledge) => - isUploadedKnowledge(knowledge) - ? knowledge.fileName.toLowerCase().includes(searchQuery.toLowerCase()) - : knowledge.url.toLowerCase().includes(searchQuery.toLowerCase()) + (knowledge.file_name ?? knowledge.url ?? "") + .toLowerCase() + .includes(searchQuery.toLowerCase()) ); if (sortConfig.key) { @@ -40,9 +41,9 @@ const filterAndSortKnowledge = ( return 0; }; - const getComparableValue = (item: Knowledge) => { + const getComparableValue = (item: KMSElement) => { if (sortConfig.key === "name") { - return isUploadedKnowledge(item) ? item.fileName : item.url; + return item.url ?? item.file_name; } if (sortConfig.key === "status") { return item.status; @@ -52,7 +53,7 @@ const filterAndSortKnowledge = ( }; filteredList = filteredList.sort((a, b) => - compareStrings(getComparableValue(a), getComparableValue(b)) + compareStrings(getComparableValue(a) ?? "", getComparableValue(b) ?? "") ); } @@ -61,7 +62,9 @@ const filterAndSortKnowledge = ( const KnowledgeTable = React.forwardRef( ({ knowledgeList }, ref) => { - const [selectedKnowledge, setSelectedKnowledge] = useState([]); + const [selectedKnowledge, setSelectedKnowledge] = useState( + [] + ); const [lastSelectedIndex, setLastSelectedIndex] = useState( null ); @@ -69,7 +72,9 @@ const KnowledgeTable = React.forwardRef( const [allChecked, setAllChecked] = useState(false); const [searchQuery, setSearchQuery] = useState(""); const [filteredKnowledgeList, setFilteredKnowledgeList] = - useState(knowledgeList); + useState(knowledgeList); + const [allKnowledgeMode, setAllKnowledgeMode] = useState(false); + const { isMobile } = useDevice(); const [sortConfig, setSortConfig] = useState<{ key: string; @@ -83,11 +88,11 @@ const KnowledgeTable = React.forwardRef( }, [searchQuery, knowledgeList, sortConfig]); const handleSelect = ( - knowledge: Knowledge, + knowledge: KMSElement, index: number, event: React.MouseEvent ) => { - const newSelectedKnowledge = updateSelectedItems({ + const newSelectedKnowledge = updateSelectedItems({ item: knowledge, index, event, @@ -125,78 +130,98 @@ const KnowledgeTable = React.forwardRef( return (
- Uploaded Knowledge -
-
- -
- + Uploaded Knowledge +
-
-
-
- { - setAllChecked(checked); - setSelectedKnowledge(checked ? filteredKnowledgeList : []); - }} - /> -
handleSort("name")}> - Name -
- +
+ {!allKnowledgeMode ? ( + <> +
+
+
+
-
-
- {!isMobile && ( +
+
+ { + setAllChecked(checked); + setSelectedKnowledge( + checked ? filteredKnowledgeList : [] + ); + }} + /> +
handleSort("name")} + > + Name +
+ +
+
+
+
+ {!isMobile && ( +
handleSort("status")} + > + Status +
+ +
+
+ )} + Actions +
+
+ {filteredKnowledgeList.map((knowledge, index) => (
handleSort("status")} + key={knowledge.id} + onClick={(event) => handleSelect(knowledge, index, event)} > - Status -
- -
+ item.id === knowledge.id + )} + setSelected={(_selected, event) => + handleSelect(knowledge, index, event) + } + lastChild={index === filteredKnowledgeList.length - 1} + />
- )} - Actions -
-
- {filteredKnowledgeList.map((knowledge, index) => ( -
handleSelect(knowledge, index, event)} - > - item.id === knowledge.id - )} - setSelected={(_selected, event) => - handleSelect(knowledge, index, event) - } - lastChild={index === filteredKnowledgeList.length - 1} - /> + ))} + + ) : ( +
+
- ))} + )}
); diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useAddedKnowledge.ts b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useAddedKnowledge.ts index 984a997b5143..ceb57a5db23c 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useAddedKnowledge.ts +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useAddedKnowledge.ts @@ -8,11 +8,11 @@ import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; export const useAddedKnowledge = ({ brainId }: { brainId?: UUID }) => { const queryClient = useQueryClient(); - const { getAllKnowledge } = useKnowledgeApi(); + const { getAllBrainKnowledge } = useKnowledgeApi(); const fetchKnowledge = () => { if (brainId !== undefined) { - return getAllKnowledge({ brainId }); + return getAllBrainKnowledge({ brainId }); } }; diff --git a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useKnowledge.ts b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useKnowledge.ts index a13a0ffed3ac..cd01d9c9dc58 100644 --- a/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useKnowledge.ts +++ b/frontend/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useKnowledge.ts @@ -8,10 +8,10 @@ import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; export const useKnowledge = ({ brainId }: { brainId?: UUID }) => { const queryClient = useQueryClient(); - const { getAllKnowledge } = useKnowledgeApi(); + const { getAllBrainKnowledge } = useKnowledgeApi(); const { data: allKnowledge, isLoading: isPending } = useQuery({ queryKey: brainId ? [getKnowledgeDataKey(brainId)] : [], - queryFn: () => (brainId ? getAllKnowledge({ brainId: brainId }) : []), + queryFn: () => (brainId ? getAllBrainKnowledge({ brainId: brainId }) : []), enabled: brainId !== undefined, }); diff --git a/frontend/app/studio/[brainId]/page.module.scss b/frontend/app/studio/[brainId]/page.module.scss index 2c50d8d85ee8..77b4fda2bef9 100644 --- a/frontend/app/studio/[brainId]/page.module.scss +++ b/frontend/app/studio/[brainId]/page.module.scss @@ -2,10 +2,11 @@ .brain_management_wrapper { width: 100%; - height: 100vh; + min-height: 100%; .content_wrapper { padding-block: Spacings.$spacing05; padding-inline: Spacings.$spacing09; + height: 100%; } } diff --git a/frontend/app/studio/[brainId]/page.tsx b/frontend/app/studio/[brainId]/page.tsx index 93d28e3531be..d7945b23e6a4 100644 --- a/frontend/app/studio/[brainId]/page.tsx +++ b/frontend/app/studio/[brainId]/page.tsx @@ -3,9 +3,7 @@ import { useEffect } from "react"; import { PageHeader } from "@/lib/components/PageHeader/PageHeader"; -import { UploadDocumentModal } from "@/lib/components/UploadDocumentModal/UploadDocumentModal"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; import { useSearchModalContext } from "@/lib/context/SearchModalProvider/hooks/useSearchModalContext"; import { ButtonType } from "@/lib/types/QuivrButton"; @@ -30,7 +28,6 @@ const BrainsManagement = (): JSX.Element => { brainId: brain?.id, userAccessibleBrains: allBrains, }); - const { setShouldDisplayFeedCard } = useKnowledgeToFeedContext(); const { setCurrentBrainId } = useBrainContext(); const buttons: ButtonType[] = [ @@ -45,15 +42,6 @@ const BrainsManagement = (): JSX.Element => { }, iconName: "chat", }, - { - label: "Add knowledge", - color: "primary", - onClick: () => { - setShouldDisplayFeedCard(true); - }, - iconName: "uploadFile", - hidden: !isOwnedByCurrentUser || !brain?.max_files, - }, { label: isOwnedByCurrentUser ? "Delete Brain" : "Unsubscribe from Brain", color: "dangerous", @@ -88,7 +76,6 @@ const BrainsManagement = (): JSX.Element => {
- { const [selectedTab, setSelectedTab] = useState("Manage my brains"); - const { setShouldDisplayFeedCard } = useKnowledgeToFeedContext(); const { setIsBrainCreationModalOpened } = useBrainCreationContext(); const { allBrains } = useBrainContext(); const { userData } = useUserData(); @@ -50,14 +47,6 @@ const Studio = (): JSX.Element => { tooltip: "You have reached the maximum number of brains allowed. Please upgrade your plan or delete some brains to create a new one.", }, - { - label: "Add knowledge", - color: "primary", - onClick: () => { - setShouldDisplayFeedCard(true); - }, - iconName: "uploadFile", - }, ]); useEffect(() => { @@ -93,7 +82,6 @@ const Studio = (): JSX.Element => { {selectedTab === "Manage my brains" && } {selectedTab === "Analytics" && }
-
); diff --git a/frontend/lib/api/knowledge/knowledge.ts b/frontend/lib/api/knowledge/knowledge.ts index 8f12d94edc15..0eca473b7e87 100644 --- a/frontend/lib/api/knowledge/knowledge.ts +++ b/frontend/lib/api/knowledge/knowledge.ts @@ -2,11 +2,13 @@ import { AxiosInstance } from "axios"; import { UUID } from "crypto"; import { - CrawledKnowledge, - Knowledge, - UploadedKnowledge, + AddFolderData, + AddKnowledgeFileData, + AddKnowledgeUrlData, } from "@/lib/types/Knowledge"; +import { KMSElement } from "../sync/types"; + export type GetAllKnowledgeInputProps = { brainId: UUID; }; @@ -22,43 +24,19 @@ interface BEKnowledge { integration_link: string; } -export const getAllKnowledge = async ( +export const getAllBrainKnowledge = async ( { brainId }: GetAllKnowledgeInputProps, axiosInstance: AxiosInstance -): Promise => { +): Promise => { const response = await axiosInstance.get<{ - knowledges: BEKnowledge[]; + knowledges: KMSElement[]; }>(`/knowledge?brain_id=${brainId}`); - return response.data.knowledges.map((knowledge) => { - if (knowledge.file_name !== null) { - return { - id: knowledge.id, - brainId: knowledge.brain_id, - fileName: knowledge.file_name, - extension: knowledge.extension, - status: knowledge.status, - integration: knowledge.integration, - integration_link: knowledge.integration_link, - } as UploadedKnowledge; - } else if (knowledge.url !== null) { - return { - id: knowledge.id, - brainId: knowledge.brain_id, - url: knowledge.url, - extension: "URL", - status: knowledge.status, - integration: knowledge.integration, - integration_link: knowledge.integration_link, - } as CrawledKnowledge; - } else { - throw new Error(`Invalid knowledge ${knowledge.id}`); - } - }); + return response.data.knowledges; }; export type DeleteKnowledgeInputProps = { - brainId: UUID; + brainId?: UUID; knowledgeId: UUID; }; @@ -79,3 +57,116 @@ export const generateSignedUrlKnowledge = async ( return response.data.signedURL; }; + +export const getFiles = async ( + parentId: UUID | null, + axiosInstance: AxiosInstance +): Promise => { + return ( + await axiosInstance.get(`/knowledge/files`, { + params: { parent_id: parentId }, + }) + ).data; +}; + +export const addFolder = async ( + knowledgeData: AddFolderData, + axiosInstance: AxiosInstance +): Promise => { + const formData = new FormData(); + formData.append("knowledge_data", JSON.stringify(knowledgeData)); + formData.append("file", ""); + + return ( + await axiosInstance.post(`/knowledge/`, formData, { + headers: { + "Content-Type": "multipart/form-data", + }, + }) + ).data; +}; + +export const addKnowledgeFile = async ( + knowledgeData: AddKnowledgeFileData, + file: File, + axiosInstance: AxiosInstance +): Promise => { + const formData = new FormData(); + formData.append("knowledge_data", JSON.stringify(knowledgeData)); + formData.append("file", file); + + return ( + await axiosInstance.post(`/knowledge/`, formData, { + headers: { + "Content-Type": "multipart/form-data", + }, + }) + ).data; +}; + +export const addKnowledgeUrl = async ( + knowledgeData: AddKnowledgeUrlData, + axiosInstance: AxiosInstance +): Promise => { + const formData = new FormData(); + formData.append("knowledge_data", JSON.stringify(knowledgeData)); + + return ( + await axiosInstance.post(`/knowledge/`, formData, { + headers: { + "Content-Type": "multipart/form-data", + }, + }) + ).data; +}; + +export const patchKnowledge = async ( + knowledge_id: UUID, + axiosInstance: AxiosInstance, + parent_id: UUID | null +): Promise => { + const data = { + parent_id, + }; + + const response: { data: KMSElement } = await axiosInstance.patch( + `/knowledge/${knowledge_id}`, + data + ); + + return response.data; +}; + +export const linkKnowledgeToBrains = async ( + knowledge: KMSElement, + brainIds: UUID[], + axiosInstance: AxiosInstance +): Promise => { + const response = await axiosInstance.post( + `/knowledge/link_to_brains/`, + { + knowledge, + brain_ids: brainIds, + } + ); + + return response.data; +}; + +export const unlinkKnowledgeFromBrains = async ( + knowledge_id: UUID, + brainIds: UUID[], + axiosInstance: AxiosInstance +): Promise => { + const response = await axiosInstance.delete( + `/knowledge/unlink_from_brains/`, + { + data: { + knowledge_id, + brain_ids: brainIds, + }, + } + ); + + return response.data; +}; diff --git a/frontend/lib/api/knowledge/useKnowledgeApi.ts b/frontend/lib/api/knowledge/useKnowledgeApi.ts index 8872377ee8b4..7a882c4b3fbe 100644 --- a/frontend/lib/api/knowledge/useKnowledgeApi.ts +++ b/frontend/lib/api/knowledge/useKnowledgeApi.ts @@ -1,25 +1,55 @@ import { UUID } from "crypto"; import { useAxios } from "@/lib/hooks"; +import { + AddFolderData, + AddKnowledgeFileData, + AddKnowledgeUrlData, +} from "@/lib/types/Knowledge"; import { + addFolder, + addKnowledgeFile, + addKnowledgeUrl, deleteKnowledge, DeleteKnowledgeInputProps, generateSignedUrlKnowledge, - getAllKnowledge, + getAllBrainKnowledge, GetAllKnowledgeInputProps, + getFiles, + linkKnowledgeToBrains, + patchKnowledge, + unlinkKnowledgeFromBrains, } from "./knowledge"; +import { KMSElement } from "../sync/types"; + // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export const useKnowledgeApi = () => { const { axiosInstance } = useAxios(); return { - getAllKnowledge: async (props: GetAllKnowledgeInputProps) => - getAllKnowledge(props, axiosInstance), + getAllBrainKnowledge: async (props: GetAllKnowledgeInputProps) => + getAllBrainKnowledge(props, axiosInstance), deleteKnowledge: async (props: DeleteKnowledgeInputProps) => deleteKnowledge(props, axiosInstance), generateSignedUrlKnowledge: async (props: { knowledgeId: UUID }) => generateSignedUrlKnowledge(props, axiosInstance), + getFiles: async (parentId: UUID | null) => + getFiles(parentId, axiosInstance), + addFolder: async (addFolderData: AddFolderData) => + addFolder(addFolderData, axiosInstance), + addKnowledgeFile: async ( + addKnowledgeData: AddKnowledgeFileData, + file: File + ) => addKnowledgeFile(addKnowledgeData, file, axiosInstance), + addKnowledgeUrl: async (addKnowledgeData: AddKnowledgeUrlData) => + addKnowledgeUrl(addKnowledgeData, axiosInstance), + patchKnowledge: async (knowledgeId: UUID, parent_id: UUID | null) => + patchKnowledge(knowledgeId, axiosInstance, parent_id), + linkKnowledgeToBrains: async (knowledge: KMSElement, brainIds: UUID[]) => + linkKnowledgeToBrains(knowledge, brainIds, axiosInstance), + unlinkKnowledgeFromBrains: async (knowledge_id: UUID, brainIds: UUID[]) => + unlinkKnowledgeFromBrains(knowledge_id, brainIds, axiosInstance), }; }; diff --git a/frontend/lib/api/sync/sync.ts b/frontend/lib/api/sync/sync.ts index c2a16004772e..ef40ea194966 100644 --- a/frontend/lib/api/sync/sync.ts +++ b/frontend/lib/api/sync/sync.ts @@ -1,19 +1,13 @@ import { AxiosInstance } from "axios"; import { UUID } from "crypto"; -import { - ActiveSync, - OpenedConnection, - Sync, - SyncElement, - SyncElements, -} from "./types"; +import { ActiveSync, KMSElement, OpenedConnection, Sync } from "./types"; -const createFilesSettings = (files: SyncElement[]) => - files.filter((file) => !file.is_folder).map((file) => file.id); +const createFilesSettings = (files: KMSElement[]) => + files.filter((file) => !file.is_folder).map((file) => file.sync_file_id); -const createFoldersSettings = (files: SyncElement[]) => - files.filter((file) => file.is_folder).map((file) => file.id); +const createFoldersSettings = (files: KMSElement[]) => + files.filter((file) => file.is_folder).map((file) => file.sync_file_id); export const syncGoogleDrive = async ( name: string, @@ -57,7 +51,7 @@ export const syncNotion = async ( `/sync/notion/authorize?name=${name}` ) ).data; -} +}; export const getUserSyncs = async ( axiosInstance: AxiosInstance @@ -69,12 +63,12 @@ export const getSyncFiles = async ( axiosInstance: AxiosInstance, userSyncId: number, folderId?: string -): Promise => { +): Promise => { const url = folderId ? `/sync/${userSyncId}/files?user_sync_id=${userSyncId}&folder_id=${folderId}` : `/sync/${userSyncId}/files?user_sync_id=${userSyncId}`; - return (await axiosInstance.get(url)).data; + return (await axiosInstance.get(url)).data; }; export const syncFiles = async ( @@ -87,8 +81,8 @@ export const syncFiles = async ( name: openedConnection.name, syncs_user_id: openedConnection.user_sync_id, settings: { - files: createFilesSettings(openedConnection.selectedFiles.files), - folders: createFoldersSettings(openedConnection.selectedFiles.files), + files: createFilesSettings(openedConnection.selectedFiles), + folders: createFoldersSettings(openedConnection.selectedFiles), }, brain_id: brainId, }) @@ -103,8 +97,8 @@ export const updateActiveSync = async ( await axiosInstance.put(`/sync/active/${openedConnection.id}`, { name: openedConnection.name, settings: { - files: createFilesSettings(openedConnection.selectedFiles.files), - folders: createFoldersSettings(openedConnection.selectedFiles.files), + files: createFilesSettings(openedConnection.selectedFiles), + folders: createFoldersSettings(openedConnection.selectedFiles), }, last_synced: openedConnection.last_synced, }) diff --git a/frontend/lib/api/sync/types.ts b/frontend/lib/api/sync/types.ts index 75fae2764adc..c3ae6da5411d 100644 --- a/frontend/lib/api/sync/types.ts +++ b/frontend/lib/api/sync/types.ts @@ -1,3 +1,7 @@ +import { UUID } from "crypto"; + +import { Brain } from "@/lib/context/BrainProvider/types"; + export type Provider = "Google" | "Azure" | "DropBox" | "Notion" | "GitHub"; export type Integration = @@ -9,19 +13,37 @@ export type Integration = export type SyncStatus = "SYNCING" | "SYNCED" | "ERROR" | "REMOVED"; -export interface SyncElement { - name?: string; - id: string; +export type KnowledgeStatus = + | "ERROR" + | "RESERVED" + | "PROCESSING" + | "PROCESSED" + | "UPLOADED"; + +export interface KMSElement { + brains: Brain[]; + id: UUID; + file_name?: string; + sync_file_id: string | null; is_folder: boolean; icon?: string; -} - -export interface SyncElements { - files: SyncElement[]; + sync_id: number | null; + parent_id: string | null; + parentKMSElement?: KMSElement; + fromProvider?: SyncsByProvider; + source: "local" | string; + last_synced_at: string; + url?: string; + extension?: string; + status?: KnowledgeStatus; + file_sha1?: string; + source_link?: string; + metadata?: string; } interface Credentials { token: string; + access_token: string; } export interface Sync { @@ -33,6 +55,11 @@ export interface Sync { status: SyncStatus; } +export interface SyncsByProvider { + provider: Provider; + syncs: Sync[]; +} + export interface SyncSettings { folders?: string[]; files?: string[]; @@ -57,7 +84,7 @@ export interface OpenedConnection { id: number | undefined; provider: Provider; submitted: boolean; - selectedFiles: SyncElements; + selectedFiles: KMSElement[]; name: string; last_synced: string; cleaned?: boolean; diff --git a/frontend/lib/api/sync/useSync.ts b/frontend/lib/api/sync/useSync.ts index 141ce71c97d3..bf5ed2d40a09 100644 --- a/frontend/lib/api/sync/useSync.ts +++ b/frontend/lib/api/sync/useSync.ts @@ -13,37 +13,24 @@ import { syncGoogleDrive, syncNotion, syncSharepoint, - updateActiveSync + updateActiveSync, } from "./sync"; -import { Integration, OpenedConnection, Provider } from "./types"; +import { OpenedConnection } from "./types"; // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export const useSync = () => { const { axiosInstance } = useAxios(); - const providerIconUrls: Record = { - Google: + const providerIconUrls: Record = { + google: "https://quivr-cms.s3.eu-west-3.amazonaws.com/gdrive_8316d080fd.png", - Azure: + azure: "https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png", - DropBox: + dropbox: "https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png", - Notion: + notion: "https://quivr-cms.s3.eu-west-3.amazonaws.com/Notion_app_logo_004168672c.png", - GitHub: - "https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png", - }; - - const integrationIconUrls: Record = { - "Google Drive": - "https://quivr-cms.s3.eu-west-3.amazonaws.com/gdrive_8316d080fd.png", - "Share Point": - "https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png", - Dropbox: - "https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png", - Notion: - "https://quivr-cms.s3.eu-west-3.amazonaws.com/Notion_app_logo_004168672c.png", - GitHub: + github: "https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png", }; @@ -62,7 +49,6 @@ export const useSync = () => { getUserSyncs: async () => getUserSyncs(axiosInstance), getSyncFiles: async (userSyncId: number, folderId?: string) => getSyncFiles(axiosInstance, userSyncId, folderId), - integrationIconUrls, providerIconUrls, syncFiles: async (openedConnection: OpenedConnection, brainId: UUID) => syncFiles(axiosInstance, openedConnection, brainId), diff --git a/frontend/lib/components/AddBrainModal/AddBrainModal.module.scss b/frontend/lib/components/AddBrainModal/AddBrainModal.module.scss index 6655484a0cf0..26d34a66ad5c 100644 --- a/frontend/lib/components/AddBrainModal/AddBrainModal.module.scss +++ b/frontend/lib/components/AddBrainModal/AddBrainModal.module.scss @@ -9,10 +9,6 @@ overflow: hidden; gap: Spacings.$spacing05; - .stepper_container { - width: 100%; - } - .content_wrapper { flex-grow: 1; overflow: auto; diff --git a/frontend/lib/components/AddBrainModal/AddBrainModal.tsx b/frontend/lib/components/AddBrainModal/AddBrainModal.tsx index bc4517d5526c..798050992ad4 100644 --- a/frontend/lib/components/AddBrainModal/AddBrainModal.tsx +++ b/frontend/lib/components/AddBrainModal/AddBrainModal.tsx @@ -1,54 +1,25 @@ import { useEffect } from "react"; -import { FormProvider, useForm } from "react-hook-form"; import { useTranslation } from "react-i18next"; -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; import { Modal } from "@/lib/components/ui/Modal/Modal"; -import { addBrainDefaultValues } from "@/lib/config/defaultBrainConfig"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; import styles from "./AddBrainModal.module.scss"; import { useBrainCreationContext } from "./brainCreation-provider"; -import { BrainMainInfosStep } from "./components/BrainMainInfosStep/BrainMainInfosStep"; -import { BrainRecapStep } from "./components/BrainRecapStep/BrainRecapStep"; -import { FeedBrainStep } from "./components/FeedBrainStep/FeedBrainStep"; -import { Stepper } from "./components/Stepper/Stepper"; -import { useBrainCreationSteps } from "./hooks/useBrainCreationSteps"; -import { CreateBrainProps } from "./types/types"; +import { BrainCreationForm } from "./components/BrainCreationForm/BrainCreationForm"; export const AddBrainModal = (): JSX.Element => { const { t } = useTranslation(["translation", "brain", "config"]); - const { currentStep, steps } = useBrainCreationSteps(); const { setCurrentBrainId } = useBrainContext(); const { setSnippetColor, setSnippetEmoji } = useBrainCreationContext(); const { isBrainCreationModalOpened, setIsBrainCreationModalOpened, - setCurrentSelectedBrain, setCreating, } = useBrainCreationContext(); - const { setCurrentSyncId, setOpenedConnections } = - useFromConnectionsContext(); - const { removeAllKnowledgeToFeed } = useKnowledgeToFeedContext(); - - const defaultValues: CreateBrainProps = { - ...addBrainDefaultValues, - setDefault: true, - brainCreationStep: "FIRST_STEP", - }; - - const methods = useForm({ - defaultValues, - }); useEffect(() => { - setCurrentSelectedBrain(undefined); - setCurrentSyncId(undefined); setCreating(false); - setOpenedConnections([]); - methods.reset(defaultValues); - removeAllKnowledgeToFeed(); if (isBrainCreationModalOpened) { setCurrentBrainId(null); setSnippetColor("#d0c6f2"); @@ -57,26 +28,19 @@ export const AddBrainModal = (): JSX.Element => { }, [isBrainCreationModalOpened]); return ( - - } - > -
-
- -
-
- - - -
+ } + > +
+
+
- - +
+
); }; diff --git a/frontend/lib/components/AddBrainModal/brainCreation-provider.tsx b/frontend/lib/components/AddBrainModal/brainCreation-provider.tsx index c5c761bbac75..6c51a161e2e1 100644 --- a/frontend/lib/components/AddBrainModal/brainCreation-provider.tsx +++ b/frontend/lib/components/AddBrainModal/brainCreation-provider.tsx @@ -1,20 +1,10 @@ import { createContext, useContext, useState } from "react"; -import { IntegrationBrains } from "@/lib/api/brain/types"; - -import { StepValue } from "./types/types"; - interface BrainCreationContextProps { isBrainCreationModalOpened: boolean; setIsBrainCreationModalOpened: React.Dispatch>; creating: boolean; setCreating: React.Dispatch>; - currentSelectedBrain: IntegrationBrains | undefined; - setCurrentSelectedBrain: React.Dispatch< - React.SetStateAction - >; - currentStep: StepValue; - setCurrentStep: React.Dispatch>; snippetColor: string; setSnippetColor: React.Dispatch>; snippetEmoji: string; @@ -32,10 +22,7 @@ export const BrainCreationProvider = ({ }): JSX.Element => { const [isBrainCreationModalOpened, setIsBrainCreationModalOpened] = useState(false); - const [currentSelectedBrain, setCurrentSelectedBrain] = - useState(); const [creating, setCreating] = useState(false); - const [currentStep, setCurrentStep] = useState("FIRST_STEP"); const [snippetColor, setSnippetColor] = useState("#d0c6f2"); const [snippetEmoji, setSnippetEmoji] = useState("🧠"); @@ -46,10 +33,6 @@ export const BrainCreationProvider = ({ setIsBrainCreationModalOpened, creating, setCreating, - currentSelectedBrain, - setCurrentSelectedBrain, - currentStep, - setCurrentStep, snippetColor, setSnippetColor, snippetEmoji, diff --git a/frontend/lib/components/AddBrainModal/components/BrainMainInfosStep/BrainMainInfosStep.module.scss b/frontend/lib/components/AddBrainModal/components/BrainCreationForm/BrainCreationForm.module.scss similarity index 94% rename from frontend/lib/components/AddBrainModal/components/BrainMainInfosStep/BrainMainInfosStep.module.scss rename to frontend/lib/components/AddBrainModal/components/BrainCreationForm/BrainCreationForm.module.scss index 077d1e17f822..dc14fba768d8 100644 --- a/frontend/lib/components/AddBrainModal/components/BrainMainInfosStep/BrainMainInfosStep.module.scss +++ b/frontend/lib/components/AddBrainModal/components/BrainCreationForm/BrainCreationForm.module.scss @@ -16,10 +16,6 @@ gap: Spacings.$spacing05; position: relative; - .title { - @include Typography.H3; - } - .inputs_wrapper { display: flex; flex-direction: column; @@ -80,6 +76,12 @@ } } } + + .text_areas_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing06; + } } .buttons_wrapper { diff --git a/frontend/lib/components/AddBrainModal/components/BrainMainInfosStep/BrainMainInfosStep.tsx b/frontend/lib/components/AddBrainModal/components/BrainCreationForm/BrainCreationForm.tsx similarity index 52% rename from frontend/lib/components/AddBrainModal/components/BrainMainInfosStep/BrainMainInfosStep.tsx rename to frontend/lib/components/AddBrainModal/components/BrainCreationForm/BrainCreationForm.tsx index 9466d23dbabb..28d0c5bd431f 100644 --- a/frontend/lib/components/AddBrainModal/components/BrainMainInfosStep/BrainMainInfosStep.tsx +++ b/frontend/lib/components/AddBrainModal/components/BrainCreationForm/BrainCreationForm.tsx @@ -1,55 +1,42 @@ import { useState } from "react"; -import { Controller, useFormContext } from "react-hook-form"; -import { CreateBrainProps } from "@/lib/components/AddBrainModal/types/types"; import { BrainSnippet } from "@/lib/components/BrainSnippet/BrainSnippet"; import { FieldHeader } from "@/lib/components/ui/FieldHeader/FieldHeader"; import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; import { TextAreaInput } from "@/lib/components/ui/TextAreaInput/TextAreaInput"; import { TextInput } from "@/lib/components/ui/TextInput/TextInput"; -import styles from "./BrainMainInfosStep.module.scss"; +import styles from "./BrainCreationForm.module.scss"; import { useBrainCreationContext } from "../../brainCreation-provider"; -import { useBrainCreationSteps } from "../../hooks/useBrainCreationSteps"; +import { useBrainCreationApi } from "../hooks/useBrainCreationApi"; -export const BrainMainInfosStep = (): JSX.Element => { +export const BrainCreationForm = (): JSX.Element => { const [editSnippet, setEditSnippet] = useState(false); - const { currentStepIndex, goToNextStep } = useBrainCreationSteps(); + const [name, setName] = useState(""); + const [description, setDescription] = useState(""); + const [instructions, setInstructions] = useState(""); const { snippetColor, setSnippetColor, snippetEmoji, setSnippetEmoji } = useBrainCreationContext(); + const { setCreating, creating } = useBrainCreationContext(); + const { createBrain } = useBrainCreationApi(); - const { watch } = useFormContext(); - const name = watch("name"); - const description = watch("description"); - - const isDisabled = !name || !description; - - const next = (): void => { - goToNextStep(); + const feed = (): void => { + setCreating(true); + createBrain({ name, description }); }; - if (currentStepIndex !== 0) { - return <>; - } - return (
- Define brain identity
- ( - - )} +
@@ -73,32 +60,44 @@ export const BrainMainInfosStep = (): JSX.Element => { />
-
- - ( - - )} - /> +
+
+ + + +
+
+ + + +
next()} + label="Create" + onClick={() => feed()} iconName="chevronRight" - disabled={isDisabled} important={true} + disabled={!name || !description || !instructions} + isLoading={creating} />
diff --git a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.module.scss b/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.module.scss deleted file mode 100644 index e48548da8749..000000000000 --- a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.module.scss +++ /dev/null @@ -1,24 +0,0 @@ -@use "styles/BoxShadow.module.scss"; -@use "styles/Radius.module.scss"; -@use "styles/Spacings.module.scss"; -@use "styles/Typography.module.scss"; - -.brain_recap_card_wrapper { - display: flex; - padding: Spacings.$spacing05; - justify-content: center; - box-shadow: BoxShadow.$small; - border-radius: Radius.$normal; - display: flex; - align-items: center; - gap: Spacings.$spacing03; - - .number_label { - @include Typography.Big; - color: var(--primary-0); - } - - .type { - @include Typography.H1; - } -} diff --git a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.tsx b/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.tsx deleted file mode 100644 index b6e7d76ddf26..000000000000 --- a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapCard/BrainRecapCard.tsx +++ /dev/null @@ -1,21 +0,0 @@ -import styles from "./BrainRecapCard.module.scss"; - -interface BrainRecapCardProps { - label: string; - number: number; -} - -export const BrainRecapCard = ({ - label, - number, -}: BrainRecapCardProps): JSX.Element => { - return ( -
- {number.toString()} - - {label} - {number > 1 ? "s" : ""} - -
- ); -}; diff --git a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.module.scss b/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.module.scss deleted file mode 100644 index 349323a54417..000000000000 --- a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.module.scss +++ /dev/null @@ -1,76 +0,0 @@ -@use "styles/ScreenSizes.module.scss"; -@use "styles/Spacings.module.scss"; -@use "styles/Typography.module.scss"; - -.brain_recap_wrapper { - display: flex; - justify-content: space-between; - flex-direction: column; - height: 100%; - gap: Spacings.$spacing05; - overflow: hidden; - - .content_wrapper { - display: flex; - flex-direction: column; - gap: Spacings.$spacing05; - overflow: auto; - - .title { - @include Typography.H3; - } - - .subtitle { - font-size: Typography.$small; - font-weight: 500; - } - - .warning_message { - font-size: Typography.$small; - } - - .brain_info_wrapper { - display: flex; - flex-direction: column; - gap: Spacings.$spacing05; - - .name_field { - width: 300px; - } - - @media screen and (max-width: ScreenSizes.$small) { - .name_field { - min-width: 100%; - max-width: 100%; - } - } - } - - .cards_wrapper { - display: flex; - flex-wrap: wrap; - padding: Spacings.$spacing01; - justify-content: space-between; - gap: Spacings.$spacing05; - - > * { - min-width: 120px; - max-width: 200px; - flex: 1; - } - - @media screen and (max-width: ScreenSizes.$small) { - > * { - min-width: 100%; - max-width: 100%; - flex: 1; - } - } - } - } - - .buttons_wrapper { - display: flex; - justify-content: space-between; - } -} diff --git a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.tsx b/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.tsx deleted file mode 100644 index 93de8f3be980..000000000000 --- a/frontend/lib/components/AddBrainModal/components/BrainRecapStep/BrainRecapStep.tsx +++ /dev/null @@ -1,137 +0,0 @@ -import { Controller } from "react-hook-form"; - -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { useUserApi } from "@/lib/api/user/useUserApi"; -import { MessageInfoBox } from "@/lib/components/ui/MessageInfoBox/MessageInfoBox"; -import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; -import { TextAreaInput } from "@/lib/components/ui/TextAreaInput/TextAreaInput"; -import { TextInput } from "@/lib/components/ui/TextInput/TextInput"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { useOnboardingContext } from "@/lib/context/OnboardingProvider/hooks/useOnboardingContext"; -import { useUserData } from "@/lib/hooks/useUserData"; - -import { BrainRecapCard } from "./BrainRecapCard/BrainRecapCard"; -import styles from "./BrainRecapStep.module.scss"; - -import { useBrainCreationContext } from "../../brainCreation-provider"; -import { useBrainCreationSteps } from "../../hooks/useBrainCreationSteps"; -import { useBrainCreationApi } from "../FeedBrainStep/hooks/useBrainCreationApi"; - -export const BrainRecapStep = (): JSX.Element => { - const { currentStepIndex, goToPreviousStep } = useBrainCreationSteps(); - const { creating, setCreating } = useBrainCreationContext(); - const { knowledgeToFeed } = useKnowledgeToFeedContext(); - const { createBrain } = useBrainCreationApi(); - const { updateUserIdentity } = useUserApi(); - const { userIdentityData } = useUserData(); - const { openedConnections } = useFromConnectionsContext(); - const { setIsBrainCreated } = useOnboardingContext(); - - const feed = async (): Promise => { - if (!userIdentityData?.onboarded) { - await updateUserIdentity({ - ...userIdentityData, - username: userIdentityData?.username ?? "", - onboarded: true, - }); - } - setCreating(true); - createBrain(); - }; - - const previous = (): void => { - goToPreviousStep(); - }; - - if (currentStepIndex !== 2) { - return <>; - } - - return ( -
-
- - - Depending on the number of knowledge, the upload can take - few minutes. - - - Brain Recap -
-
- ( - - )} - /> -
-
- ( - - )} - /> -
-
- Knowledge From -
- connection.selectedFiles.files.length - ).length - } - /> - knowledge.source === "crawl" - ).length - } - /> - knowledge.source === "upload" - ).length - } - /> -
-
-
- - { - await feed(); - setIsBrainCreated(true); - }} - isLoading={creating} - important={true} - /> -
-
- ); -}; diff --git a/frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.module.scss b/frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.module.scss deleted file mode 100644 index 2ebfe6808460..000000000000 --- a/frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.module.scss +++ /dev/null @@ -1,42 +0,0 @@ -@use "styles/Spacings.module.scss"; -@use "styles/Typography.module.scss"; - -.brain_knowledge_wrapper { - display: flex; - flex-direction: column; - justify-content: space-between; - height: 100%; - - .tutorial { - padding-bottom: Spacings.$spacing05; - } - - .feed_brain { - display: flex; - flex-direction: column; - overflow: auto; - height: 100%; - - .title { - @include Typography.H3; - } - } - - .message_info_box_wrapper { - align-self: center; - display: flex; - - .message_content { - display: flex; - gap: Spacings.$spacing03; - flex-wrap: wrap; - align-items: center; - align-self: center; - } - } - - .buttons_wrapper { - display: flex; - justify-content: space-between; - } -} diff --git a/frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.tsx b/frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.tsx deleted file mode 100644 index 887845e89ade..000000000000 --- a/frontend/lib/components/AddBrainModal/components/FeedBrainStep/FeedBrainStep.tsx +++ /dev/null @@ -1,112 +0,0 @@ -import { useEffect, useState } from "react"; - -import { KnowledgeToFeed } from "@/app/chat/[chatId]/components/ActionsBar/components"; -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { OpenedConnection } from "@/lib/api/sync/types"; -import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { createHandleGetButtonProps } from "@/lib/helpers/handleConnectionButtons"; -import { useUserData } from "@/lib/hooks/useUserData"; - -import styles from "./FeedBrainStep.module.scss"; - -import { useBrainCreationSteps } from "../../hooks/useBrainCreationSteps"; - -export const FeedBrainStep = (): JSX.Element => { - const { currentStepIndex, goToPreviousStep, goToNextStep } = - useBrainCreationSteps(); - const { userIdentityData } = useUserData(); - const { - currentSyncId, - setCurrentSyncId, - openedConnections, - setOpenedConnections, - } = useFromConnectionsContext(); - const [currentConnection, setCurrentConnection] = useState< - OpenedConnection | undefined - >(undefined); - const { knowledgeToFeed } = useKnowledgeToFeedContext(); - - useEffect(() => { - setCurrentConnection( - openedConnections.find( - (connection) => connection.user_sync_id === currentSyncId - ) - ); - }, [currentSyncId]); - - const getButtonProps = createHandleGetButtonProps( - currentConnection, - openedConnections, - setOpenedConnections, - currentSyncId, - setCurrentSyncId - ); - - const renderFeedBrain = () => ( - <> -
- Feed your brain - -
- - ); - - const renderButtons = () => { - const buttonProps = getButtonProps(); - - return ( -
- {currentSyncId ? ( - setCurrentSyncId(undefined)} - /> - ) : ( - - )} - {currentSyncId ? ( - - ) : ( - - )} -
- ); - }; - - if (currentStepIndex !== 1) { - return <>; - } - - return ( -
- {renderFeedBrain()} - {renderButtons()} -
- ); -}; diff --git a/frontend/lib/components/AddBrainModal/components/FeedBrainStep/hooks/useBrainCreationApi.ts b/frontend/lib/components/AddBrainModal/components/FeedBrainStep/hooks/useBrainCreationApi.ts deleted file mode 100644 index 92b5a6e76376..000000000000 --- a/frontend/lib/components/AddBrainModal/components/FeedBrainStep/hooks/useBrainCreationApi.ts +++ /dev/null @@ -1,144 +0,0 @@ -import { useMutation, useQueryClient } from "@tanstack/react-query"; -import { AxiosError } from "axios"; -import { UUID } from "crypto"; -import { useState } from "react"; -import { useFormContext } from "react-hook-form"; -import { useTranslation } from "react-i18next"; -import { v4 as uuidv4 } from "uuid"; - -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { PUBLIC_BRAINS_KEY } from "@/lib/api/brain/config"; -import { IntegrationSettings } from "@/lib/api/brain/types"; -import { useSync } from "@/lib/api/sync/useSync"; -import { CreateBrainProps } from "@/lib/components/AddBrainModal/types/types"; -import { useKnowledgeToFeedInput } from "@/lib/components/KnowledgeToFeedInput/hooks/useKnowledgeToFeedInput.ts"; -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { useToast } from "@/lib/hooks"; -import { useKnowledgeToFeedFilesAndUrls } from "@/lib/hooks/useKnowledgeToFeed"; - -import { useBrainCreationContext } from "../../../brainCreation-provider"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useBrainCreationApi = () => { - const queryClient = useQueryClient(); - const { publish } = useToast(); - const { t } = useTranslation(["brain", "config"]); - const { files, urls } = useKnowledgeToFeedFilesAndUrls(); - const { getValues, reset } = useFormContext(); - const { setKnowledgeToFeed } = useKnowledgeToFeedContext(); - const { createBrain: createBrainApi, setCurrentBrainId } = useBrainContext(); - const { crawlWebsiteHandler, uploadFileHandler } = useKnowledgeToFeedInput(); - const { - setIsBrainCreationModalOpened, - setCreating, - currentSelectedBrain, - snippetColor, - snippetEmoji, - } = useBrainCreationContext(); - const { setOpenedConnections } = useFromConnectionsContext(); - const [fields, setFields] = useState< - { name: string; type: string; value: string }[] - >([]); - const { syncFiles } = useSync(); - const { openedConnections } = useFromConnectionsContext(); - - const handleFeedBrain = async (brainId: UUID): Promise => { - const crawlBulkId: UUID = uuidv4().toString() as UUID; - const uploadBulkId: UUID = uuidv4().toString() as UUID; - - const uploadPromises = files.map((file) => - uploadFileHandler(file, brainId, uploadBulkId) - ); - - const crawlPromises = urls.map((url) => - crawlWebsiteHandler(url, brainId, crawlBulkId) - ); - - await Promise.all([...uploadPromises, ...crawlPromises]); - await Promise.all( - openedConnections - .filter((connection) => connection.selectedFiles.files.length) - .map(async (openedConnection) => { - await syncFiles(openedConnection, brainId); - }) - ); - setKnowledgeToFeed([]); - }; - - const createBrain = async (): Promise => { - const { name, description } = getValues(); - let integrationSettings: IntegrationSettings | undefined = undefined; - - if (currentSelectedBrain) { - integrationSettings = { - integration_id: currentSelectedBrain.id, - settings: fields.reduce((acc, field) => { - acc[field.name] = field.value; - - return acc; - }, {} as { [key: string]: string }), - }; - } - - const createdBrainId = await createBrainApi({ - brain_type: currentSelectedBrain ? "integration" : "doc", - name, - description, - integration: integrationSettings, - snippet_color: snippetColor, - snippet_emoji: snippetEmoji, - }); - - if (createdBrainId === undefined) { - publish({ - variant: "danger", - text: t("errorCreatingBrain", { ns: "brain" }), - }); - - return; - } - - void handleFeedBrain(createdBrainId); - - setCurrentBrainId(createdBrainId); - setIsBrainCreationModalOpened(false); - setCreating(false); - setOpenedConnections([]); - reset(); - - void queryClient.invalidateQueries({ - queryKey: [PUBLIC_BRAINS_KEY], - }); - }; - - const { mutate, isPending: isBrainCreationPending } = useMutation({ - mutationFn: createBrain, - onSuccess: () => { - publish({ - variant: "success", - text: t("brainCreated", { ns: "brain" }), - }); - }, - onError: (error: AxiosError) => { - if (error.response && error.response.status === 429) { - publish({ - variant: "danger", - text: "You have reached your maximum amount of brains. Upgrade your plan to create more.", - }); - } else { - publish({ - variant: "danger", - text: t("errorCreatingBrain", { ns: "brain" }), - }); - } - }, - }); - - return { - createBrain: mutate, - isBrainCreationPending, - fields, - setFields, - }; -}; diff --git a/frontend/lib/components/AddBrainModal/components/Stepper/Stepper.module.scss b/frontend/lib/components/AddBrainModal/components/Stepper/Stepper.module.scss deleted file mode 100644 index f1563704e27f..000000000000 --- a/frontend/lib/components/AddBrainModal/components/Stepper/Stepper.module.scss +++ /dev/null @@ -1,118 +0,0 @@ -@use "styles/Radius.module.scss"; -@use "styles/Spacings.module.scss"; -@use "styles/Typography.module.scss"; - -.stepper_wrapper { - display: flex; - width: 100%; - justify-content: space-between; - overflow: visible; - - .step { - display: flex; - flex-direction: column; - border-radius: Radius.$circle; - position: relative; - - .circle { - width: 1.75rem; - height: 1.75rem; - background-color: var(--primary-0); - border-radius: Radius.$circle; - display: flex; - justify-content: center; - align-items: center; - - .inside_circle { - width: 100%; - height: 100%; - border-radius: Radius.$circle; - display: flex; - justify-content: center; - align-items: center; - } - } - - .step_info { - margin-top: Spacings.$spacing03; - display: flex; - flex-direction: column; - font-size: Typography.$tiny; - width: 1.75rem; - align-items: center; - - .step_index { - white-space: nowrap; - color: var(--text-1); - } - } - - &.done_step { - .circle { - background-color: var(--success); - } - - .step_info { - .step_status { - color: var(--success); - } - } - } - - &.current_step { - .circle { - background-color: var(--background-0); - border: 1px solid var(--primary-0); - } - - .inside_circle { - background-color: var(--primary-0); - width: 70%; - height: 70%; - } - - .step_info { - .step_status { - color: var(--primary-0); - } - } - } - - &.pending_step { - .circle { - background-color: var(--primary-1); - } - - .step_info { - .step_status { - color: var(--text-1); - } - } - } - - &:first-child { - .step_info { - align-items: start; - } - } - - &:last-child { - .step_info { - align-items: end; - } - } - } - - .bar { - flex-grow: 1; - height: 4px; - border-radius: Radius.$big; - background-color: var(--primary-1); - margin: 0 8px; - margin-top: Spacings.$spacing04; - - &.done { - background-color: var(--success); - } - } -} diff --git a/frontend/lib/components/AddBrainModal/components/Stepper/Stepper.tsx b/frontend/lib/components/AddBrainModal/components/Stepper/Stepper.tsx deleted file mode 100644 index 1ae3c8717218..000000000000 --- a/frontend/lib/components/AddBrainModal/components/Stepper/Stepper.tsx +++ /dev/null @@ -1,63 +0,0 @@ -import React from "react"; - -import { Icon } from "@/lib/components/ui/Icon/Icon"; - -import styles from "./Stepper.module.scss"; - -import { StepValue } from "../../types/types"; - -interface StepperProps { - currentStep: StepValue; - steps: { value: string; label: string }[]; -} - -export const Stepper = ({ currentStep, steps }: StepperProps): JSX.Element => { - const currentStepIndex = steps.findIndex( - (step) => step.value === currentStep - ); - - return ( -
- {steps.map((step, index) => ( - -
-
-
- {index < currentStepIndex && ( - - )} -
-
-
- STEP {index + 1} - - {index === currentStepIndex - ? "Progress" - : index < currentStepIndex - ? "Completed" - : "Pending"} - -
-
- {index < steps.length - 1 && ( -
- )} -
- ))} -
- ); -}; diff --git a/frontend/lib/components/AddBrainModal/components/hooks/useBrainCreationApi.ts b/frontend/lib/components/AddBrainModal/components/hooks/useBrainCreationApi.ts new file mode 100644 index 000000000000..8b660a00926f --- /dev/null +++ b/frontend/lib/components/AddBrainModal/components/hooks/useBrainCreationApi.ts @@ -0,0 +1,97 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { AxiosError } from "axios"; +import { useState } from "react"; +import { useTranslation } from "react-i18next"; + +import { PUBLIC_BRAINS_KEY } from "@/lib/api/brain/config"; +import { IntegrationSettings } from "@/lib/api/brain/types"; +import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; +import { useToast } from "@/lib/hooks"; + +import { useBrainCreationContext } from "../../brainCreation-provider"; + +// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types +export const useBrainCreationApi = () => { + const queryClient = useQueryClient(); + const { publish } = useToast(); + const { t } = useTranslation(["brain", "config"]); + const { createBrain: createBrainApi, setCurrentBrainId } = useBrainContext(); + const { + setIsBrainCreationModalOpened, + setCreating, + snippetColor, + snippetEmoji, + } = useBrainCreationContext(); + const [fields, setFields] = useState< + { name: string; type: string; value: string }[] + >([]); + + const createBrain = async ( + name: string, + description: string + ): Promise => { + const integrationSettings: IntegrationSettings | undefined = undefined; + + const createdBrainId = await createBrainApi({ + brain_type: "doc", + name, + description, + integration: integrationSettings, + snippet_color: snippetColor, + snippet_emoji: snippetEmoji, + }); + + if (createdBrainId === undefined) { + publish({ + variant: "danger", + text: t("errorCreatingBrain", { ns: "brain" }), + }); + + return; + } + + setCurrentBrainId(createdBrainId); + setIsBrainCreationModalOpened(false); + setCreating(false); + + void queryClient.invalidateQueries({ + queryKey: [PUBLIC_BRAINS_KEY], + }); + }; + + const { mutate, isPending: isBrainCreationPending } = useMutation({ + mutationFn: ({ + name, + description, + }: { + name: string; + description: string; + }) => createBrain(name, description), + onSuccess: () => { + publish({ + variant: "success", + text: t("brainCreated", { ns: "brain" }), + }); + }, + onError: (error: AxiosError) => { + if (error.response && error.response.status === 429) { + publish({ + variant: "danger", + text: "You have reached your maximum amount of brains. Upgrade your plan to create more.", + }); + } else { + publish({ + variant: "danger", + text: t("errorCreatingBrain", { ns: "brain" }), + }); + } + }, + }); + + return { + createBrain: mutate, + isBrainCreationPending, + fields, + setFields, + }; +}; diff --git a/frontend/lib/components/AddBrainModal/hooks/useBrainCreationSteps.ts b/frontend/lib/components/AddBrainModal/hooks/useBrainCreationSteps.ts deleted file mode 100644 index 66075576d2fe..000000000000 --- a/frontend/lib/components/AddBrainModal/hooks/useBrainCreationSteps.ts +++ /dev/null @@ -1,66 +0,0 @@ -import { useEffect } from "react"; -import { useTranslation } from "react-i18next"; - -import { Step } from "@/lib/types/Modal"; - -import { useBrainCreationContext } from "../brainCreation-provider"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useBrainCreationSteps = () => { - const { t } = useTranslation("brain"); - const { isBrainCreationModalOpened, currentStep, setCurrentStep } = - useBrainCreationContext(); - - const steps: Step[] = [ - { - label: t("brain_type"), - value: "FIRST_STEP", - }, - { - label: t("brain_params"), - value: "SECOND_STEP", - }, - { - label: t("resources"), - value: "THIRD_STEP", - }, - ]; - - const currentStepIndex = steps.findIndex( - (step) => step.value === currentStep - ); - - useEffect(() => { - goToFirstStep(); - }, [isBrainCreationModalOpened]); - - const goToNextStep = () => { - if (currentStepIndex === -1 || currentStepIndex === steps.length - 1) { - return; - } - const nextStep = steps[currentStepIndex + 1]; - - return setCurrentStep(nextStep.value); - }; - - const goToPreviousStep = () => { - if (currentStepIndex === -1 || currentStepIndex === 0) { - return; - } - const previousStep = steps[currentStepIndex - 1]; - - return setCurrentStep(previousStep.value); - }; - - const goToFirstStep = () => { - return setCurrentStep(steps[0].value); - }; - - return { - currentStep, - steps, - goToNextStep, - goToPreviousStep, - currentStepIndex, - }; -}; diff --git a/frontend/lib/components/AddBrainModal/types/types.ts b/frontend/lib/components/AddBrainModal/types/types.ts index 6831c00c77f9..33290fe3477c 100644 --- a/frontend/lib/components/AddBrainModal/types/types.ts +++ b/frontend/lib/components/AddBrainModal/types/types.ts @@ -1,15 +1,5 @@ -import { CreateBrainInput } from "@/lib/api/brain/types"; import { iconList } from "@/lib/helpers/iconList"; -const steps = ["FIRST_STEP", "SECOND_STEP", "THIRD_STEP"] as const; - -export type StepValue = (typeof steps)[number]; - -export type CreateBrainProps = CreateBrainInput & { - setDefault: boolean; - brainCreationStep: StepValue; -}; - export interface BrainType { name: string; description: string; diff --git a/frontend/lib/components/ConnectionCards/ConnectionCards.tsx b/frontend/lib/components/ConnectionCards/ConnectionCards.tsx index ae81175a8302..8ed439229585 100644 --- a/frontend/lib/components/ConnectionCards/ConnectionCards.tsx +++ b/frontend/lib/components/ConnectionCards/ConnectionCards.tsx @@ -10,38 +10,34 @@ interface ConnectionCardsProps { export const ConnectionCards = ({ fromAddKnowledge, }: ConnectionCardsProps): JSX.Element => { - const { syncGoogleDrive, syncSharepoint, syncDropbox } = - useSync(); + const { syncGoogleDrive, syncSharepoint, syncDropbox } = useSync(); return (
syncDropbox(name)} - fromAddKnowledge={fromAddKnowledge} /> syncGoogleDrive(name)} - fromAddKnowledge={fromAddKnowledge} /> {/* syncNotion(name)} - fromAddKnowledge={fromAddKnowledge} oneAccountLimitation={true} /> */} syncSharepoint(name)} - fromAddKnowledge={fromAddKnowledge} />
); diff --git a/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionLine/ConnectionLine.tsx b/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionLine/ConnectionLine.tsx index 9fc1c57e1a96..7250801eb107 100644 --- a/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionLine/ConnectionLine.tsx +++ b/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionLine/ConnectionLine.tsx @@ -1,11 +1,10 @@ import { useState } from "react"; -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; import { useSync } from "@/lib/api/sync/useSync"; import { ConnectionIcon } from "@/lib/components/ui/ConnectionIcon/ConnectionIcon"; import { Icon } from "@/lib/components/ui/Icon/Icon"; import { Modal } from "@/lib/components/ui/Modal/Modal"; -import QuivrButton from "@/lib/components/ui/QuivrButton/QuivrButton"; +import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; import styles from "./ConnectionLine.module.scss"; @@ -26,7 +25,6 @@ export const ConnectionLine = ({ const [deleteModalOpened, setDeleteModalOpened] = useState(false); const { deleteUserSync } = useSync(); - const { setHasToReload } = useFromConnectionsContext(); return ( <> @@ -46,7 +44,6 @@ export const ConnectionLine = ({ setDeleteModalOpened(true); } else { await deleteUserSync(id); - setHasToReload(true); } }} /> @@ -85,7 +82,6 @@ export const ConnectionLine = ({ setDeleteLoading(true); await deleteUserSync(id); setDeleteLoading(false); - setHasToReload(true); setDeleteModalOpened(false); }} /> diff --git a/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionSection.tsx b/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionSection.tsx index 3d9b7ae986f5..dcf0cd736a7f 100644 --- a/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionSection.tsx +++ b/frontend/lib/components/ConnectionCards/ConnectionSection/ConnectionSection.tsx @@ -1,49 +1,34 @@ import Image from "next/image"; import { useEffect, useState } from "react"; -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { OpenedConnection, Provider, Sync } from "@/lib/api/sync/types"; +import { Provider, Sync } from "@/lib/api/sync/types"; import { useSync } from "@/lib/api/sync/useSync"; import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; import { iconList } from "@/lib/helpers/iconList"; -import { ConnectionButton } from "./ConnectionButton/ConnectionButton"; import { ConnectionLine } from "./ConnectionLine/ConnectionLine"; import styles from "./ConnectionSection.module.scss"; import { ConnectionIcon } from "../../ui/ConnectionIcon/ConnectionIcon"; import { Icon } from "../../ui/Icon/Icon"; -import { TextButton } from "../../ui/TextButton/TextButton"; import Tooltip from "../../ui/Tooltip/Tooltip"; interface ConnectionSectionProps { label: string; provider: Provider; callback: (name: string) => Promise<{ authorization_url: string }>; - fromAddKnowledge?: boolean; oneAccountLimitation?: boolean; } export const ConnectionSection = ({ label, provider, - fromAddKnowledge, callback, oneAccountLimitation, }: ConnectionSectionProps): JSX.Element => { - const { providerIconUrls, getUserSyncs, getSyncFiles } = useSync(); - const { - setCurrentSyncElements, - setCurrentSyncId, - setOpenedConnections, - openedConnections, - hasToReload, - setHasToReload, - setLoadingFirstList, - setCurrentProvider, - } = useFromConnectionsContext(); + const { providerIconUrls, getUserSyncs } = useSync(); const [existingConnections, setExistingConnections] = useState([]); - const [folded, setFolded] = useState(!fromAddKnowledge); + const [folded, setFolded] = useState(true); const fetchUserSyncs = async () => { try { @@ -52,7 +37,7 @@ export const ConnectionSection = ({ res.filter( (sync) => Object.keys(sync.credentials).length !== 0 && - sync.provider === provider + sync.provider.toLowerCase() === provider.toLowerCase() ) ); } catch (error) { @@ -94,48 +79,6 @@ export const ConnectionSection = ({ }; }, []); - useEffect(() => { - if (hasToReload) { - void fetchUserSyncs(); - setHasToReload(false); - } - }, [hasToReload]); - - const handleOpenedConnections = (userSyncId: number) => { - const existingConnection = openedConnections.find( - (connection) => connection.user_sync_id === userSyncId - ); - - if (!existingConnection) { - const newConnection: OpenedConnection = { - name: - existingConnections.find((connection) => connection.id === userSyncId) - ?.name ?? "", - user_sync_id: userSyncId, - id: undefined, - provider: provider, - submitted: false, - selectedFiles: { files: [] }, - last_synced: "", - }; - - setOpenedConnections([...openedConnections, newConnection]); - } - }; - - const handleGetSyncFiles = async (userSyncId: number) => { - try { - setLoadingFirstList(true); - const res = await getSyncFiles(userSyncId); - setLoadingFirstList(false); - setCurrentSyncElements(res); - setCurrentSyncId(userSyncId); - handleOpenedConnections(userSyncId); - } catch (error) { - console.error("Failed to get sync files:", error); - } - }; - const connect = async () => { const res = await callback( Math.random().toString(36).substring(2, 15) + @@ -186,45 +129,21 @@ export const ConnectionSection = ({ return null; } - if (!fromAddKnowledge) { - return ( -
-
- Connected accounts - setFolded(!folded)} - /> -
- {renderConnectionLines(activeConnections, folded)} -
- ); - } else { - return ( -
- {activeConnections.map((connection, index) => ( - - openedConnection.name === connection.name && - openedConnection.submitted - )} - onClick={() => { - void handleGetSyncFiles(connection.id); - setCurrentProvider(connection.provider); - }} - sync={connection} - /> - ))} + return ( +
+
+ Connected accounts + setFolded(!folded)} + />
- ); - } + {renderConnectionLines(activeConnections, folded)} +
+ ); }; return ( @@ -233,15 +152,14 @@ export const ConnectionSection = ({
{label} {label}
- {!fromAddKnowledge && - (!oneAccountLimitation || existingConnections.length === 0) ? ( + {!oneAccountLimitation || existingConnections.length === 0 ? ( ) : null} - - {fromAddKnowledge && - (!oneAccountLimitation || existingConnections.length === 0) && ( - - )}
{renderExistingConnections()}
diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.module.scss new file mode 100644 index 000000000000..ed8fe13f17da --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.module.scss @@ -0,0 +1,7 @@ +.current_folder_explorer_container { + display: flex; + width: 100%; + height: 100%; + background-color: var(--background-1); + overflow: auto; +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.tsx new file mode 100644 index 000000000000..483a01f19ced --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/CurrentFolderExplorer.tsx @@ -0,0 +1,34 @@ +import { useEffect } from "react"; + +import styles from "./CurrentFolderExplorer.module.scss"; +import ProviderCurrentFolder from "./ProviderCurrentFolder/ProviderCurrentFolder"; +import QuivrCurrentFolder from "./QuivrCurrentFolder/QuivrCurrentFolder"; + +import { useKnowledgeContext } from "../../KnowledgeProvider/hooks/useKnowledgeContext"; + +interface CurrentFolderExplorerProps { + fromBrainStudio?: boolean; +} + +const CurrentFolderExplorer = ({ + fromBrainStudio, +}: CurrentFolderExplorerProps): JSX.Element => { + const { exploringQuivr, exploredProvider, setExploringQuivr } = + useKnowledgeContext(); + + useEffect(() => { + setExploringQuivr(true); + }, []); + + return ( +
+ {exploredProvider || !exploringQuivr ? ( + + ) : ( + + )} +
+ ); +}; + +export default CurrentFolderExplorer; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.module.scss new file mode 100644 index 000000000000..f2ee160906de --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.module.scss @@ -0,0 +1,25 @@ +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.main_container { + display: flex; + align-items: center; + padding-inline: Spacings.$spacing06; + padding-block: Spacings.$spacing03; + border-bottom: 1px solid var(--border-1); + font-size: Typography.$small; + display: flex; + justify-content: space-between; + align-items: center; + + .left { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + } + + &:hover { + cursor: pointer; + background-color: var(--background-2); + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.tsx new file mode 100644 index 000000000000..8698b76f7b17 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderAccount/ProviderAccount.tsx @@ -0,0 +1,38 @@ +import { Sync } from "@/lib/api/sync/types"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { ConnectionIcon } from "@/lib/components/ui/ConnectionIcon/ConnectionIcon"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; + +import styles from "./ProviderAccount.module.scss"; + +interface ProviderAccountProps { + sync: Sync; + index: number; +} + +const ProviderAccount = ({ + sync, + index, +}: ProviderAccountProps): JSX.Element => { + const { setExploredSpecificAccount } = useKnowledgeContext(); + + return ( +
setExploredSpecificAccount(sync)} + > +
+ + {sync.email} +
+ +
+ ); +}; + +export default ProviderAccount; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.module.scss new file mode 100644 index 000000000000..c86a2a69b00e --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.module.scss @@ -0,0 +1,29 @@ +@use "styles/Spacings.module.scss"; + +.main_container { + height: 100%; + width: 100%; + + .current_folder_content { + display: flex; + flex-direction: column; + overflow: scroll; + + .content_header { + border-bottom: 1px solid var(--border-1); + padding-inline: Spacings.$spacing06; + padding-block: Spacings.$spacing05; + display: flex; + justify-content: space-between; + width: 100%; + } + + .loading_icon { + width: 100%; + display: flex; + align-items: center; + justify-content: center; + padding-top: Spacings.$spacing06; + } + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.tsx new file mode 100644 index 000000000000..da654de777d3 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/ProviderCurrentFolder/ProviderCurrentFolder.tsx @@ -0,0 +1,136 @@ +import { useEffect, useState } from "react"; + +import { KMSElement, Sync } from "@/lib/api/sync/types"; +import { useSync } from "@/lib/api/sync/useSync"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; +import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; + +import ProviderAccount from "./ProviderAccount/ProviderAccount"; +import styles from "./ProviderCurrentFolder.module.scss"; + +import AddToBrainsModal from "../../shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal"; +import CurrentFolderExplorerLine from "../../shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine"; +import FolderExplorerHeader from "../../shared/FolderExplorerHeader/FolderExplorerHeader"; + +interface ProviderCurrentFolderExplorerProps { + fromBrainStudio?: boolean; +} + +const ProviderCurrentFolder = ({ + fromBrainStudio, +}: ProviderCurrentFolderExplorerProps): JSX.Element => { + const [providerRootElements, setproviderRootElements] = + useState(); + const [loading, setLoading] = useState(false); + const [showAddToBrainsModal, setShowAddToBrainsModal] = + useState(false); + const { + exploredProvider, + currentFolder, + exploredSpecificAccount, + selectedKnowledges, + setSelectedKnowledges, + } = useKnowledgeContext(); + const { getSyncFiles } = useSync(); + + const fetchCurrentFolderElements = (sync: Sync) => { + setLoading(true); + void (async () => { + try { + const res = await getSyncFiles( + sync.id, + currentFolder?.sync_file_id ?? undefined + ); + setproviderRootElements(res); + setLoading(false); + setSelectedKnowledges([]); + } catch (error) { + console.error("Failed to get sync files:", error); + } + })(); + }; + + useEffect(() => { + if (!showAddToBrainsModal && exploredProvider) { + if (exploredProvider.syncs.length === 1) { + void fetchCurrentFolderElements(exploredProvider.syncs[0]); + } else if (exploredSpecificAccount) { + void fetchCurrentFolderElements(exploredSpecificAccount); + } + } + }, [showAddToBrainsModal]); + + useEffect(() => { + if (exploredProvider) { + if (exploredProvider.syncs.length === 1) { + void fetchCurrentFolderElements(exploredProvider.syncs[0]); + } else if (exploredSpecificAccount) { + void fetchCurrentFolderElements(exploredSpecificAccount); + } + } + }, [currentFolder, exploredProvider, exploredSpecificAccount]); + + return ( +
+ +
+ {exploredProvider?.syncs && + !exploredSpecificAccount && + exploredProvider.syncs.length > 1 ? ( + exploredProvider.syncs.map((sync, index) => ( +
+ +
+ )) + ) : loading ? ( +
+ +
+ ) : ( + <> + {!fromBrainStudio && ( +
+ setShowAddToBrainsModal(true)} + small={true} + disabled={!selectedKnowledges.length} + /> +
+ )} +
+ {providerRootElements + ?.sort((a, b) => Number(b.is_folder) - Number(a.is_folder)) + .map((element, index) => ( + + ))} +
+ + )} +
+ {showAddToBrainsModal && ( +
e.stopPropagation()} + > + setShowAddToBrainsModal(false)} + knowledges={selectedKnowledges} + /> +
+ )} +
+ ); +}; + +export default ProviderCurrentFolder; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.module.scss new file mode 100644 index 000000000000..fbceb2127b4f --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.module.scss @@ -0,0 +1,13 @@ +@use "styles/Spacings.module.scss"; + +.modal_content { + display: flex; + flex-direction: column; + gap: Spacings.$spacing05; + padding-top: Spacings.$spacing05; + + .buttons_wrapper { + display: flex; + justify-content: space-between; + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.tsx new file mode 100644 index 000000000000..194ce25e534d --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddFolderModal/AddFolderModal.tsx @@ -0,0 +1,92 @@ +import { useState } from "react"; + +import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { Modal } from "@/lib/components/ui/Modal/Modal"; +import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; +import { TextInput } from "@/lib/components/ui/TextInput/TextInput"; + +import styles from "./AddFolderModal.module.scss"; + +interface AddFolderModalProps { + isOpen: boolean; + setIsOpen: (isOpen: boolean) => void; +} + +const AddFolderModal = ({ + isOpen, + setIsOpen, +}: AddFolderModalProps): JSX.Element => { + const [folderName, setFolderName] = useState(""); + const [loading, setLoading] = useState(false); + + const { currentFolder, setRefetchFolderMenu } = useKnowledgeContext(); + const { addFolder } = useKnowledgeApi(); + + const handleKeyDown = async (event: React.KeyboardEvent) => { + if (event.key === "Enter" && folderName !== "") { + await createFolder(); + } + }; + + const createFolder = async () => { + if (folderName !== "") { + setLoading(true); + await addFolder({ + parent_id: currentFolder?.id ?? null, + file_name: folderName, + is_folder: true, + }); + setRefetchFolderMenu(true); + setFolderName(""); + setIsOpen(false); + setLoading(false); + } + }; + + const handleCancel = () => { + setFolderName(""); + setIsOpen(false); + }; + + return ( +
+ } + CloseTrigger={
} + > +
+ void handleKeyDown(event)} + /> +
+ + +
+
+ +
+ ); +}; + +export default AddFolderModal; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.module.scss new file mode 100644 index 000000000000..7e7b3443e56e --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.module.scss @@ -0,0 +1,60 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; +@use "styles/Variables.module.scss"; + +.modal_content { + max-height: calc(100% - Spacings.$spacing07); + min-height: calc(100% - Spacings.$spacing07); + display: flex; + flex-direction: column; + justify-content: space-between; + gap: Spacings.$spacing05; + overflow: hidden; + + .top { + display: flex; + flex-direction: column; + gap: Spacings.$spacing03; + padding-top: Spacings.$spacing05; + flex: 1; + max-height: 100%; + overflow: hidden; + + .inputs_wrapper { + padding-top: Spacings.$spacing04; + height: Variables.$fileInputHeight; + } + + .list_header { + margin-top: Spacings.$spacing06; + display: flex; + justify-content: flex-end; + } + + .file_list { + border-top: 1px solid var(--border-0); + overflow: auto; + max-height: 100%; + + &.empty { + border: none; + } + + .file_item { + font-size: Typography.$tiny; + border-bottom: 1px solid var(--border-0); + padding-block: Spacings.$spacing03; + padding-inline: Spacings.$spacing05; + display: flex; + align-items: center; + gap: Spacings.$spacing03; + } + } + } + + .buttons_wrapper { + display: flex; + justify-content: space-between; + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.tsx new file mode 100644 index 000000000000..bc88d33720d8 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/AddKnowledgeModal/AddKnowledgeModal.tsx @@ -0,0 +1,233 @@ +import { useEffect, useState } from "react"; + +import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { Checkbox } from "@/lib/components/ui/Checkbox/Checkbox"; +import { FileInput } from "@/lib/components/ui/FileInput/FileInput"; +import { Modal } from "@/lib/components/ui/Modal/Modal"; +import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; +import { Tabs } from "@/lib/components/ui/Tabs/Tabs"; +import { TextInput } from "@/lib/components/ui/TextInput/TextInput"; +import { + AddKnowledgeFileData, + AddKnowledgeUrlData, +} from "@/lib/types/Knowledge"; +import { Tab } from "@/lib/types/Tab"; + +import styles from "./AddKnowledgeModal.module.scss"; + +interface AddKnowledgeModalProps { + isOpen: boolean; + setIsOpen: (isOpen: boolean) => void; +} + +const AddKnowledgeModal = ({ + isOpen, + setIsOpen, +}: AddKnowledgeModalProps): JSX.Element => { + const [loading, setLoading] = useState(false); + const [currentUrl, setCurrentUrl] = useState(""); + const [files, setFiles] = useState([]); + const [urls, setUrls] = useState([]); + const [selectedTab, setSelectedTab] = useState("Documents"); + const [selectedKnowledges, setSelectedKnowledges] = useState< + (File | string)[] + >([]); + const { addKnowledgeFile, addKnowledgeUrl } = useKnowledgeApi(); + const { currentFolder, setRefetchFolderExplorer } = useKnowledgeContext(); + + const FILE_TYPES = ["pdf", "docx", "doc", "txt"]; + + const tabs: Tab[] = [ + { + label: "Documents", + isSelected: selectedTab === "Documents", + onClick: () => setSelectedTab("Documents"), + iconName: "file", + }, + { + label: "Websites' pages", + isSelected: selectedTab === "Websites", + onClick: () => setSelectedTab("Websites"), + iconName: "link", + }, + ]; + + const handleAddKnowledge = async () => { + setLoading(true); + try { + await Promise.all( + files.map(async (file) => { + try { + await addKnowledgeFile( + { + file_name: file.name, + parent_id: currentFolder?.id ?? null, + is_folder: false, + } as AddKnowledgeFileData, + file + ); + } catch (error) { + console.error("Failed to add knowledge:", error); + } + }) + ); + + await Promise.all( + urls.map(async (url) => { + try { + await addKnowledgeUrl({ + url: url, + parent_id: currentFolder?.id ?? null, + is_folder: false, + } as AddKnowledgeUrlData); + } catch (error) { + console.error("Failed to add knowledge from URL:", error); + } + }) + ); + } catch (error) { + console.error("Failed to add all knowledges:", error); + } finally { + setLoading(false); + setIsOpen(false); + setFiles([]); + setUrls([]); + setSelectedKnowledges([]); + setCurrentUrl(""); + setRefetchFolderExplorer(true); + } + }; + + const handleCancel = () => { + setIsOpen(false); + }; + + const handleFileChange = (newFiles: File[]) => { + setFiles((prevFiles) => [...prevFiles, ...newFiles]); + }; + + const handleCheckboxChange = (item: File | string, checked: boolean) => { + if (checked) { + setSelectedKnowledges([...selectedKnowledges, item]); + } else { + setSelectedKnowledges(selectedKnowledges.filter((f) => f !== item)); + } + }; + + const handleRemoveSelectedItems = () => { + setFiles(files.filter((file) => !selectedKnowledges.includes(file))); + setUrls(urls.filter((url) => !selectedKnowledges.includes(url))); + setSelectedKnowledges([]); + }; + + useEffect(() => { + if (!isOpen) { + setFiles([]); + setUrls([]); + setSelectedKnowledges([]); + setCurrentUrl(""); + } + }, [isOpen]); + + return ( +
+ } + CloseTrigger={
} + > +
+
+ +
+ {selectedTab === "Documents" && ( + + )} + {selectedTab === "Websites" && ( +
+ { + setCurrentUrl(""); + setUrls((prevUrls) => [...prevUrls, currentUrl]); + }} + /> +
+ )} +
+ {(!!files.length || !!urls.length) && ( +
+ +
+ )} +
+ {files.map((file, index) => ( +
+ + handleCheckboxChange(file, checked) + } + /> + {file.name} +
+ ))} + {urls.map((url, index) => ( +
+ handleCheckboxChange(url, checked)} + /> + {url} +
+ ))} +
+
+
+ + +
+
+ +
+ ); +}; + +export default AddKnowledgeModal; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.module.scss new file mode 100644 index 000000000000..519de1c530cc --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.module.scss @@ -0,0 +1,41 @@ +@use "styles/Spacings.module.scss"; + +.main_container { + height: 100%; + width: 100%; + + .current_folder_content { + display: flex; + flex-direction: column; + overflow: scroll; + + .content_header { + border-bottom: 1px solid var(--border-1); + padding-inline: Spacings.$spacing06; + padding-block: Spacings.$spacing05; + display: flex; + justify-content: space-between; + width: 100%; + + .right { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + } + + .left { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + } + } + + .loading_icon { + width: 100%; + display: flex; + align-items: center; + justify-content: center; + padding-top: Spacings.$spacing06; + } + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.tsx new file mode 100644 index 000000000000..e89496bf3463 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/CurrentFolderExplorer/QuivrCurrentFolder/QuivrCurrentFolder.tsx @@ -0,0 +1,235 @@ +import { UUID } from "crypto"; +import { useEffect, useState } from "react"; + +import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; +import { KMSElement } from "@/lib/api/sync/types"; +import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; +import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; +import { handleDragOver, handleDrop } from "@/lib/helpers/kms"; + +import AddFolderModal from "./AddFolderModal/AddFolderModal"; +import AddKnowledgeModal from "./AddKnowledgeModal/AddKnowledgeModal"; +import styles from "./QuivrCurrentFolder.module.scss"; + +import { useKnowledgeContext } from "../../../KnowledgeProvider/hooks/useKnowledgeContext"; +import AddToBrainsModal from "../../shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal"; +import CurrentFolderExplorerLine from "../../shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine"; +import FolderExplorerHeader from "../../shared/FolderExplorerHeader/FolderExplorerHeader"; + +interface QuivrCurrentFolderExplorerProps { + fromBrainStudio?: boolean; +} + +const QuivrCurrentFolder = ({ + fromBrainStudio, +}: QuivrCurrentFolderExplorerProps): JSX.Element => { + const [loading, setLoading] = useState(false); + const [deleteLoading, setDeleteLoading] = useState(false); + const [addFolderModalOpened, setAddFolderModalOpened] = useState(false); + const [addKnowledgeModalOpened, setAddKnowledgeModalOpened] = useState(false); + const [quivrElements, setQuivrElements] = useState(); + const [showAddToBrainsModal, setShowAddToBrainsModal] = + useState(false); + const { + currentFolder, + exploringQuivr, + selectedKnowledges, + setSelectedKnowledges, + setRefetchFolderMenu, + setRefetchFolderExplorer, + refetchFolderExplorer, + } = useKnowledgeContext(); + const { getFiles, deleteKnowledge, patchKnowledge } = useKnowledgeApi(); + + const fetchQuivrFiles = async (folderId: UUID | null) => { + setLoading(true); + try { + const res = await getFiles(folderId); + setQuivrElements(res); + } catch (error) { + console.error("Failed to get sync files:", error); + } finally { + setLoading(false); + } + }; + + const deleteKnowledges = async () => { + setDeleteLoading(true); + try { + await Promise.all( + selectedKnowledges.map((knowledge) => + deleteKnowledge({ knowledgeId: knowledge.id }) + ) + ); + await fetchQuivrFiles(currentFolder?.id ?? null); + setSelectedKnowledges([]); + } catch (error) { + console.error("Failed to delete knowledges:", error); + } finally { + setDeleteLoading(false); + setRefetchFolderMenu(true); + } + }; + + useEffect(() => { + if (!showAddToBrainsModal) { + setSelectedKnowledges([]); + } + }, [showAddToBrainsModal]); + + useEffect(() => { + if (exploringQuivr) { + void fetchQuivrFiles(currentFolder?.id ?? null); + setSelectedKnowledges([]); + } + }, [currentFolder]); + + useEffect(() => { + if (!addFolderModalOpened) { + void fetchQuivrFiles(currentFolder?.id ?? null); + } + }, [addFolderModalOpened]); + + useEffect(() => { + if (refetchFolderExplorer) { + void fetchQuivrFiles(currentFolder?.id ?? null); + setRefetchFolderExplorer(false); + } + }, [refetchFolderExplorer]); + + useEffect(() => { + const handleFetchQuivrFilesMissing = (event: CustomEvent) => { + void fetchQuivrFiles( + (event.detail as { draggedElement: KMSElement }).draggedElement + .parentKMSElement?.id ?? null + ); + }; + + window.addEventListener( + "needToFetch", + handleFetchQuivrFilesMissing as EventListener + ); + + return () => { + window.removeEventListener( + "needToFetch", + handleFetchQuivrFilesMissing as EventListener + ); + }; + }, []); + + const handleDragStart = ( + event: React.DragEvent, + element: KMSElement + ) => { + event.dataTransfer.setData("application/json", JSON.stringify(element)); + }; + + return ( + <> +
+ +
+ {loading ? ( +
+ +
+ ) : ( + <> + {!fromBrainStudio && ( +
+
+ void deleteKnowledges()} + small={true} + isLoading={deleteLoading} + disabled={!selectedKnowledges.length} + /> + setShowAddToBrainsModal(true)} + small={true} + disabled={!selectedKnowledges.length} + /> +
+
+ setAddFolderModalOpened(true)} + small={true} + /> + setAddKnowledgeModalOpened(true)} + small={true} + /> +
+
+ )} + {quivrElements + ?.sort((a, b) => Number(b.is_folder) - Number(a.is_folder)) + .map((element, index) => ( +
+ + void handleDrop({ + event, + targetElement: element, + patchKnowledge, + setRefetchFolderMenu, + fetchQuivrFiles, + currentFolder, + }) + : undefined + } + onDragOver={ + element.is_folder ? handleDragOver : undefined + } + /> +
+ ))} + + )} +
+
+ + + {showAddToBrainsModal && ( +
e.stopPropagation()} + > + setShowAddToBrainsModal(false)} + knowledges={selectedKnowledges} + /> +
+ )} + + ); +}; + +export default QuivrCurrentFolder; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.module.scss new file mode 100644 index 000000000000..4d679e87de71 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.module.scss @@ -0,0 +1,43 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.content_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing05; + + .brains_list { + display: flex; + flex-direction: column; + margin-top: Spacings.$spacing05; + border-top: 1px solid var(--border-0); + + .brain_line { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + font-size: Typography.$small; + padding-block: Spacings.$spacing03; + border-bottom: 1px solid var(--border-0); + + .sample_wrapper { + display: flex; + justify-content: center; + align-items: center; + border-radius: Radius.$normal; + min-width: 18px; + min-height: 18px; + max-width: 18px; + max-height: 18px; + border-radius: Radius.$normal; + font-size: Typography.$very_tiny; + } + } + } + + .button { + display: flex; + justify-content: flex-end; + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.tsx new file mode 100644 index 000000000000..0ddc734f70b5 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/AddToBrainsModal/AddToBrainsModal.tsx @@ -0,0 +1,166 @@ +import { useEffect, useState } from "react"; + +import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; +import { KMSElement } from "@/lib/api/sync/types"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { Checkbox } from "@/lib/components/ui/Checkbox/Checkbox"; +import { Modal } from "@/lib/components/ui/Modal/Modal"; +import { QuivrButton } from "@/lib/components/ui/QuivrButton/QuivrButton"; +import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; +import { Brain } from "@/lib/context/BrainProvider/types"; + +import styles from "./AddToBrainsModal.module.scss"; + +interface AddToBrainsModalProps { + isOpen: boolean; + setIsOpen: (isOpen: boolean) => void; + knowledges?: KMSElement[]; +} + +const AddToBrainsModal = ({ + isOpen, + setIsOpen, + knowledges, +}: AddToBrainsModalProps): JSX.Element => { + const [selectedBrains, setSelectedBrains] = useState([]); + const [initialBrains, setInitialBrains] = useState([]); + const [saveLoading, setSaveLoading] = useState(false); + + const { allBrains } = useBrainContext(); + const { linkKnowledgeToBrains, unlinkKnowledgeFromBrains } = + useKnowledgeApi(); + const { setRefetchFolderExplorer } = useKnowledgeContext(); + + useEffect(() => { + if (knowledges?.length === 1) { + const initialSelectedBrains = allBrains.filter((brain) => + knowledges[0].brains.some((kb) => kb.brain_id === brain.id) + ); + setSelectedBrains(initialSelectedBrains); + setInitialBrains(initialSelectedBrains); + } + }, [knowledges, allBrains]); + + const handleCheckboxChange = (brain: Brain, checked: boolean) => { + setSelectedBrains((prevSelectedBrains) => { + return checked + ? [...prevSelectedBrains, brain] + : prevSelectedBrains.filter((b) => b.id !== brain.id); + }); + }; + + const hasChanges = () => { + if (selectedBrains.length !== initialBrains.length) { + return true; + } + const selectedBrainIds = selectedBrains.map((b) => b.id).sort(); + const initialBrainIds = initialBrains.map((b) => b.id).sort(); + + return !selectedBrainIds.every( + (id, index) => id === initialBrainIds[index] + ); + }; + + const connectBrains = async (knowledgesToConnect: KMSElement[]) => { + const brainIdsToLink = selectedBrains.map((brain) => brain.id); + + try { + setSaveLoading(true); + await Promise.all( + knowledgesToConnect.map((knowledgeToConnect) => + linkKnowledgeToBrains(knowledgeToConnect, brainIdsToLink) + ) + ); + } catch (error) { + console.error("Failed to connect brains to knowledge", error); + } finally { + setSaveLoading(false); + setIsOpen(false); + setRefetchFolderExplorer(true); + } + }; + + const updateConnectedBrains = async (knowledge: KMSElement) => { + const knowledgeId = knowledge.id; + const brainIdsToLink = selectedBrains.map((brain) => brain.id); + const brainIdsToUnlink = initialBrains + .filter((brain) => !selectedBrains.some((b) => b.id === brain.id)) + .map((brain) => brain.id); + + try { + setSaveLoading(true); + if (brainIdsToLink.length > 0) { + await linkKnowledgeToBrains(knowledge, brainIdsToLink); + } + if (brainIdsToUnlink.length > 0) { + await unlinkKnowledgeFromBrains(knowledgeId, brainIdsToUnlink); + } + } catch (error) { + console.error("Failed to update knowledge to brains", error); + } finally { + setSaveLoading(false); + setIsOpen(false); + setRefetchFolderExplorer(true); + } + }; + + return ( +
event.stopPropagation()} + > + } + CloseTrigger={
} + > +
+
+ {allBrains + .filter((brain) => brain.brain_type !== "model") + .map((brain) => ( +
+
+ b.id === brain.id)} + setChecked={(checked) => + handleCheckboxChange(brain, checked) + } + /> +
+ {brain.snippet_emoji} +
+ {brain.name} +
+
+ ))} +
+
+ + knowledges && + (knowledges.length === 1 + ? updateConnectedBrains(knowledges[0]) + : connectBrains(knowledges)) + } + isLoading={saveLoading} + /> +
+
+ +
+ ); +}; + +export default AddToBrainsModal; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.module.scss new file mode 100644 index 000000000000..589b4e522293 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.module.scss @@ -0,0 +1,73 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +@mixin brain-container { + cursor: pointer; + position: relative; + + .sample_wrapper { + display: flex; + justify-content: center; + align-items: center; + border-radius: Radius.$normal; + min-width: 18px; + min-height: 18px; + max-width: 18px; + max-height: 18px; + font-size: Typography.$very_tiny; + + &.waiting { + opacity: 0.4; + } + } +} + +@mixin waiting-icon { + min-width: 18px; + min-height: 18px; + max-width: 18px; + max-height: 18px; + position: absolute; + display: flex; + justify-content: center; + align-items: center; + top: 0; + left: 0; + opacity: 1; +} + +.main_container { + display: flex; + align-items: center; + gap: Spacings.$spacing02; + + .brain_container { + @include brain-container; + } + + .waiting_icon { + @include waiting-icon; + } + + .more_brains { + cursor: pointer; + } +} + +.remaining_brains_tooltip { + display: flex; + gap: Spacings.$spacing02; + + .brain_container { + @include brain-container; + } + + .waiting_icon { + @include waiting-icon; + } +} + +.modal_content { + display: none; +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.tsx new file mode 100644 index 000000000000..bf97b401fccd --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/ConnectedBrains/ConnectedBrains.tsx @@ -0,0 +1,171 @@ +import { UUID } from "crypto"; +import { useRouter } from "next/navigation"; +import { useState } from "react"; + +import { KMSElement, KnowledgeStatus } from "@/lib/api/sync/types"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; +import Tooltip from "@/lib/components/ui/Tooltip/Tooltip"; +import { Brain } from "@/lib/context/BrainProvider/types"; + +import AddToBrainsModal from "./AddToBrainsModal/AddToBrainsModal"; +import styles from "./ConnectedBrains.module.scss"; + +interface ConnectedbrainsProps { + connectedBrains: Brain[]; + knowledge?: KMSElement; +} + +interface RemainingBrainsTooltipProps { + remainingBrains: Brain[]; + navigateToBrain: (brainId: UUID) => void; + isKnowledgeStatusWaiting: (status?: KnowledgeStatus) => boolean; + knowledgeStatus?: KnowledgeStatus; +} + +const RemainingBrainsTooltip = ({ + remainingBrains, + navigateToBrain, + isKnowledgeStatusWaiting, + knowledgeStatus, +}: RemainingBrainsTooltipProps): JSX.Element => { + return ( +
+ {remainingBrains.map((brain) => ( + +
{ + navigateToBrain(brain.brain_id ?? brain.id); + }} + > +
+ {brain.snippet_emoji} +
+
+
+ ))} +
+ ); +}; + +const ConnectedBrains = ({ + connectedBrains, + knowledge, +}: ConnectedbrainsProps): JSX.Element => { + const [showAddToBrainModal, setShowAddToBrainModal] = + useState(false); + const [showRemainingBrains, setShowRemainingBrains] = + useState(false); + const router = useRouter(); + + const navigateToBrain = (brainId: UUID) => { + router.push(`/studio/${brainId}`); + }; + + const isKnowledgeStatusWaiting = (status?: KnowledgeStatus): boolean => { + return status === "RESERVED" || status === "PROCESSING"; + }; + + const handleAddClick = (event: React.MouseEvent) => { + event.stopPropagation(); + event.preventDefault(); + setShowAddToBrainModal(true); + }; + + const handleModalClose = () => { + setShowAddToBrainModal(false); + }; + + const brainsToShow = connectedBrains.slice(0, 5); + const remainingBrains = connectedBrains.slice(5); + const showMore = connectedBrains.length > 5; + + return ( + <> +
+ {brainsToShow.map((brain) => ( + + <> +
{ + navigateToBrain(brain.brain_id ?? brain.id); + }} + > +
+ {brain.snippet_emoji} +
+ {isKnowledgeStatusWaiting(knowledge?.status) && ( +
+ +
+ )} + {knowledge?.status === "ERROR" && ( +
+ +
+ )} +
+ +
+ ))} + {showMore && ( + <> + {showRemainingBrains && ( +
+ +
+ )} +
setShowRemainingBrains(!showRemainingBrains)} + > + ... +
+ + )} + +
+ +
+
+
+ {showAddToBrainModal && ( +
e.stopPropagation()} + > + +
+ )} + + ); +}; + +export default ConnectedBrains; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.module.scss new file mode 100644 index 000000000000..28e935842945 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.module.scss @@ -0,0 +1,74 @@ +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.folder_explorer_line_wrapper { + display: flex; + align-items: center; + padding-inline: Spacings.$spacing06; + padding-block: Spacings.$spacing03; + font-size: Typography.$small; + display: flex; + justify-content: space-between; + align-items: center; + border: 1px solid transparent; + border-bottom: 1px solid var(--border-1); + width: 100%; + overflow: hidden; + gap: Spacings.$spacing05; + + &.dragged { + border-color: var(--accent); + } + + .left { + display: flex; + flex: 1; + align-items: center; + gap: Spacings.$spacing03; + overflow: hidden; + + .checkbox { + padding-right: Spacings.$spacing03; + } + + .file_icon { + height: 12px; + width: 12px; + } + + .name { + @include Typography.EllipsisOverflow; + flex: 1; + + &.url { + &:hover { + cursor: pointer; + color: var(--primary-0); + } + } + } + } + + .right { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + + .visible { + visibility: visible; + } + + .hidden { + visibility: hidden; + } + } + + &.folder { + font-weight: 550; + + &:hover { + cursor: pointer; + background-color: var(--background-2); + } + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.tsx new file mode 100644 index 000000000000..c390f9249572 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/CurrentFolderExplorerLine/CurrentFolderExplorerLine.tsx @@ -0,0 +1,170 @@ +import { useState } from "react"; + +import { KMSElement } from "@/lib/api/sync/types"; +import { Checkbox } from "@/lib/components/ui/Checkbox/Checkbox"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { iconList } from "@/lib/helpers/iconList"; + +import ConnectedBrains from "./ConnectedBrains/ConnectedBrains"; +import styles from "./CurrentFolderExplorerLine.module.scss"; + +import { useKnowledgeContext } from "../../../KnowledgeProvider/hooks/useKnowledgeContext"; + +interface CurrentFolderExplorerLineProps { + element: KMSElement; + onDragStart?: ( + event: React.DragEvent, + element: KMSElement + ) => void; + onDrop?: ( + event: React.DragEvent, + element: KMSElement + ) => void; + onDragOver?: (event: React.DragEvent) => void; +} + +const getFileType = (fileName?: string): string => { + return fileName?.includes(".") + ? fileName.split(".").pop()?.toLowerCase() ?? "default" + : "default"; +}; + +const getIconColor = (fileType: string): string => { + const iconColors: { [key: string]: string } = { + pdf: "#E44A4D", + + csv: "#4EB35E", + xlsx: "#4EB35E", + xls: "#4EB35E", + + docx: "#47A8EF", + doc: "#47A8EF", + docm: "#47A8EF", + + png: "#A36BAD", + jpg: "#A36BAD", + + pptx: "#F07114", + ppt: "#F07114", + + mp3: "#FFC220", + mp4: "#FFC220", + wav: "#FFC220", + + html: "#F16529", + py: "#F16529", + }; + + return iconColors[fileType.toLowerCase()] ?? "#B1B9BE"; +}; + +const getIconName = (element: KMSElement, fileType: string): string => { + return element.url + ? "link" + : element.is_folder + ? "folder" + : fileType !== "default" + ? iconList[fileType.toLocaleLowerCase()] + ? fileType.toLowerCase() + : "file" + : "file"; +}; + +const CurrentFolderExplorerLine = ({ + element, + onDragStart, + onDrop, + onDragOver, +}: CurrentFolderExplorerLineProps): JSX.Element => { + const { setCurrentFolder } = useKnowledgeContext(); + const { selectedKnowledges, setSelectedKnowledges } = useKnowledgeContext(); + const [isDraggedOver, setIsDraggedOver] = useState(false); + + const fileType = getFileType(element.file_name); + + const handleCheckboxChange = (checked: boolean) => { + if (checked) { + setSelectedKnowledges([...selectedKnowledges, element]); + } else { + setSelectedKnowledges( + selectedKnowledges.filter((knowledge) => knowledge.id !== element.id) + ); + } + }; + + const handleClick = () => { + if (element.is_folder) { + setCurrentFolder({ + ...element, + parentKMSElement: element.parentKMSElement, + }); + } + }; + + const handleDrop = (event: React.DragEvent) => { + onDrop?.(event, element); + setIsDraggedOver(false); + }; + + const handleDragOver = (event: React.DragEvent) => { + onDragOver?.(event); + setIsDraggedOver(true); + }; + + return ( +
onDragStart?.(event, element) + : undefined + } + onDrop={element.source === "local" ? handleDrop : undefined} + onDragOver={element.source === "local" ? handleDragOver : undefined} + onDragLeave={() => setIsDraggedOver(false)} + > +
+
+ handleCheckboxChange(checked)} + /> +
+ + { + if (element.url) { + event.stopPropagation(); + window.open(element.url, "_blank"); + } + }} + > + {element.file_name ?? element.url} + +
+
+ +
+ +
+
+
+ ); +}; + +export default CurrentFolderExplorerLine; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.module.scss new file mode 100644 index 000000000000..52eefbeab055 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.module.scss @@ -0,0 +1,80 @@ +@use "styles/BoxShadow.module.scss"; +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; + +.header_wrapper { + padding-inline: Spacings.$spacing06; + padding-block: Spacings.$spacing05; + border-bottom: 1px solid var(--border-0); + display: flex; + align-items: center; + gap: Spacings.$spacing03; + font-weight: 550; + background-color: var(--background-0); + height: 64px; + + .name { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + + &.hoverable { + &:hover { + cursor: pointer; + color: var(--primary-0); + } + } + + .hoverable { + &:hover { + cursor: pointer; + color: var(--primary-0); + } + } + + .selected { + background-color: var(--background-6); + padding-block: Spacings.$spacing02; + padding-inline: Spacings.$spacing03; + border-radius: Radius.$normal; + } + } + + .parent_folder { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + + .icon { + min-width: 18px; + max-width: 18px; + min-height: 18px; + max-height: 18px; + display: flex; + align-items: center; + justify-content: center; + } + + .name { + &:hover { + cursor: pointer; + color: var(--primary-0); + } + } + } + + .current_folder { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + + .name { + &.selected { + background-color: var(--background-6); + padding-block: Spacings.$spacing02; + padding-inline: Spacings.$spacing03; + border-radius: Radius.$normal; + } + } + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.tsx b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.tsx new file mode 100644 index 000000000000..fd59f2223a5d --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Explorer/shared/FolderExplorerHeader/FolderExplorerHeader.tsx @@ -0,0 +1,185 @@ +"use client"; + +import { KMSElement, Sync } from "@/lib/api/sync/types"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { transformConnectionLabel } from "@/lib/helpers/providers"; + +import styles from "./FolderExplorerHeader.module.scss"; + +import { useKnowledgeContext } from "../../../KnowledgeProvider/hooks/useKnowledgeContext"; + +interface QuivrHeaderProps { + currentFolder: KMSElement | undefined; + loadRoot: () => void; +} + +const QuivrHeader = ({ currentFolder, loadRoot }: QuivrHeaderProps) => ( + <> + void loadRoot()} + > + Quivr + + {currentFolder && } + +); + +interface ProviderHeaderProps { + currentFolder: KMSElement | undefined; + exploredSpecificAccount: Sync | undefined; + loadRoot: () => void; + exploredProvider: { provider: string } | undefined; + setCurrentFolder: (folder: KMSElement | undefined) => void; +} + +const ProviderHeader = ({ + currentFolder, + exploredProvider, + loadRoot, + exploredSpecificAccount, + setCurrentFolder, +}: ProviderHeaderProps) => ( +
+ {!currentFolder ? ( + <> + loadRoot()} + > + {transformConnectionLabel(exploredProvider?.provider ?? "")} + + {!!exploredSpecificAccount && ( + + )} + {exploredSpecificAccount && ( + setCurrentFolder(undefined)} + > + {exploredSpecificAccount.email} + + )} + + ) : ( + <> + {exploredSpecificAccount && ( + setCurrentFolder(undefined)}> + {exploredSpecificAccount.email} + + )} + {!currentFolder.parentKMSElement && ( + <> + {!exploredSpecificAccount && ( + loadRoot()}> + {transformConnectionLabel(exploredProvider?.provider ?? "")} + + )} + + + )} + + )} +
+); + +interface ParentFolderHeaderProps { + currentFolder: KMSElement; + loadParentFolder: () => void; +} + +const ParentFolderHeader = ({ + currentFolder, + loadParentFolder, +}: ParentFolderHeaderProps) => ( +
+ void loadParentFolder()}> + {currentFolder.parentKMSElement?.file_name?.replace(/(\..+)$/, "")} + + +
+); + +interface CurrentFolderHeaderProps { + currentFolder: KMSElement | undefined; + exploringQuivr: boolean; + exploredSpecificAccount?: Sync; +} + +const CurrentFolderHeader = ({ + currentFolder, + exploringQuivr, + exploredSpecificAccount, +}: CurrentFolderHeaderProps) => ( +
+ {currentFolder?.icon && ( +
{currentFolder.icon}
+ )} + + {currentFolder?.file_name?.replace(/(\..+)$/, "")} + +
+); + +const FolderExplorerHeader = (): JSX.Element => { + const { + currentFolder, + setCurrentFolder, + exploringQuivr, + exploredProvider, + setExploredSpecificAccount, + exploredSpecificAccount, + } = useKnowledgeContext(); + + const loadParentFolder = () => { + if (currentFolder?.parentKMSElement) { + setCurrentFolder({ + ...currentFolder.parentKMSElement, + parentKMSElement: currentFolder.parentKMSElement.parentKMSElement, + }); + } + }; + + const loadRoot = () => { + setCurrentFolder(undefined); + setExploredSpecificAccount(undefined); + }; + + return ( +
+ {exploringQuivr && !currentFolder?.parentKMSElement ? ( + + ) : (exploredProvider || exploredSpecificAccount) && + !currentFolder?.parentKMSElement ? ( + + ) : ( + currentFolder?.parentKMSElement && ( + + ) + )} + +
+ ); +}; + +export default FolderExplorerHeader; diff --git a/frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.module.scss b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.module.scss new file mode 100644 index 000000000000..5619f7e60fd2 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.module.scss @@ -0,0 +1,66 @@ +@use "styles/BoxShadow.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.content_wrapper { + display: flex; + min-height: 100%; + flex: 1; + + &.from_brain_studio { + border-top: 1px solid var(--border-0); + } + + .folders_wrapper { + padding-inline: Spacings.$spacing05; + padding-block: Spacings.$spacing05; + display: flex; + flex-direction: column; + gap: Spacings.$spacing06; + box-shadow: BoxShadow.$large; + + .folders { + display: flex; + flex-direction: column; + gap: Spacings.$spacing04; + overflow-x: hidden; + + .quivr_folder { + padding-top: Spacings.$spacing01; + padding-bottom: calc(Spacings.$spacing05 + 3px); + border-bottom: 1px solid var(--border-0); + } + } + } + + $resizeWidth: 8px; + + .resize_wrapper { + min-height: 100%; + width: $resizeWidth; + cursor: col-resize; + background-color: var(--background-0); + + .resize_handle { + height: 100%; + margin-left: calc(#{$resizeWidth} - 1px); + width: 1px; + cursor: col-resize; + background-color: var(--border-0); + } + + &:hover { + background-color: var(--background-3); + + .resize_handle { + background-color: var(--background-3); + } + } + } + + .folder_content { + flex: 1; + overflow-y: hidden; + background-color: red; + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.tsx b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.tsx new file mode 100644 index 000000000000..f70024dd597e --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeManagementSystem.tsx @@ -0,0 +1,93 @@ +"use client"; + +import { useEffect, useRef, useState } from "react"; + +import CurrentFolderExplorer from "./Explorer/CurrentFolderExplorer/CurrentFolderExplorer"; +import styles from "./KnowledgeManagementSystem.module.scss"; +import { KnowledgeProvider } from "./KnowledgeProvider/knowledge-provider"; +import ConnectionsKnowledges from "./Menu/ConnectionsKnowledge/ConnectionsKnowledges"; +import QuivrKnowledges from "./Menu/QuivrKnowledge/QuivrKnowledges"; + +interface KnowledgeManagementSystemProps { + fromBrainStudio?: boolean; +} + +const KnowledgeManagementSystem = ({ + fromBrainStudio, +}: KnowledgeManagementSystemProps): JSX.Element => { + const [isResizing, setIsResizing] = useState(false); + const [foldersWidth, setFoldersWidth] = useState(400); + const [initialMouseX, setInitialMouseX] = useState(0); + const [initialWidth, setInitialWidth] = useState(400); + const resizeHandleRef = useRef(null); + const foldersRef = useRef(null); + + useEffect(() => { + const handleMouseMove = (e: MouseEvent) => { + if (isResizing && foldersRef.current) { + const newWidth = initialWidth + (e.clientX - initialMouseX); + setFoldersWidth(newWidth < 200 ? 200 : newWidth > 900 ? 900 : newWidth); + } + }; + + const handleMouseUp = () => { + setIsResizing(false); + document.body.style.userSelect = ""; + }; + + if (isResizing) { + window.addEventListener("mousemove", handleMouseMove); + window.addEventListener("mouseup", handleMouseUp); + document.body.style.userSelect = "none"; + } + + return () => { + window.removeEventListener("mousemove", handleMouseMove); + window.removeEventListener("mouseup", handleMouseUp); + document.body.style.userSelect = ""; + }; + }, [isResizing, initialMouseX, initialWidth]); + + return ( + +
+
+
+
+ +
+ +
+
+
{ + setIsResizing(true); + setInitialMouseX(e.clientX); + setInitialWidth(foldersWidth); + }} + > +
+
+
+ +
+
+
+ ); +}; + +export default KnowledgeManagementSystem; diff --git a/frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext.tsx b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext.tsx new file mode 100644 index 000000000000..3813d71cd7aa --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext.tsx @@ -0,0 +1,15 @@ +import { useContext } from "react"; + +import { KnowledgeContext } from "../knowledge-provider"; + +// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types +export const useKnowledgeContext = () => { + const context = useContext(KnowledgeContext); + if (context === undefined) { + throw new Error( + "useKnowledgeContext must be used within a KnowledgeProvider" + ); + } + + return context; +}; diff --git a/frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/knowledge-provider.tsx b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/knowledge-provider.tsx new file mode 100644 index 000000000000..4739b0d515ba --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/KnowledgeProvider/knowledge-provider.tsx @@ -0,0 +1,76 @@ +import { createContext, useState } from "react"; + +import { KMSElement, Sync, SyncsByProvider } from "@/lib/api/sync/types"; + +type KnowledgeContextType = { + currentFolder: KMSElement | undefined; + setCurrentFolder: React.Dispatch< + React.SetStateAction + >; + exploringQuivr: boolean; + setExploringQuivr: React.Dispatch>; + exploredProvider: SyncsByProvider | undefined; + setExploredProvider: React.Dispatch< + React.SetStateAction + >; + exploredSpecificAccount: Sync | undefined; + setExploredSpecificAccount: React.Dispatch< + React.SetStateAction + >; + selectedKnowledges: KMSElement[]; + setSelectedKnowledges: React.Dispatch>; + refetchFolderMenu: boolean; + setRefetchFolderMenu: React.Dispatch>; + refetchFolderExplorer: boolean; + setRefetchFolderExplorer: React.Dispatch>; +}; + +export const KnowledgeContext = createContext( + undefined +); + +export const KnowledgeProvider = ({ + children, +}: { + children: React.ReactNode; +}): JSX.Element => { + const [currentFolder, setCurrentFolder] = useState( + undefined + ); + const [exploringQuivr, setExploringQuivr] = useState(false); + const [exploredProvider, setExploredProvider] = useState< + SyncsByProvider | undefined + >(undefined); + const [exploredSpecificAccount, setExploredSpecificAccount] = useState< + Sync | undefined + >(undefined); + const [selectedKnowledges, setSelectedKnowledges] = useState( + [] + ); + const [refetchFolderMenu, setRefetchFolderMenu] = useState(false); + const [refetchFolderExplorer, setRefetchFolderExplorer] = + useState(false); + + return ( + + {children} + + ); +}; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.module.scss new file mode 100644 index 000000000000..fc194aedf6be --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.module.scss @@ -0,0 +1,58 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.account_section_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing03; + + .account_line_wrapper { + display: flex; + align-items: center; + gap: Spacings.$spacing02; + padding-left: Spacings.$spacing03; + font-size: Typography.$small; + font-weight: 500; + + .hoverable { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + padding-inline: Spacings.$spacing02; + width: 100%; + + .name { + @include Typography.EllipsisOverflow; + } + + &:hover { + background-color: var(--background-3); + cursor: pointer; + border-radius: Radius.$small; + } + } + } + + .loader_icon { + padding-left: Spacings.$spacing06; + } + + .sync_elements_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing02; + font-size: Typography.$small; + margin-left: Spacings.$spacing05; + border-left: 0.5px solid var(--border-0); + + &.empty { + height: 0; + } + + &.single_account { + border-left: none; + margin-left: 0; + } + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.tsx b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.tsx new file mode 100644 index 000000000000..80f9bb070d06 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionAccount/ConnectionAccount.tsx @@ -0,0 +1,115 @@ +import { useEffect, useState } from "react"; + +import { KMSElement, Sync, SyncsByProvider } from "@/lib/api/sync/types"; // Assurez-vous que KMSElement est bien importé +import { useSync } from "@/lib/api/sync/useSync"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { ConnectionIcon } from "@/lib/components/ui/ConnectionIcon/ConnectionIcon"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; + +import styles from "./ConnectionAccount.module.scss"; + +import SyncFolder from "../SyncFolder/SyncFolder"; + +interface ConnectionAccountProps { + sync: Sync; + index: number; + singleAccount?: boolean; + providerGroup?: SyncsByProvider; + parentFolded?: boolean; +} + +const ConnectionAccount = ({ + sync, + index, + singleAccount, + providerGroup, + parentFolded, +}: ConnectionAccountProps): JSX.Element => { + const [loading, setLoading] = useState(false); + const [syncElements, setKMSElements] = useState(); + const [folded, setFolded] = useState(true); + const { getSyncFiles } = useSync(); + const { + setExploringQuivr, + setCurrentFolder, + setExploredProvider, + setExploredSpecificAccount, + } = useKnowledgeContext(); + + const getFiles = () => { + setLoading(true); + void (async () => { + try { + const res = await getSyncFiles(sync.id); + setKMSElements(res); + setLoading(false); + } catch (error) { + console.error("Failed to get sync files:", error); + } + })(); + }; + + const chooseAccount = () => { + setExploredSpecificAccount(sync); + setCurrentFolder(undefined); + setExploredProvider(providerGroup); + setExploringQuivr(false); + }; + + useEffect(() => { + if (!folded || (singleAccount && !parentFolded)) { + getFiles(); + } + }, [folded, parentFolded]); + + return ( +
+ {!singleAccount && ( +
+ setFolded(!folded)} + /> +
chooseAccount()}> + + {sync.email} +
+
+ )} + {(!singleAccount && !folded) || singleAccount ? ( + loading ? ( +
+ +
+ ) : ( +
file.is_folder).length + ? styles.empty + : "" + } ${singleAccount ? styles.single_account : ""}`} + > + {syncElements + ?.filter((file) => file.is_folder) + .map((element, id) => ( +
+ +
+ ))} +
+ ) + ) : null} +
+ ); +}; + +export default ConnectionAccount; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.module.scss new file mode 100644 index 000000000000..9842935be698 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.module.scss @@ -0,0 +1,44 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; + +.connection_knowledges_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing03; + + .provider_line_wrapper { + display: flex; + align-items: center; + gap: Spacings.$spacing02; + + .hoverable { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + padding-inline: Spacings.$spacing02; + width: 100%; + + .provider_title { + font-weight: 500; + } + + &:hover { + background-color: var(--background-3); + cursor: pointer; + border-radius: Radius.$small; + } + } + } + + .accounts { + display: flex; + flex-direction: column; + gap: Spacings.$spacing03; + border-left: 0.5px solid var(--border-0); + margin-left: Spacings.$spacing03; + + &.folded { + display: none; + } + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.tsx b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.tsx new file mode 100644 index 000000000000..f2526e3c5bc3 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/ConnectionKnowledges.tsx @@ -0,0 +1,75 @@ +import Image from "next/image"; +import { useState } from "react"; + +import { SyncsByProvider } from "@/lib/api/sync/types"; +import { useSync } from "@/lib/api/sync/useSync"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { transformConnectionLabel } from "@/lib/helpers/providers"; + +import ConnectionAccount from "./ConnectionAccount/ConnectionAccount"; +import styles from "./ConnectionKnowledges.module.scss"; + +import { useKnowledgeContext } from "../../../KnowledgeProvider/hooks/useKnowledgeContext"; + +interface ConnectionKnowledgeProps { + providerGroup: SyncsByProvider; +} + +const ConnectionKnowledges = ({ + providerGroup, +}: ConnectionKnowledgeProps): JSX.Element => { + const [folded, setFolded] = useState(true); + const { providerIconUrls } = useSync(); + const { + setExploringQuivr, + setCurrentFolder, + setExploredProvider, + setExploredSpecificAccount, + } = useKnowledgeContext(); + + const selectProvider = () => { + setCurrentFolder(undefined); + setExploredSpecificAccount(undefined); + setExploringQuivr(false); + setExploredProvider(providerGroup); + }; + + return ( +
+
+ setFolded(!folded)} + /> +
selectProvider()}> + {providerGroup.provider} + + {transformConnectionLabel(providerGroup.provider)} + +
+
+
+ {providerGroup.syncs.map((sync, index) => ( + + ))} +
+
+ ); +}; + +export default ConnectionKnowledges; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.module.scss new file mode 100644 index 000000000000..f0ea9634a3ee --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.module.scss @@ -0,0 +1,46 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.folder_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing02; + margin-left: Spacings.$spacing03; + + &.empty { + gap: 0; + } + + .folder_line_wrapper { + display: flex; + align-items: center; + font-size: Typography.$tiny; + + .name { + @include Typography.EllipsisOverflow; + padding: Spacings.$spacing01; + padding-inline: Spacings.$spacing03; + + &:hover, + &.selected { + background-color: var(--background-6); + cursor: pointer; + border-radius: Radius.$normal; + } + } + } + + .loader_icon { + padding-left: Spacings.$spacing06; + } + + .sync_elements_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing02; + border-left: 0.5px solid var(--border-0); + margin-left: Spacings.$spacing03; + font-size: Typography.$small; + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.tsx b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.tsx new file mode 100644 index 000000000000..65a3ad500d85 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionKnowledges/SyncFolder/SyncFolder.tsx @@ -0,0 +1,116 @@ +import { useEffect, useState } from "react"; + +import { KMSElement } from "@/lib/api/sync/types"; +import { useSync } from "@/lib/api/sync/useSync"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; + +import styles from "./SyncFolder.module.scss"; + +interface SyncFolderProps { + element: KMSElement; +} + +const SyncFolder = ({ element }: SyncFolderProps): JSX.Element => { + const [folded, setFolded] = useState(true); + const [loading, setLoading] = useState(false); + const { getSyncFiles } = useSync(); + const [syncElements, setKMSElements] = useState(); + const [selectedFolder, setSelectedFolder] = useState(false); + + const { + currentFolder, + setCurrentFolder, + setExploringQuivr, + setExploredProvider, + } = useKnowledgeContext(); + + useEffect(() => { + setSelectedFolder(currentFolder?.sync_file_id === element.sync_file_id); + if (currentFolder) { + if (currentFolder.sync_file_id) { + setExploringQuivr(false); + if (currentFolder.fromProvider) { + setExploredProvider(currentFolder.fromProvider); + } + } + } + }, [currentFolder]); + + useEffect(() => { + if (!folded && element.sync_id !== null) { + setLoading(true); + void (async () => { + try { + if (element.sync_id === null || element.sync_file_id === null) { + throw new Error("sync_id is null"); + } + const res = await getSyncFiles(element.sync_id, element.sync_file_id); + setKMSElements(res); + setLoading(false); + } catch (error) { + console.error("Failed to get sync files:", error); + setLoading(false); + } + })(); + } + }, [folded]); + + return ( +
file.is_folder).length && !loading + ? styles.empty + : "" + }`} + > +
+ setFolded(!folded)} + /> + { + setCurrentFolder({ + ...element, + parentKMSElement: element.parentKMSElement, + }); + }} + > + {element.file_name?.includes(".") + ? element.file_name.split(".").slice(0, -1).join(".") + : element.file_name} + +
+ {!folded && + (loading ? ( +
+ +
+ ) : ( +
+ {syncElements + ?.filter((file) => file.is_folder) + .map((folder, id) => ( +
+ +
+ ))} +
+ ))} +
+ ); +}; + +export default SyncFolder; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledge.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledge.module.scss new file mode 100644 index 000000000000..1f95f15dbb48 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledge.module.scss @@ -0,0 +1,12 @@ +@use "styles/Spacings.module.scss"; + +.connections_knowledge_container { + display: flex; + flex-direction: column; + gap: Spacings.$spacing04; + overflow-y: hidden; + + &:hover { + overflow-y: auto; + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledges.tsx b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledges.tsx new file mode 100644 index 000000000000..0b8085e48a93 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/ConnectionsKnowledge/ConnectionsKnowledges.tsx @@ -0,0 +1,58 @@ +import { useEffect, useState } from "react"; + +import { Provider, Sync, SyncsByProvider } from "@/lib/api/sync/types"; +import { useSync } from "@/lib/api/sync/useSync"; + +import ConnectionKnowledges from "./ConnectionKnowledges/ConnectionKnowledges"; +import styles from "./ConnectionsKnowledge.module.scss"; + +const ConnectionsKnowledges = (): JSX.Element => { + const [syncsByProvider, setSyncsByProvider] = useState([]); + const { getUserSyncs } = useSync(); + + const fetchUserSyncs = async () => { + try { + const res: Sync[] = await getUserSyncs(); + const groupedByProvider: { [key: string]: Sync[] } = {}; + + res + .filter( + (sync) => sync.credentials.token || sync.credentials.access_token + ) + .forEach((sync) => { + const providerLowerCase = sync.provider.toLowerCase(); + if (!groupedByProvider[providerLowerCase]) { + groupedByProvider[providerLowerCase] = []; + } + groupedByProvider[providerLowerCase].push(sync); + }); + + const syncsByProviderArray: SyncsByProvider[] = Object.keys( + groupedByProvider + ).map((provider) => ({ + provider: provider as Provider, + syncs: groupedByProvider[provider], + })); + + setSyncsByProvider(syncsByProviderArray); + } catch (error) { + console.error(error); + } + }; + + useEffect(() => { + void fetchUserSyncs(); + }, []); + + return ( +
+ {syncsByProvider.map((providerGroup) => ( +
+ +
+ ))} +
+ ); +}; + +export default ConnectionsKnowledges; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.module.scss new file mode 100644 index 000000000000..f673cf8cb634 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.module.scss @@ -0,0 +1,51 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.folder_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing02; + margin-left: Spacings.$spacing03; + border: 1px solid transparent; + + &.dragged { + border: 1px solid var(--accent); + } + + &.empty { + gap: 0; + } + + .folder_line_wrapper { + display: flex; + align-items: center; + font-size: Typography.$tiny; + + .name { + @include Typography.EllipsisOverflow; + padding: Spacings.$spacing01; + padding-inline: Spacings.$spacing03; + + &:hover, + &.selected { + background-color: var(--background-6); + cursor: pointer; + border-radius: Radius.$normal; + } + } + } + + .loader_icon { + padding-left: Spacings.$spacing06; + } + + .kms_elements_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing02; + border-left: 0.5px solid var(--border-0); + margin-left: Spacings.$spacing03; + font-size: Typography.$small; + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.tsx b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.tsx new file mode 100644 index 000000000000..57bb75697b1d --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrFolder/QuivrFolder.tsx @@ -0,0 +1,131 @@ +import { useEffect, useState } from "react"; + +import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; +import { KMSElement } from "@/lib/api/sync/types"; +import { useKnowledgeContext } from "@/lib/components/KnowledgeManagementSystem/KnowledgeProvider/hooks/useKnowledgeContext"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; +import { handleDragOver, handleDrop } from "@/lib/helpers/kms"; + +import styles from "./QuivrFolder.module.scss"; + +interface QuivrFolderProps { + element: KMSElement; +} + +const QuivrFolder = ({ element }: QuivrFolderProps): JSX.Element => { + const [folded, setFolded] = useState(true); + const [loading, setLoading] = useState(false); + const [isDraggedOver, setIsDraggedOver] = useState(false); + const { getFiles } = useKnowledgeApi(); + const [kmsElements, setKMSElements] = useState(); + const [selectedFolder, setSelectedFolder] = useState(false); + + const { + currentFolder, + setCurrentFolder, + setExploringQuivr, + setExploredProvider, + setRefetchFolderMenu, + } = useKnowledgeContext(); + const { patchKnowledge } = useKnowledgeApi(); + + useEffect(() => { + setSelectedFolder(currentFolder?.id === element.id); + if (currentFolder?.source === "local") { + setExploringQuivr(true); + setExploredProvider(undefined); + } + }, [currentFolder]); + + useEffect(() => { + if (!folded) { + setLoading(true); + void (async () => { + try { + const res = await getFiles(element.id); + setKMSElements(res); + setLoading(false); + } catch (error) { + console.error("Failed to get sync files:", error); + setLoading(false); + } + })(); + } + }, [folded]); + + return ( +
file.is_folder).length && !loading + ? styles.empty + : "" + } ${isDraggedOver ? styles.dragged : ""}`} + onDrop={ + element.is_folder + ? (event) => + void handleDrop({ + event, + targetElement: element, + patchKnowledge, + setRefetchFolderMenu, + currentFolder, + }) + : undefined + } + onDragOver={(event) => { + if (element.is_folder) { + handleDragOver(event); + setIsDraggedOver(true); + } + }} + onDragLeave={() => setIsDraggedOver(false)} + > +
+ setFolded(!folded)} + /> + { + setCurrentFolder({ + ...element, + parentKMSElement: element.parentKMSElement, + }); + }} + > + {element.file_name?.includes(".") + ? element.file_name.split(".").slice(0, -1).join(".") + : element.file_name} + +
+ {!folded && + (loading ? ( +
+ +
+ ) : ( +
+ {kmsElements + ?.filter((file) => file.is_folder) + .map((folder, id) => ( +
+ +
+ ))} +
+ ))} +
+ ); +}; + +export default QuivrFolder; diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.module.scss b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.module.scss new file mode 100644 index 000000000000..b24f2fc29400 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.module.scss @@ -0,0 +1,52 @@ +@use "styles/Radius.module.scss"; +@use "styles/Spacings.module.scss"; +@use "styles/Typography.module.scss"; + +.main_container { + display: flex; + flex-direction: column; + gap: Spacings.$spacing03; + + .header_section_wrapper { + display: flex; + align-items: center; + gap: Spacings.$spacing02; + font-weight: 500; + border: 1px solid transparent; + + &.dragged { + border-color: var(--accent); + } + + .hoverable { + display: flex; + align-items: center; + gap: Spacings.$spacing03; + width: 100%; + padding-inline: Spacings.$spacing02; + + &:hover { + background-color: var(--background-3); + cursor: pointer; + border-radius: Radius.$small; + } + } + } + + .loader_icon { + padding-left: Spacings.$spacing06; + } + + .sync_elements_wrapper { + display: flex; + flex-direction: column; + gap: Spacings.$spacing02; + font-size: Typography.$small; + margin-left: Spacings.$spacing03; + border-left: 0.5px solid var(--border-0); + + &.empty { + height: 0; + } + } +} diff --git a/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.tsx b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.tsx new file mode 100644 index 000000000000..1f8457817550 --- /dev/null +++ b/frontend/lib/components/KnowledgeManagementSystem/Menu/QuivrKnowledge/QuivrKnowledges.tsx @@ -0,0 +1,126 @@ +import { useEffect, useState } from "react"; + +import { useKnowledgeApi } from "@/lib/api/knowledge/useKnowledgeApi"; +import { KMSElement } from "@/lib/api/sync/types"; +import { QuivrLogo } from "@/lib/assets/QuivrLogo"; +import { Icon } from "@/lib/components/ui/Icon/Icon"; +import { LoaderIcon } from "@/lib/components/ui/LoaderIcon/LoaderIcon"; +import { useUserSettingsContext } from "@/lib/context/UserSettingsProvider/hooks/useUserSettingsContext"; +import { handleDrop } from "@/lib/helpers/kms"; + +import QuivrFolder from "./QuivrFolder/QuivrFolder"; +import styles from "./QuivrKnowledges.module.scss"; + +import { useKnowledgeContext } from "../../KnowledgeProvider/hooks/useKnowledgeContext"; + +const QuivrKnowledges = (): JSX.Element => { + const [folded, setFolded] = useState(true); + const [kmsElements, setKMSElements] = useState(); + const [loading, setLoading] = useState(false); + const [isDraggedOver, setIsDraggedOver] = useState(false); + const { isDarkMode } = useUserSettingsContext(); + const { + setExploringQuivr, + setCurrentFolder, + setExploredProvider, + setRefetchFolderMenu, + refetchFolderMenu, + currentFolder, + } = useKnowledgeContext(); + + const { getFiles, patchKnowledge } = useKnowledgeApi(); + + const chooseQuivrRoot = () => { + setCurrentFolder(undefined); + setExploredProvider(undefined); + setExploringQuivr(true); + }; + + const fetchFiles = async () => { + setLoading(true); + try { + const res = await getFiles(null); + setKMSElements(res); + } catch (error) { + console.error("Failed to get sync files:", error); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + void fetchFiles(); + }, []); + + useEffect(() => { + if (refetchFolderMenu) { + void fetchFiles(); + setRefetchFolderMenu(false); + } + }, [refetchFolderMenu]); + + return ( +
+
{ + void handleDrop({ + event, + targetElement: null, + patchKnowledge, + setRefetchFolderMenu, + currentFolder, + }); + setIsDraggedOver(false); + }} + onDragOver={(event) => { + event.preventDefault(); + setIsDraggedOver(true); + }} + onDragLeave={() => setIsDraggedOver(false)} + > + setFolded(!folded)} + /> +
void chooseQuivrRoot()} + > + + Quivr +
+
+ {!folded ? ( + loading ? ( +
+ +
+ ) : ( +
file.is_folder).length + ? styles.empty + : "" + } `} + > + {kmsElements + ?.filter((file) => file.is_folder) + .map((element, id) => ( +
+ +
+ ))} +
+ ) + ) : null} +
+ ); +}; + +export default QuivrKnowledges; diff --git a/frontend/lib/components/KnowledgeToFeedInput/KnowledgeToFeedInput.tsx b/frontend/lib/components/KnowledgeToFeedInput/KnowledgeToFeedInput.tsx deleted file mode 100644 index 0880a43d4187..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/KnowledgeToFeedInput.tsx +++ /dev/null @@ -1,43 +0,0 @@ -import { useTranslation } from "react-i18next"; - -import Button from "@/lib/components/ui/Button"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; - -import { FeedItems } from "./components"; -import { Crawler } from "./components/Crawler"; -import { FileUploader } from "./components/FileUploader"; - -export const KnowledgeToFeedInput = ({ - feedBrain, -}: { - feedBrain: () => void; -}): JSX.Element => { - const { t } = useTranslation(["translation", "upload"]); - - const { knowledgeToFeed } = useKnowledgeToFeedContext(); - - return ( -
-
- - - {`${t("and", { ns: "translation" })} / ${t("or", { - ns: "translation", - })}`} - - -
- -
- -
-
- ); -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/helpers/isValidUrl.ts b/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/helpers/isValidUrl.ts deleted file mode 100644 index b41a47ef44f7..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/helpers/isValidUrl.ts +++ /dev/null @@ -1,9 +0,0 @@ -export const isValidUrl = (url: string): boolean => { - try { - new URL(url); - - return true; - } catch (err) { - return false; - } -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/hooks/useCrawler.ts b/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/hooks/useCrawler.ts deleted file mode 100644 index 63885f89abed..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/hooks/useCrawler.ts +++ /dev/null @@ -1,62 +0,0 @@ -"use client"; -import { useRef, useState } from "react"; -import { useTranslation } from "react-i18next"; - -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { useSupabase } from "@/lib/context/SupabaseProvider"; -import { useToast } from "@/lib/hooks"; -import { useOnboarding } from "@/lib/hooks/useOnboarding"; -import { useOnboardingTracker } from "@/lib/hooks/useOnboardingTracker"; -import { redirectToLogin } from "@/lib/router/redirectToLogin"; -import { useEventTracking } from "@/services/analytics/june/useEventTracking"; - -import { isValidUrl } from "../helpers/isValidUrl"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useCrawler = () => { - const { addKnowledgeToFeed } = useKnowledgeToFeedContext(); - const urlInputRef = useRef(null); - const { session } = useSupabase(); - const { publish } = useToast(); - const { t } = useTranslation(["translation", "upload"]); - const [urlToCrawl, setUrlToCrawl] = useState(""); - const { track } = useEventTracking(); - const { trackOnboardingEvent } = useOnboardingTracker(); - const { isOnboarding } = useOnboarding(); - - if (session === null) { - redirectToLogin(); - } - - const handleSubmit = () => { - if (urlToCrawl === "") { - return; - } - if (!isValidUrl(urlToCrawl)) { - void track("URL_INVALID"); - publish({ - variant: "danger", - text: t("invalidUrl"), - }); - - return; - } - if (isOnboarding) { - void trackOnboardingEvent("URL_CRAWLED"); - } else { - void track("URL_CRAWLED"); - } - addKnowledgeToFeed({ - source: "crawl", - url: urlToCrawl, - }); - setUrlToCrawl(""); - }; - - return { - urlInputRef, - urlToCrawl, - setUrlToCrawl, - handleSubmit, - }; -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/index.tsx b/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/index.tsx deleted file mode 100644 index ec5c265be072..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/Crawler/index.tsx +++ /dev/null @@ -1,43 +0,0 @@ -"use client"; -import { useTranslation } from "react-i18next"; -import { MdSend } from "react-icons/md"; - -import Button from "@/lib/components/ui/Button"; -import Field from "@/lib/components/ui/Field"; - -import { useCrawler } from "./hooks/useCrawler"; - -export const Crawler = (): JSX.Element => { - const { urlInputRef, urlToCrawl, handleSubmit, setUrlToCrawl } = useCrawler(); - const { t } = useTranslation(["translation", "upload"]); - - return ( -
-
-
{ - e.preventDefault(); - handleSubmit(); - }} - className="w-full" - > - setUrlToCrawl(e.target.value)} - icon={ - - } - /> - -
-
- ); -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/FeedItems.tsx b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/FeedItems.tsx deleted file mode 100644 index d8c78d5e5197..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/FeedItems.tsx +++ /dev/null @@ -1,34 +0,0 @@ -import { Fragment } from "react"; - -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; - -import { CrawlFeedItem } from "./components/CrawlFeedItem"; -import { FileFeedItem } from "./components/FileFeedItem"; - -export const FeedItems = (): JSX.Element => { - const { knowledgeToFeed, removeKnowledgeToFeed } = - useKnowledgeToFeedContext(); - if (knowledgeToFeed.length === 0) { - return ; - } - - return ( -
- {knowledgeToFeed.map((item, index) => - item.source === "crawl" ? ( - removeKnowledgeToFeed(index)} - /> - ) : ( - removeKnowledgeToFeed(index)} - /> - ) - )} -
- ); -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/CrawlFeedItem.tsx b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/CrawlFeedItem.tsx deleted file mode 100644 index 4f97f9bd0814..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/CrawlFeedItem.tsx +++ /dev/null @@ -1,34 +0,0 @@ -import { IoMdCloseCircle } from "react-icons/io"; -import { MdLink } from "react-icons/md"; - -import { FeedTitleDisplayer } from "./FeedTitleDisplayer"; - -import { StyledFeedItemDiv } from "../styles/StyledFeedItemDiv"; - -type CrawlFeedItemProps = { - url: string; - onRemove: () => void; -}; -export const CrawlFeedItem = ({ - url, - onRemove, -}: CrawlFeedItemProps): JSX.Element => { - return ( - -
-
- -
-
- -
-
-
- -
-
- ); -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/FeedTitleDisplayer.tsx b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/FeedTitleDisplayer.tsx deleted file mode 100644 index 702077c65e82..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/FeedTitleDisplayer.tsx +++ /dev/null @@ -1,24 +0,0 @@ -import Tooltip from "@/lib/components/ui/Tooltip/Tooltip"; - -import { enhanceUrlDisplay } from "./utils/enhanceUrlDisplay"; -import { removeFileExtension } from "./utils/removeFileExtension"; - -type FeedTitleDisplayerProps = { - title: string; - isUrl?: boolean; -}; - -export const FeedTitleDisplayer = ({ - title, - isUrl = false, -}: FeedTitleDisplayerProps): JSX.Element => { - return ( -
- -

- {isUrl ? enhanceUrlDisplay(title) : removeFileExtension(title)} -

-
-
- ); -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/index.ts b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/index.ts deleted file mode 100644 index 7b4426407e83..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./FeedTitleDisplayer"; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/enhanceUrlDisplay.ts b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/enhanceUrlDisplay.ts deleted file mode 100644 index ef491720f574..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/enhanceUrlDisplay.ts +++ /dev/null @@ -1,24 +0,0 @@ -export const enhanceUrlDisplay = (url: string): string => { - const parts = url.split("/"); - - // Check if the URL has at least 3 parts (protocol, domain, and one more segment) - if (parts.length >= 3) { - const domain = parts[2]; - const path = parts.slice(3).join("/"); - - // Split the domain by "." to check for subdomains and remove "www" - const domainParts = domain.split("."); - if (domainParts[0] === "www") { - domainParts.shift(); // Remove "www" - } - - // Combine the beginning (subdomain/domain) and the end (trimmed path) - const beginning = domainParts.join("."); - const trimmedPath = path.slice(0, 5) + "..." + path.slice(-5); // Display the beginning and end of the path - - return `${beginning}/${trimmedPath}`; - } - - // If the URL doesn't have enough parts, return it as is - return url; -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/removeFileExtension.ts b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/removeFileExtension.ts deleted file mode 100644 index b8328dbe5582..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FeedTitleDisplayer/utils/removeFileExtension.ts +++ /dev/null @@ -1,8 +0,0 @@ -export const removeFileExtension = (fileName: string): string => { - const lastDotIndex = fileName.lastIndexOf("."); - if (lastDotIndex !== -1) { - return fileName.substring(0, lastDotIndex); - } - - return fileName; -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FileFeedItem.tsx b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FileFeedItem.tsx deleted file mode 100644 index 01d9b4c07a4c..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/FileFeedItem.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { IoMdCloseCircle } from "react-icons/io"; - -import { getFileIcon } from "@/lib/helpers/getFileIcon"; - -import { FeedTitleDisplayer } from "./FeedTitleDisplayer"; - -import { StyledFeedItemDiv } from "../styles/StyledFeedItemDiv"; - -type FileFeedItemProps = { - file: File; - onRemove: () => void; -}; - -export const FileFeedItem = ({ - file, - onRemove, -}: FileFeedItemProps): JSX.Element => { - const icon = getFileIcon(file.name); - - return ( - -
-
{icon}
-
- -
-
-
- -
-
- ); -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/index.ts b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/index.ts deleted file mode 100644 index 7b4426407e83..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/components/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./FeedTitleDisplayer"; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/index.ts b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/index.ts deleted file mode 100644 index 0096f92ff372..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./FeedItems"; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/styles/StyledFeedItemDiv.tsx b/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/styles/StyledFeedItemDiv.tsx deleted file mode 100644 index 31cf0d628b95..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FeedItems/styles/StyledFeedItemDiv.tsx +++ /dev/null @@ -1,17 +0,0 @@ -import { HtmlHTMLAttributes } from "react"; - -import { cn } from "@/lib/utils"; - -type StyledFeedItemDivProps = HtmlHTMLAttributes; -export const StyledFeedItemDiv = ({ - className, - ...propsWithoutClassname -}: StyledFeedItemDivProps): JSX.Element => ( -
-); diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/FileUploader/index.tsx b/frontend/lib/components/KnowledgeToFeedInput/components/FileUploader/index.tsx deleted file mode 100644 index 1991e65a55a4..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/FileUploader/index.tsx +++ /dev/null @@ -1,38 +0,0 @@ -"use client"; -import { useTranslation } from "react-i18next"; -import { IoCloudUploadOutline } from "react-icons/io5"; - -import Card from "@/lib/components/ui/Card"; -import { useSupabase } from "@/lib/context/SupabaseProvider"; -import { useCustomDropzone } from "@/lib/hooks/useDropzone"; -import { redirectToLogin } from "@/lib/router/redirectToLogin"; - -export const FileUploader = (): JSX.Element => { - const { session } = useSupabase(); - const { getInputProps, isDragActive, open } = useCustomDropzone(); - - if (session === null) { - redirectToLogin(); - } - - const { t } = useTranslation(["translation", "upload"]); - - return ( -
-
-
- - - - {isDragActive && ( -

{t("drop", { ns: "upload" })}

- )} -
-
-
-
- ); -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/components/index.ts b/frontend/lib/components/KnowledgeToFeedInput/components/index.ts deleted file mode 100644 index f2d053e24414..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/components/index.ts +++ /dev/null @@ -1,3 +0,0 @@ -export * from "./Crawler"; -export * from "./FeedItems"; -export * from "./FileUploader"; diff --git a/frontend/lib/components/KnowledgeToFeedInput/hooks/useKnowledgeToFeedInput.ts.ts b/frontend/lib/components/KnowledgeToFeedInput/hooks/useKnowledgeToFeedInput.ts.ts deleted file mode 100644 index 220eea40f337..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/hooks/useKnowledgeToFeedInput.ts.ts +++ /dev/null @@ -1,94 +0,0 @@ -import { UUID } from "crypto"; -import { useCallback } from "react"; -import { useTranslation } from "react-i18next"; - -import { useCrawlApi } from "@/lib/api/crawl/useCrawlApi"; -import { useUploadApi } from "@/lib/api/upload/useUploadApi"; -import { getAxiosErrorParams } from "@/lib/helpers/getAxiosErrorParams"; -import { useToast } from "@/lib/hooks"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useKnowledgeToFeedInput = () => { - const { publish } = useToast(); - const { uploadFile } = useUploadApi(); - const { t } = useTranslation(["upload"]); - const { crawlWebsiteUrl } = useCrawlApi(); - - const crawlWebsiteHandler = useCallback( - async (url: string, brainId: UUID, bulk_id: UUID, chat_id?: UUID) => { - // Configure parameters - const config = { - url: url, - js: false, - depth: 1, - max_pages: 100, - max_time: 60, - }; - - try { - await crawlWebsiteUrl({ - brainId, - config, - chat_id, - bulk_id, - }); - } catch (error: unknown) { - const errorParams = getAxiosErrorParams(error); - if (errorParams !== undefined) { - publish({ - variant: "danger", - text: t("crawlFailed", { - message: JSON.stringify(errorParams.message), - }), - }); - } else { - publish({ - variant: "danger", - text: t("crawlFailed", { - message: JSON.stringify(error), - }), - }); - } - } - }, - [crawlWebsiteUrl, publish, t] - ); - - const uploadFileHandler = useCallback( - async (file: File, brainId: UUID, bulk_id: UUID, chat_id?: UUID) => { - const formData = new FormData(); - formData.append("uploadFile", file); - try { - await uploadFile({ - brainId, - formData, - chat_id, - bulk_id, - }); - } catch (e: unknown) { - const errorParams = getAxiosErrorParams(e); - if (errorParams !== undefined) { - publish({ - variant: "danger", - text: t("uploadFailed", { - message: JSON.stringify(errorParams.message), - }), - }); - } else { - publish({ - variant: "danger", - text: t("uploadFailed", { - message: JSON.stringify(e), - }), - }); - } - } - }, - [publish, t, uploadFile] - ); - - return { - crawlWebsiteHandler, - uploadFileHandler, - }; -}; diff --git a/frontend/lib/components/KnowledgeToFeedInput/index.ts b/frontend/lib/components/KnowledgeToFeedInput/index.ts deleted file mode 100644 index 8f1207bd3010..000000000000 --- a/frontend/lib/components/KnowledgeToFeedInput/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./KnowledgeToFeedInput"; diff --git a/frontend/lib/components/Menu/Menu.tsx b/frontend/lib/components/Menu/Menu.tsx index 33667f00f475..250a5ea47b4e 100644 --- a/frontend/lib/components/Menu/Menu.tsx +++ b/frontend/lib/components/Menu/Menu.tsx @@ -1,6 +1,6 @@ import { MotionConfig } from "framer-motion"; import { usePathname, useRouter } from "next/navigation"; -import { useFeatureFlagEnabled } from 'posthog-js/react'; +import { useFeatureFlagEnabled } from "posthog-js/react"; import { useState } from "react"; import { MenuControlButton } from "@/app/chat/[chatId]/components/ActionsBar/components/ChatInput/components/MenuControlButton/MenuControlButton"; @@ -15,6 +15,7 @@ import styles from "./Menu.module.scss"; import { AnimatedDiv } from "./components/AnimationDiv"; import { DiscussionButton } from "./components/DiscussionButton/DiscussionButton"; import { HomeButton } from "./components/HomeButton/HomeButton"; +import { KnowledgeButton } from "./components/KnowledgeButton/KnowledgeButton"; import { Notifications } from "./components/Notifications/Notifications"; import { NotificationsButton } from "./components/NotificationsButton/NotificationsButton"; import { ProfileButton } from "./components/ProfileButton/ProfileButton"; @@ -24,7 +25,6 @@ import { StudioButton } from "./components/StudioButton/StudioButton"; import { ThreadsButton } from "./components/ThreadsButton/ThreadsButton"; import { UpgradeToPlusButton } from "./components/UpgradeToPlusButton/UpgradeToPlusButton"; - const showUpgradeButton = process.env.NEXT_PUBLIC_SHOW_TOKENS === "true"; export const Menu = (): JSX.Element => { @@ -34,8 +34,7 @@ export const Menu = (): JSX.Element => { const pathname = usePathname() ?? ""; const [isLogoHovered, setIsLogoHovered] = useState(false); const { isDarkMode } = useUserSettingsContext(); - const flagEnabled = useFeatureFlagEnabled('show-quality-assistant') - + const flagEnabled = useFeatureFlagEnabled("show-quality-assistant"); useChatsList(); @@ -46,6 +45,7 @@ export const Menu = (): JSX.Element => { const displayedOnPages = [ "/assistants", "/chat", + "/knowledge", "/library", "/search", "studio", @@ -65,8 +65,9 @@ export const Menu = (): JSX.Element => {
@@ -91,6 +92,7 @@ export const Menu = (): JSX.Element => { {flagEnabled && } +
diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FileLine/FileLine.module.scss b/frontend/lib/components/Menu/components/KnowledgeButton/KnowledgeButton.module.scss similarity index 100% rename from frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FileLine/FileLine.module.scss rename to frontend/lib/components/Menu/components/KnowledgeButton/KnowledgeButton.module.scss diff --git a/frontend/lib/components/Menu/components/KnowledgeButton/KnowledgeButton.tsx b/frontend/lib/components/Menu/components/KnowledgeButton/KnowledgeButton.tsx new file mode 100644 index 000000000000..c5943e6d865a --- /dev/null +++ b/frontend/lib/components/Menu/components/KnowledgeButton/KnowledgeButton.tsx @@ -0,0 +1,21 @@ +import Link from "next/link"; +import { usePathname } from "next/navigation"; + +import { MenuButton } from "@/lib/components/Menu/components/MenuButton/MenuButton"; + +export const KnowledgeButton = (): JSX.Element => { + const pathname = usePathname() ?? ""; + const isSelected = pathname.includes("/knowledge"); + + return ( + + + + ); +}; diff --git a/frontend/lib/components/Menu/components/Notifications/GenericNotification/GenericNotification.tsx b/frontend/lib/components/Menu/components/Notifications/GenericNotification/GenericNotification.tsx index 2ae06385a1fa..ed782be179f0 100644 --- a/frontend/lib/components/Menu/components/Notifications/GenericNotification/GenericNotification.tsx +++ b/frontend/lib/components/Menu/components/Notifications/GenericNotification/GenericNotification.tsx @@ -23,7 +23,7 @@ export const GenericNotification = ({ const { updateNotifications } = useNotificationsContext(); const navigateToBrain = () => { - router.push(`/studio/${bulkNotification.brain_id}`); // Naviguer vers l'URL + router.push(`/studio/${bulkNotification.brain_id}`); }; const deleteNotification = async () => { diff --git a/frontend/lib/components/UploadDocumentModal/UploadDocumentModal.module.scss b/frontend/lib/components/UploadDocumentModal/UploadDocumentModal.module.scss deleted file mode 100644 index 15e78a1a08b7..000000000000 --- a/frontend/lib/components/UploadDocumentModal/UploadDocumentModal.module.scss +++ /dev/null @@ -1,22 +0,0 @@ -@use "styles/Spacings.module.scss"; -@use "styles/ZIndexes.module.scss"; - -.knowledge_modal { - position: relative; - display: flex; - flex-direction: column; - justify-content: space-between; - background-color: var(--background-0); - width: 100%; - flex: 1; - overflow: hidden; - - .buttons { - display: flex; - justify-content: space-between; - - &.standalone { - justify-content: flex-end; - } - } -} \ No newline at end of file diff --git a/frontend/lib/components/UploadDocumentModal/UploadDocumentModal.tsx b/frontend/lib/components/UploadDocumentModal/UploadDocumentModal.tsx deleted file mode 100644 index 1236ed442567..000000000000 --- a/frontend/lib/components/UploadDocumentModal/UploadDocumentModal.tsx +++ /dev/null @@ -1,127 +0,0 @@ -import { useEffect, useMemo, useState } from "react"; -import { useTranslation } from "react-i18next"; - -import { KnowledgeToFeed } from "@/app/chat/[chatId]/components/ActionsBar/components"; -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { OpenedConnection } from "@/lib/api/sync/types"; -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { createHandleGetButtonProps } from "@/lib/helpers/handleConnectionButtons"; - -import styles from "./UploadDocumentModal.module.scss"; -import { useAddKnowledge } from "./hooks/useAddKnowledge"; - -import { Modal } from "../ui/Modal/Modal"; -import { QuivrButton } from "../ui/QuivrButton/QuivrButton"; - -export const UploadDocumentModal = (): JSX.Element => { - const { shouldDisplayFeedCard, setShouldDisplayFeedCard, knowledgeToFeed } = - useKnowledgeToFeedContext(); - const { currentBrain } = useBrainContext(); - const { feedBrain } = useAddKnowledge(); - const [feeding, setFeeding] = useState(false); - const { - currentSyncId, - setCurrentSyncId, - openedConnections, - setOpenedConnections, - } = useFromConnectionsContext(); - const [currentConnection, setCurrentConnection] = useState< - OpenedConnection | undefined - >(undefined); - - useKnowledgeToFeedContext(); - const { t } = useTranslation(["knowledge"]); - - const disabled = useMemo(() => { - return ( - (knowledgeToFeed.length === 0 && - openedConnections.filter((connection) => { - return connection.submitted || !!connection.last_synced; - }).length === 0) || - !currentBrain - ); - }, [knowledgeToFeed, openedConnections, currentBrain, currentSyncId]); - - const handleFeedBrain = async () => { - setFeeding(true); - await feedBrain(); - setFeeding(false); - setShouldDisplayFeedCard(false); - }; - - const getButtonProps = createHandleGetButtonProps( - currentConnection, - openedConnections, - setOpenedConnections, - currentSyncId, - setCurrentSyncId - ); - const buttonProps = getButtonProps(); - - useEffect(() => { - setCurrentConnection( - openedConnections.find( - (connection) => connection.user_sync_id === currentSyncId - ) - ); - }, [currentSyncId]); - - if (!shouldDisplayFeedCard) { - return <>; - } - - return ( - } - > -
- -
- {!!currentSyncId && ( - { - setCurrentSyncId(undefined); - }} - /> - )} - {currentSyncId ? ( - - ) : ( - { - setOpenedConnections([]); - void handleFeedBrain(); - }} - disabled={disabled} - isLoading={feeding} - important={true} - /> - )} -
-
-
- ); -}; diff --git a/frontend/lib/components/UploadDocumentModal/hooks/useAddKnowledge.ts b/frontend/lib/components/UploadDocumentModal/hooks/useAddKnowledge.ts deleted file mode 100644 index 2c2ff3b7f697..000000000000 --- a/frontend/lib/components/UploadDocumentModal/hooks/useAddKnowledge.ts +++ /dev/null @@ -1,31 +0,0 @@ -import { useEffect } from "react"; - -import { useKnowledge } from "@/app/studio/[brainId]/BrainManagementTabs/components/KnowledgeTab/hooks/useKnowledge"; -import { useUrlBrain } from "@/lib/hooks/useBrainIdFromUrl"; - -import { useFeedBrain } from "./useFeedBrain"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useAddKnowledge = () => { - const { brainId } = useUrlBrain(); - const { invalidateKnowledgeDataKey } = useKnowledge({ - brainId, - }); - - const { feedBrain, hasPendingRequests, setHasPendingRequests } = useFeedBrain( - { - dispatchHasPendingRequests: () => setHasPendingRequests(true), - } - ); - - useEffect(() => { - if (!hasPendingRequests) { - invalidateKnowledgeDataKey(); - } - }, [hasPendingRequests, invalidateKnowledgeDataKey]); - - return { - feedBrain, - hasPendingRequests, - }; -}; diff --git a/frontend/lib/components/UploadDocumentModal/hooks/useFeedBrain.ts b/frontend/lib/components/UploadDocumentModal/hooks/useFeedBrain.ts deleted file mode 100644 index d3d442eb22f3..000000000000 --- a/frontend/lib/components/UploadDocumentModal/hooks/useFeedBrain.ts +++ /dev/null @@ -1,82 +0,0 @@ -import { useState } from "react"; -import { useTranslation } from "react-i18next"; - -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { useChatApi } from "@/lib/api/chat/useChatApi"; -import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { useToast } from "@/lib/hooks"; -import { useUrlBrain } from "@/lib/hooks/useBrainIdFromUrl"; - -import { useFeedBrainHandler } from "./useFeedBrainHandler"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useFeedBrain = ({ - dispatchHasPendingRequests, - closeFeedInput, -}: { - dispatchHasPendingRequests?: () => void; - closeFeedInput?: () => void; -}) => { - const { publish } = useToast(); - const { t } = useTranslation(["upload"]); - let { brainId } = useUrlBrain(); - const { currentBrainId } = useBrainContext(); - const { setKnowledgeToFeed, knowledgeToFeed, setShouldDisplayFeedCard } = - useKnowledgeToFeedContext(); - const [hasPendingRequests, setHasPendingRequests] = useState(false); - const { handleFeedBrain } = useFeedBrainHandler(); - const { openedConnections } = useFromConnectionsContext(); - - const { createChat, deleteChat } = useChatApi(); - - const feedBrain = async (): Promise => { - brainId ??= currentBrainId ?? undefined; - if (brainId === undefined) { - publish({ - variant: "danger", - text: t("selectBrainFirst"), - }); - - return; - } - - if (knowledgeToFeed.length === 0 && !openedConnections.length) { - publish({ - variant: "danger", - text: t("addFiles"), - }); - - return; - } - - //TODO: Modify backend archi to avoid creating a chat for each feed action - const currentChatId = (await createChat("New Chat")).chat_id; - - try { - dispatchHasPendingRequests?.(); - closeFeedInput?.(); - setHasPendingRequests(true); - await handleFeedBrain({ - brainId, - chatId: currentChatId, - }); - setShouldDisplayFeedCard(false); - setKnowledgeToFeed([]); - } catch (e) { - publish({ - variant: "danger", - text: JSON.stringify(e), - }); - } finally { - setHasPendingRequests(false); - await deleteChat(currentChatId); - } - }; - - return { - feedBrain, - hasPendingRequests, - setHasPendingRequests, - }; -}; diff --git a/frontend/lib/components/UploadDocumentModal/hooks/useFeedBrainHandler.ts b/frontend/lib/components/UploadDocumentModal/hooks/useFeedBrainHandler.ts deleted file mode 100644 index 914a2df0e165..000000000000 --- a/frontend/lib/components/UploadDocumentModal/hooks/useFeedBrainHandler.ts +++ /dev/null @@ -1,77 +0,0 @@ -import { UUID } from "crypto"; - -import { useFromConnectionsContext } from "@/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/components/FromConnections/FromConnectionsProvider/hooks/useFromConnectionContext"; -import { useSync } from "@/lib/api/sync/useSync"; -import { useKnowledgeToFeedInput } from "@/lib/components/KnowledgeToFeedInput/hooks/useKnowledgeToFeedInput.ts"; -import { useKnowledgeToFeedFilesAndUrls } from "@/lib/hooks/useKnowledgeToFeed"; -import { useOnboarding } from "@/lib/hooks/useOnboarding"; - -type FeedBrainProps = { - brainId: UUID; - chatId: UUID; -}; -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useFeedBrainHandler = () => { - const { files, urls } = useKnowledgeToFeedFilesAndUrls(); - const { crawlWebsiteHandler, uploadFileHandler } = useKnowledgeToFeedInput(); - const { updateOnboarding, onboarding } = useOnboarding(); - const { - syncFiles, - getActiveSyncsForBrain, - deleteActiveSync, - updateActiveSync, - } = useSync(); - const { openedConnections } = useFromConnectionsContext(); - - const updateOnboardingA = async () => { - if (onboarding.onboarding_a) { - await updateOnboarding({ - onboarding_a: false, - }); - } - }; - - const handleFeedBrain = async ({ - brainId, - chatId, - }: FeedBrainProps): Promise => { - const uploadPromises = files.map((file) => - uploadFileHandler(file, brainId, chatId) - ); - const crawlPromises = urls.map((url) => - crawlWebsiteHandler(url, brainId, chatId) - ); - - const existingConnections = await getActiveSyncsForBrain(brainId); - - await Promise.all( - openedConnections - .filter((connection) => connection.selectedFiles.files.length) - .map(async (openedConnection) => { - const existingConnectionIds = existingConnections.map( - (connection) => connection.id - ); - if ( - !openedConnection.id || - !existingConnectionIds.includes(openedConnection.id) - ) { - await syncFiles(openedConnection, brainId); - } else if (!openedConnection.selectedFiles.files.length) { - await deleteActiveSync(openedConnection.id); - } else { - await updateActiveSync(openedConnection); - } - }) - ); - - await Promise.all([ - ...uploadPromises, - ...crawlPromises, - updateOnboardingA(), - ]); - }; - - return { - handleFeedBrain, - }; -}; diff --git a/frontend/lib/components/ui/FileInput/FileInput.module.scss b/frontend/lib/components/ui/FileInput/FileInput.module.scss index 93d8a34b01d4..73902a83a259 100644 --- a/frontend/lib/components/ui/FileInput/FileInput.module.scss +++ b/frontend/lib/components/ui/FileInput/FileInput.module.scss @@ -2,10 +2,11 @@ @use "styles/ScreenSizes.module.scss"; @use "styles/Spacings.module.scss"; @use "styles/Typography.module.scss"; +@use "styles/Variables.module.scss"; .file_input_wrapper { width: 100%; - height: 200px; + height: Variables.$fileInputHeight; &.drag_active { .header_wrapper { diff --git a/frontend/lib/components/ui/FileInput/FileInput.tsx b/frontend/lib/components/ui/FileInput/FileInput.tsx index 69543b15ad96..13f9b096d436 100644 --- a/frontend/lib/components/ui/FileInput/FileInput.tsx +++ b/frontend/lib/components/ui/FileInput/FileInput.tsx @@ -7,20 +7,27 @@ import { Icon } from "../Icon/Icon"; interface FileInputProps { label: string; - onFileChange: (file: File) => void; + onFileChange: (files: File[]) => void; acceptedFileTypes?: string[]; + hideFileName?: boolean; + handleMultipleFiles?: boolean; } export const FileInput = (props: FileInputProps): JSX.Element => { - const [currentFile, setCurrentFile] = useState(null); + const [currentFiles, setCurrentFiles] = useState([]); const [errorMessage, setErrorMessage] = useState(""); const fileInputRef = useRef(null); - const handleFileChange = (file: File) => { - const fileExtension = file.name.split(".").pop(); - if (props.acceptedFileTypes?.includes(fileExtension || "")) { - props.onFileChange(file); - setCurrentFile(file); + const handleFileChange = (files: File[]) => { + const validFiles = files.filter((file) => { + const fileExtension = file.name.split(".").pop(); + + return props.acceptedFileTypes?.includes(fileExtension ?? ""); + }); + + if (validFiles.length > 0) { + props.onFileChange(validFiles); + setCurrentFiles(validFiles); setErrorMessage(""); } else { setErrorMessage("Wrong extension"); @@ -28,9 +35,9 @@ export const FileInput = (props: FileInputProps): JSX.Element => { }; const handleInputChange = (event: React.ChangeEvent) => { - const file = event.target.files?.[0]; - if (file) { - handleFileChange(file); + const files = Array.from(event.target.files ?? []); + if (files.length > 0) { + handleFileChange(files); } }; @@ -65,12 +72,10 @@ export const FileInput = (props: FileInputProps): JSX.Element => { const { getRootProps, getInputProps, isDragActive } = useDropzone({ onDrop: (acceptedFiles) => { - const file = acceptedFiles[0]; - if (file) { - handleFileChange(file); - } + handleFileChange(acceptedFiles); }, accept, + multiple: props.handleMultipleFiles, }); return ( @@ -89,8 +94,12 @@ export const FileInput = (props: FileInputProps): JSX.Element => {
or drag it here
- {currentFile && ( - {currentFile.name} + {currentFiles.length > 0 && !props.hideFileName && ( +
+ {currentFiles.map((file, index) => ( + {file.name} + ))} +
)}
@@ -101,6 +110,7 @@ export const FileInput = (props: FileInputProps): JSX.Element => { className={styles.file_input} onChange={handleInputChange} style={{ display: "none" }} + multiple={props.handleMultipleFiles} /> {errorMessage !== "" && ( {errorMessage} diff --git a/frontend/lib/components/ui/Icon/Icon.module.scss b/frontend/lib/components/ui/Icon/Icon.module.scss index c0fccd57b20a..677faabfe881 100644 --- a/frontend/lib/components/ui/Icon/Icon.module.scss +++ b/frontend/lib/components/ui/Icon/Icon.module.scss @@ -44,7 +44,11 @@ } .dark-grey { - color: var(--icon-2); + color: var(--icon-4); + + &.hovered { + color: var(--primary-0); + } } .grey { diff --git a/frontend/lib/components/ui/Icon/Icon.tsx b/frontend/lib/components/ui/Icon/Icon.tsx index 077b4822ad4a..802c4926c733 100644 --- a/frontend/lib/components/ui/Icon/Icon.tsx +++ b/frontend/lib/components/ui/Icon/Icon.tsx @@ -10,7 +10,8 @@ import styles from "./Icon.module.scss"; interface IconProps { name: keyof typeof iconList; size: IconSize; - color: Color; + color?: Color; + customColor?: string; disabled?: boolean; classname?: string; hovered?: boolean; @@ -22,8 +23,9 @@ export const Icon = ({ name, size, color, - disabled, + customColor, classname, + disabled, hovered, handleHover, onClick, @@ -56,10 +58,11 @@ export const Icon = ({ className={` ${classname} ${styles[size]} - ${styles[color]} + ${!customColor && color ? styles[color] : ""} ${disabled ? styles.disabled : ""} ${iconHovered || hovered ? styles.hovered : ""} `} + style={{ color: customColor }} onMouseEnter={handleMouseEnter} onMouseLeave={handleMouseLeave} onClick={onClick} diff --git a/frontend/lib/components/ui/QuivrButton/QuivrButton.tsx b/frontend/lib/components/ui/QuivrButton/QuivrButton.tsx index c1e1fd9f6f01..16320a49198a 100644 --- a/frontend/lib/components/ui/QuivrButton/QuivrButton.tsx +++ b/frontend/lib/components/ui/QuivrButton/QuivrButton.tsx @@ -77,7 +77,7 @@ export const QuivrButton = ({
); - return disabled ? ( + return disabled && tooltip ? ( {ButtonContent} ) : ( ButtonContent diff --git a/frontend/lib/components/ui/TextInput/TextInput.module.scss b/frontend/lib/components/ui/TextInput/TextInput.module.scss index 3ce7ba9ae8fc..2a0f99413ea1 100644 --- a/frontend/lib/components/ui/TextInput/TextInput.module.scss +++ b/frontend/lib/components/ui/TextInput/TextInput.module.scss @@ -52,4 +52,9 @@ box-shadow: none; } } + + .warning { + font-size: Typography.$tiny; + color: var(--dangerous); + } } diff --git a/frontend/lib/components/ui/TextInput/TextInput.tsx b/frontend/lib/components/ui/TextInput/TextInput.tsx index e0527b847733..112ac6a884a2 100644 --- a/frontend/lib/components/ui/TextInput/TextInput.tsx +++ b/frontend/lib/components/ui/TextInput/TextInput.tsx @@ -1,3 +1,5 @@ +import { useState } from "react"; + import styles from "./TextInput.module.scss"; import { Icon } from "../Icon/Icon"; @@ -13,6 +15,7 @@ type TextInputProps = { crypted?: boolean; onKeyDown?: (event: React.KeyboardEvent) => void; small?: boolean; + url?: boolean; }; export const TextInput = ({ @@ -26,7 +29,51 @@ export const TextInput = ({ crypted, onKeyDown, small, + url, }: TextInputProps): JSX.Element => { + const [warning, setWarning] = useState(""); + + const isValidUrl = (value: string): boolean => { + try { + new URL(value); + + return true; + } catch (_) { + return false; + } + }; + + const handleSubmit = () => { + if (!url || isValidUrl(inputValue)) { + setWarning(""); + onSubmit?.(); + } else { + setWarning("Please enter a valid URL."); + } + }; + + const handleInputChange = (e: React.ChangeEvent) => { + setInputValue?.(e.target.value); + }; + + const handleKeyDown = (e: React.KeyboardEvent) => { + if (e.key === "Enter") { + handleSubmit(); + } + onKeyDown?.(e); + }; + + const getIconColor = () => { + if (!inputValue) { + return "grey"; + } + if (url) { + return isValidUrl(inputValue) ? "accent" : "grey"; + } + + return "accent"; + }; + return (
setInputValue?.(e.target.value)} + onChange={handleInputChange} placeholder={label} - onKeyDown={(e) => { - if (e.key === "Enter" && onSubmit) { - onSubmit(); - } - onKeyDown?.(e); - }} + onKeyDown={handleKeyDown} /> + {warning && !isValidUrl(inputValue) && ( +
{warning}
+ )} {!simple && iconName && ( )}
diff --git a/frontend/lib/components/ui/Tooltip/Tooltip.tsx b/frontend/lib/components/ui/Tooltip/Tooltip.tsx index bbba703b739e..f2fb1e9d85d3 100644 --- a/frontend/lib/components/ui/Tooltip/Tooltip.tsx +++ b/frontend/lib/components/ui/Tooltip/Tooltip.tsx @@ -1,7 +1,7 @@ "use client"; import * as TooltipPrimitive from "@radix-ui/react-tooltip"; import { AnimatePresence, motion } from "framer-motion"; -import { ReactNode, useState } from "react"; +import { ReactNode, useEffect, useState } from "react"; import styles from "./Tooltip.module.scss"; @@ -10,6 +10,8 @@ interface TooltipProps { tooltip?: ReactNode; small?: boolean; type?: "default" | "dangerous" | "success"; + delayDuration?: number; + open?: boolean; // Optional boolean prop } const Tooltip = ({ @@ -17,15 +19,23 @@ const Tooltip = ({ tooltip, small, type, + delayDuration = 0, + open: controlledOpen, // Renamed to avoid conflict with state variable }: TooltipProps): JSX.Element => { const [open, setOpen] = useState(false); + useEffect(() => { + if (controlledOpen !== undefined) { + setOpen(controlledOpen); + } + }, [controlledOpen]); + return ( {children} diff --git a/frontend/lib/context/BrainProvider/types.ts b/frontend/lib/context/BrainProvider/types.ts index 1a450818f523..6955674a3985 100644 --- a/frontend/lib/context/BrainProvider/types.ts +++ b/frontend/lib/context/BrainProvider/types.ts @@ -40,6 +40,7 @@ export type Brain = { display_name?: string; snippet_emoji?: string; snippet_color?: string; + brain_id?: UUID; }; export type MinimalBrainForUser = { diff --git a/frontend/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext.tsx b/frontend/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext.tsx deleted file mode 100644 index f29ce7389eff..000000000000 --- a/frontend/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext.tsx +++ /dev/null @@ -1,41 +0,0 @@ -import { useContext } from "react"; - -import { FeedItemType } from "@/app/chat/[chatId]/components/ActionsBar/types"; - -import { KnowledgeToFeedContext } from "../knowledgeToFeed-provider"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useKnowledgeToFeedContext = () => { - const context = useContext(KnowledgeToFeedContext); - - const addKnowledgeToFeed = (knowledge: FeedItemType) => { - context?.setKnowledgeToFeed((prevKnowledge) => [ - ...prevKnowledge, - knowledge, - ]); - }; - - const removeKnowledgeToFeed = (index: number) => { - context?.setKnowledgeToFeed((prevKnowledge) => { - const newKnowledge = [...prevKnowledge]; - newKnowledge.splice(index, 1); - - return newKnowledge; - }); - }; - - const removeAllKnowledgeToFeed = () => { - context?.setKnowledgeToFeed([]); - }; - - if (context === undefined) { - throw new Error("useKnowledge must be used inside KnowledgeToFeedProvider"); - } - - return { - ...context, - addKnowledgeToFeed, - removeKnowledgeToFeed, - removeAllKnowledgeToFeed, - }; -}; diff --git a/frontend/lib/context/KnowledgeToFeedProvider/index.ts b/frontend/lib/context/KnowledgeToFeedProvider/index.ts deleted file mode 100644 index 7877648f02f7..000000000000 --- a/frontend/lib/context/KnowledgeToFeedProvider/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./knowledgeToFeed-provider"; diff --git a/frontend/lib/context/KnowledgeToFeedProvider/knowledgeToFeed-provider.tsx b/frontend/lib/context/KnowledgeToFeedProvider/knowledgeToFeed-provider.tsx deleted file mode 100644 index cb06437f7c24..000000000000 --- a/frontend/lib/context/KnowledgeToFeedProvider/knowledgeToFeed-provider.tsx +++ /dev/null @@ -1,38 +0,0 @@ -"use client"; - -import { createContext, useState } from "react"; - -import { FeedItemType } from "@/app/chat/[chatId]/components/ActionsBar/types"; - -type KnowledgeToFeedContextType = { - knowledgeToFeed: FeedItemType[]; - setKnowledgeToFeed: React.Dispatch>; - shouldDisplayFeedCard: boolean; - setShouldDisplayFeedCard: React.Dispatch>; -}; - -export const KnowledgeToFeedContext = createContext< - KnowledgeToFeedContextType | undefined ->(undefined); - -export const KnowledgeToFeedProvider = ({ - children, -}: { - children: React.ReactNode; -}): JSX.Element => { - const [knowledgeToFeed, setKnowledgeToFeed] = useState([]); - const [shouldDisplayFeedCard, setShouldDisplayFeedCard] = useState(false); - - return ( - - {children} - - ); -}; diff --git a/frontend/lib/context/index.tsx b/frontend/lib/context/index.tsx index e20130cc0fa8..c883a13e4b5e 100644 --- a/frontend/lib/context/index.tsx +++ b/frontend/lib/context/index.tsx @@ -1,3 +1,2 @@ export * from "./BrainProvider"; export * from "./ChatProvider"; -export * from "./KnowledgeToFeedProvider"; diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/utils/formatMinimalBrainsToSelectComponentInput.ts b/frontend/lib/helpers/formatBrains.ts similarity index 100% rename from frontend/app/chat/[chatId]/components/ActionsBar/components/KnowledgeToFeed/utils/formatMinimalBrainsToSelectComponentInput.ts rename to frontend/lib/helpers/formatBrains.ts diff --git a/frontend/lib/helpers/handleConnectionButtons.ts b/frontend/lib/helpers/handleConnectionButtons.ts deleted file mode 100644 index ada5ccac882c..000000000000 --- a/frontend/lib/helpers/handleConnectionButtons.ts +++ /dev/null @@ -1,182 +0,0 @@ -import { SetStateAction } from "react"; - -import { OpenedConnection } from "../api/sync/types"; - -const isRemoveAll = ( - matchingOpenedConnection: OpenedConnection, - currentConnection: OpenedConnection | undefined -): boolean => { - return !!( - currentConnection?.submitted && - matchingOpenedConnection.selectedFiles.files.length === 0 && - !currentConnection.cleaned - ); -}; - -const arraysAreEqual = (arr1: string[], arr2: string[]): boolean => { - if (arr1.length !== arr2.length) { - return false; - } - - for (let i = 0; i < arr1.length; i++) { - if (arr1[i] !== arr2[i]) { - return false; - } - } - - return true; -}; - -export const handleGetButtonProps = ( - currentConnection: OpenedConnection | undefined, - openedConnections: OpenedConnection[], - setOpenedConnections: React.Dispatch< - React.SetStateAction - >, - currentSyncId: number | undefined, - setCurrentSyncId: React.Dispatch> -): { - label: string; - type: "dangerous" | "primary"; - disabled: boolean; - callback: () => void; -} => { - const matchingOpenedConnection = - currentConnection && - openedConnections.find( - (conn) => conn.user_sync_id === currentConnection.user_sync_id - ); - - if (matchingOpenedConnection) { - if (isRemoveAll(matchingOpenedConnection, currentConnection)) { - return { - label: "Remove All", - type: "dangerous", - disabled: false, - callback: () => - removeConnection( - setOpenedConnections, - currentSyncId, - setCurrentSyncId - ), - }; - } else if (currentConnection.submitted) { - const matchingSelectedFileIds = - matchingOpenedConnection.selectedFiles.files - .map((file) => file.id) - .sort(); - - const currentSelectedFileIds = currentConnection.selectedFiles.files - .map((file) => file.id) - .sort(); - - const isDisabled = arraysAreEqual( - matchingSelectedFileIds, - currentSelectedFileIds - ); - - return { - label: "Update added files", - type: "primary", - disabled: - !matchingOpenedConnection.selectedFiles.files.length || isDisabled, - callback: () => - addConnection(setOpenedConnections, currentSyncId, setCurrentSyncId), - }; - } - } - - return { - label: "Add specific files", - type: "primary", - disabled: !matchingOpenedConnection?.selectedFiles.files.length, - callback: () => - addConnection(setOpenedConnections, currentSyncId, setCurrentSyncId), - }; -}; - -const addConnection = ( - setOpenedConnections: React.Dispatch< - React.SetStateAction - >, - currentSyncId: number | undefined, - setCurrentSyncId: React.Dispatch> -): void => { - setOpenedConnections((prevConnections) => { - const connectionIndex = prevConnections.findIndex( - (connection) => connection.user_sync_id === currentSyncId - ); - - if (connectionIndex !== -1) { - const newConnections = [...prevConnections]; - newConnections[connectionIndex] = { - ...newConnections[connectionIndex], - submitted: true, - cleaned: false, - }; - - return newConnections; - } - - return prevConnections; - }); - - setCurrentSyncId(undefined); -}; - -const removeConnection = ( - setOpenedConnections: React.Dispatch< - React.SetStateAction - >, - currentSyncId: number | undefined, - setCurrentSyncId: React.Dispatch> -): void => { - setOpenedConnections((prevConnections) => - prevConnections - .filter((connection) => { - return ( - connection.user_sync_id === currentSyncId || !!connection.last_synced - ); - }) - .map((connection) => { - if ( - connection.user_sync_id === currentSyncId && - !!connection.last_synced - ) { - return { ...connection, cleaned: true }; - } else { - return connection; - } - }) - ); - - setCurrentSyncId(undefined); -}; - -export const createHandleGetButtonProps = - ( - currentConnection: OpenedConnection | undefined, - openedConnections: OpenedConnection[], - setOpenedConnections: { - (value: SetStateAction): void; - (value: SetStateAction): void; - }, - currentSyncId: number | undefined, - setCurrentSyncId: { - (value: SetStateAction): void; - (value: SetStateAction): void; - } - ) => - (): { - label: string; - type: "primary" | "dangerous"; - disabled: boolean; - callback: () => void; - } => - handleGetButtonProps( - currentConnection, - openedConnections, - setOpenedConnections, - currentSyncId, - setCurrentSyncId - ); diff --git a/frontend/lib/helpers/iconList.ts b/frontend/lib/helpers/iconList.ts index cdc389263f6e..73f1b2bb7ac2 100644 --- a/frontend/lib/helpers/iconList.ts +++ b/frontend/lib/helpers/iconList.ts @@ -2,18 +2,11 @@ import { AiOutlineLoading3Quarters } from "react-icons/ai"; import { BiCoin } from "react-icons/bi"; import { BsArrowRightShort, - BsFiletypeCsv, BsFiletypeDocx, BsFiletypeHtml, BsFiletypeMd, - BsFiletypeMp3, - BsFiletypeMp4, - BsFiletypePdf, - BsFiletypePptx, BsFiletypePy, BsFiletypeTxt, - BsFiletypeXls, - BsFiletypeXlsx, BsTextParagraph, } from "react-icons/bs"; import { CgSoftwareDownload } from "react-icons/cg"; @@ -25,11 +18,14 @@ import { FaDiscord, FaFile, FaFileAlt, + FaFileCsv, FaGithub, FaLinkedin, FaQuestionCircle, FaRegFileAlt, FaRegFileAudio, + FaRegFileExcel, + FaRegFilePdf, FaRegFolder, FaRegKeyboard, FaRegStar, @@ -40,9 +36,9 @@ import { FaTwitter, FaUnlock, } from "react-icons/fa"; -import { FaInfo } from "react-icons/fa6"; +import { FaInfo, FaRegFilePowerpoint, FaRegFileWord } from "react-icons/fa6"; import { FiUpload } from "react-icons/fi"; -import { GoLightBulb } from "react-icons/go"; +import { GoLightBulb, GoTable } from "react-icons/go"; import { HiBuildingOffice } from "react-icons/hi2"; import { IoIosAdd, @@ -56,11 +52,13 @@ import { import { IoArrowUpCircleOutline, IoBookOutline, + IoCameraOutline, IoChatbubbleEllipsesOutline, IoCloudDownloadOutline, IoFootsteps, IoHelp, IoHomeOutline, + IoMusicalNote, IoShareSocial, IoWarningOutline, } from "react-icons/io5"; @@ -68,6 +66,7 @@ import { LiaFileVideo, LiaRobotSolid } from "react-icons/lia"; import { IconType } from "react-icons/lib"; import { LuArrowLeftFromLine, + LuBook, LuBrain, LuBrainCircuit, LuChevronDown, @@ -77,12 +76,12 @@ import { LuExternalLink, LuGoal, LuPlusCircle, + LuPresentation, LuSearch, } from "react-icons/lu"; import { MdDarkMode, MdDashboardCustomize, - MdDeleteOutline, MdDynamicFeed, MdHistory, MdLink, @@ -101,6 +100,7 @@ import { import { PiOfficeChairFill } from "react-icons/pi"; import { RiDeleteBackLine, + RiDeleteBin6Line, RiHashtag, RiNotification2Line, } from "react-icons/ri"; @@ -128,11 +128,12 @@ export const iconList: { [name: string]: IconType } = { coin: BiCoin, color: MdOutlinePalette, copy: LuCopy, + csv: FaFileCsv, custom: MdDashboardCustomize, - delete: MdDeleteOutline, + delete: RiDeleteBin6Line, discord: FaDiscord, - doc: BsFiletypeDocx, - docx: BsFiletypeDocx, + doc: FaRegFileAlt, + docx: FaRegFileWord, download: IoCloudDownloadOutline, edit: MdOutlineModeEditOutline, email: MdOutlineMail, @@ -159,7 +160,9 @@ export const iconList: { [name: string]: IconType } = { info: FaInfo, invite: IoIosSend, ipynb: BsFiletypePy, + jpg: IoCameraOutline, key: MdOutlineVpnKey, + knowledge: LuBook, link: MdLink, linkedin: FaLinkedin, loader: AiOutlineLoading3Quarters, @@ -168,8 +171,8 @@ export const iconList: { [name: string]: IconType } = { markdown: BsFiletypeMd, md: BsFiletypeMd, moon: MdDarkMode, - mp3: BsFiletypeMp3, - mp4: BsFiletypeMp4, + mp3: IoMusicalNote, + mp4: IoMusicalNote, mpga: FaRegFileAudio, mpeg: LiaFileVideo, notifications: RiNotification2Line, @@ -177,7 +180,10 @@ export const iconList: { [name: string]: IconType } = { odt: BsFiletypeDocx, options: SlOptionsVertical, paragraph: BsTextParagraph, - pptx: BsFiletypePptx, + png: IoCameraOutline, + pdf: FaRegFilePdf, + ppt: LuPresentation, + pptx: FaRegFilePowerpoint, prompt: FaRegKeyboard, py: BsFiletypePy, question: FaQuestionCircle, @@ -199,6 +205,7 @@ export const iconList: { [name: string]: IconType } = { thumbsDown: FaRegThumbsDown, thumbsUp: FaRegThumbsUp, twitter: FaTwitter, + txt: BsFiletypeTxt, unfold: MdUnfoldMore, unlock: FaUnlock, unread: MdMarkEmailUnread, @@ -210,9 +217,6 @@ export const iconList: { [name: string]: IconType } = { wav: FaRegFileAudio, webm: LiaFileVideo, website: TbNetwork, - xls: BsFiletypeXls, - xlsx: BsFiletypeXlsx, - txt: BsFiletypeTxt, - csv: BsFiletypeCsv, - pdf: BsFiletypePdf, + xls: GoTable, + xlsx: FaRegFileExcel, }; diff --git a/frontend/lib/helpers/kms.ts b/frontend/lib/helpers/kms.ts new file mode 100644 index 000000000000..c4be94f52777 --- /dev/null +++ b/frontend/lib/helpers/kms.ts @@ -0,0 +1,51 @@ +import { UUID } from "crypto"; + +import { KMSElement } from "../api/sync/types"; + +interface HandleDropParams { + event: React.DragEvent; + targetElement: KMSElement | null; + patchKnowledge: ( + knowledgeId: UUID, + parent_id: UUID | null + ) => Promise; + setRefetchFolderMenu: (value: boolean) => void; + fetchQuivrFiles?: (parentId: UUID | null) => Promise; + currentFolder: KMSElement | undefined; +} + +export const handleDrop = async ({ + event, + targetElement, + patchKnowledge, + setRefetchFolderMenu, + fetchQuivrFiles, + currentFolder, +}: HandleDropParams): Promise => { + event.preventDefault(); + const draggedElement = JSON.parse( + event.dataTransfer.getData("application/json") + ) as KMSElement; + if (draggedElement.id !== targetElement?.id) { + try { + await patchKnowledge(draggedElement.id, targetElement?.id ?? null); + setRefetchFolderMenu(true); + if (fetchQuivrFiles) { + await fetchQuivrFiles(currentFolder?.id ?? null); + } else { + const customEvent = new CustomEvent("needToFetch", { + detail: { draggedElement, targetElement }, + }); + window.dispatchEvent(customEvent); + } + } catch (error) { + console.error("Failed to patch knowledge:", error); + } + } +}; + +export const handleDragOver = ( + event: React.DragEvent +): void => { + event.preventDefault(); +}; diff --git a/frontend/lib/helpers/providers.ts b/frontend/lib/helpers/providers.ts new file mode 100644 index 000000000000..1454d6a6da77 --- /dev/null +++ b/frontend/lib/helpers/providers.ts @@ -0,0 +1,12 @@ +export const transformConnectionLabel = (label: string): string => { + switch (label.toLowerCase()) { + case "google": + return "Google Drive"; + case "azure": + return "Sharepoint"; + case "dropbox": + return "Dropbox"; + default: + return label; + } +}; diff --git a/frontend/lib/hooks/useDropzone.ts b/frontend/lib/hooks/useDropzone.ts deleted file mode 100644 index b0f59db872b8..000000000000 --- a/frontend/lib/hooks/useDropzone.ts +++ /dev/null @@ -1,88 +0,0 @@ -import { FileRejection, useDropzone } from "react-dropzone"; -import { useTranslation } from "react-i18next"; - -import { FeedItemUploadType } from "@/app/chat/[chatId]/components/ActionsBar/types"; -import { useEventTracking } from "@/services/analytics/june/useEventTracking"; - -import { useOnboarding } from "./useOnboarding"; -import { useOnboardingTracker } from "./useOnboardingTracker"; -import { useToast } from "./useToast"; - -import { useBrainCreationContext } from "../components/AddBrainModal/brainCreation-provider"; -import { useKnowledgeToFeedContext } from "../context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; -import { acceptedFormats } from "../helpers/acceptedFormats"; -import { cloneFileWithSanitizedName } from "../helpers/cloneFileWithSanitizedName"; - -// eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types -export const useCustomDropzone = () => { - const { knowledgeToFeed, addKnowledgeToFeed, setShouldDisplayFeedCard } = - useKnowledgeToFeedContext(); - const { isBrainCreationModalOpened } = useBrainCreationContext(); - const { isOnboarding } = useOnboarding(); - const { trackOnboardingEvent } = useOnboardingTracker(); - const files: File[] = ( - knowledgeToFeed.filter((c) => c.source === "upload") as FeedItemUploadType[] - ).map((c) => c.file); - - const { publish } = useToast(); - const { track } = useEventTracking(); - - const { t } = useTranslation(["upload"]); - - const onDrop = (acceptedFiles: File[], fileRejections: FileRejection[]) => { - if (!isBrainCreationModalOpened) { - setShouldDisplayFeedCard(true); - } - if (fileRejections.length > 0) { - const firstRejection = fileRejections[0]; - - if (firstRejection.errors[0].code === "file-invalid-type") { - publish({ variant: "danger", text: t("invalidFileType") }); - } else { - publish({ - variant: "danger", - text: t("maxSizeError", { ns: "upload" }), - }); - } - - return; - } - - for (const file of acceptedFiles) { - const isAlreadyInFiles = - files.filter((f) => f.name === file.name && f.size === file.size) - .length > 0; - if (isAlreadyInFiles) { - publish({ - variant: "warning", - text: t("alreadyAdded", { fileName: file.name, ns: "upload" }), - }); - } else { - if (isOnboarding) { - void trackOnboardingEvent("FILE_UPLOADED"); - } else { - void track("FILE_UPLOADED"); - } - - addKnowledgeToFeed({ - source: "upload", - file: cloneFileWithSanitizedName(file), - }); - } - } - }; - - const { getInputProps, getRootProps, isDragActive, open } = useDropzone({ - onDrop, - noClick: true, - maxSize: 100000000, // 1 MB - accept: acceptedFormats, - }); - - return { - getInputProps, - getRootProps, - isDragActive, - open, - }; -}; diff --git a/frontend/lib/hooks/useKnowledgeToFeed.ts b/frontend/lib/hooks/useKnowledgeToFeed.ts deleted file mode 100644 index a5e0c6eb717c..000000000000 --- a/frontend/lib/hooks/useKnowledgeToFeed.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { - FeedItemCrawlType, - FeedItemUploadType, -} from "@/app/chat/[chatId]/components/ActionsBar/types"; -import { useKnowledgeToFeedContext } from "@/lib/context/KnowledgeToFeedProvider/hooks/useKnowledgeToFeedContext"; - -type UseKnowledgeToFeed = { - files: File[]; - urls: string[]; -}; -export const useKnowledgeToFeedFilesAndUrls = (): UseKnowledgeToFeed => { - const { knowledgeToFeed } = useKnowledgeToFeedContext(); - - const files: File[] = ( - knowledgeToFeed.filter((c) => c.source === "upload") as FeedItemUploadType[] - ).map((c) => c.file); - - const urls: string[] = ( - knowledgeToFeed.filter((c) => c.source === "crawl") as FeedItemCrawlType[] - ).map((c) => c.url); - - return { - files, - urls, - }; -}; diff --git a/frontend/lib/types/Knowledge.ts b/frontend/lib/types/Knowledge.ts index 43a85d4f7326..9772413fca2b 100644 --- a/frontend/lib/types/Knowledge.ts +++ b/frontend/lib/types/Knowledge.ts @@ -1,27 +1,19 @@ import { UUID } from "crypto"; -export type Knowledge = UploadedKnowledge | CrawledKnowledge; +export interface AddFolderData { + parent_id: UUID | null; + file_name: string; + is_folder: boolean; +} -export interface UploadedKnowledge { - id: UUID; - brainId: UUID; - fileName: string; - extension: string; - status: "UPLOADED" | "PROCESSING" | "ERROR"; - integration: string; - integration_link: string; +export interface AddKnowledgeFileData { + parent_id: UUID | null; + file_name: string; + is_folder: boolean; } -export interface CrawledKnowledge { - id: UUID; - brainId: UUID; +export interface AddKnowledgeUrlData { + parent_id: UUID | null; + is_folder: boolean; url: string; - extension: string; - status: "UPLOADED" | "PROCESSING" | "ERROR"; } - -export const isUploadedKnowledge = ( - knowledge: Knowledge -): knowledge is UploadedKnowledge => { - return "fileName" in knowledge && !("url" in knowledge); -}; diff --git a/frontend/lib/types/Modal.ts b/frontend/lib/types/Modal.ts deleted file mode 100644 index a15358fcdfb0..000000000000 --- a/frontend/lib/types/Modal.ts +++ /dev/null @@ -1,6 +0,0 @@ -import { StepValue } from "../components/AddBrainModal/types/types"; - -export type Step = { - label: string; - value: StepValue; -}; diff --git a/frontend/styles/_Variables.module.scss b/frontend/styles/_Variables.module.scss index 075cc8c8ed76..f028988a5324 100644 --- a/frontend/styles/_Variables.module.scss +++ b/frontend/styles/_Variables.module.scss @@ -4,3 +4,4 @@ $menuWidth: 230px; $brainButtonHeight: 105px; $menuSectionWidth: 175px; $assistantInputWidth: 300px; +$fileInputHeight: 200px; diff --git a/frontend/yarn.lock b/frontend/yarn.lock index b993d6c9f699..610f5180e9d4 100644 --- a/frontend/yarn.lock +++ b/frontend/yarn.lock @@ -38,7 +38,7 @@ resolved "https://registry.npmjs.org/@babel/compat-data/-/compat-data-7.22.9.tgz" integrity sha512-5UamI7xkUcJ3i9qVDS+KFDEK8/7oJ55/sJMB1Ge7IEapr7KfdfV/HErR+koZwOfd+SgtFKOKRhRakdg++DcJpQ== -"@babel/core@^7.22.9": +"@babel/core@^7.0.0", "@babel/core@^7.0.0-0", "@babel/core@^7.22.9": version "7.22.9" resolved "https://registry.npmjs.org/@babel/core/-/core-7.22.9.tgz" integrity sha512-G2EgeufBcYw27U4hhoIwFcgc1XU7TlXJ3mv04oOv1WCuo900U/anZSPzEqNjwdjgffkk2Gs0AN0dW1CKVLcG7w== @@ -282,17 +282,17 @@ dependencies: "@emotion/memoize" "^0.8.1" -"@emotion/memoize@0.7.4": - version "0.7.4" - resolved "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.7.4.tgz" - integrity sha512-Ja/Vfqe3HpuzRsG1oBtWTHk2PGZ7GR+2Vz5iYGelAw8dx32K0y7PjVuxK6z1nMpZOqAFsRUPCkK1YjJ56qJlgw== - "@emotion/memoize@^0.8.1": version "0.8.1" resolved "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.8.1.tgz" integrity sha512-W2P2c/VRW1/1tLox0mVUalvnWXxavmv/Oum2aPsRcoDJuob75FC3Y8FbpfLwUegRcxINtGUMPq0tFCvYNTBXNA== -"@emotion/react@11.11.1": +"@emotion/memoize@0.7.4": + version "0.7.4" + resolved "https://registry.npmjs.org/@emotion/memoize/-/memoize-0.7.4.tgz" + integrity sha512-Ja/Vfqe3HpuzRsG1oBtWTHk2PGZ7GR+2Vz5iYGelAw8dx32K0y7PjVuxK6z1nMpZOqAFsRUPCkK1YjJ56qJlgw== + +"@emotion/react@^11.0.0-rc.0", "@emotion/react@11.11.1": version "11.11.1" resolved "https://registry.npmjs.org/@emotion/react/-/react-11.11.1.tgz" integrity sha512-5mlW1DquU5HaxjLkfkGN1GA/fvVGdyHURRiX/0FHl2cfIfRxSOfmxEH5YS43edp0OldZrZ+dkBKbngxcNCdZvA== @@ -354,116 +354,11 @@ resolved "https://registry.npmjs.org/@emotion/weak-memoize/-/weak-memoize-0.3.1.tgz" integrity sha512-EsBwpc7hBUJWAsNPBmJy4hxWx12v6bshQsldrVmjxJoc3isbxhOrF2IcCpaXxfvq03NwkI7sbsOLXbYuqF/8Ww== -"@esbuild/android-arm64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/android-arm64/-/android-arm64-0.18.18.tgz#a52e0a1276065b1bf6b2de45b482cf36b6b945bd" - integrity sha512-dkAPYzRHq3dNXIzOyAknYOzsx8o3KWaNiuu56B2rP9IFPmFWMS58WQcTlUQi6iloku8ZyHHMluCe5sTWhKq/Yw== - -"@esbuild/android-arm@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/android-arm/-/android-arm-0.18.18.tgz#ffd591b956ced1c96e1224edfbed1001adadf2ae" - integrity sha512-oBymf7ZwplAawSxmiSlBCf+FMcY0f4bs5QP2jn43JKUf0M9DnrUTjqa5RvFPl1elw+sMfcpfBRPK+rb+E1q7zg== - -"@esbuild/android-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/android-x64/-/android-x64-0.18.18.tgz#6e8a7b41fc80265849e0a1de928fe162b27990c7" - integrity sha512-r7/pVcrUQMYkjvtE/1/n6BxhWM+/9tvLxDG1ev1ce4z3YsqoxMK9bbOM6bFcj0BowMeGQvOZWcBV182lFFKmrw== - "@esbuild/darwin-arm64@0.18.18": version "0.18.18" resolved "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.18.18.tgz" integrity sha512-MSe2iV9MAH3wfP0g+vzN9bp36rtPPuCSk+bT5E2vv/d8krvW5uB/Pi/Q5+txUZuxsG3GcO8dhygjnFq0wJU9hQ== -"@esbuild/darwin-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/darwin-x64/-/darwin-x64-0.18.18.tgz#8aa691d0cbd3fb67f9f9083375c0c72e0463b8b2" - integrity sha512-ARFYISOWkaifjcr48YtO70gcDNeOf1H2RnmOj6ip3xHIj66f3dAbhcd5Nph5np6oHI7DhHIcr9MWO18RvUL1bw== - -"@esbuild/freebsd-arm64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/freebsd-arm64/-/freebsd-arm64-0.18.18.tgz#0aafde382df508d7863360950d5f491c07024806" - integrity sha512-BHnXmexzEWRU2ZySJosU0Ts0NRnJnNrMB6t4EiIaOSel73I8iLsNiTPLH0rJulAh19cYZutsB5XHK6N8fi5eMg== - -"@esbuild/freebsd-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/freebsd-x64/-/freebsd-x64-0.18.18.tgz#f00e54a3b65824ac3c749173bec9cd56d95fe73b" - integrity sha512-n823w35wm0ZOobbuE//0sJjuz1Qj619+AwjgOcAJMN2pomZhH9BONCtn+KlfrmM/NWZ+27yB/eGVFzUIWLeh3w== - -"@esbuild/linux-arm64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-arm64/-/linux-arm64-0.18.18.tgz#e04203429670257126a1bfee79bbd56448b24f5e" - integrity sha512-zANxnwF0sCinDcAqoMohGoWBK9QaFJ65Vgh0ZE+RXtURaMwx+RfmfLElqtnn7X8OYNckMoIXSg7u+tZ3tqTlrA== - -"@esbuild/linux-arm@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-arm/-/linux-arm-0.18.18.tgz#863236dc47df2269f860001ca5c5ff50931e9933" - integrity sha512-Kck3jxPLQU4VeAGwe8Q4NU+IWIx+suULYOFUI9T0C2J1+UQlOHJ08ITN+MaJJ+2youzJOmKmcphH/t3SJxQ1Tw== - -"@esbuild/linux-ia32@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-ia32/-/linux-ia32-0.18.18.tgz#9ef6c7eeb8c86c5c1b7234a9684c6f45cbc2ed57" - integrity sha512-+VHz2sIRlY5u8IlaLJpdf5TL2kM76yx186pW7bpTB+vLWpzcFQVP04L842ZB2Ty13A1VXUvy3DbU1jV65P2skg== - -"@esbuild/linux-loong64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-loong64/-/linux-loong64-0.18.18.tgz#dca8624674924ac92c9e56399af160479283f130" - integrity sha512-fXPEPdeGBvguo/1+Na8OIWz3667BN1cwbGtTEZWTd0qdyTsk5gGf9jVX8MblElbDb/Cpw6y5JiaQuL96YmvBwQ== - -"@esbuild/linux-mips64el@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-mips64el/-/linux-mips64el-0.18.18.tgz#e6525b60ae9d8c3bdc652a773e6ebf66caa3fdd3" - integrity sha512-dLvRB87pIBIRnEIC32LIcgwK1JzlIuADIRjLKdUIpxauKwMuS/xMpN+cFl+0nN4RHNYOZ57DmXFFmQAcdlFOmw== - -"@esbuild/linux-ppc64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-ppc64/-/linux-ppc64-0.18.18.tgz#2ea6a4e0c6b0db21770d2c3c1525623dceadfe46" - integrity sha512-fRChqIJZ7hLkXSKfBLYgsX9Ssb5OGCjk3dzCETF5QSS1qjTgayLv0ALUdJDB9QOh/nbWwp+qfLZU6md4XcjL7w== - -"@esbuild/linux-riscv64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-riscv64/-/linux-riscv64-0.18.18.tgz#296c25d5bdeb3bab9ca79ad5279a8cc0a42fbeea" - integrity sha512-ALK/BT3u7Hoa/vHjow6W6+MKF0ohYcVcVA1EpskI4bkBPVuDLrUDqt2YFifg5UcZc8qup0CwQqWmFUd6VMNgaA== - -"@esbuild/linux-s390x@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-s390x/-/linux-s390x-0.18.18.tgz#bec4e9c982e778c51deaa754e1ed3f0546705647" - integrity sha512-crT7jtOXd9iirY65B+mJQ6W0HWdNy8dtkZqKGWNcBnunpLcTCfne5y5bKic9bhyYzKpQEsO+C/VBPD8iF0RhRw== - -"@esbuild/linux-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/linux-x64/-/linux-x64-0.18.18.tgz#22c9666920d3b7ef453289516ccff1c3ecbfdddd" - integrity sha512-/NSgghjBOW9ELqjXDYxOCCIsvQUZpvua1/6NdnA9Vnrp9UzEydyDdFXljUjMMS9p5KxMzbMO9frjHYGVHBfCHg== - -"@esbuild/netbsd-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/netbsd-x64/-/netbsd-x64-0.18.18.tgz#99b6125868c5ba8f0131bacc3f2bd05918245f45" - integrity sha512-8Otf05Vx5sZjLLDulgr5QS5lsWXMplKZEyHMArH9/S4olLlhzmdhQBPhzhJTNwaL2FJNdWcUPNGAcoD5zDTfUA== - -"@esbuild/openbsd-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/openbsd-x64/-/openbsd-x64-0.18.18.tgz#c2685bdd1e5aa11be1e212db371f474812a9b158" - integrity sha512-tFiFF4kT5L5qhVrWJUNxEXWvvX8nK/UX9ZrB7apuTwY3f6+Xy4aFMBPwAVrBYtBd5MOUuyOVHK6HBZCAHkwUlw== - -"@esbuild/sunos-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/sunos-x64/-/sunos-x64-0.18.18.tgz#277b2f5727119fe3004e673eb9f6ead0b4ff0738" - integrity sha512-MPogVV8Bzh8os4OM+YDGGsSzCzmNRiyKGtHoJyZLtI4BMmd6EcxmGlcEGK1uM46h1BiOyi7Z7teUtzzQhvkC+w== - -"@esbuild/win32-arm64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/win32-arm64/-/win32-arm64-0.18.18.tgz#e94d9e6d058e0ccb92d858badd4a6aa74772150e" - integrity sha512-YKD6LF/XXY9REu+ZL5RAsusiG48n602qxsMVh/E8FFD9hp4OyTQaL9fpE1ovxwQXqFio+tT0ITUGjDSSSPN13w== - -"@esbuild/win32-ia32@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/win32-ia32/-/win32-ia32-0.18.18.tgz#454916b1d0b85d2f82252192ae7bd5ea65c98ea1" - integrity sha512-NjSBmBsyZBTsZB6ga6rA6PfG/RHnwruUz/9YEVXcm4STGauFWvhYhOMhEyw1yU5NVgYYm8CH5AltCm77TS21/Q== - -"@esbuild/win32-x64@0.18.18": - version "0.18.18" - resolved "https://registry.yarnpkg.com/@esbuild/win32-x64/-/win32-x64-0.18.18.tgz#914c007ab1dbd28ca84e79ee666adeee6ccf92b4" - integrity sha512-eTSg/gC3p3tdjj4roDhe5xu94l1s2jMazP8u2FsYO8SEKvSpPOO71EucprDn/IuErDPvTFUhV9lTw5z5WJCRKQ== - "@eslint-community/eslint-utils@^4.2.0": version "4.4.0" resolved "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.4.0.tgz" @@ -606,16 +501,16 @@ resolved "https://registry.npmjs.org/@jridgewell/set-array/-/set-array-1.1.2.tgz" integrity sha512-xnkseuNADM0gt2bs+BvhO0p78Mk762YnZdsuzFV018NoG1Sj1SCQvpSqa7XUaTam5vAGasABV9qXASMKnFMwMw== -"@jridgewell/sourcemap-codec@1.4.14": - version "1.4.14" - resolved "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz" - integrity sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw== - "@jridgewell/sourcemap-codec@^1.4.10", "@jridgewell/sourcemap-codec@^1.4.13", "@jridgewell/sourcemap-codec@^1.4.15": version "1.4.15" resolved "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.15.tgz" integrity sha512-eF2rxCRulEKXHTRiDrDy6erMYWqNw4LPdQ8UQA4huuxaQsVeRPFl2oM8oDGxMFhJUWZf9McpLtJasDDZb/Bpeg== +"@jridgewell/sourcemap-codec@1.4.14": + version "1.4.14" + resolved "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.4.14.tgz" + integrity sha512-XPSJHWmi394fuUuzDnGz1wiKqWfo1yXecHQMRf2l6hztTO+nPru658AyDngaBe7isIxEkRsPR3FZh+s7iVa4Uw== + "@jridgewell/trace-mapping@^0.3.17", "@jridgewell/trace-mapping@^0.3.9": version "0.3.18" resolved "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.18.tgz" @@ -659,16 +554,16 @@ dependencies: "@lukeed/csprng" "^1.1.0" -"@next/env@14.1.0": - version "14.1.0" - resolved "https://registry.npmjs.org/@next/env/-/env-14.1.0.tgz" - integrity sha512-Py8zIo+02ht82brwwhTg36iogzFqGLPXlRGKQw5s+qP/kMNc4MAyDeEwBKDijk6zTIbegEgu8Qy7C1LboslQAw== - "@next/env@^13.4.3": version "13.5.6" resolved "https://registry.npmjs.org/@next/env/-/env-13.5.6.tgz" integrity sha512-Yac/bV5sBGkkEXmAX5FWPS9Mmo2rthrOPRQQNfycJPkjUAUclomCPH7QFVCDQ4Mp2k2K1SSM6m0zrxYrOwtFQw== +"@next/env@14.1.0": + version "14.1.0" + resolved "https://registry.npmjs.org/@next/env/-/env-14.1.0.tgz" + integrity sha512-Py8zIo+02ht82brwwhTg36iogzFqGLPXlRGKQw5s+qP/kMNc4MAyDeEwBKDijk6zTIbegEgu8Qy7C1LboslQAw== + "@next/eslint-plugin-next@14.1.0": version "14.1.0" resolved "https://registry.npmjs.org/@next/eslint-plugin-next/-/eslint-plugin-next-14.1.0.tgz" @@ -681,46 +576,6 @@ resolved "https://registry.npmjs.org/@next/swc-darwin-arm64/-/swc-darwin-arm64-14.1.0.tgz" integrity sha512-nUDn7TOGcIeyQni6lZHfzNoo9S0euXnu0jhsbMOmMJUBfgsnESdjN97kM7cBqQxZa8L/bM9om/S5/1dzCrW6wQ== -"@next/swc-darwin-x64@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-darwin-x64/-/swc-darwin-x64-14.1.0.tgz#0863a22feae1540e83c249384b539069fef054e9" - integrity sha512-1jgudN5haWxiAl3O1ljUS2GfupPmcftu2RYJqZiMJmmbBT5M1XDffjUtRUzP4W3cBHsrvkfOFdQ71hAreNQP6g== - -"@next/swc-linux-arm64-gnu@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-linux-arm64-gnu/-/swc-linux-arm64-gnu-14.1.0.tgz#893da533d3fce4aec7116fe772d4f9b95232423c" - integrity sha512-RHo7Tcj+jllXUbK7xk2NyIDod3YcCPDZxj1WLIYxd709BQ7WuRYl3OWUNG+WUfqeQBds6kvZYlc42NJJTNi4tQ== - -"@next/swc-linux-arm64-musl@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-linux-arm64-musl/-/swc-linux-arm64-musl-14.1.0.tgz#d81ddcf95916310b8b0e4ad32b637406564244c0" - integrity sha512-v6kP8sHYxjO8RwHmWMJSq7VZP2nYCkRVQ0qolh2l6xroe9QjbgV8siTbduED4u0hlk0+tjS6/Tuy4n5XCp+l6g== - -"@next/swc-linux-x64-gnu@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-linux-x64-gnu/-/swc-linux-x64-gnu-14.1.0.tgz#18967f100ec19938354332dcb0268393cbacf581" - integrity sha512-zJ2pnoFYB1F4vmEVlb/eSe+VH679zT1VdXlZKX+pE66grOgjmKJHKacf82g/sWE4MQ4Rk2FMBCRnX+l6/TVYzQ== - -"@next/swc-linux-x64-musl@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-linux-x64-musl/-/swc-linux-x64-musl-14.1.0.tgz#77077cd4ba8dda8f349dc7ceb6230e68ee3293cf" - integrity sha512-rbaIYFt2X9YZBSbH/CwGAjbBG2/MrACCVu2X0+kSykHzHnYH5FjHxwXLkcoJ10cX0aWCEynpu+rP76x0914atg== - -"@next/swc-win32-arm64-msvc@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-win32-arm64-msvc/-/swc-win32-arm64-msvc-14.1.0.tgz#5f0b8cf955644104621e6d7cc923cad3a4c5365a" - integrity sha512-o1N5TsYc8f/HpGt39OUQpQ9AKIGApd3QLueu7hXk//2xq5Z9OxmV6sQfNp8C7qYmiOlHYODOGqNNa0e9jvchGQ== - -"@next/swc-win32-ia32-msvc@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.1.0.tgz#21f4de1293ac5e5a168a412b139db5d3420a89d0" - integrity sha512-XXIuB1DBRCFwNO6EEzCTMHT5pauwaSj4SWs7CYnME57eaReAKBXCnkUE80p/pAZcewm7hs+vGvNqDPacEXHVkw== - -"@next/swc-win32-x64-msvc@14.1.0": - version "14.1.0" - resolved "https://registry.yarnpkg.com/@next/swc-win32-x64-msvc/-/swc-win32-x64-msvc-14.1.0.tgz#e561fb330466d41807123d932b365cf3d33ceba2" - integrity sha512-9WEbVRRAqJ3YFVqEZIxUqkiO8l1nool1LmNxygr5HWF8AcSYsEpneUDhmjUVJEzO2A04+oPtZdombzzPPkTtgg== - "@nodelib/fs.scandir@2.1.5": version "2.1.5" resolved "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz" @@ -729,7 +584,7 @@ "@nodelib/fs.stat" "2.0.5" run-parallel "^1.1.9" -"@nodelib/fs.stat@2.0.5", "@nodelib/fs.stat@^2.0.2": +"@nodelib/fs.stat@^2.0.2", "@nodelib/fs.stat@2.0.5": version "2.0.5" resolved "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz" integrity sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A== @@ -1291,7 +1146,7 @@ dependencies: "@segment/isodate" "^1.0.3" -"@segment/isodate@1.0.3", "@segment/isodate@^1.0.3": +"@segment/isodate@^1.0.3", "@segment/isodate@1.0.3": version "1.0.3" resolved "https://registry.npmjs.org/@segment/isodate/-/isodate-1.0.3.tgz" integrity sha512-BtanDuvJqnACFkeeYje7pWULVv8RgZaqKHWwGFnL/g/TH/CcZjkIVTfGDp/MAxmilYHUkrX70SqwnYSTNEaN7A== @@ -2382,7 +2237,7 @@ dependencies: cross-fetch "^3.1.5" -"@supabase/supabase-js@2.32.0": +"@supabase/supabase-js@^2.0.4", "@supabase/supabase-js@^2.21.0", "@supabase/supabase-js@2.32.0": version "2.32.0" resolved "https://registry.npmjs.org/@supabase/supabase-js/-/supabase-js-2.32.0.tgz" integrity sha512-1ShFhuOI5Du7604nlCelBsRD61daXk2O0qwjumoz35bqrYThnSPPtpJqZOHw6Mg6o7mLjIInYLh/DBlh8UvzRg== @@ -2430,7 +2285,7 @@ dependencies: "@tanstack/query-core" "5.4.3" -"@testing-library/dom@^9.0.0": +"@testing-library/dom@^9.0.0", "@testing-library/dom@>=7.21.4": version "9.3.4" resolved "https://registry.npmjs.org/@testing-library/dom/-/dom-9.3.4.tgz" integrity sha512-FlS4ZWlp97iiNWig0Muq8p+3rVDjRiYE+YKGbAqXOu9nwJFFOdL00kFpz42M+4huzYi86vAK1sOOfyOG45muIQ== @@ -2472,7 +2327,7 @@ resolved "https://registry.npmjs.org/@testing-library/user-event/-/user-event-14.5.1.tgz" integrity sha512-UCcUKrUYGj7ClomOo2SpNVvx4/fkd/2BbIHDCle8A0ax+P3bU7yJwDBDrS6ZwdTMARWTGODX1hEsCcO+7beJjg== -"@tiptap/core@2.1.12", "@tiptap/core@^2.1.12": +"@tiptap/core@^2.0.0", "@tiptap/core@^2.1.12", "@tiptap/core@2.1.12": version "2.1.12" resolved "https://registry.npmjs.org/@tiptap/core/-/core-2.1.12.tgz" integrity sha512-ZGc3xrBJA9KY8kln5AYTj8y+GDrKxi7u95xIl2eccrqTY5CQeRu6HRNM1yT4mAjuSaG9jmazyjGRlQuhyxCKxQ== @@ -2509,7 +2364,7 @@ resolved "https://registry.npmjs.org/@tiptap/extension-code/-/extension-code-2.1.12.tgz" integrity sha512-CRiRq5OTC1lFgSx6IMrECqmtb93a0ZZKujEnaRhzWliPBjLIi66va05f/P1vnV6/tHaC3yfXys6dxB5A4J8jxw== -"@tiptap/extension-document@2.1.12", "@tiptap/extension-document@^2.1.12": +"@tiptap/extension-document@^2.1.12", "@tiptap/extension-document@2.1.12": version "2.1.12" resolved "https://registry.npmjs.org/@tiptap/extension-document/-/extension-document-2.1.12.tgz" integrity sha512-0QNfAkCcFlB9O8cUNSwTSIQMV9TmoEhfEaLz/GvbjwEq4skXK3bU+OQX7Ih07waCDVXIGAZ7YAZogbvrn/WbOw== @@ -2531,7 +2386,7 @@ resolved "https://registry.npmjs.org/@tiptap/extension-gapcursor/-/extension-gapcursor-2.1.12.tgz" integrity sha512-zFYdZCqPgpwoB7whyuwpc8EYLYjUE5QYKb8vICvc+FraBUDM51ujYhFSgJC3rhs8EjI+8GcK8ShLbSMIn49YOQ== -"@tiptap/extension-hard-break@2.1.12", "@tiptap/extension-hard-break@^2.1.12": +"@tiptap/extension-hard-break@^2.1.12", "@tiptap/extension-hard-break@2.1.12": version "2.1.12" resolved "https://registry.npmjs.org/@tiptap/extension-hard-break/-/extension-hard-break-2.1.12.tgz" integrity sha512-nqKcAYGEOafg9D+2cy1E4gHNGuL12LerVa0eS2SQOb+PT8vSel9OTKU1RyZldsWSQJ5rq/w4uIjmLnrSR2w6Yw== @@ -2571,7 +2426,7 @@ resolved "https://registry.npmjs.org/@tiptap/extension-ordered-list/-/extension-ordered-list-2.1.12.tgz" integrity sha512-tF6VGl+D2avCgn9U/2YLJ8qVmV6sPE/iEzVAFZuOSe6L0Pj7SQw4K6AO640QBob/d8VrqqJFHCb6l10amJOnXA== -"@tiptap/extension-paragraph@2.1.12", "@tiptap/extension-paragraph@^2.1.12": +"@tiptap/extension-paragraph@^2.1.12", "@tiptap/extension-paragraph@2.1.12": version "2.1.12" resolved "https://registry.npmjs.org/@tiptap/extension-paragraph/-/extension-paragraph-2.1.12.tgz" integrity sha512-hoH/uWPX+KKnNAZagudlsrr4Xu57nusGekkJWBcrb5MCDE91BS+DN2xifuhwXiTHxnwOMVFjluc0bPzQbkArsw== @@ -2586,12 +2441,12 @@ resolved "https://registry.npmjs.org/@tiptap/extension-strike/-/extension-strike-2.1.12.tgz" integrity sha512-HlhrzIjYUT8oCH9nYzEL2QTTn8d1ECnVhKvzAe6x41xk31PjLMHTUy8aYjeQEkWZOWZ34tiTmslV1ce6R3Dt8g== -"@tiptap/extension-text@2.1.12", "@tiptap/extension-text@^2.1.12": +"@tiptap/extension-text@^2.1.12", "@tiptap/extension-text@2.1.12": version "2.1.12" resolved "https://registry.npmjs.org/@tiptap/extension-text/-/extension-text-2.1.12.tgz" integrity sha512-rCNUd505p/PXwU9Jgxo4ZJv4A3cIBAyAqlx/dtcY6cjztCQuXJhuQILPhjGhBTOLEEL4kW2wQtqzCmb7O8i2jg== -"@tiptap/pm@2.1.12": +"@tiptap/pm@^2.0.0", "@tiptap/pm@2.1.12": version "2.1.12" resolved "https://registry.npmjs.org/@tiptap/pm/-/pm-2.1.12.tgz" integrity sha512-Q3MXXQABG4CZBesSp82yV84uhJh/W0Gag6KPm2HRWPimSFELM09Z9/5WK9RItAYE0aLhe4Krnyiczn9AAa1tQQ== @@ -2648,7 +2503,7 @@ "@tiptap/extension-strike" "^2.1.12" "@tiptap/extension-text" "^2.1.12" -"@tiptap/suggestion@2.1.12": +"@tiptap/suggestion@^2.0.0", "@tiptap/suggestion@2.1.12": version "2.1.12" resolved "https://registry.npmjs.org/@tiptap/suggestion/-/suggestion-2.1.12.tgz" integrity sha512-rhlLWwVkOodBGRMK0mAmE34l2a+BqM2Y7q1ViuQRBhs/6sZ8d83O4hARHKVwqT5stY4i1l7d7PoemV3uAGI6+g== @@ -2801,7 +2656,7 @@ resolved "https://registry.npmjs.org/@types/ms/-/ms-0.7.34.tgz" integrity sha512-nG96G3Wp6acyAgJqGasjODb+acrI7KltPiRxzHPXnP3NgI28bpQDRv53olbqGXbfcgF5aiiHmO3xpwEpS5Ld9g== -"@types/node@*", "@types/node@20.10.6": +"@types/node@*", "@types/node@>= 14", "@types/node@20.10.6": version "20.10.6" resolved "https://registry.npmjs.org/@types/node/-/node-20.10.6.tgz" integrity sha512-Vac8H+NlRNNlAmDfGUP7b5h/KA+AtWIzuXy0E6OyP8f1tCLYAtPvKRRDJjAPqhpCb0t6U2j7/xqAuLEebW2kiw== @@ -2838,7 +2693,7 @@ resolved "https://registry.npmjs.org/@types/prop-types/-/prop-types-15.7.11.tgz" integrity sha512-ga8y9v9uyeiLdpKddhxYQkxNDrfvuPrlFb0N1qnZZByvcElJaXthF1UhvCh9TLWJBEHeNtdnbysW7Y6Uq8CVng== -"@types/react-dom@18.2.7", "@types/react-dom@^18.0.0": +"@types/react-dom@*", "@types/react-dom@^18.0.0", "@types/react-dom@18.2.7": version "18.2.7" resolved "https://registry.npmjs.org/@types/react-dom/-/react-dom-18.2.7.tgz" integrity sha512-GRaAEriuT4zp9N4p1i8BDBYmEyfo+xQ3yHjJU4eiK5NDa1RmUZG+unZABUTK4/Ox/M+GaHwb6Ow8rUITrtjszA== @@ -2859,7 +2714,7 @@ dependencies: "@types/react" "*" -"@types/react@*", "@types/react@18.2.18": +"@types/react@*", "@types/react@^16.8.0 || ^17.0.0 || ^18.0.0", "@types/react@^16.9.0 || ^17.0.0 || ^18.0.0", "@types/react@>=18", "@types/react@18.2.18": version "18.2.18" resolved "https://registry.npmjs.org/@types/react/-/react-18.2.18.tgz" integrity sha512-da4NTSeBv/P34xoZPhtcLkmZuJ+oYaCxHmyHzwaDQo9RQPBeXV+06gEk2FpqEcsX9XrnNLvRpVh6bdavDSjtiQ== @@ -2921,7 +2776,7 @@ semver "^7.3.7" tsutils "^3.21.0" -"@typescript-eslint/parser@^5.4.2 || ^6.0.0": +"@typescript-eslint/parser@^5.0.0", "@typescript-eslint/parser@^5.4.2 || ^6.0.0": version "5.62.0" resolved "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-5.62.0.tgz" integrity sha512-VlJEV0fOQ7BExOsHYAGrgbEiZoi8D+Bl2+f6V2RrXerRSylnp+ZBHmPvaIa8cz0Ajx7WO7Z5RqfgYg7ED1nRhA== @@ -3077,18 +2932,11 @@ acorn-walk@^8.2.0: resolved "https://registry.npmjs.org/acorn-walk/-/acorn-walk-8.2.0.tgz" integrity sha512-k+iyHEuPgSw6SbuDpGQM+06HQUa04DZ3o+F6CSzXMvvI5KMvnaEqXe+YVe555R9nn6GPt404fos4wcgpw12SDA== -acorn@^8.10.0, acorn@^8.9.0: +"acorn@^6.0.0 || ^7.0.0 || ^8.0.0", acorn@^8.10.0, acorn@^8.9.0: version "8.10.0" resolved "https://registry.npmjs.org/acorn/-/acorn-8.10.0.tgz" integrity sha512-F0SAmZ8iUtS//m8DmCTA0jlh6TDKkHQyK6xc6V4KDTyZKA9dnvX9/3sRTVQrWm79glUAZbnmmNcdYwUIHWVybw== -agent-base@6: - version "6.0.2" - resolved "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz" - integrity sha512-RZNwNclF7+MS/8bDg70amg32dyeZGZxiDuQmZxKLAlQjr3jGyLx+4Kkk58UO7D2QdgFIQCovuSuZESne6RG6XQ== - dependencies: - debug "4" - agent-base@^7.0.2, agent-base@^7.1.0: version "7.1.0" resolved "https://registry.npmjs.org/agent-base/-/agent-base-7.1.0.tgz" @@ -3096,6 +2944,13 @@ agent-base@^7.0.2, agent-base@^7.1.0: dependencies: debug "^4.3.4" +agent-base@6: + version "6.0.2" + resolved "https://registry.npmjs.org/agent-base/-/agent-base-6.0.2.tgz" + integrity sha512-RZNwNclF7+MS/8bDg70amg32dyeZGZxiDuQmZxKLAlQjr3jGyLx+4Kkk58UO7D2QdgFIQCovuSuZESne6RG6XQ== + dependencies: + debug "4" + aggregate-error@^3.0.0: version "3.1.0" resolved "https://registry.npmjs.org/aggregate-error/-/aggregate-error-3.1.0.tgz" @@ -3185,7 +3040,7 @@ aria-hidden@^1.1.1: dependencies: tslib "^2.0.0" -aria-query@5.1.3, aria-query@^5.0.0, aria-query@^5.1.3: +aria-query@^5.0.0, aria-query@^5.1.3, aria-query@5.1.3: version "5.1.3" resolved "https://registry.npmjs.org/aria-query/-/aria-query-5.1.3.tgz" integrity sha512-R5iJ5lkuHybztUfuOAznmboyjWq8O6sqNqtK7CLOqdydi54VNbORp49mb14KbWgG1QD3JFO9hJdZ+y4KutfdOQ== @@ -3412,7 +3267,7 @@ braces@^3.0.2, braces@~3.0.2: dependencies: fill-range "^7.0.1" -browserslist@^4.21.10, browserslist@^4.21.9: +browserslist@^4.21.10, browserslist@^4.21.9, "browserslist@>= 4.21.0": version "4.21.10" resolved "https://registry.npmjs.org/browserslist/-/browserslist-4.21.10.tgz" integrity sha512-bipEBdZfVH5/pwrvqc+Ub0kUPVfGUhlKxbvfD+z1BDnPEO/X98ruXGA1WP5ASpAFKan7Qr6j736IacbZQuAlKQ== @@ -3520,14 +3375,6 @@ chai@^4.3.7: pathval "^1.1.1" type-detect "^4.0.5" -chalk@3.0.0, chalk@^3.0.0: - version "3.0.0" - resolved "https://registry.npmjs.org/chalk/-/chalk-3.0.0.tgz" - integrity sha512-4D3B6Wf41KOYRFdszmDqMCGq5VV/uMAB273JILmO+3jAlh8X4qDtdtgCR3fxtbLEMzSx22QdhnDcJvu2u1fVwg== - dependencies: - ansi-styles "^4.1.0" - supports-color "^7.1.0" - chalk@^2.4.2: version "2.4.2" resolved "https://registry.npmjs.org/chalk/-/chalk-2.4.2.tgz" @@ -3537,7 +3384,23 @@ chalk@^2.4.2: escape-string-regexp "^1.0.5" supports-color "^5.3.0" -chalk@^4.0.0, chalk@^4.1.0: +chalk@^3.0.0, chalk@3.0.0: + version "3.0.0" + resolved "https://registry.npmjs.org/chalk/-/chalk-3.0.0.tgz" + integrity sha512-4D3B6Wf41KOYRFdszmDqMCGq5VV/uMAB273JILmO+3jAlh8X4qDtdtgCR3fxtbLEMzSx22QdhnDcJvu2u1fVwg== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + +chalk@^4.0.0: + version "4.1.2" + resolved "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz" + integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA== + dependencies: + ansi-styles "^4.1.0" + supports-color "^7.1.0" + +chalk@^4.1.0: version "4.1.2" resolved "https://registry.npmjs.org/chalk/-/chalk-4.1.2.tgz" integrity sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA== @@ -3570,7 +3433,7 @@ character-reference-invalid@^2.0.0: resolved "https://registry.npmjs.org/character-reference-invalid/-/character-reference-invalid-2.0.1.tgz" integrity sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw== -chart.js@^4.4.2: +chart.js@^4.1.1, chart.js@^4.4.2: version "4.4.2" resolved "https://registry.npmjs.org/chart.js/-/chart.js-4.4.2.tgz" integrity sha512-6GD7iKwFpP5kbSD4MeRRRlTnQvxfQREy36uEtm1hzHzcOqwWx0YEHuspuoNlslu+nciLIB7fjjsHkUv/FzFcOg== @@ -3582,7 +3445,7 @@ check-error@^1.0.2: resolved "https://registry.npmjs.org/check-error/-/check-error-1.0.2.tgz" integrity sha512-BrgHpW9NURQgzoNyjfq0Wu6VFO6D7IZEmJNdtgNqpzGG8RuNFHt2jQxWlAs4HMe119chBnv+34syEZtc6IhLtA== -"chokidar@>=3.0.0 <4.0.0", chokidar@^3.5.3: +chokidar@^3.5.3, "chokidar@>=3.0.0 <4.0.0": version "3.5.3" resolved "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz" integrity sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw== @@ -3619,7 +3482,7 @@ clean-stack@^2.0.0: resolved "https://registry.npmjs.org/clean-stack/-/clean-stack-2.2.0.tgz" integrity sha512-4diC9HaTE+KRAMWhDhrGOECgWZxoevMc5TlkObMqNSsVU62PYzXZ/SMTjzyGAFF1YusgxGcSWTEXBhp0CPwQ1A== -client-only@0.0.1, client-only@^0.0.1: +client-only@^0.0.1, client-only@0.0.1: version "0.0.1" resolved "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz" integrity sha512-IV3Ou0jSMzZrd3pZ48nLkT9DA7Ag1pnPzaiQhpW7c3RbcqqzvzzVu+L8gfqMp/8IM2MQtSiqaCxrrcfu8I8rMA== @@ -3648,16 +3511,16 @@ color-convert@^2.0.1: dependencies: color-name "~1.1.4" -color-name@1.1.3: - version "1.1.3" - resolved "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz" - integrity sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw== - color-name@^1.0.0, color-name@~1.1.4: version "1.1.4" resolved "https://registry.npmjs.org/color-name/-/color-name-1.1.4.tgz" integrity sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA== +color-name@1.1.3: + version "1.1.3" + resolved "https://registry.npmjs.org/color-name/-/color-name-1.1.3.tgz" + integrity sha512-72fSenhMw2HZMTVHeCA9KCmpEIbzWiQsjN+BHcBbS9vr1mtt+vJjPdksIBNUmKAW8TFUDPJK5SUU3QhE9NEXDw== + color-string@^1.9.0: version "1.9.1" resolved "https://registry.npmjs.org/color-string/-/color-string-1.9.1.tgz" @@ -3674,6 +3537,11 @@ color@^4.2.3: color-convert "^2.0.1" color-string "^1.9.0" +colord@^2.9.3: + version "2.9.3" + resolved "https://registry.npmjs.org/colord/-/colord-2.9.3.tgz" + integrity sha512-jeC1axXpnb0/2nn/Y1LPuLdgXBLH7aDcHu4KEKfqw3CUhX7ZpfBSlPKyqXE6btIgEzfWtrX3/tyBCaCvXvMkOw== + combined-stream@^1.0.8: version "1.0.8" resolved "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz" @@ -3734,13 +3602,6 @@ crelt@^1.0.0: resolved "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz" integrity sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g== -cross-fetch@3.1.6: - version "3.1.6" - resolved "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.6.tgz" - integrity sha512-riRvo06crlE8HiqOwIpQhxwdOk4fOeR7FVM/wXoxchFEqMNUjvbs3bfo4OTgMEMHzppd4DxFBDbyySj8Cv781g== - dependencies: - node-fetch "^2.6.11" - cross-fetch@^3.1.5: version "3.1.8" resolved "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.8.tgz" @@ -3748,6 +3609,13 @@ cross-fetch@^3.1.5: dependencies: node-fetch "^2.6.12" +cross-fetch@3.1.6: + version "3.1.6" + resolved "https://registry.npmjs.org/cross-fetch/-/cross-fetch-3.1.6.tgz" + integrity sha512-riRvo06crlE8HiqOwIpQhxwdOk4fOeR7FVM/wXoxchFEqMNUjvbs3bfo4OTgMEMHzppd4DxFBDbyySj8Cv781g== + dependencies: + node-fetch "^2.6.11" + cross-spawn@^7.0.0, cross-spawn@^7.0.2, cross-spawn@^7.0.3: version "7.0.3" resolved "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz" @@ -3794,7 +3662,15 @@ csstype@^3.0.2, csstype@^3.0.6: resolved "https://registry.npmjs.org/csstype/-/csstype-3.1.2.tgz" integrity sha512-I7K1Uu0MBPzaFKg4nI5Q7Vs2t+3gWWW648spaF+Rg7pI9ds18Ugn+lvg4SHczUdKlHI5LWBXyqfS8+DufyBsgQ== -"d3-array@2 - 3", "d3-array@2.10.0 - 3", d3-array@^3.1.6: +d@^1.0.1, d@1: + version "1.0.1" + resolved "https://registry.npmjs.org/d/-/d-1.0.1.tgz" + integrity sha512-m62ShEObQ39CfralilEQRjH6oAMtNCV1xJyEx5LpRYUVN+EviphDgUc/F3hnYbADmkiNs67Y+3ylmlG7Lnu+FA== + dependencies: + es5-ext "^0.10.50" + type "^1.0.1" + +d3-array@^3.1.6, "d3-array@2 - 3", "d3-array@2.10.0 - 3": version "3.2.4" resolved "https://registry.npmjs.org/d3-array/-/d3-array-3.2.4.tgz" integrity sha512-tdQAmyA18i4J7wprpYq8ClcxZy3SC31QMeByyCFyRt7BVHdREQZ5lpzoe5mFEYZUWe+oq8HBvk9JjpibyEV4Jg== @@ -3816,7 +3692,7 @@ d3-ease@^3.0.1: resolved "https://registry.npmjs.org/d3-format/-/d3-format-3.1.0.tgz" integrity sha512-YyUI6AEuY/Wpt8KWLgZHsIU86atmikuoOmCfommt0LYHiQSPjvX2AcFc38PX0CBpr2RCyZhjex+NS/LPOv6YqA== -"d3-interpolate@1.2.0 - 3", d3-interpolate@^3.0.1: +d3-interpolate@^3.0.1, "d3-interpolate@1.2.0 - 3": version "3.0.1" resolved "https://registry.npmjs.org/d3-interpolate/-/d3-interpolate-3.0.1.tgz" integrity sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g== @@ -3853,7 +3729,7 @@ d3-shape@^3.1.0: dependencies: d3-time "1 - 3" -"d3-time@1 - 3", "d3-time@2.1.1 - 3", d3-time@^3.0.0: +d3-time@^3.0.0, "d3-time@1 - 3", "d3-time@2.1.1 - 3": version "3.1.0" resolved "https://registry.npmjs.org/d3-time/-/d3-time-3.1.0.tgz" integrity sha512-VqKjzBLejbSMT4IgbmVgDjpkYrNWUYJnbCGo874u7MMKIWsILRX+OpX/gTk8MqjpT1A/c6HY2dCA77ZN0lkQ2Q== @@ -3870,14 +3746,6 @@ d3-voronoi@^1.1.4: resolved "https://registry.npmjs.org/d3-voronoi/-/d3-voronoi-1.1.4.tgz" integrity sha512-dArJ32hchFsrQ8uMiTBLq256MpnZjeuBtdHpaDlYuQyjU0CVzCJl/BVW+SkszaAeH95D/8gxqAhgx0ouAWAfRg== -d@1, d@^1.0.1: - version "1.0.1" - resolved "https://registry.npmjs.org/d/-/d-1.0.1.tgz" - integrity sha512-m62ShEObQ39CfralilEQRjH6oAMtNCV1xJyEx5LpRYUVN+EviphDgUc/F3hnYbADmkiNs67Y+3ylmlG7Lnu+FA== - dependencies: - es5-ext "^0.10.50" - type "^1.0.1" - damerau-levenshtein@^1.0.8: version "1.0.8" resolved "https://registry.npmjs.org/damerau-levenshtein/-/damerau-levenshtein-1.0.8.tgz" @@ -3904,13 +3772,6 @@ date-fns@2.30.0: dependencies: "@babel/runtime" "^7.21.0" -debug@4, debug@^4.1.0, debug@^4.1.1, debug@^4.3.1, debug@^4.3.2, debug@^4.3.4: - version "4.3.4" - resolved "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz" - integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== - dependencies: - ms "2.1.2" - debug@^2.2.0, debug@^2.6.9: version "2.6.9" resolved "https://registry.npmjs.org/debug/-/debug-2.6.9.tgz" @@ -3932,6 +3793,41 @@ debug@^4.0.0: dependencies: ms "2.1.2" +debug@^4.1.0: + version "4.3.4" + resolved "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz" + integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== + dependencies: + ms "2.1.2" + +debug@^4.1.1: + version "4.3.4" + resolved "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz" + integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== + dependencies: + ms "2.1.2" + +debug@^4.3.1: + version "4.3.4" + resolved "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz" + integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== + dependencies: + ms "2.1.2" + +debug@^4.3.2: + version "4.3.4" + resolved "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz" + integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== + dependencies: + ms "2.1.2" + +debug@^4.3.4, debug@4: + version "4.3.4" + resolved "https://registry.npmjs.org/debug/-/debug-4.3.4.tgz" + integrity sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ== + dependencies: + ms "2.1.2" + decimal.js@^10.4.3: version "10.4.3" resolved "https://registry.npmjs.org/decimal.js/-/decimal.js-10.4.3.tgz" @@ -4167,7 +4063,7 @@ emoji-regex@^9.2.2: resolved "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz" integrity sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg== -encoding@0.1.13, encoding@^0.1.13: +encoding@^0.1.0, encoding@^0.1.13, encoding@0.1.13: version "0.1.13" resolved "https://registry.npmjs.org/encoding/-/encoding-0.1.13.tgz" integrity sha512-ETBauow1T35Y/WZMkio9jiM0Z5xjHHmJ4XmjZOq1l/dXz3lr2sRn87nJy20RupqSh1F2m3HHPSp8ShIPQJrJ3A== @@ -4447,7 +4343,7 @@ eslint-module-utils@^2.7.4, eslint-module-utils@^2.8.0: dependencies: debug "^3.2.7" -eslint-plugin-import@^2.28.1: +eslint-plugin-import@*, eslint-plugin-import@^2.28.1: version "2.29.1" resolved "https://registry.npmjs.org/eslint-plugin-import/-/eslint-plugin-import-2.29.1.tgz" integrity sha512-BbPC0cuExzhiMo4Ff1BTVwHpjjv28C5R+btTOGaCRC7UEz801up0JadwkeSk5Ued6TG34uaczuVuH6qyy5YUxw== @@ -4545,7 +4441,7 @@ eslint-visitor-keys@^3.3.0, eslint-visitor-keys@^3.4.1, eslint-visitor-keys@^3.4 resolved "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-3.4.2.tgz" integrity sha512-8drBzUEyZ2llkpCA67iYrgEssKDUu68V8ChqqOfFupIaG/LCVPUT+CoGJpT77zJprs4T/W7p07LP7zAIMuweVw== -eslint@8.46.0: +eslint@*, "eslint@^2 || ^3 || ^4 || ^5 || ^6 || ^7.2.0 || ^8", "eslint@^3 || ^4 || ^5 || ^6 || ^7 || ^8", "eslint@^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0", "eslint@^6.0.0 || ^7.0.0 || ^8.0.0", "eslint@^6.0.0 || ^7.0.0 || >=8.0.0", "eslint@^7.23.0 || ^8.0.0", eslint@>=2.0.0, eslint@8.46.0: version "8.46.0" resolved "https://registry.npmjs.org/eslint/-/eslint-8.46.0.tgz" integrity sha512-cIO74PvbW0qU8e0mIvk5IV3ToWdCq5FYG6gWPHHkx6gNdjlbAYvtfHmlCMXxjcoVaIdwy/IAt3+mDkZkfvb2Dg== @@ -4894,7 +4790,7 @@ fs.realpath@^1.0.0: resolved "https://registry.npmjs.org/fs.realpath/-/fs.realpath-1.0.0.tgz" integrity sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw== -fsevents@2.3.2, fsevents@~2.3.2: +fsevents@~2.3.2, fsevents@2.3.2: version "2.3.2" resolved "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz" integrity sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA== @@ -4969,7 +4865,7 @@ github-from-package@0.0.0: resolved "https://registry.npmjs.org/github-from-package/-/github-from-package-0.0.0.tgz" integrity sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw== -glob-parent@^5.1.2, glob-parent@~5.1.2: +glob-parent@^5.1.2: version "5.1.2" resolved "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz" integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== @@ -4983,7 +4879,14 @@ glob-parent@^6.0.2: dependencies: is-glob "^4.0.3" -glob@10.3.10, glob@^10.2.2, glob@^10.3.10: +glob-parent@~5.1.2: + version "5.1.2" + resolved "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz" + integrity sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow== + dependencies: + is-glob "^4.0.1" + +glob@^10.2.2: version "10.3.10" resolved "https://registry.npmjs.org/glob/-/glob-10.3.10.tgz" integrity sha512-fa46+tv1Ak0UPK1TOy/pZrIybNNt4HCv7SDzwyfiOZkvZLEbjsZkJBPtDHVshZjbecAoAGSC20MjLDG/qr679g== @@ -4994,17 +4897,16 @@ glob@10.3.10, glob@^10.2.2, glob@^10.3.10: minipass "^5.0.0 || ^6.0.2 || ^7.0.0" path-scurry "^1.10.1" -glob@7.1.6: - version "7.1.6" - resolved "https://registry.npmjs.org/glob/-/glob-7.1.6.tgz" - integrity sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA== +glob@^10.3.10: + version "10.3.10" + resolved "https://registry.npmjs.org/glob/-/glob-10.3.10.tgz" + integrity sha512-fa46+tv1Ak0UPK1TOy/pZrIybNNt4HCv7SDzwyfiOZkvZLEbjsZkJBPtDHVshZjbecAoAGSC20MjLDG/qr679g== dependencies: - fs.realpath "^1.0.0" - inflight "^1.0.4" - inherits "2" - minimatch "^3.0.4" - once "^1.3.0" - path-is-absolute "^1.0.0" + foreground-child "^3.1.0" + jackspeak "^2.3.5" + minimatch "^9.0.1" + minipass "^5.0.0 || ^6.0.2 || ^7.0.0" + path-scurry "^1.10.1" glob@^7.1.3: version "7.2.3" @@ -5029,6 +4931,29 @@ glob@^8.0.3: minimatch "^5.0.1" once "^1.3.0" +glob@10.3.10: + version "10.3.10" + resolved "https://registry.npmjs.org/glob/-/glob-10.3.10.tgz" + integrity sha512-fa46+tv1Ak0UPK1TOy/pZrIybNNt4HCv7SDzwyfiOZkvZLEbjsZkJBPtDHVshZjbecAoAGSC20MjLDG/qr679g== + dependencies: + foreground-child "^3.1.0" + jackspeak "^2.3.5" + minimatch "^9.0.1" + minipass "^5.0.0 || ^6.0.2 || ^7.0.0" + path-scurry "^1.10.1" + +glob@7.1.6: + version "7.1.6" + resolved "https://registry.npmjs.org/glob/-/glob-7.1.6.tgz" + integrity sha512-LwaxwyZ72Lk7vZINtNNrywX0ZuLyStrdDtabefZKAY5ZGJhVtgdznluResxNmPitE0SAO+O26sWTHeKSI2wMBA== + dependencies: + fs.realpath "^1.0.0" + inflight "^1.0.4" + inherits "2" + minimatch "^3.0.4" + once "^1.3.0" + path-is-absolute "^1.0.0" + globals@^11.1.0: version "11.12.0" resolved "https://registry.npmjs.org/globals/-/globals-11.12.0.tgz" @@ -5362,14 +5287,14 @@ i18next-http-backend@2.2.1: dependencies: cross-fetch "3.1.6" -i18next@23.4.2: +"i18next@>= 23.2.3", i18next@23.4.2: version "23.4.2" resolved "https://registry.npmjs.org/i18next/-/i18next-23.4.2.tgz" integrity sha512-hkVPHKFLtn9iewdqHDiU+MGVIBk+bVFn5usw7CIeCn/SBcVKGTItGdjNPm2B8Lnz42CeHUlnSOTgsr5vbITjhA== dependencies: "@babel/runtime" "^7.22.5" -iconv-lite@0.6.3, iconv-lite@^0.6.2: +iconv-lite@^0.6.2, iconv-lite@0.6.3: version "0.6.3" resolved "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz" integrity sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw== @@ -5422,7 +5347,7 @@ inflight@^1.0.4: once "^1.3.0" wrappy "1" -inherits@2, inherits@^2.0.3, inherits@^2.0.4: +inherits@^2.0.3, inherits@^2.0.4, inherits@2: version "2.0.4" resolved "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz" integrity sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ== @@ -5818,16 +5743,16 @@ jose@^4.14.0: resolved "https://registry.npmjs.org/jose/-/jose-4.14.4.tgz" integrity sha512-j8GhLiKmUAh+dsFXlX1aJCbt5KMibuKb+d7j1JaOJG6s2UjX1PQlW+OKB/sD4a/5ZYF4RcmYmLSndOoU3Lt/3g== -js-cookie@3.0.1: - version "3.0.1" - resolved "https://registry.npmjs.org/js-cookie/-/js-cookie-3.0.1.tgz" - integrity sha512-+0rgsUXZu4ncpPxRL+lNEptWMOWl9etvPHc/koSRp6MPwpRYAhmk0dUG00J4bxVV3r9uUzfo24wW0knS07SKSw== - js-cookie@^2.2.1: version "2.2.1" resolved "https://registry.npmjs.org/js-cookie/-/js-cookie-2.2.1.tgz" integrity sha512-HvdH2LzI/EAZcUwA8+0nKNtWHqS+ZmijLA30RwZA0bo7ToCckjK5MkGhjED9KoRcXO6BaGI3I9UIzSA1FKFPOQ== +js-cookie@3.0.1: + version "3.0.1" + resolved "https://registry.npmjs.org/js-cookie/-/js-cookie-3.0.1.tgz" + integrity sha512-+0rgsUXZu4ncpPxRL+lNEptWMOWl9etvPHc/koSRp6MPwpRYAhmk0dUG00J4bxVV3r9uUzfo24wW0knS07SKSw== + "js-tokens@^3.0.0 || ^4.0.0", js-tokens@^4.0.0: version "4.0.0" resolved "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz" @@ -5848,7 +5773,7 @@ js-yaml@^4.1.0: dependencies: argparse "^2.0.1" -jsdom@22.1.0: +jsdom@*, jsdom@22.1.0: version "22.1.0" resolved "https://registry.npmjs.org/jsdom/-/jsdom-22.1.0.tgz" integrity sha512-/9AVW7xNbsBv6GfWho4TTNjEo9fe6Zhf9O7s0Fhhr3u+awPwAJMKwAMXnkk5vBxflqLW9hTHX/0cs+P3gW+cQw== @@ -6014,7 +5939,7 @@ lodash.merge@^4.6.2: resolved "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz" integrity sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ== -lodash@4.17.21, lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.21: +lodash@^4.17.15, lodash@^4.17.19, lodash@^4.17.21, lodash@4.17.21: version "4.17.21" resolved "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz" integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== @@ -6047,7 +5972,7 @@ lowlight@^2.0.0: fault "^2.0.0" highlight.js "~11.8.0" -lru-cache@^10.0.1, "lru-cache@^9.1.1 || ^10.0.0": +lru-cache@^10.0.1: version "10.0.3" resolved "https://registry.npmjs.org/lru-cache/-/lru-cache-10.0.3.tgz" integrity sha512-B7gr+F6MkqB3uzINHXNctGieGsRTMwIBgxkp0yq/5BwcuDzD4A8wQpHQW6vDAm1uKSLQghmRdD9sKqf2vJ1cEg== @@ -6066,6 +5991,11 @@ lru-cache@^6.0.0: dependencies: yallist "^4.0.0" +"lru-cache@^9.1.1 || ^10.0.0": + version "10.0.3" + resolved "https://registry.npmjs.org/lru-cache/-/lru-cache-10.0.3.tgz" + integrity sha512-B7gr+F6MkqB3uzINHXNctGieGsRTMwIBgxkp0yq/5BwcuDzD4A8wQpHQW6vDAm1uKSLQghmRdD9sKqf2vJ1cEg== + lz-string@^1.5.0: version "1.5.0" resolved "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz" @@ -6717,16 +6647,16 @@ minipass@^3.0.0: dependencies: yallist "^4.0.0" -minipass@^5.0.0: - version "5.0.0" - resolved "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz" - integrity sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ== - "minipass@^5.0.0 || ^6.0.2 || ^7.0.0", minipass@^7.0.2, minipass@^7.0.3: version "7.0.4" resolved "https://registry.npmjs.org/minipass/-/minipass-7.0.4.tgz" integrity sha512-jYofLM5Dam9279rdkWzqHozUo4ybjdZmCsDHePy5V/PbBcVMiSZR97gmAy45aqi8CK1lG2ECd356FU86avfwUQ== +minipass@^5.0.0: + version "5.0.0" + resolved "https://registry.npmjs.org/minipass/-/minipass-5.0.0.tgz" + integrity sha512-3FnjYuehv9k6ovOEbyOswadCDPX1piCfhV8ncmYtHOjuPwylVWsghTLo7rabjC3Rx5xD4HDx8Wm1xnMF7S5qFQ== + minizlib@^2.1.1, minizlib@^2.1.2: version "2.1.2" resolved "https://registry.npmjs.org/minizlib/-/minizlib-2.1.2.tgz" @@ -6762,6 +6692,11 @@ mlly@^1.2.0, mlly@^1.4.0: pkg-types "^1.0.3" ufo "^1.1.2" +ms@^2.1.1: + version "2.1.3" + resolved "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz" + integrity sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA== + ms@2.0.0: version "2.0.0" resolved "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz" @@ -6772,11 +6707,6 @@ ms@2.1.2: resolved "https://registry.npmjs.org/ms/-/ms-2.1.2.tgz" integrity sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w== -ms@^2.1.1: - version "2.1.3" - resolved "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz" - integrity sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA== - mz@^2.7.0: version "2.7.0" resolved "https://registry.npmjs.org/mz/-/mz-2.7.0.tgz" @@ -6847,7 +6777,7 @@ next-tick@^1.1.0: resolved "https://registry.npmjs.org/next-tick/-/next-tick-1.1.0.tgz" integrity sha512-CXdUiJembsNjuToQvxayPZF9Vqht7hewsvy2sOWafLvi2awflj9mOC6bHIg50orX8IJvWKY9wYQ/zB2kogPslQ== -next@^14.1.0: +next@*, "next@^10.0.8 || ^11.0 || ^12.0 || ^13.0 || ^14.0", next@^14.1.0: version "14.1.0" resolved "https://registry.npmjs.org/next/-/next-14.1.0.tgz" integrity sha512-wlzrsbfeSU48YQBjZhDzOwhWhGsy+uQycR8bHAOt1LY1bn3zZEcDyHQOEoN3aWzQ8LHCAJ1nqrWCc9XF2+O45Q== @@ -7259,7 +7189,7 @@ playwright-core@1.38.0: resolved "https://registry.npmjs.org/playwright-core/-/playwright-core-1.38.0.tgz" integrity sha512-f8z1y8J9zvmHoEhKgspmCvOExF2XdcxMW8jNRuX4vkQFrzV4MlZ55iwb5QeyiFQgOFCUolXiRHgpjSEnqvO48g== -playwright@1.38.0: +playwright@*, playwright@1.38.0: version "1.38.0" resolved "https://registry.npmjs.org/playwright/-/playwright-1.38.0.tgz" integrity sha512-fJGw+HO0YY+fU/F1N57DMO+TmXHTrmr905J05zwAQE9xkuwP/QLDk63rVhmyxh03dYnEhnRbsdbH9B0UVVRB3A== @@ -7299,14 +7229,6 @@ postcss-nested@^6.0.1: dependencies: postcss-selector-parser "^6.0.11" -postcss-selector-parser@6.0.10: - version "6.0.10" - resolved "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz" - integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== - dependencies: - cssesc "^3.0.0" - util-deprecate "^1.0.2" - postcss-selector-parser@^6.0.11: version "6.0.13" resolved "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.13.tgz" @@ -7315,11 +7237,28 @@ postcss-selector-parser@^6.0.11: cssesc "^3.0.0" util-deprecate "^1.0.2" +postcss-selector-parser@6.0.10: + version "6.0.10" + resolved "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.10.tgz" + integrity sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w== + dependencies: + cssesc "^3.0.0" + util-deprecate "^1.0.2" + postcss-value-parser@^4.0.0, postcss-value-parser@^4.2.0: version "4.2.0" resolved "https://registry.npmjs.org/postcss-value-parser/-/postcss-value-parser-4.2.0.tgz" integrity sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ== +postcss@^8.0.0, postcss@^8.1.0, postcss@^8.2.14, postcss@^8.4.21, postcss@^8.4.23, postcss@^8.4.27, postcss@>=8.0.9, postcss@8.4.32: + version "8.4.32" + resolved "https://registry.npmjs.org/postcss/-/postcss-8.4.32.tgz" + integrity sha512-D/kj5JNu6oo2EIy+XL/26JEDTlIbB8hw85G8StOE6L74RQAVVP5rej6wxCNqyMbR4RkPfqvezVbPw81Ngd6Kcw== + dependencies: + nanoid "^3.3.7" + picocolors "^1.0.0" + source-map-js "^1.0.2" + postcss@8.4.31: version "8.4.31" resolved "https://registry.npmjs.org/postcss/-/postcss-8.4.31.tgz" @@ -7329,15 +7268,6 @@ postcss@8.4.31: picocolors "^1.0.0" source-map-js "^1.0.2" -postcss@8.4.32, postcss@^8.4.23, postcss@^8.4.27: - version "8.4.32" - resolved "https://registry.npmjs.org/postcss/-/postcss-8.4.32.tgz" - integrity sha512-D/kj5JNu6oo2EIy+XL/26JEDTlIbB8hw85G8StOE6L74RQAVVP5rej6wxCNqyMbR4RkPfqvezVbPw81Ngd6Kcw== - dependencies: - nanoid "^3.3.7" - picocolors "^1.0.0" - source-map-js "^1.0.2" - posthog-js@1.96.1: version "1.96.1" resolved "https://registry.npmjs.org/posthog-js/-/posthog-js-1.96.1.tgz" @@ -7547,7 +7477,7 @@ prosemirror-schema-list@^1.2.2: prosemirror-state "^1.0.0" prosemirror-transform "^1.7.3" -prosemirror-state@^1.0.0, prosemirror-state@^1.2.2, prosemirror-state@^1.3.1, prosemirror-state@^1.4.1: +prosemirror-state@^1.0.0, prosemirror-state@^1.2.2, prosemirror-state@^1.3.1, prosemirror-state@^1.4.1, prosemirror-state@^1.4.2: version "1.4.3" resolved "https://registry.npmjs.org/prosemirror-state/-/prosemirror-state-1.4.3.tgz" integrity sha512-goFKORVbvPuAQaXhpbemJFRKJ2aixr+AZMGiquiqKxaucC6hlpHNZHWgz5R7dS4roHiwq9vDctE//CZ++o0W1Q== @@ -7583,7 +7513,7 @@ prosemirror-transform@^1.0.0, prosemirror-transform@^1.1.0, prosemirror-transfor dependencies: prosemirror-model "^1.0.0" -prosemirror-view@^1.0.0, prosemirror-view@^1.1.0, prosemirror-view@^1.13.3, prosemirror-view@^1.27.0, prosemirror-view@^1.28.2, prosemirror-view@^1.31.0: +prosemirror-view@^1.0.0, prosemirror-view@^1.1.0, prosemirror-view@^1.13.3, prosemirror-view@^1.27.0, prosemirror-view@^1.28.2, prosemirror-view@^1.31.0, prosemirror-view@^1.31.2: version "1.32.4" resolved "https://registry.npmjs.org/prosemirror-view/-/prosemirror-view-1.32.4.tgz" integrity sha512-WoT+ZYePp0WQvp5coABAysheZg9WttW3TSEUNgsfDQXmVOJlnjkbFbXicKPvWFLiC0ZjKt1ykbyoVKqhVnCiSQ== @@ -7650,7 +7580,7 @@ react-colorful@^5.6.1: resolved "https://registry.npmjs.org/react-colorful/-/react-colorful-5.6.1.tgz" integrity sha512-1exovf0uGTGyq5mXQT0zgQ80uvj2PCwvF8zY1RN9/vbJVSjSo3fsB/4L3ObbF7u70NduSiK4xu4Y6q1MHoUGEw== -react-dom@^18.2.0: +react-dom@*, "react-dom@^16 || ^17 || ^18", "react-dom@^16.8 || ^17.0 || ^18.0", "react-dom@^16.8.0 || ^17.0.0 || ^18.0.0", "react-dom@^17.0.0 || ^18.0.0", react-dom@^18.0.0, "react-dom@^18.0.0 || ^17.0.0 || ^16.2.0", react-dom@^18.2.0, react-dom@>=16.8.0: version "18.2.0" resolved "https://registry.npmjs.org/react-dom/-/react-dom-18.2.0.tgz" integrity sha512-6IMTriUmvsjHUjNtEDudZfuDQUoWXVxKHhlEGSk81n4YFS+r/Kl99wXiwlVXtPBtJenozv2P+hxDsw9eA7Xo6g== @@ -7792,7 +7722,7 @@ react-use@17.4.0: ts-easing "^0.2.0" tslib "^2.1.0" -react@^18.2.0: +react@*, "react@^16 || ^17 || ^18", "react@^16.8 || ^17.0 || ^18.0", "react@^16.8.0 || ^17.0.0 || ^18.0.0", "react@^16.8.0 || ^17 || ^18", "react@^16.8.0 || ^17.0.0 || ^18.0.0", "react@^16.8.0-0 || ^17.0.0-0 || ^18.0.0-0", "react@^17.0.0 || ^18.0.0", react@^18.0.0, "react@^18.0.0 || ^17.0.0 || ^16.2.0", react@^18.2.0, "react@>= 16", "react@>= 16.8 || 18.0.0", "react@>= 16.8.0", "react@>= 16.8.0 || 17.x.x || ^18.0.0-0", "react@>=15.3.2 <=18", react@>=16, react@>=16.6.0, react@>=16.8.0, react@>=18, "react@15.x || 16.x || 17.x || 18.x", "react@16.x || 17.x || 18.x": version "18.2.0" resolved "https://registry.npmjs.org/react/-/react-18.2.0.tgz" integrity sha512-/3IjMdb2L9QbBdWiW5e3P2/npwMBaU9mHCSCUzNln0ZCYbcfTsGbTJrU/kGemdH2IWmB2ioZ+zkxtmq6g09fGQ== @@ -7938,7 +7868,7 @@ resolve-pkg-maps@^1.0.0: resolved "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz" integrity sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw== -resolve@1.22.8, resolve@^1.1.7, resolve@^1.19.0, resolve@^1.22.2, resolve@^1.22.4: +resolve@^1.1.7, resolve@^1.19.0, resolve@^1.22.2, resolve@^1.22.4, resolve@1.22.8: version "1.22.8" resolved "https://registry.npmjs.org/resolve/-/resolve-1.22.8.tgz" integrity sha512-oKWePCxqpd6FlLvGV1VU0x7bkPmmCNolxzjMf4NczoDnQcIWrAF+cPtZn5i6n+RfD2d9i0tzpKnG6Yk168yIyw== @@ -7973,7 +7903,7 @@ rimraf@^3.0.2: dependencies: glob "^7.1.3" -rollup@2.78.0: +rollup@^1.20.0||^2.0.0||^3.0.0, rollup@^2.68.0||^3.0.0, rollup@2.78.0: version "2.78.0" resolved "https://registry.npmjs.org/rollup/-/rollup-2.78.0.tgz" integrity sha512-4+YfbQC9QEVvKTanHhIAFVUFSRsezvQF8vFOJwtGfb9Bb+r014S+qryr9PSmw8x6sMnPkmFBGAvIFVQxvJxjtg== @@ -8047,7 +7977,7 @@ safe-regex-test@^1.0.0: resolved "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz" integrity sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg== -sass@^1.70.0: +sass@*, sass@^1.3.0, sass@^1.70.0: version "1.70.0" resolved "https://registry.npmjs.org/sass/-/sass-1.70.0.tgz" integrity sha512-uUxNQ3zAHeAx5nRFskBnrWzDUJrrvpCPD5FNAoRvTi0WwremlheES3tg+56PaVtCs5QDRX5CBLxxKMDJMEa1WQ== @@ -8075,7 +8005,12 @@ screenfull@^5.1.0: resolved "https://registry.npmjs.org/screenfull/-/screenfull-5.2.0.tgz" integrity sha512-9BakfsO2aUQN2K9Fdbj87RJIEZ82Q9IGim7FqM5OsebfoFC6ZHXgDq/KvniuLTPdeM8wY2o6Dj3WQ7KeQCj3cA== -semver@^6.3.0, semver@^6.3.1: +semver@^6.3.0: + version "6.3.1" + resolved "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz" + integrity sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA== + +semver@^6.3.1: version "6.3.1" resolved "https://registry.npmjs.org/semver/-/semver-6.3.1.tgz" integrity sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA== @@ -8204,16 +8139,11 @@ socks@^2.7.1: ip "^2.0.0" smart-buffer "^4.2.0" -"source-map-js@>=0.6.2 <2.0.0", source-map-js@^1.0.2: +source-map-js@^1.0.2, "source-map-js@>=0.6.2 <2.0.0": version "1.0.2" resolved "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz" integrity sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw== -source-map@0.5.6: - version "0.5.6" - resolved "https://registry.npmjs.org/source-map/-/source-map-0.5.6.tgz" - integrity sha512-MjZkVp0NHr5+TPihLcadqnlVoGIoWo4IBHptutGh9wI3ttUYvCG26HkSuDi+K6lsZ25syXJXcctwgyVCt//xqA== - source-map@^0.5.7: version "0.5.7" resolved "https://registry.npmjs.org/source-map/-/source-map-0.5.7.tgz" @@ -8224,6 +8154,11 @@ source-map@^0.6.1: resolved "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz" integrity sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g== +source-map@0.5.6: + version "0.5.6" + resolved "https://registry.npmjs.org/source-map/-/source-map-0.5.6.tgz" + integrity sha512-MjZkVp0NHr5+TPihLcadqnlVoGIoWo4IBHptutGh9wI3ttUYvCG26HkSuDi+K6lsZ25syXJXcctwgyVCt//xqA== + sourcemap-codec@^1.4.8: version "1.4.8" resolved "https://registry.npmjs.org/sourcemap-codec/-/sourcemap-codec-1.4.8.tgz" @@ -8317,6 +8252,13 @@ streamx@^2.15.0: fast-fifo "^1.1.0" queue-tick "^1.0.1" +string_decoder@^1.1.1: + version "1.3.0" + resolved "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz" + integrity sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA== + dependencies: + safe-buffer "~5.2.0" + "string-width-cjs@npm:string-width@^4.2.0": version "4.2.3" resolved "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz" @@ -8386,13 +8328,6 @@ string.prototype.trimstart@^1.0.6: define-properties "^1.1.4" es-abstract "^1.20.4" -string_decoder@^1.1.1: - version "1.3.0" - resolved "https://registry.npmjs.org/string_decoder/-/string_decoder-1.3.0.tgz" - integrity sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA== - dependencies: - safe-buffer "~5.2.0" - stringify-entities@^4.0.0: version "4.0.4" resolved "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz" @@ -8475,7 +8410,7 @@ styled-jsx@5.1.1: dependencies: client-only "0.0.1" -stylis@4.2.0, stylis@^4.0.6: +stylis@^4.0.6, stylis@4.2.0: version "4.2.0" resolved "https://registry.npmjs.org/stylis/-/stylis-4.2.0.tgz" integrity sha512-Orov6g6BB1sDfYgzWfTHDOxamtX1bE/zo104Dh9e6fqJ3PooipYyfJ0pUmrZO2wAvO8YbEyeFrkV91XTsGMSrw== @@ -8530,7 +8465,7 @@ tailwind-merge@1.14.0: resolved "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-1.14.0.tgz" integrity sha512-3mFKyCo/MBcgyOTlrY8T7odzZFx+w+qKSMAmdFzRvqBfLlSigU6TZnlFHK0lkMwj9Bj8OYU+9yW9lmGuS0QEnQ== -tailwindcss@3.4.0: +"tailwindcss@>=3.0.0 || >= 3.0.0-alpha.1", "tailwindcss@>=3.0.0 || insiders", tailwindcss@3.4.0: version "3.4.0" resolved "https://registry.npmjs.org/tailwindcss/-/tailwindcss-3.4.0.tgz" integrity sha512-VigzymniH77knD1dryXbyxR+ePHihHociZbXnLZHUyzf2MMs2ZVqlUrZ3FvpXP8pno9JzmILt1sZPD19M3IxtA== @@ -8658,7 +8593,7 @@ tinyspy@^2.1.1: resolved "https://registry.npmjs.org/tinyspy/-/tinyspy-2.1.1.tgz" integrity sha512-XPJL2uSzcOyBMky6OFrusqWlzfFrXtE0hPuMgW8A2HmaqrPo4ZQHRN/V0QXN3FSjKxpsbRrFc5LI7KOwBsT1/w== -tippy.js@6.3.7, tippy.js@^6.3.7: +tippy.js@^6.3.7, tippy.js@6.3.7: version "6.3.7" resolved "https://registry.npmjs.org/tippy.js/-/tippy.js-6.3.7.tgz" integrity sha512-E1d3oP2emgJ9dRQZdf3Kkn0qJgI6ZLpyS5z6ZkY1DF3kaQaBsGZsndEpHwx+eC+tYM41HaSNvNtLx8tU57FzTQ== @@ -8739,16 +8674,16 @@ tsconfig-paths@^3.15.0: minimist "^1.2.6" strip-bom "^3.0.0" +tslib@*, tslib@^2.0.0, tslib@^2.1.0, tslib@^2.4.0, tslib@^2.4.1, tslib@^2.5.0, tslib@^2.6.0: + version "2.6.1" + resolved "https://registry.npmjs.org/tslib/-/tslib-2.6.1.tgz" + integrity sha512-t0hLfiEKfMUoqhG+U1oid7Pva4bbDPHYfJNiB7BiIjRkj1pyC++4N3huJfqY6aRH6VTB0rvtzQwjM4K6qpfOig== + tslib@^1.8.1: version "1.14.1" resolved "https://registry.npmjs.org/tslib/-/tslib-1.14.1.tgz" integrity sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg== -tslib@^2.0.0, tslib@^2.1.0, tslib@^2.4.0, tslib@^2.4.1, tslib@^2.5.0, tslib@^2.6.0: - version "2.6.1" - resolved "https://registry.npmjs.org/tslib/-/tslib-2.6.1.tgz" - integrity sha512-t0hLfiEKfMUoqhG+U1oid7Pva4bbDPHYfJNiB7BiIjRkj1pyC++4N3huJfqY6aRH6VTB0rvtzQwjM4K6qpfOig== - tsutils@^3.21.0: version "3.21.0" resolved "https://registry.npmjs.org/tsutils/-/tsutils-3.21.0.tgz" @@ -8846,16 +8781,16 @@ typedarray-to-buffer@^3.1.5: dependencies: is-typedarray "^1.0.0" -typescript@5.1.6: - version "5.1.6" - resolved "https://registry.npmjs.org/typescript/-/typescript-5.1.6.tgz" - integrity sha512-zaWCozRZ6DLEWAWFrVDz1H6FVXzUSfTy5FUMWsQlU8Ym5JP9eO4xkTIROFCQvhQf61z6O/G6ugw3SgAnvvm+HA== - typescript@^4.9.5: version "4.9.5" resolved "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz" integrity sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g== +"typescript@>=2.8.0 || >= 3.2.0-dev || >= 3.3.0-dev || >= 3.4.0-dev || >= 3.5.0-dev || >= 3.6.0-dev || >= 3.6.0-beta || >= 3.7.0-dev || >= 3.7.0-beta", typescript@>=3.3.1, typescript@5.1.6: + version "5.1.6" + resolved "https://registry.npmjs.org/typescript/-/typescript-5.1.6.tgz" + integrity sha512-zaWCozRZ6DLEWAWFrVDz1H6FVXzUSfTy5FUMWsQlU8Ym5JP9eO4xkTIROFCQvhQf61z6O/G6ugw3SgAnvvm+HA== + uc.micro@^1.0.1, uc.micro@^1.0.5: version "1.0.6" resolved "https://registry.npmjs.org/uc.micro/-/uc.micro-1.0.6.tgz" @@ -9064,7 +8999,7 @@ use-sidecar@^1.1.2: detect-node-es "^1.1.0" tslib "^2.0.0" -utf-8-validate@^5.0.2: +utf-8-validate@^5.0.2, utf-8-validate@>=5.0.2: version "5.0.10" resolved "https://registry.npmjs.org/utf-8-validate/-/utf-8-validate-5.0.10.tgz" integrity sha512-Z6czzLq4u8fPOyx7TU6X3dvUZVvoJmxSQ+IcrlmagKhilxlhZgxPK6C5Jqbkw1IDUmFTM+cz9QDnnLTwDz/2gQ== @@ -9462,7 +9397,7 @@ vite-node@0.32.4: picocolors "^1.0.0" vite "^3.0.0 || ^4.0.0" -"vite@^3.0.0 || ^4.0.0": +"vite@^3.0.0 || ^4.0.0", vite@^4.2.0: version "4.5.2" resolved "https://registry.npmjs.org/vite/-/vite-4.5.2.tgz" integrity sha512-tBCZBNSBbHQkaGyhGCDUGqeo2ph8Fstyp6FMSvTtsXeZSPpSMGlviAOav2hxVTqFcx8Hj/twtWKsMJXNY0xI8w== @@ -9473,7 +9408,7 @@ vite-node@0.32.4: optionalDependencies: fsevents "~2.3.2" -vitest@0.32.4: +"vitest@>= 0.32", vitest@0.32.4: version "0.32.4" resolved "https://registry.npmjs.org/vitest/-/vitest-0.32.4.tgz" integrity sha512-3czFm8RnrsWwIzVDu/Ca48Y/M+qh3vOnF16czJm98Q/AN1y3B6PBsyV8Re91Ty5s7txKNjEhpgtGPcfdbh2MZg==