Skip to content

Commit c440959

Browse files
authored
Disable ORM access from Tasks, DAG processing and Triggers (apache#47320)
All of these use the Workload supervisor from the TaskSDK and the main paths (XCom, Variables and Secrets) have all been ported to use the Execution API, so it's about time we disabled DB access.
1 parent e81b19e commit c440959

File tree

11 files changed

+104
-93
lines changed

11 files changed

+104
-93
lines changed

airflow/dag_processing/processor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,8 @@ def _parse_file_entrypoint():
6363
import structlog
6464

6565
from airflow.sdk.execution_time import task_runner
66-
from airflow.settings import configure_orm
6766

6867
# Parse DAG file, send JSON back up!
69-
70-
# We need to reconfigure the orm here, as DagFileProcessorManager does db queries for bundles, and
71-
# the session across forks blows things up.
72-
configure_orm()
73-
7468
comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
7569
input=sys.stdin,
7670
decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),

airflow/settings.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import pluggy
3131
from packaging.version import Version
32-
from sqlalchemy import create_engine, exc, text
32+
from sqlalchemy import create_engine
3333
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine
3434
from sqlalchemy.orm import scoped_session, sessionmaker
3535
from sqlalchemy.pool import NullPool
@@ -46,7 +46,6 @@
4646

4747
if TYPE_CHECKING:
4848
from sqlalchemy.engine import Engine
49-
from sqlalchemy.orm import Session as SASession
5049

5150
log = logging.getLogger(__name__)
5251

@@ -101,12 +100,12 @@
101100
"""
102101

103102
engine: Engine
104-
Session: Callable[..., SASession]
103+
Session: scoped_session
105104
# NonScopedSession creates global sessions and is not safe to use in multi-threaded environment without
106105
# additional precautions. The only use case is when the session lifecycle needs
107106
# custom handling. Most of the time we only want one unique thread local session object,
108107
# this is achieved by the Session factory above.
109-
NonScopedSession: Callable[..., SASession]
108+
NonScopedSession: sessionmaker
110109
async_engine: AsyncEngine
111110
AsyncSession: Callable[..., SAAsyncSession]
112111

@@ -389,6 +388,12 @@ def _session_maker(_engine):
389388
NonScopedSession = _session_maker(engine)
390389
Session = scoped_session(NonScopedSession)
391390

391+
from sqlalchemy.orm.session import close_all_sessions
392+
393+
os.register_at_fork(after_in_child=close_all_sessions)
394+
# https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork
395+
os.register_at_fork(after_in_child=lambda: engine.dispose(close=False))
396+
392397

393398
DEFAULT_ENGINE_ARGS = {
394399
"postgresql": {
@@ -479,14 +484,23 @@ def prepare_engine_args(disable_connection_pool=False, pool_class=None):
479484

480485
def dispose_orm():
481486
"""Properly close pooled database connections."""
487+
global Session, engine, NonScopedSession
488+
489+
_globals = globals()
490+
if "engine" not in _globals and "Session" not in _globals:
491+
return
492+
482493
log.debug("Disposing DB connection pool (PID %s)", os.getpid())
483-
global engine
484-
global Session
485494

486-
if Session is not None: # type: ignore[truthy-function]
495+
if "Session" in _globals and Session is not None:
496+
from sqlalchemy.orm.session import close_all_sessions
497+
487498
Session.remove()
488499
Session = None
489-
if engine:
500+
NonScopedSession = None
501+
close_all_sessions()
502+
503+
if "engine" in _globals:
490504
engine.dispose()
491505
engine = None
492506

@@ -529,26 +543,6 @@ def configure_adapters():
529543
pass
530544

531545

532-
def validate_session():
533-
"""Validate ORM Session."""
534-
global engine
535-
536-
worker_precheck = conf.getboolean("celery", "worker_precheck")
537-
if not worker_precheck:
538-
return True
539-
else:
540-
check_session = sessionmaker(bind=engine)
541-
session = check_session()
542-
try:
543-
session.execute(text("select 1"))
544-
conn_status = True
545-
except exc.DBAPIError as err:
546-
log.error(err)
547-
conn_status = False
548-
session.close()
549-
return conn_status
550-
551-
552546
def configure_action_logging() -> None:
553547
"""Any additional configuration (register callback) for airflow.utils.action_loggers module."""
554548

providers/celery/provider.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,6 @@ config:
308308
type: integer
309309
example: ~
310310
default: "3"
311-
worker_precheck:
312-
description: |
313-
Worker initialisation check to validate Metadata Database connection
314-
version_added: ~
315-
type: string
316-
example: ~
317-
default: "False"
318311
extra_celery_config:
319312
description: |
320313
Extra celery configs to include in the celery worker.

providers/celery/src/airflow/providers/celery/cli/celery_command.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,11 @@ def worker(args):
197197
from airflow.sdk.log import configure_logging
198198

199199
configure_logging(output=sys.stdout.buffer)
200-
201-
# Disable connection pool so that celery worker does not hold an unnecessary db connection
202-
settings.reconfigure_orm(disable_connection_pool=True)
203-
if not settings.validate_session():
204-
raise SystemExit("Worker exiting, database connection precheck failed.")
200+
else:
201+
# Disable connection pool so that celery worker does not hold an unnecessary db connection
202+
settings.reconfigure_orm(disable_connection_pool=True)
203+
if not settings.validate_session():
204+
raise SystemExit("Worker exiting, database connection precheck failed.")
205205

206206
autoscale = args.autoscale
207207
skip_serve_logs = args.skip_serve_logs

providers/celery/src/airflow/providers/celery/get_provider_info.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -266,13 +266,6 @@ def get_provider_info():
266266
"example": None,
267267
"default": "3",
268268
},
269-
"worker_precheck": {
270-
"description": "Worker initialisation check to validate Metadata Database connection\n",
271-
"version_added": None,
272-
"type": "string",
273-
"example": None,
274-
"default": "False",
275-
},
276269
"extra_celery_config": {
277270
"description": 'Extra celery configs to include in the celery worker.\nAny of the celery config can be added to this config and it\nwill be applied while starting the celery worker. e.g. {"worker_max_tasks_per_child": 10}\nSee also:\nhttps://docs.celeryq.dev/en/stable/userguide/configuration.html#configuration-and-defaults\n',
278271
"version_added": None,

providers/celery/tests/unit/celery/cli/test_celery_command.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,11 @@
1919

2020
import importlib
2121
import os
22-
from argparse import Namespace
2322
from unittest import mock
2423
from unittest.mock import patch
2524

2625
import pytest
27-
import sqlalchemy
2826

29-
import airflow
3027
from airflow.cli import cli_parser
3128
from airflow.configuration import conf
3229
from airflow.executors import executor_loader
@@ -39,37 +36,6 @@
3936
pytestmark = pytest.mark.db_test
4037

4138

42-
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
43-
class TestWorkerPrecheck:
44-
@mock.patch("airflow.settings.validate_session")
45-
def test_error(self, mock_validate_session):
46-
"""
47-
Test to verify the exit mechanism of airflow-worker cli
48-
by mocking validate_session method
49-
"""
50-
mock_validate_session.return_value = False
51-
with pytest.raises(SystemExit) as ctx, conf_vars({("core", "executor"): "CeleryExecutor"}):
52-
celery_command.worker(Namespace(queues=1, concurrency=1))
53-
assert str(ctx.value) == "Worker exiting, database connection precheck failed."
54-
55-
@conf_vars({("celery", "worker_precheck"): "False"})
56-
def test_worker_precheck_exception(self):
57-
"""
58-
Test to check the behaviour of validate_session method
59-
when worker_precheck is absent in airflow configuration
60-
"""
61-
assert airflow.settings.validate_session()
62-
63-
@mock.patch("sqlalchemy.orm.session.Session.execute")
64-
@conf_vars({("celery", "worker_precheck"): "True"})
65-
def test_validate_session_dbapi_exception(self, mock_session):
66-
"""
67-
Test to validate connection failure scenario on SELECT 1 query
68-
"""
69-
mock_session.side_effect = sqlalchemy.exc.OperationalError("m1", "m2", "m3", "m4")
70-
assert airflow.settings.validate_session() is False
71-
72-
7339
@pytest.mark.backend("mysql", "postgres")
7440
@conf_vars({("dag_processor", "stale_bundle_cleanup_interval"): 0})
7541
class TestCeleryStopCommand:

task_sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,62 @@ def _get_last_chance_stderr() -> TextIO:
206206
return stream
207207

208208

209+
class BlockedDBSession:
210+
""":meta private:""" # noqa: D400
211+
212+
def __init__(self):
213+
raise RuntimeError("Direct database access via the ORM is not allowed in Airflow 3.0")
214+
215+
def remove(*args, **kwargs):
216+
pass
217+
218+
def get_bind(
219+
self,
220+
mapper=None,
221+
clause=None,
222+
bind=None,
223+
_sa_skip_events=None,
224+
_sa_skip_for_implicit_returning=False,
225+
):
226+
pass
227+
228+
229+
def block_orm_access():
230+
"""
231+
Disable direct DB access as best as possible from task code.
232+
233+
While we still don't have 100% code separation between TaskSDK and "core" Airflow, it is still possible to
234+
import the models and use them. This does what it can to disable that if it is not blocked at the network
235+
level
236+
"""
237+
# A fake URL schema that might give users some clue what's going on. Hopefully
238+
conn = "airflow-db-not-allowed:///"
239+
if "airflow.settings" in sys.modules:
240+
from airflow import settings
241+
from airflow.configuration import conf
242+
243+
settings.dispose_orm()
244+
245+
for attr in ("engine", "async_engine", "Session", "AsyncSession", "NonScopedSession"):
246+
if hasattr(settings, attr):
247+
delattr(settings, attr)
248+
249+
def configure_orm(*args, **kwargs):
250+
raise RuntimeError("Database access is disabled from DAGs and Triggers")
251+
252+
settings.configure_orm = configure_orm
253+
settings.Session = BlockedDBSession
254+
if conf.has_section("database"):
255+
conf.set("database", "sql_alchemy_conn", conn)
256+
conf.set("database", "sql_alchemy_conn_cmd", "/bin/false")
257+
conf.set("database", "sql_alchemy_conn_secret", "db-access-blocked")
258+
259+
settings.SQL_ALCHEMY_CONN = conn
260+
settings.SQL_ALCHEMY_CONN_ASYNC = conn
261+
262+
os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_CONN"] = conn
263+
264+
209265
def _fork_main(
210266
child_stdin: socket,
211267
child_stdout: socket,
@@ -261,6 +317,8 @@ def exit(n: int) -> NoReturn:
261317
base_exit(n)
262318

263319
try:
320+
block_orm_access()
321+
264322
target()
265323
exit(0)
266324
except SystemExit as e:

task_sdk/tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
# Task SDK does not need access to the Airflow database
3030
os.environ["_AIRFLOW_SKIP_DB_TESTS"] = "true"
31+
os.environ["_AIRFLOW__AS_LIBRARY"] = "true"
3132

3233
if TYPE_CHECKING:
3334
from datetime import datetime
@@ -56,6 +57,10 @@ def pytest_configure(config: pytest.Config) -> None:
5657
# Always skip looking for tests in these folders!
5758
config.addinivalue_line("norecursedirs", "tests/test_dags")
5859

60+
import airflow.settings
61+
62+
airflow.settings.configure_policy_plugin_manager()
63+
5964

6065
@pytest.hookimpl(tryfirst=True)
6166
def pytest_runtest_setup(item):

task_sdk/tests/execution_time/test_supervisor.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import selectors
2525
import signal
2626
import sys
27+
import time
2728
from io import BytesIO
2829
from operator import attrgetter
2930
from pathlib import Path
@@ -850,7 +851,9 @@ def _handler(sig, frame):
850851
client=MagicMock(spec=sdk_client.Client),
851852
target=subprocess_main,
852853
)
854+
853855
# Ensure we get one normal run, to give the proc time to register it's custom sighandler
856+
time.sleep(0.1)
854857
proc._service_subprocess(max_wait_time=1)
855858
proc.kill(signal_to_send=signal_to_send, escalation_delay=0.5, force=True)
856859

tests/dag_processing/test_manager.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,16 +165,18 @@ def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_b
165165
processor_timeout=365 * 86_400,
166166
)
167167

168-
with create_session() as session:
169-
manager.run()
168+
manager.run()
170169

170+
with create_session() as session:
171171
import_errors = session.query(ParseImportError).all()
172172
assert len(import_errors) == 1
173173

174174
path_to_parse.unlink()
175175

176-
# Rerun the parser once the dag file has been removed
177-
manager.run()
176+
# Rerun the parser once the dag file has been removed
177+
manager.run()
178+
179+
with create_session() as session:
178180
import_errors = session.query(ParseImportError).all()
179181

180182
assert len(import_errors) == 0
@@ -658,6 +660,7 @@ def test_refresh_dags_dir_deactivates_deleted_zipped_dags(
658660
shutil.copy(source_location, zip_dag_path)
659661

660662
with configure_testing_dag_bundle(bundle_path):
663+
session.commit()
661664
manager = DagFileProcessorManager(max_runs=1)
662665
manager.run()
663666

0 commit comments

Comments
 (0)