diff --git a/airflow/providers/fab/CHANGELOG.rst b/airflow/providers/fab/CHANGELOG.rst index 9c9e29412793d..37f4c6ca0f3ed 100644 --- a/airflow/providers/fab/CHANGELOG.rst +++ b/airflow/providers/fab/CHANGELOG.rst @@ -20,6 +20,151 @@ Changelog --------- +1.5.4 +..... + +* ``[providers-fab/v1-5] Update dependencies for FAB provider to not be conflicting with 2.11.1 (#53029)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + + +1.5.3 +..... + +Bug Fixes +~~~~~~~~~ + +* ``[providers-fab/v1-5] Use different default algorithms for different werkzeug versions (#46384) (#46392)`` + +Misc +~~~~ + +* ``[providers-fab/v1-5] Upgrade to FAB 4.5.3 (#45874) (#45918)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + +1.5.2 +..... + +Misc +~~~~ + +* ``Correctly import isabs from os.path (#45178)`` +* ``[providers-fab/v1-5] Invalidate user session on password reset (#45139)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + +1.5.1 +..... + +Bug Fixes +~~~~~~~~~ + +* ``fab_auth_manager: allow get_user method to return the user authenticated via Kerberos (#43662)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Expand and improve the kerberos api authentication documentation (#43682)`` + +1.5.0 +..... + +Features +~~~~~~~~ + +* ``feat(providers/fab): Use asset in common provider (#43112)`` + +Bug Fixes +~~~~~~~~~ + +* ``fix revoke Dag stale permission on airflow < 2.10 (#42844)`` +* ``fix(providers/fab): alias is_authorized_dataset to is_authorized_asset (#43469)`` +* ``fix: Change CustomSecurityManager method name (#43034)`` + +Misc +~~~~ + +* ``Upgrade Flask-AppBuilder to 4.5.2 (#43309)`` +* ``Upgrade Flask-AppBuilder to 4.5.1 (#43251)`` +* ``Move user and roles schemas to fab provider (#42869)`` +* ``Move the session auth backend to FAB auth manager (#42878)`` +* ``Add logging to the migration commands (#43516)`` +* ``DOC fix documentation error in 'apache-airflow-providers-fab/access-control.rst' (#43495)`` +* ``Rename dataset as asset in UI (#43073)`` + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Split providers out of the main "airflow/" tree into a UV workspace project (#42505)`` + * ``Start porting DAG definition code to the Task SDK (#43076)`` + * ``Prepare docs for Oct 2nd wave of providers (#43409)`` + * ``Prepare docs for Oct 2nd wave of providers RC2 (#43540)`` + +1.4.1 +..... + +Misc +~~~~ + +* ``Update Rest API tests to no longer rely on FAB auth manager. Move tests specific to FAB permissions to FAB provider (#42523)`` +* ``Rename dataset related python variable names to asset (#41348)`` +* ``Simplify expression for get_permitted_dag_ids query (#42484)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + +1.4.0 +..... + +Features +~~~~~~~~ + +* ``Add FAB migration commands (#41804)`` +* ``Separate FAB migration from Core Airflow migration (#41437)`` + +Misc +~~~~ + +* ``Deprecated kerberos auth removed (#41693)`` +* ``Deprecated configuration removed (#42129)`` +* ``Move 'is_active' user property to FAB auth manager (#42042)`` +* ``Move 'register_views' to auth manager interface (#41777)`` +* ``Revert "Provider fab auth manager deprecated methods removed (#41720)" (#41960)`` +* ``Provider fab auth manager deprecated methods removed (#41720)`` +* ``Make kerberos an optional and devel dependency for impala and fab (#41616)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + * ``Add TODOs in providers code for Subdag code removal (#41963)`` + * ``Add fixes by breeze/precommit-lint static checks (#41604) (#41618)`` + +.. Review and move the new changes to one of the sections above: + * ``Fix pre-commit for auto update of fab migration versions (#42382)`` + * ``Handle 'AUTH_ROLE_PUBLIC' in FAB auth manager (#42280)`` + +1.3.0 +..... + +Features +~~~~~~~~ + +* ``Feature: Allow set Dag Run resource into Dag Level permission (#40703)`` + +Misc +~~~~ + +* ``Remove deprecated SubDags (#41390)`` + + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + 1.2.2 ..... diff --git a/airflow/providers/fab/__init__.py b/airflow/providers/fab/__init__.py index c59168fb92ad2..60fe1b0f9e98e 100644 --- a/airflow/providers/fab/__init__.py +++ b/airflow/providers/fab/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "1.2.2" +__version__ = "1.5.4" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.9.0" diff --git a/airflow/providers/fab/alembic.ini b/airflow/providers/fab/alembic.ini new file mode 100644 index 0000000000000..75d42ee16d3b9 --- /dev/null +++ b/airflow/providers/fab/alembic.ini @@ -0,0 +1,133 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = %(here)s/migrations + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = scheme://localhost/airflow + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py index ff7c2cc3b3742..1ba532cd5198d 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/basic_auth.py @@ -21,7 +21,7 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast -from flask import Response, request +from flask import Response, current_app, request from flask_appbuilder.const import AUTH_LDAP from flask_login import login_user @@ -62,9 +62,23 @@ def requires_authentication(function: T): @wraps(function) def decorated(*args, **kwargs): - if auth_current_user() is not None: + # Try to authenticate the user + user = auth_current_user() + if user is not None: return function(*args, **kwargs) - else: + + # Authentication failed - check if Authorization header was provided + auth_header = request.headers.get("Authorization") + if auth_header: + # Authorization header was present but authentication failed + # This includes malformed headers that Flask couldn't parse return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) + # No Authorization header - check if public access is allowed + if current_app.config.get("AUTH_ROLE_PUBLIC", None): + return function(*args, **kwargs) + + # No auth header and no public access + return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"}) + return cast(T, decorated) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py index ac50aed5f02dc..f2038b27597c1 100644 --- a/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py +++ b/airflow/providers/fab/auth_manager/api/auth/backend/kerberos_auth.py @@ -18,27 +18,129 @@ from __future__ import annotations import logging -from functools import partial -from typing import Any, cast +import os +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, TypeVar, cast +import kerberos +from flask import Response, current_app, g, make_response, request from requests_kerberos import HTTPKerberosAuth -from airflow.api.auth.backend.kerberos_auth import ( - init_app as base_init_app, - requires_authentication as base_requires_authentication, -) +from airflow.configuration import conf from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride +from airflow.utils.net import getfqdn from airflow.www.extensions.init_auth_manager import get_auth_manager +if TYPE_CHECKING: + from airflow.auth.managers.models.base_user import BaseUser + log = logging.getLogger(__name__) CLIENT_AUTH: tuple[str, str] | Any | None = HTTPKerberosAuth(service="airflow") +class KerberosService: + """Class to keep information about the Kerberos Service initialized.""" + + def __init__(self): + self.service_name = None + + +class _KerberosAuth(NamedTuple): + return_code: int | None + user: str = "" + token: str | None = None + + +# Stores currently initialized Kerberos Service +_KERBEROS_SERVICE = KerberosService() + + +def init_app(app): + """Initialize application with kerberos.""" + hostname = app.config.get("SERVER_NAME") + if not hostname: + hostname = getfqdn() + log.info("Kerberos: hostname %s", hostname) + + service = "airflow" + + _KERBEROS_SERVICE.service_name = f"{service}@{hostname}" + + if "KRB5_KTNAME" not in os.environ: + os.environ["KRB5_KTNAME"] = conf.get("kerberos", "keytab") + + try: + log.info("Kerberos init: %s %s", service, hostname) + principal = kerberos.getServerPrincipalDetails(service, hostname) + except kerberos.KrbError as err: + log.warning("Kerberos: %s", err) + else: + log.info("Kerberos API: server is %s", principal) + + +def _unauthorized(): + """Indicate that authorization is required.""" + return Response("Unauthorized", 401, {"WWW-Authenticate": "Negotiate"}) + + +def _forbidden(): + return Response("Forbidden", 403) + + +def _gssapi_authenticate(token) -> _KerberosAuth | None: + state = None + try: + return_code, state = kerberos.authGSSServerInit(_KERBEROS_SERVICE.service_name) + if return_code != kerberos.AUTH_GSS_COMPLETE: + return _KerberosAuth(return_code=None) + + if (return_code := kerberos.authGSSServerStep(state, token)) == kerberos.AUTH_GSS_COMPLETE: + return _KerberosAuth( + return_code=return_code, + user=kerberos.authGSSServerUserName(state), + token=kerberos.authGSSServerResponse(state), + ) + elif return_code == kerberos.AUTH_GSS_CONTINUE: + return _KerberosAuth(return_code=return_code) + return _KerberosAuth(return_code=return_code) + except kerberos.GSSError: + return _KerberosAuth(return_code=None) + finally: + if state: + kerberos.authGSSServerClean(state) + + +T = TypeVar("T", bound=Callable) + + def find_user(username=None, email=None): security_manager = cast(FabAirflowSecurityManagerOverride, get_auth_manager().security_manager) return security_manager.find_user(username=username, email=email) -init_app = base_init_app -requires_authentication = partial(base_requires_authentication, find_user=find_user) +def requires_authentication(function: T, find_user: Callable[[str], BaseUser] | None = find_user): + """Decorate functions that require authentication with Kerberos.""" + + @wraps(function) + def decorated(*args, **kwargs): + if current_app.config.get("AUTH_ROLE_PUBLIC", None): + response = function(*args, **kwargs) + return make_response(response) + + header = request.headers.get("Authorization") + if header: + token = "".join(header.split()[1:]) + auth = _gssapi_authenticate(token) + if auth.return_code == kerberos.AUTH_GSS_COMPLETE: + g.user = find_user(auth.user) + response = function(*args, **kwargs) + response = make_response(response) + if auth.token is not None: + response.headers["WWW-Authenticate"] = f"negotiate {auth.token}" + return response + elif auth.return_code != kerberos.AUTH_GSS_CONTINUE: + return _forbidden() + return _unauthorized() + + return cast(T, decorated) diff --git a/airflow/providers/fab/auth_manager/api/auth/backend/session.py b/airflow/providers/fab/auth_manager/api/auth/backend/session.py new file mode 100644 index 0000000000000..d51f7bf1cf4c9 --- /dev/null +++ b/airflow/providers/fab/auth_manager/api/auth/backend/session.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Session authentication backend.""" + +from __future__ import annotations + +from functools import wraps +from typing import Any, Callable, TypeVar, cast + +from flask import Response + +from airflow.www.extensions.init_auth_manager import get_auth_manager + +CLIENT_AUTH: tuple[str, str] | Any | None = None + + +def init_app(_): + """Initialize authentication backend.""" + + +T = TypeVar("T", bound=Callable) + + +def requires_authentication(function: T): + """Decorate functions that require authentication.""" + + @wraps(function) + def decorated(*args, **kwargs): + if not get_auth_manager().is_logged_in(): + return Response("Unauthorized", 401, {}) + return function(*args, **kwargs) + + return cast(T, decorated) diff --git a/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py b/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py index ed42f91163982..121a88be28587 100644 --- a/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py +++ b/airflow/providers/fab/auth_manager/api_endpoints/role_and_permission_endpoint.py @@ -26,15 +26,15 @@ from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_parameters -from airflow.api_connexion.schemas.role_and_permission_schema import ( +from airflow.api_connexion.security import requires_access_custom_view +from airflow.providers.fab.auth_manager.models import Action, Role +from airflow.providers.fab.auth_manager.schemas.role_and_permission_schema import ( ActionCollection, RoleCollection, action_collection_schema, role_collection_schema, role_schema, ) -from airflow.api_connexion.security import requires_access_custom_view -from airflow.providers.fab.auth_manager.models import Action, Role from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride from airflow.security import permissions from airflow.www.extensions.init_auth_manager import get_auth_manager diff --git a/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py b/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py index 665b7f52d896f..43464a23d365e 100644 --- a/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py +++ b/airflow/providers/fab/auth_manager/api_endpoints/user_endpoint.py @@ -27,14 +27,14 @@ from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound, Unknown from airflow.api_connexion.parameters import check_limit, format_parameters -from airflow.api_connexion.schemas.user_schema import ( +from airflow.api_connexion.security import requires_access_custom_view +from airflow.providers.fab.auth_manager.models import User +from airflow.providers.fab.auth_manager.schemas.user_schema import ( UserCollection, user_collection_item_schema, user_collection_schema, user_schema, ) -from airflow.api_connexion.security import requires_access_custom_view -from airflow.providers.fab.auth_manager.models import User from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride from airflow.security import permissions from airflow.www.extensions.init_auth_manager import get_auth_manager diff --git a/airflow/providers/fab/auth_manager/cli_commands/db_command.py b/airflow/providers/fab/auth_manager/cli_commands/db_command.py new file mode 100644 index 0000000000000..8b41cf4216c87 --- /dev/null +++ b/airflow/providers/fab/auth_manager/cli_commands/db_command.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow import settings +from airflow.cli.commands.db_command import run_db_downgrade_command, run_db_migrate_command +from airflow.providers.fab.auth_manager.models.db import _REVISION_HEADS_MAP, FABDBManager +from airflow.utils import cli as cli_utils +from airflow.utils.providers_configuration_loader import providers_configuration_loaded + + +@providers_configuration_loaded +def resetdb(args): + """Reset the metadata database.""" + print(f"DB: {settings.engine.url!r}") + if not (args.yes or input("This will drop existing tables if they exist. Proceed? (y/n)").upper() == "Y"): + raise SystemExit("Cancelled") + FABDBManager(settings.Session()).resetdb(skip_init=args.skip_init) + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def migratedb(args): + """Migrates the metadata database.""" + session = settings.Session() + upgrade_command = FABDBManager(session).upgradedb + run_db_migrate_command( + args, upgrade_command, revision_heads_map=_REVISION_HEADS_MAP, reserialize_dags=False + ) + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def downgrade(args): + """Downgrades the metadata database.""" + session = settings.Session() + dwongrade_command = FABDBManager(session).downgrade + run_db_downgrade_command(args, dwongrade_command, revision_heads_map=_REVISION_HEADS_MAP) diff --git a/airflow/providers/fab/auth_manager/cli_commands/definition.py b/airflow/providers/fab/auth_manager/cli_commands/definition.py index c7be5270d58fe..7f8f1e84e2798 100644 --- a/airflow/providers/fab/auth_manager/cli_commands/definition.py +++ b/airflow/providers/fab/auth_manager/cli_commands/definition.py @@ -19,8 +19,17 @@ import textwrap from airflow.cli.cli_config import ( + ARG_DB_FROM_REVISION, + ARG_DB_FROM_VERSION, + ARG_DB_REVISION__DOWNGRADE, + ARG_DB_REVISION__UPGRADE, + ARG_DB_SKIP_INIT, + ARG_DB_SQL_ONLY, + ARG_DB_VERSION__DOWNGRADE, + ARG_DB_VERSION__UPGRADE, ARG_OUTPUT, ARG_VERBOSE, + ARG_YES, ActionCommand, Arg, lazy_load_command, @@ -243,3 +252,55 @@ func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.sync_perm_command.sync_perm"), args=(ARG_INCLUDE_DAGS, ARG_VERBOSE), ) + +DB_COMMANDS = ( + ActionCommand( + name="migrate", + help="Migrates the FAB metadata database to the latest version", + description=( + "Migrate the schema of the FAB metadata database. " + "Create the database if it does not exist " + "To print but not execute commands, use option ``--show-sql-only``. " + "If using options ``--from-revision`` or ``--from-version``, you must also use " + "``--show-sql-only``, because if actually *running* migrations, we should only " + "migrate from the *current* Alembic revision." + ), + func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.migratedb"), + args=( + ARG_DB_REVISION__UPGRADE, + ARG_DB_VERSION__UPGRADE, + ARG_DB_SQL_ONLY, + ARG_DB_FROM_REVISION, + ARG_DB_FROM_VERSION, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="downgrade", + help="Downgrade the schema of the FAB metadata database.", + description=( + "Downgrade the schema of the FAB metadata database. " + "You must provide either `--to-revision` or `--to-version`. " + "To print but not execute commands, use option `--show-sql-only`. " + "If using options `--from-revision` or `--from-version`, you must also use `--show-sql-only`, " + "because if actually *running* migrations, we should only migrate from the *current* Alembic " + "revision." + ), + func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.downgrade"), + args=( + ARG_DB_REVISION__DOWNGRADE, + ARG_DB_VERSION__DOWNGRADE, + ARG_DB_SQL_ONLY, + ARG_YES, + ARG_DB_FROM_REVISION, + ARG_DB_FROM_VERSION, + ARG_VERBOSE, + ), + ), + ActionCommand( + name="reset", + help="Burn down and rebuild the FAB metadata database", + func=lazy_load_command("airflow.providers.fab.auth_manager.cli_commands.db_command.resetdb"), + args=(ARG_YES, ARG_DB_SKIP_INIT, ARG_VERBOSE), + ), +) diff --git a/airflow/providers/fab/auth_manager/cli_commands/user_command.py b/airflow/providers/fab/auth_manager/cli_commands/user_command.py index e877397bcf2e0..5853dcf1a63de 100644 --- a/airflow/providers/fab/auth_manager/cli_commands/user_command.py +++ b/airflow/providers/fab/auth_manager/cli_commands/user_command.py @@ -212,10 +212,12 @@ def users_import(args): users_created, users_updated = _import_users(users_list) if users_created: - print("Created the following users:\n\t{'\\n\\t'.join(users_created)}") + users_created_str = "\n\t".join(users_created) + print(f"Created the following users:\n\t{users_created_str}") if users_updated: - print("Updated the following users:\n\t{'\\n\\t.join(users_updated)}") + users_updated_str = "\n\t".join(users_updated) + print(f"Updated the following users:\n\t{users_updated_str}") def _import_users(users_list: list[dict[str, Any]]): @@ -231,7 +233,8 @@ def _import_users(users_list: list[dict[str, Any]]): msg.append(f"[Item {row_num}]") for key, value in failure.items(): msg.append(f"\t{key}: {value}") - raise SystemExit("Error: Input file didn't pass validation. See below:\n{'\\n'.join(msg)}") + msg_str = "\n".join(msg) + raise SystemExit(f"Error: Input file didn't pass validation. See below:\n{msg_str}") for user in users_list: roles = [] diff --git a/airflow/providers/fab/auth_manager/cli_commands/utils.py b/airflow/providers/fab/auth_manager/cli_commands/utils.py index 78403e24079f1..e848c2094ce5b 100644 --- a/airflow/providers/fab/auth_manager/cli_commands/utils.py +++ b/airflow/providers/fab/auth_manager/cli_commands/utils.py @@ -20,13 +20,17 @@ import os from contextlib import contextmanager from functools import lru_cache +from os.path import isabs from typing import TYPE_CHECKING, Generator from flask import Flask import airflow from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException +from airflow.www.app import make_url from airflow.www.extensions.init_appbuilder import init_appbuilder +from airflow.www.extensions.init_session import init_airflow_session_interface from airflow.www.extensions.init_views import init_plugins if TYPE_CHECKING: @@ -38,6 +42,7 @@ def _return_appbuilder(app: Flask) -> AirflowAppBuilder: """Return an appbuilder instance for the given app.""" init_appbuilder(app) init_plugins(app) + init_airflow_session_interface(app) return app.appbuilder # type: ignore[attr-defined] @@ -49,4 +54,12 @@ def get_application_builder() -> Generator[AirflowAppBuilder, None, None]: with flask_app.app_context(): # Enable customizations in webserver_config.py to be applied via Flask.current_app. flask_app.config.from_pyfile(webserver_config, silent=True) + flask_app.config["SQLALCHEMY_DATABASE_URI"] = conf.get("database", "SQL_ALCHEMY_CONN") + url = make_url(flask_app.config["SQLALCHEMY_DATABASE_URI"]) + if url.drivername == "sqlite" and url.database and not isabs(url.database): + raise AirflowConfigException( + f'Cannot use relative path: `{conf.get("database", "SQL_ALCHEMY_CONN")}` to connect to sqlite. ' + "Please use absolute path such as `sqlite:////tmp/airflow.db`." + ) + flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False yield _return_appbuilder(flask_app) diff --git a/airflow/providers/fab/auth_manager/fab_auth_manager.py b/airflow/providers/fab/auth_manager/fab_auth_manager.py index 344df7588de7d..e93e440f5ddfe 100644 --- a/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -18,15 +18,19 @@ from __future__ import annotations import argparse +import warnings from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Container +import packaging.version from connexion import FlaskApi -from flask import Blueprint, url_for +from flask import Blueprint, g, url_for +from packaging.version import Version from sqlalchemy import select from sqlalchemy.orm import Session, joinedload +from airflow import __version__ as airflow_version from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod from airflow.auth.managers.models.resource_details import ( AccessView, @@ -34,7 +38,6 @@ ConnectionDetails, DagAccessEntity, DagDetails, - DatasetDetails, PoolDetails, VariableDetails, ) @@ -44,9 +47,10 @@ GroupCommand, ) from airflow.configuration import conf -from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.exceptions import AirflowConfigException, AirflowException, AirflowProviderDeprecationWarning from airflow.models import DagModel from airflow.providers.fab.auth_manager.cli_commands.definition import ( + DB_COMMANDS, ROLES_COMMANDS, SYNC_PERM_COMMAND, USERS_COMMANDS, @@ -63,7 +67,6 @@ RESOURCE_DAG_DEPENDENCIES, RESOURCE_DAG_RUN, RESOURCE_DAG_WARNING, - RESOURCE_DATASET, RESOURCE_DOCS, RESOURCE_IMPORT_ERROR, RESOURCE_JOB, @@ -81,6 +84,7 @@ ) from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.yaml import safe_load +from airflow.version import version from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED from airflow.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver @@ -89,7 +93,12 @@ from airflow.cli.cli_config import ( CLICommand, ) + from airflow.providers.common.compat.assets import AssetDetails from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride + from airflow.security.permissions import RESOURCE_ASSET +else: + from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET + _MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE: dict[DagAccessEntity, tuple[str, ...]] = { DagAccessEntity.AUDIT_LOG: (RESOURCE_AUDIT_LOG,), @@ -132,7 +141,7 @@ class FabAuthManager(BaseAuthManager): @staticmethod def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" - return [ + commands: list[CLICommand] = [ GroupCommand( name="users", help="Manage users", @@ -145,6 +154,12 @@ def get_cli_commands() -> list[CLICommand]: ), SYNC_PERM_COMMAND, # not in a command group ] + # If Airflow version is 3.0.0 or higher, add the fab-db command group + if packaging.version.parse( + packaging.version.parse(airflow_version).base_version + ) >= packaging.version.parse("3.0.0"): + commands.append(GroupCommand(name="fab-db", help="Manage FAB", subcommands=DB_COMMANDS)) + return commands def get_api_endpoints(self) -> None | Blueprint: folder = Path(__file__).parents[0].resolve() # this is airflow/auth/managers/fab/ @@ -168,9 +183,20 @@ def get_user_display_name(self) -> str: return f"{first_name} {last_name}".strip() def get_user(self) -> User: - """Return the user associated to the user in session.""" + """ + Return the user associated to the user in session. + + Attempt to find the current user in g.user, as defined by the kerberos authentication backend. + If no such user is found, return the `current_user` local proxy object, linked to the user session. + + """ from flask_login import current_user + # If a user has gone through the Kerberos dance, the kerberos authentication manager + # has linked it with a User model, stored in g.user, and not the session. + if current_user.is_anonymous and getattr(g, "user", None) is not None and not g.user.is_anonymous: + return g.user + return current_user def init(self) -> None: @@ -179,7 +205,13 @@ def init(self) -> None: def is_logged_in(self) -> bool: """Return whether the user is logged in.""" - return not self.get_user().is_anonymous + user = self.get_user() + if Version(Version(version).base_version) < Version("3.0.0"): + return not user.is_anonymous and user.is_active + else: + return self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) or ( + not user.is_anonymous and user.is_active + ) def is_authorized_configuration( self, @@ -246,10 +278,20 @@ def is_authorized_dag( for resource_type in resource_types ) + def is_authorized_asset( + self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None + ) -> bool: + return self._is_authorized(method=method, resource_type=RESOURCE_ASSET, user=user) + def is_authorized_dataset( - self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None ) -> bool: - return self._is_authorized(method=method, resource_type=RESOURCE_DATASET, user=user) + warnings.warn( + "is_authorized_dataset will be renamed as is_authorized_asset in Airflow 3 and will be removed when the minimum Airflow version is set to 3.0 for the fab provider", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + return self.is_authorized_asset(method=method, user=user) def is_authorized_pool( self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None @@ -325,10 +367,7 @@ def get_permitted_dag_ids( resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) else: resources.add(resource) - return { - dag.dag_id - for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) - } + return set(session.scalars(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources)))) @cached_property def security_manager(self) -> FabAirflowSecurityManagerOverride: @@ -364,10 +403,15 @@ def get_url_logout(self): def get_url_user_profile(self) -> str | None: """Return the url to a page displaying info about the current user.""" - if not self.security_manager.user_view: + if not self.security_manager.user_view or self.appbuilder.get_app.config.get( + "AUTH_ROLE_PUBLIC", None + ): return None return url_for(f"{self.security_manager.user_view.endpoint}.userinfo") + def register_views(self) -> None: + self.security_manager.register_views() + def _is_authorized( self, *, @@ -503,7 +547,7 @@ def _get_root_dag_id(self, dag_id: str) -> str: :meta private: """ - if "." in dag_id: + if "." in dag_id and hasattr(DagModel, "root_dag_id"): return self.appbuilder.get_session.scalar( select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id).limit(1) ) @@ -519,9 +563,11 @@ def _sync_appbuilder_roles(self): # Otherwise, when the name of a view or menu is changed, the framework # will add the new Views and Menus names to the backend, but will not # delete the old ones. - if conf.getboolean( - "fab", "UPDATE_FAB_PERMS", fallback=conf.getboolean("webserver", "UPDATE_FAB_PERMS") - ): + if Version(Version(version).base_version) >= Version("3.0.0"): + fallback = None + else: + fallback = conf.getboolean("webserver", "UPDATE_FAB_PERMS") + if conf.getboolean("fab", "UPDATE_FAB_PERMS", fallback=fallback): self.security_manager.sync_roles() diff --git a/airflow/providers/fab/auth_manager/models/__init__.py b/airflow/providers/fab/auth_manager/models/__init__.py index bf4e43f275fab..2587d7034d04c 100644 --- a/airflow/providers/fab/auth_manager/models/__init__.py +++ b/airflow/providers/fab/auth_manager/models/__init__.py @@ -23,6 +23,7 @@ # Copyright 2013, Daniel Vaz Gaspar from typing import TYPE_CHECKING +import packaging.version from flask import current_app, g from flask_appbuilder.models.sqla import Model from sqlalchemy import ( @@ -32,6 +33,7 @@ ForeignKey, Index, Integer, + MetaData, String, Table, UniqueConstraint, @@ -39,16 +41,11 @@ func, select, ) -from sqlalchemy.orm import backref, declared_attr, relationship +from sqlalchemy.orm import backref, declared_attr, registry, relationship +from airflow import __version__ as airflow_version from airflow.auth.managers.models.base_user import BaseUser -from airflow.models.base import Base - -""" -Compatibility note: The models in this file are duplicated from Flask AppBuilder. -""" -# Use airflow metadata to create the tables -Model.metadata = Base.metadata +from airflow.models.base import _get_schema, naming_convention if TYPE_CHECKING: try: @@ -56,6 +53,22 @@ except Exception: Identity = None +""" +Compatibility note: The models in this file are duplicated from Flask AppBuilder. +""" + +metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention) +mapper_registry = registry(metadata=metadata) + +if packaging.version.parse(packaging.version.parse(airflow_version).base_version) >= packaging.version.parse( + "3.0.0" +): + Model.metadata = metadata +else: + from airflow.models.base import Base + + Model.metadata = Base.metadata + class Action(Model): """Represents permission actions such as `can_read`.""" diff --git a/airflow/providers/fab/auth_manager/models/anonymous_user.py b/airflow/providers/fab/auth_manager/models/anonymous_user.py index ba75de0d3c6e3..9afb2cdff635f 100644 --- a/airflow/providers/fab/auth_manager/models/anonymous_user.py +++ b/airflow/providers/fab/auth_manager/models/anonymous_user.py @@ -29,10 +29,13 @@ class AnonymousUser(AnonymousUserMixin, BaseUser): _roles: set[tuple[str, str]] = set() _perms: set[tuple[str, str]] = set() + first_name = "Anonymous" + last_name = "" + @property def roles(self): if not self._roles: - public_role = current_app.appbuilder.get_app.config["AUTH_ROLE_PUBLIC"] + public_role = current_app.config.get("AUTH_ROLE_PUBLIC", None) self._roles = {current_app.appbuilder.sm.find_role(public_role)} if public_role else set() return self._roles @@ -48,3 +51,6 @@ def perms(self): (perm.action.name, perm.resource.name) for role in self.roles for perm in role.permissions } return self._perms + + def get_name(self) -> str: + return "Anonymous" diff --git a/airflow/providers/fab/auth_manager/models/db.py b/airflow/providers/fab/auth_manager/models/db.py new file mode 100644 index 0000000000000..ce0efef55a1cd --- /dev/null +++ b/airflow/providers/fab/auth_manager/models/db.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from pathlib import Path + +from airflow import settings +from airflow.exceptions import AirflowException +from airflow.providers.fab.auth_manager.models import metadata +from airflow.utils.db import _offline_migration, print_happy_cat +from airflow.utils.db_manager import BaseDBManager + +PACKAGE_DIR = Path(__file__).parents[2] + +_REVISION_HEADS_MAP: dict[str, str] = { + "1.4.0": "6709f7a774b9", +} + + +class FABDBManager(BaseDBManager): + """Manages FAB database.""" + + metadata = metadata + version_table_name = "alembic_version_fab" + migration_dir = (PACKAGE_DIR / "migrations").as_posix() + alembic_file = (PACKAGE_DIR / "alembic.ini").as_posix() + supports_table_dropping = True + + def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False): + """Upgrade the database.""" + if from_revision and not show_sql_only: + raise AirflowException("`from_revision` only supported with `sql_only=True`.") + + # alembic adds significant import time, so we import it lazily + if not settings.SQL_ALCHEMY_CONN: + raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is a critical assertion.") + from alembic import command + + config = self.get_alembic_config() + + if show_sql_only: + if settings.engine.dialect.name == "sqlite": + raise SystemExit("Offline migration not supported for SQLite.") + if not from_revision: + from_revision = self.get_current_revision() + + if not to_revision: + script = self.get_script_object(config) + to_revision = script.get_current_head() + + if to_revision == from_revision: + print_happy_cat("No migrations to apply; nothing to do.") + return + _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}") + return # only running sql; our job is done + + if not self.get_current_revision(): + # New DB; initialize and exit + self.initdb() + return + + command.upgrade(config, revision=to_revision or "heads") + + def downgrade(self, to_revision, from_revision=None, show_sql_only=False): + if from_revision and not show_sql_only: + raise ValueError( + "`from_revision` can't be combined with `show_sql_only=False`. When actually " + "applying a downgrade (instead of just generating sql), we always " + "downgrade from current revision." + ) + + if not settings.SQL_ALCHEMY_CONN: + raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set.") + + # alembic adds significant import time, so we import it lazily + from alembic import command + + self.log.info("Attempting downgrade of FAB migration to revision %s", to_revision) + config = self.get_alembic_config() + + if show_sql_only: + self.log.warning("Generating sql scripts for manual migration.") + if not from_revision: + from_revision = self.get_current_revision() + if from_revision is None: + self.log.info("No revision found") + return + revision_range = f"{from_revision}:{to_revision}" + _offline_migration(command.downgrade, config=config, revision=revision_range) + else: + self.log.info("Applying FAB downgrade migrations.") + command.downgrade(config, revision=to_revision, sql=show_sql_only) diff --git a/airflow/providers/fab/auth_manager/schemas/__init__.py b/airflow/providers/fab/auth_manager/schemas/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/fab/auth_manager/schemas/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/fab/auth_manager/schemas/role_and_permission_schema.py b/airflow/providers/fab/auth_manager/schemas/role_and_permission_schema.py new file mode 100644 index 0000000000000..756d8de6f5914 --- /dev/null +++ b/airflow/providers/fab/auth_manager/schemas/role_and_permission_schema.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import NamedTuple + +from marshmallow import Schema, fields +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.providers.fab.auth_manager.models import Action, Permission, Resource, Role + + +class ActionSchema(SQLAlchemySchema): + """Action Schema.""" + + class Meta: + """Meta.""" + + model = Action + + name = auto_field() + + +class ResourceSchema(SQLAlchemySchema): + """View menu Schema.""" + + class Meta: + """Meta.""" + + model = Resource + + name = auto_field() + + +class ActionCollection(NamedTuple): + """Action Collection.""" + + actions: list[Action] + total_entries: int + + +class ActionCollectionSchema(Schema): + """Permissions list schema.""" + + actions = fields.List(fields.Nested(ActionSchema)) + total_entries = fields.Int() + + +class ActionResourceSchema(SQLAlchemySchema): + """Action View Schema.""" + + class Meta: + """Meta.""" + + model = Permission + + action = fields.Nested(ActionSchema, data_key="action") + resource = fields.Nested(ResourceSchema, data_key="resource") + + +class RoleSchema(SQLAlchemySchema): + """Role item schema.""" + + class Meta: + """Meta.""" + + model = Role + + name = auto_field() + permissions = fields.List(fields.Nested(ActionResourceSchema), data_key="actions") + + +class RoleCollection(NamedTuple): + """List of roles.""" + + roles: list[Role] + total_entries: int + + +class RoleCollectionSchema(Schema): + """List of roles.""" + + roles = fields.List(fields.Nested(RoleSchema)) + total_entries = fields.Int() + + +role_schema = RoleSchema() +role_collection_schema = RoleCollectionSchema() +action_collection_schema = ActionCollectionSchema() diff --git a/airflow/providers/fab/auth_manager/schemas/user_schema.py b/airflow/providers/fab/auth_manager/schemas/user_schema.py new file mode 100644 index 0000000000000..4155667d56766 --- /dev/null +++ b/airflow/providers/fab/auth_manager/schemas/user_schema.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import NamedTuple + +from marshmallow import Schema, fields +from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field + +from airflow.api_connexion.parameters import validate_istimezone +from airflow.providers.fab.auth_manager.models import User +from airflow.providers.fab.auth_manager.schemas.role_and_permission_schema import RoleSchema + + +class UserCollectionItemSchema(SQLAlchemySchema): + """user collection item schema.""" + + class Meta: + """Meta.""" + + model = User + dateformat = "iso" + + first_name = auto_field() + last_name = auto_field() + username = auto_field() + active = auto_field(dump_only=True) + email = auto_field() + last_login = auto_field(dump_only=True) + login_count = auto_field(dump_only=True) + fail_login_count = auto_field(dump_only=True) + roles = fields.List(fields.Nested(RoleSchema, only=("name",))) + created_on = auto_field(validate=validate_istimezone, dump_only=True) + changed_on = auto_field(validate=validate_istimezone, dump_only=True) + + +class UserSchema(UserCollectionItemSchema): + """User schema.""" + + password = auto_field(load_only=True) + + +class UserCollection(NamedTuple): + """User collection.""" + + users: list[User] + total_entries: int + + +class UserCollectionSchema(Schema): + """User collection schema.""" + + users = fields.List(fields.Nested(UserCollectionItemSchema)) + total_entries = fields.Int() + + +user_collection_item_schema = UserCollectionItemSchema() +user_schema = UserSchema() +user_collection_schema = UserCollectionSchema() diff --git a/airflow/providers/fab/auth_manager/security_manager/override.py b/airflow/providers/fab/auth_manager/security_manager/override.py index e2208e5fb409f..73b04e0a3ab5e 100644 --- a/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/airflow/providers/fab/auth_manager/security_manager/override.py @@ -17,14 +17,16 @@ # under the License. from __future__ import annotations +import copy import datetime +import importlib import itertools import logging import os import random import uuid import warnings -from typing import TYPE_CHECKING, Any, Callable, Collection, Container, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Collection, Container, Iterable, Mapping, Sequence import jwt import packaging.version @@ -67,6 +69,7 @@ from flask_login import LoginManager from itsdangerous import want_bytes from markupsafe import Markup +from packaging.version import Version from sqlalchemy import and_, func, inspect, literal, or_, select from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import Session, joinedload @@ -115,6 +118,9 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod + from airflow.security.permissions import RESOURCE_ASSET # type: ignore[attr-defined] +else: + from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET log = logging.getLogger(__name__) @@ -234,7 +240,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_DEPENDENCIES), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE), (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_READ, RESOURCE_ASSET), (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL), (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), @@ -253,7 +259,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_DEPENDENCIES), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_ACCESS_MENU, RESOURCE_ASSET), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_CLUSTER_ACTIVITY), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS), (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DOCS_MENU), @@ -273,7 +279,7 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_CREATE, RESOURCE_ASSET), ] # [END security_user_perms] @@ -302,8 +308,8 @@ class FabAirflowSecurityManagerOverride(AirflowSecurityManagerV2): (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_XCOM), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DATASET), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DATASET), + (permissions.ACTION_CAN_DELETE, RESOURCE_ASSET), + (permissions.ACTION_CAN_CREATE, RESOURCE_ASSET), ] # [END security_op_perms] @@ -552,7 +558,7 @@ def reset_password(self, userid: int, password: str) -> bool: def reset_user_sessions(self, user: User) -> None: if isinstance(self.appbuilder.get_app.session_interface, AirflowDatabaseSessionInterface): interface = self.appbuilder.get_app.session_interface - session = interface.db.session + session = interface.client.session user_session_model = interface.sql_session_model num_sessions = session.query(user_session_model).count() if num_sessions > MAX_NUM_DATABASE_USER_SESSIONS: @@ -572,6 +578,7 @@ def reset_user_sessions(self, user: User) -> None: session_details = interface.serializer.loads(want_bytes(s.data)) if session_details.get("_user_id") == user.id: session.delete(s) + session.commit() else: self._cli_safe_flash( "Since you are using `securecookie` session backend mechanism, we cannot prevent " @@ -609,7 +616,7 @@ def auth_rate_limit(self) -> str: @property def auth_role_public(self): """Get the public role.""" - return self.appbuilder.get_app.config["AUTH_ROLE_PUBLIC"] + return self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) @property def oauth_providers(self): @@ -843,6 +850,22 @@ def _init_config(self): app.config.setdefault("AUTH_ROLES_SYNC_AT_LOGIN", False) app.config.setdefault("AUTH_API_LOGIN_ALLOW_MULTIPLE_PROVIDERS", False) + # Werkzeug prior to 3.0.0 does not support scrypt + parsed_werkzeug_version = Version(importlib.metadata.version("werkzeug")) + if parsed_werkzeug_version < Version("3.0.0"): + app.config.setdefault( + "AUTH_DB_FAKE_PASSWORD_HASH_CHECK", + "pbkdf2:sha256:150000$Z3t6fmj2$22da622d94a1f8118" + "c0976a03d2f18f680bfff877c9a965db9eedc51bc0be87c", + ) + else: + app.config.setdefault( + "AUTH_DB_FAKE_PASSWORD_HASH_CHECK", + "scrypt:32768:8:1$wiDa0ruWlIPhp9LM$6e409d093e62ad54df2af895d0e125b05ff6cf6414" + "8350189ffc4bcc71286edf1b8ad94a442c00f890224bf2b32153d0750c89ee9" + "401e62f9dcee5399065e4e5", + ) + # LDAP Config if self.auth_type == AUTH_LDAP: if "AUTH_LDAP_SERVER" not in app.config: @@ -955,7 +978,8 @@ def create_db(self): self.add_role(role_name) if self.auth_role_admin not in self._builtin_roles: self.add_role(self.auth_role_admin) - self.add_role(self.auth_role_public) + if self.auth_role_public: + self.add_role(self.auth_role_public) if self.count_users() == 0 and self.auth_role_public != self.auth_role_admin: log.warning(const.LOGMSG_WAR_SEC_NO_USER) except Exception: @@ -1073,7 +1097,8 @@ def create_dag_specific_permissions(self) -> None: dags = dagbag.dags.values() for dag in dags: - root_dag_id = dag.parent_dag.dag_id if dag.parent_dag else dag.dag_id + # TODO: Remove this when the minimum version of Airflow is bumped to 3.0 + root_dag_id = (getattr(dag, "parent_dag", None) or dag).dag_id for resource_name, resource_values in self.RESOURCE_DETAILS_MAP.items(): dag_resource_name = self._resource_name(root_dag_id, resource_name) for action_name in resource_values["actions"]: @@ -1103,7 +1128,7 @@ def is_dag_resource(self, resource_name: str) -> bool: def sync_perm_for_dag( self, dag_id: str, - access_control: dict[str, dict[str, Collection[str]]] | None = None, + access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]] | None = None, ) -> None: """ Sync permissions for given dag id. @@ -1124,7 +1149,7 @@ def sync_perm_for_dag( if access_control is not None: self.log.debug("Syncing DAG-level permissions for DAG '%s'", dag_id) - self._sync_dag_view_permissions(dag_id, access_control.copy()) + self._sync_dag_view_permissions(dag_id, copy.copy(access_control)) else: self.log.debug( "Not syncing DAG-level permissions for DAG '%s' as access control is unset.", @@ -1145,7 +1170,7 @@ def _resource_name(self, dag_id: str, resource_name: str) -> str: def _sync_dag_view_permissions( self, dag_id: str, - access_control: dict[str, dict[str, Collection[str]]], + access_control: Mapping[str, Mapping[str, Collection[str]] | Collection[str]], ) -> None: """ Set the access policy on the given DAG's ViewModel. @@ -1171,7 +1196,13 @@ def _get_or_create_dag_permission(action_name: str, dag_resource_name: str) -> P for perm in existing_dag_perms: non_admin_roles = [role for role in perm.role if role.name != "Admin"] for role in non_admin_roles: - target_perms_for_role = access_control.get(role.name, {}).get(resource_name, set()) + access_control_role = access_control.get(role.name) + target_perms_for_role = set() + if access_control_role: + if isinstance(access_control_role, set): + target_perms_for_role = access_control_role + elif isinstance(access_control_role, dict): + target_perms_for_role = access_control_role.get(resource_name, set()) if perm.action.name not in target_perms_for_role: self.log.info( "Revoking '%s' on DAG '%s' for role '%s'", @@ -1190,7 +1221,7 @@ def _get_or_create_dag_permission(action_name: str, dag_resource_name: str) -> P f"'{rolename}', but that role does not exist" ) - if isinstance(resource_actions, (set, list)): + if not isinstance(resource_actions, dict): # Support for old-style access_control where only the actions are specified resource_actions = {permissions.RESOURCE_DAG: set(resource_actions)} @@ -2196,8 +2227,7 @@ def auth_user_db(self, username, password): if user is None or (not user.is_active): # Balance failure and success check_password_hash( - "pbkdf2:sha256:150000$Z3t6fmj2$22da622d94a1f8118" - "c0976a03d2f18f680bfff877c9a965db9eedc51bc0be87c", + self.appbuilder.get_app.config["AUTH_DB_FAKE_PASSWORD_HASH_CHECK"], "password", ) log.info(LOGMSG_WAR_SEC_LOGIN_FAILED, username) @@ -2828,7 +2858,8 @@ def filter_roles_by_perm_with_action(self, action_name: str, role_ids: list[int] ).all() def _get_root_dag_id(self, dag_id: str) -> str: - if "." in dag_id: + # TODO: The "root_dag_id" check can be remove when the minimum version of Airflow is bumped to 3.0 + if "." in dag_id and hasattr(DagModel, "root_dag_id"): dm = self.appbuilder.get_session.execute( select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id) ).one() diff --git a/airflow/providers/fab/migrations/README b/airflow/providers/fab/migrations/README new file mode 100644 index 0000000000000..2500aa1bcf726 --- /dev/null +++ b/airflow/providers/fab/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. diff --git a/airflow/providers/fab/migrations/__init__.py b/airflow/providers/fab/migrations/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/fab/migrations/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/fab/migrations/env.py b/airflow/providers/fab/migrations/env.py new file mode 100644 index 0000000000000..903057ba60208 --- /dev/null +++ b/airflow/providers/fab/migrations/env.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import contextlib +from logging import getLogger +from logging.config import fileConfig + +from alembic import context + +from airflow import settings +from airflow.providers.fab.auth_manager.models.db import FABDBManager + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +version_table = FABDBManager.version_table_name + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if not getLogger().handlers and config.config_file_name: + fileConfig(config.config_file_name, disable_existing_loggers=False) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = FABDBManager.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def include_object(_, name, type_, *args): + if type_ == "table" and name not in target_metadata.tables: + return False + return True + + +def run_migrations_offline(): + """ + Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + context.configure( + url=settings.SQL_ALCHEMY_CONN, + target_metadata=target_metadata, + literal_binds=True, + compare_type=True, + compare_server_default=True, + render_as_batch=True, + version_table=version_table, + include_object=include_object, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """ + Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, "autogenerate", False): + script = directives[0] + if script.upgrade_ops and script.upgrade_ops.is_empty(): + directives[:] = [] + print("No change detected in ORM schema, skipping revision.") + + with contextlib.ExitStack() as stack: + connection = config.attributes.get("connection", None) + + if not connection: + connection = stack.push(settings.engine.connect()) + + context.configure( + connection=connection, + transaction_per_migration=True, + target_metadata=target_metadata, + compare_type=True, + compare_server_default=True, + include_object=include_object, + render_as_batch=True, + process_revision_directives=process_revision_directives, + version_table=version_table, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/airflow/providers/fab/migrations/script.py.mako b/airflow/providers/fab/migrations/script.py.mako new file mode 100644 index 0000000000000..c0193ce2b0471 --- /dev/null +++ b/airflow/providers/fab/migrations/script.py.mako @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} +fab_version = None + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/airflow/providers/fab/migrations/versions/0001_1_4_0_placeholder_migration.py b/airflow/providers/fab/migrations/versions/0001_1_4_0_placeholder_migration.py new file mode 100644 index 0000000000000..722c39198a185 --- /dev/null +++ b/airflow/providers/fab/migrations/versions/0001_1_4_0_placeholder_migration.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +placeholder migration. + +Revision ID: 6709f7a774b9 +Revises: +Create Date: 2024-09-03 17:06:38.040510 + +Note: This is a placeholder migration used to stamp the migration +when we create the migration from the ORM. Otherwise, it will run +without stamping the migration, leading to subsequent changes to +the tables not being migrated. +""" + +from __future__ import annotations + +# revision identifiers, used by Alembic. +revision = "6709f7a774b9" +down_revision = None +branch_labels = None +depends_on = None +fab_version = "1.4.0" + + +def upgrade() -> None: ... + + +def downgrade() -> None: ... diff --git a/airflow/providers/fab/migrations/versions/__init__.py b/airflow/providers/fab/migrations/versions/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/fab/migrations/versions/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/fab/provider.yaml b/airflow/providers/fab/provider.yaml index 52c6dbe053b11..f36519ff278fa 100644 --- a/airflow/providers/fab/provider.yaml +++ b/airflow/providers/fab/provider.yaml @@ -28,10 +28,18 @@ description: | # For providers until we think it should be released. state: ready -source-date-epoch: 1722149665 +source-date-epoch: 1738677661 # note that those versions are maintained by release manager - do not update them manually versions: + - 1.5.4 + - 1.5.3 + - 1.5.2 + - 1.5.1 + - 1.5.0 + - 1.4.1 + - 1.4.0 + - 1.3.0 - 1.2.2 - 1.2.1 - 1.2.0 @@ -45,17 +53,27 @@ versions: dependencies: - apache-airflow>=2.9.0 - - flask>=2.2,<2.3 + - apache-airflow-providers-common-compat>=1.2.1 + - flask-login>=0.6.3 + - flask-session>=0.8.0 + - flask>=2.2,<3 # We are tightly coupled with FAB version as we vendored-in part of FAB code related to security manager # This is done as part of preparation to removing FAB as dependency, but we are not ready for it yet # Every time we update FAB version here, please make sure that you review the classes and models in # `airflow/providers/fab/auth_manager/security_manager/override.py` with their upstream counterparts. # In particular, make sure any breaking changes, for example any new methods, are accounted for. - - flask-appbuilder==4.5.2 - - flask-login>=0.6.2 + - flask-appbuilder==4.5.3 - google-re2>=1.0 - jmespath>=0.7.0 +additional-extras: + - name: kerberos + dependencies: + - kerberos>=1.3.0 + +devel-dependencies: + - kerberos>=1.3.0 + config: fab: description: This section contains configs specific to FAB provider. diff --git a/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py b/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py index 1f64b3181576d..3b218e03dcec1 100644 --- a/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py +++ b/tests/providers/fab/auth_manager/api/auth/backend/test_basic_auth.py @@ -22,7 +22,7 @@ from flask import Response from flask_appbuilder.const import AUTH_LDAP -from airflow.api.auth.backend.basic_auth import requires_authentication +from airflow.providers.fab.auth_manager.api.auth.backend.basic_auth import requires_authentication from airflow.www import app as application from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS @@ -33,7 +33,9 @@ @pytest.fixture def app(): - return application.create_app(testing=True) + _app = application.create_app(testing=True) + _app.config["AUTH_ROLE_PUBLIC"] = None + return _app @pytest.fixture diff --git a/tests/providers/fab/auth_manager/api/auth/backend/test_kerberos_auth.py b/tests/providers/fab/auth_manager/api/auth/backend/test_kerberos_auth.py index a49709c335d87..e57f34ce4b033 100644 --- a/tests/providers/fab/auth_manager/api/auth/backend/test_kerberos_auth.py +++ b/tests/providers/fab/auth_manager/api/auth/backend/test_kerberos_auth.py @@ -16,10 +16,7 @@ # under the License. from __future__ import annotations -from airflow.api.auth.backend.kerberos_auth import ( - init_app as base_init_app, -) -from tests.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.api.auth.backend.kerberos_auth import init_app @@ -27,4 +24,4 @@ class TestKerberosAuth: def test_init_app(self): - assert init_app == base_init_app + init_app diff --git a/tests/providers/fab/auth_manager/api/auth/backend/test_session.py b/tests/providers/fab/auth_manager/api/auth/backend/test_session.py new file mode 100644 index 0000000000000..405eafe11dfc4 --- /dev/null +++ b/tests/providers/fab/auth_manager/api/auth/backend/test_session.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +from flask import Response +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS + +from airflow.providers.fab.auth_manager.api.auth.backend.session import requires_authentication +from airflow.www import app as application + +pytestmark = [ + pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), +] + + +@pytest.fixture +def app(): + return application.create_app(testing=True) + + +mock_call = Mock() + + +@requires_authentication +def function_decorated(): + mock_call() + + +@pytest.mark.db_test +class TestSessionAuth: + def setup_method(self) -> None: + mock_call.reset_mock() + + @patch("airflow.providers.fab.auth_manager.api.auth.backend.session.get_auth_manager") + def test_requires_authentication_when_not_authenticated(self, mock_get_auth_manager, app): + auth_manager = Mock() + auth_manager.is_logged_in.return_value = False + mock_get_auth_manager.return_value = auth_manager + with app.test_request_context() as mock_context: + mock_context.request.authorization = None + result = function_decorated() + + assert type(result) is Response + assert result.status_code == 401 + + @patch("airflow.providers.fab.auth_manager.api.auth.backend.session.get_auth_manager") + def test_requires_authentication_when_authenticated(self, mock_get_auth_manager, app): + auth_manager = Mock() + auth_manager.is_logged_in.return_value = True + mock_get_auth_manager.return_value = auth_manager + with app.test_request_context() as mock_context: + mock_context.request.authorization = None + function_decorated() + + mock_call.assert_called_once() diff --git a/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py new file mode 100644 index 0000000000000..b208b845096b9 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/api_connexion_utils.py @@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from contextlib import contextmanager + +from tests_common.test_utils.compat import ignore_provider_compatibility_error + +with ignore_provider_compatibility_error("2.9.0+", __file__): + from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES + + +@contextmanager +def create_test_client(app, user_name, role_name, permissions): + """ + Helper function to create a client with a temporary user which will be deleted once done + """ + client = app.test_client() + with create_user_scope(app, username=user_name, role_name=role_name, permissions=permissions) as _: + resp = client.post("/login/", data={"username": user_name, "password": user_name}) + assert resp.status_code == 302 + yield client + + +@contextmanager +def create_user_scope(app, username, **kwargs): + """ + Helper function designed to be used with pytest fixture mainly. + It will create a user and provide it for the fixture via YIELD (generator) + then will tidy up once test is complete + """ + test_user = create_user(app, username, **kwargs) + + try: + yield test_user + finally: + delete_user(app, username) + + +def create_user(app, username, role_name=None, email=None, permissions=None): + appbuilder = app.appbuilder + + # Removes user and role so each test has isolated test data. + delete_user(app, username) + role = None + if role_name: + delete_role(app, role_name) + role = create_role(app, role_name, permissions) + else: + role = [] + + return appbuilder.sm.add_user( + username=username, + first_name=username, + last_name=username, + email=email or f"{username}@example.org", + role=role, + password=username, + ) + + +def create_role(app, name, permissions=None): + appbuilder = app.appbuilder + role = appbuilder.sm.find_role(name) + if not role: + role = appbuilder.sm.add_role(name) + if not permissions: + permissions = [] + for permission in permissions: + perm_object = appbuilder.sm.get_permission(*permission) + appbuilder.sm.add_permission_to_role(role, perm_object) + return role + + +def set_user_single_role(app, user, role_name): + role = create_role(app, role_name) + if role not in user.roles: + user.roles = [role] + app.appbuilder.sm.update_user(user) + user._perms = None + + +def delete_role(app, name): + if name not in EXISTING_ROLES: + if app.appbuilder.sm.find_role(name): + app.appbuilder.sm.delete_role(name) + + +def delete_roles(app): + for role in app.appbuilder.sm.get_all_roles(): + delete_role(app, role.name) + + +def delete_user(app, username): + appbuilder = app.appbuilder + for user in appbuilder.sm.get_all_users(): + if user.username == username: + _ = [ + delete_role(app, role.name) for role in user.roles if role and role.name not in EXISTING_ROLES + ] + appbuilder.sm.del_register_user(user) + break diff --git a/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py new file mode 100644 index 0000000000000..b7714e5192e6a --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/remote_user_api_auth_backend.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default authentication backend - everything is allowed""" + +from __future__ import annotations + +import logging +from functools import wraps +from typing import TYPE_CHECKING, Callable, TypeVar, cast + +from flask import Response, request +from flask_login import login_user + +from airflow.utils.airflow_flask_app import get_airflow_app + +if TYPE_CHECKING: + from requests.auth import AuthBase + +log = logging.getLogger(__name__) + +CLIENT_AUTH: tuple[str, str] | AuthBase | None = None + + +def init_app(_): + """Initializes authentication backend""" + + +T = TypeVar("T", bound=Callable) + + +def _lookup_user(user_email_or_username: str): + security_manager = get_airflow_app().appbuilder.sm + user = security_manager.find_user(email=user_email_or_username) or security_manager.find_user( + username=user_email_or_username + ) + if not user: + return None + + if not user.is_active: + return None + + return user + + +def requires_authentication(function: T): + """Decorator for functions that require authentication""" + + @wraps(function) + def decorated(*args, **kwargs): + user_id = request.remote_user + if not user_id: + log.debug("Missing REMOTE_USER.") + return Response("Forbidden", 403) + + log.debug("Looking for user: %s", user_id) + + user = _lookup_user(user_id) + if not user: + return Response("Forbidden", 403) + + log.debug("Found user: %s", user) + + login_user(user, remember=False) + return function(*args, **kwargs) + + return cast(T, decorated) diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_asset_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_asset_endpoint.py new file mode 100644 index 0000000000000..4cd76aa2b4a54 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_asset_endpoint.py @@ -0,0 +1,327 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Generator + +import pytest +import time_machine +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_assets, clear_db_runs +from tests_common.test_utils.www import _check_last_log + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.security import permissions +from airflow.utils import timezone + +try: + from airflow.models.asset import AssetDagRunQueue, AssetModel +except ImportError: + if AIRFLOW_V_3_0_PLUS: + raise + else: + pass + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_queued_event", + role_name="TestQueuedEvent", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), + ], + ) + + yield app + + delete_user(app, username="test_queued_event") + + +class TestAssetEndpoint: + default_time = "2020-06-11T18:00:00+00:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() + clear_db_assets() + clear_db_runs() + + def teardown_method(self) -> None: + clear_db_assets() + clear_db_runs() + + def _create_asset(self, session): + asset_model = AssetModel( + id=1, + uri="s3://bucket/key", + extra={"foo": "bar"}, + created_at=timezone.parse(self.default_time), + updated_at=timezone.parse(self.default_time), + ) + session.add(asset_model) + session.commit() + return asset_model + + +class TestQueuedEventEndpoint(TestAssetEndpoint): + @pytest.fixture + def time_freezer(self) -> Generator: + freezer = time_machine.travel(self.default_time, tick=False) + freezer.start() + + yield + + freezer.stop() + + def _create_asset_dag_run_queues(self, dag_id, asset_id, session): + ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) + session.add(ddrq) + session.commit() + return ddrq + + +class TestGetDagAssetQueuedEvent(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + asset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, asset_id, session) + asset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + + def test_should_respond_404(self): + dag_id = "not_exists" + asset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagAssetQueuedEvent(TestAssetEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + asset_uri = "s3://bucket/key" + asset_id = self._create_asset(session).id + + ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) + session.add(ddrq) + session.commit() + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 1 + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log( + session, dag_id=dag_id, event="api.delete_dag_asset_queued_event", execution_date=None + ) + + def test_should_respond_404(self): + dag_id = "not_exists" + asset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDagAssetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + asset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, asset_id, session) + + response = self.client.get( + f"/api/v1/dags/{dag_id}/assets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/assets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint): + def test_should_respond_404(self): + dag_id = "not_exists" + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/assets/queuedEvent", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with dag_id: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): + @pytest.mark.usefixtures("time_freezer") + def test_should_respond_200(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + asset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, asset_id, session) + asset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 200 + assert response.json == { + "queued_events": [ + { + "created_at": self.default_time, + "uri": "s3://bucket/key", + "dag_id": "dag", + } + ], + "total_entries": 1, + } + + def test_should_respond_404(self): + asset_uri = "not_exists" + + response = self.client.get( + f"/api/v1/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + + +class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): + def test_delete_should_respond_204(self, session, create_dummy_dag): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + asset_id = self._create_asset(session).id + self._create_asset_dag_run_queues(dag_id, asset_id, session) + asset_uri = "s3://bucket/key" + + response = self.client.delete( + f"/api/v1/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 204 + conn = session.query(AssetDagRunQueue).all() + assert len(conn) == 0 + _check_last_log(session, dag_id=None, event="api.delete_asset_queued_events", execution_date=None) + + def test_should_respond_404(self): + asset_uri = "not_exists" + + response = self.client.delete( + f"/api/v1/assets/queuedEvent/{asset_uri}", + environ_overrides={"REMOTE_USER": "test_queued_event"}, + ) + + assert response.status_code == 404 + assert { + "detail": "Queue event with asset uri: `not_exists` was not found", + "status": 404, + "title": "Queue event not found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_auth.py b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py new file mode 100644 index 0000000000000..6a90b7ec4b30e --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_auth.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from base64 import b64encode + +import pytest +from flask_login import current_user +from tests_common.test_utils.api_connexion_utils import assert_401 +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import clear_db_pools +from tests_common.test_utils.www import client_with_login + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class BaseTestAuth: + @pytest.fixture(autouse=True) + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api + + sm = self.app.appbuilder.sm + tester = sm.find_user(username="test") + if not tester: + role_admin = sm.find_role("Admin") + sm.add_user( + username="test", + first_name="test", + last_name="test", + email="test@fab.org", + role=role_admin, + password="test", + ) + + +class TestBasicAuth(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_success(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert current_user.email == "test@fab.org" + + assert response.status_code == 200 + assert response.json == { + "pools": [ + { + "name": "default_pool", + "slots": 128, + "occupied_slots": 0, + "running_slots": 0, + "queued_slots": 0, + "scheduled_slots": 0, + "deferred_slots": 0, + "open_slots": 128, + "description": "Default pool", + "include_deferred": False, + }, + ], + "total_entries": 1, + } + + @pytest.mark.parametrize( + "token", + [ + "basic", + "basic ", + "bearer", + "test:test", + b64encode(b"test:test").decode(), + "bearer ", + "basic: ", + "basic 123", + ], + ) + def test_malformed_headers(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + @pytest.mark.parametrize( + "token", + [ + "basic " + b64encode(b"test").decode(), + "basic " + b64encode(b"test:").decode(), + "basic " + b64encode(b"test:123").decode(), + "basic " + b64encode(b"test test").decode(), + ], + ) + def test_invalid_auth_header(self, token): + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 401 + assert response.headers["Content-Type"] == "application/problem+json" + assert response.headers["WWW-Authenticate"] == "Basic" + assert_401(response) + + +class TestSessionWithBasicAuthFallback(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + { + ( + "api", + "auth_backends", + ): "airflow.providers.fab.auth_manager.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" + } + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_basic_auth_fallback(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + # request uses session + admin_user = client_with_login(self.app, username="test", password="test") + response = admin_user.get("/api/v1/pools") + assert response.status_code == 200 + + # request uses basic auth + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + + # request without session or basic auth header + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools") + assert response.status_code == 401 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_cors.py b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py new file mode 100644 index 0000000000000..3741d71fb8b96 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_cors.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from base64 import b64encode + +import pytest +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.config import conf_vars +from tests_common.test_utils.db import clear_db_pools + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class BaseTestAuth: + @pytest.fixture(autouse=True) + def set_attrs(self, minimal_app_for_auth_api): + self.app = minimal_app_for_auth_api + + sm = self.app.appbuilder.sm + tester = sm.find_user(username="test") + if not tester: + role_admin = sm.find_role("Admin") + sm.add_user( + username="test", + first_name="test", + last_name="test", + email="test@fab.org", + role=role_admin, + password="test", + ) + + +class TestEmptyCors(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_empty_cors_headers(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + assert "Access-Control-Allow-Headers" not in response.headers + assert "Access-Control-Allow-Methods" not in response.headers + assert "Access-Control-Allow-Origin" not in response.headers + + +class TestCorsOrigin(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + { + ( + "api", + "auth_backends", + ): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth", + ("api", "access_control_allow_origins"): "http://apache.org http://example.com", + } + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_cors_origin_reflection(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + assert response.status_code == 200 + assert response.headers["Access-Control-Allow-Origin"] == "http://apache.org" + + response = test_client.get( + "/api/v1/pools", headers={"Authorization": token, "Origin": "http://apache.org"} + ) + assert response.status_code == 200 + assert response.headers["Access-Control-Allow-Origin"] == "http://apache.org" + + response = test_client.get( + "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} + ) + assert response.status_code == 200 + assert response.headers["Access-Control-Allow-Origin"] == "http://example.com" + + +class TestCorsWildcard(BaseTestAuth): + @pytest.fixture(autouse=True, scope="class") + def with_basic_auth_backend(self, minimal_app_for_auth_api): + from airflow.www.extensions.init_security import init_api_auth + + old_auth = getattr(minimal_app_for_auth_api, "api_auth") + + try: + with conf_vars( + { + ( + "api", + "auth_backends", + ): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth", + ("api", "access_control_allow_origins"): "*", + } + ): + init_api_auth(minimal_app_for_auth_api) + yield + finally: + setattr(minimal_app_for_auth_api, "api_auth", old_auth) + + def test_cors_origin_reflection(self): + token = "Basic " + b64encode(b"test:test").decode() + clear_db_pools() + + with self.app.test_client() as test_client: + response = test_client.get( + "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} + ) + assert response.status_code == 200 + assert response.headers["Access-Control-Allow-Origin"] == "*" diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py new file mode 100644 index 0000000000000..198f213aa25ff --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_endpoint.py @@ -0,0 +1,252 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +import pendulum +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags +from tests_common.test_utils.www import _check_last_log + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagBag, DagModel +from airflow.models.dag import DAG +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils.session import provide_session + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture +def current_file_token(url_safe_serializer) -> str: + return url_safe_serializer.dumps(__file__) + + +DAG_ID = "test_dag" +TASK_ID = "op1" +DAG2_ID = "test_dag2" +DAG3_ID = "test_dag3" +UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user(app, username="test_granular_permissions", role_name="TestGranularDag") + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_1", + access_control={ + "TestGranularDag": { + permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} + }, + }, + ) + + with DAG( + DAG_ID, + schedule=None, + start_date=datetime(2020, 6, 15), + doc_md="details", + params={"foo": 1}, + tags=["example"], + ) as dag: + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG2_ID, schedule=None, start_date=datetime(2020, 6, 15)) as dag2: # no doc_md + EmptyOperator(task_id=TASK_ID) + + with DAG(DAG3_ID, schedule=None) as dag3: # DAG start_date set to None + EmptyOperator(task_id=TASK_ID, start_date=datetime(2019, 6, 12)) + + dag_bag = DagBag(os.devnull, include_examples=False) + dag_bag.dags = {dag.dag_id: dag, dag2.dag_id: dag2, dag3.dag_id: dag3} + + app.dag_bag = dag_bag + + yield app + + delete_user(app, username="test_granular_permissions") + + +class TestDagEndpoint: + @staticmethod + def clean_db(): + clear_db_runs() + clear_db_dags() + clear_db_serialized_dags() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.clean_db() + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.dag_id = DAG_ID + self.dag2_id = DAG2_ID + self.dag3_id = DAG3_ID + + def teardown_method(self) -> None: + self.clean_db() + + @provide_session + def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None): + for num in range(1, count + 1): + dag_model = DagModel( + dag_id=f"{dag_id_prefix}_{num}", + fileloc=f"/tmp/dag_{num}.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=is_paused, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + ) + session.add(dag_model) + + @provide_session + def _create_dag_model_for_details_endpoint_with_asset_expression(self, dag_id, session=None): + dag_model = DagModel( + dag_id=dag_id, + fileloc="/tmp/dag.py", + timetable_summary="2 2 * * *", + is_active=True, + is_paused=False, + asset_expression={ + "any": [ + "s3://dag1/output_1.txt", + {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, + ] + }, + ) + session.add(dag_model) + + @provide_session + def _create_deactivated_dag(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_DELETED_1", + fileloc="/tmp/dag_del_1.py", + timetable_summary="2 2 * * *", + is_active=False, + ) + session.add(dag_model) + + +class TestGetDag(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(1) + response = self.client.get( + "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + + def test_should_respond_403_with_granular_access_for_different_dag(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 403 + + +class TestGetDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.get( + "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" + + +class TestPatchDag(TestDagEndpoint): + @provide_session + def _create_dag_model(self, session=None): + dag_model = DagModel( + dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True + ) + session.add(dag_model) + return dag_model + + def test_should_respond_200_on_patch_with_granular_dag_access(self, session): + self._create_dag_models(1) + response = self.client.patch( + "/api/v1/dags/TEST_DAG_1", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", execution_date=None) + + def test_validation_error_raises_400(self): + patch_body = { + "ispaused": True, + } + dag_model = self._create_dag_model() + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", + json=patch_body, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'ispaused': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + + +class TestPatchDags(TestDagEndpoint): + def test_should_respond_200_with_granular_dag_access(self): + self._create_dag_models(3) + response = self.client.patch( + "api/v1/dags?dag_id_pattern=~", + json={ + "is_paused": False, + }, + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + assert response.status_code == 200 + assert len(response.json["dags"]) == 1 + assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py new file mode 100644 index 0000000000000..9eab53aaa0cd4 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import ast +import os +from typing import TYPE_CHECKING + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_dag_code, clear_db_dags, clear_db_serialized_dags + +from airflow.models import DagBag +from airflow.security import permissions + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +if TYPE_CHECKING: + from airflow.models.dag import DAG + +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir)) +EXAMPLE_DAG_FILE = os.path.join("airflow", "example_dags", "example_bash_operator.py") +EXAMPLE_DAG_ID = "example_bash_operator" +TEST_DAG_ID = "latest_only" +NOT_READABLE_DAG_ID = "latest_only_with_trigger" +TEST_MULTIPLE_DAGS_ID = "asset_produces_1" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test", + role_name="Test", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + EXAMPLE_DAG_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + TEST_MULTIPLE_DAGS_ID, + access_control={"Test": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test") + + +class TestGetSource: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.clear_db() + + def teardown_method(self) -> None: + self.clear_db() + + @staticmethod + def clear_db(): + clear_db_dags() + clear_db_serialized_dags() + clear_db_dag_code() + + @staticmethod + def _get_dag_file_docstring(fileloc: str) -> str | None: + with open(fileloc) as f: + file_contents = f.read() + module = ast.parse(file_contents) + docstring = ast.get_docstring(module) + return docstring + + def test_should_respond_406(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + test_dag: DAG = dagbag.dags[TEST_DAG_ID] + + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" + response = self.client.get( + url, headers={"Accept": "image/webp"}, environ_overrides={"REMOTE_USER": "test"} + ) + + assert 406 == response.status_code + + def test_should_respond_403_not_readable(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[NOT_READABLE_DAG_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + read_dag = self.client.get( + f"/api/v1/dags/{NOT_READABLE_DAG_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 403 + + def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + dag: DAG = dagbag.dags[TEST_MULTIPLE_DAGS_ID] + + response = self.client.get( + f"/api/v1/dagSources/{url_safe_serializer.dumps(dag.fileloc)}", + headers={"Accept": "text/plain"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + + read_dag = self.client.get( + f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 403 + assert read_dag.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py new file mode 100644 index 0000000000000..b42d92d9cacbd --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_dag_warnings, clear_db_dags + +from airflow.models.dag import DagModel +from airflow.models.dagwarning import DagWarning +from airflow.security import permissions +from airflow.utils.session import create_session + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, # type:ignore + username="test_with_dag2_read", + role_name="TestWithDag2Read", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), + (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), + ], + ) + + yield app + + delete_user(app, username="test_with_dag2_read") + + +class TestBaseDagWarning: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + def teardown_method(self) -> None: + clear_db_dag_warnings() + clear_db_dags() + + +class TestGetDagWarningEndpoint(TestBaseDagWarning): + def setup_class(self): + clear_db_dag_warnings() + clear_db_dags() + + def setup_method(self): + with create_session() as session: + session.add(DagModel(dag_id="dag1")) + session.add(DagWarning("dag1", "non-existent pool", "test message")) + session.commit() + + def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): + response = self.client.get( + "/api/v1/dagWarnings", + environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, + query_string={"dag_id": "dag1"}, + ) + assert response.status_code == 403 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py new file mode 100644 index 0000000000000..225a79bd9ac67 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_event_log_endpoint.py @@ -0,0 +1,151 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_logs + +from airflow.models import Log +from airflow.security import permissions +from airflow.utils import timezone + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_granular", + role_name="TestGranular", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_1", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + app.appbuilder.sm.sync_perm_for_dag( + "TEST_DAG_ID_2", + access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular") + + +@pytest.fixture +def task_instance(session, create_task_instance, request): + return create_task_instance( + session=session, + dag_id="TEST_DAG_ID", + task_id="TEST_TASK_ID", + run_id="TEST_RUN_ID", + execution_date=request.instance.default_time, + ) + + +@pytest.fixture +def create_log_model(create_task_instance, task_instance, session, request): + def maker(event, when, **kwargs): + log_model = Log( + event=event, + task_instance=task_instance, + **kwargs, + ) + log_model.dttm = when + + session.add(log_model) + session.flush() + return log_model + + return maker + + +class TestEventLogEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_logs() + self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") + self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") + + def teardown_method(self) -> None: + clear_db_logs() + + +class TestGetEventLogs(TestEventLogEndpoint): + def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): + eventlog1 = create_log_model( + event="TEST_EVENT_1", + dag_id="TEST_DAG_ID_1", + task_id="TEST_TASK_ID_1", + owner="TEST_OWNER_1", + when=self.default_time, + ) + eventlog2 = create_log_model( + event="TEST_EVENT_2", + dag_id="TEST_DAG_ID_2", + task_id="TEST_TASK_ID_2", + owner="TEST_OWNER_2", + when=self.default_time_2, + ) + session.add_all([eventlog1, eventlog2]) + session.commit() + for attr in ["dag_id", "task_id", "owner", "event"]: + attr_value = f"TEST_{attr}_1".upper() + response = self.client.get( + f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == 1 + assert len(response.json["event_logs"]) == 1 + assert response.json["event_logs"][0][attr] == attr_value + + def test_should_filter_eventlogs_by_included_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 2 + assert response_data["total_entries"] == 2 + assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} + + def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): + for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: + create_log_model(event=event, when=self.default_time) + response = self.client.get( + "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", + environ_overrides={"REMOTE_USER": "test_granular"}, + ) + assert response.status_code == 200 + response_data = response.json + assert len(response_data["event_logs"]) == 1 + assert response_data["total_entries"] == 1 + assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py new file mode 100644 index 0000000000000..5bac8356af59a --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_import_error_endpoint.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS, ParseImportError +from tests_common.test_utils.db import clear_db_dags, clear_db_import_errors +from tests_common.test_utils.permissions import _resource_name + +from airflow.models.dag import DagModel +from airflow.security import permissions +from airflow.utils import timezone + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +TEST_DAG_IDS = ["test_dag", "test_dag2"] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_single_dag", + role_name="TestSingleDAG", + permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], + ) + # For some reason, DAG level permissions are not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestSingleDAG", + "perms": [ + ( + permissions.ACTION_CAN_READ, + _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), + ) + ], + } + ] + ) + + yield app + + delete_user(app, username="test_single_dag") + + +class TestBaseImportError: + timestamp = "2020-06-10T12:00" + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + + clear_db_import_errors() + clear_db_dags() + + def teardown_method(self) -> None: + clear_db_import_errors() + clear_db_dags() + + @staticmethod + def _normalize_import_errors(import_errors): + for i, import_error in enumerate(import_errors, 1): + import_error["import_error_id"] = i + + +class TestGetImportErrorEndpoint(TestBaseImportError): + def test_should_raise_403_forbidden_without_dag_read(self, session): + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 403 + + def test_should_return_200_with_single_dag_read(self, session): + dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + def test_should_return_200_redacted_with_single_dag_read_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py") + session.add(dag_model) + import_error = ParseImportError( + filename="Lorem_ipsum.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(import_error) + session.commit() + + response = self.client.get( + f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + response_data["import_error_id"] = 1 + assert { + "filename": "Lorem_ipsum.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + } == response_data + + +class TestGetImportErrorsEndpoint(TestBaseImportError): + def test_get_import_errors_single_dag(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = f"/tmp/{dag_id}.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + importerror = ParseImportError( + filename=fake_filename, + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/test_dag.py", + "import_error_id": 1, + "stack_trace": "Lorem ipsum", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data + + def test_get_import_errors_single_dag_in_dagfile(self, session): + for dag_id in TEST_DAG_IDS: + fake_filename = "/tmp/all_in_one.py" + dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename) + session.add(dag_model) + + importerror = ParseImportError( + filename="/tmp/all_in_one.py", + stacktrace="Lorem ipsum", + timestamp=timezone.parse(self.timestamp, timezone="UTC"), + ) + session.add(importerror) + session.commit() + + response = self.client.get( + "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} + ) + + assert response.status_code == 200 + response_data = response.json + self._normalize_import_errors(response_data["import_errors"]) + assert { + "import_errors": [ + { + "filename": "/tmp/all_in_one.py", + "import_error_id": 1, + "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", + "timestamp": "2020-06-10T12:00:00+00:00", + }, + ], + "total_entries": 1, + } == response_data diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py index a91a434412d9f..da3e9a06565d7 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_role_and_permission_endpoint.py @@ -17,24 +17,22 @@ from __future__ import annotations import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_role, + create_user, + delete_role, + delete_user, +) +from tests_common.test_utils.api_connexion_utils import assert_401 +from tests_common.test_utils.compat import ignore_provider_compatibility_error from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from tests.test_utils.compat import ignore_provider_compatibility_error +from airflow.security import permissions with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models import Role from airflow.providers.fab.auth_manager.security_manager.override import EXISTING_ROLES - -from airflow.security import permissions -from tests.test_utils.api_connexion_utils import ( - assert_401, - create_role, - create_user, - delete_role, - delete_user, -) - pytestmark = pytest.mark.db_test @@ -42,7 +40,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,11 +51,11 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") class TestRoleEndpoint: @@ -108,13 +106,13 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 @pytest.mark.parametrize( - "set_auto_role_public, expected_status_code", + "set_auth_role_public, expected_status_code", (("Public", 403), ("Admin", 200)), - indirect=["set_auto_role_public"], + indirect=["set_auth_role_public"], ) - def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): response = self.client.get("/auth/fab/v1/roles/Admin") - assert response.status_code == expected_status_code + assert response.status_code == expected_status_code, response.json class TestGetRolesEndpoint(TestRoleEndpoint): @@ -146,13 +144,13 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 @pytest.mark.parametrize( - "set_auto_role_public, expected_status_code", + "set_auth_role_public, expected_status_code", (("Public", 403), ("Admin", 200)), - indirect=["set_auto_role_public"], + indirect=["set_auth_role_public"], ) - def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): response = self.client.get("/auth/fab/v1/roles") - assert response.status_code == expected_status_code + assert response.status_code == expected_status_code, response.json class TestGetRolesEndpointPaginationandFilter(TestRoleEndpoint): @@ -208,13 +206,13 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 @pytest.mark.parametrize( - "set_auto_role_public, expected_status_code", + "set_auth_role_public, expected_status_code", (("Public", 403), ("Admin", 200)), - indirect=["set_auto_role_public"], + indirect=["set_auth_role_public"], ) - def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): response = self.client.get("/auth/fab/v1/permissions") - assert response.status_code == expected_status_code + assert response.status_code == expected_status_code, response.json class TestPostRole(TestRoleEndpoint): @@ -346,17 +344,17 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 @pytest.mark.parametrize( - "set_auto_role_public, expected_status_code", + "set_auth_role_public, expected_status_code", (("Public", 403), ("Admin", 200)), - indirect=["set_auto_role_public"], + indirect=["set_auth_role_public"], ) - def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): payload = { "name": "Test2", "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], } response = self.client.post("/auth/fab/v1/roles", json=payload) - assert response.status_code == expected_status_code + assert response.status_code == expected_status_code, response.json class TestDeleteRole(TestRoleEndpoint): @@ -393,14 +391,14 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 @pytest.mark.parametrize( - "set_auto_role_public, expected_status_code", + "set_auth_role_public, expected_status_code", (("Public", 403), ("Admin", 204)), - indirect=["set_auto_role_public"], + indirect=["set_auth_role_public"], ) - def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): role = create_role(self.app, "mytestrole") response = self.client.delete(f"/auth/fab/v1/roles/{role.name}") - assert response.status_code == expected_status_code + assert response.status_code == expected_status_code, response.location class TestPatchRole(TestRoleEndpoint): @@ -579,14 +577,14 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 @pytest.mark.parametrize( - "set_auto_role_public, expected_status_code", + "set_auth_role_public, expected_status_code", (("Public", 403), ("Admin", 200)), - indirect=["set_auto_role_public"], + indirect=["set_auth_role_public"], ) - def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + def test_with_auth_role_public_set(self, set_auth_role_public, expected_status_code): role = create_role(self.app, "mytestrole") response = self.client.patch( f"/auth/fab/v1/roles/{role.name}", json={"name": "mytest"}, ) - assert response.status_code == expected_status_code + assert response.status_code == expected_status_code, response.json diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py new file mode 100644 index 0000000000000..aaae228998e6e --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py @@ -0,0 +1,426 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import datetime as dt +import urllib + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_roles, + delete_user, +) +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields + +from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP +from airflow.models import DagRun, TaskInstance +from airflow.security import permissions +from airflow.utils.session import provide_session +from airflow.utils.state import State +from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + +DEFAULT_DATETIME_1 = datetime(2020, 1, 1) +DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00" +DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00" + +QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1) +QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2) + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_user( + app, + username="test_dag_read_only", + role_name="TestDagReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_task_read_only", + role_name="TestTaskReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + create_user( + app, + username="test_read_only_one_dag", + role_name="TestReadOnlyOneDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), + (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), + ], + ) + # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, + # so do it manually here: + app.appbuilder.sm.bulk_sync_roles( + [ + { + "role": "TestReadOnlyOneDag", + "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], + } + ] + ) + + yield app + + delete_user(app, username="test_dag_read_only") + delete_user(app, username="test_task_read_only") + delete_user(app, username="test_read_only_one_dag") + delete_roles(app) + + +class TestTaskInstanceEndpoint: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app, dagbag) -> None: + self.default_time = DEFAULT_DATETIME_1 + self.ti_init = { + "execution_date": self.default_time, + "state": State.RUNNING, + } + self.ti_extras = { + "start_date": self.default_time + dt.timedelta(days=1), + "end_date": self.default_time + dt.timedelta(days=2), + "pid": 100, + "duration": 10000, + "pool": "default_pool", + "queue": "default_queue", + "job_id": 0, + } + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_runs() + clear_rendered_ti_fields() + self.dagbag = dagbag + + def create_task_instances( + self, + session, + dag_id: str = "example_python_operator", + update_extras: bool = True, + task_instances=None, + dag_run_state=State.RUNNING, + with_ti_history=False, + ): + """Method to create task instances using kwargs and default arguments""" + + dag = self.dagbag.get_dag(dag_id) + tasks = dag.tasks + counter = len(tasks) + if task_instances is not None: + counter = min(len(task_instances), counter) + + run_id = "TEST_DAG_RUN_ID" + execution_date = self.ti_init.pop("execution_date", self.default_time) + dr = None + + tis = [] + for i in range(counter): + if task_instances is None: + pass + elif update_extras: + self.ti_extras.update(task_instances[i]) + else: + self.ti_init.update(task_instances[i]) + + if "execution_date" in self.ti_init: + run_id = f"TEST_DAG_RUN_ID_{i}" + execution_date = self.ti_init.pop("execution_date") + dr = None + + if not dr: + dr = DagRun( + run_id=run_id, + dag_id=dag_id, + execution_date=execution_date, + run_type=DagRunType.MANUAL, + state=dag_run_state, + ) + session.add(dr) + ti = TaskInstance(task=tasks[i], **self.ti_init) + session.add(ti) + ti.dag_run = dr + ti.note = "placeholder-note" + + for key, value in self.ti_extras.items(): + setattr(ti, key, value) + tis.append(ti) + + session.commit() + if with_ti_history: + for ti in tis: + ti.try_number = 1 + session.merge(ti) + session.commit() + dag.clear() + for ti in tis: + ti.try_number = 2 + ti.queue = "default_queue" + session.merge(ti) + session.commit() + return tis + + +class TestGetTaskInstance(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session) + # Update ti and set operator to None to + # test that operator field is nullable. + # This prevents issue when users upgrade to 2.0+ + # from 1.10.x + # https://github.com/apache/airflow/issues/14421 + session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") + session.commit() + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 + + +class TestGetTaskInstances(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, user, expected_ti", + [ + pytest.param( + { + "example_python_operator": 2, + "example_skip_dag": 1, + }, + "test_read_only_one_dag", + 2, + ), + pytest.param( + { + "example_python_operator": 1, + "example_skip_dag": 2, + }, + "test_read_only_one_dag", + 1, + ), + ], + ) + def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): + for dag_id in task_instances: + self.create_task_instances( + session, + task_instances=[ + {"execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)} + for i in range(task_instances[dag_id]) + ], + dag_id=dag_id, + ) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == 200 + assert response.json["total_entries"] == expected_ti + assert len(response.json["task_instances"]) == expected_ti + + +class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): + @pytest.mark.parametrize( + "task_instances, update_extras, payload, expected_ti_count, username", + [ + pytest.param( + [ + {"pool": "test_pool_1"}, + {"pool": "test_pool_2"}, + {"pool": "test_pool_3"}, + ], + True, + {"pool": ["test_pool_1", "test_pool_2"]}, + 2, + "test_dag_read_only", + id="test pool filter", + ), + pytest.param( + [ + {"state": State.RUNNING}, + {"state": State.QUEUED}, + {"state": State.SUCCESS}, + {"state": State.NONE}, + ], + False, + {"state": ["running", "queued", "none"]}, + 3, + "test_task_read_only", + id="test state filter", + ), + pytest.param( + [ + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + {"state": State.NONE}, + ], + False, + {}, + 4, + "test_task_read_only", + id="test dag with null states", + ), + pytest.param( + [ + {"end_date": DEFAULT_DATETIME_1}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "end_date_gte": DEFAULT_DATETIME_STR_1, + "end_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_task_read_only", + id="test end date filter", + ), + pytest.param( + [ + {"start_date": DEFAULT_DATETIME_1}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, + {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, + ], + True, + { + "start_date_gte": DEFAULT_DATETIME_STR_1, + "start_date_lte": DEFAULT_DATETIME_STR_2, + }, + 2, + "test_dag_read_only", + id="test start date filter", + ), + ], + ) + def test_should_respond_200( + self, task_instances, update_extras, payload, expected_ti_count, username, session + ): + self.create_task_instances( + session, + update_extras=update_extras, + task_instances=task_instances, + ) + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": username}, + json=payload, + ) + assert response.status_code == 200, response.json + assert expected_ti_count == response.json["total_entries"] + assert expected_ti_count == len(response.json["task_instances"]) + + def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): + self.create_task_instances(session=session) + self.create_task_instances(session=session, dag_id="example_skip_dag") + payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} + + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, + json=payload, + ) + assert response.status_code == 403 + assert response.json == { + "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", + "status": 403, + "title": "Forbidden", + "type": EXCEPTIONS_LINK_MAP[403], + } + + +class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.post( + "/api/v1/dags/example_python_operator/updateTaskInstancesState", + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "task_id": "print_the_context", + "execution_date": DEFAULT_DATETIME_1.isoformat(), + "include_upstream": True, + "include_downstream": True, + "include_future": True, + "include_past": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestPatchTaskInstance(TestTaskInstanceEndpoint): + ENDPOINT_URL = ( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + def test_should_raise_403_forbidden(self, username): + response = self.client.patch( + self.ENDPOINT_URL, + environ_overrides={"REMOTE_USER": username}, + json={ + "dry_run": True, + "new_state": "failed", + }, + ) + assert response.status_code == 403 + + +class TestGetTaskInstanceTry(TestTaskInstanceEndpoint): + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) + @provide_session + def test_should_respond_200(self, username, session): + self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) + + response = self.client.get( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", + environ_overrides={"REMOTE_USER": username}, + ) + assert response.status_code == 200 diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py index e83d9fcf83736..eb1226d714e61 100644 --- a/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py +++ b/tests/providers/fab/auth_manager/api_endpoints/test_user_endpoint.py @@ -19,19 +19,24 @@ import unittest.mock import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( + create_user, + delete_role, + delete_user, +) from sqlalchemy.sql.functions import count +from tests_common.test_utils.api_connexion_utils import assert_401 +from tests_common.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.config import conf_vars from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import create_session -from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models import User -from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_role, delete_user -from tests.test_utils.config import conf_vars pytestmark = pytest.mark.db_test @@ -43,7 +48,7 @@ def configured_app(minimal_app_for_auth_api): app = minimal_app_for_auth_api create_user( - app, # type: ignore + app, username="test", role_name="Test", permissions=[ @@ -53,12 +58,12 @@ def configured_app(minimal_app_for_auth_api): (permissions.ACTION_CAN_READ, permissions.RESOURCE_USER), ], ) - create_user(app, username="test_no_permissions", role_name="TestNoPermissions") # type: ignore + create_user(app, username="test_no_permissions", role_name="TestNoPermissions") yield app - delete_user(app, username="test") # type: ignore - delete_user(app, username="test_no_permissions") # type: ignore + delete_user(app, username="test") + delete_user(app, username="test_no_permissions") delete_role(app, name="TestNoPermissions") @@ -425,6 +430,7 @@ def autoclean_admin_user(configured_app, autoclean_user_payload): class TestPostUser(TestUserEndpoint): def test_with_default_role(self, autoclean_username, autoclean_user_payload): + self.client.application.config["AUTH_USER_REGISTRATION_ROLE"] = "Public" response = self.client.post( "/auth/fab/v1/users", json=autoclean_user_payload, diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py new file mode 100644 index 0000000000000..920869d39e346 --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_variable_endpoint.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_variables + +from airflow.models import Variable +from airflow.security import permissions + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_read_only", + role_name="TestReadOnly", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), + ], + ) + create_user( + app, + username="test_delete_only", + role_name="TestDeleteOnly", + permissions=[ + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), + ], + ) + + yield app + + delete_user(app, username="test_read_only") + delete_user(app, username="test_delete_only") + + +class TestVariableEndpoint: + @pytest.fixture(autouse=True) + def setup_method(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + clear_db_variables() + + def teardown_method(self) -> None: + clear_db_variables() + + +class TestGetVariable(TestVariableEndpoint): + @pytest.mark.parametrize( + "user, expected_status_code", + [ + ("test_read_only", 200), + ("test_delete_only", 403), + ], + ) + def test_read_variable(self, user, expected_status_code): + expected_value = '{"foo": 1}' + Variable.set("TEST_VARIABLE_KEY", expected_value) + response = self.client.get( + "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} + ) + assert response.status_code == expected_status_code + if expected_status_code == 200: + assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} diff --git a/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py new file mode 100644 index 0000000000000..f0ec606038e1d --- /dev/null +++ b/tests/providers/fab/auth_manager/api_endpoints/test_xcom_endpoint.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import timedelta + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from tests_common.test_utils.compat import AIRFLOW_V_3_0_PLUS +from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom + +from airflow.models.dag import DagModel +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance +from airflow.models.xcom import BaseXCom, XCom +from airflow.operators.empty import EmptyOperator +from airflow.security import permissions +from airflow.utils import timezone +from airflow.utils.session import create_session +from airflow.utils.types import DagRunType + +pytestmark = [ + pytest.mark.db_test, + pytest.mark.skip_if_database_isolation_mode, + pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), +] + + +class CustomXCom(BaseXCom): + @classmethod + def deserialize_value(cls, xcom: XCom): + return f"real deserialized {super().deserialize_value(xcom)}" + + def orm_deserialize_value(self): + return f"orm deserialized {super().orm_deserialize_value()}" + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + + create_user( + app, + username="test_granular_permissions", + role_name="TestGranularDag", + permissions=[ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), + ], + ) + app.appbuilder.sm.sync_perm_for_dag( + "test-dag-id-1", + access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, + ) + + yield app + + delete_user(app, username="test_granular_permissions") + + +def _compare_xcom_collections(collection1: dict, collection_2: dict): + assert collection1.get("total_entries") == collection_2.get("total_entries") + + def sort_key(record): + return ( + record.get("dag_id"), + record.get("task_id"), + record.get("execution_date"), + record.get("map_index"), + record.get("key"), + ) + + assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted( + collection_2.get("xcom_entries", []), key=sort_key + ) + + +class TestXComEndpoint: + @staticmethod + def clean_db(): + clear_db_dags() + clear_db_runs() + clear_db_xcom() + + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + """ + Setup For XCom endpoint TC + """ + self.app = configured_app + self.client = self.app.test_client() # type:ignore + # clear existing xcoms + self.clean_db() + + def teardown_method(self) -> None: + """ + Clear Hanging XComs + """ + self.clean_db() + + +class TestGetXComEntries(TestXComEndpoint): + def test_should_respond_200_with_tilde_and_granular_dag_access(self): + dag_id_1 = "test-dag-id-1" + task_id_1 = "test-task-id-1" + execution_date = "2005-04-02T00:00:00+00:00" + execution_date_parsed = timezone.parse(execution_date) + dag_run_id_1 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_1, dag_run_id_1, execution_date_parsed, task_id_1) + + dag_id_2 = "test-dag-id-2" + task_id_2 = "test-task-id-2" + run_id_2 = DagRun.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + self._create_xcom_entries(dag_id_2, run_id_2, execution_date_parsed, task_id_2) + self._create_invalid_xcom_entries(execution_date_parsed) + response = self.client.get( + "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + environ_overrides={"REMOTE_USER": "test_granular_permissions"}, + ) + + assert 200 == response.status_code + response_data = response.json + for xcom_entry in response_data["xcom_entries"]: + xcom_entry["timestamp"] = "TIMESTAMP" + _compare_xcom_collections( + response_data, + { + "xcom_entries": [ + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-1", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + { + "dag_id": dag_id_1, + "execution_date": execution_date, + "key": "test-xcom-key-2", + "task_id": task_id_1, + "timestamp": "TIMESTAMP", + "map_index": -1, + }, + ], + "total_entries": 2, + }, + ) + + def _create_xcom_entries(self, dag_id, run_id, execution_date, task_id, mapped_ti=False): + with create_session() as session: + dag = DagModel(dag_id=dag_id) + session.add(dag) + dagrun = DagRun( + dag_id=dag_id, + run_id=run_id, + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + if mapped_ti: + for i in [0, 1]: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i) + ti.dag_id = dag_id + session.add(ti) + else: + ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) + ti.dag_id = dag_id + session.add(ti) + + for i in [1, 2]: + if mapped_ti: + key = "test-xcom-key" + map_index = i - 1 + else: + key = f"test-xcom-key-{i}" + map_index = -1 + + XCom.set( + key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index + ) + + def _create_invalid_xcom_entries(self, execution_date): + """ + Invalid XCom entries to test join query + """ + with create_session() as session: + dag = DagModel(dag_id="invalid_dag") + session.add(dag) + dagrun = DagRun( + dag_id="invalid_dag", + run_id="invalid_run_id", + execution_date=execution_date + timedelta(days=1), + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun) + dagrun1 = DagRun( + dag_id="invalid_dag", + run_id="not_this_run_id", + execution_date=execution_date, + start_date=execution_date, + run_type=DagRunType.MANUAL, + ) + session.add(dagrun1) + ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id") + ti.dag_id = "invalid_dag" + session.add(ti) + for i in [1, 2]: + XCom.set( + key=f"invalid-xcom-key-{i}", + value="TEST", + run_id="not_this_run_id", + task_id="invalid_task", + dag_id="invalid_dag", + ) diff --git a/tests/providers/fab/auth_manager/cli_commands/test_db_command.py b/tests/providers/fab/auth_manager/cli_commands/test_db_command.py new file mode 100644 index 0000000000000..030b251a55ad6 --- /dev/null +++ b/tests/providers/fab/auth_manager/cli_commands/test_db_command.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.cli import cli_parser + +pytestmark = [pytest.mark.db_test] +try: + from airflow.providers.fab.auth_manager.cli_commands import db_command + from airflow.providers.fab.auth_manager.models.db import FABDBManager + + class TestFABCLiDB: + @classmethod + def setup_class(cls): + cls.parser = cli_parser.get_parser() + + @mock.patch.object(FABDBManager, "resetdb") + def test_cli_resetdb(self, mock_resetdb): + db_command.resetdb(self.parser.parse_args(["fab-db", "reset", "--yes"])) + + mock_resetdb.assert_called_once_with(skip_init=False) + + @mock.patch.object(FABDBManager, "resetdb") + def test_cli_resetdb_skip_init(self, mock_resetdb): + db_command.resetdb(self.parser.parse_args(["fab-db", "reset", "--yes", "--skip-init"])) + mock_resetdb.assert_called_once_with(skip_init=True) + + @pytest.mark.parametrize( + "args, called_with", + [ + ( + [], + dict( + to_revision=None, + from_revision=None, + show_sql_only=False, + ), + ), + ( + ["--show-sql-only"], + dict( + to_revision=None, + from_revision=None, + show_sql_only=True, + ), + ), + ( + ["--to-revision", "abc"], + dict( + to_revision="abc", + from_revision=None, + show_sql_only=False, + ), + ), + ( + ["--to-revision", "abc", "--show-sql-only"], + dict(to_revision="abc", from_revision=None, show_sql_only=True), + ), + ( + ["--to-revision", "abc", "--from-revision", "abc123", "--show-sql-only"], + dict( + to_revision="abc", + from_revision="abc123", + show_sql_only=True, + ), + ), + ], + ) + @mock.patch.object(FABDBManager, "upgradedb") + def test_cli_upgrade_success(self, mock_upgradedb, args, called_with): + db_command.migratedb(self.parser.parse_args(["fab-db", "migrate", *args])) + mock_upgradedb.assert_called_once_with(**called_with) + + @pytest.mark.parametrize( + "args, pattern", + [ + pytest.param( + ["--to-revision", "abc", "--to-version", "1.3.0"], + "Cannot supply both", + id="to both version and revision", + ), + pytest.param( + ["--from-revision", "abc", "--from-version", "1.3.0"], + "Cannot supply both", + id="from both version and revision", + ), + pytest.param(["--to-version", "1.2.0"], "Unknown version '1.2.0'", id="unknown to version"), + pytest.param(["--to-version", "abc"], "Invalid version 'abc'", id="invalid to version"), + pytest.param( + ["--to-revision", "abc", "--from-revision", "abc123"], + "used with `--show-sql-only`", + id="requires offline", + ), + pytest.param( + ["--to-revision", "abc", "--from-version", "1.3.0"], + "used with `--show-sql-only`", + id="requires offline", + ), + pytest.param( + ["--to-revision", "abc", "--from-version", "1.1.25", "--show-sql-only"], + "Unknown version '1.1.25'", + id="unknown from version", + ), + pytest.param( + ["--to-revision", "adaf", "--from-version", "abc", "--show-sql-only"], + "Invalid version 'abc'", + id="invalid from version", + ), + ], + ) + @mock.patch.object(FABDBManager, "upgradedb") + def test_cli_migratedb_failure(self, mock_upgradedb, args, pattern): + with pytest.raises(SystemExit, match=pattern): + db_command.migratedb(self.parser.parse_args(["fab-db", "migrate", *args])) +except (ModuleNotFoundError, ImportError): + pass diff --git a/tests/providers/fab/auth_manager/cli_commands/test_definition.py b/tests/providers/fab/auth_manager/cli_commands/test_definition.py index 2db5d352ecc19..572cbee05e3db 100644 --- a/tests/providers/fab/auth_manager/cli_commands/test_definition.py +++ b/tests/providers/fab/auth_manager/cli_commands/test_definition.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from tests.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.cli_commands.definition import ( diff --git a/tests/providers/fab/auth_manager/cli_commands/test_role_command.py b/tests/providers/fab/auth_manager/cli_commands/test_role_command.py index 9c9be088e87df..d07cfc61242fe 100644 --- a/tests/providers/fab/auth_manager/cli_commands/test_role_command.py +++ b/tests/providers/fab/auth_manager/cli_commands/test_role_command.py @@ -23,9 +23,10 @@ from typing import TYPE_CHECKING import pytest +from tests_common.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.config import conf_vars from airflow.cli import cli_parser -from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.cli_commands import role_command @@ -47,11 +48,12 @@ class TestCliRoles: @pytest.fixture(autouse=True) def _set_attrs(self): self.parser = cli_parser.get_parser() - with get_application_builder() as appbuilder: - self.appbuilder = appbuilder - self.clear_users_and_roles() - yield - self.clear_users_and_roles() + with conf_vars({("fab", "UPDATE_FAB_PERMS"): "False"}): + with get_application_builder() as appbuilder: + self.appbuilder = appbuilder + self.clear_users_and_roles() + yield + self.clear_users_and_roles() def clear_users_and_roles(self): session = self.appbuilder.get_session diff --git a/tests/providers/fab/auth_manager/cli_commands/test_sync_perm_command.py b/tests/providers/fab/auth_manager/cli_commands/test_sync_perm_command.py index 9e1817bd5617c..a0345909dc2df 100644 --- a/tests/providers/fab/auth_manager/cli_commands/test_sync_perm_command.py +++ b/tests/providers/fab/auth_manager/cli_commands/test_sync_perm_command.py @@ -20,9 +20,9 @@ from unittest import mock import pytest +from tests_common.test_utils.compat import ignore_provider_compatibility_error from airflow.cli import cli_parser -from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.cli_commands import sync_perm_command diff --git a/tests/providers/fab/auth_manager/cli_commands/test_user_command.py b/tests/providers/fab/auth_manager/cli_commands/test_user_command.py index b8ce2f48d6c03..6ccd4c99716ab 100644 --- a/tests/providers/fab/auth_manager/cli_commands/test_user_command.py +++ b/tests/providers/fab/auth_manager/cli_commands/test_user_command.py @@ -24,9 +24,9 @@ from io import StringIO import pytest +from tests_common.test_utils.compat import ignore_provider_compatibility_error from airflow.cli import cli_parser -from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.cli_commands import user_command diff --git a/tests/providers/fab/auth_manager/cli_commands/test_utils.py b/tests/providers/fab/auth_manager/cli_commands/test_utils.py index fd8b1dfd50c89..394beb88290ed 100644 --- a/tests/providers/fab/auth_manager/cli_commands/test_utils.py +++ b/tests/providers/fab/auth_manager/cli_commands/test_utils.py @@ -16,19 +16,68 @@ # under the License. from __future__ import annotations +import os + import pytest +from tests_common.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.config import conf_vars -from tests.test_utils.compat import ignore_provider_compatibility_error +import airflow +from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException +from airflow.www.extensions.init_appbuilder import AirflowAppBuilder +from airflow.www.session import AirflowDatabaseSessionInterface with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.cli_commands.utils import get_application_builder -from airflow.www.extensions.init_appbuilder import AirflowAppBuilder - pytestmark = pytest.mark.db_test +@pytest.fixture +def flask_app(): + """Fixture to set up the Flask app with the necessary configuration.""" + # Get the webserver config file path + webserver_config = conf.get_mandatory_value("webserver", "config_file") + + with get_application_builder() as appbuilder: + flask_app = appbuilder.app + + # Load webserver configuration + flask_app.config.from_pyfile(webserver_config, silent=True) + + yield flask_app + + class TestCliUtils: def test_get_application_builder(self): + """Test that get_application_builder returns an AirflowAppBuilder instance.""" with get_application_builder() as appbuilder: assert isinstance(appbuilder, AirflowAppBuilder) + + def test_sqlalchemy_uri_configured(self, flask_app): + """Test that the SQLALCHEMY_DATABASE_URI is correctly set in the Flask app.""" + sqlalchemy_uri = conf.get("database", "SQL_ALCHEMY_CONN") + + # Assert that the SQLAlchemy URI is correctly set + assert sqlalchemy_uri == flask_app.config["SQLALCHEMY_DATABASE_URI"] + + def test_relative_path_sqlite_raises_exception(self): + """Test that a relative SQLite path raises an AirflowConfigException.""" + # Directly simulate the configuration for relative SQLite path + with conf_vars({("database", "SQL_ALCHEMY_CONN"): "sqlite://relative/path"}): + with pytest.raises(AirflowConfigException, match="Cannot use relative path"): + with get_application_builder(): + pass + + def test_static_folder_exists(self, flask_app): + """Test that the static folder is correctly configured in the Flask app.""" + static_folder = os.path.join(os.path.dirname(airflow.__file__), "www", "static") + assert flask_app.static_folder == static_folder + + def test_database_auth_backend_in_session(self, flask_app): + """Test that the database is used for session management when AUTH_BACKEND is set to 'database'.""" + with get_application_builder() as appbuilder: + flask_app = appbuilder.app + # Ensure that the correct session interface is set (for 'database' auth backend) + assert isinstance(flask_app.session_interface, AirflowDatabaseSessionInterface) diff --git a/tests/providers/fab/auth_manager/conftest.py b/tests/providers/fab/auth_manager/conftest.py index 6b4feb143f4b5..d3e14c6a520cd 100644 --- a/tests/providers/fab/auth_manager/conftest.py +++ b/tests/providers/fab/auth_manager/conftest.py @@ -28,13 +28,27 @@ def minimal_app_for_auth_api(): @dont_initialize_flask_app_submodules( skip_all_except=[ "init_appbuilder", - "init_api_experimental_auth", + "init_api_auth", "init_api_auth_provider", + "init_api_connexion", "init_api_error_handlers", + "init_airflow_session_interface", + "init_appbuilder_views", ] ) def factory(): - with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): + with conf_vars( + { + ( + "api", + "auth_backends", + ): "providers.tests.fab.auth_manager.api_endpoints.remote_user_api_auth_backend,airflow.providers.fab.auth_manager.api.auth.backend.session", + ( + "core", + "auth_manager", + ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", + } + ): _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore _app.config["AUTH_ROLE_PUBLIC"] = None return _app @@ -43,7 +57,7 @@ def factory(): @pytest.fixture -def set_auto_role_public(request): +def set_auth_role_public(request): app = request.getfixturevalue("minimal_app_for_auth_api") auto_role_public = app.config["AUTH_ROLE_PUBLIC"] app.config["AUTH_ROLE_PUBLIC"] = request.param @@ -51,3 +65,11 @@ def set_auto_role_public(request): yield app.config["AUTH_ROLE_PUBLIC"] = auto_role_public + + +@pytest.fixture(scope="module") +def dagbag(): + from airflow.models import DagBag + + DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() + return DagBag(include_examples=True, read_dags_from_db=True) diff --git a/tests/providers/fab/auth_manager/decorators/test_auth.py b/tests/providers/fab/auth_manager/decorators/test_auth.py index 98f77a4f34271..d6f60bf42f1fb 100644 --- a/tests/providers/fab/auth_manager/decorators/test_auth.py +++ b/tests/providers/fab/auth_manager/decorators/test_auth.py @@ -19,9 +19,9 @@ from unittest.mock import Mock, patch import pytest +from tests_common.test_utils.compat import ignore_provider_compatibility_error from airflow.security.permissions import ACTION_CAN_READ, RESOURCE_DAG -from tests.test_utils.compat import ignore_provider_compatibility_error permissions = [(ACTION_CAN_READ, RESOURCE_DAG)] diff --git a/tests/providers/fab/auth_manager/models/test_anonymous_user.py b/tests/providers/fab/auth_manager/models/test_anonymous_user.py index 4e365e3c8b705..eaf6b357f9264 100644 --- a/tests/providers/fab/auth_manager/models/test_anonymous_user.py +++ b/tests/providers/fab/auth_manager/models/test_anonymous_user.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -from tests.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser diff --git a/tests/providers/fab/auth_manager/models/test_db.py b/tests/providers/fab/auth_manager/models/test_db.py new file mode 100644 index 0000000000000..3af94ceed7b18 --- /dev/null +++ b/tests/providers/fab/auth_manager/models/test_db.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import re +from unittest import mock + +import pytest +from alembic.autogenerate import compare_metadata +from alembic.migration import MigrationContext +from sqlalchemy import MetaData + +import airflow.providers +from airflow.settings import engine +from airflow.utils.db import ( + compare_server_default, + compare_type, +) + +pytestmark = [pytest.mark.db_test] +try: + from airflow.providers.fab.auth_manager.models.db import FABDBManager + + class TestFABDBManager: + def setup_method(self): + self.providers_dir: str = airflow.providers.__path__[0] + + def test_version_table_name_set(self, session): + assert FABDBManager(session=session).version_table_name == "alembic_version_fab" + + def test_migration_dir_set(self, session): + assert FABDBManager(session=session).migration_dir == f"{self.providers_dir}/fab/migrations" + + def test_alembic_file_set(self, session): + assert FABDBManager(session=session).alembic_file == f"{self.providers_dir}/fab/alembic.ini" + + def test_supports_table_dropping_set(self, session): + assert FABDBManager(session=session).supports_table_dropping is True + + def test_database_schema_and_sqlalchemy_model_are_in_sync(self, session): + def include_object(_, name, type_, *args): + if type_ == "table" and name not in FABDBManager(session=session).metadata.tables: + return False + return True + + all_meta_data = MetaData() + for table_name, table in FABDBManager(session=session).metadata.tables.items(): + all_meta_data._add_table(table_name, table.schema, table) + # create diff between database schema and SQLAlchemy model + mctx = MigrationContext.configure( + engine.connect(), + opts={ + "compare_type": compare_type, + "compare_server_default": compare_server_default, + "include_object": include_object, + }, + ) + diff = compare_metadata(mctx, all_meta_data) + + assert not diff, "Database schema and SQLAlchemy model are not in sync: " + str(diff) + + @mock.patch("airflow.providers.fab.auth_manager.models.db._offline_migration") + def test_downgrade_sql_no_from(self, mock_om, session, caplog): + FABDBManager(session=session).downgrade(to_revision="abc", show_sql_only=True, from_revision=None) + actual = mock_om.call_args.kwargs["revision"] + assert re.match(r"[a-z0-9]+:abc", actual) is not None + + @mock.patch("airflow.providers.fab.auth_manager.models.db._offline_migration") + def test_downgrade_sql_with_from(self, mock_om, session): + FABDBManager(session=session).downgrade( + to_revision="abc", show_sql_only=True, from_revision="123" + ) + actual = mock_om.call_args.kwargs["revision"] + assert actual == "123:abc" + + @mock.patch("alembic.command.downgrade") + def test_downgrade_invalid_combo(self, mock_om, session): + """can't combine `sql=False` and `from_revision`""" + with pytest.raises(ValueError, match="can't be combined"): + FABDBManager(session=session).downgrade(to_revision="abc", from_revision="123") + + @mock.patch("alembic.command.downgrade") + def test_downgrade_with_from(self, mock_om, session): + FABDBManager(session=session).downgrade(to_revision="abc") + actual = mock_om.call_args.kwargs["revision"] + assert actual == "abc" + + @mock.patch.object(FABDBManager, "get_current_revision") + def test_sqlite_offline_upgrade_raises_with_revision(self, mock_gcr, session): + with mock.patch( + "airflow.providers.fab.auth_manager.models.db.settings.engine.dialect" + ) as dialect: + dialect.name = "sqlite" + with pytest.raises(SystemExit, match="Offline migration not supported for SQLite"): + FABDBManager(session).upgradedb(from_revision=None, to_revision=None, show_sql_only=True) + + @mock.patch("airflow.utils.db_manager.inspect") + @mock.patch.object(FABDBManager, "metadata") + def test_drop_tables(self, mock_metadata, mock_inspect, session): + manager = FABDBManager(session) + connection = mock.MagicMock() + manager.drop_tables(connection) + mock_metadata.drop_all.assert_called_once_with(connection) + + @pytest.mark.parametrize("skip_init", [True, False]) + @mock.patch.object(FABDBManager, "drop_tables") + @mock.patch.object(FABDBManager, "initdb") + @mock.patch("airflow.utils.db.create_global_lock", new=mock.MagicMock) + def test_resetdb(self, mock_initdb, mock_drop_tables, session, skip_init): + manager = FABDBManager(session) + manager.resetdb(skip_init=skip_init) + mock_drop_tables.assert_called_once() + if skip_init: + mock_initdb.assert_not_called() + else: + mock_initdb.assert_called_once() +except ModuleNotFoundError: + pass diff --git a/tests/providers/fab/auth_manager/schemas/__init__.py b/tests/providers/fab/auth_manager/schemas/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/fab/auth_manager/schemas/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/fab/auth_manager/schemas/test_role_and_permission_schema.py b/tests/providers/fab/auth_manager/schemas/test_role_and_permission_schema.py new file mode 100644 index 0000000000000..f8364a2e47222 --- /dev/null +++ b/tests/providers/fab/auth_manager/schemas/test_role_and_permission_schema.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_role, delete_role + +from airflow.providers.fab.auth_manager.schemas.role_and_permission_schema import ( + RoleCollection, + role_collection_schema, + role_schema, +) +from airflow.security import permissions + +pytestmark = [pytest.mark.db_test, pytest.mark.skip_if_database_isolation_mode] + + +class TestRoleCollectionItemSchema: + @pytest.fixture(scope="class") + def role(self, minimal_app_for_auth_api): + yield create_role( + minimal_app_for_auth_api, # type: ignore + name="Test", + permissions=[ + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), + ], + ) + delete_role(minimal_app_for_auth_api, "Test") + + @pytest.fixture(autouse=True) + def _set_attrs(self, minimal_app_for_auth_api, role): + self.app = minimal_app_for_auth_api + self.role = role + + def test_serialize(self): + deserialized_role = role_schema.dump(self.role) + assert deserialized_role == { + "name": "Test", + "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], + } + + def test_deserialize(self): + role = { + "name": "Test", + "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], + } + role_obj = role_schema.load(role) + assert role_obj == { + "name": "Test", + "permissions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], + } + + +class TestRoleCollectionSchema: + @pytest.fixture(scope="class") + def role1(self, minimal_app_for_auth_api): + yield create_role( + minimal_app_for_auth_api, # type: ignore + name="Test1", + permissions=[ + (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION), + ], + ) + delete_role(minimal_app_for_auth_api, "Test1") + + @pytest.fixture(scope="class") + def role2(self, minimal_app_for_auth_api): + yield create_role( + minimal_app_for_auth_api, # type: ignore + name="Test2", + permissions=[ + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + ], + ) + delete_role(minimal_app_for_auth_api, "Test2") + + def test_serialize(self, role1, role2): + instance = RoleCollection([role1, role2], total_entries=2) + deserialized = role_collection_schema.dump(instance) + assert deserialized == { + "roles": [ + { + "name": "Test1", + "actions": [{"resource": {"name": "Connections"}, "action": {"name": "can_create"}}], + }, + { + "name": "Test2", + "actions": [{"resource": {"name": "DAGs"}, "action": {"name": "can_edit"}}], + }, + ], + "total_entries": 2, + } diff --git a/tests/providers/fab/auth_manager/schemas/test_user_schema.py b/tests/providers/fab/auth_manager/schemas/test_user_schema.py new file mode 100644 index 0000000000000..f6b07327c09ce --- /dev/null +++ b/tests/providers/fab/auth_manager/schemas/test_user_schema.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_role, delete_role +from tests_common.test_utils.compat import ignore_provider_compatibility_error + +from airflow.utils import timezone + +with ignore_provider_compatibility_error("2.9.0+", __file__): + from airflow.providers.fab.auth_manager.models import User + from airflow.providers.fab.auth_manager.schemas.user_schema import ( + user_collection_item_schema, + user_schema, + ) + + +TEST_EMAIL = "test@example.org" + +DEFAULT_TIME = "2021-01-09T13:59:56.336000+00:00" + +pytestmark = pytest.mark.db_test + + +@pytest.fixture(scope="module") +def configured_app(minimal_app_for_auth_api): + app = minimal_app_for_auth_api + create_role( + app, + name="TestRole", + permissions=[], + ) + yield app + + delete_role(app, "TestRole") # type:ignore + + +class TestUserBase: + @pytest.fixture(autouse=True) + def setup_attrs(self, configured_app) -> None: + self.app = configured_app + self.client = self.app.test_client() # type:ignore + self.role = self.app.appbuilder.sm.find_role("TestRole") + self.session = self.app.appbuilder.get_session + + def teardown_method(self): + user = self.session.query(User).filter(User.email == TEST_EMAIL).first() + if user: + self.session.delete(user) + self.session.commit() + + +class TestUserCollectionItemSchema(TestUserBase): + def test_serialize(self): + user_model = User( + first_name="Foo", + last_name="Bar", + username="test", + password="test", + email=TEST_EMAIL, + created_on=timezone.parse(DEFAULT_TIME), + changed_on=timezone.parse(DEFAULT_TIME), + ) + self.session.add(user_model) + user_model.roles = [self.role] + self.session.commit() + user = self.session.query(User).filter(User.email == TEST_EMAIL).first() + deserialized_user = user_collection_item_schema.dump(user) + # No user_id and password in dump + assert deserialized_user == { + "created_on": DEFAULT_TIME, + "email": "test@example.org", + "changed_on": DEFAULT_TIME, + "active": True, + "last_login": None, + "last_name": "Bar", + "fail_login_count": None, + "first_name": "Foo", + "username": "test", + "login_count": None, + "roles": [{"name": "TestRole"}], + } + + +class TestUserSchema(TestUserBase): + def test_serialize(self): + user_model = User( + first_name="Foo", + last_name="Bar", + username="test", + password="test", + email=TEST_EMAIL, + created_on=timezone.parse(DEFAULT_TIME), + changed_on=timezone.parse(DEFAULT_TIME), + ) + self.session.add(user_model) + self.session.commit() + user = self.session.query(User).filter(User.email == TEST_EMAIL).first() + deserialized_user = user_schema.dump(user) + # No user_id and password in dump + assert deserialized_user == { + "roles": [], + "created_on": DEFAULT_TIME, + "email": "test@example.org", + "changed_on": DEFAULT_TIME, + "active": True, + "last_login": None, + "last_name": "Bar", + "fail_login_count": None, + "first_name": "Foo", + "username": "test", + "login_count": None, + } + + def test_deserialize_user(self): + user_dump = { + "roles": [{"name": "TestRole"}], + "email": "test@example.org", + "last_name": "Bar", + "first_name": "Foo", + "username": "test", + "password": "test", # loads password + } + result = user_schema.load(user_dump) + assert result == { + "roles": [{"name": "TestRole"}], + "email": "test@example.org", + "last_name": "Bar", + "first_name": "Foo", + "username": "test", + "password": "test", # Password loaded + } diff --git a/tests/providers/fab/auth_manager/security_manager/test_constants.py b/tests/providers/fab/auth_manager/security_manager/test_constants.py index 5a718eee4b639..dbe592c59d747 100644 --- a/tests/providers/fab/auth_manager/security_manager/test_constants.py +++ b/tests/providers/fab/auth_manager/security_manager/test_constants.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from tests.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.security_manager.constants import EXISTING_ROLES diff --git a/tests/providers/fab/auth_manager/security_manager/test_override.py b/tests/providers/fab/auth_manager/security_manager/test_override.py index 6d85c0319dc44..6ba1ccda292cd 100644 --- a/tests/providers/fab/auth_manager/security_manager/test_override.py +++ b/tests/providers/fab/auth_manager/security_manager/test_override.py @@ -19,7 +19,7 @@ from unittest import mock from unittest.mock import Mock -from tests.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride diff --git a/tests/providers/fab/auth_manager/test_fab_auth_manager.py b/tests/providers/fab/auth_manager/test_fab_auth_manager.py index 3b0949d551d88..d298f7667eaaf 100644 --- a/tests/providers/fab/auth_manager/test_fab_auth_manager.py +++ b/tests/providers/fab/auth_manager/test_fab_auth_manager.py @@ -16,13 +16,14 @@ # under the License. from __future__ import annotations +from contextlib import contextmanager from itertools import chain from typing import TYPE_CHECKING from unittest import mock from unittest.mock import Mock import pytest -from flask import Flask +from flask import Flask, g from airflow.exceptions import AirflowConfigException, AirflowException @@ -31,13 +32,14 @@ except ImportError: pass -from tests.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager from airflow.providers.fab.auth_manager.models import User from airflow.providers.fab.auth_manager.security_manager.override import FabAirflowSecurityManagerOverride +from airflow.providers.common.compat.security.permissions import RESOURCE_ASSET from airflow.security.permissions import ( ACTION_CAN_ACCESS_MENU, ACTION_CAN_CREATE, @@ -48,7 +50,6 @@ RESOURCE_CONNECTION, RESOURCE_DAG, RESOURCE_DAG_RUN, - RESOURCE_DATASET, RESOURCE_DOCS, RESOURCE_JOB, RESOURCE_PLUGIN, @@ -63,14 +64,22 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod + IS_AUTHORIZED_METHODS_SIMPLE = { "is_authorized_configuration": RESOURCE_CONFIG, "is_authorized_connection": RESOURCE_CONNECTION, - "is_authorized_dataset": RESOURCE_DATASET, + "is_authorized_asset": RESOURCE_ASSET, "is_authorized_variable": RESOURCE_VARIABLE, } +@contextmanager +def user_set(app, user): + g.user = user + yield + g.user = None + + @pytest.fixture def auth_manager(): return FabAuthManager(None) @@ -113,20 +122,43 @@ def test_get_user_display_name( assert auth_manager.get_user_display_name() == expected @mock.patch("flask_login.utils._get_user") - def test_get_user(self, mock_current_user, auth_manager): + def test_get_user(self, mock_current_user, minimal_app_for_auth_api, auth_manager): user = Mock() user.is_anonymous.return_value = True mock_current_user.return_value = user + with minimal_app_for_auth_api.app_context(): + assert auth_manager.get_user() == user - assert auth_manager.get_user() == user + @mock.patch("flask_login.utils._get_user") + def test_get_user_from_flask_g(self, mock_current_user, minimal_app_for_auth_api, auth_manager): + session_user = Mock() + session_user.is_anonymous = True + mock_current_user.return_value = session_user + + flask_g_user = Mock() + flask_g_user.is_anonymous = False + with minimal_app_for_auth_api.app_context(): + with user_set(minimal_app_for_auth_api, flask_g_user): + assert auth_manager.get_user() == flask_g_user + @pytest.mark.db_test @mock.patch.object(FabAuthManager, "get_user") - def test_is_logged_in(self, mock_get_user, auth_manager): + def test_is_logged_in(self, mock_get_user, auth_manager_with_appbuilder): user = Mock() user.is_anonymous.return_value = True mock_get_user.return_value = user - assert auth_manager.is_logged_in() is False + assert auth_manager_with_appbuilder.is_logged_in() is False + + @pytest.mark.db_test + @mock.patch.object(FabAuthManager, "get_user") + def test_is_logged_in_with_inactive_user(self, mock_get_user, auth_manager_with_appbuilder): + user = Mock() + user.is_anonymous.return_value = False + user.is_active.return_value = True + mock_get_user.return_value = user + + assert auth_manager_with_appbuilder.is_logged_in() is False @pytest.mark.parametrize( "api_name, method, user_permissions, expected_result", diff --git a/tests/providers/fab/auth_manager/test_models.py b/tests/providers/fab/auth_manager/test_models.py index 30677d7095753..76b69c3dea734 100644 --- a/tests/providers/fab/auth_manager/test_models.py +++ b/tests/providers/fab/auth_manager/test_models.py @@ -19,8 +19,7 @@ from unittest import mock from sqlalchemy import Column, MetaData, String, Table - -from tests.test_utils.compat import ignore_provider_compatibility_error +from tests_common.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.models import ( diff --git a/tests/providers/fab/auth_manager/test_security.py b/tests/providers/fab/auth_manager/test_security.py index ac89018c995f2..eb99daf9b9b53 100644 --- a/tests/providers/fab/auth_manager/test_security.py +++ b/tests/providers/fab/auth_manager/test_security.py @@ -30,36 +30,36 @@ from flask_appbuilder import SQLA, Model, expose, has_access from flask_appbuilder.views import BaseView, ModelView from sqlalchemy import Column, Date, Float, Integer, String +from tests_common.test_utils.compat import ignore_provider_compatibility_error from airflow.configuration import initialize_config from airflow.exceptions import AirflowException from airflow.models import DagModel -from airflow.models.base import Base from airflow.models.dag import DAG -from tests.test_utils.compat import ignore_provider_compatibility_error with ignore_provider_compatibility_error("2.9.0+", __file__): from airflow.providers.fab.auth_manager.fab_auth_manager import FabAuthManager - from airflow.providers.fab.auth_manager.models import User, assoc_permission_role + from airflow.providers.fab.auth_manager.models import assoc_permission_role from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser -from airflow.security import permissions -from airflow.security.permissions import ACTION_CAN_READ -from airflow.www import app as application -from airflow.www.auth import get_access_denied_message -from airflow.www.extensions.init_auth_manager import get_auth_manager -from airflow.www.utils import CustomSQLAInterface -from tests.test_utils.api_connexion_utils import ( +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import ( create_user, create_user_scope, delete_role, delete_user, set_user_single_role, ) -from tests.test_utils.asserts import assert_queries_count -from tests.test_utils.db import clear_db_dags, clear_db_runs -from tests.test_utils.mock_security_manager import MockSecurityManager -from tests.test_utils.permissions import _resource_name +from tests_common.test_utils.asserts import assert_queries_count +from tests_common.test_utils.db import clear_db_dags, clear_db_runs +from tests_common.test_utils.mock_security_manager import MockSecurityManager +from tests_common.test_utils.permissions import _resource_name + +from airflow.security import permissions +from airflow.security.permissions import ACTION_CAN_READ +from airflow.www import app as application +from airflow.www.auth import get_access_denied_message +from airflow.www.extensions.init_auth_manager import get_auth_manager +from airflow.www.utils import CustomSQLAInterface pytestmark = pytest.mark.db_test @@ -514,7 +514,10 @@ def test_get_accessible_dag_ids(mock_is_logged_in, app, security_manager, sessio ], ) as user: mock_is_logged_in.return_value = True - dag_model = DagModel(dag_id=dag_id, fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *") + if hasattr(DagModel, "schedule_interval"): # Airflow 2 compat. + dag_model = DagModel(dag_id=dag_id, fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *") + else: # Airflow 3. + dag_model = DagModel(dag_id=dag_id, fileloc="/tmp/dag_.py", timetable_summary="2 2 * * *") session.add(dag_model) session.commit() @@ -545,7 +548,10 @@ def test_dont_get_inaccessible_dag_ids_for_dag_resource_permission( ], ) as user: mock_is_logged_in.return_value = True - dag_model = DagModel(dag_id=dag_id, fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *") + if hasattr(DagModel, "schedule_interval"): # Airflow 2 compat. + dag_model = DagModel(dag_id=dag_id, fileloc="/tmp/dag_.py", schedule_interval="2 2 * * *") + else: # Airflow 3. + dag_model = DagModel(dag_id=dag_id, fileloc="/tmp/dag_.py", timetable_summary="2 2 * * *") session.add(dag_model) session.commit() @@ -851,11 +857,22 @@ def test_access_control_is_set_on_init( ) +@pytest.mark.parametrize( + "access_control_before, access_control_after", + [ + (READ_WRITE, READ_ONLY), + # old access control format + ({permissions.ACTION_CAN_READ, permissions.ACTION_CAN_EDIT}, {permissions.ACTION_CAN_READ}), + ], + ids=["new_access_control_format", "old_access_control_format"], +) def test_access_control_stale_perms_are_revoked( app, security_manager, assert_user_has_dag_perms, assert_user_does_not_have_dag_perms, + access_control_before, + access_control_after, ): username = "access_control_stale_perms_are_revoked" role_name = "team-a" @@ -868,12 +885,12 @@ def test_access_control_stale_perms_are_revoked( ) as user: set_user_single_role(app, user, role_name="team-a") security_manager._sync_dag_view_permissions( - "access_control_test", access_control={"team-a": READ_WRITE} + "access_control_test", access_control={"team-a": access_control_before} ) assert_user_has_dag_perms(perms=["GET", "PUT"], dag_id="access_control_test", user=user) security_manager._sync_dag_view_permissions( - "access_control_test", access_control={"team-a": READ_ONLY} + "access_control_test", access_control={"team-a": access_control_after} ) # Clear the cache, to make it pick up new rol perms user._perms = None @@ -927,7 +944,7 @@ def test_create_dag_specific_permissions(session, security_manager, monkeypatch, dagbag_mock.collect_dags_from_db = collect_dags_from_db_mock dagbag_class_mock = mock.Mock() dagbag_class_mock.return_value = dagbag_mock - import airflow.www.security + import airflow.providers.fab.auth_manager.security_manager monkeypatch.setitem( airflow.providers.fab.auth_manager.security_manager.override.__dict__, "DagBag", dagbag_class_mock @@ -1008,46 +1025,6 @@ def test_prefixed_dag_id_is_deprecated(security_manager): security_manager.prefixed_dag_id("hello") -def test_parent_dag_access_applies_to_subdag(app, security_manager, assert_user_has_dag_perms, session): - username = "dag_permission_user" - role_name = "dag_permission_role" - parent_dag_name = "parent_dag" - subdag_name = parent_dag_name + ".subdag" - subsubdag_name = parent_dag_name + ".subdag.subsubdag" - with app.app_context(): - mock_roles = [ - { - "role": role_name, - "perms": [ - (permissions.ACTION_CAN_READ, f"DAG:{parent_dag_name}"), - (permissions.ACTION_CAN_EDIT, f"DAG:{parent_dag_name}"), - ], - } - ] - with create_user_scope( - app, - username=username, - role_name=role_name, - ) as user: - dag1 = DagModel(dag_id=parent_dag_name) - dag2 = DagModel(dag_id=subdag_name, is_subdag=True, root_dag_id=parent_dag_name) - dag3 = DagModel(dag_id=subsubdag_name, is_subdag=True, root_dag_id=parent_dag_name) - session.add_all([dag1, dag2, dag3]) - session.commit() - security_manager.bulk_sync_roles(mock_roles) - for _ in [dag1, dag2, dag3]: - security_manager._sync_dag_view_permissions( - parent_dag_name, access_control={role_name: READ_WRITE} - ) - - assert_user_has_dag_perms(perms=["GET", "PUT"], dag_id=parent_dag_name, user=user) - assert_user_has_dag_perms(perms=["GET", "PUT"], dag_id=parent_dag_name + ".subdag", user=user) - assert_user_has_dag_perms( - perms=["GET", "PUT"], dag_id=parent_dag_name + ".subdag.subsubdag", user=user - ) - session.query(DagModel).delete() - - def test_permissions_work_for_dags_with_dot_in_dagname( app, security_manager, assert_user_has_dag_perms, assert_user_does_not_have_dag_perms, session ): @@ -1082,12 +1059,6 @@ def test_permissions_work_for_dags_with_dot_in_dagname( session.query(DagModel).delete() -def test_fab_models_use_airflow_base_meta(): - # TODO: move this test to appropriate place when we have more tests for FAB models - user = User() - assert user.metadata is Base.metadata - - @pytest.fixture def mock_security_manager(app_builder): mocked_security_manager = MockSecurityManager(appbuilder=app_builder) diff --git a/tests/providers/fab/auth_manager/views/__init__.py b/tests/providers/fab/auth_manager/views/__init__.py index 217e5db960782..a1e80d332f9bb 100644 --- a/tests/providers/fab/auth_manager/views/__init__.py +++ b/tests/providers/fab/auth_manager/views/__init__.py @@ -15,3 +15,20 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +from airflow import __version__ as airflow_version +from airflow.exceptions import AirflowProviderDeprecationWarning + + +def _assert_dataset_deprecation_warning(recwarn) -> None: + if airflow_version.startswith("2"): + warning = recwarn.pop(AirflowProviderDeprecationWarning) + assert warning.category == AirflowProviderDeprecationWarning + assert ( + str(warning.message) + == "is_authorized_dataset will be renamed as is_authorized_asset in Airflow 3 and will be removed when the minimum Airflow version is set to 3.0 for the fab provider" + ) + + +__all__ = ["_assert_dataset_deprecation_warning"] diff --git a/tests/providers/fab/auth_manager/views/test_permissions.py b/tests/providers/fab/auth_manager/views/test_permissions.py index 0b1073df287fa..d80ad66e6e366 100644 --- a/tests/providers/fab/auth_manager/views/test_permissions.py +++ b/tests/providers/fab/auth_manager/views/test_permissions.py @@ -18,12 +18,13 @@ from __future__ import annotations import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS +from tests_common.test_utils.www import client_with_login from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user -from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -64,14 +65,20 @@ def client_permissions_reader(fab_app, user_permissions_reader): @pytest.mark.db_test class TestPermissionsView: - def test_action_model_view(self, client_permissions_reader): + def test_action_model_view(self, client_permissions_reader, recwarn): resp = client_permissions_reader.get("/actions/list/", follow_redirects=True) + + _assert_dataset_deprecation_warning(recwarn) assert resp.status_code == 200 - def test_permission_pair_model_view(self, client_permissions_reader): + def test_permission_pair_model_view(self, client_permissions_reader, recwarn): resp = client_permissions_reader.get("/permissions/list/", follow_redirects=True) + + _assert_dataset_deprecation_warning(recwarn) assert resp.status_code == 200 - def test_resource_model_view(self, client_permissions_reader): + def test_resource_model_view(self, client_permissions_reader, recwarn): resp = client_permissions_reader.get("/resources/list/", follow_redirects=True) + + _assert_dataset_deprecation_warning(recwarn) assert resp.status_code == 200 diff --git a/tests/providers/fab/auth_manager/views/test_roles_list.py b/tests/providers/fab/auth_manager/views/test_roles_list.py index 156f07df41209..79b11b55fa5b1 100644 --- a/tests/providers/fab/auth_manager/views/test_roles_list.py +++ b/tests/providers/fab/auth_manager/views/test_roles_list.py @@ -18,12 +18,13 @@ from __future__ import annotations import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS +from tests_common.test_utils.www import client_with_login from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user -from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -62,6 +63,8 @@ def client_roles_reader(fab_app, user_roles_reader): @pytest.mark.db_test class TestRolesListView: - def test_role_model_view(self, client_roles_reader): + def test_role_model_view(self, client_roles_reader, recwarn): resp = client_roles_reader.get("/roles/list/", follow_redirects=True) + + _assert_dataset_deprecation_warning(recwarn) assert resp.status_code == 200 diff --git a/tests/providers/fab/auth_manager/views/test_user.py b/tests/providers/fab/auth_manager/views/test_user.py index 6660ab926d886..1f33d14cac2d0 100644 --- a/tests/providers/fab/auth_manager/views/test_user.py +++ b/tests/providers/fab/auth_manager/views/test_user.py @@ -18,12 +18,13 @@ from __future__ import annotations import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS +from tests_common.test_utils.www import client_with_login from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user -from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -62,6 +63,8 @@ def client_user_reader(fab_app, user_user_reader): @pytest.mark.db_test class TestUserView: - def test_user_model_view(self, client_user_reader): + def test_user_model_view(self, client_user_reader, recwarn): resp = client_user_reader.get("/users/list/", follow_redirects=True) + + _assert_dataset_deprecation_warning(recwarn) assert resp.status_code == 200 diff --git a/tests/providers/fab/auth_manager/views/test_user_edit.py b/tests/providers/fab/auth_manager/views/test_user_edit.py index 65937b6f83d33..7279ed0b2c601 100644 --- a/tests/providers/fab/auth_manager/views/test_user_edit.py +++ b/tests/providers/fab/auth_manager/views/test_user_edit.py @@ -18,12 +18,13 @@ from __future__ import annotations import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS +from tests_common.test_utils.www import client_with_login from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user -from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -62,6 +63,7 @@ def client_user_reader(fab_app, user_user_reader): @pytest.mark.db_test class TestUserEditView: - def test_reset_my_password_view(self, client_user_reader): + def test_reset_my_password_view(self, client_user_reader, recwarn): resp = client_user_reader.get("/resetmypassword/form", follow_redirects=True) + _assert_dataset_deprecation_warning(recwarn) assert resp.status_code == 200 diff --git a/tests/providers/fab/auth_manager/views/test_user_stats.py b/tests/providers/fab/auth_manager/views/test_user_stats.py index 8cb260fcf1ec4..382a8f10984ac 100644 --- a/tests/providers/fab/auth_manager/views/test_user_stats.py +++ b/tests/providers/fab/auth_manager/views/test_user_stats.py @@ -18,12 +18,13 @@ from __future__ import annotations import pytest +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user +from providers.tests.fab.auth_manager.views import _assert_dataset_deprecation_warning +from tests_common.test_utils.compat import AIRFLOW_V_2_9_PLUS +from tests_common.test_utils.www import client_with_login from airflow.security import permissions from airflow.www import app as application -from tests.test_utils.api_connexion_utils import create_user, delete_user -from tests.test_utils.compat import AIRFLOW_V_2_9_PLUS -from tests.test_utils.www import client_with_login pytestmark = [ pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.9.0+ only"), @@ -62,6 +63,7 @@ def client_user_stats_reader(fab_app, user_user_stats_reader): @pytest.mark.db_test class TestUserStats: - def test_user_stats(self, client_user_stats_reader): + def test_user_stats(self, client_user_stats_reader, recwarn): resp = client_user_stats_reader.get("/userstatschartview/chart", follow_redirects=True) + _assert_dataset_deprecation_warning(recwarn) assert resp.status_code == 200