|
29 | 29 |
|
30 | 30 | import pluggy |
31 | 31 | from packaging.version import Version |
32 | | -from sqlalchemy import create_engine, exc, text |
| 32 | +from sqlalchemy import create_engine |
33 | 33 | from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession as SAAsyncSession, create_async_engine |
34 | 34 | from sqlalchemy.orm import scoped_session, sessionmaker |
35 | 35 | from sqlalchemy.pool import NullPool |
|
46 | 46 |
|
47 | 47 | if TYPE_CHECKING: |
48 | 48 | from sqlalchemy.engine import Engine |
49 | | - from sqlalchemy.orm import Session as SASession |
50 | 49 |
|
51 | 50 | log = logging.getLogger(__name__) |
52 | 51 |
|
|
101 | 100 | """ |
102 | 101 |
|
103 | 102 | engine: Engine |
104 | | -Session: Callable[..., SASession] |
| 103 | +Session: scoped_session |
105 | 104 | # NonScopedSession creates global sessions and is not safe to use in multi-threaded environment without |
106 | 105 | # additional precautions. The only use case is when the session lifecycle needs |
107 | 106 | # custom handling. Most of the time we only want one unique thread local session object, |
108 | 107 | # this is achieved by the Session factory above. |
109 | | -NonScopedSession: Callable[..., SASession] |
| 108 | +NonScopedSession: sessionmaker |
110 | 109 | async_engine: AsyncEngine |
111 | 110 | AsyncSession: Callable[..., SAAsyncSession] |
112 | 111 |
|
@@ -389,6 +388,12 @@ def _session_maker(_engine): |
389 | 388 | NonScopedSession = _session_maker(engine) |
390 | 389 | Session = scoped_session(NonScopedSession) |
391 | 390 |
|
| 391 | + from sqlalchemy.orm.session import close_all_sessions |
| 392 | + |
| 393 | + os.register_at_fork(after_in_child=close_all_sessions) |
| 394 | + # https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork |
| 395 | + os.register_at_fork(after_in_child=lambda: engine.dispose(close=False)) |
| 396 | + |
392 | 397 |
|
393 | 398 | DEFAULT_ENGINE_ARGS = { |
394 | 399 | "postgresql": { |
@@ -479,14 +484,23 @@ def prepare_engine_args(disable_connection_pool=False, pool_class=None): |
479 | 484 |
|
480 | 485 | def dispose_orm(): |
481 | 486 | """Properly close pooled database connections.""" |
| 487 | + global Session, engine, NonScopedSession |
| 488 | + |
| 489 | + _globals = globals() |
| 490 | + if "engine" not in _globals and "Session" not in _globals: |
| 491 | + return |
| 492 | + |
482 | 493 | log.debug("Disposing DB connection pool (PID %s)", os.getpid()) |
483 | | - global engine |
484 | | - global Session |
485 | 494 |
|
486 | | - if Session is not None: # type: ignore[truthy-function] |
| 495 | + if "Session" in _globals and Session is not None: |
| 496 | + from sqlalchemy.orm.session import close_all_sessions |
| 497 | + |
487 | 498 | Session.remove() |
488 | 499 | Session = None |
489 | | - if engine: |
| 500 | + NonScopedSession = None |
| 501 | + close_all_sessions() |
| 502 | + |
| 503 | + if "engine" in _globals: |
490 | 504 | engine.dispose() |
491 | 505 | engine = None |
492 | 506 |
|
@@ -529,26 +543,6 @@ def configure_adapters(): |
529 | 543 | pass |
530 | 544 |
|
531 | 545 |
|
532 | | -def validate_session(): |
533 | | - """Validate ORM Session.""" |
534 | | - global engine |
535 | | - |
536 | | - worker_precheck = conf.getboolean("celery", "worker_precheck") |
537 | | - if not worker_precheck: |
538 | | - return True |
539 | | - else: |
540 | | - check_session = sessionmaker(bind=engine) |
541 | | - session = check_session() |
542 | | - try: |
543 | | - session.execute(text("select 1")) |
544 | | - conn_status = True |
545 | | - except exc.DBAPIError as err: |
546 | | - log.error(err) |
547 | | - conn_status = False |
548 | | - session.close() |
549 | | - return conn_status |
550 | | - |
551 | | - |
552 | 546 | def configure_action_logging() -> None: |
553 | 547 | """Any additional configuration (register callback) for airflow.utils.action_loggers module.""" |
554 | 548 |
|
|
0 commit comments