diff --git a/src/backend/crud/deployment.py b/src/backend/crud/deployment.py index 09668c3cc6..a85f356825 100644 --- a/src/backend/crud/deployment.py +++ b/src/backend/crud/deployment.py @@ -113,7 +113,7 @@ def update_deployment( db: Session, deployment: Deployment, new_deployment: DeploymentUpdate ) -> Deployment: """ - Update a deployment by ID. + Update a deployment. Args: db (Session): Database session. @@ -125,8 +125,10 @@ def update_deployment( """ for attr, value in new_deployment.model_dump(exclude_none=True).items(): setattr(deployment, attr, value) + db.commit() db.refresh(deployment) + return deployment diff --git a/src/backend/main.py b/src/backend/main.py index 4cbb3c9e02..d0a560eaab 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -17,6 +17,7 @@ ) from backend.config.routers import ROUTER_DEPENDENCIES, RouterName from backend.config.settings import Settings +from backend.database_models.database import get_session from backend.exceptions import DeploymentNotFoundError from backend.routers.agent import router as agent_router from backend.routers.auth import router as auth_router @@ -31,6 +32,7 @@ from backend.routers.snapshot import router as snapshot_router from backend.routers.tool import router as tool_router from backend.routers.user import router as user_router +from backend.services import deployment as deployment_service from backend.services.context import ContextMiddleware, get_context from backend.services.logger.middleware import LoggingMiddleware @@ -108,6 +110,9 @@ def create_app() -> FastAPI: app.add_middleware(ContextMiddleware) # This should be the first middleware app.add_exception_handler(SCIMException, scim_exception_handler) # pyright: ignore + # Update Deployments config + deployment_service.update_db_config_from_env(next(get_session())) + return app diff --git a/src/backend/routers/deployment.py b/src/backend/routers/deployment.py index 8037f6d19c..66ea0f5e09 100644 --- a/src/backend/routers/deployment.py +++ b/src/backend/routers/deployment.py @@ -144,7 +144,7 @@ async def delete_deployment( @router.post("/{deployment_id}/update_config", response_model=DeploymentDefinition) -async def update_config( +async def update_db_config( *, deployment_id: DeploymentIdPathParam, env_vars: UpdateDeploymentEnv, @@ -155,7 +155,7 @@ async def update_config( Set environment variables for the deployment. """ return mask_deployment_secrets( - deployment_service.update_config(session, deployment_id, valid_env_vars) + deployment_service.update_db_config(session, deployment_id, valid_env_vars) ) diff --git a/src/backend/services/deployment.py b/src/backend/services/deployment.py index aac7b77b6f..a5de4b5193 100644 --- a/src/backend/services/deployment.py +++ b/src/backend/services/deployment.py @@ -113,7 +113,7 @@ def get_deployment_definitions(session: DBSessionDep) -> list[DeploymentDefiniti return [*db_deployments.values(), *installed_deployments] -def update_config(session: DBSessionDep, deployment_id: str, env_vars: dict[str, str]) -> DeploymentDefinition: +def update_db_config(session: DBSessionDep, deployment_id: str, env_vars: dict[str, str]) -> DeploymentDefinition: logger.debug(event="update_config", deployment_id=deployment_id, env_vars=env_vars) db_deployment = deployment_crud.get_deployment(session, deployment_id) @@ -128,3 +128,31 @@ def update_config(session: DBSessionDep, deployment_id: str, env_vars: dict[str, updated_deployment = get_deployment_definition(session, deployment_id) return updated_deployment + +def update_db_config_from_env(session: DBSessionDep): + try: + for deployment_name, deployment in AVAILABLE_MODEL_DEPLOYMENTS.items(): + # Fetch local config + env_config = deployment.config() + # Fetch DB entity + db_deployment = deployment_crud.get_deployment_by_name(session, deployment_name) + + # Skip to next if no config or no DB deployment found + if not env_config or not db_deployment: + logger.debug(event="Updating DB deployment config, no config or no DB deployment found.") + continue + + db_config = dict(db_deployment.default_deployment_config) + + for key, value in env_config.items(): + db_config[key] = value + + deployment_crud.update_deployment( + session, + db_deployment, + DeploymentUpdate( + default_deployment_config=db_config + ) + ) + except Exception as e: + logger.error(event=f"Error while updating DB deployment config: {e}") diff --git a/src/backend/tests/unit/services/test_deployment.py b/src/backend/tests/unit/services/test_deployment.py index d7eae458c0..a399fe4acb 100644 --- a/src/backend/tests/unit/services/test_deployment.py +++ b/src/backend/tests/unit/services/test_deployment.py @@ -123,12 +123,12 @@ def test_get_deployment_definitions_with_db_deployments(session, mock_available_ assert any(d.id == "db-mock-cohere-platform-id" for d in definitions) def test_update_config_db(session, db_deployment) -> None: - deployment_service.update_config(session, db_deployment.id, {"COHERE_API_KEY": "new-db-test-api-key"}) + deployment_service.update_db_config(session, db_deployment.id, {"COHERE_API_KEY": "new-db-test-api-key"}) updated_deployment = session.query(Deployment).get("db-mock-cohere-platform-id") assert updated_deployment.default_deployment_config == {"COHERE_API_KEY": "new-db-test-api-key"} def test_update_config_no_db_deployments(session, mock_available_model_deployments, clear_db_deployments) -> None: with patch("backend.services.deployment.update_env_file") as mock_update_env_file: with patch("backend.services.deployment.get_deployment_definition", return_value=MockCohereDeployment.to_deployment_definition()): - deployment_service.update_config(session, "some-deployment-id", {"API_KEY": "new-api-key"}) + deployment_service.update_db_config(session, "some-deployment-id", {"API_KEY": "new-api-key"}) mock_update_env_file.assert_called_with({"API_KEY": "new-api-key"})