Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Google
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this be handled as part of the flask configuration? https://docs.authlib.org/en/latest/client/flask.html#configuration

Authlib Flask OAuth registry can load the configuration from Flask app.config automatically.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this file is not used anymore?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why is it still here?

GOOGLE_CLIENT_ID=your_google_client_id
GOOGLE_CLIENT_SECRET=your_google_client_secret

# GitHub
GITHUB_CLIENT_ID=your_github_client_id
GITHUB_CLIENT_SECRET=your_github_client_secret

# Microsoft
MICROSOFT_CLIENT_ID=your_microsoft_client_id
MICROSOFT_CLIENT_SECRET=your_microsoft_client_secret

# General
SECRET_KEY=supersecret
BASE_URL=http://localhost:5000
56 changes: 32 additions & 24 deletions gramps_webapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,28 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot remove the copyright notice!! (I know ChatGPT likes to do that...)

"""Flask web app providing a REST API to a Gramps family tree."""

import logging
import os
import warnings
from typing import Any, Dict, Optional


from flask import Flask, abort, g, send_from_directory
from flask_compress import Compress
from flask_cors import CORS
from flask_jwt_extended import JWTManager
from gramps.gen.config import config as gramps_config
from gramps.gen.config import set as setconfig


from .api import api_blueprint
from .api.cache import request_cache, thumbnail_cache
from .api.cache import thumbnail_cache, request_cache
from .api.ratelimiter import limiter
from .api.search.embeddings import load_model
from .api.util import close_db
from .auth import user_db
from .auth.oidc import configure_oauth, oidc_bp
from .config import DefaultConfig, DefaultConfigJWT
from .const import API_PREFIX, ENV_CONFIG_FILE, TREE_MULTI
from .dbmanager import WebDbManager
Expand All @@ -49,20 +51,10 @@ def deprecated_config_from_env(app):
This function will be removed eventually!
"""
options = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this reformatted?

"TREE",
"SECRET_KEY",
"USER_DB_URI",
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"MEDIA_BASE_DIR",
"SEARCH_INDEX_DIR",
"EMAIL_HOST",
"EMAIL_PORT",
"EMAIL_HOST_USER",
"EMAIL_HOST_PASSWORD",
"DEFAULT_FROM_EMAIL",
"BASE_URL",
"STATIC_PATH",
"TREE", "SECRET_KEY", "USER_DB_URI", "POSTGRES_USER", "POSTGRES_PASSWORD",
"MEDIA_BASE_DIR", "SEARCH_INDEX_DIR", "EMAIL_HOST", "EMAIL_PORT",
"EMAIL_HOST_USER", "EMAIL_HOST_PASSWORD", "DEFAULT_FROM_EMAIL",
"BASE_URL", "STATIC_PATH",
]
for option in options:
value = os.getenv(option)
Expand All @@ -79,7 +71,6 @@ def deprecated_config_from_env(app):
def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool = True):
"""Flask application factory."""
app = Flask(__name__)

app.logger.setLevel(logging.INFO)

# load default config
Expand All @@ -89,7 +80,7 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
if os.getenv(ENV_CONFIG_FILE):
app.config.from_envvar(ENV_CONFIG_FILE)

# use unprefixed environment variables if exist - deprecated!
# use unprefixed environment variables if exist - deprecated!
deprecated_config_from_env(app)

# use prefixed environment variables if exist
Expand All @@ -111,9 +102,11 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
if db_path := os.getenv("GRAMPS_DATABASE_PATH"):
setconfig("database.path", db_path)


if app.config.get("LOG_LEVEL"):
app.logger.setLevel(app.config["LOG_LEVEL"])


if app.config["TREE"] != TREE_MULTI:
# create database if missing (only in single-tree mode)
WebDbManager(
Expand All @@ -122,31 +115,39 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
ignore_lock=app.config["IGNORE_DB_LOCK"],
)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't add white space changes in unmodified code - it clutters the diff

if app.config["TREE"] == TREE_MULTI and not app.config["MEDIA_PREFIX_TREE"]:
warnings.warn(
"You have enabled multi-tree support, but `MEDIA_PREFIX_TREE` is "
"set to `False`. This is strongly discouraged as it exposes media "
"files to users belonging to different trees!"
)


if app.config["TREE"] == TREE_MULTI and app.config["NEW_DB_BACKEND"] != "sqlite":
# needed in case a new postgres tree is to be created
# needed in case a new postgres tree is to be created
gramps_config.set("database.host", app.config["POSTGRES_HOST"])
gramps_config.set("database.port", str(app.config["POSTGRES_PORT"]))

# load JWT default settings
app.config.from_object(DefaultConfigJWT)

# instantiate JWT manager
JWTManager(app)


app.config["SQLALCHEMY_DATABASE_URI"] = app.config["USER_DB_URI"]
user_db.init_app(app)


request_cache.init_app(app, config=app.config["REQUEST_CACHE_CONFIG"])


configure_oauth(app)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will register the providers.

Does this involve an HTTP request to the provider endpoint? If so, what if that request fails? Will the whole flask app fail to start?

app.register_blueprint(oidc_bp)


thumbnail_cache.init_app(app, config=app.config["THUMBNAIL_CACHE_CONFIG"])

# enable CORS for /api/... resources
# enable CORS for /api/... resources
if app.config.get("CORS_ORIGINS"):
CORS(
app,
Expand All @@ -156,13 +157,15 @@ def create_app(config: Optional[Dict[str, Any]] = None, config_from_env: bool =
# enable gzip compression
Compress(app)


static_path = app.config.get("STATIC_PATH")

# routes for static hosting (e.g. SPA frontend)
# routes for static hosting (e.g. SPA frontend)
@app.route("/", methods=["GET", "POST"])
def send_index():
return send_from_directory(static_path, "index.html")


@app.route("/<path:path>", methods=["GET", "POST"])
def send_static(path):
if path.startswith(API_PREFIX[1:]):
Expand All @@ -176,10 +179,11 @@ def send_static(path):
# register the API blueprint
app.register_blueprint(api_blueprint)
limiter.init_app(app)

# instantiate celery
create_celery(app)


@app.teardown_appcontext
def close_db_connection(exception) -> None:
"""Close the Gramps database after every request."""
Expand All @@ -190,6 +194,7 @@ def close_db_connection(exception) -> None:
if db_write:
close_db(db_write)


@app.teardown_request
def close_user_db_connection(exception) -> None:
"""Close the user database after every request."""
Expand All @@ -198,13 +203,16 @@ def close_user_db_connection(exception) -> None:
user_db.session.close() # pylint: disable=no-member
user_db.session.remove() # pylint: disable=no-member


if app.config.get("VECTOR_EMBEDDING_MODEL"):
app.config["_INITIALIZED_VECTOR_EMBEDDING_MODEL"] = load_model(
app.config["VECTOR_EMBEDDING_MODEL"]
)


@app.route("/ready", methods=["GET"])
def ready():
return {"status": "ready"}, 200


return app
156 changes: 156 additions & 0 deletions gramps_webapi/auth/oidc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import uuid
from authlib.integrations.flask_client import OAuth
from flask import Blueprint, redirect, url_for, request, current_app, jsonify
from flask_jwt_extended import create_access_token, create_refresh_token
from sqlalchemy.exc import IntegrityError


from . import user_db, User, add_user
from .const import ROLE_USER
from ..api.util import get_tree_id, abort_with_message, tree_exists
from ..api.auth import get_permissions
from ..const import TREE_MULTI


oauth = OAuth()
oidc_bp = Blueprint("oidc", __name__, url_prefix="/auth")


def configure_oauth(app):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The frontend needs a way of knowing which of the providers have been registered in the end, based on the configuration, otherwise the frontend doesn't know which buttons to show.

I think the best solution is to add it to a list of registered providers to the /api/metadata endpoint under the server key.

if not app.config.get("OAUTH_ENABLED", False):
return


oauth.init_app(app)


if app.config.get("OAUTH_GOOGLE_CLIENT_ID") and app.config.get("OAUTH_GOOGLE_CLIENT_SECRET"):
oauth.register(
name="google",
client_id=app.config["OAUTH_GOOGLE_CLIENT_ID"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why this contradicts the docs which asks for {name}_CLIENT_ID etc., see https://docs.authlib.org/en/latest/client/flask.html#configuration

client_secret=app.config["OAUTH_GOOGLE_CLIENT_SECRET"],
access_token_url="https://oauth2.googleapis.com/token",
authorize_url="https://accounts.google.com/o/oauth2/auth",
api_base_url="https://www.googleapis.com/oauth2/v1/",
client_kwargs={"scope": "openid email profile"},
)


if app.config.get("OAUTH_GITHUB_CLIENT_ID") and app.config.get("OAUTH_GITHUB_CLIENT_SECRET"):
oauth.register(
name="github",
client_id=app.config["OAUTH_GITHUB_CLIENT_ID"],
client_secret=app.config["OAUTH_GITHUB_CLIENT_SECRET"],
access_token_url="https://github.com/login/oauth/access_token",
authorize_url="https://github.com/login/oauth/authorize",
api_base_url="https://api.github.com/",
client_kwargs={"scope": "read:user user:email"},
)


if app.config.get("OAUTH_MICROSOFT_CLIENT_ID") and app.config.get("OAUTH_MICROSOFT_CLIENT_SECRET"):
oauth.register(
name="microsoft",
client_id=app.config["OAUTH_MICROSOFT_CLIENT_ID"],
client_secret=app.config["OAUTH_MICROSOFT_CLIENT_SECRET"],
access_token_url="https://login.microsoftonline.com/common/oauth2/v2.0/token",
authorize_url="https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
api_base_url="https://graph.microsoft.com/v1.0/",
client_kwargs={"scope": "openid email profile"},
)


@oidc_bp.route("/login/<provider>")
def login(provider):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc string please

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And please add an example to the doc string, I see there is a redirect_uri argument expected in the request? What should it be?

if not current_app.config.get("OAUTH_ENABLED", False):
abort_with_message(403, "OAuth is not enabled")


redirect_uri = request.args.get("redirect_uri", url_for("oidc.authorize", provider=provider, _external=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is oidc.authorize correct? Why not gramps_webapi.auth.oidc.authorize?

return oauth.create_client(provider).authorize_redirect(redirect_uri)


@oidc_bp.route("/callback/<provider>")
def authorize(provider):
if not current_app.config.get("OAUTH_ENABLED", False):
abort_with_message(403, "OAuth is not enabled")

client = oauth.create_client(provider)
token = client.authorize_access_token()

if provider == "google":
user_info = client.parse_id_token(token)
else:
user_info = client.get("user").json()

email = user_info.get("email")
if not email:
abort_with_message(400, "No email provided")

user = user_db.session.query(User).filter_by(email=email).first()
if not user:
# Enforce OIDC registration enable/disable
if not current_app.config.get("ALLOW_OIDC_REGISTRATION", True):
abort_with_message(403, "User registration is disabled")

# Determine the tree to use
if current_app.config["TREE"] == TREE_MULTI:
# In multi-tree, OIDC must know which tree to assign!
abort_with_message(422, "tree is required for OIDC registration in multi-tree mode")
tree_id = current_app.config["TREE"]
if not tree_exists(tree_id):
abort_with_message(422, "Tree does not exist")

try:
username = email.split("@", 1)[0]
dummy_password = uuid.uuid4().hex # Satisfy non-empty password check
add_user(
name=username,
password=dummy_password,
email=email,
fullname=user_info.get("name", ""),
role=ROLE_USER,
tree=tree_id
)
user = user_db.session.query(User).filter_by(email=email).first()
if not user:
abort_with_message(500, "Failed to create user")
except ValueError as exc:
# Consistent error codes: 409 for conflict, 400 otherwise
msg = str(exc)
code = 409 if "exists" in msg.lower() else 400
abort_with_message(code, msg)

# Now continue as usual
tree_id = get_tree_id(str(user.id))
permissions = get_permissions(username=user.name, tree=tree_id)

access_token = create_access_token(
identity=str(user.id),
additional_claims={
"permissions": list(permissions),
"tree": tree_id
}
)
refresh_token = create_refresh_token(identity=str(user.id))

frontend_redirect = request.args.get("state") or current_app.config.get("FRONTEND_URL", "/")

response = {
"access_token": access_token,
"refresh_token": refresh_token,
"token_type": "Bearer",
"expires_in": current_app.config.get("JWT_ACCESS_TOKEN_EXPIRES", 900),
"user": {
"name": user.name,
"email": user.email,
"full_name": user.fullname,
"role": user.role,
"tree": tree_id
}
}

if request.headers.get("Accept") == "application/json":
return jsonify(response)

return redirect(f"{frontend_redirect}?access_token={access_token}&refresh_token={refresh_token}")
Loading