diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..d71942da --- /dev/null +++ b/.env.example @@ -0,0 +1,15 @@ +# Google +GOOGLE_CLIENT_ID=your_google_client_id +GOOGLE_CLIENT_SECRET=your_google_client_secret + +# GitHub +GITHUB_CLIENT_ID=your_github_client_id +GITHUB_CLIENT_SECRET=your_github_client_secret + +# Microsoft +MICROSOFT_CLIENT_ID=your_microsoft_client_id +MICROSOFT_CLIENT_SECRET=your_microsoft_client_secret + +# General +SECRET_KEY=supersecret +BASE_URL=http://localhost:5000 diff --git a/gramps_webapi/app.py b/gramps_webapi/app.py index 2041c2a1..86107e4b 100644 --- a/gramps_webapi/app.py +++ b/gramps_webapi/app.py @@ -17,13 +17,13 @@ # along with this program. If not, see . # -"""Flask web app providing a REST API to a Gramps family tree.""" import logging import os import warnings from typing import Any, Dict, Optional + from flask import Flask, abort, g, send_from_directory from flask_compress import Compress from flask_cors import CORS @@ -31,12 +31,14 @@ from gramps.gen.config import config as gramps_config from gramps.gen.config import set as setconfig + from .api import api_blueprint -from .api.cache import request_cache, thumbnail_cache +from .api.cache import thumbnail_cache, request_cache from .api.ratelimiter import limiter from .api.search.embeddings import load_model from .api.util import close_db from .auth import user_db +from .auth.oidc import configure_oauth, oidc_bp from .config import DefaultConfig, DefaultConfigJWT from .const import API_PREFIX, ENV_CONFIG_FILE, TREE_MULTI from .dbmanager import WebDbManager @@ -49,20 +51,10 @@ def deprecated_config_from_env(app): This function will be removed eventually! """ options = [ - "TREE", - "SECRET_KEY", - "USER_DB_URI", - "POSTGRES_USER", - "POSTGRES_PASSWORD", - "MEDIA_BASE_DIR", - "SEARCH_INDEX_DIR", - "EMAIL_HOST", - "EMAIL_PORT", - "EMAIL_HOST_USER", - "EMAIL_HOST_PASSWORD", - "DEFAULT_FROM_EMAIL", - "BASE_URL", - "STATIC_PATH", + "TREE", "SECRET_KEY", "USER_DB_URI", "POSTGRES_USER", "POSTGRES_PASSWORD", + "MEDIA_BASE_DIR", "SEARCH_INDEX_DIR", "EMAIL_HOST", "EMAIL_PORT", + "EMAIL_HOST_USER", "EMAIL_HOST_PASSWORD", "DEFAULT_FROM_EMAIL", + "BASE_URL", "STATIC_PATH", ] for option in options: value = os.getenv(option) @@ -79,7 +71,6 @@ def deprecated_config_from_env(app): def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = True): """Flask application factory.""" app = Flask(__name__) - app.logger.setLevel(logging.INFO) # load default config @@ -89,7 +80,7 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = if os.getenv(ENV_CONFIG_FILE): app.config.from_envvar(ENV_CONFIG_FILE) - # use unprefixed environment variables if exist - deprecated! + # use unprefixed environment variables if exist - deprecated! deprecated_config_from_env(app) # use prefixed environment variables if exist @@ -111,9 +102,11 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = if db_path := os.getenv("GRAMPS_DATABASE_PATH"): setconfig("database.path", db_path) + if app.config.get("LOG_LEVEL"): app.logger.setLevel(app.config["LOG_LEVEL"]) + if app.config["TREE"] != TREE_MULTI: # create database if missing (only in single-tree mode) WebDbManager( @@ -122,6 +115,7 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = ignore_lock=app.config["IGNORE_DB_LOCK"], ) + if app.config["TREE"] == TREE_MULTI and not app.config["MEDIA_PREFIX_TREE"]: warnings.warn( "You have enabled multi-tree support, but `MEDIA_PREFIX_TREE` is " @@ -129,24 +123,31 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = "files to users belonging to different trees!" ) + if app.config["TREE"] == TREE_MULTI and app.config["NEW_DB_BACKEND"] != "sqlite": - # needed in case a new postgres tree is to be created + # needed in case a new postgres tree is to be created gramps_config.set("database.host", app.config["POSTGRES_HOST"]) gramps_config.set("database.port", str(app.config["POSTGRES_PORT"])) # load JWT default settings app.config.from_object(DefaultConfigJWT) - - # instantiate JWT manager JWTManager(app) + app.config["SQLALCHEMY_DATABASE_URI"] = app.config["USER_DB_URI"] user_db.init_app(app) + request_cache.init_app(app, config=app.config["REQUEST_CACHE_CONFIG"]) + + + configure_oauth(app) + app.register_blueprint(oidc_bp) + + thumbnail_cache.init_app(app, config=app.config["THUMBNAIL_CACHE_CONFIG"]) - # enable CORS for /api/... resources +# enable CORS for /api/... resources if app.config.get("CORS_ORIGINS"): CORS( app, @@ -156,13 +157,15 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = # enable gzip compression Compress(app) + static_path = app.config.get("STATIC_PATH") - # routes for static hosting (e.g. SPA frontend) + # routes for static hosting (e.g. SPA frontend) @app.route("/", methods=["GET", "POST"]) def send_index(): return send_from_directory(static_path, "index.html") + @app.route("/", methods=["GET", "POST"]) def send_static(path): if path.startswith(API_PREFIX[1:]): @@ -176,10 +179,11 @@ def send_static(path): # register the API blueprint app.register_blueprint(api_blueprint) limiter.init_app(app) - + # instantiate celery create_celery(app) + @app.teardown_appcontext def close_db_connection(exception) -> None: """Close the Gramps database after every request.""" @@ -190,6 +194,7 @@ def close_db_connection(exception) -> None: if db_write: close_db(db_write) + @app.teardown_request def close_user_db_connection(exception) -> None: """Close the user database after every request.""" @@ -198,13 +203,16 @@ def close_user_db_connection(exception) -> None: user_db.session.close() # pylint: disable=no-member user_db.session.remove() # pylint: disable=no-member + if app.config.get("VECTOR_EMBEDDING_MODEL"): app.config["_INITIALIZED_VECTOR_EMBEDDING_MODEL"] = load_model( app.config["VECTOR_EMBEDDING_MODEL"] ) + @app.route("/ready", methods=["GET"]) def ready(): return {"status": "ready"}, 200 + return app diff --git a/gramps_webapi/auth/oidc.py b/gramps_webapi/auth/oidc.py new file mode 100644 index 00000000..1a4db580 --- /dev/null +++ b/gramps_webapi/auth/oidc.py @@ -0,0 +1,156 @@ +import uuid +from authlib.integrations.flask_client import OAuth +from flask import Blueprint, redirect, url_for, request, current_app, jsonify +from flask_jwt_extended import create_access_token, create_refresh_token +from sqlalchemy.exc import IntegrityError + + +from . import user_db, User, add_user +from .const import ROLE_USER +from ..api.util import get_tree_id, abort_with_message, tree_exists +from ..api.auth import get_permissions +from ..const import TREE_MULTI + + +oauth = OAuth() +oidc_bp = Blueprint("oidc", __name__, url_prefix="/auth") + + +def configure_oauth(app): + if not app.config.get("OAUTH_ENABLED", False): + return + + + oauth.init_app(app) + + + if app.config.get("OAUTH_GOOGLE_CLIENT_ID") and app.config.get("OAUTH_GOOGLE_CLIENT_SECRET"): + oauth.register( + name="google", + client_id=app.config["OAUTH_GOOGLE_CLIENT_ID"], + client_secret=app.config["OAUTH_GOOGLE_CLIENT_SECRET"], + access_token_url="https://oauth2.googleapis.com/token", + authorize_url="https://accounts.google.com/o/oauth2/auth", + api_base_url="https://www.googleapis.com/oauth2/v1/", + client_kwargs={"scope": "openid email profile"}, + ) + + + if app.config.get("OAUTH_GITHUB_CLIENT_ID") and app.config.get("OAUTH_GITHUB_CLIENT_SECRET"): + oauth.register( + name="github", + client_id=app.config["OAUTH_GITHUB_CLIENT_ID"], + client_secret=app.config["OAUTH_GITHUB_CLIENT_SECRET"], + access_token_url="https://github.com/login/oauth/access_token", + authorize_url="https://github.com/login/oauth/authorize", + api_base_url="https://api.github.com/", + client_kwargs={"scope": "read:user user:email"}, + ) + + + if app.config.get("OAUTH_MICROSOFT_CLIENT_ID") and app.config.get("OAUTH_MICROSOFT_CLIENT_SECRET"): + oauth.register( + name="microsoft", + client_id=app.config["OAUTH_MICROSOFT_CLIENT_ID"], + client_secret=app.config["OAUTH_MICROSOFT_CLIENT_SECRET"], + access_token_url="https://login.microsoftonline.com/common/oauth2/v2.0/token", + authorize_url="https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + api_base_url="https://graph.microsoft.com/v1.0/", + client_kwargs={"scope": "openid email profile"}, + ) + + +@oidc_bp.route("/login/") +def login(provider): + if not current_app.config.get("OAUTH_ENABLED", False): + abort_with_message(403, "OAuth is not enabled") + + + redirect_uri = request.args.get("redirect_uri", url_for("oidc.authorize", provider=provider, _external=True)) + return oauth.create_client(provider).authorize_redirect(redirect_uri) + + +@oidc_bp.route("/callback/") +def authorize(provider): + if not current_app.config.get("OAUTH_ENABLED", False): + abort_with_message(403, "OAuth is not enabled") + + client = oauth.create_client(provider) + token = client.authorize_access_token() + + if provider == "google": + user_info = client.parse_id_token(token) + else: + user_info = client.get("user").json() + + email = user_info.get("email") + if not email: + abort_with_message(400, "No email provided") + + user = user_db.session.query(User).filter_by(email=email).first() + if not user: + # Enforce OIDC registration enable/disable + if not current_app.config.get("ALLOW_OIDC_REGISTRATION", True): + abort_with_message(403, "User registration is disabled") + + # Determine the tree to use + if current_app.config["TREE"] == TREE_MULTI: + # In multi-tree, OIDC must know which tree to assign! + abort_with_message(422, "tree is required for OIDC registration in multi-tree mode") + tree_id = current_app.config["TREE"] + if not tree_exists(tree_id): + abort_with_message(422, "Tree does not exist") + + try: + username = email.split("@", 1)[0] + dummy_password = uuid.uuid4().hex # Satisfy non-empty password check + add_user( + name=username, + password=dummy_password, + email=email, + fullname=user_info.get("name", ""), + role=ROLE_USER, + tree=tree_id + ) + user = user_db.session.query(User).filter_by(email=email).first() + if not user: + abort_with_message(500, "Failed to create user") + except ValueError as exc: + # Consistent error codes: 409 for conflict, 400 otherwise + msg = str(exc) + code = 409 if "exists" in msg.lower() else 400 + abort_with_message(code, msg) + + # Now continue as usual + tree_id = get_tree_id(str(user.id)) + permissions = get_permissions(username=user.name, tree=tree_id) + + access_token = create_access_token( + identity=str(user.id), + additional_claims={ + "permissions": list(permissions), + "tree": tree_id + } + ) + refresh_token = create_refresh_token(identity=str(user.id)) + + frontend_redirect = request.args.get("state") or current_app.config.get("FRONTEND_URL", "/") + + response = { + "access_token": access_token, + "refresh_token": refresh_token, + "token_type": "Bearer", + "expires_in": current_app.config.get("JWT_ACCESS_TOKEN_EXPIRES", 900), + "user": { + "name": user.name, + "email": user.email, + "full_name": user.fullname, + "role": user.role, + "tree": tree_id + } + } + + if request.headers.get("Accept") == "application/json": + return jsonify(response) + + return redirect(f"{frontend_redirect}?access_token={access_token}&refresh_token={refresh_token}") diff --git a/gramps_webapi/config.py b/gramps_webapi/config.py index fd64a0ab..fc9b0c28 100644 --- a/gramps_webapi/config.py +++ b/gramps_webapi/config.py @@ -17,16 +17,21 @@ # along with this program. If not, see . # + """Default configuration settings.""" + import datetime from pathlib import Path from typing import Dict + + class DefaultConfig(object): """Default configuration object.""" + PROPAGATE_EXCEPTIONS = True SEARCH_INDEX_DIR = "indexdir" # deprecated! SEARCH_INDEX_DB_URI = "" @@ -51,6 +56,12 @@ class DefaultConfig(object): "CACHE_THRESHOLD": 1000, "CACHE_DEFAULT_TIMEOUT": 0, } + PERSISTENT_CACHE_CONFIG = { + "CACHE_TYPE": "FileSystemCache", + "CACHE_DIR": str(Path.cwd() / "persistent_cache"), + "CACHE_THRESHOLD": 0, + "CACHE_DEFAULT_TIMEOUT": 0, + } POSTGRES_USER = None POSTGRES_PASSWORD = None POSTGRES_HOST = "localhost" @@ -69,11 +80,25 @@ class DefaultConfig(object): LLM_MODEL = "" LLM_MAX_CONTEXT_LENGTH = 50000 VECTOR_EMBEDDING_MODEL = "" + DISABLE_TELEMETRY = False + + + # OAuth configuration + OAUTH_GOOGLE_CLIENT_ID = "" + OAUTH_GOOGLE_CLIENT_SECRET = "" + OAUTH_GITHUB_CLIENT_ID = "" + OAUTH_GITHUB_CLIENT_SECRET = "" + OAUTH_MICROSOFT_CLIENT_ID = "" + OAUTH_MICROSOFT_CLIENT_SECRET = "" + OAUTH_ENABLED = False # Master switch for OAuth functionality + + class DefaultConfigJWT(object): """Default configuration for JWT auth.""" + JWT_TOKEN_LOCATION = ["headers", "query_string"] JWT_ACCESS_TOKEN_EXPIRES = datetime.timedelta(minutes=15) JWT_REFRESH_TOKEN_EXPIRES = False diff --git a/pyproject.toml b/pyproject.toml index 38ad8801..1785e881 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "gramps-ql>=0.4.0", "object-ql>=0.1.3", "sifts>=0.8.3", + "authlib", ] [project.optional-dependencies] diff --git a/requirements-dev.txt b/requirements-dev.txt index ea5aeadf..169f3b47 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,4 +7,4 @@ pydocstyle pre-commit celery[pytest] moto[s3]<5.0.0 -PyYAML +PyYAML \ No newline at end of file