diff --git a/.env.example b/.env.example
new file mode 100644
index 00000000..d71942da
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,15 @@
+# Google
+GOOGLE_CLIENT_ID=your_google_client_id
+GOOGLE_CLIENT_SECRET=your_google_client_secret
+
+# GitHub
+GITHUB_CLIENT_ID=your_github_client_id
+GITHUB_CLIENT_SECRET=your_github_client_secret
+
+# Microsoft
+MICROSOFT_CLIENT_ID=your_microsoft_client_id
+MICROSOFT_CLIENT_SECRET=your_microsoft_client_secret
+
+# General
+SECRET_KEY=supersecret
+BASE_URL=http://localhost:5000
diff --git a/gramps_webapi/app.py b/gramps_webapi/app.py
index 2041c2a1..86107e4b 100644
--- a/gramps_webapi/app.py
+++ b/gramps_webapi/app.py
@@ -17,13 +17,13 @@
# along with this program. If not, see .
#
-"""Flask web app providing a REST API to a Gramps family tree."""
import logging
import os
import warnings
from typing import Any, Dict, Optional
+
from flask import Flask, abort, g, send_from_directory
from flask_compress import Compress
from flask_cors import CORS
@@ -31,12 +31,14 @@
from gramps.gen.config import config as gramps_config
from gramps.gen.config import set as setconfig
+
from .api import api_blueprint
-from .api.cache import request_cache, thumbnail_cache
+from .api.cache import thumbnail_cache, request_cache
from .api.ratelimiter import limiter
from .api.search.embeddings import load_model
from .api.util import close_db
from .auth import user_db
+from .auth.oidc import configure_oauth, oidc_bp
from .config import DefaultConfig, DefaultConfigJWT
from .const import API_PREFIX, ENV_CONFIG_FILE, TREE_MULTI
from .dbmanager import WebDbManager
@@ -49,20 +51,10 @@ def deprecated_config_from_env(app):
This function will be removed eventually!
"""
options = [
- "TREE",
- "SECRET_KEY",
- "USER_DB_URI",
- "POSTGRES_USER",
- "POSTGRES_PASSWORD",
- "MEDIA_BASE_DIR",
- "SEARCH_INDEX_DIR",
- "EMAIL_HOST",
- "EMAIL_PORT",
- "EMAIL_HOST_USER",
- "EMAIL_HOST_PASSWORD",
- "DEFAULT_FROM_EMAIL",
- "BASE_URL",
- "STATIC_PATH",
+ "TREE", "SECRET_KEY", "USER_DB_URI", "POSTGRES_USER", "POSTGRES_PASSWORD",
+ "MEDIA_BASE_DIR", "SEARCH_INDEX_DIR", "EMAIL_HOST", "EMAIL_PORT",
+ "EMAIL_HOST_USER", "EMAIL_HOST_PASSWORD", "DEFAULT_FROM_EMAIL",
+ "BASE_URL", "STATIC_PATH",
]
for option in options:
value = os.getenv(option)
@@ -79,7 +71,6 @@ def deprecated_config_from_env(app):
def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = True):
"""Flask application factory."""
app = Flask(__name__)
-
app.logger.setLevel(logging.INFO)
# load default config
@@ -89,7 +80,7 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
if os.getenv(ENV_CONFIG_FILE):
app.config.from_envvar(ENV_CONFIG_FILE)
- # use unprefixed environment variables if exist - deprecated!
+ # use unprefixed environment variables if exist - deprecated!
deprecated_config_from_env(app)
# use prefixed environment variables if exist
@@ -111,9 +102,11 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
if db_path := os.getenv("GRAMPS_DATABASE_PATH"):
setconfig("database.path", db_path)
+
if app.config.get("LOG_LEVEL"):
app.logger.setLevel(app.config["LOG_LEVEL"])
+
if app.config["TREE"] != TREE_MULTI:
# create database if missing (only in single-tree mode)
WebDbManager(
@@ -122,6 +115,7 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
ignore_lock=app.config["IGNORE_DB_LOCK"],
)
+
if app.config["TREE"] == TREE_MULTI and not app.config["MEDIA_PREFIX_TREE"]:
warnings.warn(
"You have enabled multi-tree support, but `MEDIA_PREFIX_TREE` is "
@@ -129,24 +123,31 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
"files to users belonging to different trees!"
)
+
if app.config["TREE"] == TREE_MULTI and app.config["NEW_DB_BACKEND"] != "sqlite":
- # needed in case a new postgres tree is to be created
+ # needed in case a new postgres tree is to be created
gramps_config.set("database.host", app.config["POSTGRES_HOST"])
gramps_config.set("database.port", str(app.config["POSTGRES_PORT"]))
# load JWT default settings
app.config.from_object(DefaultConfigJWT)
-
- # instantiate JWT manager
JWTManager(app)
+
app.config["SQLALCHEMY_DATABASE_URI"] = app.config["USER_DB_URI"]
user_db.init_app(app)
+
request_cache.init_app(app, config=app.config["REQUEST_CACHE_CONFIG"])
+
+
+ configure_oauth(app)
+ app.register_blueprint(oidc_bp)
+
+
thumbnail_cache.init_app(app, config=app.config["THUMBNAIL_CACHE_CONFIG"])
- # enable CORS for /api/... resources
+# enable CORS for /api/... resources
if app.config.get("CORS_ORIGINS"):
CORS(
app,
@@ -156,13 +157,15 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
# enable gzip compression
Compress(app)
+
static_path = app.config.get("STATIC_PATH")
- # routes for static hosting (e.g. SPA frontend)
+ # routes for static hosting (e.g. SPA frontend)
@app.route("/", methods=["GET", "POST"])
def send_index():
return send_from_directory(static_path, "index.html")
+
@app.route("/", methods=["GET", "POST"])
def send_static(path):
if path.startswith(API_PREFIX[1:]):
@@ -176,10 +179,11 @@ def send_static(path):
# register the API blueprint
app.register_blueprint(api_blueprint)
limiter.init_app(app)
-
+
# instantiate celery
create_celery(app)
+
@app.teardown_appcontext
def close_db_connection(exception) -> None:
"""Close the Gramps database after every request."""
@@ -190,6 +194,7 @@ def close_db_connection(exception) -> None:
if db_write:
close_db(db_write)
+
@app.teardown_request
def close_user_db_connection(exception) -> None:
"""Close the user database after every request."""
@@ -198,13 +203,16 @@ def close_user_db_connection(exception) -> None:
user_db.session.close() # pylint: disable=no-member
user_db.session.remove() # pylint: disable=no-member
+
if app.config.get("VECTOR_EMBEDDING_MODEL"):
app.config["_INITIALIZED_VECTOR_EMBEDDING_MODEL"] = load_model(
app.config["VECTOR_EMBEDDING_MODEL"]
)
+
@app.route("/ready", methods=["GET"])
def ready():
return {"status": "ready"}, 200
+
return app
diff --git a/gramps_webapi/auth/oidc.py b/gramps_webapi/auth/oidc.py
new file mode 100644
index 00000000..1a4db580
--- /dev/null
+++ b/gramps_webapi/auth/oidc.py
@@ -0,0 +1,156 @@
+import uuid
+from authlib.integrations.flask_client import OAuth
+from flask import Blueprint, redirect, url_for, request, current_app, jsonify
+from flask_jwt_extended import create_access_token, create_refresh_token
+from sqlalchemy.exc import IntegrityError
+
+
+from . import user_db, User, add_user
+from .const import ROLE_USER
+from ..api.util import get_tree_id, abort_with_message, tree_exists
+from ..api.auth import get_permissions
+from ..const import TREE_MULTI
+
+
+oauth = OAuth()
+oidc_bp = Blueprint("oidc", __name__, url_prefix="/auth")
+
+
+def configure_oauth(app):
+ if not app.config.get("OAUTH_ENABLED", False):
+ return
+
+
+ oauth.init_app(app)
+
+
+ if app.config.get("OAUTH_GOOGLE_CLIENT_ID") and app.config.get("OAUTH_GOOGLE_CLIENT_SECRET"):
+ oauth.register(
+ name="google",
+ client_id=app.config["OAUTH_GOOGLE_CLIENT_ID"],
+ client_secret=app.config["OAUTH_GOOGLE_CLIENT_SECRET"],
+ access_token_url="https://oauth2.googleapis.com/token",
+ authorize_url="https://accounts.google.com/o/oauth2/auth",
+ api_base_url="https://www.googleapis.com/oauth2/v1/",
+ client_kwargs={"scope": "openid email profile"},
+ )
+
+
+ if app.config.get("OAUTH_GITHUB_CLIENT_ID") and app.config.get("OAUTH_GITHUB_CLIENT_SECRET"):
+ oauth.register(
+ name="github",
+ client_id=app.config["OAUTH_GITHUB_CLIENT_ID"],
+ client_secret=app.config["OAUTH_GITHUB_CLIENT_SECRET"],
+ access_token_url="https://github.com/login/oauth/access_token",
+ authorize_url="https://github.com/login/oauth/authorize",
+ api_base_url="https://api.github.com/",
+ client_kwargs={"scope": "read:user user:email"},
+ )
+
+
+ if app.config.get("OAUTH_MICROSOFT_CLIENT_ID") and app.config.get("OAUTH_MICROSOFT_CLIENT_SECRET"):
+ oauth.register(
+ name="microsoft",
+ client_id=app.config["OAUTH_MICROSOFT_CLIENT_ID"],
+ client_secret=app.config["OAUTH_MICROSOFT_CLIENT_SECRET"],
+ access_token_url="https://login.microsoftonline.com/common/oauth2/v2.0/token",
+ authorize_url="https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
+ api_base_url="https://graph.microsoft.com/v1.0/",
+ client_kwargs={"scope": "openid email profile"},
+ )
+
+
+@oidc_bp.route("/login/")
+def login(provider):
+ if not current_app.config.get("OAUTH_ENABLED", False):
+ abort_with_message(403, "OAuth is not enabled")
+
+
+ redirect_uri = request.args.get("redirect_uri", url_for("oidc.authorize", provider=provider, _external=True))
+ return oauth.create_client(provider).authorize_redirect(redirect_uri)
+
+
+@oidc_bp.route("/callback/")
+def authorize(provider):
+ if not current_app.config.get("OAUTH_ENABLED", False):
+ abort_with_message(403, "OAuth is not enabled")
+
+ client = oauth.create_client(provider)
+ token = client.authorize_access_token()
+
+ if provider == "google":
+ user_info = client.parse_id_token(token)
+ else:
+ user_info = client.get("user").json()
+
+ email = user_info.get("email")
+ if not email:
+ abort_with_message(400, "No email provided")
+
+ user = user_db.session.query(User).filter_by(email=email).first()
+ if not user:
+ # Enforce OIDC registration enable/disable
+ if not current_app.config.get("ALLOW_OIDC_REGISTRATION", True):
+ abort_with_message(403, "User registration is disabled")
+
+ # Determine the tree to use
+ if current_app.config["TREE"] == TREE_MULTI:
+ # In multi-tree, OIDC must know which tree to assign!
+ abort_with_message(422, "tree is required for OIDC registration in multi-tree mode")
+ tree_id = current_app.config["TREE"]
+ if not tree_exists(tree_id):
+ abort_with_message(422, "Tree does not exist")
+
+ try:
+ username = email.split("@", 1)[0]
+ dummy_password = uuid.uuid4().hex # Satisfy non-empty password check
+ add_user(
+ name=username,
+ password=dummy_password,
+ email=email,
+ fullname=user_info.get("name", ""),
+ role=ROLE_USER,
+ tree=tree_id
+ )
+ user = user_db.session.query(User).filter_by(email=email).first()
+ if not user:
+ abort_with_message(500, "Failed to create user")
+ except ValueError as exc:
+ # Consistent error codes: 409 for conflict, 400 otherwise
+ msg = str(exc)
+ code = 409 if "exists" in msg.lower() else 400
+ abort_with_message(code, msg)
+
+ # Now continue as usual
+ tree_id = get_tree_id(str(user.id))
+ permissions = get_permissions(username=user.name, tree=tree_id)
+
+ access_token = create_access_token(
+ identity=str(user.id),
+ additional_claims={
+ "permissions": list(permissions),
+ "tree": tree_id
+ }
+ )
+ refresh_token = create_refresh_token(identity=str(user.id))
+
+ frontend_redirect = request.args.get("state") or current_app.config.get("FRONTEND_URL", "/")
+
+ response = {
+ "access_token": access_token,
+ "refresh_token": refresh_token,
+ "token_type": "Bearer",
+ "expires_in": current_app.config.get("JWT_ACCESS_TOKEN_EXPIRES", 900),
+ "user": {
+ "name": user.name,
+ "email": user.email,
+ "full_name": user.fullname,
+ "role": user.role,
+ "tree": tree_id
+ }
+ }
+
+ if request.headers.get("Accept") == "application/json":
+ return jsonify(response)
+
+ return redirect(f"{frontend_redirect}?access_token={access_token}&refresh_token={refresh_token}")
diff --git a/gramps_webapi/config.py b/gramps_webapi/config.py
index fd64a0ab..fc9b0c28 100644
--- a/gramps_webapi/config.py
+++ b/gramps_webapi/config.py
@@ -17,16 +17,21 @@
# along with this program. If not, see .
#
+
"""Default configuration settings."""
+
import datetime
from pathlib import Path
from typing import Dict
+
+
class DefaultConfig(object):
"""Default configuration object."""
+
PROPAGATE_EXCEPTIONS = True
SEARCH_INDEX_DIR = "indexdir" # deprecated!
SEARCH_INDEX_DB_URI = ""
@@ -51,6 +56,12 @@ class DefaultConfig(object):
"CACHE_THRESHOLD": 1000,
"CACHE_DEFAULT_TIMEOUT": 0,
}
+ PERSISTENT_CACHE_CONFIG = {
+ "CACHE_TYPE": "FileSystemCache",
+ "CACHE_DIR": str(Path.cwd() / "persistent_cache"),
+ "CACHE_THRESHOLD": 0,
+ "CACHE_DEFAULT_TIMEOUT": 0,
+ }
POSTGRES_USER = None
POSTGRES_PASSWORD = None
POSTGRES_HOST = "localhost"
@@ -69,11 +80,25 @@ class DefaultConfig(object):
LLM_MODEL = ""
LLM_MAX_CONTEXT_LENGTH = 50000
VECTOR_EMBEDDING_MODEL = ""
+ DISABLE_TELEMETRY = False
+
+
+ # OAuth configuration
+ OAUTH_GOOGLE_CLIENT_ID = ""
+ OAUTH_GOOGLE_CLIENT_SECRET = ""
+ OAUTH_GITHUB_CLIENT_ID = ""
+ OAUTH_GITHUB_CLIENT_SECRET = ""
+ OAUTH_MICROSOFT_CLIENT_ID = ""
+ OAUTH_MICROSOFT_CLIENT_SECRET = ""
+ OAUTH_ENABLED = False # Master switch for OAuth functionality
+
+
class DefaultConfigJWT(object):
"""Default configuration for JWT auth."""
+
JWT_TOKEN_LOCATION = ["headers", "query_string"]
JWT_ACCESS_TOKEN_EXPIRES = datetime.timedelta(minutes=15)
JWT_REFRESH_TOKEN_EXPIRES = False
diff --git a/pyproject.toml b/pyproject.toml
index 38ad8801..1785e881 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -44,6 +44,7 @@ dependencies = [
"gramps-ql>=0.4.0",
"object-ql>=0.1.3",
"sifts>=0.8.3",
+ "authlib",
]
[project.optional-dependencies]
diff --git a/requirements-dev.txt b/requirements-dev.txt
index ea5aeadf..169f3b47 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -7,4 +7,4 @@ pydocstyle
pre-commit
celery[pytest]
moto[s3]<5.0.0
-PyYAML
+PyYAML
\ No newline at end of file